| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- from __future__ import annotations
- import argparse
- import sys
- from dataclasses import dataclass
- from pathlib import Path
- import pandas as pd
- from backtesting import Strategy
- from backtesting.lib import FractionalBacktest
- ROOT_DIR = Path(__file__).resolve().parents[1]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from okx_codex_trader.models import Candle
- from okx_codex_trader.okx_client import OkxClient
- from okx_codex_trader.rsi2_report import RSI2Config, _compute_rsi, run_rsi2_segment
- from scripts.explore_ultrashort import (
- CANDLE_CACHE_DIR,
- INITIAL_EQUITY,
- LEVERAGE,
- get_candles_cached,
- load_cached_candles,
- max_drawdown_from_equity,
- )
- @dataclass(frozen=True)
- class ExternalValidation:
- symbol: str
- bar: str
- rows: int
- internal_trades: int
- external_trades: int
- internal_final_equity: float
- external_final_equity: float
- internal_return: float
- external_return: float
- internal_max_drawdown: float
- external_max_drawdown: float
- final_equity_diff: float
- return_diff: float
- max_drawdown_diff: float
- class BacktestingRsi2(Strategy):
- trend_sma = 50
- rsi_length = 2
- rsi_long_threshold = 10.0
- rsi_short_threshold = 90.0
- exit_rsi = 50.0
- warmup_bars = 50
- def init(self) -> None:
- self.rsi = self.I(
- lambda close: _compute_rsi(pd.Series(close, dtype=float), self.rsi_length),
- self.data.Close,
- overlay=False,
- )
- self.trend = self.I(
- lambda close: pd.Series(close, dtype=float).rolling(self.trend_sma).mean().to_numpy(),
- self.data.Close,
- overlay=False,
- )
- def next(self) -> None:
- index = len(self.data) - 1
- if index < self.warmup_bars:
- return
- rsi = self.rsi[-1]
- trend = self.trend[-1]
- if pd.isna(rsi) or pd.isna(trend):
- return
- if self.position:
- if (self.position.is_long and rsi >= self.exit_rsi) or (
- self.position.is_short and rsi <= self.exit_rsi
- ):
- self.position.close()
- return
- close = self.data.Close[-1]
- if close > trend and rsi <= self.rsi_long_threshold:
- self.buy(size=0.999999999)
- elif close < trend and rsi >= self.rsi_short_threshold:
- self.sell(size=0.999999999)
- def candles_to_backtesting_frame(candles: list[Candle], config: RSI2Config) -> pd.DataFrame:
- frame = pd.DataFrame(
- [
- {
- "ts": pd.to_datetime(candle.ts, unit="ms", utc=True),
- "Open": candle.open,
- "High": candle.high,
- "Low": candle.low,
- "Close": candle.close,
- "Volume": candle.volume,
- }
- for candle in candles
- ]
- ).set_index("ts")
- return frame
- def run_backtesting_rsi2(
- *,
- candles: list[Candle],
- leverage: int,
- warmup_bars: int,
- config: RSI2Config,
- ) -> tuple[pd.Series, pd.DataFrame]:
- frame = candles_to_backtesting_frame(candles, config)
- backtest = FractionalBacktest(
- frame,
- BacktestingRsi2,
- cash=config.initial_equity,
- margin=1 / leverage,
- trade_on_close=False,
- exclusive_orders=True,
- finalize_trades=False,
- )
- stats = backtest.run(
- trend_sma=config.trend_sma,
- rsi_long_threshold=config.rsi_long_threshold,
- rsi_short_threshold=config.rsi_short_threshold,
- exit_rsi=config.exit_rsi,
- rsi_length=config.rsi_length,
- warmup_bars=warmup_bars,
- )
- return stats, stats["_trades"]
- def validate_rsi2_with_backtesting(
- *,
- candles: list[Candle],
- symbol: str,
- bar: str,
- leverage: int,
- config: RSI2Config,
- ) -> ExternalValidation:
- warmup_bars = max(config.trend_sma, config.rsi_length + 1)
- internal = run_rsi2_segment(candles=candles, leverage=leverage, warmup_bars=warmup_bars, config=config)
- external, external_trades = run_backtesting_rsi2(
- candles=candles,
- leverage=leverage,
- warmup_bars=warmup_bars,
- config=config,
- )
- internal_final_equity = float(internal.equity_curve[-1]["equity"])
- external_final_equity = float(external["Equity Final [$]"])
- internal_return = internal_final_equity / config.initial_equity - 1.0
- external_return = external_final_equity / config.initial_equity - 1.0
- external_curve = external["_equity_curve"]["Equity"].tolist()
- external_max_drawdown = max_drawdown_from_equity([float(value) for value in external_curve])
- return ExternalValidation(
- symbol=symbol,
- bar=bar,
- rows=len(candles),
- internal_trades=internal.trade_count,
- external_trades=len(external_trades),
- internal_final_equity=internal_final_equity,
- external_final_equity=external_final_equity,
- internal_return=internal_return,
- external_return=external_return,
- internal_max_drawdown=internal.max_drawdown,
- external_max_drawdown=external_max_drawdown,
- final_equity_diff=internal_final_equity - external_final_equity,
- return_diff=internal_return - external_return,
- max_drawdown_diff=internal.max_drawdown - external_max_drawdown,
- )
- def load_validation_candles(symbol: str, bar: str, limit: int, cache_dir: Path) -> list[Candle]:
- cached, _ = load_cached_candles(cache_dir, symbol, bar)
- if cached:
- return cached[-limit:] if len(cached) > limit else cached
- return get_candles_cached(OkxClient(), symbol, bar, limit, cache_dir)
- def main() -> int:
- parser = argparse.ArgumentParser()
- parser.add_argument("--symbol", default="BTC-USDT-SWAP")
- parser.add_argument("--bar", default="15m")
- parser.add_argument("--limit", type=int, default=50_000)
- parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
- parser.add_argument("--leverage", type=int, default=LEVERAGE)
- parser.add_argument("--trend-sma", type=int, default=50)
- parser.add_argument("--rsi-long-threshold", type=float, default=10.0)
- parser.add_argument("--rsi-short-threshold", type=float, default=90.0)
- parser.add_argument("--exit-rsi", type=float, default=50.0)
- args = parser.parse_args()
- config = RSI2Config(
- trend_sma=args.trend_sma,
- rsi_long_threshold=args.rsi_long_threshold,
- rsi_short_threshold=args.rsi_short_threshold,
- exit_rsi=args.exit_rsi,
- initial_equity=INITIAL_EQUITY,
- )
- candles = load_validation_candles(args.symbol, args.bar, args.limit, args.cache_dir)
- validation = validate_rsi2_with_backtesting(
- candles=candles,
- symbol=args.symbol,
- bar=args.bar,
- leverage=args.leverage,
- config=config,
- )
- print(pd.DataFrame([validation.__dict__]).to_string(index=False))
- return 0
- if __name__ == "__main__":
- raise SystemExit(main())
|