Forráskód Böngészése

feat: add config and signal validation models

lxy 1 hónapja
szülő
commit
4b4284eb99

+ 24 - 0
okx_codex_trader/config.py

@@ -0,0 +1,24 @@
+import os
+from dataclasses import dataclass
+from typing import Mapping
+
+
+@dataclass(frozen=True)
+class Config:
+    api_key: str
+    api_secret: str
+    api_passphrase: str
+
+
+def load_config(env: Mapping[str, str] | None = None) -> Config:
+    source = os.environ if env is None else env
+    api_key = source.get("OKX_API_KEY")
+    api_secret = source.get("OKX_API_SECRET")
+    api_passphrase = source.get("OKX_API_PASSPHRASE")
+    if not api_key or not api_secret or not api_passphrase:
+        raise ValueError("OKX credentials are required")
+    return Config(
+        api_key=api_key,
+        api_secret=api_secret,
+        api_passphrase=api_passphrase,
+    )

+ 73 - 0
okx_codex_trader/models.py

@@ -0,0 +1,73 @@
+from dataclasses import asdict, dataclass
+from typing import Literal
+
+
+@dataclass(frozen=True)
+class Candle:
+    symbol: str
+    ts: int
+    open: float
+    high: float
+    low: float
+    close: float
+    volume: float
+
+
+@dataclass(frozen=True)
+class TradeSignal:
+    action: Literal["long", "short", "flat"]
+    confidence: float
+    leverage: int
+    entry_price: float | None
+    take_profit_price: float | None
+    stop_loss_price: float | None
+    reason: str
+
+
+@dataclass(frozen=True)
+class InstrumentMeta:
+    ct_val: float
+    lot_sz: float
+    min_sz: float
+
+
+@dataclass(frozen=True)
+class Position:
+    symbol: str
+    pos_side: str
+    size: float
+    avg_price: float
+
+
+@dataclass(frozen=True)
+class OrderResult:
+    status: str
+    order_id: str | None
+    symbol: str
+    side: str | None
+    pos_side: str | None
+    order_type: str | None
+    size: float | None
+
+
+@dataclass(frozen=True)
+class BacktestTrade:
+    direction: str
+    entry_price: float
+    exit_price: float
+    margin_used: float
+    ending_equity: float
+
+
+@dataclass(frozen=True)
+class BacktestResult:
+    initial_equity: float
+    ending_equity: float
+    total_return: float
+    max_drawdown: float
+    win_rate: float
+    trade_count: int
+    trades: list[BacktestTrade]
+
+    def to_dict(self) -> dict[str, object]:
+        return asdict(self)

+ 61 - 0
okx_codex_trader/strategy.py

@@ -0,0 +1,61 @@
+from typing import Mapping
+
+from okx_codex_trader.models import TradeSignal
+
+
+def validate_signal(payload: Mapping[str, object]) -> TradeSignal:
+    required_keys = {
+        "action",
+        "confidence",
+        "leverage",
+        "entry_price",
+        "take_profit_price",
+        "stop_loss_price",
+        "reason",
+    }
+    if set(payload) != required_keys:
+        raise ValueError("signal shape is invalid")
+
+    action = payload["action"]
+    if action not in {"long", "short", "flat"}:
+        raise ValueError("signal action is invalid")
+
+    confidence = payload["confidence"]
+    if not isinstance(confidence, int | float) or not 0 <= float(confidence) <= 1:
+        raise ValueError("signal confidence is invalid")
+
+    leverage = payload["leverage"]
+    if not isinstance(leverage, int) or not 1 <= leverage <= 3:
+        raise ValueError("signal leverage is invalid")
+
+    entry_price = payload["entry_price"]
+    if entry_price is not None:
+        if not isinstance(entry_price, int | float):
+            raise ValueError("signal entry_price is invalid")
+        entry_price = float(entry_price)
+
+    take_profit_price = payload["take_profit_price"]
+    if take_profit_price is not None:
+        if not isinstance(take_profit_price, int | float):
+            raise ValueError("signal take_profit_price is invalid")
+        take_profit_price = float(take_profit_price)
+
+    stop_loss_price = payload["stop_loss_price"]
+    if stop_loss_price is not None:
+        if not isinstance(stop_loss_price, int | float):
+            raise ValueError("signal stop_loss_price is invalid")
+        stop_loss_price = float(stop_loss_price)
+
+    reason = payload["reason"]
+    if not isinstance(reason, str) or not reason:
+        raise ValueError("signal reason is invalid")
+
+    return TradeSignal(
+        action=action,
+        confidence=float(confidence),
+        leverage=leverage,
+        entry_price=entry_price,
+        take_profit_price=take_profit_price,
+        stop_loss_price=stop_loss_price,
+        reason=reason,
+    )

+ 12 - 0
tests/test_config.py

@@ -1,5 +1,17 @@
+import pytest
+
 import okx_codex_trader
+from okx_codex_trader.config import load_config
 
 
 def test_package_exports_version():
     assert okx_codex_trader.__version__ == "0.1.0"
+
+
+def test_load_config_requires_okx_credentials(monkeypatch):
+    monkeypatch.delenv("OKX_API_KEY", raising=False)
+    monkeypatch.delenv("OKX_API_SECRET", raising=False)
+    monkeypatch.delenv("OKX_API_PASSPHRASE", raising=False)
+
+    with pytest.raises(ValueError):
+        load_config()

+ 53 - 0
tests/test_strategy.py

@@ -0,0 +1,53 @@
+import pytest
+
+from okx_codex_trader.strategy import validate_signal
+
+
+def test_validate_signal_rejects_leverage_out_of_range():
+    with pytest.raises(ValueError):
+        validate_signal(
+            {
+                "action": "long",
+                "confidence": 0.9,
+                "leverage": 4,
+                "entry_price": None,
+                "take_profit_price": None,
+                "stop_loss_price": None,
+                "reason": "x",
+            }
+        )
+
+
+def test_validate_signal_rejects_unknown_action():
+    with pytest.raises(ValueError):
+        validate_signal(
+            {
+                "action": "hold",
+                "confidence": 0.9,
+                "leverage": 2,
+                "entry_price": None,
+                "take_profit_price": None,
+                "stop_loss_price": None,
+                "reason": "x",
+            }
+        )
+
+
+def test_validate_signal_rejects_confidence_out_of_range():
+    with pytest.raises(ValueError):
+        validate_signal(
+            {
+                "action": "long",
+                "confidence": 1.2,
+                "leverage": 2,
+                "entry_price": None,
+                "take_profit_price": None,
+                "stop_loss_price": None,
+                "reason": "x",
+            }
+        )
+
+
+def test_validate_signal_requires_full_shape():
+    with pytest.raises(ValueError):
+        validate_signal({"action": "long", "confidence": 0.9, "leverage": 2})