okx_client.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, ROUND_DOWN
  7. from urllib.parse import urlencode
  8. from okx_codex_trader.config import Config
  9. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  10. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  11. raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
  12. lot_size = Decimal(str(metadata.lot_sz))
  13. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  14. if size < Decimal(str(metadata.min_sz)):
  15. raise ValueError("contract size below minimum")
  16. return float(size)
  17. def _format_number(value: float) -> str:
  18. return format(Decimal(str(value)).normalize(), "f")
  19. class OkxClient:
  20. base_url = "https://www.okx.com"
  21. def __init__(self, config: Config, session=None):
  22. self.config = config
  23. if session is None:
  24. import requests
  25. session = requests.Session()
  26. self.session = session
  27. def _invalid_payload(self) -> ValueError:
  28. return ValueError("okx response payload is invalid")
  29. def _first_item(self, data: list[dict[str, object]]) -> dict[str, object]:
  30. if not data:
  31. raise self._invalid_payload()
  32. return data[0]
  33. def _request(
  34. self,
  35. method: str,
  36. path: str,
  37. *,
  38. params: dict[str, object] | None = None,
  39. json_body: dict[str, object] | None = None,
  40. ) -> list[dict[str, object]]:
  41. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  42. query = urlencode(params or {})
  43. path_with_query = path if not query else f"{path}?{query}"
  44. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  45. signature = base64.b64encode(
  46. hmac.new(
  47. self.config.api_secret.encode(),
  48. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  49. hashlib.sha256,
  50. ).digest()
  51. ).decode()
  52. headers = {
  53. "OK-ACCESS-KEY": self.config.api_key,
  54. "OK-ACCESS-SIGN": signature,
  55. "OK-ACCESS-TIMESTAMP": timestamp,
  56. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  57. "x-simulated-trading": "1",
  58. }
  59. response = self.session.request(
  60. method.upper(),
  61. f"{self.base_url}{path}",
  62. headers=headers,
  63. params=params,
  64. json=json_body,
  65. )
  66. payload = response.json()
  67. if getattr(response, "status_code", 200) >= 400:
  68. raise ValueError(str(payload.get("msg") or "okx http error"))
  69. if payload.get("code") != "0":
  70. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  71. data = payload.get("data")
  72. return data if isinstance(data, list) else []
  73. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  74. data = self._request(
  75. "GET",
  76. "/api/v5/market/history-candles",
  77. params={"instId": symbol, "bar": bar, "limit": limit},
  78. )
  79. try:
  80. return [
  81. Candle(
  82. symbol=symbol,
  83. ts=int(entry[0]),
  84. open=float(entry[1]),
  85. high=float(entry[2]),
  86. low=float(entry[3]),
  87. close=float(entry[4]),
  88. volume=float(entry[5]),
  89. )
  90. for entry in data
  91. ]
  92. except (IndexError, TypeError, ValueError):
  93. raise self._invalid_payload() from None
  94. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  95. data = self._request(
  96. "GET",
  97. "/api/v5/public/instruments",
  98. params={"instType": "SWAP", "instId": symbol},
  99. )
  100. instrument = self._first_item(data)
  101. try:
  102. return InstrumentMeta(
  103. ct_val=float(instrument["ctVal"]),
  104. lot_sz=float(instrument["lotSz"]),
  105. min_sz=float(instrument["minSz"]),
  106. )
  107. except (KeyError, TypeError, ValueError):
  108. raise self._invalid_payload() from None
  109. def get_last_price(self, symbol: str) -> float:
  110. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  111. ticker = self._first_item(data)
  112. try:
  113. return float(ticker["last"])
  114. except (KeyError, TypeError, ValueError):
  115. raise self._invalid_payload() from None
  116. def ensure_hedge_mode(self) -> None:
  117. data = self._request("GET", "/api/v5/account/config")
  118. config = self._first_item(data)
  119. if config.get("posMode") != "long_short_mode":
  120. raise ValueError("hedge mode is required")
  121. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  122. self._request(
  123. "POST",
  124. "/api/v5/account/set-leverage",
  125. json_body={
  126. "instId": symbol,
  127. "lever": str(leverage),
  128. "mgnMode": "isolated",
  129. "posSide": pos_side,
  130. },
  131. )
  132. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  133. if signal.action == "flat":
  134. return OrderResult(
  135. status="noop",
  136. order_id=None,
  137. symbol=symbol,
  138. side=None,
  139. pos_side=None,
  140. order_type=None,
  141. size=None,
  142. )
  143. if not symbol.endswith("-SWAP"):
  144. raise ValueError("swap instrument is required")
  145. metadata = self.get_instrument_meta(symbol)
  146. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  147. side = "buy" if signal.action == "long" else "sell"
  148. pos_side = "long" if signal.action == "long" else "short"
  149. self.ensure_hedge_mode()
  150. self.set_leverage(symbol, signal.leverage, pos_side)
  151. size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
  152. order_type = "market" if signal.entry_price is None else "limit"
  153. request_body = {
  154. "instId": symbol,
  155. "tdMode": "isolated",
  156. "side": side,
  157. "posSide": pos_side,
  158. "ordType": order_type,
  159. "sz": _format_number(size),
  160. }
  161. if signal.entry_price is not None:
  162. request_body["px"] = _format_number(signal.entry_price)
  163. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  164. order = self._first_item(data)
  165. order_id = str(order.get("ordId") or "") or None
  166. return OrderResult(
  167. status="placed",
  168. order_id=order_id,
  169. symbol=symbol,
  170. side=side,
  171. pos_side=pos_side,
  172. order_type=order_type,
  173. size=size,
  174. )
  175. def get_positions(self, symbol: str) -> list[Position]:
  176. data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
  177. if not data:
  178. raise self._invalid_payload()
  179. try:
  180. return [
  181. Position(
  182. symbol=str(entry["instId"]),
  183. pos_side=str(entry["posSide"]),
  184. size=float(entry["pos"]),
  185. avg_price=float(entry["avgPx"]),
  186. )
  187. for entry in data
  188. ]
  189. except (KeyError, TypeError, ValueError):
  190. raise self._invalid_payload() from None