| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- import base64
- import hashlib
- import hmac
- import json
- from datetime import UTC, datetime
- from decimal import Decimal, InvalidOperation, ROUND_DOWN
- from math import isfinite
- from typing import TypeAlias
- from urllib.parse import urlencode
- from okx_codex_trader.config import Config
- from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
- OkxRow: TypeAlias = dict[str, object] | list[object]
- def _parse_finite_decimal(value: object) -> Decimal:
- if isinstance(value, bool):
- raise ValueError("contract sizing inputs are invalid")
- try:
- parsed = Decimal(str(value))
- except (InvalidOperation, TypeError, ValueError):
- raise ValueError("contract sizing inputs are invalid") from None
- if not parsed.is_finite():
- raise ValueError("contract sizing inputs are invalid")
- return parsed
- def _parse_finite_float(value: object) -> float:
- if isinstance(value, bool):
- raise ValueError("okx response payload is invalid")
- try:
- parsed = float(value)
- except (TypeError, ValueError):
- raise ValueError("okx response payload is invalid") from None
- if not isfinite(parsed):
- raise ValueError("okx response payload is invalid")
- return parsed
- def _parse_valid_leverage(value: object) -> int:
- if isinstance(value, bool) or not isinstance(value, int):
- raise ValueError("leverage is invalid")
- if value < 1 or value > 3:
- raise ValueError("leverage is invalid")
- return value
- def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
- notional_decimal = _parse_finite_decimal(notional)
- price_decimal = _parse_finite_decimal(price)
- ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
- lot_size = _parse_finite_decimal(metadata.lot_sz)
- min_size = _parse_finite_decimal(metadata.min_sz)
- if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
- raise ValueError("contract sizing inputs are invalid")
- raw_size = notional_decimal / (price_decimal * ct_val_decimal)
- size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
- if size < min_size:
- raise ValueError("contract size below minimum")
- return float(size)
- def _format_number(value: float) -> str:
- return format(Decimal(str(value)).normalize(), "f")
- class OkxClient:
- base_url = "https://www.okx.com"
- request_timeout = 10.0
- def __init__(self, config: Config | None = None, session=None):
- self.config = config
- if session is None:
- import requests
- session = requests.Session()
- self.session = session
- def _invalid_payload(self) -> ValueError:
- return ValueError("okx response payload is invalid")
- def _transport_error(self) -> ValueError:
- return ValueError("okx transport error")
- def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
- if not data:
- raise self._invalid_payload()
- item = data[0]
- if not isinstance(item, dict):
- raise self._invalid_payload()
- return item
- @staticmethod
- def build_post_only_limit_order_body(
- *,
- symbol: str,
- action: str,
- price: object,
- size: object,
- client_order_id: str,
- ) -> dict[str, str]:
- if action not in {"long", "short"}:
- raise ValueError("action is invalid")
- return {
- "instId": symbol,
- "tdMode": "isolated",
- "side": "buy" if action == "long" else "sell",
- "posSide": action,
- "ordType": "post_only",
- "px": _format_number(price),
- "sz": _format_number(size),
- "clOrdId": client_order_id,
- }
- @staticmethod
- def build_entry_batch_order_body(
- *,
- symbol: str,
- action: str,
- reference_price: object,
- margin_usdt: object,
- leverage: object,
- metadata: InstrumentMeta,
- client_order_id_prefix: str,
- ) -> list[dict[str, str]]:
- reference_price_decimal = _parse_finite_decimal(reference_price)
- notional_per_order = _parse_finite_decimal(margin_usdt) * _parse_finite_decimal(leverage) / Decimal("3")
- bodies = []
- for index, offset in enumerate((Decimal("0.003"), Decimal("0.006"), Decimal("0.009")), start=1):
- multiplier = Decimal("1") - offset if action == "long" else Decimal("1") + offset
- price = reference_price_decimal * multiplier
- size = build_contract_size(notional_per_order, price, metadata)
- bodies.append(
- OkxClient.build_post_only_limit_order_body(
- symbol=symbol,
- action=action,
- price=price,
- size=size,
- client_order_id=f"{client_order_id_prefix}-{index}",
- )
- )
- return bodies
- @staticmethod
- def build_cancel_order_body(
- *,
- symbol: str,
- order_id: str | None = None,
- client_order_id: str | None = None,
- ) -> dict[str, str]:
- if bool(order_id) == bool(client_order_id):
- raise ValueError("exactly one order identifier is required")
- body = {"instId": symbol}
- if order_id:
- body["ordId"] = order_id
- if client_order_id:
- body["clOrdId"] = client_order_id
- return body
- @staticmethod
- def build_pending_orders_params(*, symbol: str) -> dict[str, str]:
- return {"instType": "SWAP", "instId": symbol}
- @staticmethod
- def build_fills_params(*, symbol: str) -> dict[str, str]:
- return {"instType": "SWAP", "instId": symbol}
- def _request(
- self,
- method: str,
- path: str,
- *,
- params: dict[str, object] | None = None,
- json_body: dict[str, object] | None = None,
- ) -> list[OkxRow]:
- timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
- query = urlencode(params or {})
- path_with_query = path if not query else f"{path}?{query}"
- body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
- headers: dict[str, str] = {}
- if self.config is not None:
- signature = base64.b64encode(
- hmac.new(
- self.config.api_secret.encode(),
- f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
- hashlib.sha256,
- ).digest()
- ).decode()
- headers = {
- "OK-ACCESS-KEY": self.config.api_key,
- "OK-ACCESS-SIGN": signature,
- "OK-ACCESS-TIMESTAMP": timestamp,
- "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
- "x-simulated-trading": "1" if self.config.trading_env == "demo" else "0",
- }
- if json_body is not None:
- headers["Content-Type"] = "application/json"
- try:
- response = self.session.request(
- method.upper(),
- f"{self.base_url}{path}",
- headers=headers,
- params=params,
- data=body if json_body is not None else None,
- timeout=self.request_timeout,
- )
- except Exception:
- raise self._transport_error() from None
- try:
- payload = response.json()
- except Exception:
- raise self._invalid_payload() from None
- if not isinstance(payload, dict):
- raise self._invalid_payload()
- if getattr(response, "status_code", 200) >= 400:
- raise ValueError(str(payload.get("msg") or "okx http error"))
- if payload.get("code") != "0":
- raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
- data = payload.get("data")
- if not isinstance(data, list):
- raise self._invalid_payload()
- return data
- def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
- remaining = limit
- after: int | None = None
- candles_by_ts: dict[int, Candle] = {}
- while remaining > 0:
- page_limit = min(remaining, 300)
- params: dict[str, object] = {"instId": symbol, "bar": bar, "limit": page_limit}
- if after is not None:
- params["after"] = after
- data = self._request("GET", "/api/v5/market/history-candles", params=params)
- try:
- page = []
- for entry in data:
- if str(entry[8]) != "1":
- continue
- page.append(
- Candle(
- symbol=symbol,
- ts=int(entry[0]),
- open=_parse_finite_float(entry[1]),
- high=_parse_finite_float(entry[2]),
- low=_parse_finite_float(entry[3]),
- close=_parse_finite_float(entry[4]),
- volume=_parse_finite_float(entry[5]),
- )
- )
- except (IndexError, KeyError, TypeError, ValueError):
- raise self._invalid_payload() from None
- if not page:
- break
- for candle in page:
- candles_by_ts[candle.ts] = candle
- remaining = limit - len(candles_by_ts)
- oldest_ts = min(candle.ts for candle in page)
- after = oldest_ts - 1
- if len(page) < page_limit:
- break
- return sorted(candles_by_ts.values(), key=lambda candle: candle.ts)[:limit]
- def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
- data = self._request(
- "GET",
- "/api/v5/public/instruments",
- params={"instType": "SWAP", "instId": symbol},
- )
- instrument = self._first_item(data)
- try:
- if instrument.get("instId") != symbol or instrument.get("instType") != "SWAP":
- raise self._invalid_payload()
- return InstrumentMeta(
- ct_val=_parse_finite_float(instrument["ctVal"]),
- lot_sz=_parse_finite_float(instrument["lotSz"]),
- min_sz=_parse_finite_float(instrument["minSz"]),
- )
- except (KeyError, TypeError, ValueError):
- raise self._invalid_payload() from None
- def get_last_price(self, symbol: str) -> float:
- data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
- ticker = self._first_item(data)
- try:
- if ticker.get("instId") != symbol:
- raise self._invalid_payload()
- return _parse_finite_float(ticker["last"])
- except (KeyError, TypeError, ValueError):
- raise self._invalid_payload() from None
- def ensure_hedge_mode(self) -> None:
- data = self._request("GET", "/api/v5/account/config")
- config = self._first_item(data)
- pos_mode = config.get("posMode")
- if not isinstance(pos_mode, str):
- raise self._invalid_payload()
- if pos_mode != "long_short_mode":
- raise ValueError("hedge mode is required")
- def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
- if not symbol.endswith("-SWAP"):
- raise ValueError("swap instrument is required")
- leverage = _parse_valid_leverage(leverage)
- if pos_side not in {"long", "short"}:
- raise ValueError("pos_side is invalid")
- self._request(
- "POST",
- "/api/v5/account/set-leverage",
- json_body={
- "instId": symbol,
- "lever": str(leverage),
- "mgnMode": "isolated",
- "posSide": pos_side,
- },
- )
- def get_account_balance(self, currency: str = "USDT") -> dict[str, float]:
- data = self._request("GET", "/api/v5/account/balance", params={"ccy": currency})
- account = self._first_item(data)
- details = account.get("details")
- if not isinstance(details, list):
- raise self._invalid_payload()
- for detail in details:
- if not isinstance(detail, dict):
- raise self._invalid_payload()
- if detail.get("ccy") != currency:
- continue
- return {
- "total_equity_usd": _parse_finite_float(account["totalEq"]),
- "equity": _parse_finite_float(detail["eq"]),
- "available_equity": _parse_finite_float(detail["availEq"]),
- "cash_balance": _parse_finite_float(detail["cashBal"]),
- }
- return {
- "total_equity_usd": _parse_finite_float(account["totalEq"]),
- "equity": 0.0,
- "available_equity": 0.0,
- "cash_balance": 0.0,
- }
- def place_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
- if signal.action == "flat":
- return OrderResult(
- status="noop",
- order_id=None,
- symbol=symbol,
- side=None,
- pos_side=None,
- order_type=None,
- size=None,
- )
- if signal.action not in {"long", "short"}:
- raise ValueError("action is invalid")
- if not symbol.endswith("-SWAP"):
- raise ValueError("swap instrument is required")
- leverage = _parse_valid_leverage(signal.leverage)
- try:
- margin_value = _parse_finite_float(margin_usdt)
- margin_decimal = _parse_finite_decimal(margin_usdt)
- except ValueError:
- raise ValueError("margin_usdt is invalid") from None
- if margin_value <= 0:
- raise ValueError("margin_usdt is invalid")
- self.ensure_hedge_mode()
- metadata = self.get_instrument_meta(symbol)
- price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
- side = "buy" if signal.action == "long" else "sell"
- pos_side = "long" if signal.action == "long" else "short"
- size = build_contract_size(margin_decimal * leverage, price, metadata)
- self.set_leverage(symbol, leverage, pos_side)
- order_type = "market" if signal.entry_price is None else "limit"
- request_body = {
- "instId": symbol,
- "tdMode": "isolated",
- "side": side,
- "posSide": pos_side,
- "ordType": order_type,
- "sz": _format_number(size),
- }
- if signal.entry_price is not None:
- request_body["px"] = _format_number(signal.entry_price)
- data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
- order = self._first_item(data)
- order_id = str(order.get("ordId") or "")
- if not order_id:
- raise self._invalid_payload()
- return OrderResult(
- status="placed",
- order_id=order_id,
- symbol=symbol,
- side=side,
- pos_side=pos_side,
- order_type=order_type,
- size=size,
- )
- def get_positions(self, symbol: str) -> list[Position]:
- requested_symbol = symbol
- data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol})
- if not data:
- return []
- try:
- positions = []
- for entry in data:
- size = _parse_finite_float(entry["pos"])
- if size == 0.0:
- continue
- symbol = entry["instId"]
- pos_side = entry["posSide"]
- if not isinstance(symbol, str) or not isinstance(pos_side, str):
- raise self._invalid_payload()
- if symbol != requested_symbol:
- raise self._invalid_payload()
- if pos_side not in {"long", "short"}:
- raise self._invalid_payload()
- positions.append(
- Position(
- symbol=symbol,
- pos_side=pos_side,
- size=size,
- avg_price=_parse_finite_float(entry["avgPx"]),
- )
- )
- return positions
- except (KeyError, TypeError, ValueError):
- raise self._invalid_payload() from None
|