rsi2_report.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. import pandas as pd
  5. from okx_codex_trader.models import Candle
  6. from okx_codex_trader.sampled_report import (
  7. SegmentResult,
  8. generate_sampled_report,
  9. mark_to_market as _mark_to_market,
  10. trade_equity as _trade_equity,
  11. )
  12. RSI2_STRATEGY_DESCRIPTION = (
  13. "Trend-filtered RSI2 mean reversion, close-vs-SMA long and short entries, "
  14. "RSI reversion exits at next open."
  15. )
  16. @dataclass(frozen=True)
  17. class RSI2Config:
  18. trend_sma: int = 50
  19. rsi_length: int = 2
  20. rsi_long_threshold: float = 10.0
  21. rsi_short_threshold: float = 90.0
  22. exit_rsi: float = 50.0
  23. initial_equity: float = 10_000.0
  24. def _format_ts(ts: int) -> str:
  25. return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
  26. def _compute_rsi(closes: pd.Series, length: int) -> list[float]:
  27. deltas = closes.diff()
  28. gains = deltas.clip(lower=0.0)
  29. losses = -deltas.clip(upper=0.0)
  30. rsi = [float("nan")] * len(closes)
  31. if len(closes) <= length:
  32. return rsi
  33. average_gain = float(gains.iloc[1 : length + 1].mean())
  34. average_loss = float(losses.iloc[1 : length + 1].mean())
  35. for index in range(length, len(closes)):
  36. if index > length:
  37. average_gain = ((average_gain * (length - 1)) + float(gains.iloc[index])) / length
  38. average_loss = ((average_loss * (length - 1)) + float(losses.iloc[index])) / length
  39. if average_gain != average_gain or average_loss != average_loss:
  40. rsi[index] = float("nan")
  41. continue
  42. if average_loss == 0.0:
  43. rsi[index] = 100.0 if average_gain > 0.0 else 50.0
  44. continue
  45. relative_strength = average_gain / average_loss
  46. rsi[index] = 100.0 - (100.0 / (1.0 + relative_strength))
  47. return rsi
  48. def run_rsi2_segment(
  49. *,
  50. candles: list[Candle],
  51. leverage: int,
  52. warmup_bars: int,
  53. config: RSI2Config = RSI2Config(),
  54. ) -> SegmentResult:
  55. closes = pd.Series([candle.close for candle in candles], dtype=float)
  56. trend = closes.rolling(config.trend_sma).mean().tolist()
  57. rsi = _compute_rsi(closes, config.rsi_length)
  58. equity = config.initial_equity
  59. ending_equity = equity
  60. peak_equity = equity
  61. max_drawdown = 0.0
  62. wins = 0
  63. trades: list[dict[str, object]] = []
  64. entries: list[dict[str, object]] = []
  65. exits: list[dict[str, object]] = []
  66. equity_curve: list[dict[str, float | int]] = []
  67. position: dict[str, object] | None = None
  68. pending_entry_side: str | None = None
  69. pending_exit = False
  70. for index in range(warmup_bars, len(candles)):
  71. candle = candles[index]
  72. if pending_exit and position is not None:
  73. exit_price = candle.open
  74. exit_equity = _trade_equity(
  75. side=str(position["side"]),
  76. margin_used=float(position["margin_used"]),
  77. entry_price=float(position["entry_price"]),
  78. exit_price=exit_price,
  79. leverage=leverage,
  80. )
  81. trades.append(
  82. {
  83. "side": "Long" if position["side"] == "long" else "Short",
  84. "entry_time": _format_ts(int(position["entry_time"])),
  85. "exit_time": _format_ts(candle.ts),
  86. "entry_price": round(float(position["entry_price"]), 4),
  87. "exit_price": round(exit_price, 4),
  88. "pnl": round(exit_equity - float(position["margin_used"]), 4),
  89. "return_pct": round(
  90. (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100,
  91. 4,
  92. ),
  93. }
  94. )
  95. exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
  96. if exit_equity > float(position["margin_used"]):
  97. wins += 1
  98. equity = exit_equity
  99. position = None
  100. pending_exit = False
  101. if pending_entry_side is not None and position is None and equity > 0.0:
  102. position = {
  103. "side": pending_entry_side,
  104. "entry_time": candle.ts,
  105. "entry_price": candle.open,
  106. "margin_used": equity,
  107. }
  108. entries.append({"ts": candle.ts, "price": candle.open, "side": pending_entry_side})
  109. pending_entry_side = None
  110. current_equity = equity
  111. if position is not None:
  112. current_equity = _mark_to_market(
  113. side=str(position["side"]),
  114. margin_used=float(position["margin_used"]),
  115. entry_price=float(position["entry_price"]),
  116. mark_price=candle.close,
  117. leverage=leverage,
  118. )
  119. if current_equity > peak_equity:
  120. peak_equity = current_equity
  121. max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
  122. equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
  123. ending_equity = current_equity
  124. if index == len(candles) - 1 or equity <= 0.0:
  125. continue
  126. current_rsi = rsi[index]
  127. current_trend = trend[index]
  128. if current_rsi != current_rsi or current_trend != current_trend:
  129. continue
  130. if position is not None:
  131. exit_signal = (
  132. position["side"] == "long" and current_rsi >= config.exit_rsi
  133. ) or (
  134. position["side"] == "short" and current_rsi <= config.exit_rsi
  135. )
  136. if exit_signal:
  137. pending_exit = True
  138. continue
  139. if candle.close > float(current_trend) and current_rsi <= config.rsi_long_threshold:
  140. pending_entry_side = "long"
  141. elif candle.close < float(current_trend) and current_rsi >= config.rsi_short_threshold:
  142. pending_entry_side = "short"
  143. trade_count = len(trades)
  144. return SegmentResult(
  145. trade_count=trade_count,
  146. total_return=(ending_equity - config.initial_equity) / config.initial_equity,
  147. win_rate=(wins / trade_count) if trade_count else 0.0,
  148. max_drawdown=max_drawdown,
  149. trades=trades,
  150. open_position=position,
  151. candles=candles[warmup_bars:],
  152. equity_curve=equity_curve,
  153. entries=entries,
  154. exits=exits,
  155. )
  156. def generate_rsi2_sampled_report(
  157. *,
  158. candles: list[Candle],
  159. leverage: int,
  160. output_file: Path,
  161. symbol: str,
  162. bar: str,
  163. segments: int,
  164. window_size: int,
  165. trend_sma: int = RSI2Config.trend_sma,
  166. rsi_length: int = RSI2Config.rsi_length,
  167. rsi_long_threshold: float = RSI2Config.rsi_long_threshold,
  168. rsi_short_threshold: float = RSI2Config.rsi_short_threshold,
  169. exit_rsi: float = RSI2Config.exit_rsi,
  170. ) -> dict[str, object]:
  171. config = RSI2Config(
  172. trend_sma=trend_sma,
  173. rsi_length=rsi_length,
  174. rsi_long_threshold=rsi_long_threshold,
  175. rsi_short_threshold=rsi_short_threshold,
  176. exit_rsi=exit_rsi,
  177. )
  178. return generate_sampled_report(
  179. candles=candles,
  180. leverage=leverage,
  181. output_file=output_file,
  182. symbol=symbol,
  183. bar=bar,
  184. segments=segments,
  185. window_size=window_size,
  186. report_title="RSI2 Sampled Report",
  187. strategy_label="RSI2",
  188. strategy_description=RSI2_STRATEGY_DESCRIPTION,
  189. strategy_params={
  190. "trend_sma": config.trend_sma,
  191. "rsi_length": config.rsi_length,
  192. "rsi_long_threshold": config.rsi_long_threshold,
  193. "rsi_short_threshold": config.rsi_short_threshold,
  194. "exit_rsi": config.exit_rsi,
  195. },
  196. run_segment=lambda *, candles, leverage, warmup_bars: run_rsi2_segment(
  197. candles=candles,
  198. leverage=leverage,
  199. warmup_bars=warmup_bars,
  200. config=config,
  201. ),
  202. warmup_bars=max(config.trend_sma, config.rsi_length + 1),
  203. )