Quellcode durchsuchen

fix: validate position side and notional

lxy vor 1 Monat
Ursprung
Commit
be91f9cec7
2 geänderte Dateien mit 38 neuen und 2 gelöschten Zeilen
  1. 5 2
      okx_codex_trader/okx_client.py
  2. 33 0
      tests/test_okx_client.py

+ 5 - 2
okx_codex_trader/okx_client.py

@@ -36,13 +36,14 @@ def _parse_finite_float(value: object) -> float:
 
 
 def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
+    notional_decimal = _parse_finite_decimal(notional)
     price_decimal = _parse_finite_decimal(price)
     ct_val_decimal = _parse_finite_decimal(metadata.ct_val)
     lot_size = _parse_finite_decimal(metadata.lot_sz)
     min_size = _parse_finite_decimal(metadata.min_sz)
-    if price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
+    if notional_decimal <= 0 or price_decimal <= 0 or ct_val_decimal <= 0 or lot_size <= 0 or min_size <= 0:
         raise ValueError("contract sizing inputs are invalid")
-    raw_size = _parse_finite_decimal(notional) / (price_decimal * ct_val_decimal)
+    raw_size = notional_decimal / (price_decimal * ct_val_decimal)
     size = (raw_size / lot_size).to_integral_value(rounding=ROUND_DOWN) * lot_size
     if size < min_size:
         raise ValueError("contract size below minimum")
@@ -281,6 +282,8 @@ class OkxClient:
                     raise self._invalid_payload()
                 if symbol != requested_symbol:
                     raise self._invalid_payload()
+                if pos_side not in {"long", "short"}:
+                    raise self._invalid_payload()
                 positions.append(
                     Position(
                         symbol=symbol,

+ 33 - 0
tests/test_okx_client.py

@@ -356,6 +356,23 @@ def positions_with_wrong_symbol_response() -> DummyResponse:
     )
 
 
+def positions_with_invalid_pos_side_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "posSide": "net",
+                    "pos": "1",
+                    "avgPx": "25000",
+                }
+            ],
+        }
+    )
+
+
 def market_long_signal() -> TradeSignal:
     return TradeSignal(
         action="long",
@@ -468,6 +485,14 @@ def test_build_contract_size_fails_below_min_size():
         build_contract_size(notional=250, price=25_100, metadata=metadata)
 
 
+@pytest.mark.parametrize("notional", [0, -1, float("nan"), float("inf")])
+def test_build_contract_size_rejects_invalid_notional(notional):
+    metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)
+
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=notional, price=25_000, metadata=metadata)
+
+
 @pytest.mark.parametrize(
     ("price", "metadata"),
     [
@@ -859,3 +884,11 @@ def test_get_positions_rejects_mismatched_symbol():
 
     with pytest.raises(ValueError, match="okx response payload is invalid"):
         client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_get_positions_rejects_invalid_pos_side():
+    session = DummySession([positions_with_invalid_pos_side_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")