浏览代码

fix: validate hedge precondition ordering

lxy 1 月之前
父节点
当前提交
dcbad083c0
共有 2 个文件被更改,包括 55 次插入10 次删除
  1. 9 1
      okx_codex_trader/okx_client.py
  2. 46 9
      tests/test_okx_client.py

+ 9 - 1
okx_codex_trader/okx_client.py

@@ -16,6 +16,8 @@ OkxRow: TypeAlias = dict[str, object] | list[object]
 
 
 def _parse_finite_decimal(value: object) -> Decimal:
+    if isinstance(value, bool):
+        raise ValueError("contract sizing inputs are invalid")
     try:
         parsed = Decimal(str(value))
     except (InvalidOperation, TypeError, ValueError):
@@ -26,6 +28,8 @@ def _parse_finite_decimal(value: object) -> Decimal:
 
 
 def _parse_finite_float(value: object) -> float:
+    if isinstance(value, bool):
+        raise ValueError("okx response payload is invalid")
     try:
         parsed = float(value)
     except (TypeError, ValueError):
@@ -197,6 +201,8 @@ class OkxClient:
     def set_leverage(self, symbol: str, leverage: int, pos_side: str) -> None:
         if not symbol.endswith("-SWAP"):
             raise ValueError("swap instrument is required")
+        if isinstance(leverage, bool):
+            raise ValueError("leverage is invalid")
         if leverage < 1 or leverage > 3:
             raise ValueError("leverage is invalid")
         if pos_side not in {"long", "short"}:
@@ -227,6 +233,8 @@ class OkxClient:
             raise ValueError("action is invalid")
         if not symbol.endswith("-SWAP"):
             raise ValueError("swap instrument is required")
+        if isinstance(signal.leverage, bool):
+            raise ValueError("leverage is invalid")
         if signal.leverage < 1 or signal.leverage > 3:
             raise ValueError("leverage is invalid")
         try:
@@ -235,11 +243,11 @@ class OkxClient:
             raise ValueError("margin_usdt is invalid") from None
         if margin_value <= 0:
             raise ValueError("margin_usdt is invalid")
+        self.ensure_hedge_mode()
         metadata = self.get_instrument_meta(symbol)
         price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
         side = "buy" if signal.action == "long" else "sell"
         pos_side = "long" if signal.action == "long" else "short"
-        self.ensure_hedge_mode()
         size = build_contract_size(margin_value * signal.leverage, price, metadata)
         self.set_leverage(symbol, signal.leverage, pos_side)
         order_type = "market" if signal.entry_price is None else "limit"

+ 46 - 9
tests/test_okx_client.py

