diff --git a/src/cerbero_mcp/exchanges/ibkr/tools.py b/src/cerbero_mcp/exchanges/ibkr/tools.py index fa6eaa0..2809337 100644 --- a/src/cerbero_mcp/exchanges/ibkr/tools.py +++ b/src/cerbero_mcp/exchanges/ibkr/tools.py @@ -4,10 +4,11 @@ from __future__ import annotations import asyncio from typing import Any +from fastapi import HTTPException from pydantic import BaseModel from cerbero_mcp.exchanges.ibkr.client import _SEC_TYPE_MAP, IBKRClient, IBKRError -from cerbero_mcp.exchanges.ibkr.leverage_cap import enforce_leverage, get_max_leverage # noqa: F401 +from cerbero_mcp.exchanges.ibkr.leverage_cap import get_max_leverage from cerbero_mcp.exchanges.ibkr.ws import IBKRWebSocket # === Schemas: reads === @@ -247,3 +248,84 @@ async def unsubscribe( conid = await client.resolve_conid(params.symbol, sec) await ws.unsubscribe(conid) return {"symbol": params.symbol, "conid": conid, "unsubscribed": True} + + +# === Write tools: simple === + + +async def place_order( + client: IBKRClient, params: PlaceOrderReq, + *, creds: dict, last_price: float | None = None, +) -> dict: + cap = get_max_leverage(creds) + if last_price is None: + try: + ticker = await client.get_ticker(params.symbol, params.asset_class) + last_price = ticker.get("last_price") or ticker.get("ask") + except Exception: + last_price = None + if last_price: + notional = params.qty * float(last_price) + try: + account = await client.get_account() + equity = float( + (account.get("netliquidation") or {}).get("amount") or 0 + ) + except Exception: + equity = 0.0 + if equity > 0 and notional / equity > cap: + raise HTTPException( + status_code=403, + detail={ + "error": "LEVERAGE_CAP_EXCEEDED", + "exchange": "ibkr", + "requested_ratio": notional / equity, + "max": cap, + }, + ) + + return await client.place_order( + symbol=params.symbol, + side=params.side, + qty=params.qty, + order_type=params.order_type, + limit_price=params.limit_price, + stop_price=params.stop_price, + tif=params.tif, + asset_class=params.asset_class, + sec_type=params.sec_type, + exchange=params.exchange, + outside_rth=params.outside_rth, + ) + + +async def amend_order(client: IBKRClient, params: AmendOrderReq) -> dict: + return await client.amend_order( + params.order_id, + qty=params.qty, + limit_price=params.limit_price, + stop_price=params.stop_price, + tif=params.tif, + ) + + +async def cancel_order(client: IBKRClient, params: CancelOrderReq) -> dict: + return await client.cancel_order(params.order_id) + + +async def cancel_all_orders( + client: IBKRClient, params: CancelAllOrdersReq +) -> dict: + return {"canceled": await client.cancel_all_orders()} + + +async def close_position( + client: IBKRClient, params: ClosePositionReq +) -> dict: + return await client.close_position(params.symbol, params.qty) + + +async def close_all_positions( + client: IBKRClient, params: CloseAllPositionsReq +) -> dict: + return {"closed": await client.close_all_positions()} diff --git a/tests/unit/exchanges/ibkr/test_tools.py b/tests/unit/exchanges/ibkr/test_tools.py index 15880f9..d439f83 100644 --- a/tests/unit/exchanges/ibkr/test_tools.py +++ b/tests/unit/exchanges/ibkr/test_tools.py @@ -45,3 +45,44 @@ async def test_get_tick_uses_cache_or_subscribes(): ) assert res["last_price"] == 99.5 ws.subscribe_tick.assert_awaited_once_with(42) + + +@pytest.mark.asyncio +async def test_place_order_enforces_leverage(): + client = MagicMock() + client.get_account = AsyncMock(return_value={ + "netliquidation": {"amount": 10000}, + }) + client.place_order = AsyncMock(return_value={"order_id": "O1"}) + creds = {"max_leverage": 2} + res = await t.place_order( + client, t.PlaceOrderReq(symbol="AAPL", side="buy", qty=10), + creds=creds, last_price=100.0, + ) + assert res["order_id"] == "O1" + + +@pytest.mark.asyncio +async def test_cancel_order_calls_client(): + client = MagicMock() + client.cancel_order = AsyncMock(return_value={"order_id": "O1", "canceled": True}) + res = await t.cancel_order(client, t.CancelOrderReq(order_id="O1")) + assert res["canceled"] is True + + +@pytest.mark.asyncio +async def test_place_order_rejects_excessive_leverage(): + from fastapi import HTTPException + client = MagicMock() + client.get_account = AsyncMock(return_value={ + "netliquidation": {"amount": 1000}, + }) + creds = {"max_leverage": 2} + # Order notional = 100*100 = 10000 vs equity 1000 → ratio 10x >> 2x cap → 403 + with pytest.raises(HTTPException) as exc_info: + await t.place_order( + client, t.PlaceOrderReq(symbol="AAPL", side="buy", qty=100), + creds=creds, last_price=100.0, + ) + assert exc_info.value.status_code == 403 + assert exc_info.value.detail["error"] == "LEVERAGE_CAP_EXCEEDED"