from __future__ import annotations import json from dataclasses import asdict, dataclass from pathlib import Path import re from typing import Literal from okx_codex_trader.models import InstrumentMeta, Position from okx_codex_trader.okx_client import OkxClient, build_contract_size PositionSide = Literal["flat", "long", "short"] @dataclass(frozen=True) class TargetPosition: side: PositionSide unit: float known: bool reason: str contracts: float | None = None @dataclass(frozen=True) class RuntimeState: last_candle_ts: int | None nextgen_active_legs: tuple[str, ...] micro_side: Literal["long", "short"] | None @dataclass(frozen=True) class PlannedAction: action: Literal["noop", "open", "increase", "reduce", "close", "reverse"] side: PositionSide unit: float reduce_only: bool @dataclass(frozen=True) class ExecutionPlan: target: TargetPosition current: TargetPosition actions: tuple[PlannedAction, ...] @dataclass(frozen=True) class RenderedOrder: action: str margin_usdt: float body: dict[str, object] EMPTY_STATE = RuntimeState(last_candle_ts=None, nextgen_active_legs=(), micro_side=None) CLIENT_ORDER_ID_MAX_LENGTH = 32 def load_runtime_state(path: Path) -> RuntimeState: if not path.exists(): return EMPTY_STATE payload = json.loads(path.read_text(encoding="utf-8")) return RuntimeState( last_candle_ts=payload["last_candle_ts"], nextgen_active_legs=tuple(payload["nextgen_active_legs"]), micro_side=payload["micro_side"], ) def save_runtime_state(path: Path, state: RuntimeState) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(asdict(state), indent=2, sort_keys=True) + "\n", encoding="utf-8") def _decision_candle_ts(payload: dict[str, object]) -> int: active_engine = str(payload["decision"]["active_engine"]) if active_engine == "nextgen": nextgen = payload["nextgen"] if "data" in nextgen: return int(nextgen["data"]["decision_candle_ts"]) return int(nextgen["decision"]["decision_candle_ts"]) return int(payload["micro"]["decision_candle_ts"]) def target_from_signal(payload: dict[str, object], state: RuntimeState) -> tuple[RuntimeState, TargetPosition]: candle_ts = _decision_candle_ts(payload) if state.last_candle_ts is not None and candle_ts <= state.last_candle_ts: return state, target_from_state(payload, state) active_engine = str(payload["decision"]["active_engine"]) if active_engine == "nextgen": active = set(state.nextgen_active_legs) weights: dict[str, float] = {} for leg in payload["nextgen"]["legs"]: leg_id = str(leg["leg_id"]) weights[leg_id] = float(leg["suggested_weight"]) if bool(leg["signal"]): active.add(leg_id) elif leg_id in active and bool(leg["exit_signal"]): active.remove(leg_id) next_state = RuntimeState(last_candle_ts=candle_ts, nextgen_active_legs=tuple(sorted(active)), micro_side=None) return next_state, _nextgen_target(next_state, weights) return ( RuntimeState(last_candle_ts=candle_ts, nextgen_active_legs=(), micro_side=state.micro_side), TargetPosition( side="flat", unit=0.0, known=False, reason="micro target position requires persistent micro exit state before live execution", ), ) def target_from_state(payload: dict[str, object], state: RuntimeState) -> TargetPosition: active_engine = str(payload["decision"]["active_engine"]) if active_engine != "nextgen": return TargetPosition( side="flat", unit=0.0, known=False, reason="micro target position requires persistent micro exit state before live execution", ) weights = {str(leg["leg_id"]): float(leg["suggested_weight"]) for leg in payload["nextgen"]["legs"]} return _nextgen_target(state, weights) def _nextgen_target(state: RuntimeState, weights: dict[str, float]) -> TargetPosition: unit = sum(weights[leg_id] for leg_id in state.nextgen_active_legs) if unit <= 0.0: return TargetPosition(side="flat", unit=0.0, known=True, reason="no active nextgen virtual legs") return TargetPosition(side="long", unit=unit, known=True, reason="active nextgen virtual legs net to one long ETH target") def plan_position_delta(current: TargetPosition, target: TargetPosition) -> ExecutionPlan: if not current.known or not target.known: return ExecutionPlan(target=target, current=current, actions=()) if current.side == target.side and current.unit == target.unit: return ExecutionPlan(target=target, current=current, actions=(PlannedAction("noop", target.side, 0.0, False),)) if current.side == "flat": return ExecutionPlan(target=target, current=current, actions=(PlannedAction("open", target.side, target.unit, False),)) if target.side == "flat": return ExecutionPlan(target=target, current=current, actions=(PlannedAction("close", current.side, current.unit, True),)) if current.side == target.side: if target.unit > current.unit: return ExecutionPlan(target=target, current=current, actions=(PlannedAction("increase", target.side, target.unit - current.unit, False),)) return ExecutionPlan(target=target, current=current, actions=(PlannedAction("reduce", current.side, current.unit - target.unit, True),)) return ExecutionPlan( target=target, current=current, actions=( PlannedAction("close", current.side, current.unit, True), PlannedAction("reverse", target.side, target.unit, False), ), ) def current_position_from_okx( *, positions: list[Position], mark_price: float, metadata: InstrumentMeta, leverage: int, margin_per_unit_usdt: float, ) -> TargetPosition: if leverage <= 0 or margin_per_unit_usdt <= 0.0 or mark_price <= 0.0 or metadata.ct_val <= 0.0: raise ValueError("position normalization inputs are invalid") active = [position for position in positions if position.size > 0.0] if not active: return TargetPosition(side="flat", unit=0.0, known=True, reason="no open OKX position", contracts=0.0) sides = {position.pos_side for position in active} if len(sides) != 1: return TargetPosition(side="flat", unit=0.0, known=False, reason="both OKX hedge sides are open") side = active[0].pos_side if side not in {"long", "short"}: return TargetPosition(side="flat", unit=0.0, known=False, reason="OKX position side is unsupported") contracts = sum(position.size for position in active) notional = contracts * metadata.ct_val * mark_price margin = notional / leverage unit = margin / margin_per_unit_usdt return TargetPosition(side=side, unit=unit, known=True, reason="OKX position normalized by configured strategy unit margin", contracts=contracts) def render_market_order_bodies( *, plan: ExecutionPlan, symbol: str, mark_price: float, metadata: InstrumentMeta, leverage: int, margin_per_unit_usdt: float, max_new_margin_usdt: float, max_total_margin_usdt: float, client_order_id_prefix: str, stop_loss_pct: float | None = None, take_profit_pct: float | None = None, ) -> tuple[RenderedOrder, ...]: if leverage <= 0 or margin_per_unit_usdt <= 0.0 or max_new_margin_usdt < 0.0 or max_total_margin_usdt < 0.0: raise ValueError("order rendering inputs are invalid") if stop_loss_pct is not None and stop_loss_pct <= 0.0: raise ValueError("stop_loss_pct is invalid") if take_profit_pct is not None and take_profit_pct <= 0.0: raise ValueError("take_profit_pct is invalid") if plan.target.known and plan.target.unit * margin_per_unit_usdt > max_total_margin_usdt: raise ValueError("target margin exceeds max_total_margin_usdt") rendered: list[RenderedOrder] = [] new_margin = 0.0 index = 1 for action in plan.actions: if action.action == "noop": continue margin = action.unit * margin_per_unit_usdt if margin <= 0.0 or action.side == "flat": raise ValueError("planned action is invalid") if not action.reduce_only: new_margin += margin if new_margin > max_new_margin_usdt: raise ValueError("new margin exceeds max_new_margin_usdt") side = _okx_side(action) if action.action == "close": if plan.current.contracts is None or plan.current.contracts <= 0.0: raise ValueError("current contracts are required for close orders") size = plan.current.contracts else: size = build_contract_size(margin * leverage, mark_price, metadata) stop_loss_trigger_price = None take_profit_trigger_price = None if not action.reduce_only and stop_loss_pct is not None: stop_loss_trigger_price = mark_price * (1.0 - stop_loss_pct if action.side == "long" else 1.0 + stop_loss_pct) if not action.reduce_only and take_profit_pct is not None: take_profit_trigger_price = mark_price * (1.0 + take_profit_pct if action.side == "long" else 1.0 - take_profit_pct) rendered.append( RenderedOrder( action=action.action, margin_usdt=margin, body=OkxClient.build_market_order_body( symbol=symbol, side=side, pos_side=action.side, size=size, client_order_id=market_client_order_id(client_order_id_prefix, index, action.action), reduce_only=action.reduce_only, stop_loss_trigger_price=stop_loss_trigger_price, take_profit_trigger_price=take_profit_trigger_price, ), ) ) index += 1 return tuple(rendered) def market_client_order_id(prefix: str, index: int, action: str) -> str: compact = re.sub(r"[^A-Za-z0-9]", "", f"{prefix}{index}{action}") if not compact: raise ValueError("client order id prefix is invalid") return compact[:CLIENT_ORDER_ID_MAX_LENGTH] def _okx_side(action: PlannedAction) -> str: if action.reduce_only: return "sell" if action.side == "long" else "buy" return "buy" if action.side == "long" else "sell"