test_sampled_report.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import importlib
  2. import pytest
  3. from okx_codex_trader.models import Candle
  4. def load_sampled_report_module():
  5. try:
  6. return importlib.import_module("okx_codex_trader.sampled_report")
  7. except ModuleNotFoundError as exc:
  8. pytest.fail(f"missing shared sampled-report module: {exc}")
  9. def build_linear_candles(count: int) -> list[Candle]:
  10. candles: list[Candle] = []
  11. for index in range(count):
  12. close = 100.0 + index
  13. candles.append(
  14. Candle(
  15. symbol="BTC-USDT-SWAP",
  16. ts=index * 60_000,
  17. open=close,
  18. high=close + 1.0,
  19. low=close - 1.0,
  20. close=close,
  21. volume=1_000.0 + index,
  22. )
  23. )
  24. return candles
  25. def build_segment_result(module, *, total_return: float, trade_count: int, win_rate: float, max_drawdown: float):
  26. return module.SegmentResult(
  27. trade_count=trade_count,
  28. total_return=total_return,
  29. win_rate=win_rate,
  30. max_drawdown=max_drawdown,
  31. trades=[
  32. {
  33. "side": "Long",
  34. "entry_time": "2026-04-01 00:00",
  35. "exit_time": "2026-04-01 01:00",
  36. "entry_price": 100.0,
  37. "exit_price": 101.0,
  38. "pnl": 10.0,
  39. "return_pct": 1.0,
  40. }
  41. ],
  42. open_position=None,
  43. candles=build_linear_candles(2),
  44. equity_curve=[
  45. {"ts": 0, "equity": 10_000.0, "close": 100.0},
  46. {"ts": 60_000, "equity": 10_000.0 * (1 + total_return), "close": 101.0},
  47. ],
  48. entries=[{"ts": 0, "price": 100.0, "side": "long"}],
  49. exits=[{"ts": 60_000, "price": 101.0, "side": "long"}],
  50. )
  51. def build_report_segment(module, *, result, index: int = 0, start_time: str = "2026-04-01 00:00", end_time: str = "2026-04-01 15:00"):
  52. return module.ReportSegment(
  53. index=index,
  54. start_time=start_time,
  55. end_time=end_time,
  56. result=result,
  57. plot_div="<div>plot0</div>",
  58. )
  59. def test_sample_segments_is_deterministic():
  60. module = load_sampled_report_module()
  61. candles = build_linear_candles(5_000)
  62. first = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
  63. second = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
  64. assert first == second
  65. assert [segment.context_start for segment in first] == sorted(segment.context_start for segment in first)
  66. def test_sample_segments_rejects_undersized_history_pool():
  67. module = load_sampled_report_module()
  68. with pytest.raises(ValueError, match="history pool is too small"):
  69. module.sample_segments(candles=build_linear_candles(1_000), segments=8, window_size=300, warmup_bars=69, seed=7)
  70. def test_sample_segments_returns_exact_non_overlapping_block_ranges():
  71. module = load_sampled_report_module()
  72. sampled = module.sample_segments(candles=build_linear_candles(1_300), segments=3, window_size=300, warmup_bars=69, seed=7)
  73. assert [(segment.context_start, segment.report_start, segment.report_end) for segment in sampled] == [
  74. (0, 69, 369),
  75. (369, 438, 738),
  76. (738, 807, 1107),
  77. ]
  78. def test_sample_segments_rejects_invalid_sampling_result(tmp_path, monkeypatch):
  79. module = load_sampled_report_module()
  80. candles = build_linear_candles(5_000)
  81. monkeypatch.setattr(
  82. module,
  83. "sample_segments",
  84. lambda **_: [
  85. module.SampledSegment(
  86. context_start=0,
  87. report_start=69,
  88. report_end=369,
  89. start_ts=candles[69].ts,
  90. end_ts=candles[368].ts,
  91. )
  92. ],
  93. )
  94. with pytest.raises(ValueError, match="invalid sampling result"):
  95. module.generate_sampled_report(
  96. candles=candles,
  97. leverage=2,
  98. output_file=tmp_path / "sampled-report.html",
  99. symbol="BTC-USDT-SWAP",
  100. bar="3m",
  101. segments=2,
  102. window_size=300,
  103. report_title="Sampled Report",
  104. strategy_label="Test Strategy",
  105. strategy_description="Strategy description",
  106. strategy_params={"entry_window": 20},
  107. run_segment=lambda **_: pytest.fail("run_segment should not be called for invalid samples"),
  108. )
  109. def test_generate_sampled_report_passes_sliced_window_and_warmup_bars(tmp_path, monkeypatch):
  110. module = load_sampled_report_module()
  111. candles = build_linear_candles(100)
  112. sampled = [
  113. module.SampledSegment(context_start=5, report_start=7, report_end=17, start_ts=candles[7].ts, end_ts=candles[16].ts),
  114. module.SampledSegment(context_start=20, report_start=22, report_end=32, start_ts=candles[22].ts, end_ts=candles[31].ts),
  115. ]
  116. captured_calls: list[dict[str, object]] = []
  117. monkeypatch.setattr(module, "sample_segments", lambda **_: sampled)
  118. def run_segment(*, candles, leverage, warmup_bars):
  119. captured_calls.append(
  120. {
  121. "ts": [candle.ts for candle in candles],
  122. "leverage": leverage,
  123. "warmup_bars": warmup_bars,
  124. }
  125. )
  126. return build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
  127. module.generate_sampled_report(
  128. candles=candles,
  129. leverage=3,
  130. output_file=tmp_path / "sampled-report.html",
  131. symbol="BTC-USDT-SWAP",
  132. bar="3m",
  133. segments=2,
  134. window_size=10,
  135. warmup_bars=2,
  136. report_title="Shared Sampled Report",
  137. strategy_label="Test Strategy",
  138. strategy_description="Strategy description",
  139. strategy_params={"entry_window": 20},
  140. run_segment=run_segment,
  141. )
  142. assert captured_calls == [
  143. {
  144. "ts": [candle.ts for candle in candles[5:17]],
  145. "leverage": 3,
  146. "warmup_bars": 2,
  147. },
  148. {
  149. "ts": [candle.ts for candle in candles[20:32]],
  150. "leverage": 3,
  151. "warmup_bars": 2,
  152. },
  153. ]
  154. def test_generate_sampled_report_aggregates_metrics(tmp_path, monkeypatch):
  155. module = load_sampled_report_module()
  156. calls = iter(
  157. [
  158. build_segment_result(module, total_return=0.1, trade_count=2, win_rate=0.5, max_drawdown=0.05),
  159. build_segment_result(module, total_return=-0.2, trade_count=3, win_rate=1 / 3, max_drawdown=0.12),
  160. ]
  161. )
  162. captured_render: dict[str, object] = {}
  163. def render_sampled_report(**kwargs):
  164. captured_render.update(kwargs)
  165. return "<html>report</html>"
  166. monkeypatch.setattr(module, "render_sampled_report", render_sampled_report)
  167. report = module.generate_sampled_report(
  168. candles=build_linear_candles(400),
  169. leverage=2,
  170. output_file=tmp_path / "sampled-report.html",
  171. symbol="BTC-USDT-SWAP",
  172. bar="3m",
  173. segments=2,
  174. window_size=10,
  175. warmup_bars=2,
  176. report_title="Shared Sampled Report",
  177. strategy_label="Test Strategy",
  178. strategy_description="Strategy description",
  179. strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
  180. run_segment=lambda **_: next(calls),
  181. )
  182. assert report == {
  183. "report_file": str(tmp_path / "sampled-report.html"),
  184. "segment_count": 2,
  185. "window_size": 10,
  186. "aggregate_trade_count": 5,
  187. "average_return": -0.05,
  188. }
  189. assert (tmp_path / "sampled-report.html").exists()
  190. assert (tmp_path / "sampled-report.html").read_text() == "<html>report</html>"
  191. assert captured_render["aggregate_summary"] == {
  192. "aggregate_trade_count": 5,
  193. "average_return": -0.05,
  194. "median_return": -0.05,
  195. "best_segment_return": 0.1,
  196. "worst_segment_return": -0.2,
  197. }
  198. assert all(isinstance(segment, module.ReportSegment) for segment in captured_render["segment_results"])
  199. assert captured_render["segment_results"][0].result.trade_count == 2
  200. assert captured_render["segment_results"][1].result.trade_count == 3
  201. def test_render_sampled_report_includes_strategy_params():
  202. module = load_sampled_report_module()
  203. result = build_segment_result(module, total_return=0.1, trade_count=3, win_rate=0.66, max_drawdown=0.05)
  204. html = module.render_sampled_report(
  205. symbol="BTC-USDT-SWAP",
  206. bar="3m",
  207. leverage=2,
  208. history_limit=5_000,
  209. segments=2,
  210. window_size=300,
  211. report_title="Shared Sampled Report",
  212. strategy_label="Donchian",
  213. strategy_description="Price breakout strategy.",
  214. strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
  215. aggregate_summary={
  216. "aggregate_trade_count": 12,
  217. "average_return": 0.1,
  218. "median_return": 0.05,
  219. "best_segment_return": 0.3,
  220. "worst_segment_return": -0.2,
  221. },
  222. segment_results=[build_report_segment(module, result=result)],
  223. bokeh_script="<script>plots</script>",
  224. )
  225. assert "Donchian sampled report" in html
  226. assert "Price breakout strategy." in html
  227. assert "Entry Window" in html
  228. assert "20" in html
  229. assert "Stop Loss Pct" in html
  230. assert "0.01" in html
  231. def test_build_segment_plot_embeds_entry_exit_markers():
  232. module = load_sampled_report_module()
  233. segment = build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
  234. plot = module.build_segment_plot(segment)
  235. price_plot = plot.children[0]
  236. markers = [getattr(renderer.glyph, "marker", None) for renderer in price_plot.renderers if hasattr(renderer, "glyph")]
  237. assert "triangle" in markers
  238. assert "inverted_triangle" in markers