cli.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import argparse
  2. import json
  3. from datetime import UTC, datetime
  4. from dataclasses import asdict
  5. from pathlib import Path
  6. from typing import Callable, Sequence
  7. from okx_codex_trader.backtest import run_backtest
  8. from okx_codex_trader.codex_analyzer import analyze_with_codex
  9. from okx_codex_trader.config import Config, load_config
  10. from okx_codex_trader.bbmr_report import generate_bbmr_sampled_report
  11. from okx_codex_trader.bbsb_report import generate_bbsb_sampled_report
  12. from okx_codex_trader.donchian_report import DonchianConfig, generate_donchian_sampled_report
  13. from okx_codex_trader.ema_pullback_report import EMAPullbackConfig, generate_ema_pullback_sampled_report
  14. from okx_codex_trader.rsi2_report import RSI2Config, generate_rsi2_sampled_report
  15. from okx_codex_trader.paper_engine import apply_signal, load_state, save_state
  16. from okx_codex_trader.okx_client import OkxClient
  17. from okx_codex_trader.report import generate_backtest_report
  18. from okx_codex_trader.strategy import validate_signal
  19. SUPPORTED_SYMBOLS = ("BTC-USDT-SWAP", "ETH-USDT-SWAP")
  20. SAMPLED_REPORT_COMMANDS = {
  21. "backtest-bbmr-report": {
  22. "parser_args": (),
  23. "strategy_params": lambda args: {},
  24. },
  25. "backtest-bbsb-report": {
  26. "parser_args": (),
  27. "strategy_params": lambda args: {},
  28. },
  29. "backtest-donchian-report": {
  30. "parser_args": (
  31. (("--entry-window",), {"type": int, "default": DonchianConfig.entry_window}),
  32. (("--exit-window",), {"type": int, "default": DonchianConfig.exit_window}),
  33. (("--stop-loss-pct",), {"type": float, "default": DonchianConfig.stop_loss_pct}),
  34. ),
  35. "strategy_params": lambda args: {
  36. "entry_window": args.entry_window,
  37. "exit_window": args.exit_window,
  38. "stop_loss_pct": args.stop_loss_pct,
  39. },
  40. },
  41. "backtest-rsi2-report": {
  42. "parser_args": (
  43. (("--trend-sma",), {"type": int, "default": RSI2Config.trend_sma}),
  44. (("--rsi-length",), {"type": int, "default": RSI2Config.rsi_length}),
  45. (("--rsi-long-threshold",), {"type": float, "default": RSI2Config.rsi_long_threshold}),
  46. (("--rsi-short-threshold",), {"type": float, "default": RSI2Config.rsi_short_threshold}),
  47. (("--exit-rsi",), {"type": float, "default": RSI2Config.exit_rsi}),
  48. ),
  49. "strategy_params": lambda args: {
  50. "trend_sma": args.trend_sma,
  51. "rsi_length": args.rsi_length,
  52. "rsi_long_threshold": args.rsi_long_threshold,
  53. "rsi_short_threshold": args.rsi_short_threshold,
  54. "exit_rsi": args.exit_rsi,
  55. },
  56. },
  57. "backtest-ema-pullback-report": {
  58. "parser_args": (
  59. (("--fast-ema",), {"type": int, "default": EMAPullbackConfig.fast_ema}),
  60. (("--slow-ema",), {"type": int, "default": EMAPullbackConfig.slow_ema}),
  61. (("--stop-buffer-pct",), {"type": float, "default": EMAPullbackConfig.stop_buffer_pct}),
  62. ),
  63. "strategy_params": lambda args: {
  64. "fast_ema": args.fast_ema,
  65. "slow_ema": args.slow_ema,
  66. "stop_buffer_pct": args.stop_buffer_pct,
  67. },
  68. },
  69. }
  70. def _add_sampled_report_parser(
  71. subparsers: argparse._SubParsersAction,
  72. command: str,
  73. parser_args: tuple[tuple[tuple[str, ...], dict[str, object]], ...],
  74. ) -> None:
  75. parser = subparsers.add_parser(command)
  76. parser.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  77. parser.add_argument("--bar", required=True)
  78. parser.add_argument("--history-limit", type=int, required=True)
  79. parser.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
  80. parser.add_argument("--segments", type=int, required=True)
  81. parser.add_argument("--window-size", type=int, required=True)
  82. parser.add_argument("--output-file", required=True)
  83. for flags, kwargs in parser_args:
  84. parser.add_argument(*flags, **kwargs)
  85. def build_parser() -> argparse.ArgumentParser:
  86. parser = argparse.ArgumentParser(prog="okx-codex-trader")
  87. subparsers = parser.add_subparsers(dest="command", required=True)
  88. fetch_history = subparsers.add_parser("fetch-history")
  89. fetch_history.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  90. fetch_history.add_argument("--bar", required=True)
  91. fetch_history.add_argument("--limit", type=int, required=True)
  92. backtest = subparsers.add_parser("backtest")
  93. backtest.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  94. backtest.add_argument("--bar", required=True)
  95. backtest.add_argument("--limit", type=int, required=True)
  96. backtest.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
  97. backtest_report = subparsers.add_parser("backtest-report")
  98. backtest_report.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  99. backtest_report.add_argument("--bar", required=True)
  100. backtest_report.add_argument("--limit", type=int, required=True)
  101. backtest_report.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
  102. backtest_report.add_argument("--output-file", required=True)
  103. for command, settings in SAMPLED_REPORT_COMMANDS.items():
  104. _add_sampled_report_parser(subparsers, command, settings["parser_args"])
  105. analyze = subparsers.add_parser("analyze")
  106. analyze.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  107. analyze.add_argument("--bar", required=True)
  108. analyze.add_argument("--limit", type=int, required=True)
  109. analyze.add_argument("--output-file", required=True)
  110. paper_order = subparsers.add_parser("paper-order")
  111. paper_order.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  112. paper_order.add_argument("--signal-file", required=True)
  113. paper_order.add_argument("--margin-usdt", type=float, required=True)
  114. positions = subparsers.add_parser("positions")
  115. positions.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  116. okx_account = subparsers.add_parser("okx-account")
  117. okx_account.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  118. okx_account.add_argument("--currency", default="USDT")
  119. okx_order = subparsers.add_parser("okx-order")
  120. okx_order.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
  121. okx_order.add_argument("--signal-file", required=True)
  122. okx_order.add_argument("--margin-usdt", type=float, required=True)
  123. okx_order.add_argument("--max-margin-usdt", type=float, required=True)
  124. okx_order.add_argument("--confirm-live", action="store_true")
  125. return parser
  126. def _write_text(path: str, text: str) -> None:
  127. Path(path).write_text(text)
  128. def _dump_json(payload: object) -> str:
  129. return json.dumps(payload, indent=2)
  130. def _now_iso() -> str:
  131. return datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z")
  132. def main_factory(
  133. *,
  134. load_config: Callable[[], Config] = load_config,
  135. client_factory: Callable[[], OkxClient] = OkxClient,
  136. authenticated_client_factory: Callable[[Config], OkxClient] = OkxClient,
  137. analyze_fn: Callable = analyze_with_codex,
  138. report_fn: Callable = generate_backtest_report,
  139. bbmr_report_fn: Callable = generate_bbmr_sampled_report,
  140. bbsb_report_fn: Callable = generate_bbsb_sampled_report,
  141. donchian_report_fn: Callable = generate_donchian_sampled_report,
  142. rsi2_report_fn: Callable = generate_rsi2_sampled_report,
  143. ema_pullback_report_fn: Callable = generate_ema_pullback_sampled_report,
  144. write_text: Callable[[str, str], None] = _write_text,
  145. state_path: Path = Path("paper_state.json"),
  146. now_fn: Callable[[], str] = _now_iso,
  147. ):
  148. sampled_report_generators = {
  149. "backtest-bbmr-report": bbmr_report_fn,
  150. "backtest-bbsb-report": bbsb_report_fn,
  151. "backtest-donchian-report": donchian_report_fn,
  152. "backtest-rsi2-report": rsi2_report_fn,
  153. "backtest-ema-pullback-report": ema_pullback_report_fn,
  154. }
  155. def main(argv: Sequence[str] | None = None) -> int:
  156. parser = build_parser()
  157. args = parser.parse_args(argv)
  158. client = client_factory()
  159. if args.command == "fetch-history":
  160. candles = client.get_candles(args.symbol, args.bar, args.limit)
  161. print(_dump_json([asdict(candle) for candle in candles]))
  162. return 0
  163. if args.command == "backtest":
  164. candles = client.get_candles(args.symbol, args.bar, args.limit)
  165. print(_dump_json(run_backtest(candles=candles, leverage=args.leverage).to_dict()))
  166. return 0
  167. if args.command == "backtest-report":
  168. candles = client.get_candles(args.symbol, args.bar, args.limit)
  169. report = report_fn(
  170. candles=candles,
  171. leverage=args.leverage,
  172. output_file=Path(args.output_file),
  173. symbol=args.symbol,
  174. bar=args.bar,
  175. )
  176. print(_dump_json(report))
  177. return 0
  178. if args.command in SAMPLED_REPORT_COMMANDS:
  179. candles = client.get_candles(args.symbol, args.bar, args.history_limit)
  180. report = sampled_report_generators[args.command](
  181. candles=candles,
  182. leverage=args.leverage,
  183. output_file=Path(args.output_file),
  184. symbol=args.symbol,
  185. bar=args.bar,
  186. segments=args.segments,
  187. window_size=args.window_size,
  188. **SAMPLED_REPORT_COMMANDS[args.command]["strategy_params"](args),
  189. )
  190. print(_dump_json(report))
  191. return 0
  192. if args.command == "analyze":
  193. candles = client.get_candles(args.symbol, args.bar, args.limit)
  194. signal = analyze_fn(candles=candles, symbol=args.symbol, bar=args.bar)
  195. output = _dump_json(asdict(signal))
  196. write_text(args.output_file, output)
  197. print(output)
  198. return 0
  199. if args.command == "paper-order":
  200. state = load_state(state_path)
  201. signal = validate_signal(json.loads(Path(args.signal_file).read_text()))
  202. price = signal.entry_price if signal.entry_price is not None else client.get_last_price(args.symbol)
  203. next_state, order = apply_signal(
  204. state=state,
  205. symbol=args.symbol,
  206. signal=signal,
  207. margin_usdt=args.margin_usdt,
  208. price=price,
  209. now=now_fn,
  210. )
  211. save_state(state_path, next_state)
  212. print(_dump_json(asdict(order)))
  213. return 0
  214. if args.command == "okx-account":
  215. auth_client = authenticated_client_factory(load_config())
  216. print(
  217. _dump_json(
  218. {
  219. "balance": auth_client.get_account_balance(args.currency),
  220. "positions": [asdict(position) for position in auth_client.get_positions(args.symbol)],
  221. }
  222. )
  223. )
  224. return 0
  225. if args.command == "okx-order":
  226. config = load_config()
  227. if config.trading_env == "live" and not args.confirm_live:
  228. raise ValueError("live order requires --confirm-live")
  229. if args.margin_usdt > args.max_margin_usdt:
  230. raise ValueError("margin_usdt exceeds max_margin_usdt")
  231. signal = validate_signal(json.loads(Path(args.signal_file).read_text()))
  232. order = authenticated_client_factory(config).place_order(
  233. symbol=args.symbol,
  234. signal=signal,
  235. margin_usdt=args.margin_usdt,
  236. )
  237. print(_dump_json(asdict(order)))
  238. return 0
  239. state = load_state(state_path)
  240. positions = [asdict(position) for position in state.positions if position.symbol == args.symbol]
  241. if not state_path.exists():
  242. save_state(state_path, state)
  243. print(_dump_json(positions))
  244. return 0
  245. return main
  246. main = main_factory()
  247. if __name__ == "__main__":
  248. raise SystemExit(main())