research_metrics.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. import pandas as pd
  3. from okx_codex_trader.sampled_report import SegmentResult
  4. DEFAULT_INITIAL_EQUITY = 10_000.0
  5. DEFAULT_PRIMARY_COST = "maker_taker"
  6. DEFAULT_COSTS = (
  7. ("maker_maker", 0.0012),
  8. ("maker_taker", 0.0021),
  9. ("taker_taker", 0.0030),
  10. )
  11. DEFAULT_HORIZONS = (
  12. ("3y", pd.DateOffset(years=3)),
  13. ("1y", pd.DateOffset(years=1)),
  14. ("6m", pd.DateOffset(months=6)),
  15. ("3m", pd.DateOffset(months=3)),
  16. ("30d", pd.DateOffset(days=30)),
  17. )
  18. def format_utc_ts(ts: int) -> str:
  19. return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
  20. def cost_equity_frame(result: SegmentResult, cost: float, initial_equity: float = DEFAULT_INITIAL_EQUITY) -> pd.DataFrame:
  21. rows = [{"ts": pd.to_datetime(result.equity_curve[0]["ts"], unit="ms", utc=True), "equity": initial_equity}]
  22. equity = initial_equity
  23. for trade in result.trades:
  24. equity *= 1.0 + float(trade["return_pct"]) / 100.0 - cost * float(trade.get("cost_weight", 1.0))
  25. rows.append({"ts": pd.to_datetime(str(trade["exit_time"]), utc=True), "equity": equity})
  26. return pd.DataFrame(rows)
  27. def max_drawdown(values: list[float]) -> float:
  28. peak = values[0]
  29. dd = 0.0
  30. for value in values:
  31. peak = max(peak, value)
  32. dd = max(dd, (peak - value) / peak if peak else 0.0)
  33. return dd
  34. def equity_metrics(frame: pd.DataFrame, first_ts: int, last_ts: int) -> dict[str, float]:
  35. years = (last_ts - first_ts) / 86_400_000 / 365
  36. total_return = float(frame["equity"].iloc[-1] / frame["equity"].iloc[0] - 1.0)
  37. annualized = (1.0 + total_return) ** (1.0 / years) - 1.0 if total_return > -1.0 and years > 0.0 else 0.0
  38. dd = max_drawdown([float(value) for value in frame["equity"]])
  39. return {
  40. "net_total_return": total_return,
  41. "net_annualized_return": annualized,
  42. "net_max_drawdown": dd,
  43. "net_calmar": annualized / dd if dd else 0.0,
  44. }
  45. def horizon_rows(
  46. frame: pd.DataFrame,
  47. last_ts: int,
  48. horizons: tuple[tuple[str, pd.DateOffset], ...] = DEFAULT_HORIZONS,
  49. ) -> list[dict[str, object]]:
  50. rows: list[dict[str, object]] = []
  51. end_time = pd.to_datetime(last_ts, unit="ms", utc=True)
  52. for label, offset in horizons:
  53. cutoff = end_time - offset
  54. before = frame[frame["ts"] <= cutoff]
  55. if len(before):
  56. start_equity = float(before["equity"].iloc[-1])
  57. start_time = cutoff
  58. after = frame[frame["ts"] > cutoff]
  59. horizon_frame = pd.concat([pd.DataFrame([{"ts": cutoff, "equity": start_equity}]), after[["ts", "equity"]]], ignore_index=True)
  60. else:
  61. horizon_frame = frame[["ts", "equity"]].copy()
  62. start_time = pd.Timestamp(horizon_frame["ts"].iloc[0])
  63. rows.append(
  64. {
  65. "horizon": label,
  66. "horizon_start": start_time.strftime("%Y-%m-%d %H:%M"),
  67. "horizon_end": end_time.strftime("%Y-%m-%d %H:%M"),
  68. **equity_metrics(horizon_frame, int(start_time.timestamp() * 1000), last_ts),
  69. }
  70. )
  71. return rows
  72. def trade_stats(trades: list[dict[str, object]]) -> dict[str, float]:
  73. if not trades:
  74. return {"avg_return_pct": 0.0, "payoff_ratio": 0.0, "profit_factor": 0.0}
  75. returns = [float(trade["return_pct"]) for trade in trades]
  76. wins = [value for value in returns if value > 0.0]
  77. losses = [-value for value in returns if value < 0.0]
  78. return {
  79. "avg_return_pct": sum(returns) / len(returns),
  80. "payoff_ratio": (sum(wins) / len(wins)) / (sum(losses) / len(losses)) if wins and losses else 0.0,
  81. "profit_factor": sum(wins) / sum(losses) if losses else 0.0,
  82. }
  83. def worst_month(frame: pd.DataFrame) -> tuple[str, float]:
  84. monthly = frame.set_index("ts")["equity"].resample("ME").last().ffill().pct_change().dropna()
  85. if not len(monthly):
  86. return "", 0.0
  87. idx = monthly.idxmin()
  88. return idx.strftime("%Y-%m"), float(monthly.loc[idx])