okx_client.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, InvalidOperation, ROUND_DOWN
  7. from math import isfinite
  8. from typing import TypeAlias
  9. from urllib.parse import urlencode
  10. from okx_codex_trader.config import Config
  11. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  12. OkxRow: TypeAlias = dict[str, object] | list[object]
  13. def _parse_finite_decimal(value: object) -> Decimal:
  14. if isinstance(value, bool):
  15. raise ValueError("contract sizing inputs are invalid")
  16. try:
  17. parsed = Decimal(str(value))
  18. except (InvalidOperation, TypeError, ValueError):
  19. raise ValueError("contract sizing inputs are invalid") from None
  20. if not parsed.is_finite():
  21. raise ValueError("contract sizing inputs are invalid")
  22. return parsed
  23. def _parse_finite_float(value: object) -> float:
  24. if isinstance(value, bool):
  25. raise ValueError("okx response payload is invalid")
  26. try:
  27. parsed = float(value)
  28. except (TypeError, ValueError):
  29. raise ValueError("okx response payload is invalid") from None
  30. if not isfinite(parsed):
  31. raise ValueError("okx response payload is invalid")
  32. return parsed
  33. def _parse_valid_leverage(value: object) -> int:
  34. if isinstance(value, bool) or not isinstance(value, int):
  35. raise ValueError("leverage is invalid")
  36. if value < 1 or value > 3:
  37. raise ValueError("leverage is invalid")
  38. return value
  39. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  40. notional_decimal = _parse_finite_decimal(notional)
  41. price_decimal = _parse_finite_decimal(price)
  42. ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
  43. lot_size = _parse_finite_decimal(metadata.lot_sz)
  44. min_size = _parse_finite_decimal(metadata.min_sz)
  45. if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
  46. raise ValueError("contract sizing inputs are invalid")
  47. raw_size = notional_decimal / (price_decimal * ct_val_decimal)
  48. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  49. if size < min_size:
  50. raise ValueError("contract size below minimum")
  51. return float(size)
  52. def _format_number(value: float) -> str:
  53. return format(Decimal(str(value)).normalize(), "f")
  54. def _okx_error_message(payload: dict[str, object]) -> str:
  55. parts = [str(payload.get("msg") or payload.get("code") or "okx api error")]
  56. data = payload.get("data")
  57. if isinstance(data, list):
  58. for item in data:
  59. if not isinstance(item, dict):
  60. continue
  61. code = item.get("sCode")
  62. msg = item.get("sMsg")
  63. if code or msg:
  64. parts.append(f"{code}: {msg}")
  65. return "; ".join(parts)
  66. class OkxClient:
  67. base_url = "https://www.okx.com"
  68. request_timeout = 10.0
  69. def __init__(self, config: Config | None = None, session=None):
  70. self.config = config
  71. if session is None:
  72. import requests
  73. session = requests.Session()
  74. self.session = session
  75. def _invalid_payload(self) -> ValueError:
  76. return ValueError("okx response payload is invalid")
  77. def _transport_error(self) -> ValueError:
  78. return ValueError("okx transport error")
  79. def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
  80. if not data:
  81. raise self._invalid_payload()
  82. item = data[0]
  83. if not isinstance(item, dict):
  84. raise self._invalid_payload()
  85. return item
  86. @staticmethod
  87. def build_post_only_limit_order_body(
  88. *,
  89. symbol: str,
  90. action: str,
  91. price: object,
  92. size: object,
  93. client_order_id: str,
  94. ) -> dict[str, object]:
  95. if action not in {"long", "short"}:
  96. raise ValueError("action is invalid")
  97. return {
  98. "instId": symbol,
  99. "tdMode": "isolated",
  100. "side": "buy" if action == "long" else "sell",
  101. "posSide": action,
  102. "ordType": "post_only",
  103. "px": _format_number(price),
  104. "sz": _format_number(size),
  105. "clOrdId": client_order_id,
  106. }
  107. @staticmethod
  108. def build_entry_batch_order_body(
  109. *,
  110. symbol: str,
  111. action: str,
  112. reference_price: object,
  113. margin_usdt: object,
  114. leverage: object,
  115. metadata: InstrumentMeta,
  116. client_order_id_prefix: str,
  117. ) -> list[dict[str, str]]:
  118. reference_price_decimal = _parse_finite_decimal(reference_price)
  119. notional_per_order = _parse_finite_decimal(margin_usdt) * _parse_finite_decimal(leverage) / Decimal("3")
  120. bodies = []
  121. for index, offset in enumerate((Decimal("0.003"), Decimal("0.006"), Decimal("0.009")), start=1):
  122. multiplier = Decimal("1") - offset if action == "long" else Decimal("1") + offset
  123. price = reference_price_decimal * multiplier
  124. size = build_contract_size(notional_per_order, price, metadata)
  125. bodies.append(
  126. OkxClient.build_post_only_limit_order_body(
  127. symbol=symbol,
  128. action=action,
  129. price=price,
  130. size=size,
  131. client_order_id=f"{client_order_id_prefix}-{index}",
  132. )
  133. )
  134. return bodies
  135. @staticmethod
  136. def build_cancel_order_body(
  137. *,
  138. symbol: str,
  139. order_id: str | None = None,
  140. client_order_id: str | None = None,
  141. ) -> dict[str, str]:
  142. if bool(order_id) == bool(client_order_id):
  143. raise ValueError("exactly one order identifier is required")
  144. body = {"instId": symbol}
  145. if order_id:
  146. body["ordId"] = order_id
  147. if client_order_id:
  148. body["clOrdId"] = client_order_id
  149. return body
  150. @staticmethod
  151. def build_market_order_body(
  152. *,
  153. symbol: str,
  154. side: str,
  155. pos_side: str,
  156. size: object,
  157. client_order_id: str,
  158. reduce_only: bool,
  159. stop_loss_trigger_price: object | None = None,
  160. take_profit_trigger_price: object | None = None,
  161. ) -> dict[str, str]:
  162. if side not in {"buy", "sell"}:
  163. raise ValueError("side is invalid")
  164. if pos_side not in {"long", "short"}:
  165. raise ValueError("pos_side is invalid")
  166. body: dict[str, object] = {
  167. "instId": symbol,
  168. "tdMode": "isolated",
  169. "side": side,
  170. "posSide": pos_side,
  171. "ordType": "market",
  172. "sz": _format_number(size),
  173. "clOrdId": client_order_id,
  174. }
  175. if reduce_only:
  176. body["reduceOnly"] = "true"
  177. if stop_loss_trigger_price is not None or take_profit_trigger_price is not None:
  178. if reduce_only:
  179. raise ValueError("attached TP/SL is invalid for reduce-only orders")
  180. algo: dict[str, str] = {}
  181. if stop_loss_trigger_price is not None:
  182. algo["slTriggerPx"] = _format_number(stop_loss_trigger_price)
  183. algo["slOrdPx"] = "-1"
  184. if take_profit_trigger_price is not None:
  185. algo["tpTriggerPx"] = _format_number(take_profit_trigger_price)
  186. algo["tpOrdPx"] = "-1"
  187. body["attachAlgoOrds"] = [algo]
  188. return body
  189. @staticmethod
  190. def build_pending_orders_params(*, symbol: str) -> dict[str, str]:
  191. return {"instType": "SWAP", "instId": symbol}
  192. @staticmethod
  193. def build_fills_params(*, symbol: str) -> dict[str, str]:
  194. return {"instType": "SWAP", "instId": symbol}
  195. def _request(
  196. self,
  197. method: str,
  198. path: str,
  199. *,
  200. params: dict[str, object] | None = None,
  201. json_body: dict[str, object] | None = None,
  202. ) -> list[OkxRow]:
  203. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  204. query = urlencode(params or {})
  205. path_with_query = path if not query else f"{path}?{query}"
  206. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  207. headers: dict[str, str] = {}
  208. if self.config is not None:
  209. signature = base64.b64encode(
  210. hmac.new(
  211. self.config.api_secret.encode(),
  212. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  213. hashlib.sha256,
  214. ).digest()
  215. ).decode()
  216. headers = {
  217. "OK-ACCESS-KEY": self.config.api_key,
  218. "OK-ACCESS-SIGN": signature,
  219. "OK-ACCESS-TIMESTAMP": timestamp,
  220. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  221. "x-simulated-trading": "1" if self.config.trading_env == "demo" else "0",
  222. }
  223. if json_body is not None:
  224. headers["Content-Type"] = "application/json"
  225. try:
  226. response = self.session.request(
  227. method.upper(),
  228. f"{self.base_url}{path}",
  229. headers=headers,
  230. params=params,
  231. data=body if json_body is not None else None,
  232. timeout=self.request_timeout,
  233. )
  234. except Exception:
  235. raise self._transport_error() from None
  236. try:
  237. payload = response.json()
  238. except Exception:
  239. raise self._invalid_payload() from None
  240. if not isinstance(payload, dict):
  241. raise self._invalid_payload()
  242. if getattr(response, "status_code", 200) >= 400:
  243. raise ValueError(str(payload.get("msg") or "okx http error"))
  244. if payload.get("code") != "0":
  245. raise ValueError(_okx_error_message(payload))
  246. data = payload.get("data")
  247. if not isinstance(data, list):
  248. raise self._invalid_payload()
  249. return data
  250. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  251. return self._get_candles_from_path("/api/v5/market/history-candles", symbol, bar, limit)
  252. def get_recent_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  253. return self._get_candles_from_path("/api/v5/market/candles", symbol, bar, limit)
  254. def _get_candles_from_path(self, path: str, symbol: str, bar: str, limit: int) -> list[Candle]:
  255. remaining = limit
  256. after: int | None = None
  257. candles_by_ts: dict[int, Candle] = {}
  258. while remaining > 0:
  259. page_limit = min(remaining, 100)
  260. params: dict[str, object] = {"instId": symbol, "bar": bar, "limit": page_limit}
  261. if after is not None:
  262. params["after"] = after
  263. data = self._request("GET", path, params=params)
  264. try:
  265. page = []
  266. for entry in data:
  267. if str(entry[8]) != "1":
  268. continue
  269. page.append(
  270. Candle(
  271. symbol=symbol,
  272. ts=int(entry[0]),
  273. open=_parse_finite_float(entry[1]),
  274. high=_parse_finite_float(entry[2]),
  275. low=_parse_finite_float(entry[3]),
  276. close=_parse_finite_float(entry[4]),
  277. volume=_parse_finite_float(entry[5]),
  278. )
  279. )
  280. except (IndexError, KeyError, TypeError, ValueError):
  281. raise self._invalid_payload() from None
  282. if not page:
  283. break
  284. for candle in page:
  285. candles_by_ts[candle.ts] = candle
  286. remaining = limit - len(candles_by_ts)
  287. oldest_ts = min(candle.ts for candle in page)
  288. after = oldest_ts - 1
  289. if len(data) < page_limit:
  290. break
  291. return sorted(candles_by_ts.values(), key=lambda candle: candle.ts)[:limit]
  292. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  293. data = self._request(
  294. "GET",
  295. "/api/v5/public/instruments",
  296. params={"instType": "SWAP", "instId": symbol},
  297. )
  298. instrument = self._first_item(data)
  299. try:
  300. if instrument.get("instId") != symbol or instrument.get("instType") != "SWAP":
  301. raise self._invalid_payload()
  302. return InstrumentMeta(
  303. ct_val=_parse_finite_float(instrument["ctVal"]),
  304. lot_sz=_parse_finite_float(instrument["lotSz"]),
  305. min_sz=_parse_finite_float(instrument["minSz"]),
  306. )
  307. except (KeyError, TypeError, ValueError):
  308. raise self._invalid_payload() from None
  309. def get_last_price(self, symbol: str) -> float:
  310. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  311. ticker = self._first_item(data)
  312. try:
  313. if ticker.get("instId") != symbol:
  314. raise self._invalid_payload()
  315. return _parse_finite_float(ticker["last"])
  316. except (KeyError, TypeError, ValueError):
  317. raise self._invalid_payload() from None
  318. def ensure_hedge_mode(self) -> None:
  319. data = self._request("GET", "/api/v5/account/config")
  320. config = self._first_item(data)
  321. pos_mode = config.get("posMode")
  322. if not isinstance(pos_mode, str):
  323. raise self._invalid_payload()
  324. if pos_mode != "long_short_mode":
  325. raise ValueError("hedge mode is required")
  326. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  327. if not symbol.endswith("-SWAP"):
  328. raise ValueError("swap instrument is required")
  329. leverage = _parse_valid_leverage(leverage)
  330. if pos_side not in {"long", "short"}:
  331. raise ValueError("pos_side is invalid")
  332. self._request(
  333. "POST",
  334. "/api/v5/account/set-leverage",
  335. json_body={
  336. "instId": symbol,
  337. "lever": str(leverage),
  338. "mgnMode": "isolated",
  339. "posSide": pos_side,
  340. },
  341. )
  342. def get_account_balance(self, currency: str = "USDT") -> dict[str, float]:
  343. data = self._request("GET", "/api/v5/account/balance", params={"ccy": currency})
  344. account = self._first_item(data)
  345. details = account.get("details")
  346. if not isinstance(details, list):
  347. raise self._invalid_payload()
  348. for detail in details:
  349. if not isinstance(detail, dict):
  350. raise self._invalid_payload()
  351. if detail.get("ccy") != currency:
  352. continue
  353. return {
  354. "total_equity_usd": _parse_finite_float(account["totalEq"]),
  355. "equity": _parse_finite_float(detail["eq"]),
  356. "available_equity": _parse_finite_float(detail["availEq"]),
  357. "cash_balance": _parse_finite_float(detail["cashBal"]),
  358. }
  359. return {
  360. "total_equity_usd": _parse_finite_float(account["totalEq"]),
  361. "equity": 0.0,
  362. "available_equity": 0.0,
  363. "cash_balance": 0.0,
  364. }
  365. def place_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  366. if signal.action == "flat":
  367. return OrderResult(
  368. status="noop",
  369. order_id=None,
  370. symbol=symbol,
  371. side=None,
  372. pos_side=None,
  373. order_type=None,
  374. size=None,
  375. )
  376. if signal.action not in {"long", "short"}:
  377. raise ValueError("action is invalid")
  378. if not symbol.endswith("-SWAP"):
  379. raise ValueError("swap instrument is required")
  380. leverage = _parse_valid_leverage(signal.leverage)
  381. try:
  382. margin_value = _parse_finite_float(margin_usdt)
  383. margin_decimal = _parse_finite_decimal(margin_usdt)
  384. except ValueError:
  385. raise ValueError("margin_usdt is invalid") from None
  386. if margin_value <= 0:
  387. raise ValueError("margin_usdt is invalid")
  388. self.ensure_hedge_mode()
  389. metadata = self.get_instrument_meta(symbol)
  390. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  391. side = "buy" if signal.action == "long" else "sell"
  392. pos_side = "long" if signal.action == "long" else "short"
  393. size = build_contract_size(margin_decimal * leverage, price, metadata)
  394. self.set_leverage(symbol, leverage, pos_side)
  395. order_type = "market" if signal.entry_price is None else "limit"
  396. request_body = {
  397. "instId": symbol,
  398. "tdMode": "isolated",
  399. "side": side,
  400. "posSide": pos_side,
  401. "ordType": order_type,
  402. "sz": _format_number(size),
  403. }
  404. if signal.entry_price is not None:
  405. request_body["px"] = _format_number(signal.entry_price)
  406. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  407. order = self._first_item(data)
  408. order_id = str(order.get("ordId") or "")
  409. if not order_id:
  410. raise self._invalid_payload()
  411. return OrderResult(
  412. status="placed",
  413. order_id=order_id,
  414. symbol=symbol,
  415. side=side,
  416. pos_side=pos_side,
  417. order_type=order_type,
  418. size=size,
  419. )
  420. def submit_market_order_body(self, body: dict[str, object]) -> OrderResult:
  421. required = {"instId", "side", "posSide", "ordType", "sz", "clOrdId"}
  422. if any(not body.get(key) for key in required) or body.get("ordType") != "market":
  423. raise ValueError("market order body is invalid")
  424. _parse_finite_float(body.get("sz"))
  425. data = self._request("POST", "/api/v5/trade/order", json_body=body)
  426. order = self._first_item(data)
  427. order_id = str(order.get("ordId") or "")
  428. if not order_id:
  429. raise self._invalid_payload()
  430. return OrderResult(
  431. status="placed",
  432. order_id=order_id,
  433. symbol=body.get("instId"),
  434. side=body.get("side"),
  435. pos_side=body.get("posSide"),
  436. order_type=body.get("ordType"),
  437. size=_parse_finite_float(body.get("sz")),
  438. )
  439. def get_positions(self, symbol: str) -> list[Position]:
  440. requested_symbol = symbol
  441. data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol})
  442. if not data:
  443. return []
  444. try:
  445. positions = []
  446. for entry in data:
  447. size = _parse_finite_float(entry["pos"])
  448. if size == 0.0:
  449. continue
  450. symbol = entry["instId"]
  451. pos_side = entry["posSide"]
  452. if not isinstance(symbol, str) or not isinstance(pos_side, str):
  453. raise self._invalid_payload()
  454. if symbol != requested_symbol:
  455. raise self._invalid_payload()
  456. if pos_side not in {"long", "short"}:
  457. raise self._invalid_payload()
  458. positions.append(
  459. Position(
  460. symbol=symbol,
  461. pos_side=pos_side,
  462. size=size,
  463. avg_price=_parse_finite_float(entry["avgPx"]),
  464. )
  465. )
  466. return positions
  467. except (KeyError, TypeError, ValueError):
  468. raise self._invalid_payload() from None