import base64 import hashlib import hmac import json from datetime import UTC, datetime from decimal import Decimal, InvalidOperation, ROUND_DOWN from math import isfinite from typing import TypeAlias from urllib.parse import urlencode from okx_codex_trader.config import Config from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal OkxRow: TypeAlias = dict[str, object] | list[object] def _parse_finite_decimal(value: object) -> Decimal: if isinstance(value, bool): raise ValueError("contract sizing inputs are invalid") try: parsed = Decimal(str(value)) except (InvalidOperation, TypeError, ValueError): raise ValueError("contract sizing inputs are invalid") from None if not parsed.is_finite(): raise ValueError("contract sizing inputs are invalid") return parsed def _parse_finite_float(value: object) -> float: if isinstance(value, bool): raise ValueError("okx response payload is invalid") try: parsed = float(value) except (TypeError, ValueError): raise ValueError("okx response payload is invalid") from None if not isfinite(parsed): raise ValueError("okx response payload is invalid") return parsed def _parse_valid_leverage(value: object) -> int: if isinstance(value, bool) or not isinstance(value, int): raise ValueError("leverage is invalid") if value < 1 or value > 3: raise ValueError("leverage is invalid") return value def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float: notional_decimal = _parse_finite_decimal(notional) price_decimal = _parse_finite_decimal(price) ct_val_decimal = _parse_finite_decimal(metadata.ct_val) lot_size = _parse_finite_decimal(metadata.lot_sz) min_size = _parse_finite_decimal(metadata.min_sz) if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0: raise ValueError("contract sizing inputs are invalid") raw_size = notional_decimal / (price_decimal * ct_val_decimal) size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size if size < min_size: raise ValueError("contract size below minimum") return float(size) def _format_number(value: float) -> str: return format(Decimal(str(value)).normalize(), "f") def _okx_error_message(payload: dict[str, object]) -> str: parts = [str(payload.get("msg") or payload.get("code") or "okx api error")] data = payload.get("data") if isinstance(data, list): for item in data: if not isinstance(item, dict): continue code = item.get("sCode") msg = item.get("sMsg") if code or msg: parts.append(f"{code}: {msg}") return "; ".join(parts) class OkxClient: base_url = "https://www.okx.com" request_timeout = 10.0 def __init__(self, config: Config | None = None, session=None): self.config = config if session is None: import requests session = requests.Session() self.session = session def _invalid_payload(self) -> ValueError: return ValueError("okx response payload is invalid") def _transport_error(self) -> ValueError: return ValueError("okx transport error") def _first_item(self, data: list[OkxRow]) -> dict[str, object]: if not data: raise self._invalid_payload() item = data[0] if not isinstance(item, dict): raise self._invalid_payload() return item @staticmethod def build_post_only_limit_order_body( *, symbol: str, action: str, price: object, size: object, client_order_id: str, ) -> dict[str, str]: if action not in {"long", "short"}: raise ValueError("action is invalid") return { "instId": symbol, "tdMode": "isolated", "side": "buy" if action == "long" else "sell", "posSide": action, "ordType": "post_only", "px": _format_number(price), "sz": _format_number(size), "clOrdId": client_order_id, } @staticmethod def build_entry_batch_order_body( *, symbol: str, action: str, reference_price: object, margin_usdt: object, leverage: object, metadata: InstrumentMeta, client_order_id_prefix: str, ) -> list[dict[str, str]]: reference_price_decimal = _parse_finite_decimal(reference_price) notional_per_order = _parse_finite_decimal(margin_usdt) * _parse_finite_decimal(leverage) / Decimal("3") bodies = [] for index, offset in enumerate((Decimal("0.003"), Decimal("0.006"), Decimal("0.009")), start=1): multiplier = Decimal("1") - offset if action == "long" else Decimal("1") + offset price = reference_price_decimal * multiplier size = build_contract_size(notional_per_order, price, metadata) bodies.append( OkxClient.build_post_only_limit_order_body( symbol=symbol, action=action, price=price, size=size, client_order_id=f"{client_order_id_prefix}-{index}", ) ) return bodies @staticmethod def build_cancel_order_body( *, symbol: str, order_id: str | None = None, client_order_id: str | None = None, ) -> dict[str, str]: if bool(order_id) == bool(client_order_id): raise ValueError("exactly one order identifier is required") body = {"instId": symbol} if order_id: body["ordId"] = order_id if client_order_id: body["clOrdId"] = client_order_id return body @staticmethod def build_market_order_body( *, symbol: str, side: str, pos_side: str, size: object, client_order_id: str, reduce_only: bool, ) -> dict[str, str]: if side not in {"buy", "sell"}: raise ValueError("side is invalid") if pos_side not in {"long", "short"}: raise ValueError("pos_side is invalid") body = { "instId": symbol, "tdMode": "isolated", "side": side, "posSide": pos_side, "ordType": "market", "sz": _format_number(size), "clOrdId": client_order_id, } if reduce_only: body["reduceOnly"] = "true" return body @staticmethod def build_pending_orders_params(*, symbol: str) -> dict[str, str]: return {"instType": "SWAP", "instId": symbol} @staticmethod def build_fills_params(*, symbol: str) -> dict[str, str]: return {"instType": "SWAP", "instId": symbol} def _request( self, method: str, path: str, *, params: dict[str, object] | None = None, json_body: dict[str, object] | None = None, ) -> list[OkxRow]: timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z") query = urlencode(params or {}) path_with_query = path if not query else f"{path}?{query}" body = "" if json_body is None else json.dumps(json_body, separators=(",", ":")) headers: dict[str, str] = {} if self.config is not None: signature = base64.b64encode( hmac.new( self.config.api_secret.encode(), f"{timestamp}{method.upper()}{path_with_query}{body}".encode(), hashlib.sha256, ).digest() ).decode() headers = { "OK-ACCESS-KEY": self.config.api_key, "OK-ACCESS-SIGN": signature, "OK-ACCESS-TIMESTAMP": timestamp, "OK-ACCESS-PASSPHRASE": self.config.api_passphrase, "x-simulated-trading": "1" if self.config.trading_env == "demo" else "0", } if json_body is not None: headers["Content-Type"] = "application/json" try: response = self.session.request( method.upper(), f"{self.base_url}{path}", headers=headers, params=params, data=body if json_body is not None else None, timeout=self.request_timeout, ) except Exception: raise self._transport_error() from None try: payload = response.json() except Exception: raise self._invalid_payload() from None if not isinstance(payload, dict): raise self._invalid_payload() if getattr(response, "status_code", 200) >= 400: raise ValueError(str(payload.get("msg") or "okx http error")) if payload.get("code") != "0": raise ValueError(_okx_error_message(payload)) data = payload.get("data") if not isinstance(data, list): raise self._invalid_payload() return data def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]: remaining = limit after: int | None = None candles_by_ts: dict[int, Candle] = {} while remaining > 0: page_limit = min(remaining, 100) params: dict[str, object] = {"instId": symbol, "bar": bar, "limit": page_limit} if after is not None: params["after"] = after data = self._request("GET", "/api/v5/market/history-candles", params=params) try: page = [] for entry in data: if str(entry[8]) != "1": continue page.append( Candle( symbol=symbol, ts=int(entry[0]), open=_parse_finite_float(entry[1]), high=_parse_finite_float(entry[2]), low=_parse_finite_float(entry[3]), close=_parse_finite_float(entry[4]), volume=_parse_finite_float(entry[5]), ) ) except (IndexError, KeyError, TypeError, ValueError): raise self._invalid_payload() from None if not page: break for candle in page: candles_by_ts[candle.ts] = candle remaining = limit - len(candles_by_ts) oldest_ts = min(candle.ts for candle in page) after = oldest_ts - 1 if len(data) < page_limit: break return sorted(candles_by_ts.values(), key=lambda candle: candle.ts)[:limit] def get_instrument_meta(self, symbol: str) -> InstrumentMeta: data = self._request( "GET", "/api/v5/public/instruments", params={"instType": "SWAP", "instId": symbol}, ) instrument = self._first_item(data) try: if instrument.get("instId") != symbol or instrument.get("instType") != "SWAP": raise self._invalid_payload() return InstrumentMeta( ct_val=_parse_finite_float(instrument["ctVal"]), lot_sz=_parse_finite_float(instrument["lotSz"]), min_sz=_parse_finite_float(instrument["minSz"]), ) except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None def get_last_price(self, symbol: str) -> float: data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol}) ticker = self._first_item(data) try: if ticker.get("instId") != symbol: raise self._invalid_payload() return _parse_finite_float(ticker["last"]) except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None def ensure_hedge_mode(self) -> None: data = self._request("GET", "/api/v5/account/config") config = self._first_item(data) pos_mode = config.get("posMode") if not isinstance(pos_mode, str): raise self._invalid_payload() if pos_mode != "long_short_mode": raise ValueError("hedge mode is required") def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None: if not symbol.endswith("-SWAP"): raise ValueError("swap instrument is required") leverage = _parse_valid_leverage(leverage) if pos_side not in {"long", "short"}: raise ValueError("pos_side is invalid") self._request( "POST", "/api/v5/account/set-leverage", json_body={ "instId": symbol, "lever": str(leverage), "mgnMode": "isolated", "posSide": pos_side, }, ) def get_account_balance(self, currency: str = "USDT") -> dict[str, float]: data = self._request("GET", "/api/v5/account/balance", params={"ccy": currency}) account = self._first_item(data) details = account.get("details") if not isinstance(details, list): raise self._invalid_payload() for detail in details: if not isinstance(detail, dict): raise self._invalid_payload() if detail.get("ccy") != currency: continue return { "total_equity_usd": _parse_finite_float(account["totalEq"]), "equity": _parse_finite_float(detail["eq"]), "available_equity": _parse_finite_float(detail["availEq"]), "cash_balance": _parse_finite_float(detail["cashBal"]), } return { "total_equity_usd": _parse_finite_float(account["totalEq"]), "equity": 0.0, "available_equity": 0.0, "cash_balance": 0.0, } def place_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult: if signal.action == "flat": return OrderResult( status="noop", order_id=None, symbol=symbol, side=None, pos_side=None, order_type=None, size=None, ) if signal.action not in {"long", "short"}: raise ValueError("action is invalid") if not symbol.endswith("-SWAP"): raise ValueError("swap instrument is required") leverage = _parse_valid_leverage(signal.leverage) try: margin_value = _parse_finite_float(margin_usdt) margin_decimal = _parse_finite_decimal(margin_usdt) except ValueError: raise ValueError("margin_usdt is invalid") from None if margin_value <= 0: raise ValueError("margin_usdt is invalid") self.ensure_hedge_mode() metadata = self.get_instrument_meta(symbol) price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol) side = "buy" if signal.action == "long" else "sell" pos_side = "long" if signal.action == "long" else "short" size = build_contract_size(margin_decimal * leverage, price, metadata) self.set_leverage(symbol, leverage, pos_side) order_type = "market" if signal.entry_price is None else "limit" request_body = { "instId": symbol, "tdMode": "isolated", "side": side, "posSide": pos_side, "ordType": order_type, "sz": _format_number(size), } if signal.entry_price is not None: request_body["px"] = _format_number(signal.entry_price) data = self._request("POST", "/api/v5/trade/order", json_body=request_body) order = self._first_item(data) order_id = str(order.get("ordId") or "") if not order_id: raise self._invalid_payload() return OrderResult( status="placed", order_id=order_id, symbol=symbol, side=side, pos_side=pos_side, order_type=order_type, size=size, ) def submit_market_order_body(self, body: dict[str, str]) -> OrderResult: required = {"instId", "side", "posSide", "ordType", "sz", "clOrdId"} if any(not body.get(key) for key in required) or body.get("ordType") != "market": raise ValueError("market order body is invalid") _parse_finite_float(body.get("sz")) data = self._request("POST", "/api/v5/trade/order", json_body=body) order = self._first_item(data) order_id = str(order.get("ordId") or "") if not order_id: raise self._invalid_payload() return OrderResult( status="placed", order_id=order_id, symbol=body.get("instId"), side=body.get("side"), pos_side=body.get("posSide"), order_type=body.get("ordType"), size=_parse_finite_float(body.get("sz")), ) def get_positions(self, symbol: str) -> list[Position]: requested_symbol = symbol data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol}) if not data: return [] try: positions = [] for entry in data: size = _parse_finite_float(entry["pos"]) if size == 0.0: continue symbol = entry["instId"] pos_side = entry["posSide"] if not isinstance(symbol, str) or not isinstance(pos_side, str): raise self._invalid_payload() if symbol != requested_symbol: raise self._invalid_payload() if pos_side not in {"long", "short"}: raise self._invalid_payload() positions.append( Position( symbol=symbol, pos_side=pos_side, size=size, avg_price=_parse_finite_float(entry["avgPx"]), ) ) return positions except (KeyError, TypeError, ValueError): raise self._invalid_payload() from None