backtest.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from okx_codex_trader.models import BacktestResult, BacktestTrade, Candle
  2. from okx_codex_trader.strategy import simple_moving_average
  3. def run_backtest(candles: list[Candle], leverage: int) -> BacktestResult:
  4. if leverage is True or leverage is False or not isinstance(leverage, int) or not 1 <= leverage <= 3:
  5. raise ValueError("leverage is invalid")
  6. fast = simple_moving_average(candles, 10)
  7. slow = simple_moving_average(candles, 20)
  8. initial_equity = 10_000.0
  9. equity = initial_equity
  10. trades: list[BacktestTrade] = []
  11. wins = 0
  12. peak_equity = initial_equity
  13. max_drawdown = 0.0
  14. position: dict[str, float | str] | None = None
  15. for index in range(1, len(candles) - 1):
  16. if position is not None:
  17. entry_price = float(position["entry_price"])
  18. margin_used = float(position["margin_used"])
  19. if position["direction"] == "long":
  20. price_return = (candles[index].close - entry_price) / entry_price
  21. else:
  22. price_return = (entry_price - candles[index].close) / entry_price
  23. marked_equity = margin_used + (margin_used * leverage * price_return)
  24. if marked_equity > peak_equity:
  25. peak_equity = marked_equity
  26. drawdown = (peak_equity - marked_equity) / peak_equity
  27. if drawdown > max_drawdown:
  28. max_drawdown = drawdown
  29. if fast[index] is None or slow[index] is None:
  30. continue
  31. signal: str | None = None
  32. if fast[index - 1] is not None and slow[index - 1] is not None:
  33. if fast[index - 1] <= slow[index - 1] and fast[index] > slow[index]:
  34. signal = "long"
  35. elif fast[index - 1] >= slow[index - 1] and fast[index] < slow[index]:
  36. signal = "short"
  37. if signal is None:
  38. continue
  39. execution_price = candles[index + 1].open
  40. if position is not None and position["direction"] != signal:
  41. entry_price = float(position["entry_price"])
  42. margin_used = float(position["margin_used"])
  43. if position["direction"] == "long":
  44. price_return = (execution_price - entry_price) / entry_price
  45. else:
  46. price_return = (entry_price - execution_price) / entry_price
  47. ending_equity = margin_used + (margin_used * leverage * price_return)
  48. trades.append(
  49. BacktestTrade(
  50. direction=str(position["direction"]),
  51. entry_price=entry_price,
  52. exit_price=execution_price,
  53. margin_used=margin_used,
  54. ending_equity=ending_equity,
  55. )
  56. )
  57. equity = ending_equity
  58. if ending_equity > float(position["margin_used"]):
  59. wins += 1
  60. if equity > peak_equity:
  61. peak_equity = equity
  62. drawdown = (peak_equity - equity) / peak_equity
  63. if drawdown > max_drawdown:
  64. max_drawdown = drawdown
  65. position = None
  66. if position is None:
  67. position = {
  68. "direction": signal,
  69. "entry_price": execution_price,
  70. "margin_used": equity,
  71. }
  72. trade_count = len(trades)
  73. win_rate = wins / trade_count if trade_count else 0.0
  74. ending_equity = equity
  75. if position is not None:
  76. entry_price = float(position["entry_price"])
  77. margin_used = float(position["margin_used"])
  78. if position["direction"] == "long":
  79. price_return = (candles[-1].close - entry_price) / entry_price
  80. else:
  81. price_return = (entry_price - candles[-1].close) / entry_price
  82. ending_equity = margin_used + (margin_used * leverage * price_return)
  83. if ending_equity > peak_equity:
  84. peak_equity = ending_equity
  85. drawdown = (peak_equity - ending_equity) / peak_equity
  86. if drawdown > max_drawdown:
  87. max_drawdown = drawdown
  88. return BacktestResult(
  89. initial_equity=initial_equity,
  90. ending_equity=ending_equity,
  91. total_return=(ending_equity - initial_equity) / initial_equity,
  92. max_drawdown=max_drawdown,
  93. win_rate=win_rate,
  94. trade_count=trade_count,
  95. trades=trades,
  96. )