feat(V2): IBKR WebSocket layer + tick/depth snapshot cache
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,223 @@
|
||||
"""IBKR Client Portal WebSocket — persistent WSS, smd/sbd subs, snapshot cache."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from websockets import connect as websockets_connect # exposed for tests
|
||||
|
||||
from cerbero_mcp.exchanges.ibkr.oauth import OAuth1aSigner
|
||||
|
||||
|
||||
class WSError(Exception):
|
||||
"""WebSocket layer error."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickSnapshot:
|
||||
last_price: float | None
|
||||
bid: float | None
|
||||
ask: float | None
|
||||
bid_size: float | None
|
||||
ask_size: float | None
|
||||
timestamp_ms: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthSnapshot:
|
||||
bids: list[dict]
|
||||
asks: list[dict]
|
||||
timestamp_ms: int
|
||||
|
||||
|
||||
_SMD_FIELDS = ["31", "84", "86", "7295", "7296"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IBKRWebSocket:
|
||||
signer: OAuth1aSigner
|
||||
ws_url: str
|
||||
base_url: str
|
||||
max_subs: int = 80
|
||||
idle_timeout_s: int = 300
|
||||
|
||||
_ws: Any = field(default=None, init=False, repr=False)
|
||||
_tick_cache: dict[int, TickSnapshot] = field(default_factory=dict, init=False)
|
||||
_depth_cache: dict[int, DepthSnapshot] = field(default_factory=dict, init=False)
|
||||
_subs: set[int] = field(default_factory=set, init=False)
|
||||
_depth_subs: set[int] = field(default_factory=set, init=False)
|
||||
_last_polled_at: dict[int, float] = field(default_factory=dict, init=False)
|
||||
_forced_subs: set[int] = field(default_factory=set, init=False)
|
||||
_reader_task: asyncio.Task | None = field(default=None, init=False)
|
||||
_idle_task: asyncio.Task | None = field(default=None, init=False)
|
||||
_stopped: bool = field(default=False, init=False)
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self._ws is not None and not getattr(self._ws, "closed", True)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.connected:
|
||||
return
|
||||
lst = await self.signer.get_live_session_token(base_url=self.base_url)
|
||||
self._ws = await websockets_connect(
|
||||
self.ws_url,
|
||||
additional_headers={"Cookie": f"api={lst}"},
|
||||
)
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
self._idle_task = asyncio.create_task(self._idle_sweeper())
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stopped = True
|
||||
if self._idle_task:
|
||||
self._idle_task.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await self._idle_task
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
with contextlib.suppress(BaseException):
|
||||
await self._reader_task
|
||||
if self._ws:
|
||||
with contextlib.suppress(Exception):
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def subscribe_tick(self, conid: int, *, forced: bool = False) -> None:
|
||||
await self._ensure_capacity(conid)
|
||||
if conid in self._subs:
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
if forced:
|
||||
self._forced_subs.add(conid)
|
||||
return
|
||||
msg = "smd+" + str(conid) + "+" + json.dumps({"fields": _SMD_FIELDS})
|
||||
await self._ws.send(msg)
|
||||
self._subs.add(conid)
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
if forced:
|
||||
self._forced_subs.add(conid)
|
||||
|
||||
async def subscribe_depth(
|
||||
self, conid: int, *, exchange: str = "SMART", rows: int = 5
|
||||
) -> None:
|
||||
await self._ensure_capacity(conid)
|
||||
if conid in self._depth_subs:
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
return
|
||||
msg = f"sbd+{conid}+{exchange}+{rows}"
|
||||
await self._ws.send(msg)
|
||||
self._depth_subs.add(conid)
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
|
||||
async def unsubscribe(self, conid: int) -> None:
|
||||
if conid in self._subs:
|
||||
await self._ws.send(f"umd+{conid}+{{}}")
|
||||
self._subs.discard(conid)
|
||||
if conid in self._depth_subs:
|
||||
await self._ws.send(f"ubd+{conid}")
|
||||
self._depth_subs.discard(conid)
|
||||
self._tick_cache.pop(conid, None)
|
||||
self._depth_cache.pop(conid, None)
|
||||
self._last_polled_at.pop(conid, None)
|
||||
self._forced_subs.discard(conid)
|
||||
|
||||
def get_tick_snapshot(self, conid: int) -> dict | None:
|
||||
snap = self._tick_cache.get(conid)
|
||||
if not snap:
|
||||
return None
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
return {
|
||||
"conid": conid,
|
||||
"last_price": snap.last_price,
|
||||
"bid": snap.bid,
|
||||
"ask": snap.ask,
|
||||
"bid_size": snap.bid_size,
|
||||
"ask_size": snap.ask_size,
|
||||
"timestamp_ms": snap.timestamp_ms,
|
||||
}
|
||||
|
||||
def get_depth_snapshot(self, conid: int) -> dict | None:
|
||||
snap = self._depth_cache.get(conid)
|
||||
if not snap:
|
||||
return None
|
||||
self._last_polled_at[conid] = time.monotonic()
|
||||
return {
|
||||
"conid": conid,
|
||||
"bids": snap.bids,
|
||||
"asks": snap.asks,
|
||||
"timestamp_ms": snap.timestamp_ms,
|
||||
}
|
||||
|
||||
async def _ensure_capacity(self, conid: int) -> None:
|
||||
if (conid in self._subs) or (conid in self._depth_subs):
|
||||
return
|
||||
active = len(self._subs) + len(self._depth_subs)
|
||||
if active >= self.max_subs:
|
||||
raise WSError(f"IBKR_WS_SUB_LIMIT: {active}/{self.max_subs}")
|
||||
|
||||
async def _reader_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopped and self._ws:
|
||||
raw = await self._ws.recv()
|
||||
try:
|
||||
msg = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
topic = msg.get("topic", "")
|
||||
if topic.startswith("smd+"):
|
||||
self._on_tick(topic, msg)
|
||||
elif topic.startswith("sbd+"):
|
||||
self._on_depth(topic, msg)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _on_tick(self, topic: str, msg: dict) -> None:
|
||||
try:
|
||||
conid = int(topic.split("+", 1)[1])
|
||||
except (ValueError, IndexError):
|
||||
return
|
||||
|
||||
def _f(k: str) -> float | None:
|
||||
v = msg.get(k)
|
||||
try:
|
||||
return float(v) if v not in (None, "") else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
self._tick_cache[conid] = TickSnapshot(
|
||||
last_price=_f("31"), bid=_f("84"), ask=_f("86"),
|
||||
bid_size=_f("7295"), ask_size=_f("7296"),
|
||||
timestamp_ms=int(time.time() * 1000),
|
||||
)
|
||||
|
||||
def _on_depth(self, topic: str, msg: dict) -> None:
|
||||
try:
|
||||
conid = int(topic.split("+", 1)[1])
|
||||
except (ValueError, IndexError):
|
||||
return
|
||||
self._depth_cache[conid] = DepthSnapshot(
|
||||
bids=msg.get("bids") or [],
|
||||
asks=msg.get("asks") or [],
|
||||
timestamp_ms=int(time.time() * 1000),
|
||||
)
|
||||
|
||||
async def _idle_sweeper(self) -> None:
|
||||
try:
|
||||
while not self._stopped:
|
||||
await asyncio.sleep(30)
|
||||
now = time.monotonic()
|
||||
expired = [
|
||||
c for c in list(self._subs | self._depth_subs)
|
||||
if c not in self._forced_subs
|
||||
and now - self._last_polled_at.get(c, now) > self.idle_timeout_s
|
||||
]
|
||||
for c in expired:
|
||||
with contextlib.suppress(Exception):
|
||||
await self.unsubscribe(c)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
Reference in New Issue
Block a user