test_okx_client.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. from dataclasses import dataclass
  2. from urllib.parse import urlparse
  3. import pytest
  4. from okx_codex_trader.config import Config
  5. from okx_codex_trader.models import InstrumentMeta, TradeSignal
  6. from okx_codex_trader.okx_client import OkxClient, build_contract_size
  7. @dataclass
  8. class DummyResponse:
  9. payload: dict[str, object]
  10. status_code: int = 200
  11. def json(self) -> dict[str, object]:
  12. return self.payload
  13. @dataclass
  14. class RecordedRequest:
  15. method: str
  16. url: str
  17. headers: dict[str, str]
  18. params: dict[str, object] | None
  19. json_body: dict[str, object] | None
  20. class DummySession:
  21. def __init__(self, responses: list[DummyResponse] | None = None):
  22. self._responses = list(responses or [])
  23. self.last_request: RecordedRequest | None = None
  24. self.request_paths: list[str] = []
  25. self.request_bodies: list[dict[str, object] | None] = []
  26. @property
  27. def last_json_body(self) -> dict[str, object] | None:
  28. return self.last_request.json_body if self.last_request else None
  29. def request(
  30. self,
  31. method: str,
  32. url: str,
  33. *,
  34. headers: dict[str, str] | None = None,
  35. params: dict[str, object] | None = None,
  36. json: dict[str, object] | None = None,
  37. ) -> DummyResponse:
  38. self.last_request = RecordedRequest(
  39. method=method,
  40. url=url,
  41. headers=headers or {},
  42. params=params,
  43. json_body=json,
  44. )
  45. self.request_paths.append(urlparse(url).path)
  46. self.request_bodies.append(json)
  47. if self._responses:
  48. return self._responses.pop(0)
  49. return candles_response()
  50. def sample_config() -> Config:
  51. return Config(api_key="key", api_secret="secret", api_passphrase="passphrase")
  52. def candles_response() -> DummyResponse:
  53. return DummyResponse(
  54. {
  55. "code": "0",
  56. "msg": "",
  57. "data": [
  58. ["1710000000000", "25000", "25100", "24900", "25050", "100", "1000", "1000", "1"],
  59. ],
  60. }
  61. )
  62. def instrument_response() -> DummyResponse:
  63. return DummyResponse(
  64. {
  65. "code": "0",
  66. "msg": "",
  67. "data": [
  68. {
  69. "instId": "BTC-USDT-SWAP",
  70. "instType": "SWAP",
  71. "ctVal": "0.001",
  72. "lotSz": "1",
  73. "minSz": "1",
  74. }
  75. ],
  76. }
  77. )
  78. def ticker_response(last: str) -> DummyResponse:
  79. return DummyResponse({"code": "0", "msg": "", "data": [{"instId": "BTC-USDT-SWAP", "last": last}]})
  80. def account_config_response(pos_mode: str) -> DummyResponse:
  81. return DummyResponse({"code": "0", "msg": "", "data": [{"posMode": pos_mode}]})
  82. def leverage_response() -> DummyResponse:
  83. return DummyResponse({"code": "0", "msg": "", "data": [{"lever": "2"}]})
  84. def place_order_response() -> DummyResponse:
  85. return DummyResponse({"code": "0", "msg": "", "data": [{"ordId": "123"}]})
  86. def place_order_response_without_order_id() -> DummyResponse:
  87. return DummyResponse({"code": "0", "msg": "", "data": [{}]})
  88. def error_response(code: str, msg: str) -> DummyResponse:
  89. return DummyResponse({"code": code, "msg": msg, "data": []})
  90. def positions_response() -> DummyResponse:
  91. return DummyResponse(
  92. {
  93. "code": "0",
  94. "msg": "",
  95. "data": [
  96. {
  97. "instId": "BTC-USDT-SWAP",
  98. "posSide": "long",
  99. "pos": "8",
  100. "avgPx": "25000",
  101. }
  102. ],
  103. }
  104. )
  105. def market_long_signal() -> TradeSignal:
  106. return TradeSignal(
  107. action="long",
  108. confidence=0.9,
  109. leverage=2,
  110. entry_price=None,
  111. take_profit_price=26000.0,
  112. stop_loss_price=24000.0,
  113. reason="trend",
  114. )
  115. def limit_short_signal() -> TradeSignal:
  116. return TradeSignal(
  117. action="short",
  118. confidence=0.8,
  119. leverage=2,
  120. entry_price=25000.0,
  121. take_profit_price=24000.0,
  122. stop_loss_price=25500.0,
  123. reason="mean reversion",
  124. )
  125. def flat_signal() -> TradeSignal:
  126. return TradeSignal(
  127. action="flat",
  128. confidence=0.7,
  129. leverage=2,
  130. entry_price=None,
  131. take_profit_price=None,
  132. stop_loss_price=None,
  133. reason="exit",
  134. )
  135. def test_signed_demo_request_attaches_headers():
  136. session = DummySession()
  137. client = OkxClient(config=sample_config(), session=session)
  138. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  139. request = session.last_request
  140. assert request is not None
  141. assert request.headers["x-simulated-trading"] == "1"
  142. assert request.headers["OK-ACCESS-KEY"] == "key"
  143. assert request.headers["OK-ACCESS-SIGN"]
  144. assert request.headers["OK-ACCESS-TIMESTAMP"]
  145. assert request.headers["OK-ACCESS-PASSPHRASE"] == "passphrase"
  146. def test_build_contract_size_rounds_down_to_lot_size():
  147. metadata = InstrumentMeta(ct_val=0.01, lot_sz=0.1, min_sz=0.1)
  148. assert build_contract_size(notional=251, price=25_000, metadata=metadata) == 1.0
  149. def test_build_contract_size_fails_below_min_size():
  150. metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=5)
  151. with pytest.raises(ValueError):
  152. build_contract_size(notional=250, price=25_100, metadata=metadata)
  153. def test_market_order_fetches_latest_price_before_sizing():
  154. session = DummySession(
  155. [
  156. instrument_response(),
  157. ticker_response(last="25000"),
  158. account_config_response(pos_mode="long_short_mode"),
  159. leverage_response(),
  160. place_order_response(),
  161. ]
  162. )
  163. client = OkxClient(config=sample_config(), session=session)
  164. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  165. assert session.request_paths == [
  166. "/api/v5/public/instruments",
  167. "/api/v5/market/ticker",
  168. "/api/v5/account/config",
  169. "/api/v5/account/set-leverage",
  170. "/api/v5/trade/order",
  171. ]
  172. def test_place_demo_order_fails_when_not_hedge_mode():
  173. session = DummySession(
  174. [
  175. instrument_response(),
  176. ticker_response(last="25000"),
  177. account_config_response(pos_mode="net_mode"),
  178. ]
  179. )
  180. client = OkxClient(config=sample_config(), session=session)
  181. with pytest.raises(ValueError):
  182. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  183. def test_limit_short_order_uses_sell_and_short_pos_side():
  184. session = DummySession(
  185. [
  186. instrument_response(),
  187. account_config_response(pos_mode="long_short_mode"),
  188. leverage_response(),
  189. place_order_response(),
  190. ]
  191. )
  192. client = OkxClient(config=sample_config(), session=session)
  193. client.place_demo_order(symbol="ETH-USDT-SWAP", signal=limit_short_signal(), margin_usdt=100)
  194. order_request = session.last_json_body
  195. assert order_request is not None
  196. assert order_request["ordType"] == "limit"
  197. assert order_request["side"] == "sell"
  198. assert order_request["posSide"] == "short"
  199. assert order_request["px"] == "25000"
  200. assert session.request_bodies[2]["lever"] == "2"
  201. assert session.request_bodies[2]["mgnMode"] == "isolated"
  202. def test_flat_signal_returns_noop_without_order_submission():
  203. session = DummySession([])
  204. client = OkxClient(config=sample_config(), session=session)
  205. result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=flat_signal(), margin_usdt=100)
  206. assert result.status == "noop"
  207. assert session.request_paths == []
  208. def test_place_demo_order_sends_computed_sz_and_ignores_tp_sl_fields():
  209. session = DummySession(
  210. [
  211. instrument_response(),
  212. ticker_response(last="25000"),
  213. account_config_response(pos_mode="long_short_mode"),
  214. leverage_response(),
  215. place_order_response(),
  216. ]
  217. )
  218. client = OkxClient(config=sample_config(), session=session)
  219. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  220. order_request = session.last_json_body
  221. assert order_request is not None
  222. assert order_request["sz"] == "8"
  223. assert "tpTriggerPx" not in order_request
  224. assert "slTriggerPx" not in order_request
  225. def test_okx_error_payload_raises_value_error():
  226. session = DummySession([error_response(code="51000", msg="parameter error")])
  227. client = OkxClient(config=sample_config(), session=session)
  228. with pytest.raises(ValueError):
  229. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  230. def test_empty_positions_data_returns_empty_list():
  231. session = DummySession([DummyResponse({"code": "0", "msg": "", "data": []})])
  232. client = OkxClient(config=sample_config(), session=session)
  233. assert client.get_positions(symbol="BTC-USDT-SWAP") == []
  234. def test_malformed_numeric_field_raises_stable_value_error():
  235. session = DummySession(
  236. [
  237. DummyResponse(
  238. {
  239. "code": "0",
  240. "msg": "",
  241. "data": [
  242. {
  243. "instId": "BTC-USDT-SWAP",
  244. "posSide": "long",
  245. "pos": "bad",
  246. "avgPx": "25000",
  247. }
  248. ],
  249. }
  250. )
  251. ]
  252. )
  253. client = OkxClient(config=sample_config(), session=session)
  254. with pytest.raises(ValueError, match="okx response payload is invalid"):
  255. client.get_positions(symbol="BTC-USDT-SWAP")
  256. def test_non_list_okx_data_raises_stable_value_error():
  257. session = DummySession([DummyResponse({"code": "0", "msg": "", "data": {}})])
  258. client = OkxClient(config=sample_config(), session=session)
  259. with pytest.raises(ValueError, match="okx response payload is invalid"):
  260. client.get_positions(symbol="BTC-USDT-SWAP")
  261. def test_place_demo_order_raises_when_order_id_is_missing():
  262. session = DummySession(
  263. [
  264. instrument_response(),
  265. ticker_response(last="25000"),
  266. account_config_response(pos_mode="long_short_mode"),
  267. leverage_response(),
  268. place_order_response_without_order_id(),
  269. ]
  270. )
  271. client = OkxClient(config=sample_config(), session=session)
  272. with pytest.raises(ValueError, match="okx response payload is invalid"):
  273. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  274. def test_place_demo_order_rejects_invalid_leverage_before_okx():
  275. session = DummySession([])
  276. signal = TradeSignal(
  277. action="long",
  278. confidence=0.9,
  279. leverage=4,
  280. entry_price=None,
  281. take_profit_price=None,
  282. stop_loss_price=None,
  283. reason="x",
  284. )
  285. client = OkxClient(config=sample_config(), session=session)
  286. with pytest.raises(ValueError, match="leverage is invalid"):
  287. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=signal, margin_usdt=100)
  288. assert session.request_paths == []
  289. def test_get_positions_returns_normalized_positions():
  290. session = DummySession([positions_response()])
  291. client = OkxClient(config=sample_config(), session=session)
  292. positions = client.get_positions(symbol="BTC-USDT-SWAP")
  293. assert positions[0].symbol == "BTC-USDT-SWAP"
  294. assert positions[0].pos_side == "long"
  295. assert positions[0].size == 8.0
  296. assert positions[0].avg_price == 25000.0