Procházet zdrojové kódy

fix: validate sizing and position payloads

lxy před 1 měsícem
rodič
revize
1d141092e0
2 změnil soubory, kde provedl 56 přidání a 4 odebrání
  1. 13 4
      okx_codex_trader/okx_client.py
  2. 43 0
      tests/test_okx_client.py

+ 13 - 4
okx_codex_trader/okx_client.py

@@ -15,10 +15,15 @@ OkxRow: TypeAlias = dict[str, object] | list[object]
 
 
 def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
-    raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
+    price_decimal = Decimal(str(price))
+    ct_val_decimal = Decimal(str(metadata.ct_val))
     lot_size = Decimal(str(metadata.lot_sz))
+    min_size = Decimal(str(metadata.min_sz))
+    if price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
+        raise ValueError("contract sizing inputs are invalid")
+    raw_size = Decimal(str(notional)) / (price_decimal * ct_val_decimal)
     size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
-    if size < Decimal(str(metadata.min_sz)):
+    if size < min_size:
         raise ValueError("contract size below minimum")
     return float(size)
 
@@ -228,10 +233,14 @@ class OkxClient:
                 size = float(entry["pos"])
                 if size == 0.0:
                     continue
+                symbol = entry["instId"]
+                pos_side = entry["posSide"]
+                if not isinstance(symbol, str) or not isinstance(pos_side, str):
+                    raise self._invalid_payload()
                 positions.append(
                     Position(
-                        symbol=str(entry["instId"]),
-                        pos_side=str(entry["posSide"]),
+                        symbol=symbol,
+                        pos_side=pos_side,
                         size=size,
                         avg_price=float(entry["avgPx"]),
                     )

+ 43 - 0
tests/test_okx_client.py

@@ -221,6 +221,23 @@ def positions_with_zero_size_malformed_avg_price_response() -> DummyResponse:
     )
 
 
+def positions_with_non_string_identity_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": None,
+                    "posSide": ["long"],
+                    "pos": "3",
+                    "avgPx": "24900",
+                }
+            ],
+        }
+    )
+
+
 def market_long_signal() -> TradeSignal:
     return TradeSignal(
         action="long",
@@ -302,6 +319,24 @@ def test_build_contract_size_fails_below_min_size():
         build_contract_size(notional=250, price=25_100, metadata=metadata)
 
 
+@pytest.mark.parametrize(
+    ("price", "metadata"),
+    [
+        (0, InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)),
+        (-1, InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0, lot_sz=1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=-0.01, lot_sz=1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=0, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=-1, min_sz=1)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=0)),
+        (25_000, InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=-1)),
+    ],
+)
+def test_build_contract_size_rejects_non_positive_inputs(price, metadata):
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=250, price=price, metadata=metadata)
+
+
 def test_market_order_fetches_latest_price_before_sizing():
     session = DummySession(
         [
@@ -561,3 +596,11 @@ def test_get_positions_ignores_malformed_fields_on_zero_size_rows():
     assert len(positions) == 1
     assert positions[0].pos_side == "short"
     assert positions[0].avg_price == 24900.0
+
+
+def test_get_positions_rejects_non_string_inst_id_and_pos_side():
+    session = DummySession([positions_with_non_string_identity_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")