瀏覽代碼

fix: reject non-finite okx numeric values

lxy 1 月之前
父節點
當前提交
4df84d5dea
共有 2 個文件被更改,包括 136 次插入17 次删除
  1. 38 17
      okx_codex_trader/okx_client.py
  2. 98 0
      tests/test_okx_client.py

+ 38 - 17
okx_codex_trader/okx_client.py

@@ -3,7 +3,8 @@ import hashlib
 import hmac
 import json
 from datetime import UTC, datetime
-from decimal import Decimal, ROUND_DOWN
+from decimal import Decimal, InvalidOperation, ROUND_DOWN
+from math import isfinite
 from typing import TypeAlias
 from urllib.parse import urlencode
 
@@ -14,14 +15,34 @@ from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Positio
 OkxRow: TypeAlias = dict[str, object] | list[object]
 
 
+def _parse_finite_decimal(value: object) -> Decimal:
+    try:
+        parsed = Decimal(str(value))
+    except (InvalidOperation, TypeError, ValueError):
+        raise ValueError("contract sizing inputs are invalid") from None
+    if not parsed.is_finite():
+        raise ValueError("contract sizing inputs are invalid")
+    return parsed
+
+
+def _parse_finite_float(value: object) -> float:
+    try:
+        parsed = float(value)
+    except (TypeError, ValueError):
+        raise ValueError("okx response payload is invalid") from None
+    if not isfinite(parsed):
+        raise ValueError("okx response payload is invalid")
+    return parsed
+
+
 def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
-    price_decimal = Decimal(str(price))
-    ct_val_decimal = Decimal(str(metadata.ct_val))
-    lot_size = Decimal(str(metadata.lot_sz))
-    min_size = Decimal(str(metadata.min_sz))
+    price_decimal = _parse_finite_decimal(price)
+    ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
+    lot_size = _parse_finite_decimal(metadata.lot_sz)
+    min_size = _parse_finite_decimal(metadata.min_sz)
     if price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
         raise ValueError("contract sizing inputs are invalid")
-    raw_size = Decimal(str(notional)) / (price_decimal * ct_val_decimal)
+    raw_size = _parse_finite_decimal(notional) / (price_decimal * ct_val_decimal)
     size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
     if size < min_size:
         raise ValueError("contract size below minimum")
@@ -121,11 +142,11 @@ class OkxClient:
                 Candle(
                     symbol=symbol,
                     ts=int(entry[0]),
-                    open=float(entry[1]),
-                    high=float(entry[2]),
-                    low=float(entry[3]),
-                    close=float(entry[4]),
-                    volume=float(entry[5]),
+                    open=_parse_finite_float(entry[1]),
+                    high=_parse_finite_float(entry[2]),
+                    low=_parse_finite_float(entry[3]),
+                    close=_parse_finite_float(entry[4]),
+                    volume=_parse_finite_float(entry[5]),
                 )
                 for entry in data
             ]
@@ -142,9 +163,9 @@ class OkxClient:
         instrument = self._first_item(data)
         try:
             return InstrumentMeta(
-                ct_val=float(instrument["ctVal"]),
-                lot_sz=float(instrument["lotSz"]),
-                min_sz=float(instrument["minSz"]),
+                ct_val=_parse_finite_float(instrument["ctVal"]),
+                lot_sz=_parse_finite_float(instrument["lotSz"]),
+                min_sz=_parse_finite_float(instrument["minSz"]),
             )
         except (KeyError, TypeError, ValueError):
             raise self._invalid_payload() from None
@@ -153,7 +174,7 @@ class OkxClient:
         data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
         ticker = self._first_item(data)
         try:
-            return float(ticker["last"])
+            return _parse_finite_float(ticker["last"])
         except (KeyError, TypeError, ValueError):
             raise self._invalid_payload() from None
 
@@ -232,7 +253,7 @@ class OkxClient:
         try:
             positions = []
             for entry in data:
-                size = float(entry["pos"])
+                size = _parse_finite_float(entry["pos"])
                 if size == 0.0:
                     continue
                 symbol = entry["instId"]
@@ -244,7 +265,7 @@ class OkxClient:
                         symbol=symbol,
                         pos_side=pos_side,
                         size=size,
-                        avg_price=float(entry["avgPx"]),
+                        avg_price=_parse_finite_float(entry["avgPx"]),
                     )
                 )
             return positions

+ 98 - 0
tests/test_okx_client.py

@@ -249,6 +249,57 @@ def positions_with_non_string_identity_response() -> DummyResponse:
     )
 
 
+def candles_with_non_finite_numeric_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                ["1710000000000", "NaN", "25100", "24900", "25050", "100", "1000", "1000", "1"],
+            ],
+        }
+    )
+
+
+def instrument_with_non_finite_numeric_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "instType": "SWAP",
+                    "ctVal": "NaN",
+                    "lotSz": "1",
+                    "minSz": "1",
+                }
+            ],
+        }
+    )
+
+
+def ticker_with_non_finite_numeric_response() -> DummyResponse:
+    return DummyResponse({"code": "0", "msg": "", "data": [{"instId": "BTC-USDT-SWAP", "last": "Infinity"}]})
+
+
+def positions_with_non_finite_numeric_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "posSide": "long",
+                    "pos": "1",
+                    "avgPx": "NaN",
+                }
+            ],
+        }
+    )
+
+
 def market_long_signal() -> TradeSignal:
     return TradeSignal(
         action="long",
@@ -377,6 +428,21 @@ def test_build_contract_size_rejects_non_positive_inputs(price, metadata):
         build_contract_size(notional=250, price=price, metadata=metadata)
 
 
+@pytest.mark.parametrize(
+    ("price", "metadata"),
+    [
+        (float("nan"), InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)),
+        (float("inf"), InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=float("nan"), lot_sz=1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=float("inf"), min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=float("-inf"))),
+    ],
+)
+def test_build_contract_size_rejects_non_finite_inputs(price, metadata):
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=250, price=price, metadata=metadata)
+
+
 def test_market_order_fetches_latest_price_before_sizing():
     session = DummySession(
         [
@@ -496,6 +562,14 @@ def test_okx_error_payload_raises_value_error():
         client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
 
 
+def test_get_candles_rejects_non_finite_numeric_fields():
+    session = DummySession([candles_with_non_finite_numeric_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+
 def test_transport_failure_raises_stable_value_error():
     session = DummySession([RuntimeError("socket closed")])
     client = OkxClient(config=sample_config(), session=session)
@@ -552,6 +626,22 @@ def test_non_list_okx_data_raises_stable_value_error():
         client.get_positions(symbol="BTC-USDT-SWAP")
 
 
+def test_get_instrument_meta_rejects_non_finite_numeric_fields():
+    session = DummySession([instrument_with_non_finite_numeric_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_instrument_meta(symbol="BTC-USDT-SWAP")
+
+
+def test_get_last_price_rejects_non_finite_numeric_field():
+    session = DummySession([ticker_with_non_finite_numeric_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_last_price(symbol="BTC-USDT-SWAP")
+
+
 def test_place_demo_order_raises_when_order_id_is_missing():
     session = DummySession(
         [
@@ -644,3 +734,11 @@ def test_get_positions_rejects_non_string_inst_id_and_pos_side():
 
     with pytest.raises(ValueError, match="okx response payload is invalid"):
         client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_get_positions_rejects_non_finite_numeric_fields():
+    session = DummySession([positions_with_non_finite_numeric_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")