okx_client.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, InvalidOperation, ROUND_DOWN
  7. from math import isfinite
  8. from typing import TypeAlias
  9. from urllib.parse import urlencode
  10. from okx_codex_trader.config import Config
  11. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  12. OkxRow: TypeAlias = dict[str, object] | list[object]
  13. def _parse_finite_decimal(value: object) -> Decimal:
  14. try:
  15. parsed = Decimal(str(value))
  16. except (InvalidOperation, TypeError, ValueError):
  17. raise ValueError("contract sizing inputs are invalid") from None
  18. if not parsed.is_finite():
  19. raise ValueError("contract sizing inputs are invalid")
  20. return parsed
  21. def _parse_finite_float(value: object) -> float:
  22. try:
  23. parsed = float(value)
  24. except (TypeError, ValueError):
  25. raise ValueError("okx response payload is invalid") from None
  26. if not isfinite(parsed):
  27. raise ValueError("okx response payload is invalid")
  28. return parsed
  29. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  30. price_decimal = _parse_finite_decimal(price)
  31. ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
  32. lot_size = _parse_finite_decimal(metadata.lot_sz)
  33. min_size = _parse_finite_decimal(metadata.min_sz)
  34. if price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
  35. raise ValueError("contract sizing inputs are invalid")
  36. raw_size = _parse_finite_decimal(notional) / (price_decimal * ct_val_decimal)
  37. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  38. if size < min_size:
  39. raise ValueError("contract size below minimum")
  40. return float(size)
  41. def _format_number(value: float) -> str:
  42. return format(Decimal(str(value)).normalize(), "f")
  43. class OkxClient:
  44. base_url = "https://www.okx.com"
  45. def __init__(self, config: Config, session=None):
  46. self.config = config
  47. if session is None:
  48. import requests
  49. session = requests.Session()
  50. self.session = session
  51. def _invalid_payload(self) -> ValueError:
  52. return ValueError("okx response payload is invalid")
  53. def _transport_error(self) -> ValueError:
  54. return ValueError("okx transport error")
  55. def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
  56. if not data:
  57. raise self._invalid_payload()
  58. item = data[0]
  59. if not isinstance(item, dict):
  60. raise self._invalid_payload()
  61. return item
  62. def _request(
  63. self,
  64. method: str,
  65. path: str,
  66. *,
  67. params: dict[str, object] | None = None,
  68. json_body: dict[str, object] | None = None,
  69. ) -> list[OkxRow]:
  70. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  71. query = urlencode(params or {})
  72. path_with_query = path if not query else f"{path}?{query}"
  73. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  74. signature = base64.b64encode(
  75. hmac.new(
  76. self.config.api_secret.encode(),
  77. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  78. hashlib.sha256,
  79. ).digest()
  80. ).decode()
  81. headers = {
  82. "OK-ACCESS-KEY": self.config.api_key,
  83. "OK-ACCESS-SIGN": signature,
  84. "OK-ACCESS-TIMESTAMP": timestamp,
  85. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  86. "x-simulated-trading": "1",
  87. }
  88. if json_body is not None:
  89. headers["Content-Type"] = "application/json"
  90. try:
  91. response = self.session.request(
  92. method.upper(),
  93. f"{self.base_url}{path}",
  94. headers=headers,
  95. params=params,
  96. data=body if json_body is not None else None,
  97. )
  98. except Exception:
  99. raise self._transport_error() from None
  100. try:
  101. payload = response.json()
  102. except Exception:
  103. raise self._invalid_payload() from None
  104. if not isinstance(payload, dict):
  105. raise self._invalid_payload()
  106. if getattr(response, "status_code", 200) >= 400:
  107. raise ValueError(str(payload.get("msg") or "okx http error"))
  108. if payload.get("code") != "0":
  109. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  110. data = payload.get("data")
  111. if not isinstance(data, list):
  112. raise self._invalid_payload()
  113. return data
  114. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  115. data = self._request(
  116. "GET",
  117. "/api/v5/market/history-candles",
  118. params={"instId": symbol, "bar": bar, "limit": limit},
  119. )
  120. try:
  121. candles = [
  122. Candle(
  123. symbol=symbol,
  124. ts=int(entry[0]),
  125. open=_parse_finite_float(entry[1]),
  126. high=_parse_finite_float(entry[2]),
  127. low=_parse_finite_float(entry[3]),
  128. close=_parse_finite_float(entry[4]),
  129. volume=_parse_finite_float(entry[5]),
  130. )
  131. for entry in data
  132. ]
  133. return sorted(candles, key=lambda candle: candle.ts)
  134. except (IndexError, KeyError, TypeError, ValueError):
  135. raise self._invalid_payload() from None
  136. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  137. data = self._request(
  138. "GET",
  139. "/api/v5/public/instruments",
  140. params={"instType": "SWAP", "instId": symbol},
  141. )
  142. instrument = self._first_item(data)
  143. try:
  144. return InstrumentMeta(
  145. ct_val=_parse_finite_float(instrument["ctVal"]),
  146. lot_sz=_parse_finite_float(instrument["lotSz"]),
  147. min_sz=_parse_finite_float(instrument["minSz"]),
  148. )
  149. except (KeyError, TypeError, ValueError):
  150. raise self._invalid_payload() from None
  151. def get_last_price(self, symbol: str) -> float:
  152. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  153. ticker = self._first_item(data)
  154. try:
  155. return _parse_finite_float(ticker["last"])
  156. except (KeyError, TypeError, ValueError):
  157. raise self._invalid_payload() from None
  158. def ensure_hedge_mode(self) -> None:
  159. data = self._request("GET", "/api/v5/account/config")
  160. config = self._first_item(data)
  161. if config.get("posMode") != "long_short_mode":
  162. raise ValueError("hedge mode is required")
  163. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  164. self._request(
  165. "POST",
  166. "/api/v5/account/set-leverage",
  167. json_body={
  168. "instId": symbol,
  169. "lever": str(leverage),
  170. "mgnMode": "isolated",
  171. "posSide": pos_side,
  172. },
  173. )
  174. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  175. if signal.action == "flat":
  176. return OrderResult(
  177. status="noop",
  178. order_id=None,
  179. symbol=symbol,
  180. side=None,
  181. pos_side=None,
  182. order_type=None,
  183. size=None,
  184. )
  185. if signal.action not in {"long", "short"}:
  186. raise ValueError("action is invalid")
  187. if not symbol.endswith("-SWAP"):
  188. raise ValueError("swap instrument is required")
  189. if signal.leverage < 1 or signal.leverage > 3:
  190. raise ValueError("leverage is invalid")
  191. metadata = self.get_instrument_meta(symbol)
  192. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  193. side = "buy" if signal.action == "long" else "sell"
  194. pos_side = "long" if signal.action == "long" else "short"
  195. self.ensure_hedge_mode()
  196. size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
  197. self.set_leverage(symbol, signal.leverage, pos_side)
  198. order_type = "market" if signal.entry_price is None else "limit"
  199. request_body = {
  200. "instId": symbol,
  201. "tdMode": "isolated",
  202. "side": side,
  203. "posSide": pos_side,
  204. "ordType": order_type,
  205. "sz": _format_number(size),
  206. }
  207. if signal.entry_price is not None:
  208. request_body["px"] = _format_number(signal.entry_price)
  209. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  210. order = self._first_item(data)
  211. order_id = str(order.get("ordId") or "")
  212. if not order_id:
  213. raise self._invalid_payload()
  214. return OrderResult(
  215. status="placed",
  216. order_id=order_id,
  217. symbol=symbol,
  218. side=side,
  219. pos_side=pos_side,
  220. order_type=order_type,
  221. size=size,
  222. )
  223. def get_positions(self, symbol: str) -> list[Position]:
  224. data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
  225. if not data:
  226. return []
  227. try:
  228. positions = []
  229. for entry in data:
  230. size = _parse_finite_float(entry["pos"])
  231. if size == 0.0:
  232. continue
  233. symbol = entry["instId"]
  234. pos_side = entry["posSide"]
  235. if not isinstance(symbol, str) or not isinstance(pos_side, str):
  236. raise self._invalid_payload()
  237. positions.append(
  238. Position(
  239. symbol=symbol,
  240. pos_side=pos_side,
  241. size=size,
  242. avg_price=_parse_finite_float(entry["avgPx"]),
  243. )
  244. )
  245. return positions
  246. except (KeyError, TypeError, ValueError):
  247. raise self._invalid_payload() from None