Procházet zdrojové kódy

fix: normalize okx transport errors

lxy před 1 měsícem
rodič
revize
ef44053b03
2 změnil soubory, kde provedl 70 přidání a 17 odebrání
  1. 30 12
      okx_codex_trader/okx_client.py
  2. 40 5
      tests/test_okx_client.py

+ 30 - 12
okx_codex_trader/okx_client.py

@@ -4,12 +4,16 @@ import hmac
 import json
 from datetime import UTC, datetime
 from decimal import Decimal, ROUND_DOWN
+from typing import TypeAlias
 from urllib.parse import urlencode
 
 from okx_codex_trader.config import Config
 from okx_codex_trader.models import Candle, InstrumentMeta, OrderResult, Position, TradeSignal
 
 
+OkxRow: TypeAlias = dict[str, object] | list[object]
+
+
 def build_contract_size(notional: float, price: float, metadata: InstrumentMeta) -> float:
     raw_size = Decimal(str(notional)) / (Decimal(str(price)) * Decimal(str(metadata.ct_val)))
     lot_size = Decimal(str(metadata.lot_sz))
@@ -37,10 +41,16 @@ class OkxClient:
     def _invalid_payload(self) -> ValueError:
         return ValueError("okx response payload is invalid")
 
-    def _first_item(self, data: list[dict[str, object]]) -> dict[str, object]:
+    def _transport_error(self) -> ValueError:
+        return ValueError("okx transport error")
+
+    def _first_item(self, data: list[OkxRow]) -> dict[str, object]:
         if not data:
             raise self._invalid_payload()
-        return data[0]
+        item = data[0]
+        if not isinstance(item, dict):
+            raise self._invalid_payload()
+        return item
 
     def _request(
         self,
@@ -49,7 +59,7 @@ class OkxClient:
         *,
         params: dict[str, object] | None = None,
         json_body: dict[str, object] | None = None,
-    ) -> list[dict[str, object]]:
+    ) -> list[OkxRow]:
         timestamp = datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
         query = urlencode(params or {})
         path_with_query = path if not query else f"{path}?{query}"
@@ -68,14 +78,22 @@ class OkxClient:
             "OK-ACCESS-PASSPHRASE": self.config.api_passphrase,
             "x-simulated-trading": "1",
         }
-        response = self.session.request(
-            method.upper(),
-            f"{self.base_url}{path}",
-            headers=headers,
-            params=params,
-            json=json_body,
-        )
-        payload = response.json()
+        try:
+            response = self.session.request(
+                method.upper(),
+                f"{self.base_url}{path}",
+                headers=headers,
+                params=params,
+                json=json_body,
+            )
+        except Exception:
+            raise self._transport_error() from None
+        try:
+            payload = response.json()
+        except Exception:
+            raise self._invalid_payload() from None
+        if not isinstance(payload, dict):
+            raise self._invalid_payload()
         if getattr(response, "status_code", 200) >= 400:
             raise ValueError(str(payload.get("msg") or "okx http error"))
         if payload.get("code") != "0":
@@ -105,7 +123,7 @@ class OkxClient:
                 for entry in data
             ]
             return sorted(candles, key=lambda candle: candle.ts)
-        except (IndexError, TypeError, ValueError):
+        except (IndexError, KeyError, TypeError, ValueError):
             raise self._invalid_payload() from None
 
     def get_instrument_meta(self, symbol: str) -> InstrumentMeta:

+ 40 - 5
tests/test_okx_client.py

@@ -1,5 +1,8 @@
+import base64
+import hashlib
+import hmac
 from dataclasses import dataclass
-from urllib.parse import urlparse
+from urllib.parse import urlencode, urlparse
 
 import pytest
 
@@ -12,8 +15,11 @@ from okx_codex_trader.okx_client import OkxClient, build_contract_size
 class DummyResponse:
     payload: dict[str, object]
     status_code: int = 200
+    json_error: Exception | None = None
 
     def json(self) -> dict[str, object]:
+        if self.json_error is not None:
+            raise self.json_error
         return self.payload
 
 
@@ -27,7 +33,7 @@ class RecordedRequest:
 
 
 class DummySession:
-    def __init__(self, responses: list[DummyResponse] | None = None):
+    def __init__(self, responses: list[DummyResponse | Exception] | None = None):
         self._responses = list(responses or [])
         self.last_request: RecordedRequest | None = None
         self.request_paths: list[str] = []
@@ -56,7 +62,10 @@ class DummySession:
         self.request_paths.append(urlparse(url).path)
         self.request_bodies.append(json)
         if self._responses:
-            return self._responses.pop(0)
+            response = self._responses.pop(0)
+            if isinstance(response, Exception):
+                raise response
+            return response
         return candles_response()
 
 
@@ -258,9 +267,19 @@ def test_signed_demo_request_attaches_headers():
     assert request is not None
     assert request.headers["x-simulated-trading"] == "1"
     assert request.headers["OK-ACCESS-KEY"] == "key"
-    assert request.headers["OK-ACCESS-SIGN"]
-    assert request.headers["OK-ACCESS-TIMESTAMP"]
     assert request.headers["OK-ACCESS-PASSPHRASE"] == "passphrase"
+    timestamp = request.headers["OK-ACCESS-TIMESTAMP"]
+    path = urlparse(request.url).path
+    query = urlencode(request.params or {})
+    path_with_query = path if not query else f"{path}?{query}"
+    expected_signature = base64.b64encode(
+        hmac.new(
+            b"secret",
+            f"{timestamp}{request.method}{path_with_query}".encode(),
+            hashlib.sha256,
+        ).digest()
+    ).decode()
+    assert request.headers["OK-ACCESS-SIGN"] == expected_signature
 
 
 def test_get_candles_returns_chronological_ascending_order():
@@ -402,6 +421,22 @@ def test_okx_error_payload_raises_value_error():
         client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
 
 
+def test_transport_failure_raises_stable_value_error():
+    session = DummySession([RuntimeError("socket closed")])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx transport error"):
+        client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+
+def test_invalid_json_raises_stable_value_error():
+    session = DummySession([DummyResponse({}, json_error=ValueError("bad json"))])
+    client = OkxClient(config=sample_config(), session=session)
+
+    with pytest.raises(ValueError, match="okx response payload is invalid"):
+        client.get_candles(symbol="BTC-USDT-SWAP", bar="1H", limit=20)
+
+
 def test_empty_positions_data_returns_empty_list():
     session = DummySession([DummyResponse({"code": "0", "msg": "", "data": []})])
     client = OkxClient(config=sample_config(), session=session)