diff --git a/src/cerbero_mcp/common/candles.py b/src/cerbero_mcp/common/candles.py new file mode 100644 index 0000000..42cb16e --- /dev/null +++ b/src/cerbero_mcp/common/candles.py @@ -0,0 +1,53 @@ +"""Shared OHLCV candle model + validator for exchange historical endpoints.""" +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException +from pydantic import BaseModel, ConfigDict, ValidationError, model_validator + + +class Candle(BaseModel): + model_config = ConfigDict(extra="ignore") + + timestamp: int + open: float + high: float + low: float + close: float + volume: float + + @model_validator(mode="after") + def _check(self) -> Candle: + if self.timestamp <= 0: + raise ValueError(f"timestamp must be > 0, got {self.timestamp}") + if self.volume < 0: + raise ValueError(f"volume must be >= 0, got {self.volume}") + if self.high < max(self.open, self.close, self.low): + raise ValueError( + f"high {self.high} < max(open={self.open}, " + f"close={self.close}, low={self.low})" + ) + if self.low > min(self.open, self.close, self.high): + raise ValueError( + f"low {self.low} > min(open={self.open}, " + f"close={self.close}, high={self.high})" + ) + return self + + +def validate_candles(raw: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Coerce upstream rows into validated candle dicts, sorted by timestamp. + + Raises HTTPException(502) if any row violates OHLC consistency or schema — + upstream data corruption is mapped to a retryable error envelope. + """ + try: + candles = [Candle.model_validate(row) for row in raw] + except ValidationError as e: + raise HTTPException( + status_code=502, + detail=f"upstream returned malformed candle: {e.errors()[0]['msg']}", + ) from e + candles.sort(key=lambda c: c.timestamp) + return [c.model_dump() for c in candles] diff --git a/src/cerbero_mcp/exchanges/alpaca/client.py b/src/cerbero_mcp/exchanges/alpaca/client.py index d48a94f..94fae3c 100644 --- a/src/cerbero_mcp/exchanges/alpaca/client.py +++ b/src/cerbero_mcp/exchanges/alpaca/client.py @@ -20,6 +20,7 @@ from typing import Any import httpx +from cerbero_mcp.common.candles import validate_candles from cerbero_mcp.common.http import async_client # ── Endpoint base ──────────────────────────────────────────────── @@ -301,9 +302,17 @@ class AlpacaClient: bars_dict = (data or {}).get("bars") or {} rows = bars_dict.get(symbol) or [] - bars = [ + + def _iso_to_ms(ts: str | int | None) -> int | None: + if ts is None or isinstance(ts, int): + return ts + return int(_dt.datetime.fromisoformat( + ts.replace("Z", "+00:00") + ).timestamp() * 1000) + + candles = validate_candles([ { - "timestamp": b.get("t"), + "timestamp": _iso_to_ms(b.get("t")), "open": b.get("o"), "high": b.get("h"), "low": b.get("l"), @@ -311,12 +320,12 @@ class AlpacaClient: "volume": b.get("v"), } for b in rows - ] + ]) return { "symbol": symbol, "asset_class": ac, "interval": interval, - "bars": bars, + "candles": candles, } async def get_snapshot(self, symbol: str) -> dict: diff --git a/src/cerbero_mcp/exchanges/bybit/client.py b/src/cerbero_mcp/exchanges/bybit/client.py index 40e9833..074e929 100644 --- a/src/cerbero_mcp/exchanges/bybit/client.py +++ b/src/cerbero_mcp/exchanges/bybit/client.py @@ -22,6 +22,7 @@ import httpx from cerbero_mcp.common import indicators as ind from cerbero_mcp.common import microstructure as micro +from cerbero_mcp.common.candles import validate_candles BASE_MAINNET = "https://api.bybit.com" BASE_TESTNET = "https://api-testnet.bybit.com" @@ -254,18 +255,17 @@ class BybitClient: params["end"] = end resp = await self._request_public("GET", "/v5/market/kline", params=params) rows = (resp.get("result") or {}).get("list") or [] - rows_sorted = sorted(rows, key=lambda r: int(r[0])) - candles = [ + candles = validate_candles([ { "timestamp": int(r[0]), - "open": float(r[1]), - "high": float(r[2]), - "low": float(r[3]), - "close": float(r[4]), - "volume": float(r[5]), + "open": r[1], + "high": r[2], + "low": r[3], + "close": r[4], + "volume": r[5], } - for r in rows_sorted - ] + for r in rows + ]) return {"symbol": symbol, "candles": candles} async def get_indicators( diff --git a/src/cerbero_mcp/exchanges/deribit/client.py b/src/cerbero_mcp/exchanges/deribit/client.py index 08057d2..81c8e0b 100644 --- a/src/cerbero_mcp/exchanges/deribit/client.py +++ b/src/cerbero_mcp/exchanges/deribit/client.py @@ -11,10 +11,11 @@ from fastapi import HTTPException from cerbero_mcp.common import indicators as ind from cerbero_mcp.common import microstructure as micro from cerbero_mcp.common import options as opt +from cerbero_mcp.common.candles import validate_candles from cerbero_mcp.common.http import async_client -def _parse_deribit_response(resp) -> dict: +def _parse_deribit_response(resp: Any) -> dict[str, Any]: """Map Deribit upstream errors to a clean HTTP 502 (retryable) instead of leaking JSONDecodeError when the body is HTML (e.g. Cloudflare 5xx page).""" if resp.status_code >= 500: @@ -23,7 +24,8 @@ def _parse_deribit_response(resp) -> dict: detail=f"Deribit upstream HTTP {resp.status_code}", ) try: - return resp.json() + data: dict[str, Any] = resp.json() + return data except json.JSONDecodeError as e: raise HTTPException( status_code=502, @@ -121,10 +123,10 @@ class DeribitClient: resp = await http.get(url, params=request_params, headers=headers) data = _parse_deribit_response(resp) if "result" in data: - return data # type: ignore[no-any-return] + return data return {"result": None, "error": error_msg} - return data # type: ignore[no-any-return] + return data # ── Read tools ─────────────────────────────────────────────── @@ -418,24 +420,24 @@ class DeribitClient: }, ) r = raw.get("result") or {} - candles = [] ticks = r.get("ticks", []) or [] opens = r.get("open", []) or [] highs = r.get("high", []) or [] lows = r.get("low", []) or [] closes = r.get("close", []) or [] volumes = r.get("volume", []) or [] - for idx, ts in enumerate(ticks): - if idx >= min(len(opens), len(highs), len(lows), len(closes), len(volumes)): - break - candles.append({ - "timestamp": ts, - "open": opens[idx], - "high": highs[idx], - "low": lows[idx], - "close": closes[idx], - "volume": volumes[idx], - }) + n = min(len(ticks), len(opens), len(highs), len(lows), len(closes), len(volumes)) + candles = validate_candles([ + { + "timestamp": ticks[i], + "open": opens[i], + "high": highs[i], + "low": lows[i], + "close": closes[i], + "volume": volumes[i], + } + for i in range(n) + ]) return {"candles": candles} async def get_dvol( diff --git a/src/cerbero_mcp/exchanges/hyperliquid/client.py b/src/cerbero_mcp/exchanges/hyperliquid/client.py index d15d4ba..ac62bcb 100644 --- a/src/cerbero_mcp/exchanges/hyperliquid/client.py +++ b/src/cerbero_mcp/exchanges/hyperliquid/client.py @@ -27,6 +27,7 @@ from eth_account.messages import encode_typed_data from eth_utils import keccak, to_hex from cerbero_mcp.common import indicators as ind +from cerbero_mcp.common.candles import validate_candles from cerbero_mcp.common.http import async_client BASE_LIVE = "https://api.hyperliquid.xyz" @@ -408,18 +409,17 @@ class HyperliquidClient: }, } ) - candles = [] - for c in data: - candles.append( - { - "timestamp": c.get("t", 0), - "open": float(c.get("o", 0)), - "high": float(c.get("h", 0)), - "low": float(c.get("l", 0)), - "close": float(c.get("c", 0)), - "volume": float(c.get("v", 0)), - } - ) + candles = validate_candles([ + { + "timestamp": c.get("t"), + "open": c.get("o"), + "high": c.get("h"), + "low": c.get("l"), + "close": c.get("c"), + "volume": c.get("v"), + } + for c in data + ]) return {"candles": candles} async def get_open_orders(self) -> list[dict[str, Any]]: diff --git a/src/cerbero_mcp/exchanges/ibkr/client.py b/src/cerbero_mcp/exchanges/ibkr/client.py index 1768980..a983a11 100644 --- a/src/cerbero_mcp/exchanges/ibkr/client.py +++ b/src/cerbero_mcp/exchanges/ibkr/client.py @@ -9,6 +9,7 @@ from typing import Any import httpx +from cerbero_mcp.common.candles import validate_candles from cerbero_mcp.common.http import async_client from cerbero_mcp.exchanges.ibkr.oauth import ( IBKRAuthError, @@ -234,21 +235,22 @@ class IBKRClient: params={"conid": str(conid), "period": period, "bar": bar}, ) rows = (data or {}).get("data") or [] + candles = validate_candles([ + { + "timestamp": r.get("t"), + "open": r.get("o"), + "high": r.get("h"), + "low": r.get("l"), + "close": r.get("c"), + "volume": r.get("v"), + } + for r in rows + ]) return { "symbol": symbol, "asset_class": asset_class, "interval": bar, - "bars": [ - { - "timestamp": r.get("t"), - "open": r.get("o"), - "high": r.get("h"), - "low": r.get("l"), - "close": r.get("c"), - "volume": r.get("v"), - } - for r in rows - ], + "candles": candles, } async def get_option_chain( diff --git a/tests/unit/common/test_candles.py b/tests/unit/common/test_candles.py new file mode 100644 index 0000000..7639334 --- /dev/null +++ b/tests/unit/common/test_candles.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import pytest +from cerbero_mcp.common.candles import Candle, validate_candles +from fastapi import HTTPException + + +def test_valid_candle(): + c = Candle(timestamp=1_700_000_000_000, open=100.0, high=110.0, + low=95.0, close=105.0, volume=12.5) + assert c.high == 110.0 + + +def test_high_below_close_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=1, open=100, high=90, low=80, close=95, volume=1) + + +def test_high_below_open_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=1, open=100, high=90, low=80, close=85, volume=1) + + +def test_low_above_close_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=1, open=100, high=110, low=105, close=102, volume=1) + + +def test_low_above_open_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=1, open=95, high=110, low=100, close=105, volume=1) + + +def test_negative_volume_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=1, open=100, high=110, low=90, close=105, volume=-1) + + +def test_non_positive_timestamp_rejected(): + with pytest.raises(ValueError): + Candle(timestamp=0, open=100, high=110, low=90, close=105, volume=1) + + +def test_validate_candles_sorts_by_timestamp(): + raw = [ + {"timestamp": 3, "open": 1, "high": 2, "low": 1, "close": 1, "volume": 0}, + {"timestamp": 1, "open": 1, "high": 2, "low": 1, "close": 1, "volume": 0}, + {"timestamp": 2, "open": 1, "high": 2, "low": 1, "close": 1, "volume": 0}, + ] + out = validate_candles(raw) + assert [c["timestamp"] for c in out] == [1, 2, 3] + + +def test_validate_candles_coerces_string_numerics(): + raw = [{"timestamp": "1", "open": "100", "high": "110", + "low": "90", "close": "105", "volume": "10"}] + out = validate_candles(raw) + assert out[0]["open"] == 100.0 + assert isinstance(out[0]["volume"], float) + + +def test_validate_candles_malformed_raises_http_502(): + raw = [{"timestamp": 1, "open": 100, "high": 50, "low": 90, + "close": 105, "volume": 1}] + with pytest.raises(HTTPException) as exc_info: + validate_candles(raw) + assert exc_info.value.status_code == 502 + assert "candle" in str(exc_info.value.detail).lower() + + +def test_validate_candles_empty_list(): + assert validate_candles([]) == [] diff --git a/tests/unit/exchanges/alpaca/test_client.py b/tests/unit/exchanges/alpaca/test_client.py index f9aad8a..f2ac719 100644 --- a/tests/unit/exchanges/alpaca/test_client.py +++ b/tests/unit/exchanges/alpaca/test_client.py @@ -271,8 +271,8 @@ async def test_get_bars_stocks(httpx_mock: HTTPXMock, client: AlpacaClient): ) assert result["symbol"] == "AAPL" assert result["interval"] == "1d" - assert len(result["bars"]) == 1 - assert result["bars"][0]["close"] == 175.0 + assert len(result["candles"]) == 1 + assert result["candles"][0]["close"] == 175.0 @pytest.mark.asyncio