| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- import importlib
- import pytest
- from okx_codex_trader.models import Candle
- def load_sampled_report_module():
- try:
- return importlib.import_module("okx_codex_trader.sampled_report")
- except ModuleNotFoundError as exc:
- pytest.fail(f"missing shared sampled-report module: {exc}")
- def build_linear_candles(count: int) -> list[Candle]:
- candles: list[Candle] = []
- for index in range(count):
- close = 100.0 + index
- candles.append(
- Candle(
- symbol="BTC-USDT-SWAP",
- ts=index * 60_000,
- open=close,
- high=close + 1.0,
- low=close - 1.0,
- close=close,
- volume=1_000.0 + index,
- )
- )
- return candles
- def build_segment_result(module, *, total_return: float, trade_count: int, win_rate: float, max_drawdown: float):
- return module.SegmentResult(
- trade_count=trade_count,
- total_return=total_return,
- win_rate=win_rate,
- max_drawdown=max_drawdown,
- trades=[
- {
- "side": "Long",
- "entry_time": "2026-04-01 00:00",
- "exit_time": "2026-04-01 01:00",
- "entry_price": 100.0,
- "exit_price": 101.0,
- "pnl": 10.0,
- "return_pct": 1.0,
- }
- ],
- open_position=None,
- candles=build_linear_candles(2),
- equity_curve=[
- {"ts": 0, "equity": 10_000.0, "close": 100.0},
- {"ts": 60_000, "equity": 10_000.0 * (1 + total_return), "close": 101.0},
- ],
- entries=[{"ts": 0, "price": 100.0, "side": "long"}],
- exits=[{"ts": 60_000, "price": 101.0, "side": "long"}],
- )
- def build_report_segment(module, *, result, index: int = 0, start_time: str = "2026-04-01 00:00", end_time: str = "2026-04-01 15:00"):
- return module.ReportSegment(
- index=index,
- start_time=start_time,
- end_time=end_time,
- result=result,
- plot_div="<div>plot0</div>",
- )
- def test_sample_segments_is_deterministic():
- module = load_sampled_report_module()
- candles = build_linear_candles(5_000)
- first = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
- second = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
- assert first == second
- assert [segment.context_start for segment in first] == sorted(segment.context_start for segment in first)
- def test_sample_segments_rejects_undersized_history_pool():
- module = load_sampled_report_module()
- with pytest.raises(ValueError, match="history pool is too small"):
- module.sample_segments(candles=build_linear_candles(1_000), segments=8, window_size=300, warmup_bars=69, seed=7)
- def test_sample_segments_returns_exact_non_overlapping_block_ranges():
- module = load_sampled_report_module()
- sampled = module.sample_segments(candles=build_linear_candles(1_300), segments=3, window_size=300, warmup_bars=69, seed=7)
- assert [(segment.context_start, segment.report_start, segment.report_end) for segment in sampled] == [
- (0, 69, 369),
- (369, 438, 738),
- (738, 807, 1107),
- ]
- def test_sample_segments_rejects_invalid_sampling_result(tmp_path, monkeypatch):
- module = load_sampled_report_module()
- candles = build_linear_candles(5_000)
- monkeypatch.setattr(
- module,
- "sample_segments",
- lambda **_: [
- module.SampledSegment(
- context_start=0,
- report_start=69,
- report_end=369,
- start_ts=candles[69].ts,
- end_ts=candles[368].ts,
- )
- ],
- )
- with pytest.raises(ValueError, match="invalid sampling result"):
- module.generate_sampled_report(
- candles=candles,
- leverage=2,
- output_file=tmp_path / "sampled-report.html",
- symbol="BTC-USDT-SWAP",
- bar="3m",
- segments=2,
- window_size=300,
- report_title="Sampled Report",
- strategy_label="Test Strategy",
- strategy_description="Strategy description",
- strategy_params={"entry_window": 20},
- run_segment=lambda **_: pytest.fail("run_segment should not be called for invalid samples"),
- )
- def test_generate_sampled_report_passes_sliced_window_and_warmup_bars(tmp_path, monkeypatch):
- module = load_sampled_report_module()
- candles = build_linear_candles(100)
- sampled = [
- module.SampledSegment(context_start=5, report_start=7, report_end=17, start_ts=candles[7].ts, end_ts=candles[16].ts),
- module.SampledSegment(context_start=20, report_start=22, report_end=32, start_ts=candles[22].ts, end_ts=candles[31].ts),
- ]
- captured_calls: list[dict[str, object]] = []
- monkeypatch.setattr(module, "sample_segments", lambda **_: sampled)
- def run_segment(*, candles, leverage, warmup_bars):
- captured_calls.append(
- {
- "ts": [candle.ts for candle in candles],
- "leverage": leverage,
- "warmup_bars": warmup_bars,
- }
- )
- return build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
- module.generate_sampled_report(
- candles=candles,
- leverage=3,
- output_file=tmp_path / "sampled-report.html",
- symbol="BTC-USDT-SWAP",
- bar="3m",
- segments=2,
- window_size=10,
- warmup_bars=2,
- report_title="Shared Sampled Report",
- strategy_label="Test Strategy",
- strategy_description="Strategy description",
- strategy_params={"entry_window": 20},
- run_segment=run_segment,
- )
- assert captured_calls == [
- {
- "ts": [candle.ts for candle in candles[5:17]],
- "leverage": 3,
- "warmup_bars": 2,
- },
- {
- "ts": [candle.ts for candle in candles[20:32]],
- "leverage": 3,
- "warmup_bars": 2,
- },
- ]
- def test_generate_sampled_report_aggregates_metrics(tmp_path, monkeypatch):
- module = load_sampled_report_module()
- calls = iter(
- [
- build_segment_result(module, total_return=0.1, trade_count=2, win_rate=0.5, max_drawdown=0.05),
- build_segment_result(module, total_return=-0.2, trade_count=3, win_rate=1 / 3, max_drawdown=0.12),
- ]
- )
- captured_render: dict[str, object] = {}
- def render_sampled_report(**kwargs):
- captured_render.update(kwargs)
- return "<html>report</html>"
- monkeypatch.setattr(module, "render_sampled_report", render_sampled_report)
- report = module.generate_sampled_report(
- candles=build_linear_candles(400),
- leverage=2,
- output_file=tmp_path / "sampled-report.html",
- symbol="BTC-USDT-SWAP",
- bar="3m",
- segments=2,
- window_size=10,
- warmup_bars=2,
- report_title="Shared Sampled Report",
- strategy_label="Test Strategy",
- strategy_description="Strategy description",
- strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
- run_segment=lambda **_: next(calls),
- )
- assert report == {
- "report_file": str(tmp_path / "sampled-report.html"),
- "segment_count": 2,
- "window_size": 10,
- "aggregate_trade_count": 5,
- "average_return": -0.05,
- }
- assert (tmp_path / "sampled-report.html").exists()
- assert (tmp_path / "sampled-report.html").read_text() == "<html>report</html>"
- assert captured_render["aggregate_summary"] == {
- "aggregate_trade_count": 5,
- "average_return": -0.05,
- "median_return": -0.05,
- "best_segment_return": 0.1,
- "worst_segment_return": -0.2,
- }
- assert all(isinstance(segment, module.ReportSegment) for segment in captured_render["segment_results"])
- assert captured_render["segment_results"][0].result.trade_count == 2
- assert captured_render["segment_results"][1].result.trade_count == 3
- def test_render_sampled_report_includes_strategy_params():
- module = load_sampled_report_module()
- result = build_segment_result(module, total_return=0.1, trade_count=3, win_rate=0.66, max_drawdown=0.05)
- html = module.render_sampled_report(
- symbol="BTC-USDT-SWAP",
- bar="3m",
- leverage=2,
- history_limit=5_000,
- segments=2,
- window_size=300,
- report_title="Shared Sampled Report",
- strategy_label="Donchian",
- strategy_description="Price breakout strategy.",
- strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
- aggregate_summary={
- "aggregate_trade_count": 12,
- "average_return": 0.1,
- "median_return": 0.05,
- "best_segment_return": 0.3,
- "worst_segment_return": -0.2,
- },
- segment_results=[build_report_segment(module, result=result)],
- bokeh_script="<script>plots</script>",
- )
- assert "Donchian sampled report" in html
- assert "Price breakout strategy." in html
- assert "Entry Window" in html
- assert "20" in html
- assert "Stop Loss Pct" in html
- assert "0.01" in html
- def test_build_segment_plot_embeds_entry_exit_markers():
- module = load_sampled_report_module()
- segment = build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
- plot = module.build_segment_plot(segment)
- price_plot = plot.children[0]
- markers = [getattr(renderer.glyph, "marker", None) for renderer in price_plot.renderers if hasattr(renderer, "glyph")]
- assert "triangle" in markers
- assert "inverted_triangle" in markers
|