paper_engine.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import json
  2. from dataclasses import asdict
  3. from pathlib import Path
  4. from typing import Callable, Mapping
  5. from okx_codex_trader.models import PaperOrderResult, PaperPosition, PaperState, TradeSignal
  6. def default_state() -> PaperState:
  7. return PaperState(
  8. cash_usdt=10_000.0,
  9. realized_pnl=0.0,
  10. positions=[],
  11. updated_at="1970-01-01T00:00:00Z",
  12. )
  13. def load_state(path: Path) -> PaperState:
  14. if not path.exists():
  15. return default_state()
  16. try:
  17. payload = json.loads(path.read_text())
  18. except json.JSONDecodeError as exc:
  19. raise ValueError("paper state is invalid") from exc
  20. return parse_state(payload)
  21. def parse_state(payload: Mapping[str, object]) -> PaperState:
  22. try:
  23. cash_usdt = float(payload["cash_usdt"])
  24. realized_pnl = float(payload["realized_pnl"])
  25. updated_at = payload["updated_at"]
  26. positions_payload = payload["positions"]
  27. except (KeyError, TypeError, ValueError) as exc:
  28. raise ValueError("paper state is invalid") from exc
  29. if not isinstance(updated_at, str) or not isinstance(positions_payload, list):
  30. raise ValueError("paper state is invalid")
  31. positions: list[PaperPosition] = []
  32. for entry in positions_payload:
  33. if not isinstance(entry, Mapping):
  34. raise ValueError("paper state is invalid")
  35. try:
  36. symbol = entry["symbol"]
  37. side = entry["side"]
  38. quantity = float(entry["quantity"])
  39. avg_entry_price = float(entry["avg_entry_price"])
  40. margin_used = float(entry["margin_used"])
  41. except (KeyError, TypeError, ValueError) as exc:
  42. raise ValueError("paper state is invalid") from exc
  43. if not isinstance(symbol, str) or side not in {"long", "short"}:
  44. raise ValueError("paper state is invalid")
  45. positions.append(
  46. PaperPosition(
  47. symbol=symbol,
  48. side=side,
  49. quantity=quantity,
  50. avg_entry_price=avg_entry_price,
  51. margin_used=margin_used,
  52. )
  53. )
  54. return PaperState(
  55. cash_usdt=cash_usdt,
  56. realized_pnl=realized_pnl,
  57. positions=positions,
  58. updated_at=updated_at,
  59. )
  60. def save_state(path: Path, state: PaperState) -> None:
  61. path.write_text(json.dumps(asdict(state), indent=2))
  62. def apply_signal(
  63. *,
  64. state: PaperState,
  65. symbol: str,
  66. signal: TradeSignal,
  67. margin_usdt: float,
  68. price: float,
  69. now: Callable[[], str],
  70. ) -> tuple[PaperState, PaperOrderResult]:
  71. if signal.action == "flat":
  72. return state, PaperOrderResult(
  73. status="noop",
  74. symbol=symbol,
  75. side=None,
  76. price=None,
  77. quantity=None,
  78. margin_used=None,
  79. cash_usdt=state.cash_usdt,
  80. )
  81. if state.cash_usdt < margin_usdt:
  82. raise ValueError("insufficient local cash")
  83. quantity = margin_usdt * signal.leverage / price
  84. positions = list(state.positions)
  85. for index, position in enumerate(positions):
  86. if position.symbol != symbol or position.side != signal.action:
  87. continue
  88. total_quantity = position.quantity + quantity
  89. avg_entry_price = (
  90. position.quantity * position.avg_entry_price + quantity * price
  91. ) / total_quantity
  92. positions[index] = PaperPosition(
  93. symbol=symbol,
  94. side=signal.action,
  95. quantity=total_quantity,
  96. avg_entry_price=avg_entry_price,
  97. margin_used=position.margin_used + margin_usdt,
  98. )
  99. break
  100. else:
  101. positions.append(
  102. PaperPosition(
  103. symbol=symbol,
  104. side=signal.action,
  105. quantity=quantity,
  106. avg_entry_price=price,
  107. margin_used=margin_usdt,
  108. )
  109. )
  110. next_state = PaperState(
  111. cash_usdt=state.cash_usdt - margin_usdt,
  112. realized_pnl=state.realized_pnl,
  113. positions=positions,
  114. updated_at=now(),
  115. )
  116. return next_state, PaperOrderResult(
  117. status="filled",
  118. symbol=symbol,
  119. side=signal.action,
  120. price=price,
  121. quantity=quantity,
  122. margin_used=margin_usdt,
  123. cash_usdt=next_state.cash_usdt,
  124. )