From 17700d27a0d345e6198f37b2e30b58e7272d0037 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 3 May 2026 21:15:25 +0000 Subject: [PATCH] feat(V2): IBKR WebSocket layer + tick/depth snapshot cache Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cerbero_mcp/exchanges/ibkr/ws.py | 223 +++++++++++++++++++++++++++ tests/unit/exchanges/ibkr/test_ws.py | 86 +++++++++++ 2 files changed, 309 insertions(+) create mode 100644 src/cerbero_mcp/exchanges/ibkr/ws.py create mode 100644 tests/unit/exchanges/ibkr/test_ws.py diff --git a/src/cerbero_mcp/exchanges/ibkr/ws.py b/src/cerbero_mcp/exchanges/ibkr/ws.py new file mode 100644 index 0000000..06b7e22 --- /dev/null +++ b/src/cerbero_mcp/exchanges/ibkr/ws.py @@ -0,0 +1,223 @@ +"""IBKR Client Portal WebSocket — persistent WSS, smd/sbd subs, snapshot cache.""" +from __future__ import annotations + +import asyncio +import contextlib +import json +import time +from dataclasses import dataclass, field +from typing import Any + +from websockets import connect as websockets_connect # exposed for tests + +from cerbero_mcp.exchanges.ibkr.oauth import OAuth1aSigner + + +class WSError(Exception): + """WebSocket layer error.""" + + +@dataclass +class TickSnapshot: + last_price: float | None + bid: float | None + ask: float | None + bid_size: float | None + ask_size: float | None + timestamp_ms: int + + +@dataclass +class DepthSnapshot: + bids: list[dict] + asks: list[dict] + timestamp_ms: int + + +_SMD_FIELDS = ["31", "84", "86", "7295", "7296"] + + +@dataclass +class IBKRWebSocket: + signer: OAuth1aSigner + ws_url: str + base_url: str + max_subs: int = 80 + idle_timeout_s: int = 300 + + _ws: Any = field(default=None, init=False, repr=False) + _tick_cache: dict[int, TickSnapshot] = field(default_factory=dict, init=False) + _depth_cache: dict[int, DepthSnapshot] = field(default_factory=dict, init=False) + _subs: set[int] = field(default_factory=set, init=False) + _depth_subs: set[int] = field(default_factory=set, init=False) + _last_polled_at: dict[int, float] = field(default_factory=dict, init=False) + _forced_subs: set[int] = field(default_factory=set, init=False) + _reader_task: asyncio.Task | None = field(default=None, init=False) + _idle_task: asyncio.Task | None = field(default=None, init=False) + _stopped: bool = field(default=False, init=False) + + @property + def connected(self) -> bool: + return self._ws is not None and not getattr(self._ws, "closed", True) + + async def start(self) -> None: + if self.connected: + return + lst = await self.signer.get_live_session_token(base_url=self.base_url) + self._ws = await websockets_connect( + self.ws_url, + additional_headers={"Cookie": f"api={lst}"}, + ) + self._reader_task = asyncio.create_task(self._reader_loop()) + self._idle_task = asyncio.create_task(self._idle_sweeper()) + + async def stop(self) -> None: + self._stopped = True + if self._idle_task: + self._idle_task.cancel() + with contextlib.suppress(BaseException): + await self._idle_task + if self._reader_task: + self._reader_task.cancel() + with contextlib.suppress(BaseException): + await self._reader_task + if self._ws: + with contextlib.suppress(Exception): + await self._ws.close() + self._ws = None + + async def subscribe_tick(self, conid: int, *, forced: bool = False) -> None: + await self._ensure_capacity(conid) + if conid in self._subs: + self._last_polled_at[conid] = time.monotonic() + if forced: + self._forced_subs.add(conid) + return + msg = "smd+" + str(conid) + "+" + json.dumps({"fields": _SMD_FIELDS}) + await self._ws.send(msg) + self._subs.add(conid) + self._last_polled_at[conid] = time.monotonic() + if forced: + self._forced_subs.add(conid) + + async def subscribe_depth( + self, conid: int, *, exchange: str = "SMART", rows: int = 5 + ) -> None: + await self._ensure_capacity(conid) + if conid in self._depth_subs: + self._last_polled_at[conid] = time.monotonic() + return + msg = f"sbd+{conid}+{exchange}+{rows}" + await self._ws.send(msg) + self._depth_subs.add(conid) + self._last_polled_at[conid] = time.monotonic() + + async def unsubscribe(self, conid: int) -> None: + if conid in self._subs: + await self._ws.send(f"umd+{conid}+{{}}") + self._subs.discard(conid) + if conid in self._depth_subs: + await self._ws.send(f"ubd+{conid}") + self._depth_subs.discard(conid) + self._tick_cache.pop(conid, None) + self._depth_cache.pop(conid, None) + self._last_polled_at.pop(conid, None) + self._forced_subs.discard(conid) + + def get_tick_snapshot(self, conid: int) -> dict | None: + snap = self._tick_cache.get(conid) + if not snap: + return None + self._last_polled_at[conid] = time.monotonic() + return { + "conid": conid, + "last_price": snap.last_price, + "bid": snap.bid, + "ask": snap.ask, + "bid_size": snap.bid_size, + "ask_size": snap.ask_size, + "timestamp_ms": snap.timestamp_ms, + } + + def get_depth_snapshot(self, conid: int) -> dict | None: + snap = self._depth_cache.get(conid) + if not snap: + return None + self._last_polled_at[conid] = time.monotonic() + return { + "conid": conid, + "bids": snap.bids, + "asks": snap.asks, + "timestamp_ms": snap.timestamp_ms, + } + + async def _ensure_capacity(self, conid: int) -> None: + if (conid in self._subs) or (conid in self._depth_subs): + return + active = len(self._subs) + len(self._depth_subs) + if active >= self.max_subs: + raise WSError(f"IBKR_WS_SUB_LIMIT: {active}/{self.max_subs}") + + async def _reader_loop(self) -> None: + try: + while not self._stopped and self._ws: + raw = await self._ws.recv() + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + topic = msg.get("topic", "") + if topic.startswith("smd+"): + self._on_tick(topic, msg) + elif topic.startswith("sbd+"): + self._on_depth(topic, msg) + except asyncio.CancelledError: + raise + except Exception: + return + + def _on_tick(self, topic: str, msg: dict) -> None: + try: + conid = int(topic.split("+", 1)[1]) + except (ValueError, IndexError): + return + + def _f(k: str) -> float | None: + v = msg.get(k) + try: + return float(v) if v not in (None, "") else None + except (TypeError, ValueError): + return None + + self._tick_cache[conid] = TickSnapshot( + last_price=_f("31"), bid=_f("84"), ask=_f("86"), + bid_size=_f("7295"), ask_size=_f("7296"), + timestamp_ms=int(time.time() * 1000), + ) + + def _on_depth(self, topic: str, msg: dict) -> None: + try: + conid = int(topic.split("+", 1)[1]) + except (ValueError, IndexError): + return + self._depth_cache[conid] = DepthSnapshot( + bids=msg.get("bids") or [], + asks=msg.get("asks") or [], + timestamp_ms=int(time.time() * 1000), + ) + + async def _idle_sweeper(self) -> None: + try: + while not self._stopped: + await asyncio.sleep(30) + now = time.monotonic() + expired = [ + c for c in list(self._subs | self._depth_subs) + if c not in self._forced_subs + and now - self._last_polled_at.get(c, now) > self.idle_timeout_s + ] + for c in expired: + with contextlib.suppress(Exception): + await self.unsubscribe(c) + except asyncio.CancelledError: + raise diff --git a/tests/unit/exchanges/ibkr/test_ws.py b/tests/unit/exchanges/ibkr/test_ws.py new file mode 100644 index 0000000..47770b1 --- /dev/null +++ b/tests/unit/exchanges/ibkr/test_ws.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket, WSError + + +class FakeWS: + """Bidirectional async fake for WSS messages.""" + def __init__(self) -> None: + self.sent: list[str] = [] + self._inbox: asyncio.Queue[str] = asyncio.Queue() + self.closed = False + async def send(self, msg: str) -> None: + self.sent.append(msg) + async def recv(self) -> str: + return await self._inbox.get() + async def close(self) -> None: + self.closed = True + async def push(self, payload: dict) -> None: + await self._inbox.put(json.dumps(payload)) + + +@pytest.fixture +def fake_signer(): + s = MagicMock() + s.get_live_session_token = AsyncMock(return_value="LST==") + return s + + +@pytest.mark.asyncio +async def test_subscribe_tick_caches_snapshot(fake_signer, monkeypatch): + fake_ws = FakeWS() + + async def fake_connect(url, **kw): + return fake_ws + + monkeypatch.setattr("cerbero_mcp.exchanges.ibkr.ws.websockets_connect", fake_connect) + + ws = IBKRWebSocket( + signer=fake_signer, + ws_url="wss://api.ibkr.com/v1/api/ws", + base_url="https://api.ibkr.com/v1/api", + max_subs=80, idle_timeout_s=300, + ) + await ws.start() + await ws.subscribe_tick(265598) + + await fake_ws.push({ + "topic": "smd+265598", + "31": "150.5", "84": "150.4", "86": "150.6", + "7295": "100", "7296": "200", + }) + await asyncio.sleep(0.05) + + snap = ws.get_tick_snapshot(265598) + assert snap is not None + assert snap["last_price"] == 150.5 + assert snap["bid"] == 150.4 + + await ws.stop() + + +@pytest.mark.asyncio +async def test_subscribe_limit(fake_signer, monkeypatch): + fake_ws = FakeWS() + + async def fake_connect(url, **kw): + return fake_ws + + monkeypatch.setattr("cerbero_mcp.exchanges.ibkr.ws.websockets_connect", fake_connect) + + ws = IBKRWebSocket( + signer=fake_signer, + ws_url="wss://x", base_url="https://x", + max_subs=2, idle_timeout_s=300, + ) + await ws.start() + await ws.subscribe_tick(1) + await ws.subscribe_tick(2) + with pytest.raises(WSError, match="IBKR_WS_SUB_LIMIT"): + await ws.subscribe_tick(3) + await ws.stop()