test_ema_pullback_report.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import pytest
  2. from okx_codex_trader import ema_pullback_report
  3. from okx_codex_trader.ema_pullback_report import (
  4. EMAPullbackConfig,
  5. generate_ema_pullback_sampled_report,
  6. run_ema_pullback_segment,
  7. )
  8. from okx_codex_trader.models import Candle
  9. from okx_codex_trader.sampled_report import SegmentResult
  10. def make_candle(index: int, open_price: float, high: float, low: float, close: float) -> Candle:
  11. return Candle(
  12. symbol="BTC-USDT-SWAP",
  13. ts=index * 60_000,
  14. open=open_price,
  15. high=high,
  16. low=low,
  17. close=close,
  18. volume=1_000.0 + index,
  19. )
  20. def build_linear_candles(count: int) -> list[Candle]:
  21. candles: list[Candle] = []
  22. for index in range(count):
  23. price = 100.0 + index
  24. candles.append(make_candle(index, price, price + 1.0, price - 1.0, price))
  25. return candles
  26. def build_long_trade_fixture() -> list[Candle]:
  27. return [
  28. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  29. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  30. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  31. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  32. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  33. make_candle(5, 104.0, 105.0, 103.2, 104.5),
  34. make_candle(6, 104.2, 104.6, 102.8, 103.0),
  35. make_candle(7, 104.5, 105.0, 104.0, 104.5),
  36. ]
  37. def build_short_trade_fixture() -> list[Candle]:
  38. return [
  39. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  40. make_candle(1, 98.0, 99.0, 97.0, 98.0),
  41. make_candle(2, 96.0, 97.0, 95.0, 96.0),
  42. make_candle(3, 97.0, 98.0, 96.0, 97.0),
  43. make_candle(4, 96.0, 97.0, 94.5, 95.0),
  44. make_candle(5, 96.5, 96.6, 95.0, 95.5),
  45. make_candle(6, 95.8, 97.4, 95.6, 97.0),
  46. make_candle(7, 96.0, 96.2, 95.5, 96.0),
  47. ]
  48. def build_stop_priority_fixture() -> list[Candle]:
  49. return [
  50. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  51. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  52. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  53. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  54. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  55. make_candle(5, 104.0, 104.3, 101.0, 102.0),
  56. make_candle(6, 105.0, 106.0, 104.5, 105.0),
  57. ]
  58. def build_gap_through_stop_fixture() -> list[Candle]:
  59. return [
  60. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  61. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  62. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  63. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  64. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  65. make_candle(5, 102.0, 103.0, 101.0, 102.5),
  66. ]
  67. def build_final_bar_signal_fixture() -> list[Candle]:
  68. return [
  69. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  70. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  71. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  72. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  73. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  74. ]
  75. def build_open_tail_fixture() -> list[Candle]:
  76. return [
  77. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  78. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  79. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  80. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  81. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  82. make_candle(5, 104.0, 106.5, 103.5, 106.0),
  83. ]
  84. def build_depleted_equity_fixture() -> list[Candle]:
  85. return [
  86. make_candle(0, 100.0, 101.0, 99.0, 100.0),
  87. make_candle(1, 102.0, 103.0, 101.0, 102.0),
  88. make_candle(2, 104.0, 105.0, 103.0, 104.0),
  89. make_candle(3, 103.0, 104.0, 102.0, 103.0),
  90. make_candle(4, 104.0, 106.0, 103.0, 105.0),
  91. make_candle(5, 104.0, 104.3, 101.0, 102.0),
  92. make_candle(6, 106.0, 108.0, 105.0, 107.0),
  93. make_candle(7, 106.5, 107.0, 106.0, 106.5),
  94. ]
  95. def test_run_ema_pullback_segment_produces_long_trade():
  96. result = run_ema_pullback_segment(
  97. candles=build_long_trade_fixture(),
  98. leverage=2,
  99. warmup_bars=4,
  100. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  101. )
  102. assert isinstance(result, SegmentResult)
  103. assert result.trade_count == 1
  104. assert result.trades[0]["side"] == "Long"
  105. assert result.trades[0]["entry_price"] == pytest.approx(104.0)
  106. assert result.trades[0]["exit_price"] == pytest.approx(104.5)
  107. assert result.open_position is None
  108. def test_run_ema_pullback_segment_produces_short_trade():
  109. result = run_ema_pullback_segment(
  110. candles=build_short_trade_fixture(),
  111. leverage=2,
  112. warmup_bars=4,
  113. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  114. )
  115. assert isinstance(result, SegmentResult)
  116. assert result.trade_count == 1
  117. assert result.trades[0]["side"] == "Short"
  118. assert result.trades[0]["entry_price"] == pytest.approx(96.5)
  119. assert result.trades[0]["exit_price"] == pytest.approx(96.0)
  120. assert result.open_position is None
  121. def test_run_ema_pullback_segment_stop_priority_is_correct():
  122. result = run_ema_pullback_segment(
  123. candles=build_stop_priority_fixture(),
  124. leverage=2,
  125. warmup_bars=4,
  126. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  127. )
  128. assert result.trade_count == 1
  129. assert len(result.entries) == 1
  130. assert result.trades[0]["exit_price"] == pytest.approx(102.485)
  131. assert result.open_position is None
  132. def test_run_ema_pullback_segment_exits_gap_through_stop_at_open():
  133. result = run_ema_pullback_segment(
  134. candles=build_gap_through_stop_fixture(),
  135. leverage=2,
  136. warmup_bars=4,
  137. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  138. )
  139. assert result.trade_count == 1
  140. assert result.trades[0]["exit_price"] == pytest.approx(102.0)
  141. assert result.open_position is None
  142. def test_run_ema_pullback_segment_does_not_generate_entry_from_final_bar():
  143. result = run_ema_pullback_segment(
  144. candles=build_final_bar_signal_fixture(),
  145. leverage=2,
  146. warmup_bars=4,
  147. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  148. )
  149. assert result.trade_count == 0
  150. assert result.entries == []
  151. assert result.open_position is None
  152. def test_run_ema_pullback_segment_marks_open_position_to_market():
  153. result = run_ema_pullback_segment(
  154. candles=build_open_tail_fixture(),
  155. leverage=2,
  156. warmup_bars=4,
  157. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  158. )
  159. assert result.trade_count == 0
  160. assert result.trades == []
  161. assert result.total_return == pytest.approx((10_384.615384615385 - 10_000.0) / 10_000.0)
  162. assert result.open_position is not None
  163. assert result.open_position["side"] == "long"
  164. def test_run_ema_pullback_segment_does_not_reenter_after_equity_is_depleted():
  165. result = run_ema_pullback_segment(
  166. candles=build_depleted_equity_fixture(),
  167. leverage=100,
  168. warmup_bars=4,
  169. config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
  170. )
  171. assert result.trade_count == 1
  172. assert len(result.entries) == 1
  173. assert result.open_position is None
  174. assert result.total_return <= -1.0
  175. def test_generate_ema_pullback_sampled_report_uses_shared_shell_defaults(monkeypatch, tmp_path):
  176. candles = build_linear_candles(5_000)
  177. output_file = tmp_path / "ema-pullback.html"
  178. recorded: dict[str, object] = {}
  179. sentinel = {
  180. "report_file": str(output_file),
  181. "segment_count": 2,
  182. "window_size": 300,
  183. "aggregate_trade_count": 4,
  184. "average_return": 0.12,
  185. }
  186. def fake_generate_sampled_report(**kwargs):
  187. recorded.update(kwargs)
  188. return sentinel
  189. monkeypatch.setattr(ema_pullback_report, "generate_sampled_report", fake_generate_sampled_report)
  190. result = generate_ema_pullback_sampled_report(
  191. candles=candles,
  192. leverage=2,
  193. output_file=output_file,
  194. symbol="BTC-USDT-SWAP",
  195. bar="3m",
  196. segments=2,
  197. window_size=300,
  198. )
  199. assert result == sentinel
  200. assert recorded["report_title"] == "EMA Pullback Sampled Report"
  201. assert recorded["strategy_label"] == "EMA Pullback"
  202. assert recorded["strategy_params"] == {
  203. "fast_ema": 20,
  204. "slow_ema": 50,
  205. "stop_buffer_pct": 0.005,
  206. }
  207. assert recorded["warmup_bars"] == 50
  208. assert callable(recorded["run_segment"])