Просмотр исходного кода

fix: validate symbols and margin boundaries

lxy 1 месяц назад
Родитель
Сommit
50b54e0371
2 измененных файлов с 63 добавлено и 2 удалено
  1. 15 2
      okx_codex_trader/okx_client.py
  2. 48 0
      tests/test_okx_client.py

+ 15 - 2
okx_codex_trader/okx_client.py

@@ -55,6 +55,7 @@ def _format_number(value: float) -> str:
 
 class OkxClient:
     base_url = "https://www.okx.com"
+    request_timeout = 10.0
 
     def __init__(self, config: Config, session=None):
         self.config = config
@@ -113,6 +114,7 @@ class OkxClient:
                 headers=headers,
                 params=params,
                 data=body if json_body is not None else None,
+                timeout=self.request_timeout,
             )
         except Exception:
             raise self._transport_error() from None
@@ -176,6 +178,8 @@ class OkxClient:
         data = self._request("GET", "/api/v5/market/ticker", params={"instId": symbol})
         ticker = self._first_item(data)
         try:
+            if ticker.get("instId") != symbol:
+                raise self._invalid_payload()
             return _parse_finite_float(ticker["last"])
         except (KeyError, TypeError, ValueError):
             raise self._invalid_payload() from None
@@ -221,12 +225,18 @@ class OkxClient:
             raise ValueError("swap instrument is required")
         if signal.leverage < 1 or signal.leverage > 3:
             raise ValueError("leverage is invalid")
+        try:
+            margin_value = _parse_finite_float(margin_usdt)
+        except ValueError:
+            raise ValueError("margin_usdt is invalid") from None
+        if margin_value <= 0:
+            raise ValueError("margin_usdt is invalid")
         metadata = self.get_instrument_meta(symbol)
         price = signal.entry_price if signal.entry_price is not None else self.get_last_price(symbol)
         side = "buy" if signal.action == "long" else "sell"
         pos_side = "long" if signal.action == "long" else "short"
         self.ensure_hedge_mode()
-        size = build_contract_size(margin_usdt * signal.leverage, price, metadata)
+        size = build_contract_size(margin_value * signal.leverage, price, metadata)
         self.set_leverage(symbol, signal.leverage, pos_side)
         order_type = "market" if signal.entry_price is None else "limit"
         request_body = {
@@ -255,7 +265,8 @@ class OkxClient:
         )
 
     def get_positions(self, symbol: str) -> list[Position]:
-        data = self._request("GET", "/api/v5/account/positions", params={"instId": symbol})
+        requested_symbol = symbol
+        data = self._request("GET", "/api/v5/account/positions", params={"instId": requested_symbol})
         if not data:
             return []
         try:
@@ -268,6 +279,8 @@ class OkxClient:
                 pos_side = entry["posSide"]
                 if not isinstance(symbol, str) or not isinstance(pos_side, str):
                     raise self._invalid_payload()
+                if symbol != requested_symbol:
+                    raise self._invalid_payload()
                 positions.append(
                     Position(
                         symbol=symbol,

+ 48 - 0
tests/test_okx_client.py

@@ -32,6 +32,7 @@ class RecordedRequest:
     params: dict[str, object] | None
     json_body: dict[str, object] | None
     body: str | None
+    timeout: float | None
 
 
 class DummySession:
@@ -58,6 +59,7 @@ class DummySession:
         params: dict[str, object] | None = None,
         json: dict[str, object] | None = None,
         data: str | None = None,
+        timeout: float | None = None,
     ) -> DummyResponse:
         parsed_json = json
         if parsed_json is None and data is not None:
@@ -69,6 +71,7 @@ class DummySession:
             params=params,
             json_body=parsed_json,
             body=data,
+            timeout=timeout,
         )
         self.request_paths.append(urlparse(url).path)
         self.request_bodies.append(parsed_json)
@@ -336,6 +339,23 @@ def positions_with_non_finite_numeric_response() -> DummyResponse:
     )
 
 
+def positions_with_wrong_symbol_response() -> DummyResponse:
+    return DummyResponse(
+        {
+            "code": "0",
+            "msg": "",
+            "data": [
+                {
+                    "instId": "ETH-USDT-SWAP",
+                    "posSide": "long",
+                    "pos": "1",
+                    "avgPx": "25000",
+                }
+            ],
+        }
+    )
+
+
 def market_long_signal() -> TradeSignal:
     return TradeSignal(
         action="long",
@@ -395,6 +415,8 @@ def test_signed_demo_request_attaches_headers():
         ).digest()
     ).decode()
     assert request.headers["OK-ACCESS-SIGN"] == expected_signature
+    assert request.timeout is not None
+    assert request.timeout > 0
 
 
 def test_signed_post_request_uses_actual_serialized_body_bytes():
@@ -678,6 +700,14 @@ def test_get_last_price_rejects_non_finite_numeric_field():
         client.get_last_price(symbol="BTC-USDT-SWAP")
 
 
+def test_get_last_price_rejects_mismatched_symbol():
+    session = DummySession([DummyResponse({"code": "0", "msg": "", "data": [{"instId": "ETH-USDT-SWAP", "last": "25000"}]})])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_last_price(symbol="BTC-USDT-SWAP")
+
+
 def test_get_instrument_meta_rejects_mismatched_symbol():
     session = DummySession([instrument_with_wrong_symbol_response()])
     client = OkxClient(config=sample_config(), session=session)
@@ -728,6 +758,16 @@ def test_place_demo_order_rejects_invalid_leverage_before_okx():
     assert session.request_paths == []
 
 
+@pytest.mark.parametrize("margin_usdt", [0, -1, float("nan"), float("inf")])
+def test_place_demo_order_rejects_invalid_margin_before_okx(margin_usdt):
+    session = DummySession([])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="margin_usdt is invalid"):
+        client.place_demo_order(symbol="BTC-USDT-SWAP", signal=market_long_signal(), margin_usdt=margin_usdt)
+    assert session.request_paths == []
+
+
 @pytest.mark.parametrize(
     ("symbol", "leverage", "pos_side", "expected_message"),
     [
@@ -811,3 +851,11 @@ def test_get_positions_rejects_non_finite_numeric_fields():
 
     with pytest.raises(ValueError, match="okx response payload is invalid"):
         client.get_positions(symbol="BTC-USDT-SWAP")
+
+
+def test_get_positions_rejects_mismatched_symbol():
+    session = DummySession([positions_with_wrong_symbol_response()])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_positions(symbol="BTC-USDT-SWAP")