import importlib
import pytest
from okx_codex_trader.models import Candle
def load_sampled_report_module():
try:
return importlib.import_module("okx_codex_trader.sampled_report")
except ModuleNotFoundError as exc:
pytest.fail(f"missing shared sampled-report module: {exc}")
def build_linear_candles(count: int) -> list[Candle]:
candles: list[Candle] = []
for index in range(count):
close = 100.0 + index
candles.append(
Candle(
symbol="BTC-USDT-SWAP",
ts=index * 60_000,
open=close,
high=close + 1.0,
low=close - 1.0,
close=close,
volume=1_000.0 + index,
)
)
return candles
def build_segment_result(module, *, total_return: float, trade_count: int, win_rate: float, max_drawdown: float):
return module.SegmentResult(
trade_count=trade_count,
total_return=total_return,
win_rate=win_rate,
max_drawdown=max_drawdown,
trades=[
{
"side": "Long",
"entry_time": "2026-04-01 00:00",
"exit_time": "2026-04-01 01:00",
"entry_price": 100.0,
"exit_price": 101.0,
"pnl": 10.0,
"return_pct": 1.0,
}
],
open_position=None,
candles=build_linear_candles(2),
equity_curve=[
{"ts": 0, "equity": 10_000.0, "close": 100.0},
{"ts": 60_000, "equity": 10_000.0 * (1 + total_return), "close": 101.0},
],
entries=[{"ts": 0, "price": 100.0, "side": "long"}],
exits=[{"ts": 60_000, "price": 101.0, "side": "long"}],
)
def build_report_segment(module, *, result, index: int = 0, start_time: str = "2026-04-01 00:00", end_time: str = "2026-04-01 15:00"):
return module.ReportSegment(
index=index,
start_time=start_time,
end_time=end_time,
result=result,
plot_div="
plot0
",
)
def test_sample_segments_is_deterministic():
module = load_sampled_report_module()
candles = build_linear_candles(5_000)
first = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
second = module.sample_segments(candles=candles, segments=4, window_size=300, warmup_bars=69, seed=7)
assert first == second
assert [segment.context_start for segment in first] == sorted(segment.context_start for segment in first)
def test_sample_segments_rejects_undersized_history_pool():
module = load_sampled_report_module()
with pytest.raises(ValueError, match="history pool is too small"):
module.sample_segments(candles=build_linear_candles(1_000), segments=8, window_size=300, warmup_bars=69, seed=7)
def test_sample_segments_returns_exact_non_overlapping_block_ranges():
module = load_sampled_report_module()
sampled = module.sample_segments(candles=build_linear_candles(1_300), segments=3, window_size=300, warmup_bars=69, seed=7)
assert [(segment.context_start, segment.report_start, segment.report_end) for segment in sampled] == [
(0, 69, 369),
(369, 438, 738),
(738, 807, 1107),
]
def test_sample_segments_rejects_invalid_sampling_result(tmp_path, monkeypatch):
module = load_sampled_report_module()
candles = build_linear_candles(5_000)
monkeypatch.setattr(
module,
"sample_segments",
lambda **_: [
module.SampledSegment(
context_start=0,
report_start=69,
report_end=369,
start_ts=candles[69].ts,
end_ts=candles[368].ts,
)
],
)
with pytest.raises(ValueError, match="invalid sampling result"):
module.generate_sampled_report(
candles=candles,
leverage=2,
output_file=tmp_path / "sampled-report.html",
symbol="BTC-USDT-SWAP",
bar="3m",
segments=2,
window_size=300,
report_title="Sampled Report",
strategy_label="Test Strategy",
strategy_description="Strategy description",
strategy_params={"entry_window": 20},
run_segment=lambda **_: pytest.fail("run_segment should not be called for invalid samples"),
)
def test_generate_sampled_report_passes_sliced_window_and_warmup_bars(tmp_path, monkeypatch):
module = load_sampled_report_module()
candles = build_linear_candles(100)
sampled = [
module.SampledSegment(context_start=5, report_start=7, report_end=17, start_ts=candles[7].ts, end_ts=candles[16].ts),
module.SampledSegment(context_start=20, report_start=22, report_end=32, start_ts=candles[22].ts, end_ts=candles[31].ts),
]
captured_calls: list[dict[str, object]] = []
monkeypatch.setattr(module, "sample_segments", lambda **_: sampled)
def run_segment(*, candles, leverage, warmup_bars):
captured_calls.append(
{
"ts": [candle.ts for candle in candles],
"leverage": leverage,
"warmup_bars": warmup_bars,
}
)
return build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
module.generate_sampled_report(
candles=candles,
leverage=3,
output_file=tmp_path / "sampled-report.html",
symbol="BTC-USDT-SWAP",
bar="3m",
segments=2,
window_size=10,
warmup_bars=2,
report_title="Shared Sampled Report",
strategy_label="Test Strategy",
strategy_description="Strategy description",
strategy_params={"entry_window": 20},
run_segment=run_segment,
)
assert captured_calls == [
{
"ts": [candle.ts for candle in candles[5:17]],
"leverage": 3,
"warmup_bars": 2,
},
{
"ts": [candle.ts for candle in candles[20:32]],
"leverage": 3,
"warmup_bars": 2,
},
]
def test_generate_sampled_report_aggregates_metrics(tmp_path, monkeypatch):
module = load_sampled_report_module()
calls = iter(
[
build_segment_result(module, total_return=0.1, trade_count=2, win_rate=0.5, max_drawdown=0.05),
build_segment_result(module, total_return=-0.2, trade_count=3, win_rate=1 / 3, max_drawdown=0.12),
]
)
captured_render: dict[str, object] = {}
def render_sampled_report(**kwargs):
captured_render.update(kwargs)
return "report"
monkeypatch.setattr(module, "render_sampled_report", render_sampled_report)
report = module.generate_sampled_report(
candles=build_linear_candles(400),
leverage=2,
output_file=tmp_path / "sampled-report.html",
symbol="BTC-USDT-SWAP",
bar="3m",
segments=2,
window_size=10,
warmup_bars=2,
report_title="Shared Sampled Report",
strategy_label="Test Strategy",
strategy_description="Strategy description",
strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
run_segment=lambda **_: next(calls),
)
assert report == {
"report_file": str(tmp_path / "sampled-report.html"),
"segment_count": 2,
"window_size": 10,
"aggregate_trade_count": 5,
"average_return": -0.05,
}
assert (tmp_path / "sampled-report.html").exists()
assert (tmp_path / "sampled-report.html").read_text() == "report"
assert captured_render["aggregate_summary"] == {
"aggregate_trade_count": 5,
"average_return": -0.05,
"median_return": -0.05,
"best_segment_return": 0.1,
"worst_segment_return": -0.2,
}
assert all(isinstance(segment, module.ReportSegment) for segment in captured_render["segment_results"])
assert captured_render["segment_results"][0].result.trade_count == 2
assert captured_render["segment_results"][1].result.trade_count == 3
def test_render_sampled_report_includes_strategy_params():
module = load_sampled_report_module()
result = build_segment_result(module, total_return=0.1, trade_count=3, win_rate=0.66, max_drawdown=0.05)
html = module.render_sampled_report(
symbol="BTC-USDT-SWAP",
bar="3m",
leverage=2,
history_limit=5_000,
segments=2,
window_size=300,
report_title="Shared Sampled Report",
strategy_label="Donchian",
strategy_description="Price breakout strategy.",
strategy_params={"entry_window": 20, "stop_loss_pct": 0.01},
aggregate_summary={
"aggregate_trade_count": 12,
"average_return": 0.1,
"median_return": 0.05,
"best_segment_return": 0.3,
"worst_segment_return": -0.2,
},
segment_results=[build_report_segment(module, result=result)],
bokeh_script="",
)
assert "Donchian sampled report" in html
assert "Price breakout strategy." in html
assert "Entry Window" in html
assert "20" in html
assert "Stop Loss Pct" in html
assert "0.01" in html
def test_build_segment_plot_embeds_entry_exit_markers():
module = load_sampled_report_module()
segment = build_segment_result(module, total_return=0.1, trade_count=1, win_rate=1.0, max_drawdown=0.05)
plot = module.build_segment_plot(segment)
price_plot = plot.children[0]
markers = [getattr(renderer.glyph, "marker", None) for renderer in price_plot.renderers if hasattr(renderer, "glyph")]
assert "triangle" in markers
assert "inverted_triangle" in markers