Sfoglia il codice sorgente

fix: reject boolean signal numeric values

lxy 1 mese fa
parent
commit
b4ef5ae6c9
3 ha cambiato i file con 72 aggiunte e 0 eliminazioni
  1. 10 0
      okx_codex_trader/strategy.py
  2. 14 0
      tests/test_config.py
  3. 48 0
      tests/test_strategy.py

+ 10 - 0
okx_codex_trader/strategy.py

@@ -21,27 +21,37 @@ def validate_signal(payload: Mapping[str, object]) -> TradeSignal:
         raise ValueError("signal action is invalid")
 
     confidence = payload["confidence"]
+    if isinstance(confidence, bool):
+        raise ValueError("signal confidence is invalid")
     if not isinstance(confidence, int | float) or not 0 <= float(confidence) <= 1:
         raise ValueError("signal confidence is invalid")
 
     leverage = payload["leverage"]
+    if isinstance(leverage, bool):
+        raise ValueError("signal leverage is invalid")
     if not isinstance(leverage, int) or not 1 <= leverage <= 3:
         raise ValueError("signal leverage is invalid")
 
     entry_price = payload["entry_price"]
     if entry_price is not None:
+        if isinstance(entry_price, bool):
+            raise ValueError("signal entry_price is invalid")
         if not isinstance(entry_price, int | float):
             raise ValueError("signal entry_price is invalid")
         entry_price = float(entry_price)
 
     take_profit_price = payload["take_profit_price"]
     if take_profit_price is not None:
+        if isinstance(take_profit_price, bool):
+            raise ValueError("signal take_profit_price is invalid")
         if not isinstance(take_profit_price, int | float):
             raise ValueError("signal take_profit_price is invalid")
         take_profit_price = float(take_profit_price)
 
     stop_loss_price = payload["stop_loss_price"]
     if stop_loss_price is not None:
+        if isinstance(stop_loss_price, bool):
+            raise ValueError("signal stop_loss_price is invalid")
         if not isinstance(stop_loss_price, int | float):
             raise ValueError("signal stop_loss_price is invalid")
         stop_loss_price = float(stop_loss_price)

+ 14 - 0
tests/test_config.py

@@ -15,3 +15,17 @@ def test_load_config_requires_okx_credentials(monkeypatch):
 
     with pytest.raises(ValueError):
         load_config()
+
+
+def test_load_config_uses_explicit_env_mapping():
+    config = load_config(
+        {
+            "OKX_API_KEY": "key",
+            "OKX_API_SECRET": "secret",
+            "OKX_API_PASSPHRASE": "passphrase",
+        }
+    )
+
+    assert config.api_key == "key"
+    assert config.api_secret == "secret"
+    assert config.api_passphrase == "passphrase"

+ 48 - 0
tests/test_strategy.py

@@ -3,6 +3,32 @@ import pytest
 from okx_codex_trader.strategy import validate_signal
 
 
+@pytest.mark.parametrize(
+    ("field_name", "field_value"),
+    [
+        ("confidence", True),
+        ("leverage", True),
+        ("entry_price", True),
+        ("take_profit_price", False),
+        ("stop_loss_price", True),
+    ],
+)
+def test_validate_signal_rejects_boolean_numeric_fields(field_name, field_value):
+    signal = {
+        "action": "long",
+        "confidence": 0.9,
+        "leverage": 2,
+        "entry_price": None,
+        "take_profit_price": None,
+        "stop_loss_price": None,
+        "reason": "x",
+    }
+    signal[field_name] = field_value
+
+    with pytest.raises(ValueError):
+        validate_signal(signal)
+
+
 def test_validate_signal_rejects_leverage_out_of_range():
     with pytest.raises(ValueError):
         validate_signal(
@@ -51,3 +77,25 @@ def test_validate_signal_rejects_confidence_out_of_range():
 def test_validate_signal_requires_full_shape():
     with pytest.raises(ValueError):
         validate_signal({"action": "long", "confidence": 0.9, "leverage": 2})
+
+
+def test_validate_signal_returns_trade_signal():
+    signal = validate_signal(
+        {
+            "action": "short",
+            "confidence": 0.75,
+            "leverage": 3,
+            "entry_price": 101.5,
+            "take_profit_price": 98,
+            "stop_loss_price": None,
+            "reason": "trend",
+        }
+    )
+
+    assert signal.action == "short"
+    assert signal.confidence == 0.75
+    assert signal.leverage == 3
+    assert signal.entry_price == 101.5
+    assert signal.take_profit_price == 98.0
+    assert signal.stop_loss_price is None
+    assert signal.reason == "trend"