Sfoglia il codice sorgente

fix: normalize okx boundary errors

lxy 1 mese fa
parent
commit
61cbf14d3a
2 ha cambiato i file con 109 aggiunte e 30 eliminazioni
  1. 55 30
      okx_codex_trader/okx_client.py
  2. 54 0
      tests/test_okx_client.py

+ 55 - 30
okx_codex_trader/okx_client.py

@@ -34,6 +34,14 @@ class OkxClient:
             session = requests.Session()
         self.session = session
 
+    def _invalid_payload(self) -> ValueError:
+        return ValueError("okx response payload is invalid")
+
+    def _first_item(self, data: list[dict[str, object]]) -> dict[str, object]:
+        if not data:
+            raise self._invalid_payload()
+        return data[0]
+
     def _request(
         self,
         method: str,
@@ -81,18 +89,21 @@ class OkxClient:
             "/api/v5/market/history-candles",
             params={"instId": symbol, "bar": bar, "limit": limit},
         )
-        return [
-            Candle(
-                symbol=symbol,
-                ts=int(entry[0]),
-                open=float(entry[1]),
-                high=float(entry[2]),
-                low=float(entry[3]),
-                close=float(entry[4]),
-                volume=float(entry[5]),
-            )
-            for entry in data
-        ]
+        try:
+            return [
+                Candle(
+                    symbol=symbol,
+                    ts=int(entry[0]),
+                    open=float(entry[1]),
+                    high=float(entry[2]),
+                    low=float(entry[3]),
+                    close=float(entry[4]),
+                    volume=float(entry[5]),
+                )
+                for entry in data
+            ]
+        except (IndexError, TypeError, ValueError):
+            raise self._invalid_payload() from None
 
     def get_instrument_meta(self, symbol: str) -> InstrumentMeta:
         data = self._request(
@@ -100,20 +111,28 @@ class OkxClient:
             "/api/v5/public/instruments",
             params={"instType": "SWAP", "instId": symbol},
         )
-        instrument = data[0]
-        return InstrumentMeta(
-            ct_val=float(instrument["ctVal"]),
-            lot_sz=float(instrument["lotSz"]),
-            min_sz=float(instrument["minSz"]),
-        )
+        instrument = self._first_item(data)
+        try:
+            return InstrumentMeta(
+                ct_val=float(instrument["ctVal"]),
+                lot_sz=float(instrument["lotSz"]),
+                min_sz=float(instrument["minSz"]),
+            )
+        except (KeyError, TypeError, ValueError):
+            raise self._invalid_payload() from None
 
     def get_last_price(self, symbol: str) -> float:
         data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
-        return float(data[0]["last"])
+        ticker = self._first_item(data)
+        try:
+            return float(ticker["last"])
+        except (KeyError, TypeError, ValueError):
+            raise self._invalid_payload() from None
 
     def ensure_hedge_mode(self) -> None:
         data = self._request("GET", "/api/v5/account/config")
-        if data[0]["posMode"] != "long_short_mode":
+        config = self._first_item(data)
+        if config.get("posMode") != "long_short_mode":
             raise ValueError("hedge mode is required")
 
     def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
@@ -160,7 +179,8 @@ class OkxClient:
         if signal.entry_price is not None:
             request_body["px"] = _format_number(signal.entry_price)
         data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
-        order_id = None if not data else str(data[0].get("ordId") or "")
+        order = self._first_item(data)
+        order_id = str(order.get("ordId") or "") or None
         return OrderResult(
             status="placed",
             order_id=order_id,
@@ -173,12 +193,17 @@ class OkxClient:
 
     def get_positions(self, symbol: str) -> list[Position]:
         data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
-        return [
-            Position(
-                symbol=str(entry["instId"]),
-                pos_side=str(entry["posSide"]),
-                size=float(entry["pos"]),
-                avg_price=float(entry["avgPx"]),
-            )
-            for entry in data
-        ]
+        if not data:
+            raise self._invalid_payload()
+        try:
+            return [
+                Position(
+                    symbol=str(entry["instId"]),
+                    pos_side=str(entry["posSide"]),
+                    size=float(entry["pos"]),
+                    avg_price=float(entry["avgPx"]),
+                )
+                for entry in data
+            ]
+        except (KeyError, TypeError, ValueError):
+            raise self._invalid_payload() from None

+ 54 - 0
tests/test_okx_client.py

@@ -110,6 +110,10 @@ def place_order_response() -> DummyResponse:
     return DummyResponse({"code": "0", "msg": "", "data": [{"ordId": "123"}]})
 
 
+def place_order_response_without_order_id() -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{}]})
+
+
 def error_response(code: str, msg: str) -> DummyResponse:
     return DummyResponse({"code": code, "msg": msg, "data": []})
 
@@ -292,6 +296,56 @@ def test_okx_error_payload_raises_value_error():
         client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
 
 
+def test_empty_okx_data_raises_stable_value_error():
+    session = DummySession([DummyResponse({"code": "0", "msg": "", "data": []})])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_malformed_numeric_field_raises_stable_value_error():
+    session = DummySession(
+        [
+            DummyResponse(
+                {
+                    "code": "0",
+                    "msg": "",
+                    "data": [
+                        {
+                            "instId": "BTC-USDT-SWAP",
+                            "posSide": "long",
+                            "pos": "bad",
+                            "avgPx": "25000",
+                        }
+                    ],
+                }
+            )
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_place_demo_order_returns_none_when_order_id_is_missing():
+    session = DummySession(
+        [
+            instrument_response(),
+            ticker_response(last="25000"),
+            account_config_response(pos_mode="long_short_mode"),
+            leverage_response(),
+            place_order_response_without_order_id(),
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+
+    assert result.order_id is None
+
+
 def test_get_positions_returns_normalized_positions():
     session = DummySession([positions_response()])
     client = OkxClient(config=sample_config(), session=session)