Kaynağa Gözat

fix: validate action and ignore flat rows

lxy 1 ay önce
ebeveyn
işleme
64f399034e
2 değiştirilmiş dosya ile 67 ekleme ve 9 silme
  1. 15 9
      okx_codex_trader/okx_client.py
  2. 52 0
      tests/test_okx_client.py

+ 15 - 9
okx_codex_trader/okx_client.py

@@ -161,6 +161,8 @@ class OkxClient:
                 order_type=None,
                 order_type=None,
                 size=None,
                 size=None,
             )
             )
+        if signal.action not in {"long", "short"}:
+            raise ValueError("action is invalid")
         if not symbol.endswith("-SWAP"):
         if not symbol.endswith("-SWAP"):
             raise ValueError("swap instrument is required")
             raise ValueError("swap instrument is required")
         if signal.leverage < 1 or signal.leverage > 3:
         if signal.leverage < 1 or signal.leverage > 3:
@@ -203,15 +205,19 @@ class OkxClient:
         if not data:
         if not data:
             return []
             return []
         try:
         try:
-            positions = [
-                Position(
-                    symbol=str(entry["instId"]),
-                    pos_side=str(entry["posSide"]),
-                    size=float(entry["pos"]),
-                    avg_price=float(entry["avgPx"]),
+            positions = []
+            for entry in data:
+                size = float(entry["pos"])
+                if size == 0.0:
+                    continue
+                positions.append(
+                    Position(
+                        symbol=str(entry["instId"]),
+                        pos_side=str(entry["posSide"]),
+                        size=size,
+                        avg_price=float(entry["avgPx"]),
+                    )
                 )
                 )
-                for entry in data
-            ]
-            return [position for position in positions if position.size != 0.0]
+            return positions
         except (KeyError, TypeError, ValueError):
         except (KeyError, TypeError, ValueError):
             raise self._invalid_payload() from None
             raise self._invalid_payload() from None

+ 52 - 0
tests/test_okx_client.py

@@ -189,6 +189,29 @@ def positions_with_zero_size_response() -> DummyResponse:
     )
     )
 
 
 
 
+def positions_with_zero_size_malformed_avg_price_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "posSide": "long",
+                    "pos": "0",
+                    "avgPx": "bad",
+                },
+                {
+                    "instId": "BTC-USDT-SWAP",
+                    "posSide": "short",
+                    "pos": "3",
+                    "avgPx": "24900",
+                },
+            ],
+        }
+    )
+
+
 def market_long_signal() -> TradeSignal:
 def market_long_signal() -> TradeSignal:
     return TradeSignal(
     return TradeSignal(
         action="long",
         action="long",
@@ -453,6 +476,24 @@ def test_place_demo_order_rejects_invalid_leverage_before_okx():
     assert session.request_paths == []
     assert session.request_paths == []
 
 
 
 
+def test_place_demo_order_rejects_unknown_action_before_okx():
+    session = DummySession([])
+    signal = TradeSignal(
+        action="hold",
+        confidence=0.9,
+        leverage=2,
+        entry_price=None,
+        take_profit_price=None,
+        stop_loss_price=None,
+        reason="x",
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="action is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=signal, margin_usdt=100)
+    assert session.request_paths == []
+
+
 def test_get_positions_returns_normalized_positions():
 def test_get_positions_returns_normalized_positions():
     session = DummySession([positions_response()])
     session = DummySession([positions_response()])
     client = OkxClient(config=sample_config(), session=session)
     client = OkxClient(config=sample_config(), session=session)
@@ -474,3 +515,14 @@ def test_get_positions_filters_zero_size_rows():
     assert len(positions) == 1
     assert len(positions) == 1
     assert positions[0].pos_side == "short"
     assert positions[0].pos_side == "short"
     assert positions[0].size == 3.0
     assert positions[0].size == 3.0
+
+
+def test_get_positions_ignores_malformed_fields_on_zero_size_rows():
+    session = DummySession([positions_with_zero_size_malformed_avg_price_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    positions = client.get_positions(symbol="BTC-USDT-SWAP")
+
+    assert len(positions) == 1
+    assert positions[0].pos_side == "short"
+    assert positions[0].avg_price == 24900.0