Kaynağa Gözat

feat: add deterministic sma backtest engine

lxy 1 ay önce
ebeveyn
işleme
2bbaa55b9d

+ 87 - 0
okx_codex_trader/backtest.py

@@ -0,0 +1,87 @@
+from okx_codex_trader.models import BacktestResult, BacktestTrade, Candle
+from okx_codex_trader.strategy import simple_moving_average
+
+
+def run_backtest(candles: list[Candle], leverage: int) -> BacktestResult:
+    if leverage is True or leverage is False or not isinstance(leverage, int) or not 1 <= leverage <= 3:
+        raise ValueError("leverage is invalid")
+
+    fast = simple_moving_average(candles, 10)
+    slow = simple_moving_average(candles, 20)
+
+    initial_equity = 10_000.0
+    equity = initial_equity
+    trades: list[BacktestTrade] = []
+    wins = 0
+    peak_equity = initial_equity
+    max_drawdown = 0.0
+    position: dict[str, float | str] | None = None
+
+    for index in range(len(candles) - 1):
+        if fast[index] is None or slow[index] is None:
+            continue
+
+        signal: str | None = None
+        if index == 19:
+            if fast[index] > slow[index]:
+                signal = "long"
+            elif fast[index] < slow[index]:
+                signal = "short"
+        elif fast[index - 1] is not None and slow[index - 1] is not None:
+            if fast[index - 1] <= slow[index - 1] and fast[index] > slow[index]:
+                signal = "long"
+            elif fast[index - 1] >= slow[index - 1] and fast[index] < slow[index]:
+                signal = "short"
+
+        if signal is None:
+            continue
+
+        execution_price = candles[index + 1].open
+
+        if position is not None and position["direction"] != signal:
+            entry_price = float(position["entry_price"])
+            margin_used = float(position["margin_used"])
+            if position["direction"] == "long":
+                price_return = (execution_price - entry_price) / entry_price
+            else:
+                price_return = (entry_price - execution_price) / entry_price
+
+            ending_equity = margin_used + (margin_used * leverage * price_return)
+            trades.append(
+                BacktestTrade(
+                    direction=str(position["direction"]),
+                    entry_price=entry_price,
+                    exit_price=execution_price,
+                    margin_used=margin_used,
+                    ending_equity=ending_equity,
+                )
+            )
+            equity = ending_equity
+            if ending_equity > float(position["margin_used"]):
+                wins += 1
+            if equity > peak_equity:
+                peak_equity = equity
+            drawdown = (peak_equity - equity) / peak_equity
+            if drawdown > max_drawdown:
+                max_drawdown = drawdown
+            position = None
+
+        if position is None:
+            position = {
+                "direction": signal,
+                "entry_price": execution_price,
+                "margin_used": equity,
+            }
+
+    trade_count = len(trades)
+    win_rate = wins / trade_count if trade_count else 0.0
+
+    return BacktestResult(
+        initial_equity=initial_equity,
+        ending_equity=equity,
+        total_return=(equity - initial_equity) / initial_equity,
+        max_drawdown=max_drawdown,
+        win_rate=win_rate,
+        trade_count=trade_count,
+        trades=trades,
+    )

+ 17 - 1
okx_codex_trader/strategy.py

@@ -1,6 +1,22 @@
 from typing import Mapping
 
-from okx_codex_trader.models import TradeSignal
+from okx_codex_trader.models import Candle, TradeSignal
+
+
+def simple_moving_average(candles: list[Candle], window: int) -> list[float | None]:
+    averages: list[float | None] = []
+    running_total = 0.0
+
+    for index, candle in enumerate(candles):
+        running_total += candle.close
+        if index >= window:
+            running_total -= candles[index - window].close
+        if index + 1 < window:
+            averages.append(None)
+            continue
+        averages.append(running_total / window)
+
+    return averages
 
 
 def validate_signal(payload: Mapping[str, object]) -> TradeSignal:

+ 90 - 0
tests/test_backtest.py

@@ -0,0 +1,90 @@
+from okx_codex_trader.backtest import run_backtest
+from okx_codex_trader.models import Candle
+
+
+def build_crossing_series() -> list[Candle]:
+    closes = [
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        81.66666666666667,
+        83.33333333333333,
+        85.0,
+        86.66666666666667,
+        88.33333333333333,
+        90.0,
+        91.66666666666667,
+        93.33333333333333,
+        95.0,
+        90.0,
+        88.88888888888889,
+        87.77777777777777,
+        86.66666666666667,
+        85.55555555555556,
+        84.44444444444444,
+        83.33333333333333,
+        82.22222222222223,
+        81.11111111111111,
+        80.0,
+        80.0,
+        81.66666666666667,
+        83.33333333333333,
+        85.0,
+        86.66666666666667,
+        88.33333333333333,
+        90.0,
+        91.66666666666667,
+        93.33333333333333,
+        95.0,
+        95.0,
+        95.0,
+        95.0,
+        95.0,
+        95.0,
+    ]
+    opens = list(closes)
+    opens[20] = 100.0
+    opens[30] = 90.0
+    opens[40] = 80.0
+
+    candles = []
+    for index, (open_price, close_price) in enumerate(zip(opens, closes)):
+        high = max(open_price, close_price)
+        low = min(open_price, close_price)
+        candles.append(
+            Candle(
+                symbol="BTC-USDT-SWAP",
+                ts=index,
+                open=open_price,
+                high=high,
+                low=low,
+                close=close_price,
+                volume=1_000.0,
+            )
+        )
+    return candles
+
+
+def test_backtest_runs_fixed_sma_crossover_series():
+    candles = build_crossing_series()
+
+    result = run_backtest(candles=candles, leverage=2)
+
+    assert result.initial_equity == 10_000
+    assert result.trade_count == 2
+    assert result.trades[0].entry_price == candles[20].open
+    assert result.trades[0].exit_price == candles[30].open
+    assert result.trades[0].margin_used == 10_000
+    assert result.trades[1].margin_used == result.trades[0].ending_equity
+    assert result.ending_equity == result.trades[-1].ending_equity
+    assert "total_return" in result.to_dict()
+    assert "max_drawdown" in result.to_dict()
+    assert result.win_rate == 0.5