feat(V2): IBKR WebSocket layer + tick/depth snapshot cache
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,223 @@
|
|||||||
|
"""IBKR Client Portal WebSocket — persistent WSS, smd/sbd subs, snapshot cache."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from websockets import connect as websockets_connect # exposed for tests
|
||||||
|
|
||||||
|
from cerbero_mcp.exchanges.ibkr.oauth import OAuth1aSigner
|
||||||
|
|
||||||
|
|
||||||
|
class WSError(Exception):
|
||||||
|
"""WebSocket layer error."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TickSnapshot:
|
||||||
|
last_price: float | None
|
||||||
|
bid: float | None
|
||||||
|
ask: float | None
|
||||||
|
bid_size: float | None
|
||||||
|
ask_size: float | None
|
||||||
|
timestamp_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DepthSnapshot:
|
||||||
|
bids: list[dict]
|
||||||
|
asks: list[dict]
|
||||||
|
timestamp_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
_SMD_FIELDS = ["31", "84", "86", "7295", "7296"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IBKRWebSocket:
|
||||||
|
signer: OAuth1aSigner
|
||||||
|
ws_url: str
|
||||||
|
base_url: str
|
||||||
|
max_subs: int = 80
|
||||||
|
idle_timeout_s: int = 300
|
||||||
|
|
||||||
|
_ws: Any = field(default=None, init=False, repr=False)
|
||||||
|
_tick_cache: dict[int, TickSnapshot] = field(default_factory=dict, init=False)
|
||||||
|
_depth_cache: dict[int, DepthSnapshot] = field(default_factory=dict, init=False)
|
||||||
|
_subs: set[int] = field(default_factory=set, init=False)
|
||||||
|
_depth_subs: set[int] = field(default_factory=set, init=False)
|
||||||
|
_last_polled_at: dict[int, float] = field(default_factory=dict, init=False)
|
||||||
|
_forced_subs: set[int] = field(default_factory=set, init=False)
|
||||||
|
_reader_task: asyncio.Task | None = field(default=None, init=False)
|
||||||
|
_idle_task: asyncio.Task | None = field(default=None, init=False)
|
||||||
|
_stopped: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def connected(self) -> bool:
|
||||||
|
return self._ws is not None and not getattr(self._ws, "closed", True)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self.connected:
|
||||||
|
return
|
||||||
|
lst = await self.signer.get_live_session_token(base_url=self.base_url)
|
||||||
|
self._ws = await websockets_connect(
|
||||||
|
self.ws_url,
|
||||||
|
additional_headers={"Cookie": f"api={lst}"},
|
||||||
|
)
|
||||||
|
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||||
|
self._idle_task = asyncio.create_task(self._idle_sweeper())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._stopped = True
|
||||||
|
if self._idle_task:
|
||||||
|
self._idle_task.cancel()
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
await self._idle_task
|
||||||
|
if self._reader_task:
|
||||||
|
self._reader_task.cancel()
|
||||||
|
with contextlib.suppress(BaseException):
|
||||||
|
await self._reader_task
|
||||||
|
if self._ws:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
async def subscribe_tick(self, conid: int, *, forced: bool = False) -> None:
|
||||||
|
await self._ensure_capacity(conid)
|
||||||
|
if conid in self._subs:
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
if forced:
|
||||||
|
self._forced_subs.add(conid)
|
||||||
|
return
|
||||||
|
msg = "smd+" + str(conid) + "+" + json.dumps({"fields": _SMD_FIELDS})
|
||||||
|
await self._ws.send(msg)
|
||||||
|
self._subs.add(conid)
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
if forced:
|
||||||
|
self._forced_subs.add(conid)
|
||||||
|
|
||||||
|
async def subscribe_depth(
|
||||||
|
self, conid: int, *, exchange: str = "SMART", rows: int = 5
|
||||||
|
) -> None:
|
||||||
|
await self._ensure_capacity(conid)
|
||||||
|
if conid in self._depth_subs:
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
return
|
||||||
|
msg = f"sbd+{conid}+{exchange}+{rows}"
|
||||||
|
await self._ws.send(msg)
|
||||||
|
self._depth_subs.add(conid)
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
|
||||||
|
async def unsubscribe(self, conid: int) -> None:
|
||||||
|
if conid in self._subs:
|
||||||
|
await self._ws.send(f"umd+{conid}+{{}}")
|
||||||
|
self._subs.discard(conid)
|
||||||
|
if conid in self._depth_subs:
|
||||||
|
await self._ws.send(f"ubd+{conid}")
|
||||||
|
self._depth_subs.discard(conid)
|
||||||
|
self._tick_cache.pop(conid, None)
|
||||||
|
self._depth_cache.pop(conid, None)
|
||||||
|
self._last_polled_at.pop(conid, None)
|
||||||
|
self._forced_subs.discard(conid)
|
||||||
|
|
||||||
|
def get_tick_snapshot(self, conid: int) -> dict | None:
|
||||||
|
snap = self._tick_cache.get(conid)
|
||||||
|
if not snap:
|
||||||
|
return None
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
return {
|
||||||
|
"conid": conid,
|
||||||
|
"last_price": snap.last_price,
|
||||||
|
"bid": snap.bid,
|
||||||
|
"ask": snap.ask,
|
||||||
|
"bid_size": snap.bid_size,
|
||||||
|
"ask_size": snap.ask_size,
|
||||||
|
"timestamp_ms": snap.timestamp_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_depth_snapshot(self, conid: int) -> dict | None:
|
||||||
|
snap = self._depth_cache.get(conid)
|
||||||
|
if not snap:
|
||||||
|
return None
|
||||||
|
self._last_polled_at[conid] = time.monotonic()
|
||||||
|
return {
|
||||||
|
"conid": conid,
|
||||||
|
"bids": snap.bids,
|
||||||
|
"asks": snap.asks,
|
||||||
|
"timestamp_ms": snap.timestamp_ms,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _ensure_capacity(self, conid: int) -> None:
|
||||||
|
if (conid in self._subs) or (conid in self._depth_subs):
|
||||||
|
return
|
||||||
|
active = len(self._subs) + len(self._depth_subs)
|
||||||
|
if active >= self.max_subs:
|
||||||
|
raise WSError(f"IBKR_WS_SUB_LIMIT: {active}/{self.max_subs}")
|
||||||
|
|
||||||
|
async def _reader_loop(self) -> None:
|
||||||
|
try:
|
||||||
|
while not self._stopped and self._ws:
|
||||||
|
raw = await self._ws.recv()
|
||||||
|
try:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
topic = msg.get("topic", "")
|
||||||
|
if topic.startswith("smd+"):
|
||||||
|
self._on_tick(topic, msg)
|
||||||
|
elif topic.startswith("sbd+"):
|
||||||
|
self._on_depth(topic, msg)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _on_tick(self, topic: str, msg: dict) -> None:
|
||||||
|
try:
|
||||||
|
conid = int(topic.split("+", 1)[1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _f(k: str) -> float | None:
|
||||||
|
v = msg.get(k)
|
||||||
|
try:
|
||||||
|
return float(v) if v not in (None, "") else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._tick_cache[conid] = TickSnapshot(
|
||||||
|
last_price=_f("31"), bid=_f("84"), ask=_f("86"),
|
||||||
|
bid_size=_f("7295"), ask_size=_f("7296"),
|
||||||
|
timestamp_ms=int(time.time() * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_depth(self, topic: str, msg: dict) -> None:
|
||||||
|
try:
|
||||||
|
conid = int(topic.split("+", 1)[1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return
|
||||||
|
self._depth_cache[conid] = DepthSnapshot(
|
||||||
|
bids=msg.get("bids") or [],
|
||||||
|
asks=msg.get("asks") or [],
|
||||||
|
timestamp_ms=int(time.time() * 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _idle_sweeper(self) -> None:
|
||||||
|
try:
|
||||||
|
while not self._stopped:
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
now = time.monotonic()
|
||||||
|
expired = [
|
||||||
|
c for c in list(self._subs | self._depth_subs)
|
||||||
|
if c not in self._forced_subs
|
||||||
|
and now - self._last_polled_at.get(c, now) > self.idle_timeout_s
|
||||||
|
]
|
||||||
|
for c in expired:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await self.unsubscribe(c)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket, WSError
|
||||||
|
|
||||||
|
|
||||||
|
class FakeWS:
|
||||||
|
"""Bidirectional async fake for WSS messages."""
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[str] = []
|
||||||
|
self._inbox: asyncio.Queue[str] = asyncio.Queue()
|
||||||
|
self.closed = False
|
||||||
|
async def send(self, msg: str) -> None:
|
||||||
|
self.sent.append(msg)
|
||||||
|
async def recv(self) -> str:
|
||||||
|
return await self._inbox.get()
|
||||||
|
async def close(self) -> None:
|
||||||
|
self.closed = True
|
||||||
|
async def push(self, payload: dict) -> None:
|
||||||
|
await self._inbox.put(json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_signer():
|
||||||
|
s = MagicMock()
|
||||||
|
s.get_live_session_token = AsyncMock(return_value="LST==")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_tick_caches_snapshot(fake_signer, monkeypatch):
|
||||||
|
fake_ws = FakeWS()
|
||||||
|
|
||||||
|
async def fake_connect(url, **kw):
|
||||||
|
return fake_ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("cerbero_mcp.exchanges.ibkr.ws.websockets_connect", fake_connect)
|
||||||
|
|
||||||
|
ws = IBKRWebSocket(
|
||||||
|
signer=fake_signer,
|
||||||
|
ws_url="wss://api.ibkr.com/v1/api/ws",
|
||||||
|
base_url="https://api.ibkr.com/v1/api",
|
||||||
|
max_subs=80, idle_timeout_s=300,
|
||||||
|
)
|
||||||
|
await ws.start()
|
||||||
|
await ws.subscribe_tick(265598)
|
||||||
|
|
||||||
|
await fake_ws.push({
|
||||||
|
"topic": "smd+265598",
|
||||||
|
"31": "150.5", "84": "150.4", "86": "150.6",
|
||||||
|
"7295": "100", "7296": "200",
|
||||||
|
})
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
snap = ws.get_tick_snapshot(265598)
|
||||||
|
assert snap is not None
|
||||||
|
assert snap["last_price"] == 150.5
|
||||||
|
assert snap["bid"] == 150.4
|
||||||
|
|
||||||
|
await ws.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subscribe_limit(fake_signer, monkeypatch):
|
||||||
|
fake_ws = FakeWS()
|
||||||
|
|
||||||
|
async def fake_connect(url, **kw):
|
||||||
|
return fake_ws
|
||||||
|
|
||||||
|
monkeypatch.setattr("cerbero_mcp.exchanges.ibkr.ws.websockets_connect", fake_connect)
|
||||||
|
|
||||||
|
ws = IBKRWebSocket(
|
||||||
|
signer=fake_signer,
|
||||||
|
ws_url="wss://x", base_url="https://x",
|
||||||
|
max_subs=2, idle_timeout_s=300,
|
||||||
|
)
|
||||||
|
await ws.start()
|
||||||
|
await ws.subscribe_tick(1)
|
||||||
|
await ws.subscribe_tick(2)
|
||||||
|
with pytest.raises(WSError, match="IBKR_WS_SUB_LIMIT"):
|
||||||
|
await ws.subscribe_tick(3)
|
||||||
|
await ws.stop()
|
||||||
Reference in New Issue
Block a user