Explorar el Código

fix: validate and repair candle data pipeline

lxy hace 1 mes
padre
commit
c9c3e47846

+ 14 - 11
okx_codex_trader/okx_client.py

@@ -160,18 +160,21 @@ class OkxClient:
                 params["after"] = after
             data = self._request("GET", "/api/v5/market/history-candles", params=params)
             try:
-                page = [
-                    Candle(
-                        symbol=symbol,
-                        ts=int(entry[0]),
-                        open=_parse_finite_float(entry[1]),
-                        high=_parse_finite_float(entry[2]),
-                        low=_parse_finite_float(entry[3]),
-                        close=_parse_finite_float(entry[4]),
-                        volume=_parse_finite_float(entry[5]),
+                page = []
+                for entry in data:
+                    if str(entry[8]) != "1":
+                        continue
+                    page.append(
+                        Candle(
+                            symbol=symbol,
+                            ts=int(entry[0]),
+                            open=_parse_finite_float(entry[1]),
+                            high=_parse_finite_float(entry[2]),
+                            low=_parse_finite_float(entry[3]),
+                            close=_parse_finite_float(entry[4]),
+                            volume=_parse_finite_float(entry[5]),
+                        )
                     )
-                    for entry in data
-                ]
             except (IndexError, KeyError, TypeError, ValueError):
                 raise self._invalid_payload() from None
             if not page:

+ 24 - 0
scripts/explore_ultrashort.py

@@ -30,6 +30,16 @@ ROBUST_HISTORY_LIMIT = 50_000
 GROSS_RETURN_NOTE = "Gross-return backtest only: fees, slippage, and funding rates are excluded."
 MINUTES_PER_YEAR = 365 * 24 * 60
 CANDLE_CACHE_DIR = Path("data/okx-candles")
+CANDLE_BAR_MS = {
+    "1m": 60_000,
+    "3m": 180_000,
+    "5m": 300_000,
+    "15m": 900_000,
+    "30m": 1_800_000,
+    "1H": 3_600_000,
+    "4H": 14_400_000,
+    "1D": 86_400_000,
+}
 
 
 @dataclass(frozen=True)
@@ -112,6 +122,14 @@ def save_cached_candles(cache_dir: Path, symbol: str, bar: str, candles: list[Ca
         json.dump(meta, handle, separators=(",", ":"))
 
 
+def latest_bridge_count(cached: list[Candle], latest_last_ts: int, interval: int) -> int:
+    bridge_from_ts = max(candle.ts for candle in cached)
+    for left, right in zip(cached, cached[1:]):
+        if right.ts - left.ts != interval:
+            bridge_from_ts = left.ts
+    return ((latest_last_ts - bridge_from_ts) // interval) + 1
+
+
 def get_candles_cached(
     client: OkxClient,
     symbol: str,
@@ -122,6 +140,12 @@ def get_candles_cached(
     cached, history_exhausted = load_cached_candles(cache_dir, symbol, bar)
     if cached and (len(cached) >= limit or history_exhausted):
         latest = client.get_candles(symbol, bar, min(300, limit))
+        interval = CANDLE_BAR_MS[bar]
+        if latest:
+            latest_last_ts = max(candle.ts for candle in latest)
+            needed_latest_count = latest_bridge_count(cached, latest_last_ts, interval)
+            if needed_latest_count > len(latest):
+                latest = client.get_candles(symbol, bar, needed_latest_count)
         merged = {candle.ts: candle for candle in cached}
         for candle in latest:
             merged[candle.ts] = candle

+ 208 - 0
scripts/validate_data_pipeline.py

@@ -0,0 +1,208 @@
+from __future__ import annotations
+
+import argparse
+import math
+import sys
+from dataclasses import dataclass
+from datetime import UTC, datetime
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pandas as pd
+
+ROOT_DIR = Path(__file__).resolve().parents[1]
+if str(ROOT_DIR) not in sys.path:
+    sys.path.insert(0, str(ROOT_DIR))
+
+from okx_codex_trader.models import Candle
+from okx_codex_trader.okx_client import OkxClient
+from scripts.explore_ultrashort import CANDLE_CACHE_DIR, load_cached_candles, save_cached_candles
+from scripts.validate_external_backtest import candles_to_backtesting_frame
+
+
+BAR_MS = {
+    "1m": 60_000,
+    "3m": 180_000,
+    "5m": 300_000,
+    "15m": 900_000,
+    "30m": 1_800_000,
+    "1H": 3_600_000,
+    "4H": 14_400_000,
+    "1D": 86_400_000,
+}
+
+
+@dataclass(frozen=True)
+class DataAudit:
+    symbol: str
+    bar: str
+    rows: int
+    first_time: str
+    last_time: str
+    duplicate_timestamps: int
+    non_increasing_steps: int
+    unexpected_intervals: int
+    invalid_ohlc_rows: int
+    invalid_volume_rows: int
+    cache_roundtrip_matches: bool
+    frame_conversion_matches: bool
+    live_overlap_rows: int
+    live_mismatches: int
+
+    @property
+    def passed(self) -> bool:
+        return (
+            self.rows > 0
+            and self.duplicate_timestamps == 0
+            and self.non_increasing_steps == 0
+            and self.unexpected_intervals == 0
+            and self.invalid_ohlc_rows == 0
+            and self.invalid_volume_rows == 0
+            and self.cache_roundtrip_matches
+            and self.frame_conversion_matches
+            and self.live_mismatches == 0
+        )
+
+
+def _format_ts(ts: int) -> str:
+    return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
+
+
+def _same_candle(left: Candle, right: Candle) -> bool:
+    return (
+        left.symbol == right.symbol
+        and left.ts == right.ts
+        and left.open == right.open
+        and left.high == right.high
+        and left.low == right.low
+        and left.close == right.close
+        and left.volume == right.volume
+    )
+
+
+def _finite(value: float) -> bool:
+    return math.isfinite(value)
+
+
+def count_invalid_ohlc(candles: list[Candle]) -> int:
+    invalid = 0
+    for candle in candles:
+        values = (candle.open, candle.high, candle.low, candle.close)
+        if not all(_finite(value) for value in values):
+            invalid += 1
+            continue
+        if candle.low > min(candle.open, candle.close) or candle.high < max(candle.open, candle.close):
+            invalid += 1
+    return invalid
+
+
+def cache_roundtrip_matches(candles: list[Candle], symbol: str, bar: str) -> bool:
+    with TemporaryDirectory() as directory:
+        cache_dir = Path(directory)
+        save_cached_candles(cache_dir, symbol, bar, candles, history_exhausted=True)
+        loaded, history_exhausted = load_cached_candles(cache_dir, symbol, bar)
+    return history_exhausted and len(loaded) == len(candles) and all(
+        _same_candle(left, right) for left, right in zip(candles, loaded)
+    )
+
+
+def frame_conversion_matches(candles: list[Candle]) -> bool:
+    from okx_codex_trader.rsi2_report import RSI2Config
+
+    frame = candles_to_backtesting_frame(candles, RSI2Config())
+    if len(frame) != len(candles):
+        return False
+    for row, candle in zip(frame.itertuples(), candles):
+        if int(row.Index.timestamp() * 1000) != candle.ts:
+            return False
+        if (row.Open, row.High, row.Low, row.Close, row.Volume) != (
+            candle.open,
+            candle.high,
+            candle.low,
+            candle.close,
+            candle.volume,
+        ):
+            return False
+    return True
+
+
+def count_live_mismatches(bar: str, cached: list[Candle], live: list[Candle]) -> tuple[int, int]:
+    now_ms = int(datetime.now(UTC).timestamp() * 1000)
+    interval = BAR_MS[bar]
+    cached_by_ts = {candle.ts: candle for candle in cached}
+    overlap = 0
+    mismatches = 0
+    for candle in live:
+        if now_ms < candle.ts + interval:
+            continue
+        cached_candle = cached_by_ts.get(candle.ts)
+        if cached_candle is None:
+            continue
+        overlap += 1
+        if not _same_candle(cached_candle, candle):
+            mismatches += 1
+    return overlap, mismatches
+
+
+def audit_candles(
+    *,
+    candles: list[Candle],
+    symbol: str,
+    bar: str,
+    live_candles: list[Candle] | None = None,
+) -> DataAudit:
+    if not candles:
+        return DataAudit(symbol, bar, 0, "", "", 0, 0, 0, 0, 0, False, False, 0, 0)
+
+    timestamps = [candle.ts for candle in candles]
+    expected_interval = BAR_MS[bar]
+    intervals = [right - left for left, right in zip(timestamps, timestamps[1:])]
+    duplicate_timestamps = len(timestamps) - len(set(timestamps))
+    non_increasing_steps = sum(1 for interval in intervals if interval <= 0)
+    unexpected_intervals = sum(1 for interval in intervals if interval != expected_interval)
+    invalid_volume_rows = sum(1 for candle in candles if not _finite(candle.volume) or candle.volume < 0)
+    live_overlap_rows, live_mismatches = count_live_mismatches(bar, candles, live_candles or [])
+
+    return DataAudit(
+        symbol=symbol,
+        bar=bar,
+        rows=len(candles),
+        first_time=_format_ts(candles[0].ts),
+        last_time=_format_ts(candles[-1].ts),
+        duplicate_timestamps=duplicate_timestamps,
+        non_increasing_steps=non_increasing_steps,
+        unexpected_intervals=unexpected_intervals,
+        invalid_ohlc_rows=count_invalid_ohlc(candles),
+        invalid_volume_rows=invalid_volume_rows,
+        cache_roundtrip_matches=cache_roundtrip_matches(candles, symbol, bar),
+        frame_conversion_matches=frame_conversion_matches(candles),
+        live_overlap_rows=live_overlap_rows,
+        live_mismatches=live_mismatches,
+    )
+
+
+def main() -> int:
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--symbols", default="BTC-USDT-SWAP,ETH-USDT-SWAP")
+    parser.add_argument("--bars", default="1m,3m,5m,15m")
+    parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
+    parser.add_argument("--fetch-latest", action="store_true")
+    parser.add_argument("--live-limit", type=int, default=300)
+    args = parser.parse_args()
+
+    client = OkxClient() if args.fetch_latest else None
+    rows = []
+    for symbol in [value.strip() for value in args.symbols.split(",") if value.strip()]:
+        for bar in [value.strip() for value in args.bars.split(",") if value.strip()]:
+            candles, _ = load_cached_candles(args.cache_dir, symbol, bar)
+            live_candles = client.get_candles(symbol, bar, args.live_limit) if client is not None else None
+            audit = audit_candles(candles=candles, symbol=symbol, bar=bar, live_candles=live_candles)
+            rows.append({**audit.__dict__, "passed": audit.passed})
+
+    frame = pd.DataFrame(rows)
+    print(frame.to_string(index=False))
+    return 0 if bool(frame["passed"].all()) else 1
+
+
+if __name__ == "__main__":
+    raise SystemExit(main())

+ 70 - 0
tests/test_explore_ultrashort.py

@@ -436,6 +436,76 @@ def test_get_candles_cached_saves_exhausted_history_and_updates_latest(tmp_path)
     assert module.load_cached_candles(tmp_path, "BTC-USDT-SWAP", "15m")[1] is True
 
 
+def test_get_candles_cached_bridges_stale_cache_gap(tmp_path):
+    module = load_explore_module()
+    module.save_cached_candles(
+        tmp_path,
+        "BTC-USDT-SWAP",
+        "3m",
+        [
+            Candle(symbol="BTC-USDT-SWAP", ts=0, open=100.0, high=101.0, low=99.0, close=100.0, volume=1.0),
+            Candle(symbol="BTC-USDT-SWAP", ts=180_000, open=101.0, high=102.0, low=100.0, close=101.0, volume=1.0),
+        ],
+        history_exhausted=True,
+    )
+
+    class Client:
+        def __init__(self):
+            self.limits: list[int] = []
+
+        def get_candles(self, symbol, bar, limit):
+            self.limits.append(limit)
+            if limit == 300:
+                return [
+                    Candle(symbol=symbol, ts=720_000, open=104.0, high=105.0, low=103.0, close=104.0, volume=1.0),
+                    Candle(symbol=symbol, ts=900_000, open=105.0, high=106.0, low=104.0, close=105.0, volume=1.0),
+                ]
+            return [
+                Candle(symbol=symbol, ts=360_000, open=102.0, high=103.0, low=101.0, close=102.0, volume=1.0),
+                Candle(symbol=symbol, ts=540_000, open=103.0, high=104.0, low=102.0, close=103.0, volume=1.0),
+                Candle(symbol=symbol, ts=720_000, open=104.0, high=105.0, low=103.0, close=104.0, volume=1.0),
+                Candle(symbol=symbol, ts=900_000, open=105.0, high=106.0, low=104.0, close=105.0, volume=1.0),
+            ]
+
+    candles = module.get_candles_cached(Client(), "BTC-USDT-SWAP", "3m", 10, tmp_path)
+
+    assert [candle.ts for candle in candles] == [0, 180_000, 360_000, 540_000, 720_000, 900_000]
+
+
+def test_get_candles_cached_repairs_existing_internal_gap(tmp_path):
+    module = load_explore_module()
+    module.save_cached_candles(
+        tmp_path,
+        "BTC-USDT-SWAP",
+        "3m",
+        [
+            Candle(symbol="BTC-USDT-SWAP", ts=0, open=100.0, high=101.0, low=99.0, close=100.0, volume=1.0),
+            Candle(symbol="BTC-USDT-SWAP", ts=180_000, open=101.0, high=102.0, low=100.0, close=101.0, volume=1.0),
+            Candle(symbol="BTC-USDT-SWAP", ts=720_000, open=104.0, high=105.0, low=103.0, close=104.0, volume=1.0),
+            Candle(symbol="BTC-USDT-SWAP", ts=900_000, open=105.0, high=106.0, low=104.0, close=105.0, volume=1.0),
+        ],
+        history_exhausted=True,
+    )
+
+    class Client:
+        def get_candles(self, symbol, bar, limit):
+            if limit == 300:
+                return [
+                    Candle(symbol=symbol, ts=720_000, open=104.0, high=105.0, low=103.0, close=104.0, volume=1.0),
+                    Candle(symbol=symbol, ts=900_000, open=105.0, high=106.0, low=104.0, close=105.0, volume=1.0),
+                ]
+            return [
+                Candle(symbol=symbol, ts=360_000, open=102.0, high=103.0, low=101.0, close=102.0, volume=1.0),
+                Candle(symbol=symbol, ts=540_000, open=103.0, high=104.0, low=102.0, close=103.0, volume=1.0),
+                Candle(symbol=symbol, ts=720_000, open=104.0, high=105.0, low=103.0, close=104.0, volume=1.0),
+                Candle(symbol=symbol, ts=900_000, open=105.0, high=106.0, low=104.0, close=105.0, volume=1.0),
+            ]
+
+    candles = module.get_candles_cached(Client(), "BTC-USDT-SWAP", "3m", 10, tmp_path)
+
+    assert [candle.ts for candle in candles] == [0, 180_000, 360_000, 540_000, 720_000, 900_000]
+
+
 def test_history_bars_for_years_counts_minute_bars():
     module = load_explore_module()
 

+ 22 - 0
tests/test_okx_client.py

@@ -500,6 +500,28 @@ def test_get_candles_returns_chronological_ascending_order():
     assert [candle.ts for candle in candles] == [1710000000000, 1710000001000]
 
 
+def test_get_candles_ignores_unconfirmed_history_rows():
+    session = DummySession(
+        [
+            DummyResponse(
+                {
+                    "code": "0",
+                    "msg": "",
+                    "data": [
+                        ["1710000001000", "25100", "25200", "25000", "25150", "110", "1100", "1100", "0"],
+                        ["1710000000000", "25000", "25100", "24900", "25050", "100", "1000", "1000", "1"],
+                    ],
+                }
+            )
+        ]
+    )
+    client = OkxClient(config=sample_config(), session=session)
+
+    candles = client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+    assert [candle.ts for candle in candles] == [1710000000000]
+
+
 def test_get_candles_paginates_when_limit_exceeds_single_page():
     session = DummySession([full_page_candles_response(), older_candles_response()])
     client = OkxClient(config=sample_config(), session=session)

+ 90 - 0
tests/test_validate_data_pipeline.py

@@ -0,0 +1,90 @@
+import importlib.util
+import sys
+from pathlib import Path
+
+from okx_codex_trader.models import Candle
+
+
+def load_data_validation_module():
+    path = Path(__file__).resolve().parents[1] / "scripts" / "validate_data_pipeline.py"
+    spec = importlib.util.spec_from_file_location("validate_data_pipeline", path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    assert spec.loader is not None
+    sys.modules[spec.name] = module
+    spec.loader.exec_module(module)
+    return module
+
+
+def build_candles(count: int) -> list[Candle]:
+    return [
+        Candle(
+            symbol="BTC-USDT-SWAP",
+            ts=index * 60_000,
+            open=100.0 + index,
+            high=101.0 + index,
+            low=99.0 + index,
+            close=100.5 + index,
+            volume=1_000.0 + index,
+        )
+        for index in range(count)
+    ]
+
+
+def test_audit_candles_passes_clean_ohlcv_pipeline():
+    module = load_data_validation_module()
+    audit = module.audit_candles(candles=build_candles(80), symbol="BTC-USDT-SWAP", bar="1m")
+
+    assert audit.passed
+    assert audit.rows == 80
+    assert audit.duplicate_timestamps == 0
+    assert audit.unexpected_intervals == 0
+    assert audit.invalid_ohlc_rows == 0
+    assert audit.invalid_volume_rows == 0
+    assert audit.cache_roundtrip_matches
+    assert audit.frame_conversion_matches
+
+
+def test_audit_candles_fails_duplicate_gap_and_invalid_ohlcv():
+    module = load_data_validation_module()
+    candles = build_candles(5)
+    candles[2] = Candle(
+        symbol="BTC-USDT-SWAP",
+        ts=candles[1].ts,
+        open=100.0,
+        high=99.0,
+        low=101.0,
+        close=100.0,
+        volume=-1.0,
+    )
+    audit = module.audit_candles(candles=candles, symbol="BTC-USDT-SWAP", bar="1m")
+
+    assert not audit.passed
+    assert audit.duplicate_timestamps == 1
+    assert audit.non_increasing_steps == 1
+    assert audit.unexpected_intervals == 2
+    assert audit.invalid_ohlc_rows == 1
+    assert audit.invalid_volume_rows == 1
+
+
+def test_audit_candles_reports_live_overlap_mismatches():
+    module = load_data_validation_module()
+    candles = build_candles(3)
+    live = [
+        candles[0],
+        Candle(
+            symbol="BTC-USDT-SWAP",
+            ts=candles[1].ts,
+            open=candles[1].open,
+            high=candles[1].high,
+            low=candles[1].low,
+            close=candles[1].close + 1.0,
+            volume=candles[1].volume,
+        ),
+    ]
+
+    audit = module.audit_candles(candles=candles, symbol="BTC-USDT-SWAP", bar="1m", live_candles=live)
+
+    assert not audit.passed
+    assert audit.live_overlap_rows == 2
+    assert audit.live_mismatches == 1