search_expansion_rotation.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from dataclasses import dataclass
  5. from itertools import product
  6. from pathlib import Path
  7. import pandas as pd
  8. sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
  9. from okx_codex_trader.models import Candle
  10. from okx_codex_trader.okx_client import OkxClient
  11. from scripts import explore_ultrashort as explore
  12. OUTPUT_DIR = Path("reports/strategy-expansion")
  13. PREFIX = "rotation"
  14. SYMBOLS = ("BTC-USDT-SWAP", "ETH-USDT-SWAP", "SOL-USDT-SWAP")
  15. BASE_BAR = "15m"
  16. BARS = ("1h", "4h", "1d")
  17. OKX_BARS = {"1h": "1H", "4h": "4H", "1d": "1Dutc"}
  18. BAR_MS = {"1h": 3_600_000, "4h": 14_400_000, "1d": 86_400_000}
  19. LEVERAGE = 3
  20. INITIAL_EQUITY = 10_000.0
  21. TAKER_FEE = 0.0004
  22. HORIZONS = (
  23. ("3y", pd.DateOffset(years=3)),
  24. ("1y", pd.DateOffset(years=1)),
  25. ("6m", pd.DateOffset(months=6)),
  26. ("3m", pd.DateOffset(months=3)),
  27. )
  28. @dataclass(frozen=True)
  29. class Params:
  30. family: str
  31. bar: str
  32. lookback: int
  33. trend: int
  34. btc_trend: int
  35. rebalance: int
  36. top_n: int
  37. min_momentum: float
  38. btc_min_momentum: float
  39. vol_lookback: int
  40. max_vol: float
  41. @property
  42. def name(self) -> str:
  43. return (
  44. f"{self.family}-{self.bar}-lb{self.lookback}-tr{self.trend}-bt{self.btc_trend}"
  45. f"-rb{self.rebalance}-top{self.top_n}-mm{self.min_momentum:.3f}"
  46. f"-bm{self.btc_min_momentum:.3f}-vw{self.vol_lookback}-vc{self.max_vol:.4f}"
  47. )
  48. def load_local_candles(symbol: str, bar: str) -> list[Candle]:
  49. candles, _ = explore.load_cached_candles(explore.CANDLE_CACHE_DIR, symbol, bar)
  50. return candles
  51. def frame_from_candles(candles: list[Candle]) -> pd.DataFrame:
  52. frame = pd.DataFrame(
  53. {
  54. "ts": [pd.to_datetime(candle.ts, unit="ms", utc=True) for candle in candles],
  55. "open": [candle.open for candle in candles],
  56. "high": [candle.high for candle in candles],
  57. "low": [candle.low for candle in candles],
  58. "close": [candle.close for candle in candles],
  59. "volume": [candle.volume for candle in candles],
  60. }
  61. )
  62. return frame.set_index("ts").sort_index()
  63. def aggregate_frame(frame: pd.DataFrame, bar: str) -> pd.DataFrame:
  64. rule = {"1h": "1h", "4h": "4h", "1d": "1D"}[bar]
  65. aggregated = frame.resample(rule, label="left", closed="left").agg(
  66. {"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"}
  67. )
  68. return aggregated.dropna()
  69. def fetch_okx_frame(symbol: str, bar: str, years: float) -> pd.DataFrame:
  70. interval_ms = BAR_MS[bar]
  71. limit = int(years * 365 * 86_400_000 / interval_ms) + 500
  72. candles = OkxClient().get_candles(symbol, OKX_BARS[bar], limit)
  73. if not candles:
  74. raise FileNotFoundError(f"missing OKX candles for {symbol} {bar}")
  75. return frame_from_candles(candles)
  76. def load_symbol_bar_frames(years: float) -> dict[tuple[str, str], pd.DataFrame]:
  77. frames: dict[tuple[str, str], pd.DataFrame] = {}
  78. for symbol in SYMBOLS:
  79. local = load_local_candles(symbol, BASE_BAR)
  80. base = frame_from_candles(local) if local else None
  81. for bar in BARS:
  82. if base is not None:
  83. frame = aggregate_frame(base, bar)
  84. else:
  85. frame = fetch_okx_frame(symbol, bar, years)
  86. cutoff = frame.index[-1] - pd.DateOffset(years=years)
  87. frames[(symbol, bar)] = frame[frame.index >= cutoff]
  88. return frames
  89. def aligned_closes(frames: dict[tuple[str, str], pd.DataFrame], params: Params) -> pd.DataFrame:
  90. required = max(params.lookback, params.trend, params.btc_trend, params.vol_lookback) + max(params.rebalance * 6, 120)
  91. full = pd.DataFrame({symbol: frames[(symbol, params.bar)]["close"] for symbol in SYMBOLS}).dropna()
  92. if len(full) >= required:
  93. return full
  94. btc_eth = pd.DataFrame({symbol: frames[(symbol, params.bar)]["close"] for symbol in SYMBOLS[:2]}).dropna()
  95. if len(btc_eth) < required:
  96. raise ValueError(f"insufficient aligned candles for {params.name}")
  97. return btc_eth
  98. def build_params() -> list[Params]:
  99. params: list[Params] = []
  100. for bar in BARS:
  101. if bar == "1h":
  102. lookbacks = (24 * 14, 24 * 30)
  103. trends = (24 * 30, 24 * 60)
  104. btc_trends = (24 * 120,)
  105. rebalances = (24 * 3, 24 * 7)
  106. vol_lookbacks = (24 * 14,)
  107. elif bar == "4h":
  108. lookbacks = (6 * 30, 6 * 60)
  109. trends = (6 * 60, 6 * 120)
  110. btc_trends = (6 * 180,)
  111. rebalances = (6 * 7, 6 * 14)
  112. vol_lookbacks = (6 * 30,)
  113. else:
  114. lookbacks = (60, 120)
  115. trends = (120, 200)
  116. btc_trends = (200,)
  117. rebalances = (14, 30)
  118. vol_lookbacks = (30,)
  119. for family, lookback, trend, btc_trend, rebalance, top_n, min_momentum, btc_min_momentum, max_vol in product(
  120. ("dual_momentum", "trend_basket"),
  121. lookbacks,
  122. trends,
  123. btc_trends,
  124. rebalances,
  125. (1, 2),
  126. (0.0, 0.03),
  127. (0.0,),
  128. (0.055,),
  129. ):
  130. if family == "dual_momentum" and top_n != 1:
  131. continue
  132. params.append(
  133. Params(
  134. family=family,
  135. bar=bar,
  136. lookback=lookback,
  137. trend=trend,
  138. btc_trend=btc_trend,
  139. rebalance=rebalance,
  140. top_n=top_n,
  141. min_momentum=min_momentum,
  142. btc_min_momentum=btc_min_momentum,
  143. vol_lookback=vol_lookbacks[0],
  144. max_vol=max_vol,
  145. )
  146. )
  147. return params
  148. def target_weights(closes: pd.DataFrame, params: Params) -> pd.DataFrame:
  149. momentum = closes / closes.shift(params.lookback) - 1.0
  150. trend = closes > closes.rolling(params.trend).mean()
  151. btc_trend = closes["BTC-USDT-SWAP"] > closes["BTC-USDT-SWAP"].rolling(params.btc_trend).mean()
  152. btc_momentum = closes["BTC-USDT-SWAP"] / closes["BTC-USDT-SWAP"].shift(params.lookback) - 1.0
  153. btc_vol = closes["BTC-USDT-SWAP"].pct_change().rolling(params.vol_lookback).std(ddof=1)
  154. risk_on = btc_trend & (btc_momentum >= params.btc_min_momentum) & (btc_vol <= params.max_vol)
  155. weights = pd.DataFrame(0.0, index=closes.index, columns=closes.columns)
  156. for index in range(max(params.lookback, params.trend, params.btc_trend, params.vol_lookback), len(closes), params.rebalance):
  157. if not bool(risk_on.iloc[index]):
  158. continue
  159. current_momentum = momentum.iloc[index]
  160. eligible = current_momentum[(current_momentum >= params.min_momentum) & trend.iloc[index]]
  161. if eligible.empty:
  162. continue
  163. if params.family == "dual_momentum":
  164. selected = eligible.sort_values(ascending=False).head(1).index
  165. else:
  166. selected = eligible.sort_values(ascending=False).head(params.top_n).index
  167. weights.loc[closes.index[index], selected] = 1.0 / len(selected)
  168. return weights.replace(0.0, pd.NA).ffill(limit=max(params.rebalance - 1, 1)).fillna(0.0)
  169. def trade_stats(weights: pd.DataFrame, closes: pd.DataFrame) -> dict[str, float | int]:
  170. returns = closes.pct_change().shift(-1)
  171. wins = 0
  172. losses = 0
  173. gross_profit = 0.0
  174. gross_loss = 0.0
  175. trades = 0
  176. for symbol in closes.columns:
  177. active = weights[symbol] > 0.0
  178. group = (active != active.shift(1)).cumsum()
  179. for _, mask in active.groupby(group):
  180. if not bool(mask.iloc[0]):
  181. continue
  182. trade_return = float((1.0 + returns.loc[mask.index, symbol].dropna() * LEVERAGE).prod() - 1.0)
  183. trade_return -= TAKER_FEE * LEVERAGE * 2.0
  184. trades += 1
  185. if trade_return > 0.0:
  186. wins += 1
  187. gross_profit += trade_return
  188. else:
  189. losses += 1
  190. gross_loss += abs(trade_return)
  191. return {
  192. "trades": trades,
  193. "win_rate": wins / trades if trades else 0.0,
  194. "profit_factor": gross_profit / gross_loss if gross_loss else 0.0,
  195. }
  196. def equity_curve(closes: pd.DataFrame, weights: pd.DataFrame) -> pd.Series:
  197. returns = closes.pct_change().fillna(0.0)
  198. executed = weights.shift(1).fillna(0.0)
  199. turnover = executed.diff().abs().sum(axis=1).fillna(executed.abs().sum(axis=1))
  200. net_returns = (executed * returns * LEVERAGE).sum(axis=1) - turnover * TAKER_FEE * LEVERAGE
  201. equity = INITIAL_EQUITY * (1.0 + net_returns).cumprod()
  202. equity.name = "equity"
  203. return equity
  204. def metrics(series: pd.Series) -> dict[str, float]:
  205. years = (series.index[-1] - series.index[0]).total_seconds() / 86_400 / 365
  206. total = float(series.iloc[-1] / series.iloc[0] - 1.0)
  207. annualized = (1.0 + total) ** (1.0 / years) - 1.0 if total > -1.0 and years > 0.0 else 0.0
  208. drawdown = float((series.cummax() - series).div(series.cummax()).max())
  209. return {
  210. "total_return": total,
  211. "annualized_return": annualized,
  212. "max_drawdown": drawdown,
  213. "calmar": annualized / drawdown if drawdown else 0.0,
  214. }
  215. def horizon_rows(name: str, series: pd.Series) -> list[dict[str, object]]:
  216. rows: list[dict[str, object]] = []
  217. end = series.index[-1]
  218. for label, offset in HORIZONS:
  219. horizon = series[series.index >= end - offset]
  220. if len(horizon) < 2:
  221. horizon = series
  222. rows.append(
  223. {
  224. "strategy": name,
  225. "horizon": label,
  226. "start": horizon.index[0].strftime("%Y-%m-%d"),
  227. "end": horizon.index[-1].strftime("%Y-%m-%d"),
  228. **metrics(horizon),
  229. }
  230. )
  231. return rows
  232. def monthly_rows(name: str, series: pd.Series) -> pd.DataFrame:
  233. monthly = series.resample("ME").last()
  234. frame = pd.DataFrame(
  235. {
  236. "strategy": name,
  237. "month": monthly.index.strftime("%Y-%m"),
  238. "start_equity": monthly.shift(1).fillna(series.iloc[0]).to_numpy(),
  239. "end_equity": monthly.to_numpy(),
  240. }
  241. )
  242. frame["return"] = frame["end_equity"] / frame["start_equity"] - 1.0
  243. return frame
  244. def markdown_table(frame: pd.DataFrame) -> str:
  245. rows = [list(frame.columns), ["---" for _ in frame.columns]]
  246. rows.extend(frame.astype(object).where(pd.notna(frame), "").values.tolist())
  247. return "\n".join("| " + " | ".join(format_cell(value) for value in row) + " |" for row in rows)
  248. def format_cell(value: object) -> str:
  249. if isinstance(value, float):
  250. return f"{value:.6g}"
  251. return str(value).replace("|", "\\|")
  252. def report_text(command: str, paths: list[Path], top: pd.DataFrame, horizons: pd.DataFrame, monthly: pd.DataFrame) -> str:
  253. best_names = set(top.head(3)["strategy"])
  254. recent = horizons[horizons["strategy"].isin(best_names)]
  255. best_monthly = monthly[monthly["strategy"].isin(best_names)]
  256. return "\n".join(
  257. [
  258. "# Rotation strategy expansion",
  259. "",
  260. f"Run command: `{command}`",
  261. "",
  262. "Output files:",
  263. *[f"- `{path}`" for path in paths],
  264. "",
  265. "Cost model: 0.04% one-way taker fee, charged on leveraged notional at each portfolio weight change.",
  266. "Universe: BTC-USDT-SWAP, ETH-USDT-SWAP, SOL-USDT-SWAP OKX perpetual swaps when the aligned history is long enough for that candidate; otherwise the row records the BTC/ETH universe used. BTC/ETH use local 15m cache aggregated upward; SOL uses OKX historical candles in memory.",
  267. "",
  268. "## Top candidates",
  269. "",
  270. markdown_table(
  271. top.head(10)[
  272. [
  273. "strategy",
  274. "family",
  275. "bar",
  276. "universe",
  277. "total_return",
  278. "annualized_return",
  279. "max_drawdown",
  280. "calmar",
  281. "trades",
  282. "win_rate",
  283. "profit_factor",
  284. "turnover_per_year",
  285. ]
  286. ]
  287. ),
  288. "",
  289. "## Recent horizons for top 3",
  290. "",
  291. markdown_table(recent),
  292. "",
  293. "## Monthly returns for top 3",
  294. "",
  295. markdown_table(best_monthly.tail(36)),
  296. "",
  297. ]
  298. )
  299. def main() -> int:
  300. parser = argparse.ArgumentParser()
  301. parser.add_argument("--years", type=float, default=8.0)
  302. parser.add_argument("--output-dir", type=Path, default=OUTPUT_DIR)
  303. parser.add_argument("--top", type=int, default=30)
  304. args = parser.parse_args()
  305. frames = load_symbol_bar_frames(args.years)
  306. totals: list[dict[str, object]] = []
  307. horizon_output: list[dict[str, object]] = []
  308. monthly_output: list[pd.DataFrame] = []
  309. params = build_params()
  310. for index, param in enumerate(params, start=1):
  311. closes = aligned_closes(frames, param)
  312. weights = target_weights(closes, param)
  313. equity = equity_curve(closes, weights)
  314. stat = trade_stats(weights, closes)
  315. years = (equity.index[-1] - equity.index[0]).total_seconds() / 86_400 / 365
  316. turnover_per_year = float(weights.shift(1).fillna(0.0).diff().abs().sum(axis=1).sum() / years)
  317. row = {
  318. "strategy": param.name,
  319. "family": param.family,
  320. "bar": param.bar,
  321. "universe": ",".join(closes.columns),
  322. "lookback": param.lookback,
  323. "trend": param.trend,
  324. "btc_trend": param.btc_trend,
  325. "rebalance": param.rebalance,
  326. "top_n": param.top_n,
  327. "min_momentum": param.min_momentum,
  328. "btc_min_momentum": param.btc_min_momentum,
  329. "vol_lookback": param.vol_lookback,
  330. "max_vol": param.max_vol,
  331. "first_candle": equity.index[0].strftime("%Y-%m-%d %H:%M"),
  332. "last_candle": equity.index[-1].strftime("%Y-%m-%d %H:%M"),
  333. "years": years,
  334. "turnover_per_year": turnover_per_year,
  335. **metrics(equity),
  336. **stat,
  337. }
  338. totals.append(row)
  339. for horizon in horizon_rows(param.name, equity):
  340. horizon_output.append(horizon)
  341. monthly_output.append(monthly_rows(param.name, equity))
  342. if index % 100 == 0:
  343. print(f"done {index}/{len(params)}")
  344. total = pd.DataFrame(totals)
  345. positive_recent = pd.DataFrame(horizon_output).pivot(index="strategy", columns="horizon", values="total_return")
  346. stable = total[
  347. total["strategy"].isin(
  348. positive_recent[
  349. (positive_recent["3y"] > 0.0)
  350. & (positive_recent["1y"] > 0.0)
  351. & (positive_recent["6m"] > -0.05)
  352. & (positive_recent["3m"] > -0.05)
  353. ].index
  354. )
  355. ]
  356. ranked = stable.sort_values(
  357. ["calmar", "max_drawdown", "annualized_return", "turnover_per_year"],
  358. ascending=[False, True, False, True],
  359. )
  360. if ranked.empty:
  361. ranked = total.sort_values(["calmar", "max_drawdown", "annualized_return"], ascending=[False, True, False])
  362. top = ranked.head(args.top)
  363. top_names = set(top["strategy"])
  364. horizons = pd.DataFrame(horizon_output)
  365. horizons = horizons[horizons["strategy"].isin(top_names)]
  366. monthly = pd.concat(monthly_output, ignore_index=True)
  367. monthly = monthly[monthly["strategy"].isin(top_names)]
  368. args.output_dir.mkdir(parents=True, exist_ok=True)
  369. total_path = args.output_dir / f"{PREFIX}-total.csv"
  370. top_path = args.output_dir / f"{PREFIX}-top.csv"
  371. horizon_path = args.output_dir / f"{PREFIX}-horizons.csv"
  372. monthly_path = args.output_dir / f"{PREFIX}-monthly.csv"
  373. report_path = args.output_dir / f"{PREFIX}-report.md"
  374. paths = [total_path, top_path, horizon_path, monthly_path, report_path]
  375. total.to_csv(total_path, index=False)
  376. top.to_csv(top_path, index=False)
  377. horizons.to_csv(horizon_path, index=False)
  378. monthly.to_csv(monthly_path, index=False)
  379. command = f"rtk .venv/bin/python scripts/search_expansion_rotation.py --years {args.years} --top {args.top}"
  380. report_path.write_text(report_text(command, paths, top, horizons, monthly), encoding="utf-8")
  381. print(top.head(10).to_string(index=False))
  382. return 0
  383. if __name__ == "__main__":
  384. raise SystemExit(main())