import base64 import hashlib import hmac import json from datetime import UTC, datetime from decimal import Decimal, ROUND_DOWN from urllib.parse import urlencode from okx_codex_trader.config import Config from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float: raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val))) lot_size = Decimal(str(metadata.lot_sz)) size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size if size < Decimal(str(metadata.min_sz)): raise ValueError("contract size below minimum") return float(size) def _format_number(value: float) -> str: return format(Decimal(str(value)).normalize(), "f") class OkxClient: base_url = "https://www.okx.com" def __init__(self, config: Config, session=None): self.config = config if session is None: import requests session = requests.Session() self.session = session def _invalid_payload(self) -> ValueError: return ValueError("okx response payload is invalid") def _first_item(self, data: list[dict[str, object]]) -> dict[str, object]: if not data: raise self._invalid_payload() return data[0] def _request( self, method: str, path: str, *, params: dict[str, object] | None = None, json_body: dict[str, object] | None = None, ) -> list[dict[str, object]]: timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z") query = urlencode(params or {}) path_with_query = path if not query else f"{path}?{query}" body = "" if json_body is None else json.dumps(json_body, separators=(",", ":")) signature = base64.b64encode( hmac.new( self.config.api_secret.encode(), f"{timestamp}{method.upper()}{path_with_query}{body}".encode(), hashlib.sha256, ).digest() ).decode() headers = { "OK-ACCESS-KEY": self.config.api_key, "OK-ACCESS-SIGN": signature, "OK-ACCESS-TIMESTAMP": timestamp, "OK-ACCESS-PASSPHRASE": self.config.api_passphrase, "x-simulated-trading": "1", } response = self.session.request( method.upper(), f"{self.base_url}{path}", headers=headers, params=params, json=json_body, ) payload = response.json() if getattr(response, "status_code", 200) >= 400: raise ValueError(str(payload.get("msg") or "okx http error")) if payload.get("code") != "0": raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error")) data = payload.get("data") if not isinstance(data, list): raise self._invalid_payload() return data def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]: data = self._request( "GET", "/api/v5/market/history-candles", params={"instId": symbol, "bar": bar, "limit": limit}, ) try: return [ Candle( symbol=symbol, ts=int(entry[0]), open=float(entry[1]), high=float(entry[2]), low=float(entry[3]), close=float(entry[4]), volume=float(entry[5]), ) for entry in data ] except (IndexError, TypeError, ValueError): raise self._invalid_payload() from None def get_instrument_meta(self, symbol: str) -> InstrumentMeta: data = self._request( "GET", "/api/v5/public/instruments", params={"instType": "SWAP", "instId": symbol}, ) instrument = self._first_item(data) try: return InstrumentMeta( ct_val=float(instrument["ctVal"]), lot_sz=float(instrument["lotSz"]), min_sz=float(instrument["minSz"]), ) except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None def get_last_price(self, symbol: str) -> float: data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol}) ticker = self._first_item(data) try: return float(ticker["last"]) except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None def ensure_hedge_mode(self) -> None: data = self._request("GET", "/api/v5/account/config") config = self._first_item(data) if config.get("posMode") != "long_short_mode": raise ValueError("hedge mode is required") def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None: self._request( "POST", "/api/v5/account/set-leverage", json_body={ "instId": symbol, "lever": str(leverage), "mgnMode": "isolated", "posSide": pos_side, }, ) def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult: if signal.action == "flat": return OrderResult( status="noop", order_id=None, symbol=symbol, side=None, pos_side=None, order_type=None, size=None, ) if not symbol.endswith("-SWAP"): raise ValueError("swap instrument is required") if signal.leverage < 1 or signal.leverage > 3: raise ValueError("leverage is invalid") metadata = self.get_instrument_meta(symbol) price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol) side = "buy" if signal.action == "long" else "sell" pos_side = "long" if signal.action == "long" else "short" self.ensure_hedge_mode() self.set_leverage(symbol, signal.leverage, pos_side) size = build_contract_size(margin_usdt * signal.leverage, price, metadata) order_type = "market" if signal.entry_price is None else "limit" request_body = { "instId": symbol, "tdMode": "isolated", "side": side, "posSide": pos_side, "ordType": order_type, "sz": _format_number(size), } if signal.entry_price is not None: request_body["px"] = _format_number(signal.entry_price) data = self._request("POST", "/api/v5/trade/order", json_body=request_body) order = self._first_item(data) order_id = str(order.get("ordId") or "") if not order_id: raise self._invalid_payload() return OrderResult( status="placed", order_id=order_id, symbol=symbol, side=side, pos_side=pos_side, order_type=order_type, size=size, ) def get_positions(self, symbol: str) -> list[Position]: data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol}) if not data: return [] try: return [ Position( symbol=str(entry["instId"]), pos_side=str(entry["posSide"]), size=float(entry["pos"]), avg_price=float(entry["avgPx"]), ) for entry in data ] except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None