import json from dataclasses import asdict from pathlib import Path from typing import Callable, Mapping from okx_codex_trader.models import PaperOrderResult, PaperPosition, PaperState, TradeSignal def default_state() -> PaperState: return PaperState( cash_usdt=10_000.0, realized_pnl=0.0, positions=[], updated_at="1970-01-01T00:00:00Z", ) def load_state(path: Path) -> PaperState: if not path.exists(): return default_state() try: payload = json.loads(path.read_text()) except json.JSONDecodeError as exc: raise ValueError("paper state is invalid") from exc return parse_state(payload) def parse_state(payload: Mapping[str, object]) -> PaperState: try: cash_usdt = float(payload["cash_usdt"]) realized_pnl = float(payload["realized_pnl"]) updated_at = payload["updated_at"] positions_payload = payload["positions"] except (KeyError, TypeError, ValueError) as exc: raise ValueError("paper state is invalid") from exc if not isinstance(updated_at, str) or not isinstance(positions_payload, list): raise ValueError("paper state is invalid") positions: list[PaperPosition] = [] for entry in positions_payload: if not isinstance(entry, Mapping): raise ValueError("paper state is invalid") try: symbol = entry["symbol"] side = entry["side"] quantity = float(entry["quantity"]) avg_entry_price = float(entry["avg_entry_price"]) margin_used = float(entry["margin_used"]) except (KeyError, TypeError, ValueError) as exc: raise ValueError("paper state is invalid") from exc if not isinstance(symbol, str) or side not in {"long", "short"}: raise ValueError("paper state is invalid") positions.append( PaperPosition( symbol=symbol, side=side, quantity=quantity, avg_entry_price=avg_entry_price, margin_used=margin_used, ) ) return PaperState( cash_usdt=cash_usdt, realized_pnl=realized_pnl, positions=positions, updated_at=updated_at, ) def save_state(path: Path, state: PaperState) -> None: path.write_text(json.dumps(asdict(state), indent=2)) def apply_signal( *, state: PaperState, symbol: str, signal: TradeSignal, margin_usdt: float, price: float, now: Callable[[], str], ) -> tuple[PaperState, PaperOrderResult]: if signal.action == "flat": return state, PaperOrderResult( status="noop", symbol=symbol, side=None, price=None, quantity=None, margin_used=None, cash_usdt=state.cash_usdt, ) if state.cash_usdt < margin_usdt: raise ValueError("insufficient local cash") quantity = margin_usdt * signal.leverage / price positions = list(state.positions) for index, position in enumerate(positions): if position.symbol != symbol or position.side != signal.action: continue total_quantity = position.quantity + quantity avg_entry_price = ( position.quantity * position.avg_entry_price + quantity * price ) / total_quantity positions[index] = PaperPosition( symbol=symbol, side=signal.action, quantity=total_quantity, avg_entry_price=avg_entry_price, margin_used=position.margin_used + margin_usdt, ) break else: positions.append( PaperPosition( symbol=symbol, side=signal.action, quantity=quantity, avg_entry_price=price, margin_used=margin_usdt, ) ) next_state = PaperState( cash_usdt=state.cash_usdt - margin_usdt, realized_pnl=state.realized_pnl, positions=positions, updated_at=now(), ) return next_state, PaperOrderResult( status="filled", symbol=symbol, side=signal.action, price=price, quantity=quantity, margin_used=margin_usdt, cash_usdt=next_state.cash_usdt, )