test_okx_client.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from dataclasses import dataclass
  2. from urllib.parse import urlparse
  3. import pytest
  4. from okx_codex_trader.config import Config
  5. from okx_codex_trader.models import InstrumentMeta, TradeSignal
  6. from okx_codex_trader.okx_client import OkxClient, build_contract_size
  7. @dataclass
  8. class DummyResponse:
  9. payload: dict[str, object]
  10. status_code: int = 200
  11. def json(self) -> dict[str, object]:
  12. return self.payload
  13. @dataclass
  14. class RecordedRequest:
  15. method: str
  16. url: str
  17. headers: dict[str, str]
  18. params: dict[str, object] | None
  19. json_body: dict[str, object] | None
  20. class DummySession:
  21. def __init__(self, responses: list[DummyResponse] | None = None):
  22. self._responses = list(responses or [])
  23. self.last_request: RecordedRequest | None = None
  24. self.request_paths: list[str] = []
  25. self.request_bodies: list[dict[str, object] | None] = []
  26. @property
  27. def last_json_body(self) -> dict[str, object] | None:
  28. return self.last_request.json_body if self.last_request else None
  29. def request(
  30. self,
  31. method: str,
  32. url: str,
  33. *,
  34. headers: dict[str, str] | None = None,
  35. params: dict[str, object] | None = None,
  36. json: dict[str, object] | None = None,
  37. ) -> DummyResponse:
  38. self.last_request = RecordedRequest(
  39. method=method,
  40. url=url,
  41. headers=headers or {},
  42. params=params,
  43. json_body=json,
  44. )
  45. self.request_paths.append(urlparse(url).path)
  46. self.request_bodies.append(json)
  47. if self._responses:
  48. return self._responses.pop(0)
  49. return candles_response()
  50. def sample_config() -> Config:
  51. return Config(api_key="key", api_secret="secret", api_passphrase="passphrase")
  52. def candles_response() -> DummyResponse:
  53. return DummyResponse(
  54. {
  55. "code": "0",
  56. "msg": "",
  57. "data": [
  58. ["1710000000000", "25000", "25100", "24900", "25050", "100", "1000", "1000", "1"],
  59. ],
  60. }
  61. )
  62. def instrument_response() -> DummyResponse:
  63. return DummyResponse(
  64. {
  65. "code": "0",
  66. "msg": "",
  67. "data": [
  68. {
  69. "instId": "BTC-USDT-SWAP",
  70. "instType": "SWAP",
  71. "ctVal": "0.001",
  72. "lotSz": "1",
  73. "minSz": "1",
  74. }
  75. ],
  76. }
  77. )
  78. def ticker_response(last: str) -> DummyResponse:
  79. return DummyResponse({"code": "0", "msg": "", "data": [{"instId": "BTC-USDT-SWAP", "last": last}]})
  80. def account_config_response(pos_mode: str) -> DummyResponse:
  81. return DummyResponse({"code": "0", "msg": "", "data": [{"posMode": pos_mode}]})
  82. def leverage_response() -> DummyResponse:
  83. return DummyResponse({"code": "0", "msg": "", "data": [{"lever": "2"}]})
  84. def place_order_response() -> DummyResponse:
  85. return DummyResponse({"code": "0", "msg": "", "data": [{"ordId": "123"}]})
  86. def error_response(code: str, msg: str) -> DummyResponse:
  87. return DummyResponse({"code": code, "msg": msg, "data": []})
  88. def positions_response() -> DummyResponse:
  89. return DummyResponse(
  90. {
  91. "code": "0",
  92. "msg": "",
  93. "data": [
  94. {
  95. "instId": "BTC-USDT-SWAP",
  96. "posSide": "long",
  97. "pos": "8",
  98. "avgPx": "25000",
  99. }
  100. ],
  101. }
  102. )
  103. def market_long_signal() -> TradeSignal:
  104. return TradeSignal(
  105. action="long",
  106. confidence=0.9,
  107. leverage=2,
  108. entry_price=None,
  109. take_profit_price=26000.0,
  110. stop_loss_price=24000.0,
  111. reason="trend",
  112. )
  113. def limit_short_signal() -> TradeSignal:
  114. return TradeSignal(
  115. action="short",
  116. confidence=0.8,
  117. leverage=2,
  118. entry_price=25000.0,
  119. take_profit_price=24000.0,
  120. stop_loss_price=25500.0,
  121. reason="mean reversion",
  122. )
  123. def flat_signal() -> TradeSignal:
  124. return TradeSignal(
  125. action="flat",
  126. confidence=0.7,
  127. leverage=2,
  128. entry_price=None,
  129. take_profit_price=None,
  130. stop_loss_price=None,
  131. reason="exit",
  132. )
  133. def test_signed_demo_request_attaches_headers():
  134. session = DummySession()
  135. client = OkxClient(config=sample_config(), session=session)
  136. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  137. request = session.last_request
  138. assert request is not None
  139. assert request.headers["x-simulated-trading"] == "1"
  140. assert request.headers["OK-ACCESS-KEY"] == "key"
  141. assert request.headers["OK-ACCESS-SIGN"]
  142. assert request.headers["OK-ACCESS-TIMESTAMP"]
  143. assert request.headers["OK-ACCESS-PASSPHRASE"] == "passphrase"
  144. def test_build_contract_size_rounds_down_to_lot_size():
  145. metadata = InstrumentMeta(ct_val=0.01, lot_sz=0.1, min_sz=0.1)
  146. assert build_contract_size(notional=251, price=25_000, metadata=metadata) == 1.0
  147. def test_build_contract_size_fails_below_min_size():
  148. metadata = InstrumentMeta(ct_val=0.01, lot_sz=1, min_sz=5)
  149. with pytest.raises(ValueError):
  150. build_contract_size(notional=250, price=25_100, metadata=metadata)
  151. def test_market_order_fetches_latest_price_before_sizing():
  152. session = DummySession(
  153. [
  154. instrument_response(),
  155. ticker_response(last="25000"),
  156. account_config_response(pos_mode="long_short_mode"),
  157. leverage_response(),
  158. place_order_response(),
  159. ]
  160. )
  161. client = OkxClient(config=sample_config(), session=session)
  162. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  163. assert session.request_paths == [
  164. "/api/v5/public/instruments",
  165. "/api/v5/market/ticker",
  166. "/api/v5/account/config",
  167. "/api/v5/account/set-leverage",
  168. "/api/v5/trade/order",
  169. ]
  170. def test_place_demo_order_fails_when_not_hedge_mode():
  171. session = DummySession(
  172. [
  173. instrument_response(),
  174. ticker_response(last="25000"),
  175. account_config_response(pos_mode="net_mode"),
  176. ]
  177. )
  178. client = OkxClient(config=sample_config(), session=session)
  179. with pytest.raises(ValueError):
  180. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  181. def test_limit_short_order_uses_sell_and_short_pos_side():
  182. session = DummySession(
  183. [
  184. instrument_response(),
  185. account_config_response(pos_mode="long_short_mode"),
  186. leverage_response(),
  187. place_order_response(),
  188. ]
  189. )
  190. client = OkxClient(config=sample_config(), session=session)
  191. client.place_demo_order(symbol="ETH-USDT-SWAP", signal=limit_short_signal(), margin_usdt=100)
  192. order_request = session.last_json_body
  193. assert order_request is not None
  194. assert order_request["ordType"] == "limit"
  195. assert order_request["side"] == "sell"
  196. assert order_request["posSide"] == "short"
  197. assert order_request["px"] == "25000"
  198. assert session.request_bodies[2]["lever"] == "2"
  199. assert session.request_bodies[2]["mgnMode"] == "isolated"
  200. def test_flat_signal_returns_noop_without_order_submission():
  201. session = DummySession([])
  202. client = OkxClient(config=sample_config(), session=session)
  203. result = client.place_demo_order(symbol="BTC-USDT-SWAP", signal=flat_signal(), margin_usdt=100)
  204. assert result.status == "noop"
  205. assert session.request_paths == []
  206. def test_place_demo_order_sends_computed_sz_and_ignores_tp_sl_fields():
  207. session = DummySession(
  208. [
  209. instrument_response(),
  210. ticker_response(last="25000"),
  211. account_config_response(pos_mode="long_short_mode"),
  212. leverage_response(),
  213. place_order_response(),
  214. ]
  215. )
  216. client = OkxClient(config=sample_config(), session=session)
  217. client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=100)
  218. order_request = session.last_json_body
  219. assert order_request is not None
  220. assert order_request["sz"] == "8"
  221. assert "tpTriggerPx" not in order_request
  222. assert "slTriggerPx" not in order_request
  223. def test_okx_error_payload_raises_value_error():
  224. session = DummySession([error_response(code="51000", msg="parameter error")])
  225. client = OkxClient(config=sample_config(), session=session)
  226. with pytest.raises(ValueError):
  227. client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
  228. def test_get_positions_returns_normalized_positions():
  229. session = DummySession([positions_response()])
  230. client = OkxClient(config=sample_config(), session=session)
  231. positions = client.get_positions(symbol="BTC-USDT-SWAP")
  232. assert positions[0].symbol == "BTC-USDT-SWAP"
  233. assert positions[0].pos_side == "long"
  234. assert positions[0].size == 8.0
  235. assert positions[0].avg_price == 25000.0