Преглед на файлове

fix: track open-trade drawdown in backtest

lxy преди 1 месец
родител
ревизия
dd74d8fc81
променени са 2 файла, в които са добавени 75 реда и са изтрити 0 реда
  1. 14 0
      okx_codex_trader/backtest.py
  2. 61 0
      tests/test_backtest.py

+ 14 - 0
okx_codex_trader/backtest.py

@@ -18,6 +18,20 @@ def run_backtest(candles: list[Candle], leverage: int) -> BacktestResult:
     position: dict[str, float | str] | None = None
     position: dict[str, float | str] | None = None
 
 
     for index in range(1, len(candles) - 1):
     for index in range(1, len(candles) - 1):
+        if position is not None:
+            entry_price = float(position["entry_price"])
+            margin_used = float(position["margin_used"])
+            if position["direction"] == "long":
+                price_return = (candles[index].close - entry_price) / entry_price
+            else:
+                price_return = (entry_price - candles[index].close) / entry_price
+            marked_equity = margin_used + (margin_used * leverage * price_return)
+            if marked_equity > peak_equity:
+                peak_equity = marked_equity
+            drawdown = (peak_equity - marked_equity) / peak_equity
+            if drawdown > max_drawdown:
+                max_drawdown = drawdown
+
         if fast[index] is None or slow[index] is None:
         if fast[index] is None or slow[index] is None:
             continue
             continue
 
 

+ 61 - 0
tests/test_backtest.py

@@ -74,6 +74,35 @@ def build_crossing_series() -> list[Candle]:
     return candles
     return candles
 
 
 
 
+def build_open_position_series() -> list[Candle]:
+    candles = build_crossing_series()[:29]
+    return candles
+
+
+def build_drawdown_series() -> list[Candle]:
+    closes = [60.0] * 20 + [120.0, 75.0, 100.0] + [75.0] * 14
+    opens = list(closes)
+    opens[21] = 100.0
+    opens[36] = 100.0
+
+    candles = []
+    for index, (open_price, close_price) in enumerate(zip(opens, closes)):
+        high = max(open_price, close_price)
+        low = min(open_price, close_price)
+        candles.append(
+            Candle(
+                symbol="BTC-USDT-SWAP",
+                ts=index,
+                open=open_price,
+                high=high,
+                low=low,
+                close=close_price,
+                volume=1_000.0,
+            )
+        )
+    return candles
+
+
 def test_simple_moving_average_requires_full_window():
 def test_simple_moving_average_requires_full_window():
     candles = [
     candles = [
         Candle(symbol="BTC-USDT-SWAP", ts=index, open=close, high=close, low=close, close=close, volume=1_000.0)
         Candle(symbol="BTC-USDT-SWAP", ts=index, open=close, high=close, low=close, close=close, volume=1_000.0)
@@ -83,6 +112,18 @@ def test_simple_moving_average_requires_full_window():
     assert simple_moving_average(candles, 3) == [None, None, 20.0, 30.0]
     assert simple_moving_average(candles, 3) == [None, None, 20.0, 30.0]
 
 
 
 
+def test_backtest_rejects_invalid_leverage():
+    candles = build_crossing_series()
+
+    for leverage in (0, 4):
+        try:
+            run_backtest(candles=candles, leverage=leverage)
+        except ValueError as exc:
+            assert str(exc) == "leverage is invalid"
+        else:
+            raise AssertionError("expected ValueError")
+
+
 def test_backtest_runs_fixed_sma_crossover_series():
 def test_backtest_runs_fixed_sma_crossover_series():
     candles = build_crossing_series()
     candles = build_crossing_series()
 
 
@@ -98,3 +139,23 @@ def test_backtest_runs_fixed_sma_crossover_series():
     assert "total_return" in result.to_dict()
     assert "total_return" in result.to_dict()
     assert "max_drawdown" in result.to_dict()
     assert "max_drawdown" in result.to_dict()
     assert result.win_rate == 0.5
     assert result.win_rate == 0.5
+
+
+def test_backtest_does_not_force_close_open_position_at_series_end():
+    candles = build_open_position_series()
+
+    result = run_backtest(candles=candles, leverage=2)
+
+    assert result.trade_count == 0
+    assert result.trades == []
+
+
+def test_backtest_tracks_open_trade_drawdown_from_candle_close():
+    candles = build_drawdown_series()
+
+    result = run_backtest(candles=candles, leverage=2)
+
+    assert result.trade_count == 1
+    assert result.trades[0].entry_price == candles[21].open
+    assert result.trades[0].exit_price == candles[36].open
+    assert result.max_drawdown == 0.5