"""IBKR tool functions: Pydantic schemas + async dispatch to client/ws.""" from __future__ import annotations import asyncio from typing import Any from pydantic import BaseModel from cerbero_mcp.exchanges.ibkr.client import _SEC_TYPE_MAP, IBKRClient, IBKRError from cerbero_mcp.exchanges.ibkr.leverage_cap import enforce_leverage, get_max_leverage # noqa: F401 from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket # === Schemas: reads === class GetAccountReq(BaseModel): pass class GetPositionsReq(BaseModel): pass class GetOpenOrdersReq(BaseModel): pass class GetActivitiesReq(BaseModel): days: int = 7 class GetTickerReq(BaseModel): symbol: str asset_class: str = "stocks" class GetBarsReq(BaseModel): symbol: str asset_class: str = "stocks" period: str = "1d" bar: str = "5min" class GetSnapshotReq(BaseModel): symbol: str asset_class: str = "stocks" class GetOptionChainReq(BaseModel): underlying: str expiry: str | None = None class SearchContractsReq(BaseModel): symbol: str sec_type: str = "STK" class GetClockReq(BaseModel): pass # === Schemas: streaming === class GetTickReq(BaseModel): symbol: str asset_class: str = "stocks" class GetDepthReq(BaseModel): symbol: str asset_class: str = "stocks" rows: int = 5 exchange: str = "SMART" class SubscribeTickReq(BaseModel): symbol: str asset_class: str = "stocks" class UnsubscribeReq(BaseModel): symbol: str asset_class: str = "stocks" # === Schemas: writes simple === class PlaceOrderReq(BaseModel): symbol: str side: str qty: float order_type: str = "market" limit_price: float | None = None stop_price: float | None = None tif: str = "day" asset_class: str = "stocks" sec_type: str | None = None exchange: str = "SMART" outside_rth: bool = False class AmendOrderReq(BaseModel): order_id: str qty: float | None = None limit_price: float | None = None stop_price: float | None = None tif: str | None = None class CancelOrderReq(BaseModel): order_id: str class CancelAllOrdersReq(BaseModel): pass class ClosePositionReq(BaseModel): symbol: str qty: float | None = None class CloseAllPositionsReq(BaseModel): pass # === Schemas: writes complex === class PlaceBracketOrderReq(BaseModel): symbol: str side: str qty: float entry_price: float stop_loss: float take_profit: float tif: str = "gtc" asset_class: str = "stocks" exchange: str = "SMART" class OrderLeg(BaseModel): symbol: str side: str qty: float order_type: str = "limit" limit_price: float | None = None stop_price: float | None = None tif: str = "gtc" asset_class: str = "stocks" class PlaceOcoOrderReq(BaseModel): legs: list[OrderLeg] class PlaceOtoOrderReq(BaseModel): trigger: OrderLeg child: OrderLeg # === Read tools === async def environment_info( client: IBKRClient, *, creds: dict, env_info: Any | None = None ) -> dict: return { "exchange": "ibkr", "environment": "testnet" if client.paper else "mainnet", "paper": client.paper, "base_url": client.base_url, "max_leverage": get_max_leverage(creds), } async def get_account(client: IBKRClient, params: GetAccountReq) -> dict: return await client.get_account() async def get_positions(client: IBKRClient, params: GetPositionsReq) -> dict: return {"positions": await client.get_positions()} async def get_open_orders(client: IBKRClient, params: GetOpenOrdersReq) -> dict: return {"orders": await client.get_open_orders()} async def get_activities(client: IBKRClient, params: GetActivitiesReq) -> dict: return {"activities": await client.get_activities(params.days)} async def get_ticker(client: IBKRClient, params: GetTickerReq) -> dict: return await client.get_ticker(params.symbol, params.asset_class) async def get_bars(client: IBKRClient, params: GetBarsReq) -> dict: return await client.get_bars( params.symbol, params.asset_class, params.period, params.bar, ) async def get_snapshot(client: IBKRClient, params: GetSnapshotReq) -> dict: return await client.get_ticker(params.symbol, params.asset_class) async def get_option_chain(client: IBKRClient, params: GetOptionChainReq) -> dict: return await client.get_option_chain(params.underlying, params.expiry) async def search_contracts(client: IBKRClient, params: SearchContractsReq) -> dict: return {"contracts": await client.search_contracts(params.symbol, params.sec_type)} async def get_clock(client: IBKRClient, params: GetClockReq) -> dict: import datetime as _dt now = _dt.datetime.now(_dt.UTC) return { "timestamp": now.isoformat(), "is_open": _dt.time(13, 30) <= now.time() <= _dt.time(20, 0) and now.weekday() < 5, } # === Streaming tools === def _sec_type_for(asset_class: str) -> str: return _SEC_TYPE_MAP.get(asset_class.lower(), "STK") async def get_tick( client: IBKRClient, params: GetTickReq, *, ws: IBKRWebSocket, timeout_s: float = 3.0, ) -> dict: sec = _sec_type_for(params.asset_class) conid = await client.resolve_conid(params.symbol, sec) snap = ws.get_tick_snapshot(conid) if snap: return {**snap, "symbol": params.symbol} await ws.subscribe_tick(conid) deadline = asyncio.get_event_loop().time() + timeout_s while asyncio.get_event_loop().time() < deadline: snap = ws.get_tick_snapshot(conid) if snap: return {**snap, "symbol": params.symbol} await asyncio.sleep(0.05) raise IBKRError(f"IBKR_TICK_TIMEOUT: {params.symbol}") async def get_depth( client: IBKRClient, params: GetDepthReq, *, ws: IBKRWebSocket, timeout_s: float = 3.0, ) -> dict: sec = _sec_type_for(params.asset_class) conid = await client.resolve_conid(params.symbol, sec) snap = ws.get_depth_snapshot(conid) if snap: return {**snap, "symbol": params.symbol} await ws.subscribe_depth(conid, exchange=params.exchange, rows=params.rows) deadline = asyncio.get_event_loop().time() + timeout_s while asyncio.get_event_loop().time() < deadline: snap = ws.get_depth_snapshot(conid) if snap: return {**snap, "symbol": params.symbol} await asyncio.sleep(0.05) raise IBKRError(f"IBKR_DEPTH_TIMEOUT: {params.symbol}") async def subscribe_tick( client: IBKRClient, params: SubscribeTickReq, *, ws: IBKRWebSocket, ) -> dict: sec = _sec_type_for(params.asset_class) conid = await client.resolve_conid(params.symbol, sec) await ws.subscribe_tick(conid, forced=True) return {"symbol": params.symbol, "conid": conid, "subscribed": True} async def unsubscribe( client: IBKRClient, params: UnsubscribeReq, *, ws: IBKRWebSocket, ) -> dict: sec = _sec_type_for(params.asset_class) conid = await client.resolve_conid(params.symbol, sec) await ws.unsubscribe(conid) return {"symbol": params.symbol, "conid": conid, "unsubscribed": True}