live_execution.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import annotations
  2. import json
  3. from dataclasses import asdict, dataclass
  4. from pathlib import Path
  5. from typing import Literal
  6. from okx_codex_trader.models import InstrumentMeta, Position
  7. from okx_codex_trader.okx_client import OkxClient, build_contract_size
  8. PositionSide = Literal["flat", "long", "short"]
  9. @dataclass(frozen=True)
  10. class TargetPosition:
  11. side: PositionSide
  12. unit: float
  13. known: bool
  14. reason: str
  15. @dataclass(frozen=True)
  16. class RuntimeState:
  17. last_candle_ts: int | None
  18. nextgen_active_legs: tuple[str, ...]
  19. micro_side: Literal["long", "short"] | None
  20. @dataclass(frozen=True)
  21. class PlannedAction:
  22. action: Literal["noop", "open", "increase", "reduce", "close", "reverse"]
  23. side: PositionSide
  24. unit: float
  25. reduce_only: bool
  26. @dataclass(frozen=True)
  27. class ExecutionPlan:
  28. target: TargetPosition
  29. current: TargetPosition
  30. actions: tuple[PlannedAction, ...]
  31. @dataclass(frozen=True)
  32. class RenderedOrder:
  33. action: str
  34. margin_usdt: float
  35. body: dict[str, str]
  36. EMPTY_STATE = RuntimeState(last_candle_ts=None, nextgen_active_legs=(), micro_side=None)
  37. def load_runtime_state(path: Path) -> RuntimeState:
  38. if not path.exists():
  39. return EMPTY_STATE
  40. payload = json.loads(path.read_text(encoding="utf-8"))
  41. return RuntimeState(
  42. last_candle_ts=payload["last_candle_ts"],
  43. nextgen_active_legs=tuple(payload["nextgen_active_legs"]),
  44. micro_side=payload["micro_side"],
  45. )
  46. def save_runtime_state(path: Path, state: RuntimeState) -> None:
  47. path.parent.mkdir(parents=True, exist_ok=True)
  48. path.write_text(json.dumps(asdict(state), indent=2, sort_keys=True) + "\n", encoding="utf-8")
  49. def _decision_candle_ts(payload: dict[str, object]) -> int:
  50. active_engine = str(payload["decision"]["active_engine"])
  51. if active_engine == "nextgen":
  52. nextgen = payload["nextgen"]
  53. if "data" in nextgen:
  54. return int(nextgen["data"]["decision_candle_ts"])
  55. return int(nextgen["decision"]["decision_candle_ts"])
  56. return int(payload["micro"]["decision_candle_ts"])
  57. def target_from_signal(payload: dict[str, object], state: RuntimeState) -> tuple[RuntimeState, TargetPosition]:
  58. candle_ts = _decision_candle_ts(payload)
  59. if state.last_candle_ts is not None and candle_ts <= state.last_candle_ts:
  60. return state, target_from_state(payload, state)
  61. active_engine = str(payload["decision"]["active_engine"])
  62. if active_engine == "nextgen":
  63. active = set(state.nextgen_active_legs)
  64. weights: dict[str, float] = {}
  65. for leg in payload["nextgen"]["legs"]:
  66. leg_id = str(leg["leg_id"])
  67. weights[leg_id] = float(leg["suggested_weight"])
  68. if bool(leg["signal"]):
  69. active.add(leg_id)
  70. elif leg_id in active and bool(leg["exit_signal"]):
  71. active.remove(leg_id)
  72. next_state = RuntimeState(last_candle_ts=candle_ts, nextgen_active_legs=tuple(sorted(active)), micro_side=None)
  73. return next_state, _nextgen_target(next_state, weights)
  74. return (
  75. RuntimeState(last_candle_ts=candle_ts, nextgen_active_legs=(), micro_side=state.micro_side),
  76. TargetPosition(
  77. side="flat",
  78. unit=0.0,
  79. known=False,
  80. reason="micro target position requires persistent micro exit state before live execution",
  81. ),
  82. )
  83. def target_from_state(payload: dict[str, object], state: RuntimeState) -> TargetPosition:
  84. active_engine = str(payload["decision"]["active_engine"])
  85. if active_engine != "nextgen":
  86. return TargetPosition(
  87. side="flat",
  88. unit=0.0,
  89. known=False,
  90. reason="micro target position requires persistent micro exit state before live execution",
  91. )
  92. weights = {str(leg["leg_id"]): float(leg["suggested_weight"]) for leg in payload["nextgen"]["legs"]}
  93. return _nextgen_target(state, weights)
  94. def _nextgen_target(state: RuntimeState, weights: dict[str, float]) -> TargetPosition:
  95. unit = sum(weights[leg_id] for leg_id in state.nextgen_active_legs)
  96. if unit <= 0.0:
  97. return TargetPosition(side="flat", unit=0.0, known=True, reason="no active nextgen virtual legs")
  98. return TargetPosition(side="long", unit=unit, known=True, reason="active nextgen virtual legs net to one long ETH target")
  99. def plan_position_delta(current: TargetPosition, target: TargetPosition) -> ExecutionPlan:
  100. if not current.known or not target.known:
  101. return ExecutionPlan(target=target, current=current, actions=())
  102. if current.side == target.side and current.unit == target.unit:
  103. return ExecutionPlan(target=target, current=current, actions=(PlannedAction("noop", target.side, 0.0, False),))
  104. if current.side == "flat":
  105. return ExecutionPlan(target=target, current=current, actions=(PlannedAction("open", target.side, target.unit, False),))
  106. if target.side == "flat":
  107. return ExecutionPlan(target=target, current=current, actions=(PlannedAction("close", current.side, current.unit, True),))
  108. if current.side == target.side:
  109. if target.unit > current.unit:
  110. return ExecutionPlan(target=target, current=current, actions=(PlannedAction("increase", target.side, target.unit - current.unit, False),))
  111. return ExecutionPlan(target=target, current=current, actions=(PlannedAction("reduce", current.side, current.unit - target.unit, True),))
  112. return ExecutionPlan(
  113. target=target,
  114. current=current,
  115. actions=(
  116. PlannedAction("close", current.side, current.unit, True),
  117. PlannedAction("reverse", target.side, target.unit, False),
  118. ),
  119. )
  120. def current_position_from_okx(
  121. *,
  122. positions: list[Position],
  123. mark_price: float,
  124. metadata: InstrumentMeta,
  125. leverage: int,
  126. margin_per_unit_usdt: float,
  127. ) -> TargetPosition:
  128. if leverage <= 0 or margin_per_unit_usdt <= 0.0 or mark_price <= 0.0 or metadata.ct_val <= 0.0:
  129. raise ValueError("position normalization inputs are invalid")
  130. active = [position for position in positions if position.size > 0.0]
  131. if not active:
  132. return TargetPosition(side="flat", unit=0.0, known=True, reason="no open OKX position")
  133. sides = {position.pos_side for position in active}
  134. if len(sides) != 1:
  135. return TargetPosition(side="flat", unit=0.0, known=False, reason="both OKX hedge sides are open")
  136. side = active[0].pos_side
  137. if side not in {"long", "short"}:
  138. return TargetPosition(side="flat", unit=0.0, known=False, reason="OKX position side is unsupported")
  139. notional = sum(position.size for position in active) * metadata.ct_val * mark_price
  140. margin = notional / leverage
  141. unit = margin / margin_per_unit_usdt
  142. return TargetPosition(side=side, unit=unit, known=True, reason="OKX position normalized by configured strategy unit margin")
  143. def render_market_order_bodies(
  144. *,
  145. plan: ExecutionPlan,
  146. symbol: str,
  147. mark_price: float,
  148. metadata: InstrumentMeta,
  149. leverage: int,
  150. margin_per_unit_usdt: float,
  151. max_new_margin_usdt: float,
  152. client_order_id_prefix: str,
  153. ) -> tuple[RenderedOrder, ...]:
  154. if leverage <= 0 or margin_per_unit_usdt <= 0.0 or max_new_margin_usdt < 0.0:
  155. raise ValueError("order rendering inputs are invalid")
  156. rendered: list[RenderedOrder] = []
  157. new_margin = 0.0
  158. index = 1
  159. for action in plan.actions:
  160. if action.action == "noop":
  161. continue
  162. margin = action.unit * margin_per_unit_usdt
  163. if margin <= 0.0 or action.side == "flat":
  164. raise ValueError("planned action is invalid")
  165. if not action.reduce_only:
  166. new_margin += margin
  167. if new_margin > max_new_margin_usdt:
  168. raise ValueError("new margin exceeds max_new_margin_usdt")
  169. side = _okx_side(action)
  170. size = build_contract_size(margin * leverage, mark_price, metadata)
  171. rendered.append(
  172. RenderedOrder(
  173. action=action.action,
  174. margin_usdt=margin,
  175. body=OkxClient.build_market_order_body(
  176. symbol=symbol,
  177. side=side,
  178. pos_side=action.side,
  179. size=size,
  180. client_order_id=f"{client_order_id_prefix}-{index}-{action.action}",
  181. reduce_only=action.reduce_only,
  182. ),
  183. )
  184. )
  185. index += 1
  186. return tuple(rendered)
  187. def _okx_side(action: PlannedAction) -> str:
  188. if action.reduce_only:
  189. return "sell" if action.side == "long" else "buy"
  190. return "buy" if action.side == "long" else "sell"