| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import json
- from dataclasses import asdict
- from pathlib import Path
- from typing import Callable, Mapping
- from okx_codex_trader.models import PaperOrderResult, PaperPosition, PaperState, TradeSignal
- def default_state() -> PaperState:
- return PaperState(
- cash_usdt=10_000.0,
- realized_pnl=0.0,
- positions=[],
- updated_at="1970-01-01T00:00:00Z",
- )
- def load_state(path: Path) -> PaperState:
- if not path.exists():
- return default_state()
- try:
- payload = json.loads(path.read_text())
- except json.JSONDecodeError as exc:
- raise ValueError("paper state is invalid") from exc
- return parse_state(payload)
- def parse_state(payload: Mapping[str, object]) -> PaperState:
- try:
- cash_usdt = float(payload["cash_usdt"])
- realized_pnl = float(payload["realized_pnl"])
- updated_at = payload["updated_at"]
- positions_payload = payload["positions"]
- except (KeyError, TypeError, ValueError) as exc:
- raise ValueError("paper state is invalid") from exc
- if not isinstance(updated_at, str) or not isinstance(positions_payload, list):
- raise ValueError("paper state is invalid")
- positions: list[PaperPosition] = []
- for entry in positions_payload:
- if not isinstance(entry, Mapping):
- raise ValueError("paper state is invalid")
- try:
- symbol = entry["symbol"]
- side = entry["side"]
- quantity = float(entry["quantity"])
- avg_entry_price = float(entry["avg_entry_price"])
- margin_used = float(entry["margin_used"])
- except (KeyError, TypeError, ValueError) as exc:
- raise ValueError("paper state is invalid") from exc
- if not isinstance(symbol, str) or side not in {"long", "short"}:
- raise ValueError("paper state is invalid")
- positions.append(
- PaperPosition(
- symbol=symbol,
- side=side,
- quantity=quantity,
- avg_entry_price=avg_entry_price,
- margin_used=margin_used,
- )
- )
- return PaperState(
- cash_usdt=cash_usdt,
- realized_pnl=realized_pnl,
- positions=positions,
- updated_at=updated_at,
- )
- def save_state(path: Path, state: PaperState) -> None:
- path.write_text(json.dumps(asdict(state), indent=2))
- def apply_signal(
- *,
- state: PaperState,
- symbol: str,
- signal: TradeSignal,
- margin_usdt: float,
- price: float,
- now: Callable[[], str],
- ) -> tuple[PaperState, PaperOrderResult]:
- if signal.action == "flat":
- return state, PaperOrderResult(
- status="noop",
- symbol=symbol,
- side=None,
- price=None,
- quantity=None,
- margin_used=None,
- cash_usdt=state.cash_usdt,
- )
- if state.cash_usdt < margin_usdt:
- raise ValueError("insufficient local cash")
- quantity = margin_usdt * signal.leverage / price
- positions = list(state.positions)
- for index, position in enumerate(positions):
- if position.symbol != symbol or position.side != signal.action:
- continue
- total_quantity = position.quantity + quantity
- avg_entry_price = (
- position.quantity * position.avg_entry_price + quantity * price
- ) / total_quantity
- positions[index] = PaperPosition(
- symbol=symbol,
- side=signal.action,
- quantity=total_quantity,
- avg_entry_price=avg_entry_price,
- margin_used=position.margin_used + margin_usdt,
- )
- break
- else:
- positions.append(
- PaperPosition(
- symbol=symbol,
- side=signal.action,
- quantity=quantity,
- avg_entry_price=price,
- margin_used=margin_usdt,
- )
- )
- next_state = PaperState(
- cash_usdt=state.cash_usdt - margin_usdt,
- realized_pnl=state.realized_pnl,
- positions=positions,
- updated_at=now(),
- )
- return next_state, PaperOrderResult(
- status="filled",
- symbol=symbol,
- side=signal.action,
- price=price,
- quantity=quantity,
- margin_used=margin_usdt,
- cash_usdt=next_state.cash_usdt,
- )
|