test_okx_client.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from dataclasses import dataclass
  2. from urllib.parse import urlparse
  3. import pytest
  4. from okx_codex_trader.config import Config
  5. from okx_codex_trader.models import InstrumentMeta, TradeSignal
  6. from okx_codex_trader.okx_client import OkxClient, build_contract_size
  7. @dataclass
  8. class DummyResponse:
  9. payload: dict[str, object]
  10. status_code: int = 200
  11. def json(self) -> dict[str, object]:
  12. return self.payload
  13. @dataclass
  14. class RecordedRequest:
  15. method: str
  16. url: str
  17. headers: dict[str, str]
  18. params: dict[str, object] | None
  19. json_body: dict[str, object] | None
  20. class DummySession:
  21. def __init__(self, responses: list[DummyResponse] | None = None):
  22. self._responses = list(responses or [])
  23. self.last_request: RecordedRequest | None = None
  24. self.request_paths: list[str] = []
  25. self.request_bodies: list[dict[str, object] | None] = []
  26. @property
  27. def last_json_body(self) -> dict[str, object] | None:
  28. return self.last_request.json_body if self.last_request else None
  29. def request(
  30. self,
  31. method: str,
  32. url: str,
  33. *,
  34. headers: dict[str, str] | None = None,
  35. params: dict[str, object] | None = None,
  36. json: dict[str, object] | None = None,
  37. ) -> DummyResponse:
  38. self.last_request = RecordedRequest(
  39. method=method,
  40. url=url,
  41. headers=headers or {},
  42. params=params,
  43. json_body=json,
  44. )
  45. self.request_paths.append(urlparse(url).path)
  46. self.request_bodies.append(json)
  47. if self._responses:
  48. return self._responses.pop(0)
  49. return candles_response()
  50. def sample_config() -> Config:
  51. return Config(api_key="key", api_secret="secret", api_passphrase="passphrase")
  52. def candles_response() -> DummyResponse:
  53. return DummyResponse(
  54. {
  55. "code": "0",
  56. "msg": "",
  57. "data": [
  58. ["1710000000000", "25000", "25100", "24900", "25050", "100", "1000", "1000", "1"],
  59. ],
  60. }
  61. )
  62. def instrument_response() -> DummyResponse:
  63. return DummyResponse(
  64. {
  65. "code": "0",
  66. "msg": "",
  67. "data": [
  68. {
  69. "instId": "BTC-USDT-SWAP",
  70. "instType": "SWAP",
  71. "ctVal": "0.001",
  72. "lotSz": "1",
  73. "minSz": "1",
  74. }
  75. ],
  76. }
  77. )
  78. def ticker_response(last: str) -> DummyResponse:
  79. return DummyResponse({"code": "0", "msg": "", "data": [{"instId": "BTC-USDT-SWAP", "last": last}]})
  80. def account_config_response(pos_mode: str) -> DummyResponse:
  81. return DummyResponse({"code": "0", "msg": "", "data": [{"posMode": pos_mode}]})
  82. def leverage_response() -> DummyResponse:
  83. return DummyResponse({"code": "0", "msg": "", "data": [{"lever": "2"}]})
  84. def place_order_response() -> DummyResponse:
  85. return DummyResponse({"code": "0", "msg": "", "data": [{"ordId": "123"}]})
  86. def error_response(code: str, msg: str) -> DummyResponse:
  87. return DummyResponse({"code": code, "msg": msg, "data": []})
  88. def positions_response() -> DummyResponse:
  89. return DummyResponse(
  90. {
  91. "code": "0",
  92. "msg": "",
  93. "data": [
  94. {
  95. "instId": "BTC-USDT-SWAP",
  96. "posSide": "long",
  97. "pos": "8",
  98. "avgPx": "25000",
  99. }
  100. ],
  101. }
  102. )
  103. def market_long_signal() -> TradeSignal:
  104. return TradeSignal(
  105. action="long",
  106. confidence=0.9,
  107. leverage=2,
  108. entry_price=None,
  109. take_profit_price=26000.0,
  110. stop_loss_price=24000.0,
  111. reason="trend",
  112. )
  113. def limit_short_signal() -> TradeSignal:
  114. return TradeSignal(
  115. action="short",
  116. confidence=0.8,
  117. leverage=2,
  118. entry_price=25000.0,
  119. take_profit_price=24000.0,
  120. stop_loss_price=25500.0,
  121. reason="mean reversion",
  122. )
  123. def flat_signal() -> TradeSignal:
  124. return TradeSignal(
  125. action="flat",
  126. confidence=0.7,
  127. leverage=2,
  128. entry_price=None,
  129. take_profit_price=None,
  130. stop_loss_price=None,
  131. reason="exit",
  132. )
  133. def test_signed_demo_request_attaches_headers():
  134. session = DummySession()
  135. client = OkxClient(config=sample_config(), session=session)
  136. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  137. request = session.last_request
  138. assert request is not None
  139. assert request.headers["x-simulated-trading"] == "1"
  140. assert request.headers["OK-ACCESS-KEY"] == "key"
  141. def test_build_contract_size_rounds_down_to_lot_size():
  142. metadata = InstrumentMeta(ct_val=0.01, lot_sz=0.1, min_sz=0.1)
  143. assert build_contract_size(notional=251, price=25_000, metadata=metadata) == 1.0
  144. def test_build_contract_size_fails_below_min_size():
  145. metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=5)
  146. with pytest.raises(ValueError):
  147. build_contract_size(notional=250, price=25_100, metadata=metadata)
  148. def test_market_order_fetches_latest_price_before_sizing():
  149. session = DummySession(
  150. [
  151. instrument_response(),
  152. ticker_response(last="25000"),
  153. account_config_response(pos_mode="long_short_mode"),
  154. leverage_response(),
  155. place_order_response(),
  156. ]
  157. )
  158. client = OkxClient(config=sample_config(), session=session)
  159. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  160. assert session.request_paths == [
  161. "/api/v5/public/instruments",
  162. "/api/v5/market/ticker",
  163. "/api/v5/account/config",
  164. "/api/v5/account/set-leverage",
  165. "/api/v5/trade/order",
  166. ]
  167. def test_place_demo_order_fails_when_not_hedge_mode():
  168. session = DummySession(
  169. [
  170. instrument_response(),
  171. ticker_response(last="25000"),
  172. account_config_response(pos_mode="net_mode"),
  173. ]
  174. )
  175. client = OkxClient(config=sample_config(), session=session)
  176. with pytest.raises(ValueError):
  177. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  178. def test_limit_short_order_uses_sell_and_short_pos_side():
  179. session = DummySession(
  180. [
  181. instrument_response(),
  182. account_config_response(pos_mode="long_short_mode"),
  183. leverage_response(),
  184. place_order_response(),
  185. ]
  186. )
  187. client = OkxClient(config=sample_config(), session=session)
  188. client.place_demo_order(symbol="ETH-USDT-SWAP", signal=limit_short_signal(), margin_usdt=100)
  189. order_request = session.last_json_body
  190. assert order_request is not None
  191. assert order_request["ordType"] == "limit"
  192. assert order_request["side"] == "sell"
  193. assert order_request["posSide"] == "short"
  194. assert order_request["px"] == "25000"
  195. assert session.request_bodies[2]["lever"] == "2"
  196. assert session.request_bodies[2]["mgnMode"] == "isolated"
  197. def test_flat_signal_returns_noop_without_order_submission():
  198. session = DummySession([])
  199. client = OkxClient(config=sample_config(), session=session)
  200. result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=flat_signal(), margin_usdt=100)
  201. assert result.status == "noop"
  202. assert session.request_paths == []
  203. def test_place_demo_order_sends_computed_sz_and_ignores_tp_sl_fields():
  204. session = DummySession(
  205. [
  206. instrument_response(),
  207. ticker_response(last="25000"),
  208. account_config_response(pos_mode="long_short_mode"),
  209. leverage_response(),
  210. place_order_response(),
  211. ]
  212. )
  213. client = OkxClient(config=sample_config(), session=session)
  214. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  215. order_request = session.last_json_body
  216. assert order_request is not None
  217. assert order_request["sz"] == "8"
  218. assert "tpTriggerPx" not in order_request
  219. assert "slTriggerPx" not in order_request
  220. def test_okx_error_payload_raises_value_error():
  221. session = DummySession([error_response(code="51000", msg="parameter error")])
  222. client = OkxClient(config=sample_config(), session=session)
  223. with pytest.raises(ValueError):
  224. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  225. def test_get_positions_returns_normalized_positions():
  226. session = DummySession([positions_response()])
  227. client = OkxClient(config=sample_config(), session=session)
  228. positions = client.get_positions(symbol="BTC-USDT-SWAP")
  229. assert positions[0].symbol == "BTC-USDT-SWAP"