okx_client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, InvalidOperation, ROUND_DOWN
  7. from math import isfinite
  8. from typing import TypeAlias
  9. from urllib.parse import urlencode
  10. from okx_codex_trader.config import Config
  11. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  12. OkxRow: TypeAlias = dict[str, object] | list[object]
  13. def _parse_finite_decimal(value: object) -> Decimal:
  14. if isinstance(value, bool):
  15. raise ValueError("contract sizing inputs are invalid")
  16. try:
  17. parsed = Decimal(str(value))
  18. except (InvalidOperation, TypeError, ValueError):
  19. raise ValueError("contract sizing inputs are invalid") from None
  20. if not parsed.is_finite():
  21. raise ValueError("contract sizing inputs are invalid")
  22. return parsed
  23. def _parse_finite_float(value: object) -> float:
  24. if isinstance(value, bool):
  25. raise ValueError("okx response payload is invalid")
  26. try:
  27. parsed = float(value)
  28. except (TypeError, ValueError):
  29. raise ValueError("okx response payload is invalid") from None
  30. if not isfinite(parsed):
  31. raise ValueError("okx response payload is invalid")
  32. return parsed
  33. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  34. notional_decimal = _parse_finite_decimal(notional)
  35. price_decimal = _parse_finite_decimal(price)
  36. ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
  37. lot_size = _parse_finite_decimal(metadata.lot_sz)
  38. min_size = _parse_finite_decimal(metadata.min_sz)
  39. if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
  40. raise ValueError("contract sizing inputs are invalid")
  41. raw_size = notional_decimal / (price_decimal * ct_val_decimal)
  42. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  43. if size < min_size:
  44. raise ValueError("contract size below minimum")
  45. return float(size)
  46. def _format_number(value: float) -> str:
  47. return format(Decimal(str(value)).normalize(), "f")
  48. class OkxClient:
  49. base_url = "https://www.okx.com"
  50. request_timeout = 10.0
  51. def __init__(self, config: Config, session=None):
  52. self.config = config
  53. if session is None:
  54. import requests
  55. session = requests.Session()
  56. self.session = session
  57. def _invalid_payload(self) -> ValueError:
  58. return ValueError("okx response payload is invalid")
  59. def _transport_error(self) -> ValueError:
  60. return ValueError("okx transport error")
  61. def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
  62. if not data:
  63. raise self._invalid_payload()
  64. item = data[0]
  65. if not isinstance(item, dict):
  66. raise self._invalid_payload()
  67. return item
  68. def _request(
  69. self,
  70. method: str,
  71. path: str,
  72. *,
  73. params: dict[str, object] | None = None,
  74. json_body: dict[str, object] | None = None,
  75. ) -> list[OkxRow]:
  76. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  77. query = urlencode(params or {})
  78. path_with_query = path if not query else f"{path}?{query}"
  79. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  80. signature = base64.b64encode(
  81. hmac.new(
  82. self.config.api_secret.encode(),
  83. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  84. hashlib.sha256,
  85. ).digest()
  86. ).decode()
  87. headers = {
  88. "OK-ACCESS-KEY": self.config.api_key,
  89. "OK-ACCESS-SIGN": signature,
  90. "OK-ACCESS-TIMESTAMP": timestamp,
  91. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  92. "x-simulated-trading": "1",
  93. }
  94. if json_body is not None:
  95. headers["Content-Type"] = "application/json"
  96. try:
  97. response = self.session.request(
  98. method.upper(),
  99. f"{self.base_url}{path}",
  100. headers=headers,
  101. params=params,
  102. data=body if json_body is not None else None,
  103. timeout=self.request_timeout,
  104. )
  105. except Exception:
  106. raise self._transport_error() from None
  107. try:
  108. payload = response.json()
  109. except Exception:
  110. raise self._invalid_payload() from None
  111. if not isinstance(payload, dict):
  112. raise self._invalid_payload()
  113. if getattr(response, "status_code", 200) >= 400:
  114. raise ValueError(str(payload.get("msg") or "okx http error"))
  115. if payload.get("code") != "0":
  116. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  117. data = payload.get("data")
  118. if not isinstance(data, list):
  119. raise self._invalid_payload()
  120. return data
  121. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  122. data = self._request(
  123. "GET",
  124. "/api/v5/market/history-candles",
  125. params={"instId": symbol, "bar": bar, "limit": limit},
  126. )
  127. try:
  128. candles = [
  129. Candle(
  130. symbol=symbol,
  131. ts=int(entry[0]),
  132. open=_parse_finite_float(entry[1]),
  133. high=_parse_finite_float(entry[2]),
  134. low=_parse_finite_float(entry[3]),
  135. close=_parse_finite_float(entry[4]),
  136. volume=_parse_finite_float(entry[5]),
  137. )
  138. for entry in data
  139. ]
  140. return sorted(candles, key=lambda candle: candle.ts)
  141. except (IndexError, KeyError, TypeError, ValueError):
  142. raise self._invalid_payload() from None
  143. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  144. data = self._request(
  145. "GET",
  146. "/api/v5/public/instruments",
  147. params={"instType": "SWAP", "instId": symbol},
  148. )
  149. instrument = self._first_item(data)
  150. try:
  151. if instrument.get("instId") != symbol or instrument.get("instType") != "SWAP":
  152. raise self._invalid_payload()
  153. return InstrumentMeta(
  154. ct_val=_parse_finite_float(instrument["ctVal"]),
  155. lot_sz=_parse_finite_float(instrument["lotSz"]),
  156. min_sz=_parse_finite_float(instrument["minSz"]),
  157. )
  158. except (KeyError, TypeError, ValueError):
  159. raise self._invalid_payload() from None
  160. def get_last_price(self, symbol: str) -> float:
  161. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  162. ticker = self._first_item(data)
  163. try:
  164. if ticker.get("instId") != symbol:
  165. raise self._invalid_payload()
  166. return _parse_finite_float(ticker["last"])
  167. except (KeyError, TypeError, ValueError):
  168. raise self._invalid_payload() from None
  169. def ensure_hedge_mode(self) -> None:
  170. data = self._request("GET", "/api/v5/account/config")
  171. config = self._first_item(data)
  172. pos_mode = config.get("posMode")
  173. if not isinstance(pos_mode, str):
  174. raise self._invalid_payload()
  175. if pos_mode != "long_short_mode":
  176. raise ValueError("hedge mode is required")
  177. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  178. if not symbol.endswith("-SWAP"):
  179. raise ValueError("swap instrument is required")
  180. if isinstance(leverage, bool):
  181. raise ValueError("leverage is invalid")
  182. if leverage < 1 or leverage > 3:
  183. raise ValueError("leverage is invalid")
  184. if pos_side not in {"long", "short"}:
  185. raise ValueError("pos_side is invalid")
  186. self._request(
  187. "POST",
  188. "/api/v5/account/set-leverage",
  189. json_body={
  190. "instId": symbol,
  191. "lever": str(leverage),
  192. "mgnMode": "isolated",
  193. "posSide": pos_side,
  194. },
  195. )
  196. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  197. if signal.action == "flat":
  198. return OrderResult(
  199. status="noop",
  200. order_id=None,
  201. symbol=symbol,
  202. side=None,
  203. pos_side=None,
  204. order_type=None,
  205. size=None,
  206. )
  207. if signal.action not in {"long", "short"}:
  208. raise ValueError("action is invalid")
  209. if not symbol.endswith("-SWAP"):
  210. raise ValueError("swap instrument is required")
  211. if isinstance(signal.leverage, bool):
  212. raise ValueError("leverage is invalid")
  213. if signal.leverage < 1 or signal.leverage > 3:
  214. raise ValueError("leverage is invalid")
  215. try:
  216. margin_value = _parse_finite_float(margin_usdt)
  217. except ValueError:
  218. raise ValueError("margin_usdt is invalid") from None
  219. if margin_value <= 0:
  220. raise ValueError("margin_usdt is invalid")
  221. self.ensure_hedge_mode()
  222. metadata = self.get_instrument_meta(symbol)
  223. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  224. side = "buy" if signal.action == "long" else "sell"
  225. pos_side = "long" if signal.action == "long" else "short"
  226. size = build_contract_size(margin_value * signal.leverage, price, metadata)
  227. self.set_leverage(symbol, signal.leverage, pos_side)
  228. order_type = "market" if signal.entry_price is None else "limit"
  229. request_body = {
  230. "instId": symbol,
  231. "tdMode": "isolated",
  232. "side": side,
  233. "posSide": pos_side,
  234. "ordType": order_type,
  235. "sz": _format_number(size),
  236. }
  237. if signal.entry_price is not None:
  238. request_body["px"] = _format_number(signal.entry_price)
  239. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  240. order = self._first_item(data)
  241. order_id = str(order.get("ordId") or "")
  242. if not order_id:
  243. raise self._invalid_payload()
  244. return OrderResult(
  245. status="placed",
  246. order_id=order_id,
  247. symbol=symbol,
  248. side=side,
  249. pos_side=pos_side,
  250. order_type=order_type,
  251. size=size,
  252. )
  253. def get_positions(self, symbol: str) -> list[Position]:
  254. requested_symbol = symbol
  255. data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol})
  256. if not data:
  257. return []
  258. try:
  259. positions = []
  260. for entry in data:
  261. size = _parse_finite_float(entry["pos"])
  262. if size == 0.0:
  263. continue
  264. symbol = entry["instId"]
  265. pos_side = entry["posSide"]
  266. if not isinstance(symbol, str) or not isinstance(pos_side, str):
  267. raise self._invalid_payload()
  268. if symbol != requested_symbol:
  269. raise self._invalid_payload()
  270. if pos_side not in {"long", "short"}:
  271. raise self._invalid_payload()
  272. positions.append(
  273. Position(
  274. symbol=symbol,
  275. pos_side=pos_side,
  276. size=size,
  277. avg_price=_parse_finite_float(entry["avgPx"]),
  278. )
  279. )
  280. return positions
  281. except (KeyError, TypeError, ValueError):
  282. raise self._invalid_payload() from None