ソースを参照

feat: add okx demo client and contract sizing

lxy 1 ヶ月 前
コミット
d5130891bf
2 ファイル変更482 行追加0 行削除
  1. 184 0
      okx_codex_trader/okx_client.py
  2. 298 0
      tests/test_okx_client.py

+ 184 - 0
okx_codex_trader/okx_client.py

@@ -0,0 +1,184 @@
+import base64
+import hashlib
+import hmac
+import json
+from datetime import UTC, datetime
+from decimal import Decimal, ROUND_DOWN
+from urllib.parse import urlencode
+
+from okx_codex_trader.config import Config
+from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
+
+
+def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
+    raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
+    lot_size = Decimal(str(metadata.lot_sz))
+    size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
+    if size < Decimal(str(metadata.min_sz)):
+        raise ValueError("contract size below minimum")
+    return float(size)
+
+
+def _format_number(value: float) -> str:
+    return format(Decimal(str(value)).normalize(), "f")
+
+
+class OkxClient:
+    base_url = "https://www.okx.com"
+
+    def __init__(self, config: Config, session=None):
+        self.config = config
+        if session is None:
+            import requests
+
+            session = requests.Session()
+        self.session = session
+
+    def _request(
+        self,
+        method: str,
+        path: str,
+        *,
+        params: dict[str, object] | None = None,
+        json_body: dict[str, object] | None = None,
+    ) -> list[dict[str, object]]:
+        timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
+        query = urlencode(params or {})
+        path_with_query = path if not query else f"{path}?{query}"
+        body = "" if json_body is None else json.dumps(json_body, separators=(",", ":"))
+        signature = base64.b64encode(
+            hmac.new(
+                self.config.api_secret.encode(),
+                f"{timestamp}{method.upper()}{path_with_query}{body}".encode(),
+                hashlib.sha256,
+            ).digest()
+        ).decode()
+        headers = {
+            "OK-ACCESS-KEY": self.config.api_key,
+            "OK-ACCESS-SIGN": signature,
+            "OK-ACCESS-TIMESTAMP": timestamp,
+            "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
+            "x-simulated-trading": "1",
+        }
+        response = self.session.request(
+            method.upper(),
+            f"{self.base_url}{path}",
+            headers=headers,
+            params=params,
+            json=json_body,
+        )
+        payload = response.json()
+        if getattr(response, "status_code", 200) >= 400:
+            raise ValueError(str(payload.get("msg") or "okx http error"))
+        if payload.get("code") != "0":
+            raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
+        data = payload.get("data")
+        return data if isinstance(data, list) else []
+
+    def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
+        data = self._request(
+            "GET",
+            "/api/v5/market/history-candles",
+            params={"instId": symbol, "bar": bar, "limit": limit},
+        )
+        return [
+            Candle(
+                symbol=symbol,
+                ts=int(entry[0]),
+                open=float(entry[1]),
+                high=float(entry[2]),
+                low=float(entry[3]),
+                close=float(entry[4]),
+                volume=float(entry[5]),
+            )
+            for entry in data
+        ]
+
+    def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
+        data = self._request(
+            "GET",
+            "/api/v5/public/instruments",
+            params={"instType": "SWAP", "instId": symbol},
+        )
+        instrument = data[0]
+        return InstrumentMeta(
+            ct_val=float(instrument["ctVal"]),
+            lot_sz=float(instrument["lotSz"]),
+            min_sz=float(instrument["minSz"]),
+        )
+
+    def get_last_price(self, symbol: str) -> float:
+        data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
+        return float(data[0]["last"])
+
+    def ensure_hedge_mode(self) -> None:
+        data = self._request("GET", "/api/v5/account/config")
+        if data[0]["posMode"] != "long_short_mode":
+            raise ValueError("hedge mode is required")
+
+    def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
+        self._request(
+            "POST",
+            "/api/v5/account/set-leverage",
+            json_body={
+                "instId": symbol,
+                "lever": str(leverage),
+                "mgnMode": "isolated",
+                "posSide": pos_side,
+            },
+        )
+
+    def place_demo_order(self, symbol: str, signal: TradeSignal, margin_usdt: float) -> OrderResult:
+        if signal.action == "flat":
+            return OrderResult(
+                status="noop",
+                order_id=None,
+                symbol=symbol,
+                side=None,
+                pos_side=None,
+                order_type=None,
+                size=None,
+            )
+        if not symbol.endswith("-SWAP"):
+            raise ValueError("swap instrument is required")
+        metadata = self.get_instrument_meta(symbol)
+        price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
+        side = "buy" if signal.action == "long" else "sell"
+        pos_side = "long" if signal.action == "long" else "short"
+        self.ensure_hedge_mode()
+        self.set_leverage(symbol, signal.leverage, pos_side)
+        size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
+        order_type = "market" if signal.entry_price is None else "limit"
+        request_body = {
+            "instId": symbol,
+            "tdMode": "isolated",
+            "side": side,
+            "posSide": pos_side,
+            "ordType": order_type,
+            "sz": _format_number(size),
+        }
+        if signal.entry_price is not None:
+            request_body["px"] = _format_number(signal.entry_price)
+        data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
+        order_id = None if not data else str(data[0].get("ordId") or "")
+        return OrderResult(
+            status="placed",
+            order_id=order_id,
+            symbol=symbol,
+            side=side,
+            pos_side=pos_side,
+            order_type=order_type,
+            size=size,
+        )
+
+    def get_positions(self, symbol: str) -> list[Position]:
+        data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
+        return [
+            Position(
+                symbol=str(entry["instId"]),
+                pos_side=str(entry["posSide"]),
+                size=float(entry["pos"]),
+                avg_price=float(entry["avgPx"]),
+            )
+            for entry in data
+        ]

