sampled_report.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from html import escape
  4. from pathlib import Path
  5. from random import Random
  6. from statistics import median
  7. from typing import Callable
  8. import pandas as pd
  9. from bokeh.embed import components
  10. from bokeh.layouts import column
  11. from bokeh.plotting import figure
  12. from bokeh.resources import INLINE
  13. from okx_codex_trader.models import Candle
  14. WARMUP_BARS = 69
  15. SAMPLER_SEED = 7
  16. @dataclass(frozen=True)
  17. class SampledSegment:
  18. context_start: int
  19. report_start: int
  20. report_end: int
  21. start_ts: int
  22. end_ts: int
  23. @dataclass(frozen=True)
  24. class SegmentResult:
  25. trade_count: int
  26. total_return: float
  27. win_rate: float
  28. max_drawdown: float
  29. trades: list[dict[str, object]]
  30. open_position: dict[str, object] | None
  31. candles: list[Candle]
  32. equity_curve: list[dict[str, float | int]]
  33. entries: list[dict[str, object]]
  34. exits: list[dict[str, object]]
  35. @dataclass(frozen=True)
  36. class ReportSegment:
  37. index: int
  38. start_time: str
  39. end_time: str
  40. result: SegmentResult
  41. plot_div: str
  42. def _format_ts(ts: int) -> str:
  43. return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
  44. def sample_segments(
  45. *,
  46. candles: list[Candle],
  47. segments: int,
  48. window_size: int,
  49. warmup_bars: int = WARMUP_BARS,
  50. seed: int = SAMPLER_SEED,
  51. ) -> list[SampledSegment]:
  52. block_size = window_size + warmup_bars
  53. if len(candles) < segments * block_size:
  54. raise ValueError("history pool is too small")
  55. context_starts = list(range(0, len(candles) - block_size + 1, block_size))
  56. if len(context_starts) < segments:
  57. raise ValueError("history pool is too small")
  58. rng = Random(seed)
  59. rng.shuffle(context_starts)
  60. selected_context_starts = sorted(context_starts[:segments])
  61. return [
  62. SampledSegment(
  63. context_start=context_start,
  64. report_start=context_start + warmup_bars,
  65. report_end=context_start + block_size,
  66. start_ts=candles[context_start + warmup_bars].ts,
  67. end_ts=candles[context_start + block_size - 1].ts,
  68. )
  69. for context_start in selected_context_starts
  70. ]
  71. def trade_equity(
  72. *,
  73. side: str,
  74. margin_used: float,
  75. entry_price: float,
  76. exit_price: float,
  77. leverage: int,
  78. ) -> float:
  79. if side == "long":
  80. price_return = (exit_price - entry_price) / entry_price
  81. else:
  82. price_return = (entry_price - exit_price) / entry_price
  83. return margin_used + (margin_used * leverage * price_return)
  84. def mark_to_market(
  85. *,
  86. side: str,
  87. margin_used: float,
  88. entry_price: float,
  89. mark_price: float,
  90. leverage: int,
  91. ) -> float:
  92. return trade_equity(
  93. side=side,
  94. margin_used=margin_used,
  95. entry_price=entry_price,
  96. exit_price=mark_price,
  97. leverage=leverage,
  98. )
  99. def build_segment_plot(segment: SegmentResult):
  100. timestamps = [pd.to_datetime(point["ts"], unit="ms", utc=True) for point in segment.equity_curve]
  101. closes = [point["close"] for point in segment.equity_curve]
  102. equities = [point["equity"] for point in segment.equity_curve]
  103. price_fig = figure(height=320, sizing_mode="stretch_width", x_axis_type="datetime", title="Price")
  104. price_fig.line(timestamps, closes, line_width=2, color="#1f6f78")
  105. if segment.entries:
  106. price_fig.scatter(
  107. [pd.to_datetime(entry["ts"], unit="ms", utc=True) for entry in segment.entries],
  108. [entry["price"] for entry in segment.entries],
  109. marker="triangle",
  110. size=12,
  111. color="#1d7c44",
  112. )
  113. if segment.exits:
  114. price_fig.scatter(
  115. [pd.to_datetime(exit_point["ts"], unit="ms", utc=True) for exit_point in segment.exits],
  116. [exit_point["price"] for exit_point in segment.exits],
  117. marker="inverted_triangle",
  118. size=12,
  119. color="#a13d2d",
  120. )
  121. equity_fig = figure(height=220, sizing_mode="stretch_width", x_axis_type="datetime", title="Equity")
  122. equity_fig.line(timestamps, equities, line_width=2, color="#7b4f9d")
  123. return column(price_fig, equity_fig, sizing_mode="stretch_width")
  124. def render_sampled_report(
  125. *,
  126. symbol: str,
  127. bar: str,
  128. leverage: int,
  129. history_limit: int,
  130. segments: int,
  131. window_size: int,
  132. report_title: str,
  133. strategy_label: str,
  134. strategy_description: str,
  135. strategy_params: dict[str, object],
  136. aggregate_summary: dict[str, object],
  137. segment_results: list[ReportSegment],
  138. bokeh_script: str,
  139. ) -> str:
  140. summary_cards = "".join(
  141. f"""
  142. <div class="card">
  143. <div class="label">{escape(label)}</div>
  144. <div class="value">{escape(str(value))}</div>
  145. </div>
  146. """
  147. for label, value in (
  148. ("History Limit", history_limit),
  149. ("Segment Count", segments),
  150. ("Window Size", window_size),
  151. ("Average Return Across Segments", aggregate_summary["average_return"]),
  152. ("Median Return Across Segments", aggregate_summary["median_return"]),
  153. ("Best Segment Return", aggregate_summary["best_segment_return"]),
  154. ("Worst Segment Return", aggregate_summary["worst_segment_return"]),
  155. ("Aggregate Trade Count", aggregate_summary["aggregate_trade_count"]),
  156. )
  157. )
  158. params_markup = "".join(
  159. f"""
  160. <div class="card">
  161. <div class="label">{escape(key.replace('_', ' ').title())}</div>
  162. <div class="value">{escape(str(value))}</div>
  163. </div>
  164. """
  165. for key, value in strategy_params.items()
  166. )
  167. selector = "".join(
  168. f'<button class="segment-button{" active" if segment.index == 0 else ""}" data-segment-index="{segment.index}">Segment {segment.index + 1}</button>'
  169. for segment in segment_results
  170. )
  171. panels = []
  172. for segment in segment_results:
  173. rows = "".join(
  174. f"""
  175. <tr>
  176. <td>{escape(str(trade["side"]))}</td>
  177. <td>{escape(str(trade["entry_time"]))}</td>
  178. <td>{escape(str(trade["exit_time"]))}</td>
  179. <td>{escape(str(trade["entry_price"]))}</td>
  180. <td>{escape(str(trade["exit_price"]))}</td>
  181. <td>{escape(str(trade["pnl"]))}</td>
  182. <td>{escape(str(trade["return_pct"]))}</td>
  183. </tr>
  184. """
  185. for trade in segment.result.trades
  186. )
  187. panels.append(
  188. f"""
  189. <section class="segment-panel{' active' if segment.index == 0 else ''}" data-segment-index="{segment.index}">
  190. <div class="segment-metrics">
  191. <div class="metric"><span>Sampled Range Start Time</span><strong>{escape(segment.start_time)}</strong></div>
  192. <div class="metric"><span>Sampled Range End Time</span><strong>{escape(segment.end_time)}</strong></div>
  193. <div class="metric"><span>Trade Count</span><strong>{escape(str(segment.result.trade_count))}</strong></div>
  194. <div class="metric"><span>Total Return</span><strong>{escape(str(round(segment.result.total_return, 6)))}</strong></div>
  195. <div class="metric"><span>Win Rate</span><strong>{escape(str(round(segment.result.win_rate, 6)))}</strong></div>
  196. <div class="metric"><span>Max Drawdown</span><strong>{escape(str(round(segment.result.max_drawdown, 6)))}</strong></div>
  197. </div>
  198. <div class="layout">
  199. <section class="panel">
  200. <h3>Trade Journal</h3>
  201. <table>
  202. <thead>
  203. <tr>
  204. <th>Side</th>
  205. <th>Entry Time</th>
  206. <th>Exit Time</th>
  207. <th>Entry</th>
  208. <th>Exit</th>
  209. <th>PnL</th>
  210. <th>Return %</th>
  211. </tr>
  212. </thead>
  213. <tbody>{rows}</tbody>
  214. </table>
  215. </section>
  216. <section class="panel">{segment.plot_div}</section>
  217. </div>
  218. </section>
  219. """
  220. )
  221. return f"""<!DOCTYPE html>
  222. <html lang="en">
  223. <head>
  224. <meta charset="utf-8">
  225. <meta name="viewport" content="width=device-width, initial-scale=1">
  226. <title>{escape(symbol)} {escape(report_title)}</title>
  227. <style>
  228. body {{ font-family: Inter, system-ui, sans-serif; margin: 0; background: #f5f1e8; color: #1f1c18; }}
  229. .page {{ max-width: 1440px; margin: 0 auto; padding: 28px 24px 48px; }}
  230. .hero {{ display:flex; justify-content:space-between; gap:24px; align-items:end; margin-bottom:20px; }}
  231. .hero h1 {{ margin:0; font-size:36px; }}
  232. .meta {{ color:#5f564b; }}
  233. .strategy {{ background:#fffdf8; border:1px solid #d8cdbd; border-radius:18px; padding:18px; margin-bottom:18px; }}
  234. .stats {{ display:grid; grid-template-columns:repeat(4, minmax(0, 1fr)); gap:12px; margin-bottom:18px; }}
  235. .card {{ background:#fffdf8; border:1px solid #d8cdbd; border-radius:16px; padding:16px; }}
  236. .label {{ font-size:12px; letter-spacing:.08em; text-transform:uppercase; color:#7a6f62; margin-bottom:8px; }}
  237. .value {{ font-size:24px; font-weight:700; }}
  238. .segment-selector {{ display:flex; flex-wrap:wrap; gap:10px; margin-bottom:18px; }}
  239. .segment-button {{ border:1px solid #c6b7a2; background:#fffdf8; border-radius:999px; padding:10px 14px; cursor:pointer; }}
  240. .segment-button.active {{ background:#1f1c18; color:#fffdf8; }}
  241. .segment-panel {{ display:none; }}
  242. .segment-panel.active {{ display:block; }}
  243. .segment-metrics {{ display:grid; grid-template-columns:repeat(3, minmax(0,1fr)); gap:12px; margin-bottom:16px; }}
  244. .metric {{ background:#fffdf8; border:1px solid #d8cdbd; border-radius:14px; padding:14px; }}
  245. .metric span {{ display:block; font-size:12px; color:#6b6258; text-transform:uppercase; letter-spacing:.08em; margin-bottom:6px; }}
  246. .metric strong {{ font-size:20px; }}
  247. .layout {{ display:grid; grid-template-columns:1fr 1fr; gap:16px; }}
  248. .panel {{ background:#fffdf8; border:1px solid #d8cdbd; border-radius:18px; padding:18px; overflow:auto; }}
  249. table {{ width:100%; border-collapse:collapse; font-size:14px; }}
  250. th, td {{ text-align:left; padding:10px 8px; border-bottom:1px solid #ece3d6; }}
  251. th {{ color:#6b6258; font-size:12px; text-transform:uppercase; letter-spacing:.08em; }}
  252. @media (max-width: 1100px) {{
  253. .stats, .segment-metrics, .layout {{ grid-template-columns:1fr; }}
  254. }}
  255. </style>
  256. </head>
  257. <body>
  258. <div class="page">
  259. <div class="hero">
  260. <div>
  261. <div class="meta">{escape(strategy_label)} sampled report</div>
  262. <h1>{escape(symbol)}</h1>
  263. </div>
  264. <div class="meta">Bar: {escape(bar)} · Leverage: {escape(str(leverage))}x</div>
  265. </div>
  266. <section class="strategy">
  267. <strong>{escape(report_title)}</strong>
  268. <p>{escape(strategy_description)}</p>
  269. <section class="stats">{params_markup}</section>
  270. </section>
  271. <section class="stats">{summary_cards}</section>
  272. <div class="segment-selector" id="segment-selector">{selector}</div>
  273. {''.join(panels)}
  274. </div>
  275. {bokeh_script}
  276. <script>
  277. const buttons = Array.from(document.querySelectorAll('.segment-button'));
  278. const panels = Array.from(document.querySelectorAll('.segment-panel'));
  279. buttons.forEach((button) => {{
  280. button.addEventListener('click', () => {{
  281. const target = button.dataset.segmentIndex;
  282. buttons.forEach((item) => item.classList.toggle('active', item === button));
  283. panels.forEach((panel) => panel.classList.toggle('active', panel.dataset.segmentIndex === target));
  284. }});
  285. }});
  286. </script>
  287. </body>
  288. </html>"""
  289. def generate_sampled_report(
  290. *,
  291. candles: list[Candle],
  292. leverage: int,
  293. output_file: Path,
  294. symbol: str,
  295. bar: str,
  296. segments: int,
  297. window_size: int,
  298. report_title: str,
  299. strategy_label: str,
  300. strategy_description: str,
  301. strategy_params: dict[str, object],
  302. run_segment: Callable[..., SegmentResult],
  303. warmup_bars: int = WARMUP_BARS,
  304. ) -> dict[str, object]:
  305. sampled = sample_segments(candles=candles, segments=segments, window_size=window_size, warmup_bars=warmup_bars)
  306. if len(sampled) != segments:
  307. raise ValueError("invalid sampling result")
  308. output_file.parent.mkdir(parents=True, exist_ok=True)
  309. segment_results: list[SegmentResult] = []
  310. plots = {}
  311. for index, segment in enumerate(sampled):
  312. result = run_segment(
  313. candles=candles[segment.context_start : segment.report_end],
  314. leverage=leverage,
  315. warmup_bars=warmup_bars,
  316. )
  317. segment_results.append(result)
  318. plots[f"segment_{index}"] = build_segment_plot(result)
  319. plot_script, plot_divs = components(plots)
  320. report_segments = [
  321. ReportSegment(
  322. index=index,
  323. start_time=_format_ts(segment.start_ts),
  324. end_time=_format_ts(segment.end_ts),
  325. result=result,
  326. plot_div=plot_divs[f"segment_{index}"],
  327. )
  328. for index, (segment, result) in enumerate(zip(sampled, segment_results, strict=True))
  329. ]
  330. returns = [result.total_return for result in segment_results]
  331. aggregate_summary = {
  332. "aggregate_trade_count": sum(result.trade_count for result in segment_results),
  333. "average_return": round(sum(returns) / len(returns), 6),
  334. "median_return": round(float(median(returns)), 6),
  335. "best_segment_return": round(max(returns), 6),
  336. "worst_segment_return": round(min(returns), 6),
  337. }
  338. output_file.write_text(
  339. render_sampled_report(
  340. symbol=symbol,
  341. bar=bar,
  342. leverage=leverage,
  343. history_limit=len(candles),
  344. segments=segments,
  345. window_size=window_size,
  346. report_title=report_title,
  347. strategy_label=strategy_label,
  348. strategy_description=strategy_description,
  349. strategy_params=strategy_params,
  350. aggregate_summary=aggregate_summary,
  351. segment_results=report_segments,
  352. bokeh_script=INLINE.render_js() + plot_script,
  353. )
  354. )
  355. return {
  356. "report_file": str(output_file),
  357. "segment_count": segments,
  358. "window_size": window_size,
  359. "aggregate_trade_count": aggregate_summary["aggregate_trade_count"],
  360. "average_return": aggregate_summary["average_return"],
  361. }