8914d613ec
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
250 lines
7.0 KiB
Python
250 lines
7.0 KiB
Python
"""IBKR tool functions: Pydantic schemas + async dispatch to client/ws."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from cerbero_mcp.exchanges.ibkr.client import _SEC_TYPE_MAP, IBKRClient, IBKRError
|
|
from cerbero_mcp.exchanges.ibkr.leverage_cap import enforce_leverage, get_max_leverage # noqa: F401
|
|
from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket
|
|
|
|
# === Schemas: reads ===
|
|
|
|
class GetAccountReq(BaseModel):
|
|
pass
|
|
|
|
class GetPositionsReq(BaseModel):
|
|
pass
|
|
|
|
class GetOpenOrdersReq(BaseModel):
|
|
pass
|
|
|
|
class GetActivitiesReq(BaseModel):
|
|
days: int = 7
|
|
|
|
class GetTickerReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
|
|
class GetBarsReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
period: str = "1d"
|
|
bar: str = "5min"
|
|
|
|
class GetSnapshotReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
|
|
class GetOptionChainReq(BaseModel):
|
|
underlying: str
|
|
expiry: str | None = None
|
|
|
|
class SearchContractsReq(BaseModel):
|
|
symbol: str
|
|
sec_type: str = "STK"
|
|
|
|
class GetClockReq(BaseModel):
|
|
pass
|
|
|
|
# === Schemas: streaming ===
|
|
|
|
class GetTickReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
|
|
class GetDepthReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
rows: int = 5
|
|
exchange: str = "SMART"
|
|
|
|
class SubscribeTickReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
|
|
class UnsubscribeReq(BaseModel):
|
|
symbol: str
|
|
asset_class: str = "stocks"
|
|
|
|
# === Schemas: writes simple ===
|
|
|
|
class PlaceOrderReq(BaseModel):
|
|
symbol: str
|
|
side: str
|
|
qty: float
|
|
order_type: str = "market"
|
|
limit_price: float | None = None
|
|
stop_price: float | None = None
|
|
tif: str = "day"
|
|
asset_class: str = "stocks"
|
|
sec_type: str | None = None
|
|
exchange: str = "SMART"
|
|
outside_rth: bool = False
|
|
|
|
class AmendOrderReq(BaseModel):
|
|
order_id: str
|
|
qty: float | None = None
|
|
limit_price: float | None = None
|
|
stop_price: float | None = None
|
|
tif: str | None = None
|
|
|
|
class CancelOrderReq(BaseModel):
|
|
order_id: str
|
|
|
|
class CancelAllOrdersReq(BaseModel):
|
|
pass
|
|
|
|
class ClosePositionReq(BaseModel):
|
|
symbol: str
|
|
qty: float | None = None
|
|
|
|
class CloseAllPositionsReq(BaseModel):
|
|
pass
|
|
|
|
# === Schemas: writes complex ===
|
|
|
|
class PlaceBracketOrderReq(BaseModel):
|
|
symbol: str
|
|
side: str
|
|
qty: float
|
|
entry_price: float
|
|
stop_loss: float
|
|
take_profit: float
|
|
tif: str = "gtc"
|
|
asset_class: str = "stocks"
|
|
exchange: str = "SMART"
|
|
|
|
class OrderLeg(BaseModel):
|
|
symbol: str
|
|
side: str
|
|
qty: float
|
|
order_type: str = "limit"
|
|
limit_price: float | None = None
|
|
stop_price: float | None = None
|
|
tif: str = "gtc"
|
|
asset_class: str = "stocks"
|
|
|
|
class PlaceOcoOrderReq(BaseModel):
|
|
legs: list[OrderLeg]
|
|
|
|
class PlaceOtoOrderReq(BaseModel):
|
|
trigger: OrderLeg
|
|
child: OrderLeg
|
|
|
|
# === Read tools ===
|
|
|
|
async def environment_info(
|
|
client: IBKRClient, *, creds: dict, env_info: Any | None = None
|
|
) -> dict:
|
|
return {
|
|
"exchange": "ibkr",
|
|
"environment": "testnet" if client.paper else "mainnet",
|
|
"paper": client.paper,
|
|
"base_url": client.base_url,
|
|
"max_leverage": get_max_leverage(creds),
|
|
}
|
|
|
|
async def get_account(client: IBKRClient, params: GetAccountReq) -> dict:
|
|
return await client.get_account()
|
|
|
|
async def get_positions(client: IBKRClient, params: GetPositionsReq) -> dict:
|
|
return {"positions": await client.get_positions()}
|
|
|
|
async def get_open_orders(client: IBKRClient, params: GetOpenOrdersReq) -> dict:
|
|
return {"orders": await client.get_open_orders()}
|
|
|
|
async def get_activities(client: IBKRClient, params: GetActivitiesReq) -> dict:
|
|
return {"activities": await client.get_activities(params.days)}
|
|
|
|
async def get_ticker(client: IBKRClient, params: GetTickerReq) -> dict:
|
|
return await client.get_ticker(params.symbol, params.asset_class)
|
|
|
|
async def get_bars(client: IBKRClient, params: GetBarsReq) -> dict:
|
|
return await client.get_bars(
|
|
params.symbol, params.asset_class, params.period, params.bar,
|
|
)
|
|
|
|
async def get_snapshot(client: IBKRClient, params: GetSnapshotReq) -> dict:
|
|
return await client.get_ticker(params.symbol, params.asset_class)
|
|
|
|
async def get_option_chain(client: IBKRClient, params: GetOptionChainReq) -> dict:
|
|
return await client.get_option_chain(params.underlying, params.expiry)
|
|
|
|
async def search_contracts(client: IBKRClient, params: SearchContractsReq) -> dict:
|
|
return {"contracts": await client.search_contracts(params.symbol, params.sec_type)}
|
|
|
|
async def get_clock(client: IBKRClient, params: GetClockReq) -> dict:
|
|
import datetime as _dt
|
|
now = _dt.datetime.now(_dt.UTC)
|
|
return {
|
|
"timestamp": now.isoformat(),
|
|
"is_open": _dt.time(13, 30) <= now.time() <= _dt.time(20, 0)
|
|
and now.weekday() < 5,
|
|
}
|
|
|
|
|
|
# === Streaming tools ===
|
|
|
|
|
|
def _sec_type_for(asset_class: str) -> str:
|
|
return _SEC_TYPE_MAP.get(asset_class.lower(), "STK")
|
|
|
|
|
|
async def get_tick(
|
|
client: IBKRClient, params: GetTickReq,
|
|
*, ws: IBKRWebSocket, timeout_s: float = 3.0,
|
|
) -> dict:
|
|
sec = _sec_type_for(params.asset_class)
|
|
conid = await client.resolve_conid(params.symbol, sec)
|
|
snap = ws.get_tick_snapshot(conid)
|
|
if snap:
|
|
return {**snap, "symbol": params.symbol}
|
|
await ws.subscribe_tick(conid)
|
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
snap = ws.get_tick_snapshot(conid)
|
|
if snap:
|
|
return {**snap, "symbol": params.symbol}
|
|
await asyncio.sleep(0.05)
|
|
raise IBKRError(f"IBKR_TICK_TIMEOUT: {params.symbol}")
|
|
|
|
|
|
async def get_depth(
|
|
client: IBKRClient, params: GetDepthReq,
|
|
*, ws: IBKRWebSocket, timeout_s: float = 3.0,
|
|
) -> dict:
|
|
sec = _sec_type_for(params.asset_class)
|
|
conid = await client.resolve_conid(params.symbol, sec)
|
|
snap = ws.get_depth_snapshot(conid)
|
|
if snap:
|
|
return {**snap, "symbol": params.symbol}
|
|
await ws.subscribe_depth(conid, exchange=params.exchange, rows=params.rows)
|
|
deadline = asyncio.get_event_loop().time() + timeout_s
|
|
while asyncio.get_event_loop().time() < deadline:
|
|
snap = ws.get_depth_snapshot(conid)
|
|
if snap:
|
|
return {**snap, "symbol": params.symbol}
|
|
await asyncio.sleep(0.05)
|
|
raise IBKRError(f"IBKR_DEPTH_TIMEOUT: {params.symbol}")
|
|
|
|
|
|
async def subscribe_tick(
|
|
client: IBKRClient, params: SubscribeTickReq, *, ws: IBKRWebSocket,
|
|
) -> dict:
|
|
sec = _sec_type_for(params.asset_class)
|
|
conid = await client.resolve_conid(params.symbol, sec)
|
|
await ws.subscribe_tick(conid, forced=True)
|
|
return {"symbol": params.symbol, "conid": conid, "subscribed": True}
|
|
|
|
|
|
async def unsubscribe(
|
|
client: IBKRClient, params: UnsubscribeReq, *, ws: IBKRWebSocket,
|
|
) -> dict:
|
|
sec = _sec_type_for(params.asset_class)
|
|
conid = await client.resolve_conid(params.symbol, sec)
|
|
await ws.unsubscribe(conid)
|
|
return {"symbol": params.symbol, "conid": conid, "unsubscribed": True}
|