+ 298 - 0
tests/test_okx_client.py

@@ -0,0 +1,298 @@
+from dataclasses import dataclass
+from urllib.parse import urlparse
+
+import pytest
+
+from okx_codex_trader.config import Config
+from okx_codex_trader.models import InstrumentMeta, TradeSignal
+from okx_codex_trader.okx_client import OkxClient, build_contract_size
+
+
+@dataclass
+class DummyResponse:
+    payload: dict[str, object]
+    status_code: int = 200
+
+    def json(self) -> dict[str, object]:
+        return self.payload
+
+
+@dataclass
+class RecordedRequest:
+    method: str
+    url: str
+    headers: dict[str, str]
+    params: dict[str, object] | None
+    json_body: dict[str, object] | None
+
+
+class DummySession:
+    def __init__(self, responses: list[DummyResponse] | None = None):
+        self._responses = list(responses or [])
+        self.last_request: RecordedRequest | None = None
+        self.request_paths: list[str] = []
+        self.request_bodies: list[dict[str, object] | None] = []
+
+    @property
+    def last_json_body(self) -> dict[str, object] | None:
+        return self.last_request.json_body if self.last_request else None
+
+    def request(
+        self,
+        method: str,
+        url: str,
+        *,
+        headers: dict[str, str] | None = None,
+        params: dict[str, object] | None = None,
+        json: dict[str, object] | None = None,
+    ) -> DummyResponse:
+        self.last_request = RecordedRequest(
+            method=method,
+            url=url,
+            headers=headers or {},
+            params=params,
+            json_body=json,
+        )
+        self.request_paths.append(urlparse(url).path)
+        self.request_bodies.append(json)
+        if self._responses:
+            return self._responses.pop(0)
+        return candles_response()
+
+
+def sample_config() -> Config:
+    return Config(api_key="key", api_secret="secret", api_passphrase="passphrase")
+
+
+def candles_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                ["1710000000000", "25000", "25100", "24900", "25050", "100", "1000", "1000", "1"],
+            ],
+        }
+    )
+
+
+def instrument_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "instType": "SWAP",
+                    "ctVal": "0.001",
+                    "lotSz": "1",
+                    "minSz": "1",
+                }
+            ],
+        }
+    )
+
+
+def ticker_response(last: str) -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{"instId": "BTC-USDT-SWAP", "last": last}]})
+
+
+def account_config_response(pos_mode: str) -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{"posMode": pos_mode}]})
+
+
+def leverage_response() -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{"lever": "2"}]})
+
+
+def place_order_response() -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{"ordId": "123"}]})
+
+
+def error_response(code: str, msg: str) -> DummyResponse:
+    return DummyResponse({"code": code, "msg": msg, "data": []})
+
+
+def positions_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "posSide": "long",
+                    "pos": "8",
+                    "avgPx": "25000",
+                }
+            ],
+        }
+    )
+
+
+def market_long_signal() -> TradeSignal:
+    return TradeSignal(
+        action="long",
+        confidence=0.9,
+        leverage=2,
+        entry_price=None,
+        take_profit_price=26000.0,
+        stop_loss_price=24000.0,
+        reason="trend",
+    )
+
+
+def limit_short_signal() -> TradeSignal:
+    return TradeSignal(
+        action="short",
+        confidence=0.8,
+        leverage=2,
+        entry_price=25000.0,
+        take_profit_price=24000.0,
+        stop_loss_price=25500.0,
+        reason="mean reversion",
+    )
+
+
+def flat_signal() -> TradeSignal:
+    return TradeSignal(
+        action="flat",
+        confidence=0.7,
+        leverage=2,
+        entry_price=None,
+        take_profit_price=None,
+        stop_loss_price=None,
+        reason="exit",
+    )
+
+
+def test_signed_demo_request_attaches_headers():
+    session = DummySession()
+    client = OkxClient(config=sample_config(), session=session)
+
+    client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+    request = session.last_request
+    assert request is not None
+    assert request.headers["x-simulated-trading"] == "1"
+    assert request.headers["OK-ACCESS-KEY"] == "key"
+
+
+def test_build_contract_size_rounds_down_to_lot_size():
+    metadata = InstrumentMeta(ct_val=0.01, lot_sz=0.1, min_sz=0.1)
+    assert build_contract_size(notional=251, price=25_000, metadata=metadata) == 1.0
+
+
+def test_build_contract_size_fails_below_min_size():
+    metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=5)
+    with pytest.raises(ValueError):
+        build_contract_size(notional=250, price=25_100, metadata=metadata)
+
+
+def test_market_order_fetches_latest_price_before_sizing():
+    session = DummySession(
+        [
+            instrument_response(),
+            ticker_response(last="25000"),
+            account_config_response(pos_mode="long_short_mode"),
+            leverage_response(),
+            place_order_response(),
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+
+    assert session.request_paths == [
+        "/api/v5/public/instruments",
+        "/api/v5/market/ticker",
+        "/api/v5/account/config",
+        "/api/v5/account/set-leverage",
+        "/api/v5/trade/order",
+    ]
+
+
+def test_place_demo_order_fails_when_not_hedge_mode():
+    session = DummySession(
+        [
+            instrument_response(),
+            ticker_response(last="25000"),
+            account_config_response(pos_mode="net_mode"),
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+
+
+def test_limit_short_order_uses_sell_and_short_pos_side():
+    session = DummySession(
+        [
+            instrument_response(),
+            account_config_response(pos_mode="long_short_mode"),
+            leverage_response(),
+            place_order_response(),
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    client.place_demo_order(symbol="ETH-USDT-SWAP", signal=limit_short_signal(), margin_usdt=100)
+
+    order_request = session.last_json_body
+    assert order_request is not None
+    assert order_request["ordType"] == "limit"
+    assert order_request["side"] == "sell"
+    assert order_request["posSide"] == "short"
+    assert order_request["px"] == "25000"
+    assert session.request_bodies[2]["lever"] == "2"
+    assert session.request_bodies[2]["mgnMode"] == "isolated"
+
+
+def test_flat_signal_returns_noop_without_order_submission():
+    session = DummySession([])
+    client = OkxClient(config=sample_config(), session=session)
+
+    result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=flat_signal(), margin_usdt=100)
+
+    assert result.status == "noop"
+    assert session.request_paths == []
+
+
+def test_place_demo_order_sends_computed_sz_and_ignores_tp_sl_fields():
+    session = DummySession(
+        [
+            instrument_response(),
+            ticker_response(last="25000"),
+            account_config_response(pos_mode="long_short_mode"),
+            leverage_response(),
+            place_order_response(),
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+
+    order_request = session.last_json_body
+    assert order_request is not None
+    assert order_request["sz"] == "8"
+    assert "tpTriggerPx" not in order_request
+    assert "slTriggerPx" not in order_request
+
+
+def test_okx_error_payload_raises_value_error():
+    session = DummySession([error_response(code="51000", msg="parameter error")])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError):
+        client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+
+def test_get_positions_returns_normalized_positions():
+    session = DummySession([positions_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    positions = client.get_positions(symbol="BTC-USDT-SWAP")
+
+    assert positions[0].symbol == "BTC-USDT-SWAP"