| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- import argparse
- import json
- from datetime import UTC, datetime
- from dataclasses import asdict
- from pathlib import Path
- from typing import Callable, Sequence
- from okx_codex_trader.backtest import run_backtest
- from okx_codex_trader.codex_analyzer import analyze_with_codex
- from okx_codex_trader.config import Config, load_config
- from okx_codex_trader.bbmr_report import generate_bbmr_sampled_report
- from okx_codex_trader.bbsb_report import generate_bbsb_sampled_report
- from okx_codex_trader.donchian_report import DonchianConfig, generate_donchian_sampled_report
- from okx_codex_trader.ema_pullback_report import EMAPullbackConfig, generate_ema_pullback_sampled_report
- from okx_codex_trader.rsi2_report import RSI2Config, generate_rsi2_sampled_report
- from okx_codex_trader.paper_engine import apply_signal, load_state, save_state
- from okx_codex_trader.okx_client import OkxClient
- from okx_codex_trader.report import generate_backtest_report
- from okx_codex_trader.strategy import validate_signal
- SUPPORTED_SYMBOLS = ("BTC-USDT-SWAP", "ETH-USDT-SWAP")
- SAMPLED_REPORT_COMMANDS = {
- "backtest-bbmr-report": {
- "parser_args": (),
- "strategy_params": lambda args: {},
- },
- "backtest-bbsb-report": {
- "parser_args": (),
- "strategy_params": lambda args: {},
- },
- "backtest-donchian-report": {
- "parser_args": (
- (("--entry-window",), {"type": int, "default": DonchianConfig.entry_window}),
- (("--exit-window",), {"type": int, "default": DonchianConfig.exit_window}),
- (("--stop-loss-pct",), {"type": float, "default": DonchianConfig.stop_loss_pct}),
- ),
- "strategy_params": lambda args: {
- "entry_window": args.entry_window,
- "exit_window": args.exit_window,
- "stop_loss_pct": args.stop_loss_pct,
- },
- },
- "backtest-rsi2-report": {
- "parser_args": (
- (("--trend-sma",), {"type": int, "default": RSI2Config.trend_sma}),
- (("--rsi-length",), {"type": int, "default": RSI2Config.rsi_length}),
- (("--rsi-long-threshold",), {"type": float, "default": RSI2Config.rsi_long_threshold}),
- (("--rsi-short-threshold",), {"type": float, "default": RSI2Config.rsi_short_threshold}),
- (("--exit-rsi",), {"type": float, "default": RSI2Config.exit_rsi}),
- ),
- "strategy_params": lambda args: {
- "trend_sma": args.trend_sma,
- "rsi_length": args.rsi_length,
- "rsi_long_threshold": args.rsi_long_threshold,
- "rsi_short_threshold": args.rsi_short_threshold,
- "exit_rsi": args.exit_rsi,
- },
- },
- "backtest-ema-pullback-report": {
- "parser_args": (
- (("--fast-ema",), {"type": int, "default": EMAPullbackConfig.fast_ema}),
- (("--slow-ema",), {"type": int, "default": EMAPullbackConfig.slow_ema}),
- (("--stop-buffer-pct",), {"type": float, "default": EMAPullbackConfig.stop_buffer_pct}),
- ),
- "strategy_params": lambda args: {
- "fast_ema": args.fast_ema,
- "slow_ema": args.slow_ema,
- "stop_buffer_pct": args.stop_buffer_pct,
- },
- },
- }
- def _add_sampled_report_parser(
- subparsers: argparse._SubParsersAction,
- command: str,
- parser_args: tuple[tuple[tuple[str, ...], dict[str, object]], ...],
- ) -> None:
- parser = subparsers.add_parser(command)
- parser.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- parser.add_argument("--bar", required=True)
- parser.add_argument("--history-limit", type=int, required=True)
- parser.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
- parser.add_argument("--segments", type=int, required=True)
- parser.add_argument("--window-size", type=int, required=True)
- parser.add_argument("--output-file", required=True)
- for flags, kwargs in parser_args:
- parser.add_argument(*flags, **kwargs)
- def build_parser() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser(prog="okx-codex-trader")
- subparsers = parser.add_subparsers(dest="command", required=True)
- fetch_history = subparsers.add_parser("fetch-history")
- fetch_history.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- fetch_history.add_argument("--bar", required=True)
- fetch_history.add_argument("--limit", type=int, required=True)
- backtest = subparsers.add_parser("backtest")
- backtest.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- backtest.add_argument("--bar", required=True)
- backtest.add_argument("--limit", type=int, required=True)
- backtest.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
- backtest_report = subparsers.add_parser("backtest-report")
- backtest_report.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- backtest_report.add_argument("--bar", required=True)
- backtest_report.add_argument("--limit", type=int, required=True)
- backtest_report.add_argument("--leverage", type=int, choices=(1, 2, 3), required=True)
- backtest_report.add_argument("--output-file", required=True)
- for command, settings in SAMPLED_REPORT_COMMANDS.items():
- _add_sampled_report_parser(subparsers, command, settings["parser_args"])
- analyze = subparsers.add_parser("analyze")
- analyze.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- analyze.add_argument("--bar", required=True)
- analyze.add_argument("--limit", type=int, required=True)
- analyze.add_argument("--output-file", required=True)
- paper_order = subparsers.add_parser("paper-order")
- paper_order.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- paper_order.add_argument("--signal-file", required=True)
- paper_order.add_argument("--margin-usdt", type=float, required=True)
- positions = subparsers.add_parser("positions")
- positions.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- okx_account = subparsers.add_parser("okx-account")
- okx_account.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- okx_account.add_argument("--currency", default="USDT")
- okx_order = subparsers.add_parser("okx-order")
- okx_order.add_argument("--symbol", choices=SUPPORTED_SYMBOLS, required=True)
- okx_order.add_argument("--signal-file", required=True)
- okx_order.add_argument("--margin-usdt", type=float, required=True)
- okx_order.add_argument("--max-margin-usdt", type=float, required=True)
- okx_order.add_argument("--confirm-live", action="store_true")
- return parser
- def _write_text(path: str, text: str) -> None:
- Path(path).write_text(text)
- def _dump_json(payload: object) -> str:
- return json.dumps(payload, indent=2)
- def _now_iso() -> str:
- return datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z")
- def main_factory(
- *,
- load_config: Callable[[], Config] = load_config,
- client_factory: Callable[[], OkxClient] = OkxClient,
- authenticated_client_factory: Callable[[Config], OkxClient] = OkxClient,
- analyze_fn: Callable = analyze_with_codex,
- report_fn: Callable = generate_backtest_report,
- bbmr_report_fn: Callable = generate_bbmr_sampled_report,
- bbsb_report_fn: Callable = generate_bbsb_sampled_report,
- donchian_report_fn: Callable = generate_donchian_sampled_report,
- rsi2_report_fn: Callable = generate_rsi2_sampled_report,
- ema_pullback_report_fn: Callable = generate_ema_pullback_sampled_report,
- write_text: Callable[[str, str], None] = _write_text,
- state_path: Path = Path("paper_state.json"),
- now_fn: Callable[[], str] = _now_iso,
- ):
- sampled_report_generators = {
- "backtest-bbmr-report": bbmr_report_fn,
- "backtest-bbsb-report": bbsb_report_fn,
- "backtest-donchian-report": donchian_report_fn,
- "backtest-rsi2-report": rsi2_report_fn,
- "backtest-ema-pullback-report": ema_pullback_report_fn,
- }
- def main(argv: Sequence[str] | None = None) -> int:
- parser = build_parser()
- args = parser.parse_args(argv)
- client = client_factory()
- if args.command == "fetch-history":
- candles = client.get_candles(args.symbol, args.bar, args.limit)
- print(_dump_json([asdict(candle) for candle in candles]))
- return 0
- if args.command == "backtest":
- candles = client.get_candles(args.symbol, args.bar, args.limit)
- print(_dump_json(run_backtest(candles=candles, leverage=args.leverage).to_dict()))
- return 0
- if args.command == "backtest-report":
- candles = client.get_candles(args.symbol, args.bar, args.limit)
- report = report_fn(
- candles=candles,
- leverage=args.leverage,
- output_file=Path(args.output_file),
- symbol=args.symbol,
- bar=args.bar,
- )
- print(_dump_json(report))
- return 0
- if args.command in SAMPLED_REPORT_COMMANDS:
- candles = client.get_candles(args.symbol, args.bar, args.history_limit)
- report = sampled_report_generators[args.command](
- candles=candles,
- leverage=args.leverage,
- output_file=Path(args.output_file),
- symbol=args.symbol,
- bar=args.bar,
- segments=args.segments,
- window_size=args.window_size,
- **SAMPLED_REPORT_COMMANDS[args.command]["strategy_params"](args),
- )
- print(_dump_json(report))
- return 0
- if args.command == "analyze":
- candles = client.get_candles(args.symbol, args.bar, args.limit)
- signal = analyze_fn(candles=candles, symbol=args.symbol, bar=args.bar)
- output = _dump_json(asdict(signal))
- write_text(args.output_file, output)
- print(output)
- return 0
- if args.command == "paper-order":
- state = load_state(state_path)
- signal = validate_signal(json.loads(Path(args.signal_file).read_text()))
- price = signal.entry_price if signal.entry_price is not None else client.get_last_price(args.symbol)
- next_state, order = apply_signal(
- state=state,
- symbol=args.symbol,
- signal=signal,
- margin_usdt=args.margin_usdt,
- price=price,
- now=now_fn,
- )
- save_state(state_path, next_state)
- print(_dump_json(asdict(order)))
- return 0
- if args.command == "okx-account":
- auth_client = authenticated_client_factory(load_config())
- print(
- _dump_json(
- {
- "balance": auth_client.get_account_balance(args.currency),
- "positions": [asdict(position) for position in auth_client.get_positions(args.symbol)],
- }
- )
- )
- return 0
- if args.command == "okx-order":
- config = load_config()
- if config.trading_env == "live" and not args.confirm_live:
- raise ValueError("live order requires --confirm-live")
- if args.margin_usdt > args.max_margin_usdt:
- raise ValueError("margin_usdt exceeds max_margin_usdt")
- signal = validate_signal(json.loads(Path(args.signal_file).read_text()))
- order = authenticated_client_factory(config).place_order(
- symbol=args.symbol,
- signal=signal,
- margin_usdt=args.margin_usdt,
- )
- print(_dump_json(asdict(order)))
- return 0
- state = load_state(state_path)
- positions = [asdict(position) for position in state.positions if position.symbol == args.symbol]
- if not state_path.exists():
- save_state(state_path, state)
- print(_dump_json(positions))
- return 0
- return main
- main = main_factory()
- if __name__ == "__main__":
- raise SystemExit(main())
|