ema_pullback_report.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. import pandas as pd
  5. from okx_codex_trader.models import Candle
  6. from okx_codex_trader.sampled_report import (
  7. SegmentResult,
  8. generate_sampled_report,
  9. mark_to_market as _mark_to_market,
  10. trade_equity as _trade_equity,
  11. )
  12. EMA_PULLBACK_STRATEGY_DESCRIPTION = (
  13. "EMA pullback reclaim, fast-over-slow trend bias, next-open continuation entries after a fast-EMA reclaim, "
  14. "opposite close-through-fast-EMA exits, stop beyond the signal candle."
  15. )
  16. @dataclass(frozen=True)
  17. class EMAPullbackConfig:
  18. fast_ema: int = 20
  19. slow_ema: int = 50
  20. stop_buffer_pct: float = 0.005
  21. initial_equity: float = 10_000.0
  22. def _format_ts(ts: int) -> str:
  23. return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
  24. def run_ema_pullback_segment(
  25. *,
  26. candles: list[Candle],
  27. leverage: int,
  28. warmup_bars: int,
  29. config: EMAPullbackConfig = EMAPullbackConfig(),
  30. ) -> SegmentResult:
  31. closes = pd.Series([candle.close for candle in candles], dtype=float)
  32. fast_ema = closes.ewm(span=config.fast_ema, adjust=False).mean().tolist()
  33. slow_ema = closes.ewm(span=config.slow_ema, adjust=False).mean().tolist()
  34. equity = config.initial_equity
  35. ending_equity = equity
  36. peak_equity = equity
  37. max_drawdown = 0.0
  38. wins = 0
  39. trades: list[dict[str, object]] = []
  40. entries: list[dict[str, object]] = []
  41. exits: list[dict[str, object]] = []
  42. equity_curve: list[dict[str, float | int]] = []
  43. position: dict[str, object] | None = None
  44. pending_entry: dict[str, object] | None = None
  45. pending_exit = False
  46. for index in range(warmup_bars, len(candles)):
  47. candle = candles[index]
  48. if pending_exit and position is not None:
  49. exit_price = candle.open
  50. exit_equity = _trade_equity(
  51. side=str(position["side"]),
  52. margin_used=float(position["margin_used"]),
  53. entry_price=float(position["entry_price"]),
  54. exit_price=exit_price,
  55. leverage=leverage,
  56. )
  57. trades.append(
  58. {
  59. "side": "Long" if position["side"] == "long" else "Short",
  60. "entry_time": _format_ts(int(position["entry_time"])),
  61. "exit_time": _format_ts(candle.ts),
  62. "entry_price": round(float(position["entry_price"]), 4),
  63. "exit_price": round(exit_price, 4),
  64. "pnl": round(exit_equity - float(position["margin_used"]), 4),
  65. "return_pct": round(
  66. (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100,
  67. 4,
  68. ),
  69. }
  70. )
  71. exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
  72. if exit_equity > float(position["margin_used"]):
  73. wins += 1
  74. equity = exit_equity
  75. position = None
  76. pending_exit = False
  77. if pending_entry is not None and position is None and equity > 0.0:
  78. position = {
  79. "side": str(pending_entry["side"]),
  80. "entry_time": candle.ts,
  81. "entry_price": candle.open,
  82. "margin_used": equity,
  83. "stop_price": float(pending_entry["stop_price"]),
  84. }
  85. entries.append({"ts": candle.ts, "price": candle.open, "side": str(pending_entry["side"])})
  86. pending_entry = None
  87. current_equity = equity
  88. if position is not None:
  89. stop_hit = (
  90. position["side"] == "long" and candle.low <= float(position["stop_price"])
  91. ) or (
  92. position["side"] == "short" and candle.high >= float(position["stop_price"])
  93. )
  94. if stop_hit:
  95. if position["side"] == "long" and candle.open < float(position["stop_price"]):
  96. exit_price = candle.open
  97. elif position["side"] == "short" and candle.open > float(position["stop_price"]):
  98. exit_price = candle.open
  99. else:
  100. exit_price = float(position["stop_price"])
  101. exit_equity = _trade_equity(
  102. side=str(position["side"]),
  103. margin_used=float(position["margin_used"]),
  104. entry_price=float(position["entry_price"]),
  105. exit_price=exit_price,
  106. leverage=leverage,
  107. )
  108. trades.append(
  109. {
  110. "side": "Long" if position["side"] == "long" else "Short",
  111. "entry_time": _format_ts(int(position["entry_time"])),
  112. "exit_time": _format_ts(candle.ts),
  113. "entry_price": round(float(position["entry_price"]), 4),
  114. "exit_price": round(exit_price, 4),
  115. "pnl": round(exit_equity - float(position["margin_used"]), 4),
  116. "return_pct": round(
  117. (exit_equity - float(position["margin_used"])) / float(position["margin_used"]) * 100,
  118. 4,
  119. ),
  120. }
  121. )
  122. exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
  123. if exit_equity > float(position["margin_used"]):
  124. wins += 1
  125. equity = exit_equity
  126. current_equity = exit_equity
  127. position = None
  128. if current_equity > peak_equity:
  129. peak_equity = current_equity
  130. max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
  131. equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
  132. ending_equity = current_equity
  133. continue
  134. if position is not None:
  135. current_equity = _mark_to_market(
  136. side=str(position["side"]),
  137. margin_used=float(position["margin_used"]),
  138. entry_price=float(position["entry_price"]),
  139. mark_price=candle.close,
  140. leverage=leverage,
  141. )
  142. if current_equity > peak_equity:
  143. peak_equity = current_equity
  144. max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
  145. equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
  146. ending_equity = current_equity
  147. if index == len(candles) - 1 or equity <= 0.0:
  148. continue
  149. current_fast = fast_ema[index]
  150. current_slow = slow_ema[index]
  151. if current_fast != current_fast or current_slow != current_slow:
  152. continue
  153. if position is not None:
  154. exit_signal = (
  155. position["side"] == "long" and candle.close < float(current_fast)
  156. ) or (
  157. position["side"] == "short" and candle.close > float(current_fast)
  158. )
  159. if exit_signal:
  160. pending_exit = True
  161. continue
  162. if float(current_fast) > float(current_slow) and candle.low <= float(current_fast) and candle.close > float(current_fast):
  163. pending_entry = {
  164. "side": "long",
  165. "stop_price": candle.low * (1 - config.stop_buffer_pct),
  166. }
  167. elif float(current_fast) < float(current_slow) and candle.high >= float(current_fast) and candle.close < float(current_fast):
  168. pending_entry = {
  169. "side": "short",
  170. "stop_price": candle.high * (1 + config.stop_buffer_pct),
  171. }
  172. trade_count = len(trades)
  173. return SegmentResult(
  174. trade_count=trade_count,
  175. total_return=(ending_equity - config.initial_equity) / config.initial_equity,
  176. win_rate=(wins / trade_count) if trade_count else 0.0,
  177. max_drawdown=max_drawdown,
  178. trades=trades,
  179. open_position=position,
  180. candles=candles[warmup_bars:],
  181. equity_curve=equity_curve,
  182. entries=entries,
  183. exits=exits,
  184. )
  185. def generate_ema_pullback_sampled_report(
  186. *,
  187. candles: list[Candle],
  188. leverage: int,
  189. output_file: Path,
  190. symbol: str,
  191. bar: str,
  192. segments: int,
  193. window_size: int,
  194. fast_ema: int = EMAPullbackConfig.fast_ema,
  195. slow_ema: int = EMAPullbackConfig.slow_ema,
  196. stop_buffer_pct: float = EMAPullbackConfig.stop_buffer_pct,
  197. ) -> dict[str, object]:
  198. config = EMAPullbackConfig(
  199. fast_ema=fast_ema,
  200. slow_ema=slow_ema,
  201. stop_buffer_pct=stop_buffer_pct,
  202. )
  203. return generate_sampled_report(
  204. candles=candles,
  205. leverage=leverage,
  206. output_file=output_file,
  207. symbol=symbol,
  208. bar=bar,
  209. segments=segments,
  210. window_size=window_size,
  211. report_title="EMA Pullback Sampled Report",
  212. strategy_label="EMA Pullback",
  213. strategy_description=EMA_PULLBACK_STRATEGY_DESCRIPTION,
  214. strategy_params={
  215. "fast_ema": config.fast_ema,
  216. "slow_ema": config.slow_ema,
  217. "stop_buffer_pct": config.stop_buffer_pct,
  218. },
  219. run_segment=lambda *, candles, leverage, warmup_bars: run_ema_pullback_segment(
  220. candles=candles,
  221. leverage=leverage,
  222. warmup_bars=warmup_bars,
  223. config=config,
  224. ),
  225. warmup_bars=max(config.fast_ema, config.slow_ema),
  226. )