Prechádzať zdrojové kódy

fix: require real sma crossovers in backtest

lxy 1 mesiac pred
rodič
commit
40e7ac25f1
2 zmenil súbory, kde vykonal 42 pridanie a 33 odobranie
  1. 1 6
      okx_codex_trader/backtest.py
  2. 41 27
      tests/test_backtest.py

+ 1 - 6
okx_codex_trader/backtest.py

@@ -22,12 +22,7 @@ def run_backtest(candles: list[Candle], leverage: int) -> BacktestResult:
             continue
 
         signal: str | None = None
-        if index == 19:
-            if fast[index] > slow[index]:
-                signal = "long"
-            elif fast[index] < slow[index]:
-                signal = "short"
-        elif fast[index - 1] is not None and slow[index - 1] is not None:
+        if fast[index - 1] is not None and slow[index - 1] is not None:
             if fast[index - 1] <= slow[index - 1] and fast[index] > slow[index]:
                 signal = "long"
             elif fast[index - 1] >= slow[index - 1] and fast[index] < slow[index]:

+ 41 - 27
tests/test_backtest.py

@@ -15,45 +15,57 @@ def build_crossing_series() -> list[Candle]:
         80.0,
         80.0,
         80.0,
+        80.55555555555556,
+        81.11111111111111,
         81.66666666666667,
+        82.22222222222223,
+        82.77777777777777,
         83.33333333333333,
+        83.88888888888889,
+        84.44444444444444,
+        85.0,
         85.0,
-        86.66666666666667,
-        88.33333333333333,
-        90.0,
-        91.66666666666667,
-        93.33333333333333,
-        95.0,
-        90.0,
-        88.88888888888889,
-        87.77777777777777,
-        86.66666666666667,
-        85.55555555555556,
         84.44444444444444,
+        83.88888888888889,
         83.33333333333333,
+        82.77777777777777,
         82.22222222222223,
+        81.66666666666667,
         81.11111111111111,
+        80.55555555555556,
         80.0,
         80.0,
+        80.55555555555556,
+        81.11111111111111,
         81.66666666666667,
+        82.22222222222223,
+        82.77777777777777,
         83.33333333333333,
+        83.88888888888889,
+        84.44444444444444,
+        85.0,
         85.0,
-        86.66666666666667,
-        88.33333333333333,
-        90.0,
-        91.66666666666667,
-        93.33333333333333,
-        95.0,
-        95.0,
-        95.0,
-        95.0,
-        95.0,
-        95.0,
+        84.44444444444444,
+        83.88888888888889,
+        83.33333333333333,
+        82.77777777777777,
+        82.22222222222223,
+        81.66666666666667,
+        81.11111111111111,
+        80.55555555555556,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
+        80.0,
     ]
     opens = list(closes)
-    opens[20] = 100.0
-    opens[30] = 90.0
-    opens[40] = 80.0
+    opens[31] = 90.0
+    opens[41] = 80.0
+    opens[51] = 70.0
 
     candles = []
     for index, (open_price, close_price) in enumerate(zip(opens, closes)):
@@ -80,10 +92,12 @@ def test_backtest_runs_fixed_sma_crossover_series():
 
     assert result.initial_equity == 10_000
     assert result.trade_count == 2
-    assert result.trades[0].entry_price == candles[20].open
-    assert result.trades[0].exit_price == candles[30].open
+    assert result.trades[0].entry_price == candles[31].open
+    assert result.trades[0].exit_price == candles[41].open
     assert result.trades[0].margin_used == 10_000
     assert result.trades[1].margin_used == result.trades[0].ending_equity
+    assert result.trades[1].entry_price == candles[41].open
+    assert result.trades[1].exit_price == candles[51].open
     assert result.ending_equity == result.trades[-1].ending_equity
     assert "total_return" in result.to_dict()
     assert "max_drawdown" in result.to_dict()