feat(V2): IBKR streaming tools (tick/depth/subscribe)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,13 +1,14 @@
|
|||||||
"""IBKR tool functions: Pydantic schemas + async dispatch to client/ws."""
|
"""IBKR tool functions: Pydantic schemas + async dispatch to client/ws."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cerbero_mcp.exchanges.ibkr.client import IBKRClient
|
from cerbero_mcp.exchanges.ibkr.client import _SEC_TYPE_MAP, IBKRClient, IBKRError
|
||||||
from cerbero_mcp.exchanges.ibkr.leverage_cap import enforce_leverage, get_max_leverage # noqa: F401
|
from cerbero_mcp.exchanges.ibkr.leverage_cap import enforce_leverage, get_max_leverage # noqa: F401
|
||||||
from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket # noqa: F401
|
from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket
|
||||||
|
|
||||||
# === Schemas: reads ===
|
# === Schemas: reads ===
|
||||||
|
|
||||||
@@ -183,3 +184,66 @@ async def get_clock(client: IBKRClient, params: GetClockReq) -> dict:
|
|||||||
"is_open": _dt.time(13, 30) <= now.time() <= _dt.time(20, 0)
|
"is_open": _dt.time(13, 30) <= now.time() <= _dt.time(20, 0)
|
||||||
and now.weekday() < 5,
|
and now.weekday() < 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# === Streaming tools ===
|
||||||
|
|
||||||
|
|
||||||
|
def _sec_type_for(asset_class: str) -> str:
|
||||||
|
return _SEC_TYPE_MAP.get(asset_class.lower(), "STK")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tick(
|
||||||
|
client: IBKRClient, params: GetTickReq,
|
||||||
|
*, ws: IBKRWebSocket, timeout_s: float = 3.0,
|
||||||
|
) -> dict:
|
||||||
|
sec = _sec_type_for(params.asset_class)
|
||||||
|
conid = await client.resolve_conid(params.symbol, sec)
|
||||||
|
snap = ws.get_tick_snapshot(conid)
|
||||||
|
if snap:
|
||||||
|
return {**snap, "symbol": params.symbol}
|
||||||
|
await ws.subscribe_tick(conid)
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
snap = ws.get_tick_snapshot(conid)
|
||||||
|
if snap:
|
||||||
|
return {**snap, "symbol": params.symbol}
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
raise IBKRError(f"IBKR_TICK_TIMEOUT: {params.symbol}")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_depth(
|
||||||
|
client: IBKRClient, params: GetDepthReq,
|
||||||
|
*, ws: IBKRWebSocket, timeout_s: float = 3.0,
|
||||||
|
) -> dict:
|
||||||
|
sec = _sec_type_for(params.asset_class)
|
||||||
|
conid = await client.resolve_conid(params.symbol, sec)
|
||||||
|
snap = ws.get_depth_snapshot(conid)
|
||||||
|
if snap:
|
||||||
|
return {**snap, "symbol": params.symbol}
|
||||||
|
await ws.subscribe_depth(conid, exchange=params.exchange, rows=params.rows)
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
snap = ws.get_depth_snapshot(conid)
|
||||||
|
if snap:
|
||||||
|
return {**snap, "symbol": params.symbol}
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
raise IBKRError(f"IBKR_DEPTH_TIMEOUT: {params.symbol}")
|
||||||
|
|
||||||
|
|
||||||
|
async def subscribe_tick(
|
||||||
|
client: IBKRClient, params: SubscribeTickReq, *, ws: IBKRWebSocket,
|
||||||
|
) -> dict:
|
||||||
|
sec = _sec_type_for(params.asset_class)
|
||||||
|
conid = await client.resolve_conid(params.symbol, sec)
|
||||||
|
await ws.subscribe_tick(conid, forced=True)
|
||||||
|
return {"symbol": params.symbol, "conid": conid, "subscribed": True}
|
||||||
|
|
||||||
|
|
||||||
|
async def unsubscribe(
|
||||||
|
client: IBKRClient, params: UnsubscribeReq, *, ws: IBKRWebSocket,
|
||||||
|
) -> dict:
|
||||||
|
sec = _sec_type_for(params.asset_class)
|
||||||
|
conid = await client.resolve_conid(params.symbol, sec)
|
||||||
|
await ws.unsubscribe(conid)
|
||||||
|
return {"symbol": params.symbol, "conid": conid, "unsubscribed": True}
|
||||||
|
|||||||
@@ -26,3 +26,22 @@ async def test_get_account_tool_calls_client():
|
|||||||
client.get_account = AsyncMock(return_value={"netliquidation": {"amount": 10000}})
|
client.get_account = AsyncMock(return_value={"netliquidation": {"amount": 10000}})
|
||||||
res = await t.get_account(client, t.GetAccountReq())
|
res = await t.get_account(client, t.GetAccountReq())
|
||||||
assert res["netliquidation"]["amount"] == 10000
|
assert res["netliquidation"]["amount"] == 10000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tick_uses_cache_or_subscribes():
|
||||||
|
client = MagicMock()
|
||||||
|
client.resolve_conid = AsyncMock(return_value=42)
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.get_tick_snapshot = MagicMock(side_effect=[
|
||||||
|
None,
|
||||||
|
{"conid": 42, "last_price": 99.5, "bid": 99.4, "ask": 99.6,
|
||||||
|
"bid_size": 1, "ask_size": 1, "timestamp_ms": 1700000000000},
|
||||||
|
])
|
||||||
|
ws.subscribe_tick = AsyncMock()
|
||||||
|
|
||||||
|
res = await t.get_tick(
|
||||||
|
client, t.GetTickReq(symbol="AAPL"), ws=ws, timeout_s=0.05,
|
||||||
|
)
|
||||||
|
assert res["last_price"] == 99.5
|
||||||
|
ws.subscribe_tick.assert_awaited_once_with(42)
|
||||||
|
|||||||
Reference in New Issue
Block a user