okx_client.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, InvalidOperation, ROUND_DOWN
  7. from math import isfinite
  8. from typing import TypeAlias
  9. from urllib.parse import urlencode
  10. from okx_codex_trader.config import Config
  11. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  12. OkxRow: TypeAlias = dict[str, object] | list[object]
  13. def _parse_finite_decimal(value: object) -> Decimal:
  14. try:
  15. parsed = Decimal(str(value))
  16. except (InvalidOperation, TypeError, ValueError):
  17. raise ValueError("contract sizing inputs are invalid") from None
  18. if not parsed.is_finite():
  19. raise ValueError("contract sizing inputs are invalid")
  20. return parsed
  21. def _parse_finite_float(value: object) -> float:
  22. try:
  23. parsed = float(value)
  24. except (TypeError, ValueError):
  25. raise ValueError("okx response payload is invalid") from None
  26. if not isfinite(parsed):
  27. raise ValueError("okx response payload is invalid")
  28. return parsed
  29. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  30. notional_decimal = _parse_finite_decimal(notional)
  31. price_decimal = _parse_finite_decimal(price)
  32. ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
  33. lot_size = _parse_finite_decimal(metadata.lot_sz)
  34. min_size = _parse_finite_decimal(metadata.min_sz)
  35. if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
  36. raise ValueError("contract sizing inputs are invalid")
  37. raw_size = notional_decimal / (price_decimal * ct_val_decimal)
  38. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  39. if size < min_size:
  40. raise ValueError("contract size below minimum")
  41. return float(size)
  42. def _format_number(value: float) -> str:
  43. return format(Decimal(str(value)).normalize(), "f")
  44. class OkxClient:
  45. base_url = "https://www.okx.com"
  46. request_timeout = 10.0
  47. def __init__(self, config: Config, session=None):
  48. self.config = config
  49. if session is None:
  50. import requests
  51. session = requests.Session()
  52. self.session = session
  53. def _invalid_payload(self) -> ValueError:
  54. return ValueError("okx response payload is invalid")
  55. def _transport_error(self) -> ValueError:
  56. return ValueError("okx transport error")
  57. def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
  58. if not data:
  59. raise self._invalid_payload()
  60. item = data[0]
  61. if not isinstance(item, dict):
  62. raise self._invalid_payload()
  63. return item
  64. def _request(
  65. self,
  66. method: str,
  67. path: str,
  68. *,
  69. params: dict[str, object] | None = None,
  70. json_body: dict[str, object] | None = None,
  71. ) -> list[OkxRow]:
  72. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  73. query = urlencode(params or {})
  74. path_with_query = path if not query else f"{path}?{query}"
  75. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  76. signature = base64.b64encode(
  77. hmac.new(
  78. self.config.api_secret.encode(),
  79. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  80. hashlib.sha256,
  81. ).digest()
  82. ).decode()
  83. headers = {
  84. "OK-ACCESS-KEY": self.config.api_key,
  85. "OK-ACCESS-SIGN": signature,
  86. "OK-ACCESS-TIMESTAMP": timestamp,
  87. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  88. "x-simulated-trading": "1",
  89. }
  90. if json_body is not None:
  91. headers["Content-Type"] = "application/json"
  92. try:
  93. response = self.session.request(
  94. method.upper(),
  95. f"{self.base_url}{path}",
  96. headers=headers,
  97. params=params,
  98. data=body if json_body is not None else None,
  99. timeout=self.request_timeout,
  100. )
  101. except Exception:
  102. raise self._transport_error() from None
  103. try:
  104. payload = response.json()
  105. except Exception:
  106. raise self._invalid_payload() from None
  107. if not isinstance(payload, dict):
  108. raise self._invalid_payload()
  109. if getattr(response, "status_code", 200) >= 400:
  110. raise ValueError(str(payload.get("msg") or "okx http error"))
  111. if payload.get("code") != "0":
  112. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  113. data = payload.get("data")
  114. if not isinstance(data, list):
  115. raise self._invalid_payload()
  116. return data
  117. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  118. data = self._request(
  119. "GET",
  120. "/api/v5/market/history-candles",
  121. params={"instId": symbol, "bar": bar, "limit": limit},
  122. )
  123. try:
  124. candles = [
  125. Candle(
  126. symbol=symbol,
  127. ts=int(entry[0]),
  128. open=_parse_finite_float(entry[1]),
  129. high=_parse_finite_float(entry[2]),
  130. low=_parse_finite_float(entry[3]),
  131. close=_parse_finite_float(entry[4]),
  132. volume=_parse_finite_float(entry[5]),
  133. )
  134. for entry in data
  135. ]
  136. return sorted(candles, key=lambda candle: candle.ts)
  137. except (IndexError, KeyError, TypeError, ValueError):
  138. raise self._invalid_payload() from None
  139. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  140. data = self._request(
  141. "GET",
  142. "/api/v5/public/instruments",
  143. params={"instType": "SWAP", "instId": symbol},
  144. )
  145. instrument = self._first_item(data)
  146. try:
  147. if instrument.get("instId") != symbol or instrument.get("instType") != "SWAP":
  148. raise self._invalid_payload()
  149. return InstrumentMeta(
  150. ct_val=_parse_finite_float(instrument["ctVal"]),
  151. lot_sz=_parse_finite_float(instrument["lotSz"]),
  152. min_sz=_parse_finite_float(instrument["minSz"]),
  153. )
  154. except (KeyError, TypeError, ValueError):
  155. raise self._invalid_payload() from None
  156. def get_last_price(self, symbol: str) -> float:
  157. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  158. ticker = self._first_item(data)
  159. try:
  160. if ticker.get("instId") != symbol:
  161. raise self._invalid_payload()
  162. return _parse_finite_float(ticker["last"])
  163. except (KeyError, TypeError, ValueError):
  164. raise self._invalid_payload() from None
  165. def ensure_hedge_mode(self) -> None:
  166. data = self._request("GET", "/api/v5/account/config")
  167. config = self._first_item(data)
  168. if config.get("posMode") != "long_short_mode":
  169. raise ValueError("hedge mode is required")
  170. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  171. if not symbol.endswith("-SWAP"):
  172. raise ValueError("swap instrument is required")
  173. if leverage < 1 or leverage > 3:
  174. raise ValueError("leverage is invalid")
  175. if pos_side not in {"long", "short"}:
  176. raise ValueError("pos_side is invalid")
  177. self._request(
  178. "POST",
  179. "/api/v5/account/set-leverage",
  180. json_body={
  181. "instId": symbol,
  182. "lever": str(leverage),
  183. "mgnMode": "isolated",
  184. "posSide": pos_side,
  185. },
  186. )
  187. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  188. if signal.action == "flat":
  189. return OrderResult(
  190. status="noop",
  191. order_id=None,
  192. symbol=symbol,
  193. side=None,
  194. pos_side=None,
  195. order_type=None,
  196. size=None,
  197. )
  198. if signal.action not in {"long", "short"}:
  199. raise ValueError("action is invalid")
  200. if not symbol.endswith("-SWAP"):
  201. raise ValueError("swap instrument is required")
  202. if signal.leverage < 1 or signal.leverage > 3:
  203. raise ValueError("leverage is invalid")
  204. try:
  205. margin_value = _parse_finite_float(margin_usdt)
  206. except ValueError:
  207. raise ValueError("margin_usdt is invalid") from None
  208. if margin_value <= 0:
  209. raise ValueError("margin_usdt is invalid")
  210. metadata = self.get_instrument_meta(symbol)
  211. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  212. side = "buy" if signal.action == "long" else "sell"
  213. pos_side = "long" if signal.action == "long" else "short"
  214. self.ensure_hedge_mode()
  215. size = build_contract_size(margin_value * signal.leverage, price, metadata)
  216. self.set_leverage(symbol, signal.leverage, pos_side)
  217. order_type = "market" if signal.entry_price is None else "limit"
  218. request_body = {
  219. "instId": symbol,
  220. "tdMode": "isolated",
  221. "side": side,
  222. "posSide": pos_side,
  223. "ordType": order_type,
  224. "sz": _format_number(size),
  225. }
  226. if signal.entry_price is not None:
  227. request_body["px"] = _format_number(signal.entry_price)
  228. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  229. order = self._first_item(data)
  230. order_id = str(order.get("ordId") or "")
  231. if not order_id:
  232. raise self._invalid_payload()
  233. return OrderResult(
  234. status="placed",
  235. order_id=order_id,
  236. symbol=symbol,
  237. side=side,
  238. pos_side=pos_side,
  239. order_type=order_type,
  240. size=size,
  241. )
  242. def get_positions(self, symbol: str) -> list[Position]:
  243. requested_symbol = symbol
  244. data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol})
  245. if not data:
  246. return []
  247. try:
  248. positions = []
  249. for entry in data:
  250. size = _parse_finite_float(entry["pos"])
  251. if size == 0.0:
  252. continue
  253. symbol = entry["instId"]
  254. pos_side = entry["posSide"]
  255. if not isinstance(symbol, str) or not isinstance(pos_side, str):
  256. raise self._invalid_payload()
  257. if symbol != requested_symbol:
  258. raise self._invalid_payload()
  259. if pos_side not in {"long", "short"}:
  260. raise self._invalid_payload()
  261. positions.append(
  262. Position(
  263. symbol=symbol,
  264. pos_side=pos_side,
  265. size=size,
  266. avg_price=_parse_finite_float(entry["avgPx"]),
  267. )
  268. )
  269. return positions
  270. except (KeyError, TypeError, ValueError):
  271. raise self._invalid_payload() from None