okx_client.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, ROUND_DOWN
  7. from typing import TypeAlias
  8. from urllib.parse import urlencode
  9. from okx_codex_trader.config import Config
  10. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  11. OkxRow: TypeAlias = dict[str, object] | list[object]
  12. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  13. price_decimal = Decimal(str(price))
  14. ct_val_decimal = Decimal(str(metadata.ct_val))
  15. lot_size = Decimal(str(metadata.lot_sz))
  16. min_size = Decimal(str(metadata.min_sz))
  17. if price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
  18. raise ValueError("contract sizing inputs are invalid")
  19. raw_size = Decimal(str(notional)) / (price_decimal * ct_val_decimal)
  20. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  21. if size < min_size:
  22. raise ValueError("contract size below minimum")
  23. return float(size)
  24. def _format_number(value: float) -> str:
  25. return format(Decimal(str(value)).normalize(), "f")
  26. class OkxClient:
  27. base_url = "https://www.okx.com"
  28. def __init__(self, config: Config, session=None):
  29. self.config = config
  30. if session is None:
  31. import requests
  32. session = requests.Session()
  33. self.session = session
  34. def _invalid_payload(self) -> ValueError:
  35. return ValueError("okx response payload is invalid")
  36. def _transport_error(self) -> ValueError:
  37. return ValueError("okx transport error")
  38. def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
  39. if not data:
  40. raise self._invalid_payload()
  41. item = data[0]
  42. if not isinstance(item, dict):
  43. raise self._invalid_payload()
  44. return item
  45. def _request(
  46. self,
  47. method: str,
  48. path: str,
  49. *,
  50. params: dict[str, object] | None = None,
  51. json_body: dict[str, object] | None = None,
  52. ) -> list[OkxRow]:
  53. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  54. query = urlencode(params or {})
  55. path_with_query = path if not query else f"{path}?{query}"
  56. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  57. signature = base64.b64encode(
  58. hmac.new(
  59. self.config.api_secret.encode(),
  60. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  61. hashlib.sha256,
  62. ).digest()
  63. ).decode()
  64. headers = {
  65. "OK-ACCESS-KEY": self.config.api_key,
  66. "OK-ACCESS-SIGN": signature,
  67. "OK-ACCESS-TIMESTAMP": timestamp,
  68. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  69. "x-simulated-trading": "1",
  70. }
  71. if json_body is not None:
  72. headers["Content-Type"] = "application/json"
  73. try:
  74. response = self.session.request(
  75. method.upper(),
  76. f"{self.base_url}{path}",
  77. headers=headers,
  78. params=params,
  79. data=body if json_body is not None else None,
  80. )
  81. except Exception:
  82. raise self._transport_error() from None
  83. try:
  84. payload = response.json()
  85. except Exception:
  86. raise self._invalid_payload() from None
  87. if not isinstance(payload, dict):
  88. raise self._invalid_payload()
  89. if getattr(response, "status_code", 200) >= 400:
  90. raise ValueError(str(payload.get("msg") or "okx http error"))
  91. if payload.get("code") != "0":
  92. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  93. data = payload.get("data")
  94. if not isinstance(data, list):
  95. raise self._invalid_payload()
  96. return data
  97. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  98. data = self._request(
  99. "GET",
  100. "/api/v5/market/history-candles",
  101. params={"instId": symbol, "bar": bar, "limit": limit},
  102. )
  103. try:
  104. candles = [
  105. Candle(
  106. symbol=symbol,
  107. ts=int(entry[0]),
  108. open=float(entry[1]),
  109. high=float(entry[2]),
  110. low=float(entry[3]),
  111. close=float(entry[4]),
  112. volume=float(entry[5]),
  113. )
  114. for entry in data
  115. ]
  116. return sorted(candles, key=lambda candle: candle.ts)
  117. except (IndexError, KeyError, TypeError, ValueError):
  118. raise self._invalid_payload() from None
  119. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  120. data = self._request(
  121. "GET",
  122. "/api/v5/public/instruments",
  123. params={"instType": "SWAP", "instId": symbol},
  124. )
  125. instrument = self._first_item(data)
  126. try:
  127. return InstrumentMeta(
  128. ct_val=float(instrument["ctVal"]),
  129. lot_sz=float(instrument["lotSz"]),
  130. min_sz=float(instrument["minSz"]),
  131. )
  132. except (KeyError, TypeError, ValueError):
  133. raise self._invalid_payload() from None
  134. def get_last_price(self, symbol: str) -> float:
  135. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  136. ticker = self._first_item(data)
  137. try:
  138. return float(ticker["last"])
  139. except (KeyError, TypeError, ValueError):
  140. raise self._invalid_payload() from None
  141. def ensure_hedge_mode(self) -> None:
  142. data = self._request("GET", "/api/v5/account/config")
  143. config = self._first_item(data)
  144. if config.get("posMode") != "long_short_mode":
  145. raise ValueError("hedge mode is required")
  146. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  147. self._request(
  148. "POST",
  149. "/api/v5/account/set-leverage",
  150. json_body={
  151. "instId": symbol,
  152. "lever": str(leverage),
  153. "mgnMode": "isolated",
  154. "posSide": pos_side,
  155. },
  156. )
  157. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  158. if signal.action == "flat":
  159. return OrderResult(
  160. status="noop",
  161. order_id=None,
  162. symbol=symbol,
  163. side=None,
  164. pos_side=None,
  165. order_type=None,
  166. size=None,
  167. )
  168. if signal.action not in {"long", "short"}:
  169. raise ValueError("action is invalid")
  170. if not symbol.endswith("-SWAP"):
  171. raise ValueError("swap instrument is required")
  172. if signal.leverage < 1 or signal.leverage > 3:
  173. raise ValueError("leverage is invalid")
  174. metadata = self.get_instrument_meta(symbol)
  175. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  176. side = "buy" if signal.action == "long" else "sell"
  177. pos_side = "long" if signal.action == "long" else "short"
  178. self.ensure_hedge_mode()
  179. size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
  180. self.set_leverage(symbol, signal.leverage, pos_side)
  181. order_type = "market" if signal.entry_price is None else "limit"
  182. request_body = {
  183. "instId": symbol,
  184. "tdMode": "isolated",
  185. "side": side,
  186. "posSide": pos_side,
  187. "ordType": order_type,
  188. "sz": _format_number(size),
  189. }
  190. if signal.entry_price is not None:
  191. request_body["px"] = _format_number(signal.entry_price)
  192. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  193. order = self._first_item(data)
  194. order_id = str(order.get("ordId") or "")
  195. if not order_id:
  196. raise self._invalid_payload()
  197. return OrderResult(
  198. status="placed",
  199. order_id=order_id,
  200. symbol=symbol,
  201. side=side,
  202. pos_side=pos_side,
  203. order_type=order_type,
  204. size=size,
  205. )
  206. def get_positions(self, symbol: str) -> list[Position]:
  207. data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
  208. if not data:
  209. return []
  210. try:
  211. positions = []
  212. for entry in data:
  213. size = float(entry["pos"])
  214. if size == 0.0:
  215. continue
  216. symbol = entry["instId"]
  217. pos_side = entry["posSide"]
  218. if not isinstance(symbol, str) or not isinstance(pos_side, str):
  219. raise self._invalid_payload()
  220. positions.append(
  221. Position(
  222. symbol=symbol,
  223. pos_side=pos_side,
  224. size=size,
  225. avg_price=float(entry["avgPx"]),
  226. )
  227. )
  228. return positions
  229. except (KeyError, TypeError, ValueError):
  230. raise self._invalid_payload() from None