search_eth_relative_momentum.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from __future__ import annotations
  2. import argparse
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. import pandas as pd
  6. CACHE_DIR = Path("data/okx-candles")
  7. OUTPUT_DIR = Path("reports/eth-exploration")
  8. PREFIX = "eth-relative-momentum"
  9. INITIAL_EQUITY = 10_000.0
  10. TAKER_FEE = 0.0004
  11. HORIZONS = (
  12. ("full", None),
  13. ("3y", pd.DateOffset(years=3)),
  14. ("1y", pd.DateOffset(years=1)),
  15. ("6m", pd.DateOffset(months=6)),
  16. ("3m", pd.DateOffset(months=3)),
  17. )
  18. @dataclass(frozen=True)
  19. class Params:
  20. bar: str
  21. lookback: int
  22. trend: int
  23. rel_entry: float
  24. vol_quantile: float
  25. short_weight: float
  26. long_weight: float
  27. @property
  28. def name(self) -> str:
  29. return (
  30. f"eth_relmom-{self.bar}-lb{self.lookback}-tr{self.trend}"
  31. f"-re{self.rel_entry:.3f}-vq{self.vol_quantile:.1f}"
  32. f"-sw{self.short_weight:.2f}-lw{self.long_weight:.2f}"
  33. )
  34. def load_15m(symbol: str) -> pd.DataFrame:
  35. path = CACHE_DIR / symbol / "15m.csv"
  36. frame = pd.read_csv(path)
  37. frame["ts"] = pd.to_datetime(frame["ts"], unit="ms", utc=True)
  38. return frame.sort_values("ts").drop_duplicates("ts", keep="last").set_index("ts")
  39. def resample(frame: pd.DataFrame, bar: str) -> pd.DataFrame:
  40. rule = {"1H": "1h", "4H": "4h"}[bar]
  41. out = frame.resample(rule, label="left", closed="left").agg(
  42. open=("open", "first"),
  43. high=("high", "max"),
  44. low=("low", "min"),
  45. close=("close", "last"),
  46. volume=("volume", "sum"),
  47. )
  48. return out.dropna()
  49. def build_params() -> list[Params]:
  50. params: list[Params] = []
  51. for bar, lookbacks, trends in (
  52. ("1H", (24, 72, 168), (24 * 30, 24 * 60)),
  53. ("4H", (12, 42, 84), (6 * 30, 6 * 60)),
  54. ):
  55. for lookback in lookbacks:
  56. for trend in trends:
  57. for rel_entry in (0.015, 0.025, 0.04):
  58. for vol_quantile in (0.4, 0.7):
  59. for short_weight, long_weight in ((1.0, 0.0), (1.0, 0.25), (0.75, 0.25)):
  60. params.append(Params(bar, lookback, trend, rel_entry, vol_quantile, short_weight, long_weight))
  61. return params
  62. def target_position(closes: pd.DataFrame, params: Params) -> pd.Series:
  63. eth = closes["ETH-USDT-SWAP"]
  64. btc = closes["BTC-USDT-SWAP"]
  65. eth_momentum = eth / eth.shift(params.lookback) - 1.0
  66. btc_momentum = btc / btc.shift(params.lookback) - 1.0
  67. relative = eth_momentum - btc_momentum
  68. trend = eth.ewm(span=params.trend, adjust=False).mean()
  69. btc_trend = btc.ewm(span=params.trend, adjust=False).mean()
  70. eth_vol = eth.pct_change().rolling(params.lookback).std(ddof=1)
  71. vol_gate = eth_vol >= eth_vol.rolling(params.trend).quantile(params.vol_quantile)
  72. position = pd.Series(0.0, index=closes.index)
  73. short_signal = (relative <= -params.rel_entry) & (eth < trend) & vol_gate
  74. long_signal = (relative >= params.rel_entry) & (eth > trend) & (btc > btc_trend) & vol_gate
  75. position.loc[short_signal] = -params.short_weight
  76. position.loc[long_signal] = params.long_weight
  77. return position.fillna(0.0)
  78. def equity_curve(closes: pd.DataFrame, position: pd.Series) -> pd.Series:
  79. eth_returns = closes["ETH-USDT-SWAP"].pct_change().fillna(0.0)
  80. executed = position.shift(1).fillna(0.0)
  81. turnover = executed.diff().abs().fillna(executed.abs())
  82. net_returns = executed * eth_returns - turnover * TAKER_FEE
  83. equity = INITIAL_EQUITY * (1.0 + net_returns).cumprod()
  84. equity.name = "equity"
  85. return equity
  86. def trade_returns(closes: pd.DataFrame, position: pd.Series) -> list[dict[str, object]]:
  87. eth_returns = closes["ETH-USDT-SWAP"].pct_change().fillna(0.0)
  88. executed = position.shift(1).fillna(0.0)
  89. turnover = executed.diff().abs().fillna(executed.abs())
  90. net_returns = executed * eth_returns - turnover * TAKER_FEE
  91. active = executed != 0.0
  92. groups = (active.ne(active.shift(1)) | executed.ne(executed.shift(1))).cumsum()
  93. trades: list[dict[str, object]] = []
  94. for _, mask in active.groupby(groups):
  95. if not bool(mask.iloc[0]):
  96. continue
  97. index = mask.index
  98. returns = net_returns.loc[index]
  99. value = float((1.0 + returns).prod() - 1.0)
  100. side = "short" if float(executed.loc[index[0]]) < 0.0 else "long"
  101. trades.append({"side": side, "entry_time": index[0], "exit_time": index[-1], "return": value})
  102. return trades
  103. def series_metrics(series: pd.Series) -> dict[str, float]:
  104. years = (series.index[-1] - series.index[0]).total_seconds() / 86_400 / 365
  105. total = float(series.iloc[-1] / series.iloc[0] - 1.0)
  106. annualized = (1.0 + total) ** (1.0 / years) - 1.0 if total > -1.0 and years > 0.0 else 0.0
  107. drawdown = float((series.cummax() - series).div(series.cummax()).max())
  108. return {"total_return": total, "annualized_return": annualized, "max_drawdown": drawdown}
  109. def trade_metrics(trades: list[dict[str, object]], start: pd.Timestamp, end: pd.Timestamp) -> dict[str, float | int]:
  110. scoped = [float(trade["return"]) for trade in trades if start <= pd.Timestamp(trade["exit_time"]) <= end]
  111. wins = [value for value in scoped if value > 0.0]
  112. losses = [value for value in scoped if value < 0.0]
  113. gross_profit = sum(wins)
  114. gross_loss = abs(sum(losses))
  115. return {
  116. "win_rate": len(wins) / len(scoped) if scoped else 0.0,
  117. "profit_factor": gross_profit / gross_loss if gross_loss else 0.0,
  118. "trades": len(scoped),
  119. }
  120. def horizon_rows(name: str, params: Params, series: pd.Series, trades: list[dict[str, object]]) -> list[dict[str, object]]:
  121. rows: list[dict[str, object]] = []
  122. end = series.index[-1]
  123. for horizon, offset in HORIZONS:
  124. scoped = series if offset is None else series[series.index >= end - offset]
  125. if len(scoped) < 2:
  126. scoped = series
  127. start = scoped.index[0]
  128. rows.append(
  129. {
  130. "name": name,
  131. "horizon": horizon,
  132. "start": start.strftime("%Y-%m-%d"),
  133. "end": scoped.index[-1].strftime("%Y-%m-%d"),
  134. "bar": params.bar,
  135. "lookback": params.lookback,
  136. "trend": params.trend,
  137. "rel_entry": params.rel_entry,
  138. "vol_quantile": params.vol_quantile,
  139. "short_weight": params.short_weight,
  140. "long_weight": params.long_weight,
  141. **series_metrics(scoped),
  142. **trade_metrics(trades, start, scoped.index[-1]),
  143. }
  144. )
  145. return rows
  146. def markdown_table(frame: pd.DataFrame) -> str:
  147. values = [list(frame.columns), ["---" for _ in frame.columns]]
  148. values.extend(frame.astype(object).where(pd.notna(frame), "").values.tolist())
  149. lines = []
  150. for row in values:
  151. cells = []
  152. for value in row:
  153. cells.append(f"{value:.6g}" if isinstance(value, float) else str(value).replace("|", "\\|"))
  154. lines.append("| " + " | ".join(cells) + " |")
  155. return "\n".join(lines)
  156. def report_text(command: str, paths: list[Path], selected: pd.DataFrame, horizons: pd.DataFrame, qualified_count: int) -> str:
  157. names = set(selected["name"])
  158. selected_horizons = horizons[horizons["name"].isin(names)]
  159. conclusion = (
  160. "Worth continuing: at least one candidate is positive across full/3y/1y/6m/3m with controlled drawdown."
  161. if qualified_count
  162. else "Not worth continuing as a standalone direction: no candidate passed the positive full/3y/1y/6m/3m filter."
  163. )
  164. return "\n".join(
  165. [
  166. "# ETH Relative Momentum Exploration",
  167. "",
  168. f"Run command: `{command}`",
  169. "",
  170. "Output files:",
  171. *[f"- `{path}`" for path in paths],
  172. "",
  173. "Scope: offline ETH-USDT-SWAP strategy using cached OKX 15m candles resampled to 1H/4H, with BTC-USDT-SWAP only as a relative-momentum filter. No live code or exchange API path was used.",
  174. "Direction: bidirectional but short-biased; tested short-only and small-long variants.",
  175. "Cost: 0.04% taker fee on absolute notional turnover.",
  176. "",
  177. f"Conclusion: {conclusion}",
  178. "",
  179. "## Selected Candidates",
  180. "",
  181. markdown_table(selected),
  182. "",
  183. "## Required Horizons",
  184. "",
  185. markdown_table(selected_horizons),
  186. "",
  187. ]
  188. )
  189. def score(row: dict[str, object]) -> float:
  190. return (
  191. float(row["annualized_return"])
  192. - float(row["max_drawdown"])
  193. + 0.7 * float(row["return_1y"])
  194. + 0.4 * float(row["return_6m"])
  195. + 0.2 * float(row["return_3m"])
  196. )
  197. def main() -> int:
  198. parser = argparse.ArgumentParser()
  199. parser.add_argument("--output-dir", type=Path, default=OUTPUT_DIR)
  200. parser.add_argument("--top", type=int, default=25)
  201. args = parser.parse_args()
  202. args.output_dir.mkdir(parents=True, exist_ok=True)
  203. source = {symbol: load_15m(symbol) for symbol in ("ETH-USDT-SWAP", "BTC-USDT-SWAP")}
  204. closes_by_bar = {
  205. bar: pd.DataFrame(
  206. {
  207. symbol: resample(frame, bar)["close"]
  208. for symbol, frame in source.items()
  209. }
  210. ).dropna()
  211. for bar in ("1H", "4H")
  212. }
  213. totals: list[dict[str, object]] = []
  214. all_horizons: list[dict[str, object]] = []
  215. for params in build_params():
  216. closes = closes_by_bar[params.bar]
  217. position = target_position(closes, params)
  218. equity = equity_curve(closes, position)
  219. trades = trade_returns(closes, position)
  220. horizons = horizon_rows(params.name, params, equity, trades)
  221. by_horizon = {row["horizon"]: row for row in horizons}
  222. full = by_horizon["full"]
  223. row = {
  224. **full,
  225. "return_3y": float(by_horizon["3y"]["total_return"]),
  226. "return_1y": float(by_horizon["1y"]["total_return"]),
  227. "return_6m": float(by_horizon["6m"]["total_return"]),
  228. "return_3m": float(by_horizon["3m"]["total_return"]),
  229. }
  230. row["score"] = score(row)
  231. totals.append(row)
  232. all_horizons.extend(horizons)
  233. total = pd.DataFrame(totals).sort_values(["score", "annualized_return"], ascending=[False, False])
  234. qualified = total[
  235. (total["total_return"] > 0.0)
  236. & (total["return_3y"] > 0.0)
  237. & (total["return_1y"] > 0.0)
  238. & (total["return_6m"] > 0.0)
  239. & (total["return_3m"] > 0.0)
  240. & (total["max_drawdown"] <= 0.35)
  241. & (total["trades"] >= 20)
  242. & (total["profit_factor"] > 1.0)
  243. ].head(args.top)
  244. selected = qualified if len(qualified) else total.head(args.top)
  245. horizons = pd.DataFrame(all_horizons)
  246. selected_horizons = horizons[horizons["name"].isin(set(selected["name"]))]
  247. total_path = args.output_dir / f"{PREFIX}-totals.csv"
  248. selected_path = args.output_dir / f"{PREFIX}-selected.csv"
  249. horizon_path = args.output_dir / f"{PREFIX}-horizons.csv"
  250. report_path = args.output_dir / f"{PREFIX}-report.md"
  251. total.head(200).to_csv(total_path, index=False)
  252. selected.to_csv(selected_path, index=False)
  253. selected_horizons.to_csv(horizon_path, index=False)
  254. report_path.write_text(
  255. report_text(
  256. "rtk .venv/bin/python scripts/search_eth_relative_momentum.py",
  257. [total_path, selected_path, horizon_path, report_path],
  258. selected,
  259. selected_horizons,
  260. len(qualified),
  261. ),
  262. encoding="utf-8",
  263. )
  264. print(report_path)
  265. print(selected.head(10).to_string(index=False))
  266. return 0
  267. if __name__ == "__main__":
  268. raise SystemExit(main())