validate_data_pipeline.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import annotations
  2. import argparse
  3. import math
  4. import sys
  5. from dataclasses import dataclass
  6. from datetime import UTC, datetime
  7. from pathlib import Path
  8. from tempfile import TemporaryDirectory
  9. import pandas as pd
  10. ROOT_DIR = Path(__file__).resolve().parents[1]
  11. if str(ROOT_DIR) not in sys.path:
  12. sys.path.insert(0, str(ROOT_DIR))
  13. from okx_codex_trader.models import Candle
  14. from okx_codex_trader.okx_client import OkxClient
  15. from scripts.explore_ultrashort import CANDLE_CACHE_DIR, load_cached_candles, save_cached_candles
  16. from scripts.validate_external_backtest import candles_to_backtesting_frame
  17. BAR_MS = {
  18. "1m": 60_000,
  19. "3m": 180_000,
  20. "5m": 300_000,
  21. "15m": 900_000,
  22. "30m": 1_800_000,
  23. "1H": 3_600_000,
  24. "4H": 14_400_000,
  25. "1D": 86_400_000,
  26. }
  27. @dataclass(frozen=True)
  28. class DataAudit:
  29. symbol: str
  30. bar: str
  31. rows: int
  32. first_time: str
  33. last_time: str
  34. duplicate_timestamps: int
  35. non_increasing_steps: int
  36. unexpected_intervals: int
  37. invalid_ohlc_rows: int
  38. invalid_volume_rows: int
  39. cache_roundtrip_matches: bool
  40. frame_conversion_matches: bool
  41. live_overlap_rows: int
  42. live_mismatches: int
  43. @property
  44. def passed(self) -> bool:
  45. return (
  46. self.rows > 0
  47. and self.duplicate_timestamps == 0
  48. and self.non_increasing_steps == 0
  49. and self.unexpected_intervals == 0
  50. and self.invalid_ohlc_rows == 0
  51. and self.invalid_volume_rows == 0
  52. and self.cache_roundtrip_matches
  53. and self.frame_conversion_matches
  54. and self.live_mismatches == 0
  55. )
  56. def _format_ts(ts: int) -> str:
  57. return pd.to_datetime(ts, unit="ms", utc=True).strftime("%Y-%m-%d %H:%M")
  58. def _same_candle(left: Candle, right: Candle) -> bool:
  59. return (
  60. left.symbol == right.symbol
  61. and left.ts == right.ts
  62. and left.open == right.open
  63. and left.high == right.high
  64. and left.low == right.low
  65. and left.close == right.close
  66. and left.volume == right.volume
  67. )
  68. def _finite(value: float) -> bool:
  69. return math.isfinite(value)
  70. def count_invalid_ohlc(candles: list[Candle]) -> int:
  71. invalid = 0
  72. for candle in candles:
  73. values = (candle.open, candle.high, candle.low, candle.close)
  74. if not all(_finite(value) for value in values):
  75. invalid += 1
  76. continue
  77. if candle.low > min(candle.open, candle.close) or candle.high < max(candle.open, candle.close):
  78. invalid += 1
  79. return invalid
  80. def cache_roundtrip_matches(candles: list[Candle], symbol: str, bar: str) -> bool:
  81. with TemporaryDirectory() as directory:
  82. cache_dir = Path(directory)
  83. save_cached_candles(cache_dir, symbol, bar, candles, history_exhausted=True)
  84. loaded, history_exhausted = load_cached_candles(cache_dir, symbol, bar)
  85. return history_exhausted and len(loaded) == len(candles) and all(
  86. _same_candle(left, right) for left, right in zip(candles, loaded)
  87. )
  88. def frame_conversion_matches(candles: list[Candle]) -> bool:
  89. from okx_codex_trader.rsi2_report import RSI2Config
  90. frame = candles_to_backtesting_frame(candles, RSI2Config())
  91. if len(frame) != len(candles):
  92. return False
  93. for row, candle in zip(frame.itertuples(), candles):
  94. if int(row.Index.timestamp() * 1000) != candle.ts:
  95. return False
  96. if (row.Open, row.High, row.Low, row.Close, row.Volume) != (
  97. candle.open,
  98. candle.high,
  99. candle.low,
  100. candle.close,
  101. candle.volume,
  102. ):
  103. return False
  104. return True
  105. def count_live_mismatches(bar: str, cached: list[Candle], live: list[Candle]) -> tuple[int, int]:
  106. now_ms = int(datetime.now(UTC).timestamp() * 1000)
  107. interval = BAR_MS[bar]
  108. cached_by_ts = {candle.ts: candle for candle in cached}
  109. overlap = 0
  110. mismatches = 0
  111. for candle in live:
  112. if now_ms < candle.ts + interval:
  113. continue
  114. cached_candle = cached_by_ts.get(candle.ts)
  115. if cached_candle is None:
  116. continue
  117. overlap += 1
  118. if not _same_candle(cached_candle, candle):
  119. mismatches += 1
  120. return overlap, mismatches
  121. def audit_candles(
  122. *,
  123. candles: list[Candle],
  124. symbol: str,
  125. bar: str,
  126. live_candles: list[Candle] | None = None,
  127. ) -> DataAudit:
  128. if not candles:
  129. return DataAudit(symbol, bar, 0, "", "", 0, 0, 0, 0, 0, False, False, 0, 0)
  130. timestamps = [candle.ts for candle in candles]
  131. expected_interval = BAR_MS[bar]
  132. intervals = [right - left for left, right in zip(timestamps, timestamps[1:])]
  133. duplicate_timestamps = len(timestamps) - len(set(timestamps))
  134. non_increasing_steps = sum(1 for interval in intervals if interval <= 0)
  135. unexpected_intervals = sum(1 for interval in intervals if interval != expected_interval)
  136. invalid_volume_rows = sum(1 for candle in candles if not _finite(candle.volume) or candle.volume < 0)
  137. live_overlap_rows, live_mismatches = count_live_mismatches(bar, candles, live_candles or [])
  138. return DataAudit(
  139. symbol=symbol,
  140. bar=bar,
  141. rows=len(candles),
  142. first_time=_format_ts(candles[0].ts),
  143. last_time=_format_ts(candles[-1].ts),
  144. duplicate_timestamps=duplicate_timestamps,
  145. non_increasing_steps=non_increasing_steps,
  146. unexpected_intervals=unexpected_intervals,
  147. invalid_ohlc_rows=count_invalid_ohlc(candles),
  148. invalid_volume_rows=invalid_volume_rows,
  149. cache_roundtrip_matches=cache_roundtrip_matches(candles, symbol, bar),
  150. frame_conversion_matches=frame_conversion_matches(candles),
  151. live_overlap_rows=live_overlap_rows,
  152. live_mismatches=live_mismatches,
  153. )
  154. def main() -> int:
  155. parser = argparse.ArgumentParser()
  156. parser.add_argument("--symbols", default="BTC-USDT-SWAP,ETH-USDT-SWAP")
  157. parser.add_argument("--bars", default="1m,3m,5m,15m")
  158. parser.add_argument("--cache-dir", type=Path, default=CANDLE_CACHE_DIR)
  159. parser.add_argument("--fetch-latest", action="store_true")
  160. parser.add_argument("--live-limit", type=int, default=300)
  161. args = parser.parse_args()
  162. client = OkxClient() if args.fetch_latest else None
  163. rows = []
  164. for symbol in [value.strip() for value in args.symbols.split(",") if value.strip()]:
  165. for bar in [value.strip() for value in args.bars.split(",") if value.strip()]:
  166. candles, _ = load_cached_candles(args.cache_dir, symbol, bar)
  167. live_candles = client.get_candles(symbol, bar, args.live_limit) if client is not None else None
  168. audit = audit_candles(candles=candles, symbol=symbol, bar=bar, live_candles=live_candles)
  169. rows.append({**audit.__dict__, "passed": audit.passed})
  170. frame = pd.DataFrame(rows)
  171. print(frame.to_string(index=False))
  172. return 0 if bool(frame["passed"].all()) else 1
  173. if __name__ == "__main__":
  174. raise SystemExit(main())