@@ -443,8 +443,8 @@ def test_signed_demo_request_attaches_headers():
 def test_signed_post_request_uses_actual_serialized_body_bytes():
     session = DummySession(
         [
-            instrument_response(symbol="ETH-USDT-SWAP"),
             account_config_response(pos_mode="long_short_mode"),
+            instrument_response(symbol="ETH-USDT-SWAP"),
             leverage_response(),
             place_order_response(),
         ]
@@ -497,6 +497,17 @@ def test_build_contract_size_rejects_invalid_notional(notional):
         build_contract_size(notional=notional, price=25_000, metadata=metadata)
 
 
+def test_build_contract_size_rejects_boolean_inputs():
+    metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=1)
+
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=True, price=25_000, metadata=metadata)
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=100, price=False, metadata=metadata)
+    with pytest.raises(ValueError, match="contract sizing inputs are invalid"):
+        build_contract_size(notional=100, price=25_000, metadata=InstrumentMeta(ct_val=True, lot_sz=1, min_sz=1))
+
+
 @pytest.mark.parametrize(
     ("price", "metadata"),
     [
@@ -533,9 +544,9 @@ def test_build_contract_size_rejects_non_finite_inputs(price, metadata):
 def test_market_order_fetches_latest_price_before_sizing():
     session = DummySession(
         [
+            account_config_response(pos_mode="long_short_mode"),
             instrument_response(),
             ticker_response(last="25000"),
-            account_config_response(pos_mode="long_short_mode"),
             leverage_response(),
             place_order_response(),
         ]
@@ -545,9 +556,9 @@ def test_market_order_fetches_latest_price_before_sizing():
     client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
 
     assert session.request_paths == [
+        "/api/v5/account/config",
         "/api/v5/public/instruments",
         "/api/v5/market/ticker",
-        "/api/v5/account/config",
         "/api/v5/account/set-leverage",
         "/api/v5/trade/order",
     ]
@@ -556,8 +567,6 @@ def test_market_order_fetches_latest_price_before_sizing():
 def test_place_demo_order_fails_when_not_hedge_mode():
     session = DummySession(
         [
-            instrument_response(),
-            ticker_response(last="25000"),
             account_config_response(pos_mode="net_mode"),
         ]
     )
@@ -565,6 +574,7 @@ def test_place_demo_order_fails_when_not_hedge_mode():
 
     with pytest.raises(ValueError):
         client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
+    assert session.request_paths == ["/api/v5/account/config"]
 
 
 def test_ensure_hedge_mode_rejects_malformed_config_payload():
@@ -578,9 +588,9 @@ def test_ensure_hedge_mode_rejects_malformed_config_payload():
 def test_place_demo_order_validates_size_before_setting_leverage():
     session = DummySession(
         [
+            account_config_response(pos_mode="long_short_mode"),
             large_min_size_instrument_response(),
             ticker_response(last="25000"),
-            account_config_response(pos_mode="long_short_mode"),
         ]
     )
     client = OkxClient(config=sample_config(), session=session)
@@ -589,17 +599,17 @@ def test_place_demo_order_validates_size_before_setting_leverage():
         client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
 
     assert session.request_paths == [
+        "/api/v5/account/config",
         "/api/v5/public/instruments",
         "/api/v5/market/ticker",
-        "/api/v5/account/config",
     ]
 
 
 def test_limit_short_order_uses_sell_and_short_pos_side():
     session = DummySession(
         [
-            instrument_response(symbol="ETH-USDT-SWAP"),
             account_config_response(pos_mode="long_short_mode"),
+            instrument_response(symbol="ETH-USDT-SWAP"),
             leverage_response(),
             place_order_response(),
         ]
@@ -631,9 +641,9 @@ def test_flat_signal_returns_noop_without_order_submission():
 def test_place_demo_order_sends_computed_sz_and_ignores_tp_sl_fields():
     session = DummySession(
         [
+            account_config_response(pos_mode="long_short_mode"),
             instrument_response(),
             ticker_response(last="25000"),
-            account_config_response(pos_mode="long_short_mode"),
             leverage_response(),
             place_order_response(),
         ]
@@ -795,6 +805,24 @@ def test_place_demo_order_rejects_invalid_leverage_before_okx():
     assert session.request_paths == []
 
 
+def test_place_demo_order_rejects_boolean_leverage_before_okx():
+    session = DummySession([])
+    signal = TradeSignal(
+        action="long",
+        confidence=0.9,
+        leverage=True,
+        entry_price=None,
+        take_profit_price=None,
+        stop_loss_price=None,
+        reason="x",
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="leverage is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=signal, margin_usdt=100)
+    assert session.request_paths == []
+
+
 @pytest.mark.parametrize("margin_usdt", [0, -1, float("nan"), float("inf")])
 def test_place_demo_order_rejects_invalid_margin_before_okx(margin_usdt):
     session = DummySession([])
@@ -805,6 +833,15 @@ def test_place_demo_order_rejects_invalid_margin_before_okx(margin_usdt):
     assert session.request_paths == []
 
 
+def test_place_demo_order_rejects_boolean_margin_before_okx():
+    session = DummySession([])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="margin_usdt is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=True)
+    assert session.request_paths == []
+
+
 @pytest.mark.parametrize(
     ("symbol", "leverage", "pos_side", "expected_message"),
     [