فهرست منبع

test: add external backtest validation

lxy 1 ماه پیش
والد
کامیت
0289834dec
2فایلهای تغییر یافته به همراه282 افزوده شده و 0 حذف شده
  1. 216 0
      scripts/validate_external_backtest.py
  2. 66 0
      tests/test_validate_external_backtest.py

+ 216 - 0
scripts/validate_external_backtest.py

@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+import argparse
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+
+import pandas as pd
+from backtesting import Strategy
+from backtesting.lib import FractionalBacktest
+
+ROOT_DIR = Path(__file__).resolve().parents[1]
+if str(ROOT_DIR) not in sys.path:
+    sys.path.insert(0, str(ROOT_DIR))
+
+from okx_codex_trader.models import Candle
+from okx_codex_trader.okx_client import OkxClient
+from okx_codex_trader.rsi2_report import RSI2Config, _compute_rsi, run_rsi2_segment
+from scripts.explore_ultrashort import (
+    CANDLE_CACHE_DIR,
+    INITIAL_EQUITY,
+    LEVERAGE,
+    get_candles_cached,
+    load_cached_candles,
+    max_drawdown_from_equity,
+)
+
+
+@dataclass(frozen=True)
+class ExternalValidation:
+    symbol: str
+    bar: str
+    rows: int
+    internal_trades: int
+    external_trades: int
+    internal_final_equity: float
+    external_final_equity: float
+    internal_return: float
+    external_return: float
+    internal_max_drawdown: float
+    external_max_drawdown: float
+    final_equity_diff: float
+    return_diff: float
+    max_drawdown_diff: float
+
+
+class BacktestingRsi2(Strategy):
+    trend_sma = 50
+    rsi_length = 2
+    rsi_long_threshold = 10.0
+    rsi_short_threshold = 90.0
+    exit_rsi = 50.0
+    warmup_bars = 50
+
+    def init(self) -> None:
+        self.rsi = self.I(
+            lambda close: _compute_rsi(pd.Series(close, dtype=float), self.rsi_length),
+            self.data.Close,
+            overlay=False,
+        )
+        self.trend = self.I(
+            lambda close: pd.Series(close, dtype=float).rolling(self.trend_sma).mean().to_numpy(),
+            self.data.Close,
+            overlay=False,
+        )
+
+    def next(self) -> None:
+        index = len(self.data) - 1
+        if index < self.warmup_bars:
+            return
+
+        rsi = self.rsi[-1]
+        trend = self.trend[-1]
+        if pd.isna(rsi) or pd.isna(trend):
+            return
+
+        if self.position:
+            if (self.position.is_long and rsi >= self.exit_rsi) or (
+                self.position.is_short and rsi <= self.exit_rsi
+            ):
+                self.position.close()
+            return
+
+        close = self.data.Close[-1]
+        if close > trend and rsi <= self.rsi_long_threshold:
+            self.buy(size=0.999999999)
+        elif close < trend and rsi >= self.rsi_short_threshold:
+            self.sell(size=0.999999999)
+
+
+def candles_to_backtesting_frame(candles: list[Candle], config: RSI2Config) -> pd.DataFrame:
+    frame = pd.DataFrame(
+        [
+            {
+                "ts": pd.to_datetime(candle.ts, unit="ms", utc=True),
+                "Open": candle.open,
+                "High": candle.high,
+                "Low": candle.low,
+                "Close": candle.close,
+                "Volume": candle.volume,
+            }
+            for candle in candles
+        ]
+    ).set_index("ts")
+    return frame
+
+
+def run_backtesting_rsi2(
+    *,
+    candles: list[Candle],
+    leverage: int,
+    warmup_bars: int,
+    config: RSI2Config,
+) -> tuple[pd.Series, pd.DataFrame]:
+    frame = candles_to_backtesting_frame(candles, config)
+    backtest = FractionalBacktest(
+        frame,
+        BacktestingRsi2,
+        cash=config.initial_equity,
+        margin=1 / leverage,
+        trade_on_close=False,
+        exclusive_orders=True,
+        finalize_trades=False,
+    )
+    stats = backtest.run(
+        trend_sma=config.trend_sma,
+        rsi_long_threshold=config.rsi_long_threshold,
+        rsi_short_threshold=config.rsi_short_threshold,
+        exit_rsi=config.exit_rsi,
+        rsi_length=config.rsi_length,
+        warmup_bars=warmup_bars,
+    )
+    return stats, stats["_trades"]
+
+
+def validate_rsi2_with_backtesting(
+    *,
+    candles: list[Candle],
+    symbol: str,
+    bar: str,
+    leverage: int,
+    config: RSI2Config,
+) -> ExternalValidation:
+    warmup_bars = max(config.trend_sma, config.rsi_length + 1)
+    internal = run_rsi2_segment(candles=candles, leverage=leverage, warmup_bars=warmup_bars, config=config)
+    external, external_trades = run_backtesting_rsi2(
+        candles=candles,
+        leverage=leverage,
+        warmup_bars=warmup_bars,
+        config=config,
+    )
+    internal_final_equity = float(internal.equity_curve[-1]["equity"])
+    external_final_equity = float(external["Equity Final [$]"])
+    internal_return = internal_final_equity / config.initial_equity - 1.0
+    external_return = external_final_equity / config.initial_equity - 1.0
+    external_curve = external["_equity_curve"]["Equity"].tolist()
+    external_max_drawdown = max_drawdown_from_equity([float(value) for value in external_curve])
+    return ExternalValidation(
+        symbol=symbol,
+        bar=bar,
+        rows=len(candles),
+        internal_trades=internal.trade_count,
+        external_trades=len(external_trades),
+        internal_final_equity=internal_final_equity,
+        external_final_equity=external_final_equity,
+        internal_return=internal_return,
+        external_return=external_return,
+        internal_max_drawdown=internal.max_drawdown,
+        external_max_drawdown=external_max_drawdown,
+        final_equity_diff=internal_final_equity - external_final_equity,
+        return_diff=internal_return - external_return,
+        max_drawdown_diff=internal.max_drawdown - external_max_drawdown,
+    )
+
+
+def load_validation_candles(symbol: str, bar: str, limit: int, cache_dir: Path) -> list[Candle]:
+    cached, _ = load_cached_candles(cache_dir, symbol, bar)
+    if cached:
+        return cached[-limit:] if len(cached) > limit else cached
+    return get_candles_cached(OkxClient(), symbol, bar, limit, cache_dir)
+
+
+def main() -> int:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--symbol", default="BTC-USDT-SWAP")
+    parser.add_argument("--bar", default="15m")
+    parser.add_argument("--limit", type=int, default=50_000)
+    parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
+    parser.add_argument("--leverage", type=int, default=LEVERAGE)
+    parser.add_argument("--trend-sma", type=int, default=50)
+    parser.add_argument("--rsi-long-threshold", type=float, default=10.0)
+    parser.add_argument("--rsi-short-threshold", type=float, default=90.0)
+    parser.add_argument("--exit-rsi", type=float, default=50.0)
+    args = parser.parse_args()
+
+    config = RSI2Config(
+        trend_sma=args.trend_sma,
+        rsi_long_threshold=args.rsi_long_threshold,
+        rsi_short_threshold=args.rsi_short_threshold,
+        exit_rsi=args.exit_rsi,
+        initial_equity=INITIAL_EQUITY,
+    )
+    candles = load_validation_candles(args.symbol, args.bar, args.limit, args.cache_dir)
+    validation = validate_rsi2_with_backtesting(
+        candles=candles,
+        symbol=args.symbol,
+        bar=args.bar,
+        leverage=args.leverage,
+        config=config,
+    )
+    print(pd.DataFrame([validation.__dict__]).to_string(index=False))
+    return 0
+
+
+if __name__ == "__main__":
+    raise SystemExit(main())

