backtest.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from okx_codex_trader.models import BacktestResult, BacktestTrade, Candle
  2. from okx_codex_trader.strategy import simple_moving_average
  3. def run_backtest(candles: list[Candle], leverage: int) -> BacktestResult:
  4. if leverage is True or leverage is False or not isinstance(leverage, int) or not 1 <= leverage <= 3:
  5. raise ValueError("leverage is invalid")
  6. fast = simple_moving_average(candles, 10)
  7. slow = simple_moving_average(candles, 20)
  8. initial_equity = 10_000.0
  9. equity = initial_equity
  10. trades: list[BacktestTrade] = []
  11. wins = 0
  12. peak_equity = initial_equity
  13. max_drawdown = 0.0
  14. position: dict[str, float | str] | None = None
  15. for index in range(len(candles) - 1):
  16. if fast[index] is None or slow[index] is None:
  17. continue
  18. signal: str | None = None
  19. if index == 19:
  20. if fast[index] > slow[index]:
  21. signal = "long"
  22. elif fast[index] < slow[index]:
  23. signal = "short"
  24. elif fast[index - 1] is not None and slow[index - 1] is not None:
  25. if fast[index - 1] <= slow[index - 1] and fast[index] > slow[index]:
  26. signal = "long"
  27. elif fast[index - 1] >= slow[index - 1] and fast[index] < slow[index]:
  28. signal = "short"
  29. if signal is None:
  30. continue
  31. execution_price = candles[index + 1].open
  32. if position is not None and position["direction"] != signal:
  33. entry_price = float(position["entry_price"])
  34. margin_used = float(position["margin_used"])
  35. if position["direction"] == "long":
  36. price_return = (execution_price - entry_price) / entry_price
  37. else:
  38. price_return = (entry_price - execution_price) / entry_price
  39. ending_equity = margin_used + (margin_used * leverage * price_return)
  40. trades.append(
  41. BacktestTrade(
  42. direction=str(position["direction"]),
  43. entry_price=entry_price,
  44. exit_price=execution_price,
  45. margin_used=margin_used,
  46. ending_equity=ending_equity,
  47. )
  48. )
  49. equity = ending_equity
  50. if ending_equity > float(position["margin_used"]):
  51. wins += 1
  52. if equity > peak_equity:
  53. peak_equity = equity
  54. drawdown = (peak_equity - equity) / peak_equity
  55. if drawdown > max_drawdown:
  56. max_drawdown = drawdown
  57. position = None
  58. if position is None:
  59. position = {
  60. "direction": signal,
  61. "entry_price": execution_price,
  62. "margin_used": equity,
  63. }
  64. trade_count = len(trades)
  65. win_rate = wins / trade_count if trade_count else 0.0
  66. return BacktestResult(
  67. initial_equity=initial_equity,
  68. ending_equity=equity,
  69. total_return=(equity - initial_equity) / initial_equity,
  70. max_drawdown=max_drawdown,
  71. win_rate=win_rate,
  72. trade_count=trade_count,
  73. trades=trades,
  74. )