| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- import pytest
- from okx_codex_trader import ema_pullback_report
- from okx_codex_trader.ema_pullback_report import (
- EMAPullbackConfig,
- generate_ema_pullback_sampled_report,
- run_ema_pullback_segment,
- )
- from okx_codex_trader.models import Candle
- from okx_codex_trader.sampled_report import SegmentResult
- def make_candle(index: int, open_price: float, high: float, low: float, close: float) -> Candle:
- return Candle(
- symbol="BTC-USDT-SWAP",
- ts=index * 60_000,
- open=open_price,
- high=high,
- low=low,
- close=close,
- volume=1_000.0 + index,
- )
- def build_linear_candles(count: int) -> list[Candle]:
- candles: list[Candle] = []
- for index in range(count):
- price = 100.0 + index
- candles.append(make_candle(index, price, price + 1.0, price - 1.0, price))
- return candles
- def build_long_trade_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- make_candle(5, 104.0, 105.0, 103.2, 104.5),
- make_candle(6, 104.2, 104.6, 102.8, 103.0),
- make_candle(7, 104.5, 105.0, 104.0, 104.5),
- ]
- def build_short_trade_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 98.0, 99.0, 97.0, 98.0),
- make_candle(2, 96.0, 97.0, 95.0, 96.0),
- make_candle(3, 97.0, 98.0, 96.0, 97.0),
- make_candle(4, 96.0, 97.0, 94.5, 95.0),
- make_candle(5, 96.5, 96.6, 95.0, 95.5),
- make_candle(6, 95.8, 97.4, 95.6, 97.0),
- make_candle(7, 96.0, 96.2, 95.5, 96.0),
- ]
- def build_stop_priority_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- make_candle(5, 104.0, 104.3, 101.0, 102.0),
- make_candle(6, 105.0, 106.0, 104.5, 105.0),
- ]
- def build_gap_through_stop_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- make_candle(5, 102.0, 103.0, 101.0, 102.5),
- ]
- def build_final_bar_signal_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- ]
- def build_open_tail_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- make_candle(5, 104.0, 106.5, 103.5, 106.0),
- ]
- def build_depleted_equity_fixture() -> list[Candle]:
- return [
- make_candle(0, 100.0, 101.0, 99.0, 100.0),
- make_candle(1, 102.0, 103.0, 101.0, 102.0),
- make_candle(2, 104.0, 105.0, 103.0, 104.0),
- make_candle(3, 103.0, 104.0, 102.0, 103.0),
- make_candle(4, 104.0, 106.0, 103.0, 105.0),
- make_candle(5, 104.0, 104.3, 101.0, 102.0),
- make_candle(6, 106.0, 108.0, 105.0, 107.0),
- make_candle(7, 106.5, 107.0, 106.0, 106.5),
- ]
- def test_run_ema_pullback_segment_produces_long_trade():
- result = run_ema_pullback_segment(
- candles=build_long_trade_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert isinstance(result, SegmentResult)
- assert result.trade_count == 1
- assert result.trades[0]["side"] == "Long"
- assert result.trades[0]["entry_price"] == pytest.approx(104.0)
- assert result.trades[0]["exit_price"] == pytest.approx(104.5)
- assert result.open_position is None
- def test_run_ema_pullback_segment_produces_short_trade():
- result = run_ema_pullback_segment(
- candles=build_short_trade_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert isinstance(result, SegmentResult)
- assert result.trade_count == 1
- assert result.trades[0]["side"] == "Short"
- assert result.trades[0]["entry_price"] == pytest.approx(96.5)
- assert result.trades[0]["exit_price"] == pytest.approx(96.0)
- assert result.open_position is None
- def test_run_ema_pullback_segment_stop_priority_is_correct():
- result = run_ema_pullback_segment(
- candles=build_stop_priority_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert result.trade_count == 1
- assert len(result.entries) == 1
- assert result.trades[0]["exit_price"] == pytest.approx(102.485)
- assert result.open_position is None
- def test_run_ema_pullback_segment_exits_gap_through_stop_at_open():
- result = run_ema_pullback_segment(
- candles=build_gap_through_stop_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert result.trade_count == 1
- assert result.trades[0]["exit_price"] == pytest.approx(102.0)
- assert result.open_position is None
- def test_run_ema_pullback_segment_does_not_generate_entry_from_final_bar():
- result = run_ema_pullback_segment(
- candles=build_final_bar_signal_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert result.trade_count == 0
- assert result.entries == []
- assert result.open_position is None
- def test_run_ema_pullback_segment_marks_open_position_to_market():
- result = run_ema_pullback_segment(
- candles=build_open_tail_fixture(),
- leverage=2,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert result.trade_count == 0
- assert result.trades == []
- assert result.total_return == pytest.approx((10_384.615384615385 - 10_000.0) / 10_000.0)
- assert result.open_position is not None
- assert result.open_position["side"] == "long"
- def test_run_ema_pullback_segment_does_not_reenter_after_equity_is_depleted():
- result = run_ema_pullback_segment(
- candles=build_depleted_equity_fixture(),
- leverage=100,
- warmup_bars=4,
- config=EMAPullbackConfig(fast_ema=2, slow_ema=4, stop_buffer_pct=0.005),
- )
- assert result.trade_count == 1
- assert len(result.entries) == 1
- assert result.open_position is None
- assert result.total_return <= -1.0
- def test_generate_ema_pullback_sampled_report_uses_shared_shell_defaults(monkeypatch, tmp_path):
- candles = build_linear_candles(5_000)
- output_file = tmp_path / "ema-pullback.html"
- recorded: dict[str, object] = {}
- sentinel = {
- "report_file": str(output_file),
- "segment_count": 2,
- "window_size": 300,
- "aggregate_trade_count": 4,
- "average_return": 0.12,
- }
- def fake_generate_sampled_report(**kwargs):
- recorded.update(kwargs)
- return sentinel
- monkeypatch.setattr(ema_pullback_report, "generate_sampled_report", fake_generate_sampled_report)
- result = generate_ema_pullback_sampled_report(
- candles=candles,
- leverage=2,
- output_file=output_file,
- symbol="BTC-USDT-SWAP",
- bar="3m",
- segments=2,
- window_size=300,
- )
- assert result == sentinel
- assert recorded["report_title"] == "EMA Pullback Sampled Report"
- assert recorded["strategy_label"] == "EMA Pullback"
- assert recorded["strategy_params"] == {
- "fast_ema": 20,
- "slow_ema": 50,
- "stop_buffer_pct": 0.005,
- }
- assert recorded["warmup_bars"] == 50
- assert callable(recorded["run_segment"])
|