| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import base64
- import hashlib
- import hmac
- import json
- from datetime import UTC, datetime
- from decimal import Decimal, ROUND_DOWN
- from urllib.parse import urlencode
- from okx_codex_trader.config import Config
- from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
- def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
- raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
- lot_size = Decimal(str(metadata.lot_sz))
- size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
- if size < Decimal(str(metadata.min_sz)):
- raise ValueError("contract size below minimum")
- return float(size)
- def _format_number(value: float) -> str:
- return format(Decimal(str(value)).normalize(), "f")
- class OkxClient:
- base_url = "https://www.okx.com"
- def __init__(self, config: Config, session=None):
- self.config = config
- if session is None:
- import requests
- session = requests.Session()
- self.session = session
- def _request(
- self,
- method: str,
- path: str,
- *,
- params: dict[str, object] | None = None,
- json_body: dict[str, object] | None = None,
- ) -> list[dict[str, object]]:
- timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
- query = urlencode(params or {})
- path_with_query = path if not query else f"{path}?{query}"
- body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
- signature = base64.b64encode(
- hmac.new(
- self.config.api_secret.encode(),
- f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
- hashlib.sha256,
- ).digest()
- ).decode()
- headers = {
- "OK-ACCESS-KEY": self.config.api_key,
- "OK-ACCESS-SIGN": signature,
- "OK-ACCESS-TIMESTAMP": timestamp,
- "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
- "x-simulated-trading": "1",
- }
- response = self.session.request(
- method.upper(),
- f"{self.base_url}{path}",
- headers=headers,
- params=params,
- json=json_body,
- )
- payload = response.json()
- if getattr(response, "status_code", 200) >= 400:
- raise ValueError(str(payload.get("msg") or "okx http error"))
- if payload.get("code") != "0":
- raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
- data = payload.get("data")
- return data if isinstance(data, list) else []
- def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
- data = self._request(
- "GET",
- "/api/v5/market/history-candles",
- params={"instId": symbol, "bar": bar, "limit": limit},
- )
- return [
- Candle(
- symbol=symbol,
- ts=int(entry[0]),
- open=float(entry[1]),
- high=float(entry[2]),
- low=float(entry[3]),
- close=float(entry[4]),
- volume=float(entry[5]),
- )
- for entry in data
- ]
- def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
- data = self._request(
- "GET",
- "/api/v5/public/instruments",
- params={"instType": "SWAP", "instId": symbol},
- )
- instrument = data[0]
- return InstrumentMeta(
- ct_val=float(instrument["ctVal"]),
- lot_sz=float(instrument["lotSz"]),
- min_sz=float(instrument["minSz"]),
- )
- def get_last_price(self, symbol: str) -> float:
- data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
- return float(data[0]["last"])
- def ensure_hedge_mode(self) -> None:
- data = self._request("GET", "/api/v5/account/config")
- if data[0]["posMode"] != "long_short_mode":
- raise ValueError("hedge mode is required")
- def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
- self._request(
- "POST",
- "/api/v5/account/set-leverage",
- json_body={
- "instId": symbol,
- "lever": str(leverage),
- "mgnMode": "isolated",
- "posSide": pos_side,
- },
- )
- def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
- if signal.action == "flat":
- return OrderResult(
- status="noop",
- order_id=None,
- symbol=symbol,
- side=None,
- pos_side=None,
- order_type=None,
- size=None,
- )
- if not symbol.endswith("-SWAP"):
- raise ValueError("swap instrument is required")
- metadata = self.get_instrument_meta(symbol)
- price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
- side = "buy" if signal.action == "long" else "sell"
- pos_side = "long" if signal.action == "long" else "short"
- self.ensure_hedge_mode()
- self.set_leverage(symbol, signal.leverage, pos_side)
- size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
- order_type = "market" if signal.entry_price is None else "limit"
- request_body = {
- "instId": symbol,
- "tdMode": "isolated",
- "side": side,
- "posSide": pos_side,
- "ordType": order_type,
- "sz": _format_number(size),
- }
- if signal.entry_price is not None:
- request_body["px"] = _format_number(signal.entry_price)
- data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
- order_id = None if not data else str(data[0].get("ordId") or "")
- return OrderResult(
- status="placed",
- order_id=order_id,
- symbol=symbol,
- side=side,
- pos_side=pos_side,
- order_type=order_type,
- size=size,
- )
- def get_positions(self, symbol: str) -> list[Position]:
- data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
- return [
- Position(
- symbol=str(entry["instId"]),
- pos_side=str(entry["posSide"]),
- size=float(entry["pos"]),
- avg_price=float(entry["avgPx"]),
- )
- for entry in data
- ]
|