from __future__ import annotations import argparse import math import sys from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path from tempfile import TemporaryDirectory import pandas as pd ROOT_DIR = Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from okx_codex_trader.models import Candle from okx_codex_trader.okx_client import OkxClient from scripts.explore_ultrashort import CANDLE_CACHE_DIR, load_cached_candles, save_cached_candles from scripts.validate_external_backtest import candles_to_backtesting_frame BAR_MS = { "1m": 60_000, "3m": 180_000, "5m": 300_000, "15m": 900_000, "30m": 1_800_000, "1H": 3_600_000, "4H": 14_400_000, "1D": 86_400_000, } @dataclass(frozen=True) class DataAudit: symbol: str bar: str rows: int first_time: str last_time: str duplicate_timestamps: int non_increasing_steps: int unexpected_intervals: int invalid_ohlc_rows: int invalid_volume_rows: int cache_roundtrip_matches: bool frame_conversion_matches: bool live_overlap_rows: int live_mismatches: int @property def passed(self) -> bool: return ( self.rows > 0 and self.duplicate_timestamps == 0 and self.non_increasing_steps == 0 and self.unexpected_intervals == 0 and self.invalid_ohlc_rows == 0 and self.invalid_volume_rows == 0 and self.cache_roundtrip_matches and self.frame_conversion_matches and self.live_mismatches == 0 ) def _format_ts(ts: int) -> str: return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M") def _same_candle(left: Candle, right: Candle) -> bool: return ( left.symbol == right.symbol and left.ts == right.ts and left.open == right.open and left.high == right.high and left.low == right.low and left.close == right.close and left.volume == right.volume ) def _finite(value: float) -> bool: return math.isfinite(value) def count_invalid_ohlc(candles: list[Candle]) -> int: invalid = 0 for candle in candles: values = (candle.open, candle.high, candle.low, candle.close) if not all(_finite(value) for value in values): invalid += 1 continue if candle.low > min(candle.open, candle.close) or candle.high < max(candle.open, candle.close): invalid += 1 return invalid def cache_roundtrip_matches(candles: list[Candle], symbol: str, bar: str) -> bool: with TemporaryDirectory() as directory: cache_dir = Path(directory) save_cached_candles(cache_dir, symbol, bar, candles, history_exhausted=True) loaded, history_exhausted = load_cached_candles(cache_dir, symbol, bar) return history_exhausted and len(loaded) == len(candles) and all( _same_candle(left, right) for left, right in zip(candles, loaded) ) def frame_conversion_matches(candles: list[Candle]) -> bool: from okx_codex_trader.rsi2_report import RSI2Config frame = candles_to_backtesting_frame(candles, RSI2Config()) if len(frame) != len(candles): return False for row, candle in zip(frame.itertuples(), candles): if int(row.Index.timestamp() * 1000) != candle.ts: return False if (row.Open, row.High, row.Low, row.Close, row.Volume) != ( candle.open, candle.high, candle.low, candle.close, candle.volume, ): return False return True def count_live_mismatches(bar: str, cached: list[Candle], live: list[Candle]) -> tuple[int, int]: now_ms = int(datetime.now(UTC).timestamp() * 1000) interval = BAR_MS[bar] cached_by_ts = {candle.ts: candle for candle in cached} overlap = 0 mismatches = 0 for candle in live: if now_ms < candle.ts + interval: continue cached_candle = cached_by_ts.get(candle.ts) if cached_candle is None: continue overlap += 1 if not _same_candle(cached_candle, candle): mismatches += 1 return overlap, mismatches def audit_candles( *, candles: list[Candle], symbol: str, bar: str, live_candles: list[Candle] | None = None, ) -> DataAudit: if not candles: return DataAudit(symbol, bar, 0, "", "", 0, 0, 0, 0, 0, False, False, 0, 0) timestamps = [candle.ts for candle in candles] expected_interval = BAR_MS[bar] intervals = [right - left for left, right in zip(timestamps, timestamps[1:])] duplicate_timestamps = len(timestamps) - len(set(timestamps)) non_increasing_steps = sum(1 for interval in intervals if interval <= 0) unexpected_intervals = sum(1 for interval in intervals if interval != expected_interval) invalid_volume_rows = sum(1 for candle in candles if not _finite(candle.volume) or candle.volume < 0) live_overlap_rows, live_mismatches = count_live_mismatches(bar, candles, live_candles or []) return DataAudit( symbol=symbol, bar=bar, rows=len(candles), first_time=_format_ts(candles[0].ts), last_time=_format_ts(candles[-1].ts), duplicate_timestamps=duplicate_timestamps, non_increasing_steps=non_increasing_steps, unexpected_intervals=unexpected_intervals, invalid_ohlc_rows=count_invalid_ohlc(candles), invalid_volume_rows=invalid_volume_rows, cache_roundtrip_matches=cache_roundtrip_matches(candles, symbol, bar), frame_conversion_matches=frame_conversion_matches(candles), live_overlap_rows=live_overlap_rows, live_mismatches=live_mismatches, ) def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--symbols", default="BTC-USDT-SWAP,ETH-USDT-SWAP") parser.add_argument("--bars", default="1m,3m,5m,15m") parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR) parser.add_argument("--fetch-latest", action="store_true") parser.add_argument("--live-limit", type=int, default=300) args = parser.parse_args() client = OkxClient() if args.fetch_latest else None rows = [] for symbol in [value.strip() for value in args.symbols.split(",") if value.strip()]: for bar in [value.strip() for value in args.bars.split(",") if value.strip()]: candles, _ = load_cached_candles(args.cache_dir, symbol, bar) live_candles = client.get_candles(symbol, bar, args.live_limit) if client is not None else None audit = audit_candles(candles=candles, symbol=symbol, bar=bar, live_candles=live_candles) rows.append({**audit.__dict__, "passed": audit.passed}) frame = pd.DataFrame(rows) print(frame.to_string(index=False)) return 0 if bool(frame["passed"].all()) else 1 if __name__ == "__main__": raise SystemExit(main())