validate_external_backtest.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. import pandas as pd
  7. from backtesting import Strategy
  8. from backtesting.lib import FractionalBacktest
  9. ROOT_DIR = Path(__file__).resolve().parents[1]
  10. if str(ROOT_DIR) not in sys.path:
  11. sys.path.insert(0, str(ROOT_DIR))
  12. from okx_codex_trader.models import Candle
  13. from okx_codex_trader.okx_client import OkxClient
  14. from okx_codex_trader.rsi2_report import RSI2Config, _compute_rsi, run_rsi2_segment
  15. from scripts.explore_ultrashort import (
  16. CANDLE_CACHE_DIR,
  17. INITIAL_EQUITY,
  18. LEVERAGE,
  19. get_candles_cached,
  20. load_cached_candles,
  21. max_drawdown_from_equity,
  22. )
  23. @dataclass(frozen=True)
  24. class ExternalValidation:
  25. symbol: str
  26. bar: str
  27. rows: int
  28. internal_trades: int
  29. external_trades: int
  30. internal_final_equity: float
  31. external_final_equity: float
  32. internal_return: float
  33. external_return: float
  34. internal_max_drawdown: float
  35. external_max_drawdown: float
  36. final_equity_diff: float
  37. return_diff: float
  38. max_drawdown_diff: float
  39. class BacktestingRsi2(Strategy):
  40. trend_sma = 50
  41. rsi_length = 2
  42. rsi_long_threshold = 10.0
  43. rsi_short_threshold = 90.0
  44. exit_rsi = 50.0
  45. warmup_bars = 50
  46. def init(self) -> None:
  47. self.rsi = self.I(
  48. lambda close: _compute_rsi(pd.Series(close, dtype=float), self.rsi_length),
  49. self.data.Close,
  50. overlay=False,
  51. )
  52. self.trend = self.I(
  53. lambda close: pd.Series(close, dtype=float).rolling(self.trend_sma).mean().to_numpy(),
  54. self.data.Close,
  55. overlay=False,
  56. )
  57. def next(self) -> None:
  58. index = len(self.data) - 1
  59. if index < self.warmup_bars:
  60. return
  61. rsi = self.rsi[-1]
  62. trend = self.trend[-1]
  63. if pd.isna(rsi) or pd.isna(trend):
  64. return
  65. if self.position:
  66. if (self.position.is_long and rsi >= self.exit_rsi) or (
  67. self.position.is_short and rsi <= self.exit_rsi
  68. ):
  69. self.position.close()
  70. return
  71. close = self.data.Close[-1]
  72. if close > trend and rsi <= self.rsi_long_threshold:
  73. self.buy(size=0.999999999)
  74. elif close < trend and rsi >= self.rsi_short_threshold:
  75. self.sell(size=0.999999999)
  76. def candles_to_backtesting_frame(candles: list[Candle], config: RSI2Config) -> pd.DataFrame:
  77. frame = pd.DataFrame(
  78. [
  79. {
  80. "ts": pd.to_datetime(candle.ts, unit="ms", utc=True),
  81. "Open": candle.open,
  82. "High": candle.high,
  83. "Low": candle.low,
  84. "Close": candle.close,
  85. "Volume": candle.volume,
  86. }
  87. for candle in candles
  88. ]
  89. ).set_index("ts")
  90. return frame
  91. def run_backtesting_rsi2(
  92. *,
  93. candles: list[Candle],
  94. leverage: int,
  95. warmup_bars: int,
  96. config: RSI2Config,
  97. ) -> tuple[pd.Series, pd.DataFrame]:
  98. frame = candles_to_backtesting_frame(candles, config)
  99. backtest = FractionalBacktest(
  100. frame,
  101. BacktestingRsi2,
  102. cash=config.initial_equity,
  103. margin=1 / leverage,
  104. trade_on_close=False,
  105. exclusive_orders=True,
  106. finalize_trades=False,
  107. )
  108. stats = backtest.run(
  109. trend_sma=config.trend_sma,
  110. rsi_long_threshold=config.rsi_long_threshold,
  111. rsi_short_threshold=config.rsi_short_threshold,
  112. exit_rsi=config.exit_rsi,
  113. rsi_length=config.rsi_length,
  114. warmup_bars=warmup_bars,
  115. )
  116. return stats, stats["_trades"]
  117. def validate_rsi2_with_backtesting(
  118. *,
  119. candles: list[Candle],
  120. symbol: str,
  121. bar: str,
  122. leverage: int,
  123. config: RSI2Config,
  124. ) -> ExternalValidation:
  125. warmup_bars = max(config.trend_sma, config.rsi_length + 1)
  126. internal = run_rsi2_segment(candles=candles, leverage=leverage, warmup_bars=warmup_bars, config=config)
  127. external, external_trades = run_backtesting_rsi2(
  128. candles=candles,
  129. leverage=leverage,
  130. warmup_bars=warmup_bars,
  131. config=config,
  132. )
  133. internal_final_equity = float(internal.equity_curve[-1]["equity"])
  134. external_final_equity = float(external["Equity Final [$]"])
  135. internal_return = internal_final_equity / config.initial_equity - 1.0
  136. external_return = external_final_equity / config.initial_equity - 1.0
  137. external_curve = external["_equity_curve"]["Equity"].tolist()
  138. external_max_drawdown = max_drawdown_from_equity([float(value) for value in external_curve])
  139. return ExternalValidation(
  140. symbol=symbol,
  141. bar=bar,
  142. rows=len(candles),
  143. internal_trades=internal.trade_count,
  144. external_trades=len(external_trades),
  145. internal_final_equity=internal_final_equity,
  146. external_final_equity=external_final_equity,
  147. internal_return=internal_return,
  148. external_return=external_return,
  149. internal_max_drawdown=internal.max_drawdown,
  150. external_max_drawdown=external_max_drawdown,
  151. final_equity_diff=internal_final_equity - external_final_equity,
  152. return_diff=internal_return - external_return,
  153. max_drawdown_diff=internal.max_drawdown - external_max_drawdown,
  154. )
  155. def load_validation_candles(symbol: str, bar: str, limit: int, cache_dir: Path) -> list[Candle]:
  156. cached, _ = load_cached_candles(cache_dir, symbol, bar)
  157. if cached:
  158. return cached[-limit:] if len(cached) > limit else cached
  159. return get_candles_cached(OkxClient(), symbol, bar, limit, cache_dir)
  160. def main() -> int:
  161. parser = argparse.ArgumentParser()
  162. parser.add_argument("--symbol", default="BTC-USDT-SWAP")
  163. parser.add_argument("--bar", default="15m")
  164. parser.add_argument("--limit", type=int, default=50_000)
  165. parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
  166. parser.add_argument("--leverage", type=int, default=LEVERAGE)
  167. parser.add_argument("--trend-sma", type=int, default=50)
  168. parser.add_argument("--rsi-long-threshold", type=float, default=10.0)
  169. parser.add_argument("--rsi-short-threshold", type=float, default=90.0)
  170. parser.add_argument("--exit-rsi", type=float, default=50.0)
  171. args = parser.parse_args()
  172. config = RSI2Config(
  173. trend_sma=args.trend_sma,
  174. rsi_long_threshold=args.rsi_long_threshold,
  175. rsi_short_threshold=args.rsi_short_threshold,
  176. exit_rsi=args.exit_rsi,
  177. initial_equity=INITIAL_EQUITY,
  178. )
  179. candles = load_validation_candles(args.symbol, args.bar, args.limit, args.cache_dir)
  180. validation = validate_rsi2_with_backtesting(
  181. candles=candles,
  182. symbol=args.symbol,
  183. bar=args.bar,
  184. leverage=args.leverage,
  185. config=config,
  186. )
  187. print(pd.DataFrame([validation.__dict__]).to_string(index=False))
  188. return 0
  189. if __name__ == "__main__":
  190. raise SystemExit(main())