| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- from __future__ import annotations
- from dataclasses import dataclass
- from pathlib import Path
- import pandas as pd
- from okx_codex_trader.models import Candle
- from okx_codex_trader.sampled_report import (
- SegmentResult,
- generate_sampled_report,
- mark_to_market as _mark_to_market,
- trade_equity as _trade_equity,
- )
- EMA_PULLBACK_STRATEGY_DESCRIPTION = (
- "EMA pullback reclaim, fast-over-slow trend bias, next-open continuation entries after a fast-EMA reclaim, "
- "opposite close-through-fast-EMA exits, stop beyond the signal candle."
- )
- @dataclass(frozen=True)
- class EMAPullbackConfig:
- fast_ema: int = 20
- slow_ema: int = 50
- stop_buffer_pct: float = 0.005
- initial_equity: float = 10_000.0
- def _format_ts(ts: int) -> str:
- return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
- def run_ema_pullback_segment(
- *,
- candles: list[Candle],
- leverage: int,
- warmup_bars: int,
- config: EMAPullbackConfig = EMAPullbackConfig(),
- ) -> SegmentResult:
- closes = pd.Series([candle.close for candle in candles], dtype=float)
- fast_ema = closes.ewm(span=config.fast_ema, adjust=False).mean().tolist()
- slow_ema = closes.ewm(span=config.slow_ema, adjust=False).mean().tolist()
- equity = config.initial_equity
- ending_equity = equity
- peak_equity = equity
- max_drawdown = 0.0
- wins = 0
- trades: list[dict[str, object]] = []
- entries: list[dict[str, object]] = []
- exits: list[dict[str, object]] = []
- equity_curve: list[dict[str, float | int]] = []
- position: dict[str, object] | None = None
- pending_entry: dict[str, object] | None = None
- pending_exit = False
- for index in range(warmup_bars, len(candles)):
- candle = candles[index]
- if pending_exit and position is not None:
- exit_price = candle.open
- exit_equity = _trade_equity(
- side=str(position["side"]),
- margin_used=float(position["margin_used"]),
- entry_price=float(position["entry_price"]),
- exit_price=exit_price,
- leverage=leverage,
- )
- trades.append(
- {
- "side": "Long" if position["side"] == "long" else "Short",
- "entry_time": _format_ts(int(position["entry_time"])),
- "exit_time": _format_ts(candle.ts),
- "entry_price": round(float(position["entry_price"]), 4),
- "exit_price": round(exit_price, 4),
- "pnl": round(exit_equity - float(position["margin_used"]), 4),
- "return_pct": round(
- (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100,
- 4,
- ),
- }
- )
- exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
- if exit_equity > float(position["margin_used"]):
- wins += 1
- equity = exit_equity
- position = None
- pending_exit = False
- if pending_entry is not None and position is None and equity > 0.0:
- position = {
- "side": str(pending_entry["side"]),
- "entry_time": candle.ts,
- "entry_price": candle.open,
- "margin_used": equity,
- "stop_price": float(pending_entry["stop_price"]),
- }
- entries.append({"ts": candle.ts, "price": candle.open, "side": str(pending_entry["side"])})
- pending_entry = None
- current_equity = equity
- if position is not None:
- stop_hit = (
- position["side"] == "long" and candle.low <= float(position["stop_price"])
- ) or (
- position["side"] == "short" and candle.high >= float(position["stop_price"])
- )
- if stop_hit:
- if position["side"] == "long" and candle.open < float(position["stop_price"]):
- exit_price = candle.open
- elif position["side"] == "short" and candle.open > float(position["stop_price"]):
- exit_price = candle.open
- else:
- exit_price = float(position["stop_price"])
- exit_equity = _trade_equity(
- side=str(position["side"]),
- margin_used=float(position["margin_used"]),
- entry_price=float(position["entry_price"]),
- exit_price=exit_price,
- leverage=leverage,
- )
- trades.append(
- {
- "side": "Long" if position["side"] == "long" else "Short",
- "entry_time": _format_ts(int(position["entry_time"])),
- "exit_time": _format_ts(candle.ts),
- "entry_price": round(float(position["entry_price"]), 4),
- "exit_price": round(exit_price, 4),
- "pnl": round(exit_equity - float(position["margin_used"]), 4),
- "return_pct": round(
- (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100,
- 4,
- ),
- }
- )
- exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
- if exit_equity > float(position["margin_used"]):
- wins += 1
- equity = exit_equity
- current_equity = exit_equity
- position = None
- if current_equity > peak_equity:
- peak_equity = current_equity
- max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
- equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
- ending_equity = current_equity
- continue
- if position is not None:
- current_equity = _mark_to_market(
- side=str(position["side"]),
- margin_used=float(position["margin_used"]),
- entry_price=float(position["entry_price"]),
- mark_price=candle.close,
- leverage=leverage,
- )
- if current_equity > peak_equity:
- peak_equity = current_equity
- max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
- equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
- ending_equity = current_equity
- if index == len(candles) - 1 or equity <= 0.0:
- continue
- current_fast = fast_ema[index]
- current_slow = slow_ema[index]
- if current_fast != current_fast or current_slow != current_slow:
- continue
- if position is not None:
- exit_signal = (
- position["side"] == "long" and candle.close < float(current_fast)
- ) or (
- position["side"] == "short" and candle.close > float(current_fast)
- )
- if exit_signal:
- pending_exit = True
- continue
- if float(current_fast) > float(current_slow) and candle.low <= float(current_fast) and candle.close > float(current_fast):
- pending_entry = {
- "side": "long",
- "stop_price": candle.low * (1 - config.stop_buffer_pct),
- }
- elif float(current_fast) < float(current_slow) and candle.high >= float(current_fast) and candle.close < float(current_fast):
- pending_entry = {
- "side": "short",
- "stop_price": candle.high * (1 + config.stop_buffer_pct),
- }
- trade_count = len(trades)
- return SegmentResult(
- trade_count=trade_count,
- total_return=(ending_equity - config.initial_equity) / config.initial_equity,
- win_rate=(wins / trade_count) if trade_count else 0.0,
- max_drawdown=max_drawdown,
- trades=trades,
- open_position=position,
- candles=candles[warmup_bars:],
- equity_curve=equity_curve,
- entries=entries,
- exits=exits,
- )
- def generate_ema_pullback_sampled_report(
- *,
- candles: list[Candle],
- leverage: int,
- output_file: Path,
- symbol: str,
- bar: str,
- segments: int,
- window_size: int,
- fast_ema: int = EMAPullbackConfig.fast_ema,
- slow_ema: int = EMAPullbackConfig.slow_ema,
- stop_buffer_pct: float = EMAPullbackConfig.stop_buffer_pct,
- ) -> dict[str, object]:
- config = EMAPullbackConfig(
- fast_ema=fast_ema,
- slow_ema=slow_ema,
- stop_buffer_pct=stop_buffer_pct,
- )
- return generate_sampled_report(
- candles=candles,
- leverage=leverage,
- output_file=output_file,
- symbol=symbol,
- bar=bar,
- segments=segments,
- window_size=window_size,
- report_title="EMA Pullback Sampled Report",
- strategy_label="EMA Pullback",
- strategy_description=EMA_PULLBACK_STRATEGY_DESCRIPTION,
- strategy_params={
- "fast_ema": config.fast_ema,
- "slow_ema": config.slow_ema,
- "stop_buffer_pct": config.stop_buffer_pct,
- },
- run_segment=lambda *, candles, leverage, warmup_bars: run_ema_pullback_segment(
- candles=candles,
- leverage=leverage,
- warmup_bars=warmup_bars,
- config=config,
- ),
- warmup_bars=max(config.fast_ema, config.slow_ema),
- )
|