test_rsi2_report.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import pytest
  2. from okx_codex_trader import rsi2_report
  3. from okx_codex_trader.models import Candle
  4. from okx_codex_trader.rsi2_report import RSI2Config, generate_rsi2_sampled_report, run_rsi2_segment
  5. from okx_codex_trader.sampled_report import SegmentResult
  6. def make_candle(index: int, open_price: float, close: float) -> Candle:
  7. high = max(open_price, close) + 1.0
  8. low = min(open_price, close) - 1.0
  9. return Candle(
  10. symbol="BTC-USDT-SWAP",
  11. ts=index * 60_000,
  12. open=open_price,
  13. high=high,
  14. low=low,
  15. close=close,
  16. volume=1_000.0 + index,
  17. )
  18. def build_linear_candles(count: int) -> list[Candle]:
  19. candles: list[Candle] = []
  20. for index in range(count):
  21. price = 100.0 + index
  22. candles.append(make_candle(index, price, price))
  23. return candles
  24. def build_long_trade_fixture() -> list[Candle]:
  25. return [
  26. make_candle(0, 90.0, 90.0),
  27. make_candle(1, 100.0, 100.0),
  28. make_candle(2, 150.0, 150.0),
  29. make_candle(3, 149.0, 149.0),
  30. make_candle(4, 148.0, 148.0),
  31. make_candle(5, 149.0, 149.0),
  32. make_candle(6, 150.0, 150.0),
  33. ]
  34. def build_short_trade_fixture() -> list[Candle]:
  35. return [
  36. make_candle(0, 160.0, 160.0),
  37. make_candle(1, 150.0, 150.0),
  38. make_candle(2, 100.0, 100.0),
  39. make_candle(3, 101.0, 101.0),
  40. make_candle(4, 102.0, 102.0),
  41. make_candle(5, 101.0, 101.0),
  42. make_candle(6, 100.0, 100.0),
  43. ]
  44. def build_exit_priority_fixture() -> list[Candle]:
  45. return [
  46. make_candle(0, 90.0, 90.0),
  47. make_candle(1, 100.0, 100.0),
  48. make_candle(2, 150.0, 150.0),
  49. make_candle(3, 149.1, 149.1),
  50. make_candle(4, 149.0, 149.0),
  51. make_candle(5, 149.1, 149.1),
  52. make_candle(6, 149.2, 149.2),
  53. ]
  54. def build_final_bar_signal_fixture() -> list[Candle]:
  55. return [
  56. make_candle(0, 90.0, 90.0),
  57. make_candle(1, 100.0, 100.0),
  58. make_candle(2, 150.0, 150.0),
  59. make_candle(3, 149.0, 149.0),
  60. make_candle(4, 148.0, 148.0),
  61. ]
  62. def build_open_tail_fixture() -> list[Candle]:
  63. return [
  64. make_candle(0, 90.0, 90.0),
  65. make_candle(1, 100.0, 100.0),
  66. make_candle(2, 150.0, 150.0),
  67. make_candle(3, 149.0, 149.0),
  68. make_candle(4, 148.0, 148.0),
  69. make_candle(5, 149.0, 151.0),
  70. ]
  71. def build_depleted_equity_fixture() -> list[Candle]:
  72. return [
  73. make_candle(0, 90.0, 90.0),
  74. make_candle(1, 100.0, 100.0),
  75. make_candle(2, 150.0, 150.0),
  76. make_candle(3, 149.0, 149.0),
  77. make_candle(4, 148.0, 148.0),
  78. make_candle(5, 149.0, 151.0),
  79. make_candle(6, 0.0, 100.0),
  80. make_candle(7, 150.0, 150.0),
  81. make_candle(8, 149.0, 149.0),
  82. make_candle(9, 150.0, 150.0),
  83. ]
  84. def test_compute_rsi_uses_wilder_smoothing():
  85. closes = rsi2_report.pd.Series([100.0, 102.07, 103.62, 103.14, 101.69], dtype=float)
  86. rsi = rsi2_report._compute_rsi(closes, 2)
  87. assert rsi[4] == pytest.approx(34.8747591522)
  88. def test_run_rsi2_segment_produces_long_trade():
  89. result = run_rsi2_segment(
  90. candles=build_long_trade_fixture(),
  91. leverage=2,
  92. warmup_bars=4,
  93. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_long_threshold=95.0),
  94. )
  95. assert isinstance(result, SegmentResult)
  96. assert result.trade_count == 1
  97. assert result.trades[0]["side"] == "Long"
  98. assert result.trades[0]["entry_price"] == pytest.approx(149.0)
  99. assert result.trades[0]["exit_price"] == pytest.approx(150.0)
  100. assert result.open_position is None
  101. def test_run_rsi2_segment_produces_short_trade():
  102. result = run_rsi2_segment(
  103. candles=build_short_trade_fixture(),
  104. leverage=2,
  105. warmup_bars=4,
  106. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_short_threshold=5.0),
  107. )
  108. assert isinstance(result, SegmentResult)
  109. assert result.trade_count == 1
  110. assert result.trades[0]["side"] == "Short"
  111. assert result.trades[0]["entry_price"] == pytest.approx(101.0)
  112. assert result.trades[0]["exit_price"] == pytest.approx(100.0)
  113. assert result.open_position is None
  114. def test_run_rsi2_segment_exit_priority_is_correct():
  115. result = run_rsi2_segment(
  116. candles=build_exit_priority_fixture(),
  117. leverage=2,
  118. warmup_bars=4,
  119. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_long_threshold=98.0, rsi_short_threshold=50.0, exit_rsi=96.5),
  120. )
  121. assert result.trade_count == 1
  122. assert len(result.entries) == 1
  123. assert result.trades[0]["side"] == "Long"
  124. assert result.open_position is None
  125. def test_run_rsi2_segment_does_not_generate_entry_from_final_bar():
  126. result = run_rsi2_segment(
  127. candles=build_final_bar_signal_fixture(),
  128. leverage=2,
  129. warmup_bars=4,
  130. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_long_threshold=95.0),
  131. )
  132. assert result.trade_count == 0
  133. assert result.entries == []
  134. assert result.open_position is None
  135. def test_run_rsi2_segment_marks_open_position_to_market():
  136. result = run_rsi2_segment(
  137. candles=build_open_tail_fixture(),
  138. leverage=2,
  139. warmup_bars=4,
  140. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_long_threshold=95.0),
  141. )
  142. assert result.trade_count == 0
  143. assert result.trades == []
  144. assert result.total_return == pytest.approx((10_268.456375838927 - 10_000.0) / 10_000.0)
  145. assert result.open_position is not None
  146. assert result.open_position["side"] == "long"
  147. def test_run_rsi2_segment_stops_after_equity_is_depleted():
  148. result = run_rsi2_segment(
  149. candles=build_depleted_equity_fixture(),
  150. leverage=2,
  151. warmup_bars=4,
  152. config=RSI2Config(trend_sma=4, rsi_length=2, rsi_long_threshold=95.0),
  153. )
  154. assert result.trade_count == 1
  155. assert result.open_position is None
  156. assert len(result.entries) == 1
  157. assert result.total_return <= -1.0
  158. def test_generate_rsi2_sampled_report_uses_shared_shell_defaults(monkeypatch, tmp_path):
  159. candles = build_linear_candles(5_000)
  160. output_file = tmp_path / "rsi2.html"
  161. recorded: dict[str, object] = {}
  162. sentinel = {
  163. "report_file": str(output_file),
  164. "segment_count": 2,
  165. "window_size": 300,
  166. "aggregate_trade_count": 4,
  167. "average_return": 0.12,
  168. }
  169. def fake_generate_sampled_report(**kwargs):
  170. recorded.update(kwargs)
  171. return sentinel
  172. monkeypatch.setattr(rsi2_report, "generate_sampled_report", fake_generate_sampled_report)
  173. result = generate_rsi2_sampled_report(
  174. candles=candles,
  175. leverage=2,
  176. output_file=output_file,
  177. symbol="BTC-USDT-SWAP",
  178. bar="3m",
  179. segments=2,
  180. window_size=300,
  181. )
  182. assert result == sentinel
  183. assert recorded["report_title"] == "RSI2 Sampled Report"
  184. assert recorded["strategy_label"] == "RSI2"
  185. assert recorded["strategy_params"] == {
  186. "trend_sma": 50,
  187. "rsi_length": 2,
  188. "rsi_long_threshold": 10.0,
  189. "rsi_short_threshold": 90.0,
  190. "exit_rsi": 50.0,
  191. }
  192. assert recorded["warmup_bars"] == 50
  193. assert callable(recorded["run_segment"])