okx_client.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. from datetime import UTC, datetime
  6. from decimal import Decimal, ROUND_DOWN
  7. from urllib.parse import urlencode
  8. from okx_codex_trader.config import Config
  9. from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
  10. def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
  11. raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
  12. lot_size = Decimal(str(metadata.lot_sz))
  13. size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
  14. if size < Decimal(str(metadata.min_sz)):
  15. raise ValueError("contract size below minimum")
  16. return float(size)
  17. def _format_number(value: float) -> str:
  18. return format(Decimal(str(value)).normalize(), "f")
  19. class OkxClient:
  20. base_url = "https://www.okx.com"
  21. def __init__(self, config: Config, session=None):
  22. self.config = config
  23. if session is None:
  24. import requests
  25. session = requests.Session()
  26. self.session = session
  27. def _request(
  28. self,
  29. method: str,
  30. path: str,
  31. *,
  32. params: dict[str, object] | None = None,
  33. json_body: dict[str, object] | None = None,
  34. ) -> list[dict[str, object]]:
  35. timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
  36. query = urlencode(params or {})
  37. path_with_query = path if not query else f"{path}?{query}"
  38. body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
  39. signature = base64.b64encode(
  40. hmac.new(
  41. self.config.api_secret.encode(),
  42. f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
  43. hashlib.sha256,
  44. ).digest()
  45. ).decode()
  46. headers = {
  47. "OK-ACCESS-KEY": self.config.api_key,
  48. "OK-ACCESS-SIGN": signature,
  49. "OK-ACCESS-TIMESTAMP": timestamp,
  50. "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
  51. "x-simulated-trading": "1",
  52. }
  53. response = self.session.request(
  54. method.upper(),
  55. f"{self.base_url}{path}",
  56. headers=headers,
  57. params=params,
  58. json=json_body,
  59. )
  60. payload = response.json()
  61. if getattr(response, "status_code", 200) >= 400:
  62. raise ValueError(str(payload.get("msg") or "okx http error"))
  63. if payload.get("code") != "0":
  64. raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
  65. data = payload.get("data")
  66. return data if isinstance(data, list) else []
  67. def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
  68. data = self._request(
  69. "GET",
  70. "/api/v5/market/history-candles",
  71. params={"instId": symbol, "bar": bar, "limit": limit},
  72. )
  73. return [
  74. Candle(
  75. symbol=symbol,
  76. ts=int(entry[0]),
  77. open=float(entry[1]),
  78. high=float(entry[2]),
  79. low=float(entry[3]),
  80. close=float(entry[4]),
  81. volume=float(entry[5]),
  82. )
  83. for entry in data
  84. ]
  85. def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
  86. data = self._request(
  87. "GET",
  88. "/api/v5/public/instruments",
  89. params={"instType": "SWAP", "instId": symbol},
  90. )
  91. instrument = data[0]
  92. return InstrumentMeta(
  93. ct_val=float(instrument["ctVal"]),
  94. lot_sz=float(instrument["lotSz"]),
  95. min_sz=float(instrument["minSz"]),
  96. )
  97. def get_last_price(self, symbol: str) -> float:
  98. data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
  99. return float(data[0]["last"])
  100. def ensure_hedge_mode(self) -> None:
  101. data = self._request("GET", "/api/v5/account/config")
  102. if data[0]["posMode"] != "long_short_mode":
  103. raise ValueError("hedge mode is required")
  104. def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
  105. self._request(
  106. "POST",
  107. "/api/v5/account/set-leverage",
  108. json_body={
  109. "instId": symbol,
  110. "lever": str(leverage),
  111. "mgnMode": "isolated",
  112. "posSide": pos_side,
  113. },
  114. )
  115. def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
  116. if signal.action == "flat":
  117. return OrderResult(
  118. status="noop",
  119. order_id=None,
  120. symbol=symbol,
  121. side=None,
  122. pos_side=None,
  123. order_type=None,
  124. size=None,
  125. )
  126. if not symbol.endswith("-SWAP"):
  127. raise ValueError("swap instrument is required")
  128. metadata = self.get_instrument_meta(symbol)
  129. price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
  130. side = "buy" if signal.action == "long" else "sell"
  131. pos_side = "long" if signal.action == "long" else "short"
  132. self.ensure_hedge_mode()
  133. self.set_leverage(symbol, signal.leverage, pos_side)
  134. size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
  135. order_type = "market" if signal.entry_price is None else "limit"
  136. request_body = {
  137. "instId": symbol,
  138. "tdMode": "isolated",
  139. "side": side,
  140. "posSide": pos_side,
  141. "ordType": order_type,
  142. "sz": _format_number(size),
  143. }
  144. if signal.entry_price is not None:
  145. request_body["px"] = _format_number(signal.entry_price)
  146. data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
  147. order_id = None if not data else str(data[0].get("ordId") or "")
  148. return OrderResult(
  149. status="placed",
  150. order_id=order_id,
  151. symbol=symbol,
  152. side=side,
  153. pos_side=pos_side,
  154. order_type=order_type,
  155. size=size,
  156. )
  157. def get_positions(self, symbol: str) -> list[Position]:
  158. data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
  159. return [
  160. Position(
  161. symbol=str(entry["instId"]),
  162. pos_side=str(entry["posSide"]),
  163. size=float(entry["pos"]),
  164. avg_price=float(entry["avgPx"]),
  165. )
  166. for entry in data
  167. ]