from __future__ import annotations from dataclasses import dataclass from pathlib import Path import pandas as pd from okx_codex_trader.models import Candle from okx_codex_trader.sampled_report import ( SegmentResult, generate_sampled_report, mark_to_market as _mark_to_market, trade_equity as _trade_equity, ) RSI2_STRATEGY_DESCRIPTION = ( "Trend-filtered RSI2 mean reversion, close-vs-SMA long and short entries, " "RSI reversion exits at next open." ) @dataclass(frozen=True) class RSI2Config: trend_sma: int = 50 rsi_length: int = 2 rsi_long_threshold: float = 10.0 rsi_short_threshold: float = 90.0 exit_rsi: float = 50.0 initial_equity: float = 10_000.0 def _format_ts(ts: int) -> str: return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M") def _compute_rsi(closes: pd.Series, length: int) -> list[float]: deltas = closes.diff() gains = deltas.clip(lower=0.0) losses = -deltas.clip(upper=0.0) rsi = [float("nan")] * len(closes) if len(closes) <= length: return rsi average_gain = float(gains.iloc[1 : length + 1].mean()) average_loss = float(losses.iloc[1 : length + 1].mean()) for index in range(length, len(closes)): if index > length: average_gain = ((average_gain * (length - 1)) + float(gains.iloc[index])) / length average_loss = ((average_loss * (length - 1)) + float(losses.iloc[index])) / length if average_gain != average_gain or average_loss != average_loss: rsi[index] = float("nan") continue if average_loss == 0.0: rsi[index] = 100.0 if average_gain > 0.0 else 50.0 continue relative_strength = average_gain / average_loss rsi[index] = 100.0 - (100.0 / (1.0 + relative_strength)) return rsi def run_rsi2_segment( *, candles: list[Candle], leverage: int, warmup_bars: int, config: RSI2Config = RSI2Config(), ) -> SegmentResult: closes = pd.Series([candle.close for candle in candles], dtype=float) trend = closes.rolling(config.trend_sma).mean().tolist() rsi = _compute_rsi(closes, config.rsi_length) equity = config.initial_equity ending_equity = equity peak_equity = equity max_drawdown = 0.0 wins = 0 trades: list[dict[str, object]] = [] entries: list[dict[str, object]] = [] exits: list[dict[str, object]] = [] equity_curve: list[dict[str, float | int]] = [] position: dict[str, object] | None = None pending_entry_side: str | None = None pending_exit = False for index in range(warmup_bars, len(candles)): candle = candles[index] if pending_exit and position is not None: exit_price = candle.open exit_equity = _trade_equity( side=str(position["side"]), margin_used=float(position["margin_used"]), entry_price=float(position["entry_price"]), exit_price=exit_price, leverage=leverage, ) trades.append( { "side": "Long" if position["side"] == "long" else "Short", "entry_time": _format_ts(int(position["entry_time"])), "exit_time": _format_ts(candle.ts), "entry_price": round(float(position["entry_price"]), 4), "exit_price": round(exit_price, 4), "pnl": round(exit_equity - float(position["margin_used"]), 4), "return_pct": round( (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100, 4, ), } ) exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]}) if exit_equity > float(position["margin_used"]): wins += 1 equity = exit_equity position = None pending_exit = False if pending_entry_side is not None and position is None and equity > 0.0: position = { "side": pending_entry_side, "entry_time": candle.ts, "entry_price": candle.open, "margin_used": equity, } entries.append({"ts": candle.ts, "price": candle.open, "side": pending_entry_side}) pending_entry_side = None current_equity = equity if position is not None: current_equity = _mark_to_market( side=str(position["side"]), margin_used=float(position["margin_used"]), entry_price=float(position["entry_price"]), mark_price=candle.close, leverage=leverage, ) if current_equity > peak_equity: peak_equity = current_equity max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity) equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close}) ending_equity = current_equity if index == len(candles) - 1 or equity <= 0.0: continue current_rsi = rsi[index] current_trend = trend[index] if current_rsi != current_rsi or current_trend != current_trend: continue if position is not None: exit_signal = ( position["side"] == "long" and current_rsi >= config.exit_rsi ) or ( position["side"] == "short" and current_rsi <= config.exit_rsi ) if exit_signal: pending_exit = True continue if candle.close > float(current_trend) and current_rsi <= config.rsi_long_threshold: pending_entry_side = "long" elif candle.close < float(current_trend) and current_rsi >= config.rsi_short_threshold: pending_entry_side = "short" trade_count = len(trades) return SegmentResult( trade_count=trade_count, total_return=(ending_equity - config.initial_equity) / config.initial_equity, win_rate=(wins / trade_count) if trade_count else 0.0, max_drawdown=max_drawdown, trades=trades, open_position=position, candles=candles[warmup_bars:], equity_curve=equity_curve, entries=entries, exits=exits, ) def generate_rsi2_sampled_report( *, candles: list[Candle], leverage: int, output_file: Path, symbol: str, bar: str, segments: int, window_size: int, trend_sma: int = RSI2Config.trend_sma, rsi_length: int = RSI2Config.rsi_length, rsi_long_threshold: float = RSI2Config.rsi_long_threshold, rsi_short_threshold: float = RSI2Config.rsi_short_threshold, exit_rsi: float = RSI2Config.exit_rsi, ) -> dict[str, object]: config = RSI2Config( trend_sma=trend_sma, rsi_length=rsi_length, rsi_long_threshold=rsi_long_threshold, rsi_short_threshold=rsi_short_threshold, exit_rsi=exit_rsi, ) return generate_sampled_report( candles=candles, leverage=leverage, output_file=output_file, symbol=symbol, bar=bar, segments=segments, window_size=window_size, report_title="RSI2 Sampled Report", strategy_label="RSI2", strategy_description=RSI2_STRATEGY_DESCRIPTION, strategy_params={ "trend_sma": config.trend_sma, "rsi_length": config.rsi_length, "rsi_long_threshold": config.rsi_long_threshold, "rsi_short_threshold": config.rsi_short_threshold, "exit_rsi": config.exit_rsi, }, run_segment=lambda *, candles, leverage, warmup_bars: run_rsi2_segment( candles=candles, leverage=leverage, warmup_bars=warmup_bars, config=config, ), warmup_bars=max(config.trend_sma, config.rsi_length + 1), )