refactor(V2): IBKR WebSocket — fix stop/start cycle, guard rails, log disconnect

Code review fixes (commit 17700d2):
- _stopped reset on start() (was stuck True after stop→start)
- _require_started guard on subscribe_*/unsubscribe (clear WSError vs AttributeError)
- _reader_loop logs disconnect via logger.warning + sets _ws=None for `connected` signal
- Class docstring documents stale-snapshot behavior + deferred reconnect
- New tests: subscribe-before-start, stop→start cycle resumption

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
root
2026-05-03 21:18:57 +00:00
parent 17700d27a0
commit 6266708e15
2 changed files with 63 additions and 1 deletions
+25 -1
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import json import json
import logging
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -12,6 +13,8 @@ from websockets import connect as websockets_connect # exposed for tests
from cerbero_mcp.exchanges.ibkr.oauth import OAuth1aSigner from cerbero_mcp.exchanges.ibkr.oauth import OAuth1aSigner
logger = logging.getLogger(__name__)
class WSError(Exception): class WSError(Exception):
"""WebSocket layer error.""" """WebSocket layer error."""
@@ -39,6 +42,15 @@ _SMD_FIELDS = ["31", "84", "86", "7295", "7296"]
@dataclass @dataclass
class IBKRWebSocket: class IBKRWebSocket:
"""Persistent WSS to IBKR Client Portal with smd/sbd subs.
Snapshot lifetime: each (tick|depth) cache entry is overwritten on every
incoming message. On disconnect, the reader loop logs and exits leaving
the existing cache intact. Consumers should check `connected` before
trusting a stale snapshot, or compare `timestamp_ms` against wall clock.
Automatic reconnect is deferred to a follow-up; V1 surfaces disconnects
via `connected=False` so the higher-level tool layer can rebuild the WS.
"""
signer: OAuth1aSigner signer: OAuth1aSigner
ws_url: str ws_url: str
base_url: str base_url: str
@@ -63,6 +75,7 @@ class IBKRWebSocket:
async def start(self) -> None: async def start(self) -> None:
if self.connected: if self.connected:
return return
self._stopped = False # reset on every start (supports stop→start cycles)
lst = await self.signer.get_live_session_token(base_url=self.base_url) lst = await self.signer.get_live_session_token(base_url=self.base_url)
self._ws = await websockets_connect( self._ws = await websockets_connect(
self.ws_url, self.ws_url,
@@ -87,6 +100,7 @@ class IBKRWebSocket:
self._ws = None self._ws = None
async def subscribe_tick(self, conid: int, *, forced: bool = False) -> None: async def subscribe_tick(self, conid: int, *, forced: bool = False) -> None:
self._require_started()
await self._ensure_capacity(conid) await self._ensure_capacity(conid)
if conid in self._subs: if conid in self._subs:
self._last_polled_at[conid] = time.monotonic() self._last_polled_at[conid] = time.monotonic()
@@ -103,6 +117,7 @@ class IBKRWebSocket:
async def subscribe_depth( async def subscribe_depth(
self, conid: int, *, exchange: str = "SMART", rows: int = 5 self, conid: int, *, exchange: str = "SMART", rows: int = 5
) -> None: ) -> None:
self._require_started()
await self._ensure_capacity(conid) await self._ensure_capacity(conid)
if conid in self._depth_subs: if conid in self._depth_subs:
self._last_polled_at[conid] = time.monotonic() self._last_polled_at[conid] = time.monotonic()
@@ -113,6 +128,7 @@ class IBKRWebSocket:
self._last_polled_at[conid] = time.monotonic() self._last_polled_at[conid] = time.monotonic()
async def unsubscribe(self, conid: int) -> None: async def unsubscribe(self, conid: int) -> None:
self._require_started()
if conid in self._subs: if conid in self._subs:
await self._ws.send(f"umd+{conid}+{{}}") await self._ws.send(f"umd+{conid}+{{}}")
self._subs.discard(conid) self._subs.discard(conid)
@@ -151,6 +167,10 @@ class IBKRWebSocket:
"timestamp_ms": snap.timestamp_ms, "timestamp_ms": snap.timestamp_ms,
} }
def _require_started(self) -> None:
if self._ws is None:
raise WSError("IBKR_WS_NOT_STARTED: call start() first")
async def _ensure_capacity(self, conid: int) -> None: async def _ensure_capacity(self, conid: int) -> None:
if (conid in self._subs) or (conid in self._depth_subs): if (conid in self._subs) or (conid in self._depth_subs):
return return
@@ -173,7 +193,11 @@ class IBKRWebSocket:
self._on_depth(topic, msg) self._on_depth(topic, msg)
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except Exception: except Exception as exc:
# Disconnect / parse error / network — leave cache as-is, mark dead.
# V1: no automatic reconnect; consumers detect via stale timestamp_ms.
logger.warning("ibkr ws reader exited: %s", exc)
self._ws = None
return return
def _on_tick(self, topic: str, msg: dict) -> None: def _on_tick(self, topic: str, msg: dict) -> None:
+38
View File
@@ -84,3 +84,41 @@ async def test_subscribe_limit(fake_signer, monkeypatch):
with pytest.raises(WSError, match="IBKR_WS_SUB_LIMIT"): with pytest.raises(WSError, match="IBKR_WS_SUB_LIMIT"):
await ws.subscribe_tick(3) await ws.subscribe_tick(3)
await ws.stop() await ws.stop()
@pytest.mark.asyncio
async def test_subscribe_before_start_raises(fake_signer):
ws = IBKRWebSocket(
signer=fake_signer,
ws_url="wss://x", base_url="https://x",
max_subs=10, idle_timeout_s=300,
)
with pytest.raises(WSError, match="IBKR_WS_NOT_STARTED"):
await ws.subscribe_tick(1)
@pytest.mark.asyncio
async def test_start_after_stop_resumes_reader(fake_signer, monkeypatch):
fake_ws_a = FakeWS()
fake_ws_b = FakeWS()
fakes = iter([fake_ws_a, fake_ws_b])
async def fake_connect(url, **kw):
return next(fakes)
monkeypatch.setattr("cerbero_mcp.exchanges.ibkr.ws.websockets_connect", fake_connect)
ws = IBKRWebSocket(
signer=fake_signer,
ws_url="wss://x", base_url="https://x",
max_subs=10, idle_timeout_s=300,
)
await ws.start()
await ws.stop()
# Restart with fresh fake_ws_b
await ws.start()
await ws.subscribe_tick(42)
await fake_ws_b.push({"topic": "smd+42", "31": "100"})
await asyncio.sleep(0.05)
assert ws.get_tick_snapshot(42) is not None
await ws.stop()