| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- from __future__ import annotations
- import argparse
- import math
- import sys
- from dataclasses import dataclass
- from datetime import UTC, datetime
- from pathlib import Path
- from tempfile import TemporaryDirectory
- import pandas as pd
- ROOT_DIR = Path(__file__).resolve().parents[1]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from okx_codex_trader.models import Candle
- from okx_codex_trader.okx_client import OkxClient
- from scripts.explore_ultrashort import CANDLE_CACHE_DIR, load_cached_candles, save_cached_candles
- from scripts.validate_external_backtest import candles_to_backtesting_frame
- BAR_MS = {
- "1m": 60_000,
- "3m": 180_000,
- "5m": 300_000,
- "15m": 900_000,
- "30m": 1_800_000,
- "1H": 3_600_000,
- "4H": 14_400_000,
- "1D": 86_400_000,
- }
- @dataclass(frozen=True)
- class DataAudit:
- symbol: str
- bar: str
- rows: int
- first_time: str
- last_time: str
- duplicate_timestamps: int
- non_increasing_steps: int
- unexpected_intervals: int
- invalid_ohlc_rows: int
- invalid_volume_rows: int
- cache_roundtrip_matches: bool
- frame_conversion_matches: bool
- live_overlap_rows: int
- live_mismatches: int
- @property
- def passed(self) -> bool:
- return (
- self.rows > 0
- and self.duplicate_timestamps == 0
- and self.non_increasing_steps == 0
- and self.unexpected_intervals == 0
- and self.invalid_ohlc_rows == 0
- and self.invalid_volume_rows == 0
- and self.cache_roundtrip_matches
- and self.frame_conversion_matches
- and self.live_mismatches == 0
- )
- def _format_ts(ts: int) -> str:
- return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
- def _same_candle(left: Candle, right: Candle) -> bool:
- return (
- left.symbol == right.symbol
- and left.ts == right.ts
- and left.open == right.open
- and left.high == right.high
- and left.low == right.low
- and left.close == right.close
- and left.volume == right.volume
- )
- def _finite(value: float) -> bool:
- return math.isfinite(value)
- def count_invalid_ohlc(candles: list[Candle]) -> int:
- invalid = 0
- for candle in candles:
- values = (candle.open, candle.high, candle.low, candle.close)
- if not all(_finite(value) for value in values):
- invalid += 1
- continue
- if candle.low > min(candle.open, candle.close) or candle.high < max(candle.open, candle.close):
- invalid += 1
- return invalid
- def cache_roundtrip_matches(candles: list[Candle], symbol: str, bar: str) -> bool:
- with TemporaryDirectory() as directory:
- cache_dir = Path(directory)
- save_cached_candles(cache_dir, symbol, bar, candles, history_exhausted=True)
- loaded, history_exhausted = load_cached_candles(cache_dir, symbol, bar)
- return history_exhausted and len(loaded) == len(candles) and all(
- _same_candle(left, right) for left, right in zip(candles, loaded)
- )
- def frame_conversion_matches(candles: list[Candle]) -> bool:
- from okx_codex_trader.rsi2_report import RSI2Config
- frame = candles_to_backtesting_frame(candles, RSI2Config())
- if len(frame) != len(candles):
- return False
- for row, candle in zip(frame.itertuples(), candles):
- if int(row.Index.timestamp() * 1000) != candle.ts:
- return False
- if (row.Open, row.High, row.Low, row.Close, row.Volume) != (
- candle.open,
- candle.high,
- candle.low,
- candle.close,
- candle.volume,
- ):
- return False
- return True
- def count_live_mismatches(bar: str, cached: list[Candle], live: list[Candle]) -> tuple[int, int]:
- now_ms = int(datetime.now(UTC).timestamp() * 1000)
- interval = BAR_MS[bar]
- cached_by_ts = {candle.ts: candle for candle in cached}
- overlap = 0
- mismatches = 0
- for candle in live:
- if now_ms < candle.ts + interval:
- continue
- cached_candle = cached_by_ts.get(candle.ts)
- if cached_candle is None:
- continue
- overlap += 1
- if not _same_candle(cached_candle, candle):
- mismatches += 1
- return overlap, mismatches
- def audit_candles(
- *,
- candles: list[Candle],
- symbol: str,
- bar: str,
- live_candles: list[Candle] | None = None,
- ) -> DataAudit:
- if not candles:
- return DataAudit(symbol, bar, 0, "", "", 0, 0, 0, 0, 0, False, False, 0, 0)
- timestamps = [candle.ts for candle in candles]
- expected_interval = BAR_MS[bar]
- intervals = [right - left for left, right in zip(timestamps, timestamps[1:])]
- duplicate_timestamps = len(timestamps) - len(set(timestamps))
- non_increasing_steps = sum(1 for interval in intervals if interval <= 0)
- unexpected_intervals = sum(1 for interval in intervals if interval != expected_interval)
- invalid_volume_rows = sum(1 for candle in candles if not _finite(candle.volume) or candle.volume < 0)
- live_overlap_rows, live_mismatches = count_live_mismatches(bar, candles, live_candles or [])
- return DataAudit(
- symbol=symbol,
- bar=bar,
- rows=len(candles),
- first_time=_format_ts(candles[0].ts),
- last_time=_format_ts(candles[-1].ts),
- duplicate_timestamps=duplicate_timestamps,
- non_increasing_steps=non_increasing_steps,
- unexpected_intervals=unexpected_intervals,
- invalid_ohlc_rows=count_invalid_ohlc(candles),
- invalid_volume_rows=invalid_volume_rows,
- cache_roundtrip_matches=cache_roundtrip_matches(candles, symbol, bar),
- frame_conversion_matches=frame_conversion_matches(candles),
- live_overlap_rows=live_overlap_rows,
- live_mismatches=live_mismatches,
- )
- def main() -> int:
- parser = argparse.ArgumentParser()
- parser.add_argument("--symbols", default="BTC-USDT-SWAP,ETH-USDT-SWAP")
- parser.add_argument("--bars", default="1m,3m,5m,15m")
- parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
- parser.add_argument("--fetch-latest", action="store_true")
- parser.add_argument("--live-limit", type=int, default=300)
- args = parser.parse_args()
- client = OkxClient() if args.fetch_latest else None
- rows = []
- for symbol in [value.strip() for value in args.symbols.split(",") if value.strip()]:
- for bar in [value.strip() for value in args.bars.split(",") if value.strip()]:
- candles, _ = load_cached_candles(args.cache_dir, symbol, bar)
- live_candles = client.get_candles(symbol, bar, args.live_limit) if client is not None else None
- audit = audit_candles(candles=candles, symbol=symbol, bar=bar, live_candles=live_candles)
- rows.append({**audit.__dict__, "passed": audit.passed})
- frame = pd.DataFrame(rows)
- print(frame.to_string(index=False))
- return 0 if bool(frame["passed"].all()) else 1
- if __name__ == "__main__":
- raise SystemExit(main())
|