search_expansion_trend_swing.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. from __future__ import annotations
  2. import argparse
  3. import json
  4. import sys
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. import pandas as pd
  8. sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
  9. from okx_codex_trader.models import Candle
  10. from okx_codex_trader.sampled_report import SegmentResult, mark_to_market, trade_equity
  11. from scripts import explore_ultrashort as explore
  12. from scripts.search_eth_btc_nextgen_variants import format_cell, markdown_table
  13. OUTPUT_DIR = Path("reports/strategy-expansion")
  14. PREFIX = "trend-swing"
  15. SYMBOLS = ("BTC-USDT-SWAP", "ETH-USDT-SWAP")
  16. BARS = ("1H", "4H", "1D")
  17. YEARS = 10.0
  18. LEVERAGE = 3
  19. ROUNDTRIP_COST_ON_MARGIN = 0.0004 * 2 * LEVERAGE
  20. HORIZONS = (
  21. ("full", None),
  22. ("3y", pd.DateOffset(years=3)),
  23. ("1y", pd.DateOffset(years=1)),
  24. ("6m", pd.DateOffset(months=6)),
  25. ("3m", pd.DateOffset(months=3)),
  26. )
  27. @dataclass(frozen=True)
  28. class Params:
  29. symbol: str
  30. bar: str
  31. family: str
  32. fast: int
  33. slow: int
  34. entry: int
  35. exit: int
  36. atr: int
  37. stop_atr: float
  38. take_atr: float
  39. max_hold: int
  40. @property
  41. def name(self) -> str:
  42. return (
  43. f"{self.symbol}-{self.bar}-{self.family}"
  44. f"-f{self.fast}-s{self.slow}-e{self.entry}-x{self.exit}"
  45. f"-a{self.atr}-sl{self.stop_atr}-tp{self.take_atr}-mh{self.max_hold}"
  46. )
  47. def load_15m_frame(symbol: str, years: float) -> pd.DataFrame:
  48. path = explore.CANDLE_CACHE_DIR / symbol / "15m.csv"
  49. if not path.exists():
  50. raise FileNotFoundError(f"missing local cache: {path}")
  51. frame = pd.read_csv(path)
  52. frame["ts"] = pd.to_datetime(frame["ts"], unit="ms", utc=True)
  53. frame = frame.sort_values("ts").drop_duplicates("ts", keep="last").set_index("ts")
  54. start = frame.index[-1] - pd.DateOffset(years=years)
  55. return frame[frame.index >= start]
  56. def resample_frame(frame: pd.DataFrame, bar: str) -> pd.DataFrame:
  57. rule = {"1H": "1h", "4H": "4h", "1D": "1D"}[bar]
  58. out = frame.resample(rule, label="left", closed="left").agg(
  59. open=("open", "first"),
  60. high=("high", "max"),
  61. low=("low", "min"),
  62. close=("close", "last"),
  63. volume=("volume", "sum"),
  64. )
  65. return out.dropna()
  66. def frame_to_candles(symbol: str, frame: pd.DataFrame) -> list[Candle]:
  67. return [
  68. Candle(
  69. symbol=symbol,
  70. ts=int(ts.timestamp() * 1000),
  71. open=float(row.open),
  72. high=float(row.high),
  73. low=float(row.low),
  74. close=float(row.close),
  75. volume=float(row.volume),
  76. )
  77. for ts, row in frame.iterrows()
  78. ]
  79. def true_range(highs: pd.Series, lows: pd.Series, closes: pd.Series) -> pd.Series:
  80. previous = closes.shift(1)
  81. return pd.concat([(highs - lows), (highs - previous).abs(), (lows - previous).abs()], axis=1).max(axis=1)
  82. def close_trade(
  83. *,
  84. trades: list[dict[str, object]],
  85. exits: list[dict[str, object]],
  86. position: dict[str, object],
  87. candle: Candle,
  88. exit_price: float,
  89. ) -> tuple[float, bool]:
  90. exit_equity = trade_equity(
  91. side=str(position["side"]),
  92. margin_used=float(position["margin_used"]),
  93. entry_price=float(position["entry_price"]),
  94. exit_price=exit_price,
  95. leverage=LEVERAGE,
  96. )
  97. pnl = exit_equity - float(position["margin_used"])
  98. trades.append(
  99. {
  100. "side": "Long" if position["side"] == "long" else "Short",
  101. "entry_time": explore._format_ts(int(position["entry_time"])),
  102. "exit_time": explore._format_ts(candle.ts),
  103. "entry_price": round(float(position["entry_price"]), 4),
  104. "exit_price": round(exit_price, 4),
  105. "pnl": round(pnl, 4),
  106. "return_pct": round(pnl / float(position["margin_used"]) * 100, 4),
  107. }
  108. )
  109. exits.append({"ts": candle.ts, "price": exit_price, "side": position["side"]})
  110. return exit_equity, pnl > 0.0
  111. def run_segment(candles: list[Candle], params: Params) -> SegmentResult:
  112. highs = pd.Series([c.high for c in candles], dtype=float)
  113. lows = pd.Series([c.low for c in candles], dtype=float)
  114. closes = pd.Series([c.close for c in candles], dtype=float)
  115. fast = closes.ewm(span=params.fast, adjust=False).mean()
  116. slow = closes.ewm(span=params.slow, adjust=False).mean()
  117. atr = true_range(highs, lows, closes).rolling(params.atr).mean()
  118. entry_high = highs.shift(1).rolling(params.entry).max()
  119. entry_low = lows.shift(1).rolling(params.entry).min()
  120. exit_high = highs.shift(1).rolling(params.exit).max()
  121. exit_low = lows.shift(1).rolling(params.exit).min()
  122. rsi = explore._compute_rsi(closes, 5)
  123. warmup = max(params.slow, params.entry, params.exit, params.atr, 8)
  124. equity = explore.INITIAL_EQUITY
  125. ending_equity = equity
  126. peak_equity = equity
  127. max_drawdown = 0.0
  128. wins = 0
  129. trades: list[dict[str, object]] = []
  130. entries: list[dict[str, object]] = []
  131. exits: list[dict[str, object]] = []
  132. equity_curve: list[dict[str, float | int]] = []
  133. position: dict[str, object] | None = None
  134. pending_side: str | None = None
  135. pending_exit = False
  136. for index in range(warmup, len(candles)):
  137. candle = candles[index]
  138. if pending_exit and position is not None:
  139. equity, won = close_trade(trades=trades, exits=exits, position=position, candle=candle, exit_price=candle.open)
  140. wins += 1 if won else 0
  141. position = None
  142. pending_exit = False
  143. if pending_side is not None and position is None and equity > 0.0:
  144. side = pending_side
  145. current_atr = float(atr.iloc[index - 1])
  146. position = {
  147. "side": side,
  148. "entry_time": candle.ts,
  149. "entry_price": candle.open,
  150. "entry_index": index,
  151. "margin_used": equity,
  152. "stop_price": candle.open - params.stop_atr * current_atr if side == "long" else candle.open + params.stop_atr * current_atr,
  153. "take_price": candle.open + params.take_atr * current_atr if side == "long" else candle.open - params.take_atr * current_atr,
  154. }
  155. entries.append({"ts": candle.ts, "price": candle.open, "side": side})
  156. pending_side = None
  157. current_equity = equity
  158. if position is not None:
  159. side = str(position["side"])
  160. stop_hit = (side == "long" and candle.low <= float(position["stop_price"])) or (
  161. side == "short" and candle.high >= float(position["stop_price"])
  162. )
  163. take_hit = (side == "long" and candle.high >= float(position["take_price"])) or (
  164. side == "short" and candle.low <= float(position["take_price"])
  165. )
  166. if stop_hit or take_hit:
  167. exit_price = float(position["stop_price"] if stop_hit else position["take_price"])
  168. equity, won = close_trade(trades=trades, exits=exits, position=position, candle=candle, exit_price=exit_price)
  169. wins += 1 if won else 0
  170. current_equity = equity
  171. position = None
  172. if position is not None:
  173. current_equity = mark_to_market(
  174. side=str(position["side"]),
  175. margin_used=float(position["margin_used"]),
  176. entry_price=float(position["entry_price"]),
  177. mark_price=candle.close,
  178. leverage=LEVERAGE,
  179. )
  180. peak_equity = max(peak_equity, current_equity)
  181. max_drawdown = max(max_drawdown, (peak_equity - current_equity) / peak_equity)
  182. equity_curve.append({"ts": candle.ts, "equity": current_equity, "close": candle.close})
  183. ending_equity = current_equity
  184. if index == len(candles) - 1 or equity <= 0.0:
  185. continue
  186. if position is not None:
  187. held = index - int(position["entry_index"])
  188. side = str(position["side"])
  189. if side == "long":
  190. pending_exit = candle.close < float(exit_low.iloc[index]) or fast.iloc[index] < slow.iloc[index] or held >= params.max_hold
  191. else:
  192. pending_exit = candle.close > float(exit_high.iloc[index]) or fast.iloc[index] > slow.iloc[index] or held >= params.max_hold
  193. continue
  194. if params.family == "donchian":
  195. if candle.close > float(entry_high.iloc[index]) and fast.iloc[index] > slow.iloc[index]:
  196. pending_side = "long"
  197. elif candle.close < float(entry_low.iloc[index]) and fast.iloc[index] < slow.iloc[index]:
  198. pending_side = "short"
  199. elif params.family == "ema_cross":
  200. prev_fast = fast.iloc[index - 1]
  201. prev_slow = slow.iloc[index - 1]
  202. if prev_fast <= prev_slow and fast.iloc[index] > slow.iloc[index]:
  203. pending_side = "long"
  204. elif prev_fast >= prev_slow and fast.iloc[index] < slow.iloc[index]:
  205. pending_side = "short"
  206. elif params.family == "trend_pullback":
  207. if fast.iloc[index] > slow.iloc[index] and candle.close <= fast.iloc[index] and rsi[index] <= 45:
  208. pending_side = "long"
  209. elif fast.iloc[index] < slow.iloc[index] and candle.close >= fast.iloc[index] and rsi[index] >= 55:
  210. pending_side = "short"
  211. trade_count = len(trades)
  212. return SegmentResult(
  213. trade_count=trade_count,
  214. total_return=(ending_equity - explore.INITIAL_EQUITY) / explore.INITIAL_EQUITY,
  215. win_rate=wins / trade_count if trade_count else 0.0,
  216. max_drawdown=max_drawdown,
  217. trades=trades,
  218. open_position=position,
  219. candles=candles[warmup:],
  220. equity_curve=equity_curve,
  221. entries=entries,
  222. exits=exits,
  223. )
  224. def cost_adjusted_frame(result: SegmentResult) -> pd.DataFrame:
  225. rows = [{"ts": pd.to_datetime(result.equity_curve[0]["ts"], unit="ms", utc=True), "equity": explore.INITIAL_EQUITY}]
  226. equity = explore.INITIAL_EQUITY
  227. for trade in result.trades:
  228. equity *= 1.0 + float(trade["return_pct"]) / 100.0 - ROUNDTRIP_COST_ON_MARGIN
  229. rows.append({"ts": pd.to_datetime(str(trade["exit_time"]), utc=True), "equity": equity})
  230. return pd.DataFrame(rows)
  231. def daily_equity(frame: pd.DataFrame, start: pd.Timestamp, end: pd.Timestamp) -> pd.Series:
  232. series = frame.set_index("ts")["equity"].sort_index()
  233. index = pd.date_range(start.normalize(), end.normalize(), freq="1D", tz="UTC")
  234. return series.reindex(index.union(series.index)).sort_index().ffill().reindex(index).fillna(explore.INITIAL_EQUITY)
  235. def metrics_from_daily(series: pd.Series) -> dict[str, float]:
  236. years = (series.index[-1] - series.index[0]).total_seconds() / 86_400 / 365
  237. total = float(series.iloc[-1] / series.iloc[0] - 1.0)
  238. annual = (1.0 + total) ** (1.0 / years) - 1.0 if total > -1.0 and years > 0.0 else 0.0
  239. drawdown = explore.max_drawdown_from_equity([float(v) for v in series])
  240. return {
  241. "total_return": total,
  242. "annualized_return": annual,
  243. "max_drawdown": drawdown,
  244. "calmar": annual / drawdown if drawdown else 0.0,
  245. }
  246. def trade_stats(result: SegmentResult) -> dict[str, float | int]:
  247. returns = [float(trade["return_pct"]) / 100.0 - ROUNDTRIP_COST_ON_MARGIN for trade in result.trades]
  248. wins = [value for value in returns if value > 0.0]
  249. losses = [value for value in returns if value < 0.0]
  250. avg_win = sum(wins) / len(wins) if wins else 0.0
  251. avg_loss = abs(sum(losses) / len(losses)) if losses else 0.0
  252. return {
  253. "trades": len(returns),
  254. "win_rate": len(wins) / len(returns) if returns else 0.0,
  255. "payoff_ratio": avg_win / avg_loss if avg_loss else 0.0,
  256. }
  257. def horizon_metrics(series: pd.Series) -> dict[str, float]:
  258. out: dict[str, float] = {}
  259. end = series.index[-1]
  260. for label, offset in HORIZONS[1:]:
  261. scoped = series[series.index >= end - offset]
  262. if len(scoped) < 2:
  263. scoped = series
  264. out[f"return_{label}"] = float(scoped.iloc[-1] / scoped.iloc[0] - 1.0)
  265. return out
  266. def monthly_rows(name: str, params: Params, series: pd.Series) -> pd.DataFrame:
  267. monthly = series.resample("ME").last()
  268. frame = pd.DataFrame(
  269. {
  270. "name": name,
  271. "symbol": params.symbol,
  272. "bar": params.bar,
  273. "family": params.family,
  274. "month": monthly.index.strftime("%Y-%m"),
  275. "start_equity": monthly.shift(1).fillna(series.iloc[0]).to_numpy(),
  276. "end_equity": monthly.to_numpy(),
  277. }
  278. )
  279. frame["return"] = frame["end_equity"] / frame["start_equity"] - 1.0
  280. return frame
  281. def build_params() -> list[Params]:
  282. rows: list[Params] = []
  283. for symbol in SYMBOLS:
  284. for bar in BARS:
  285. scale = {"1H": 1, "4H": 1, "1D": 1}[bar]
  286. for family in ("donchian", "ema_cross", "trend_pullback"):
  287. for fast, slow in ((20 * scale, 80 * scale), (30 * scale, 120 * scale), (50 * scale, 200 * scale)):
  288. for entry, exit_ in ((20, 10), (55, 20)):
  289. for stop_atr, take_atr in ((2.0, 4.0), (3.0, 6.0)):
  290. max_hold = {"1H": 240, "4H": 120, "1D": 60}[bar]
  291. rows.append(
  292. Params(
  293. symbol=symbol,
  294. bar=bar,
  295. family=family,
  296. fast=fast,
  297. slow=slow,
  298. entry=entry,
  299. exit=exit_,
  300. atr=14,
  301. stop_atr=stop_atr,
  302. take_atr=take_atr,
  303. max_hold=max_hold,
  304. )
  305. )
  306. return rows
  307. def markdown_report(command: str, paths: list[Path], totals: pd.DataFrame, monthly: pd.DataFrame) -> str:
  308. top = totals.head(10)
  309. best_names = set(top.head(3)["name"])
  310. lines = [
  311. "# Trend swing expansion",
  312. "",
  313. f"Run command: `{command}`",
  314. "",
  315. "Output files:",
  316. *[f"- `{path}`" for path in paths],
  317. "",
  318. "Scope: BTC-USDT-SWAP and ETH-USDT-SWAP perpetuals, resampled from local 15m cache to 1H/4H/1D.",
  319. f"Cost: 0.04% single-side taker fee, roundtrip cost on margin = {ROUNDTRIP_COST_ON_MARGIN:.4%} at {LEVERAGE}x.",
  320. "",
  321. "## Top candidates",
  322. "",
  323. markdown_table(
  324. top[
  325. [
  326. "name",
  327. "symbol",
  328. "bar",
  329. "family",
  330. "trades",
  331. "total_return",
  332. "annualized_return",
  333. "max_drawdown",
  334. "calmar",
  335. "win_rate",
  336. "payoff_ratio",
  337. "return_3y",
  338. "return_1y",
  339. "return_6m",
  340. "return_3m",
  341. ]
  342. ]
  343. ),
  344. "",
  345. "## Monthly returns for top 3",
  346. "",
  347. markdown_table(monthly[monthly["name"].isin(best_names)].tail(120)),
  348. ]
  349. return "\n".join(lines) + "\n"
  350. def main() -> int:
  351. parser = argparse.ArgumentParser()
  352. parser.add_argument("--years", type=float, default=YEARS)
  353. parser.add_argument("--output-dir", type=Path, default=OUTPUT_DIR)
  354. parser.add_argument("--max-candidates", type=int, default=0)
  355. args = parser.parse_args()
  356. raw = {symbol: load_15m_frame(symbol, args.years) for symbol in SYMBOLS}
  357. candles = {
  358. (symbol, bar): frame_to_candles(symbol, resample_frame(raw[symbol], bar))
  359. for symbol in SYMBOLS
  360. for bar in BARS
  361. }
  362. params_grid = build_params()
  363. if args.max_candidates:
  364. params_grid = params_grid[: args.max_candidates]
  365. total_rows: list[dict[str, object]] = []
  366. monthly_frames: list[pd.DataFrame] = []
  367. for index, params in enumerate(params_grid, start=1):
  368. result = run_segment(candles[(params.symbol, params.bar)], params)
  369. frame = cost_adjusted_frame(result)
  370. start = pd.to_datetime(result.equity_curve[0]["ts"], unit="ms", utc=True)
  371. end = pd.to_datetime(result.equity_curve[-1]["ts"], unit="ms", utc=True)
  372. daily = daily_equity(frame, start, end)
  373. monthly = monthly_rows(params.name, params, daily)
  374. row = {
  375. "name": params.name,
  376. **params.__dict__,
  377. "first_candle": start.strftime("%Y-%m-%d %H:%M"),
  378. "last_candle": end.strftime("%Y-%m-%d %H:%M"),
  379. "fee_single_side": 0.0004,
  380. "roundtrip_cost_on_margin": ROUNDTRIP_COST_ON_MARGIN,
  381. "worst_month_return": float(monthly["return"].min()),
  382. **metrics_from_daily(daily),
  383. **trade_stats(result),
  384. **horizon_metrics(daily),
  385. }
  386. total_rows.append(row)
  387. monthly_frames.append(monthly)
  388. print(f"done {index}/{len(params_grid)} {params.name}", flush=True)
  389. totals = pd.DataFrame(total_rows).sort_values(
  390. ["calmar", "annualized_return", "max_drawdown", "trades"],
  391. ascending=[False, False, True, True],
  392. )
  393. monthly_all = pd.concat(monthly_frames, ignore_index=True)
  394. top3 = totals.head(3)
  395. args.output_dir.mkdir(parents=True, exist_ok=True)
  396. totals_path = args.output_dir / f"{PREFIX}-totals.csv"
  397. monthly_path = args.output_dir / f"{PREFIX}-monthly-returns.csv"
  398. top_path = args.output_dir / f"{PREFIX}-top3.csv"
  399. best_path = args.output_dir / f"{PREFIX}-best.json"
  400. report_path = args.output_dir / f"{PREFIX}-report.md"
  401. totals.to_csv(totals_path, index=False)
  402. monthly_all.to_csv(monthly_path, index=False)
  403. top3.to_csv(top_path, index=False)
  404. best_path.write_text(json.dumps(top3.to_dict(orient="records"), indent=2), encoding="utf-8")
  405. command = f"rtk .venv/bin/python {Path(__file__).as_posix()} --years {args.years}"
  406. report_path.write_text(
  407. markdown_report(command, [totals_path, monthly_path, top_path, best_path, report_path], totals, monthly_all),
  408. encoding="utf-8",
  409. )
  410. print(top3.to_string(index=False, formatters={col: format_cell for col in top3.columns}))
  411. return 0
  412. if __name__ == "__main__":
  413. raise SystemExit(main())