浏览代码

fix: tighten okx client boundaries

lxy 1 月之前
父节点
当前提交
4b8a75a214
共有 2 个文件被更改,包括 39 次插入9 次删除
  1. 9 3
      okx_codex_trader/okx_client.py
  2. 30 6
      tests/test_okx_client.py

+ 9 - 3
okx_codex_trader/okx_client.py

@@ -81,7 +81,9 @@ class OkxClient:
         if payload.get("code") != "0":
             raise ValueError(str(payload.get("msg") or payload.get("code") or "okx api error"))
         data = payload.get("data")
-        return data if isinstance(data, list) else []
+        if not isinstance(data, list):
+            raise self._invalid_payload()
+        return data
 
     def get_candles(self, symbol: str, bar: str, limit: int) -> list[Candle]:
         data = self._request(
@@ -160,6 +162,8 @@ class OkxClient:
             )
         if not symbol.endswith("-SWAP"):
             raise ValueError("swap instrument is required")
+        if signal.leverage < 1 or signal.leverage > 3:
+            raise ValueError("leverage is invalid")
         metadata = self.get_instrument_meta(symbol)
         price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
         side = "buy" if signal.action == "long" else "sell"
@@ -180,7 +184,9 @@ class OkxClient:
             request_body["px"] = _format_number(signal.entry_price)
         data = self._request("POST", "/api/v5/trade/order", json_body=request_body)
         order = self._first_item(data)
-        order_id = str(order.get("ordId") or "") or None
+        order_id = str(order.get("ordId") or "")
+        if not order_id:
+            raise self._invalid_payload()
         return OrderResult(
             status="placed",
             order_id=order_id,
@@ -194,7 +200,7 @@ class OkxClient:
     def get_positions(self, symbol: str) -> list[Position]:
         data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
         if not data:
-            raise self._invalid_payload()
+            return []
         try:
             return [
                 Position(

+ 30 - 6
tests/test_okx_client.py

@@ -296,12 +296,11 @@ def test_okx_error_payload_raises_value_error():
         client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
 
 
-def test_empty_okx_data_raises_stable_value_error():
+def test_empty_positions_data_returns_empty_list():
     session = DummySession([DummyResponse({"code": "0", "msg": "", "data": []})])
     client = OkxClient(config=sample_config(), session=session)
 
-    with pytest.raises(ValueError, match="okx response payload is invalid"):
-        client.get_positions(symbol="BTC-USDT-SWAP")
+    assert client.get_positions(symbol="BTC-USDT-SWAP") == []
 
 
 def test_malformed_numeric_field_raises_stable_value_error():
@@ -329,7 +328,15 @@ def test_malformed_numeric_field_raises_stable_value_error():
         client.get_positions(symbol="BTC-USDT-SWAP")
 
 
-def test_place_demo_order_returns_none_when_order_id_is_missing():
+def test_non_list_okx_data_raises_stable_value_error():
+    session = DummySession([DummyResponse({"code": "0", "msg": "", "data": {}})])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_place_demo_order_raises_when_order_id_is_missing():
     session = DummySession(
         [
             instrument_response(),
@@ -341,9 +348,26 @@ def test_place_demo_order_returns_none_when_order_id_is_missing():
     )
     client = OkxClient(config=sample_config(), session=session)
 
-    result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+
 
-    assert result.order_id is None
+def test_place_demo_order_rejects_invalid_leverage_before_okx():
+    session = DummySession([])
+    signal = TradeSignal(
+        action="long",
+        confidence=0.9,
+        leverage=4,
+        entry_price=None,
+        take_profit_price=None,
+        stop_loss_price=None,
+        reason="x",
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="leverage is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=signal, margin_usdt=100)
+    assert session.request_paths == []
 
 
 def test_get_positions_returns_normalized_positions():