+ 66 - 0
tests/test_validate_external_backtest.py

@@ -0,0 +1,66 @@
+import math
+import importlib.util
+import sys
+from pathlib import Path
+
+import pytest
+
+from okx_codex_trader.models import Candle
+from okx_codex_trader.rsi2_report import RSI2Config
+
+
+def load_validation_module():
+    path = Path(__file__).resolve().parents[1] / "scripts" / "validate_external_backtest.py"
+    spec = importlib.util.spec_from_file_location("validate_external_backtest", path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    assert spec.loader is not None
+    sys.modules[spec.name] = module
+    spec.loader.exec_module(module)
+    return module
+
+
+def build_oscillating_candles(count: int) -> list[Candle]:
+    candles = []
+    previous_close = 100.0
+    for index in range(count):
+        close = 100.0 + index * 0.04 + math.sin(index / 3.0) * 4.0
+        open_price = previous_close
+        high = max(open_price, close) + 0.5
+        low = min(open_price, close) - 0.5
+        candles.append(
+            Candle(
+                symbol="BTC-USDT-SWAP",
+                ts=index * 60_000,
+                open=open_price,
+                high=high,
+                low=low,
+                close=close,
+                volume=1_000.0,
+            )
+        )
+        previous_close = close
+    return candles
+
+
+def test_rsi2_matches_backtesting_py_on_same_signals_and_execution_model():
+    module = load_validation_module()
+    result = module.validate_rsi2_with_backtesting(
+        candles=build_oscillating_candles(500),
+        symbol="BTC-USDT-SWAP",
+        bar="1m",
+        leverage=3,
+        config=RSI2Config(
+            trend_sma=20,
+            rsi_length=2,
+            rsi_long_threshold=20.0,
+            rsi_short_threshold=80.0,
+            exit_rsi=50.0,
+        ),
+    )
+
+    assert result.internal_trades == result.external_trades
+    assert result.internal_trades > 0
+    assert result.final_equity_diff == pytest.approx(0.0, abs=0.02)
+    assert result.return_diff == pytest.approx(0.0, abs=0.000002)
+    assert result.max_drawdown_diff == pytest.approx(0.0, abs=0.000002)