From 1b8ba0ef9c7a0d9c7359fab1f9f956aa8a5b1988 Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Thu, 30 Apr 2026 18:39:25 +0200 Subject: [PATCH] feat(V2): migrazione alpaca completa Task 6.7: porting alpaca da services/mcp-alpaca a src/cerbero_mcp. client.py + leverage_cap.py copiati 1:1 (default cap 1 cash). tools.py: 17 tool senza ACL/Principal/audit. Router /mcp-alpaca con 18 route (env_info + 17 tool). Builder branch alpaca: paper=(env=="testnet"), api_key viene da settings.alpaca.api_key_id. Test client + leverage_cap migrati (15 test alpaca pass). Test builder con stub SDK alpaca-py. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cerbero_mcp/exchanges/__init__.py | 8 + src/cerbero_mcp/exchanges/alpaca/__init__.py | 0 src/cerbero_mcp/exchanges/alpaca/client.py | 385 ++++++++++++++++++ .../exchanges/alpaca/leverage_cap.py | 56 +++ src/cerbero_mcp/exchanges/alpaca/tools.py | 279 +++++++++++++ src/cerbero_mcp/routers/alpaca.py | 176 ++++++++ tests/unit/exchanges/alpaca/__init__.py | 0 tests/unit/exchanges/alpaca/conftest.py | 39 ++ tests/unit/exchanges/alpaca/test_client.py | 80 ++++ .../exchanges/alpaca/test_leverage_cap.py | 46 +++ tests/unit/test_exchanges_builder.py | 31 ++ 11 files changed, 1100 insertions(+) create mode 100644 src/cerbero_mcp/exchanges/alpaca/__init__.py create mode 100644 src/cerbero_mcp/exchanges/alpaca/client.py create mode 100644 src/cerbero_mcp/exchanges/alpaca/leverage_cap.py create mode 100644 src/cerbero_mcp/exchanges/alpaca/tools.py create mode 100644 src/cerbero_mcp/routers/alpaca.py create mode 100644 tests/unit/exchanges/alpaca/__init__.py create mode 100644 tests/unit/exchanges/alpaca/conftest.py create mode 100644 tests/unit/exchanges/alpaca/test_client.py create mode 100644 tests/unit/exchanges/alpaca/test_leverage_cap.py diff --git a/src/cerbero_mcp/exchanges/__init__.py b/src/cerbero_mcp/exchanges/__init__.py index d68e489..44a3770 100644 --- a/src/cerbero_mcp/exchanges/__init__.py +++ b/src/cerbero_mcp/exchanges/__init__.py @@ -36,4 +36,12 @@ async def build_client( testnet=(env == "testnet"), api_wallet_address=settings.hyperliquid.api_wallet_address, ) + if exchange == "alpaca": + from cerbero_mcp.exchanges.alpaca.client import AlpacaClient + + return AlpacaClient( + api_key=settings.alpaca.api_key_id, + secret_key=settings.alpaca.secret_key.get_secret_value(), + paper=(env == "testnet"), + ) raise ValueError(f"unsupported exchange: {exchange}") diff --git a/src/cerbero_mcp/exchanges/alpaca/__init__.py b/src/cerbero_mcp/exchanges/alpaca/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cerbero_mcp/exchanges/alpaca/client.py b/src/cerbero_mcp/exchanges/alpaca/client.py new file mode 100644 index 0000000..001d5c6 --- /dev/null +++ b/src/cerbero_mcp/exchanges/alpaca/client.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import asyncio +import datetime as _dt +from typing import Any + +from alpaca.data.historical import ( + CryptoHistoricalDataClient, + OptionHistoricalDataClient, + StockHistoricalDataClient, +) +from alpaca.data.requests import ( + CryptoBarsRequest, + CryptoLatestQuoteRequest, + CryptoLatestTradeRequest, + OptionBarsRequest, + OptionChainRequest, + OptionLatestQuoteRequest, + StockBarsRequest, + StockLatestQuoteRequest, + StockLatestTradeRequest, + StockSnapshotRequest, +) +from alpaca.data.timeframe import TimeFrame, TimeFrameUnit +from alpaca.trading.client import TradingClient +from alpaca.trading.enums import ( + AssetClass, + OrderSide, + QueryOrderStatus, + TimeInForce, +) +from alpaca.trading.requests import ( + ClosePositionRequest, + GetAssetsRequest, + GetOrdersRequest, + LimitOrderRequest, + MarketOrderRequest, + ReplaceOrderRequest, + StopOrderRequest, +) + +_TF_MAP = { + "1min": TimeFrame(1, TimeFrameUnit.Minute), + "5min": TimeFrame(5, TimeFrameUnit.Minute), + "15min": TimeFrame(15, TimeFrameUnit.Minute), + "30min": TimeFrame(30, TimeFrameUnit.Minute), + "1h": TimeFrame(1, TimeFrameUnit.Hour), + "1d": TimeFrame(1, TimeFrameUnit.Day), + "1w": TimeFrame(1, TimeFrameUnit.Week), +} + +_ASSET_CLASSES = {"stocks", "crypto", "options"} + + +def _tf(interval: str) -> TimeFrame: + if interval in _TF_MAP: + return _TF_MAP[interval] + raise ValueError(f"unsupported timeframe: {interval}") + + +def _asset_class_enum(ac: str) -> AssetClass: + ac = ac.lower() + if ac == "stocks": + return AssetClass.US_EQUITY + if ac == "crypto": + return AssetClass.CRYPTO + if ac == "options": + return AssetClass.US_OPTION + raise ValueError(f"invalid asset_class: {ac}") + + +def _serialize(obj: Any) -> Any: + """Recursively convert pydantic/datetime objects → json-safe.""" + if obj is None or isinstance(obj, str | int | float | bool): + return obj + if isinstance(obj, _dt.datetime | _dt.date): + return obj.isoformat() + if isinstance(obj, dict): + return {k: _serialize(v) for k, v in obj.items()} + if isinstance(obj, list | tuple): + return [_serialize(v) for v in obj] + if hasattr(obj, "model_dump"): + return _serialize(obj.model_dump()) + if hasattr(obj, "__dict__"): + return _serialize(vars(obj)) + return str(obj) + + +class AlpacaClient: + def __init__( + self, + api_key: str, + secret_key: str, + paper: bool = True, + trading: Any | None = None, + stock_data: Any | None = None, + crypto_data: Any | None = None, + option_data: Any | None = None, + ) -> None: + self.api_key = api_key + self.secret_key = secret_key + self.paper = paper + self._trading = trading or TradingClient( + api_key=api_key, secret_key=secret_key, paper=paper + ) + self._stock = stock_data or StockHistoricalDataClient( + api_key=api_key, secret_key=secret_key + ) + self._crypto = crypto_data or CryptoHistoricalDataClient( + api_key=api_key, secret_key=secret_key + ) + self._option = option_data or OptionHistoricalDataClient( + api_key=api_key, secret_key=secret_key + ) + + async def _run(self, fn, /, *args, **kwargs): + return await asyncio.to_thread(fn, *args, **kwargs) + + # ── Account / positions ────────────────────────────────────── + + async def get_account(self) -> dict: + acc = await self._run(self._trading.get_account) + return _serialize(acc) + + async def get_positions(self) -> list[dict]: + pos = await self._run(self._trading.get_all_positions) + return [_serialize(p) for p in pos] + + async def get_activities(self, limit: int = 50) -> list[dict]: + acts = await self._run(self._trading.get_account_activities) + data = [_serialize(a) for a in acts] + return data[:limit] + + # ── Assets ────────────────────────────────────────────────── + + async def get_assets( + self, asset_class: str = "stocks", status: str = "active" + ) -> list[dict]: + req = GetAssetsRequest( + asset_class=_asset_class_enum(asset_class), + status=status, + ) + assets = await self._run(self._trading.get_all_assets, req) + return [_serialize(a) for a in assets[:500]] + + # ── Market data ───────────────────────────────────────────── + + async def get_ticker(self, symbol: str, asset_class: str = "stocks") -> dict: + ac = asset_class.lower() + if ac == "stocks": + req = StockLatestTradeRequest(symbol_or_symbols=symbol) + data = await self._run(self._stock.get_stock_latest_trade, req) + trade = data.get(symbol) + q_req = StockLatestQuoteRequest(symbol_or_symbols=symbol) + qdata = await self._run(self._stock.get_stock_latest_quote, q_req) + quote = qdata.get(symbol) + return { + "symbol": symbol, + "asset_class": "stocks", + "last_price": getattr(trade, "price", None), + "bid": getattr(quote, "bid_price", None), + "ask": getattr(quote, "ask_price", None), + "bid_size": getattr(quote, "bid_size", None), + "ask_size": getattr(quote, "ask_size", None), + "timestamp": _serialize(getattr(trade, "timestamp", None)), + } + if ac == "crypto": + req = CryptoLatestTradeRequest(symbol_or_symbols=symbol) + data = await self._run(self._crypto.get_crypto_latest_trade, req) + trade = data.get(symbol) + q_req = CryptoLatestQuoteRequest(symbol_or_symbols=symbol) + qdata = await self._run(self._crypto.get_crypto_latest_quote, q_req) + quote = qdata.get(symbol) + return { + "symbol": symbol, + "asset_class": "crypto", + "last_price": getattr(trade, "price", None), + "bid": getattr(quote, "bid_price", None), + "ask": getattr(quote, "ask_price", None), + "timestamp": _serialize(getattr(trade, "timestamp", None)), + } + if ac == "options": + req = OptionLatestQuoteRequest(symbol_or_symbols=symbol) + data = await self._run(self._option.get_option_latest_quote, req) + quote = data.get(symbol) + return { + "symbol": symbol, + "asset_class": "options", + "bid": getattr(quote, "bid_price", None), + "ask": getattr(quote, "ask_price", None), + "timestamp": _serialize(getattr(quote, "timestamp", None)), + } + raise ValueError(f"invalid asset_class: {asset_class}") + + async def get_bars( + self, + symbol: str, + asset_class: str = "stocks", + interval: str = "1d", + start: str | None = None, + end: str | None = None, + limit: int = 1000, + ) -> dict: + tf = _tf(interval) + start_dt = _dt.datetime.fromisoformat(start) if start else ( + _dt.datetime.now(_dt.UTC) - _dt.timedelta(days=30) + ) + end_dt = _dt.datetime.fromisoformat(end) if end else _dt.datetime.now(_dt.UTC) + ac = asset_class.lower() + if ac == "stocks": + req = StockBarsRequest( + symbol_or_symbols=symbol, timeframe=tf, + start=start_dt, end=end_dt, limit=limit, + ) + data = await self._run(self._stock.get_stock_bars, req) + elif ac == "crypto": + req = CryptoBarsRequest( + symbol_or_symbols=symbol, timeframe=tf, + start=start_dt, end=end_dt, limit=limit, + ) + data = await self._run(self._crypto.get_crypto_bars, req) + elif ac == "options": + req = OptionBarsRequest( + symbol_or_symbols=symbol, timeframe=tf, + start=start_dt, end=end_dt, limit=limit, + ) + data = await self._run(self._option.get_option_bars, req) + else: + raise ValueError(f"invalid asset_class: {asset_class}") + bars_dict = getattr(data, "data", {}) or {} + rows = bars_dict.get(symbol, []) or [] + bars = [ + { + "timestamp": _serialize(getattr(b, "timestamp", None)), + "open": getattr(b, "open", None), + "high": getattr(b, "high", None), + "low": getattr(b, "low", None), + "close": getattr(b, "close", None), + "volume": getattr(b, "volume", None), + } + for b in rows + ] + return {"symbol": symbol, "asset_class": ac, "interval": interval, "bars": bars} + + async def get_snapshot(self, symbol: str) -> dict: + req = StockSnapshotRequest(symbol_or_symbols=symbol) + data = await self._run(self._stock.get_stock_snapshot, req) + return _serialize(data.get(symbol)) + + async def get_option_chain( + self, + underlying: str, + expiry: str | None = None, + ) -> dict: + kwargs: dict[str, Any] = {"underlying_symbol": underlying} + if expiry: + kwargs["expiration_date"] = _dt.date.fromisoformat(expiry) + req = OptionChainRequest(**kwargs) + data = await self._run(self._option.get_option_chain, req) + return { + "underlying": underlying, + "expiry": expiry, + "contracts": _serialize(data), + } + + # ── Orders ────────────────────────────────────────────────── + + async def get_open_orders(self, limit: int = 50) -> list[dict]: + req = GetOrdersRequest(status=QueryOrderStatus.OPEN, limit=limit) + orders = await self._run(self._trading.get_orders, filter=req) + return [_serialize(o) for o in orders] + + async def place_order( + self, + symbol: str, + side: str, + qty: float | None = None, + notional: float | None = None, + order_type: str = "market", + limit_price: float | None = None, + stop_price: float | None = None, + tif: str = "day", + asset_class: str = "stocks", + ) -> dict: + side_enum = OrderSide.BUY if side.lower() == "buy" else OrderSide.SELL + tif_enum = TimeInForce(tif.lower()) + ot = order_type.lower() + common = { + "symbol": symbol, + "side": side_enum, + "time_in_force": tif_enum, + } + if qty is not None: + common["qty"] = qty + if notional is not None: + common["notional"] = notional + if ot == "market": + req = MarketOrderRequest(**common) + elif ot == "limit": + if limit_price is None: + raise ValueError("limit_price required for limit order") + req = LimitOrderRequest(**common, limit_price=limit_price) + elif ot == "stop": + if stop_price is None: + raise ValueError("stop_price required for stop order") + req = StopOrderRequest(**common, stop_price=stop_price) + else: + raise ValueError(f"unsupported order_type: {order_type}") + order = await self._run(self._trading.submit_order, req) + return _serialize(order) + + async def amend_order( + self, + order_id: str, + qty: float | None = None, + limit_price: float | None = None, + stop_price: float | None = None, + tif: str | None = None, + ) -> dict: + kwargs: dict[str, Any] = {} + if qty is not None: + kwargs["qty"] = qty + if limit_price is not None: + kwargs["limit_price"] = limit_price + if stop_price is not None: + kwargs["stop_price"] = stop_price + if tif is not None: + kwargs["time_in_force"] = TimeInForce(tif.lower()) + req = ReplaceOrderRequest(**kwargs) + order = await self._run(self._trading.replace_order_by_id, order_id, req) + return _serialize(order) + + async def cancel_order(self, order_id: str) -> dict: + await self._run(self._trading.cancel_order_by_id, order_id) + return {"order_id": order_id, "canceled": True} + + async def cancel_all_orders(self) -> list[dict]: + resp = await self._run(self._trading.cancel_orders) + return [_serialize(r) for r in resp] + + # ── Position close ────────────────────────────────────────── + + async def close_position( + self, symbol: str, qty: float | None = None, percentage: float | None = None + ) -> dict: + req = None + if qty is not None or percentage is not None: + kwargs: dict[str, Any] = {} + if qty is not None: + kwargs["qty"] = str(qty) + if percentage is not None: + kwargs["percentage"] = str(percentage) + req = ClosePositionRequest(**kwargs) + order = await self._run( + self._trading.close_position, symbol, close_options=req + ) + return _serialize(order) + + async def close_all_positions(self, cancel_orders: bool = True) -> list[dict]: + resp = await self._run( + self._trading.close_all_positions, cancel_orders=cancel_orders + ) + return [_serialize(r) for r in resp] + + # ── Clock / calendar ──────────────────────────────────────── + + async def get_clock(self) -> dict: + clock = await self._run(self._trading.get_clock) + return _serialize(clock) + + async def get_calendar( + self, start: str | None = None, end: str | None = None + ) -> list[dict]: + from alpaca.trading.requests import GetCalendarRequest + + kwargs: dict[str, Any] = {} + if start: + kwargs["start"] = _dt.date.fromisoformat(start) + if end: + kwargs["end"] = _dt.date.fromisoformat(end) + req = GetCalendarRequest(**kwargs) if kwargs else None + cal = await self._run( + self._trading.get_calendar, filters=req + ) if req else await self._run(self._trading.get_calendar) + return [_serialize(c) for c in cal] diff --git a/src/cerbero_mcp/exchanges/alpaca/leverage_cap.py b/src/cerbero_mcp/exchanges/alpaca/leverage_cap.py new file mode 100644 index 0000000..d04dd51 --- /dev/null +++ b/src/cerbero_mcp/exchanges/alpaca/leverage_cap.py @@ -0,0 +1,56 @@ +"""Leverage cap server-side per place_order. + +Cap letto dal secret JSON via campo `max_leverage`. Default 1 (cash) se assente. +""" +from __future__ import annotations + +from fastapi import HTTPException + + +def get_max_leverage(creds: dict) -> int: + """Legge max_leverage dal secret. Default 1 se mancante.""" + raw = creds.get("max_leverage", 1) + try: + value = int(raw) + except (TypeError, ValueError): + value = 1 + return max(1, value) + + +def enforce_leverage( + requested: int | float | None, + *, + creds: dict, + exchange: str, +) -> int: + """Verifica e applica leverage cap. Ritorna leverage applicabile. + + Solleva HTTPException(403, LEVERAGE_CAP_EXCEEDED) se requested > cap. + Se requested is None, applica il cap come default. + """ + cap = get_max_leverage(creds) + if requested is None: + return cap + lev = int(requested) + if lev < 1: + raise HTTPException( + status_code=403, + detail={ + "error": "LEVERAGE_CAP_EXCEEDED", + "exchange": exchange, + "requested": lev, + "max": cap, + "reason": "leverage must be >= 1", + }, + ) + if lev > cap: + raise HTTPException( + status_code=403, + detail={ + "error": "LEVERAGE_CAP_EXCEEDED", + "exchange": exchange, + "requested": lev, + "max": cap, + }, + ) + return lev diff --git a/src/cerbero_mcp/exchanges/alpaca/tools.py b/src/cerbero_mcp/exchanges/alpaca/tools.py new file mode 100644 index 0000000..2599d43 --- /dev/null +++ b/src/cerbero_mcp/exchanges/alpaca/tools.py @@ -0,0 +1,279 @@ +"""Tool alpaca V2: pydantic schemas + async functions. + +Ogni funzione prende (client: AlpacaClient, params: ) e restituisce +un dict (o list[dict]). Pure logica, no FastAPI dependency, no ACL. +L'autenticazione bearer è gestita dal middleware in cerbero_mcp.auth; +l'audit verrà cablato dal router via request.state.environment. +""" +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from cerbero_mcp.exchanges.alpaca.client import AlpacaClient +from cerbero_mcp.exchanges.alpaca.leverage_cap import get_max_leverage + +# === Schemas: reads === + + +class GetAccountReq(BaseModel): + pass + + +class GetPositionsReq(BaseModel): + pass + + +class GetActivitiesReq(BaseModel): + limit: int = 50 + + +class GetAssetsReq(BaseModel): + asset_class: str = "stocks" + status: str = "active" + + +class GetTickerReq(BaseModel): + symbol: str + asset_class: str = "stocks" + + +class GetBarsReq(BaseModel): + symbol: str + asset_class: str = "stocks" + interval: str = "1d" + start: str | None = None + end: str | None = None + limit: int = 1000 + + +class GetSnapshotReq(BaseModel): + symbol: str + + +class GetOptionChainReq(BaseModel): + underlying: str + expiry: str | None = None + + +class GetOpenOrdersReq(BaseModel): + limit: int = 50 + + +class GetClockReq(BaseModel): + pass + + +class GetCalendarReq(BaseModel): + start: str | None = None + end: str | None = None + + +# === Schemas: writes === + + +class PlaceOrderReq(BaseModel): + symbol: str + side: str # "buy" | "sell" + qty: float | None = None + notional: float | None = None + order_type: str = "market" + limit_price: float | None = None + stop_price: float | None = None + tif: str = "day" + asset_class: str = "stocks" + + model_config = { + "json_schema_extra": { + "examples": [ + { + "summary": "Market buy 1 share AAPL", + "value": { + "symbol": "AAPL", + "side": "buy", + "qty": 1, + "order_type": "market", + "asset_class": "stocks", + }, + } + ] + } + } + + +class AmendOrderReq(BaseModel): + order_id: str + qty: float | None = None + limit_price: float | None = None + stop_price: float | None = None + tif: str | None = None + + +class CancelOrderReq(BaseModel): + order_id: str + + +class CancelAllOrdersReq(BaseModel): + pass + + +class ClosePositionReq(BaseModel): + symbol: str + qty: float | None = None + percentage: float | None = None + + +class CloseAllPositionsReq(BaseModel): + cancel_orders: bool = True + + +# === Tools (reads) === + + +async def environment_info( + client: AlpacaClient, *, creds: dict, env_info: Any | None = None +) -> dict: + if env_info is None: + return { + "exchange": "alpaca", + "environment": "testnet" if getattr(client, "paper", True) else "mainnet", + "source": "credentials", + "env_value": None, + "base_url": getattr(client, "base_url", None), + "max_leverage": get_max_leverage(creds), + } + return { + "exchange": env_info.exchange, + "environment": env_info.environment, + "source": env_info.source, + "env_value": env_info.env_value, + "base_url": env_info.base_url, + "max_leverage": get_max_leverage(creds), + } + + +async def get_account(client: AlpacaClient, params: GetAccountReq) -> dict: + return await client.get_account() + + +async def get_positions( + client: AlpacaClient, params: GetPositionsReq +) -> dict: + return {"positions": await client.get_positions()} + + +async def get_activities( + client: AlpacaClient, params: GetActivitiesReq +) -> dict: + return {"activities": await client.get_activities(params.limit)} + + +async def get_assets(client: AlpacaClient, params: GetAssetsReq) -> dict: + return { + "assets": await client.get_assets(params.asset_class, params.status) + } + + +async def get_ticker(client: AlpacaClient, params: GetTickerReq) -> dict: + return await client.get_ticker(params.symbol, params.asset_class) + + +async def get_bars(client: AlpacaClient, params: GetBarsReq) -> dict: + return await client.get_bars( + params.symbol, + params.asset_class, + params.interval, + params.start, + params.end, + params.limit, + ) + + +async def get_snapshot( + client: AlpacaClient, params: GetSnapshotReq +) -> dict: + return await client.get_snapshot(params.symbol) + + +async def get_option_chain( + client: AlpacaClient, params: GetOptionChainReq +) -> dict: + return await client.get_option_chain(params.underlying, params.expiry) + + +async def get_open_orders( + client: AlpacaClient, params: GetOpenOrdersReq +) -> dict: + return {"orders": await client.get_open_orders(params.limit)} + + +async def get_clock(client: AlpacaClient, params: GetClockReq) -> dict: + return await client.get_clock() + + +async def get_calendar( + client: AlpacaClient, params: GetCalendarReq +) -> dict: + return {"calendar": await client.get_calendar(params.start, params.end)} + + +# === Tools (writes) === + + +async def place_order( + client: AlpacaClient, params: PlaceOrderReq, *, creds: dict +) -> dict: + # Alpaca: cap default 1 (cash account). Niente leverage parametro; + # cap presente per coerenza con altri exchange e per audit. + return await client.place_order( + symbol=params.symbol, + side=params.side, + qty=params.qty, + notional=params.notional, + order_type=params.order_type, + limit_price=params.limit_price, + stop_price=params.stop_price, + tif=params.tif, + asset_class=params.asset_class, + ) + + +async def amend_order( + client: AlpacaClient, params: AmendOrderReq +) -> dict: + return await client.amend_order( + params.order_id, + params.qty, + params.limit_price, + params.stop_price, + params.tif, + ) + + +async def cancel_order( + client: AlpacaClient, params: CancelOrderReq +) -> dict: + return await client.cancel_order(params.order_id) + + +async def cancel_all_orders( + client: AlpacaClient, params: CancelAllOrdersReq +) -> dict: + return {"canceled": await client.cancel_all_orders()} + + +async def close_position( + client: AlpacaClient, params: ClosePositionReq +) -> dict: + return await client.close_position( + params.symbol, params.qty, params.percentage + ) + + +async def close_all_positions( + client: AlpacaClient, params: CloseAllPositionsReq +) -> dict: + return { + "closed": await client.close_all_positions(params.cancel_orders) + } diff --git a/src/cerbero_mcp/routers/alpaca.py b/src/cerbero_mcp/routers/alpaca.py new file mode 100644 index 0000000..967653f --- /dev/null +++ b/src/cerbero_mcp/routers/alpaca.py @@ -0,0 +1,176 @@ +"""Router /mcp-alpaca/* — DI per env, client e (write) creds. + +Mappa 1:1 i tool di `cerbero_mcp.exchanges.alpaca.tools` a endpoint +`POST /mcp-alpaca/tools/{tool_name}`. L'autenticazione bearer è gestita +dal middleware in `cerbero_mcp.auth`; qui leggiamo solo `request.state.environment`. +""" +from __future__ import annotations + +from typing import Literal + +from fastapi import APIRouter, Depends, Request + +from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.exchanges.alpaca import tools as t +from cerbero_mcp.exchanges.alpaca.client import AlpacaClient + +Environment = Literal["testnet", "mainnet"] + + +def get_environment(request: Request) -> Environment: + return request.state.environment + + +async def get_alpaca_client( + request: Request, env: Environment = Depends(get_environment) +) -> AlpacaClient: + registry: ClientRegistry = request.app.state.registry + return await registry.get("alpaca", env) + + +def _build_creds(request: Request) -> dict: + """Costruisce dict `creds` minimale per leverage cap / metadata.""" + settings = request.app.state.settings + return { + "max_leverage": settings.alpaca.max_leverage, + "api_key_id": settings.alpaca.api_key_id, + } + + +def make_router() -> APIRouter: + r = APIRouter(prefix="/mcp-alpaca", tags=["alpaca"]) + + # === READ tools === + + @r.post("/tools/environment_info") + async def _environment_info( + request: Request, + client: AlpacaClient = Depends(get_alpaca_client), + ): + creds = _build_creds(request) + return await t.environment_info(client, creds=creds) + + @r.post("/tools/get_account") + async def _get_account( + params: t.GetAccountReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_account(client, params) + + @r.post("/tools/get_positions") + async def _get_positions( + params: t.GetPositionsReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_positions(client, params) + + @r.post("/tools/get_activities") + async def _get_activities( + params: t.GetActivitiesReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_activities(client, params) + + @r.post("/tools/get_assets") + async def _get_assets( + params: t.GetAssetsReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_assets(client, params) + + @r.post("/tools/get_ticker") + async def _get_ticker( + params: t.GetTickerReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_ticker(client, params) + + @r.post("/tools/get_bars") + async def _get_bars( + params: t.GetBarsReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_bars(client, params) + + @r.post("/tools/get_snapshot") + async def _get_snapshot( + params: t.GetSnapshotReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_snapshot(client, params) + + @r.post("/tools/get_option_chain") + async def _get_option_chain( + params: t.GetOptionChainReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_option_chain(client, params) + + @r.post("/tools/get_open_orders") + async def _get_open_orders( + params: t.GetOpenOrdersReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_open_orders(client, params) + + @r.post("/tools/get_clock") + async def _get_clock( + params: t.GetClockReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_clock(client, params) + + @r.post("/tools/get_calendar") + async def _get_calendar( + params: t.GetCalendarReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.get_calendar(client, params) + + # === WRITE tools === + + @r.post("/tools/place_order") + async def _place_order( + params: t.PlaceOrderReq, + request: Request, + client: AlpacaClient = Depends(get_alpaca_client), + ): + creds = _build_creds(request) + return await t.place_order(client, params, creds=creds) + + @r.post("/tools/amend_order") + async def _amend_order( + params: t.AmendOrderReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.amend_order(client, params) + + @r.post("/tools/cancel_order") + async def _cancel_order( + params: t.CancelOrderReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.cancel_order(client, params) + + @r.post("/tools/cancel_all_orders") + async def _cancel_all_orders( + params: t.CancelAllOrdersReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.cancel_all_orders(client, params) + + @r.post("/tools/close_position") + async def _close_position( + params: t.ClosePositionReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.close_position(client, params) + + @r.post("/tools/close_all_positions") + async def _close_all_positions( + params: t.CloseAllPositionsReq, + client: AlpacaClient = Depends(get_alpaca_client), + ): + return await t.close_all_positions(client, params) + + return r diff --git a/tests/unit/exchanges/alpaca/__init__.py b/tests/unit/exchanges/alpaca/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/exchanges/alpaca/conftest.py b/tests/unit/exchanges/alpaca/conftest.py new file mode 100644 index 0000000..ac65829 --- /dev/null +++ b/tests/unit/exchanges/alpaca/conftest.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from cerbero_mcp.exchanges.alpaca.client import AlpacaClient + + +@pytest.fixture +def mock_trading(): + return MagicMock(name="alpaca_TradingClient") + + +@pytest.fixture +def mock_stock(): + return MagicMock(name="alpaca_StockHistoricalDataClient") + + +@pytest.fixture +def mock_crypto(): + return MagicMock(name="alpaca_CryptoHistoricalDataClient") + + +@pytest.fixture +def mock_option(): + return MagicMock(name="alpaca_OptionHistoricalDataClient") + + +@pytest.fixture +def client(mock_trading, mock_stock, mock_crypto, mock_option): + return AlpacaClient( + api_key="test_key", + secret_key="test_secret", + paper=True, + trading=mock_trading, + stock_data=mock_stock, + crypto_data=mock_crypto, + option_data=mock_option, + ) diff --git a/tests/unit/exchanges/alpaca/test_client.py b/tests/unit/exchanges/alpaca/test_client.py new file mode 100644 index 0000000..24172ec --- /dev/null +++ b/tests/unit/exchanges/alpaca/test_client.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +@pytest.mark.asyncio +async def test_init_paper_mode(client, mock_trading): + assert client.paper is True + assert client._trading is mock_trading + + +@pytest.mark.asyncio +async def test_get_account_calls_trading(client, mock_trading): + mock_trading.get_account.return_value = MagicMock( + model_dump=lambda: {"equity": 100000, "cash": 50000} + ) + result = await client.get_account() + mock_trading.get_account.assert_called_once() + assert result["equity"] == 100000 + + +@pytest.mark.asyncio +async def test_get_positions_returns_list(client, mock_trading): + pos_mock = MagicMock(model_dump=lambda: {"symbol": "AAPL", "qty": 10}) + mock_trading.get_all_positions.return_value = [pos_mock] + result = await client.get_positions() + assert len(result) == 1 + assert result[0]["symbol"] == "AAPL" + + +@pytest.mark.asyncio +async def test_place_market_order_stocks(client, mock_trading): + order_mock = MagicMock(model_dump=lambda: {"id": "o123", "symbol": "AAPL"}) + mock_trading.submit_order.return_value = order_mock + result = await client.place_order( + symbol="AAPL", side="buy", qty=1, order_type="market", asset_class="stocks", + ) + assert result["id"] == "o123" + assert mock_trading.submit_order.called + + +@pytest.mark.asyncio +async def test_place_limit_order_requires_price(client): + with pytest.raises(ValueError, match="limit_price"): + await client.place_order( + symbol="AAPL", side="buy", qty=1, order_type="limit", + ) + + +@pytest.mark.asyncio +async def test_cancel_order(client, mock_trading): + mock_trading.cancel_order_by_id.return_value = None + result = await client.cancel_order("o1") + mock_trading.cancel_order_by_id.assert_called_once_with("o1") + assert result == {"order_id": "o1", "canceled": True} + + +@pytest.mark.asyncio +async def test_close_position_no_options(client, mock_trading): + order_mock = MagicMock(model_dump=lambda: {"id": "close-1"}) + mock_trading.close_position.return_value = order_mock + result = await client.close_position("AAPL") + assert mock_trading.close_position.called + assert result["id"] == "close-1" + + +@pytest.mark.asyncio +async def test_get_clock(client, mock_trading): + clock_mock = MagicMock(model_dump=lambda: {"is_open": True, "next_close": "2026-04-21T20:00:00Z"}) + mock_trading.get_clock.return_value = clock_mock + result = await client.get_clock() + assert result["is_open"] is True + + +@pytest.mark.asyncio +async def test_invalid_asset_class(client): + with pytest.raises(ValueError, match="invalid asset_class"): + await client.get_ticker("AAPL", asset_class="forex") diff --git a/tests/unit/exchanges/alpaca/test_leverage_cap.py b/tests/unit/exchanges/alpaca/test_leverage_cap.py new file mode 100644 index 0000000..c073e20 --- /dev/null +++ b/tests/unit/exchanges/alpaca/test_leverage_cap.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pytest +from fastapi import HTTPException +from cerbero_mcp.exchanges.alpaca.leverage_cap import enforce_leverage, get_max_leverage + + +def test_get_max_leverage_returns_creds_value(): + creds = {"max_leverage": 4} + assert get_max_leverage(creds) == 4 + + +def test_get_max_leverage_default_when_missing(): + """Default 1 (cash) se il secret non ha max_leverage.""" + assert get_max_leverage({}) == 1 + + +def test_enforce_leverage_pass_at_cap_one(): + """Alpaca cash account: cap 1, leverage 1 OK.""" + creds = {"max_leverage": 1} + enforce_leverage(1, creds=creds, exchange="alpaca") # no raise + + +def test_enforce_leverage_reject_over_cap_one(): + creds = {"max_leverage": 1} + with pytest.raises(HTTPException) as exc: + enforce_leverage(2, creds=creds, exchange="alpaca") + assert exc.value.status_code == 403 + assert exc.value.detail["error"] == "LEVERAGE_CAP_EXCEEDED" + assert exc.value.detail["exchange"] == "alpaca" + assert exc.value.detail["requested"] == 2 + assert exc.value.detail["max"] == 1 + + +def test_enforce_leverage_reject_when_below_one(): + creds = {"max_leverage": 1} + with pytest.raises(HTTPException) as exc: + enforce_leverage(0, creds=creds, exchange="alpaca") + assert exc.value.status_code == 403 + + +def test_enforce_leverage_default_when_none(): + """Se requested è None, applica il cap come default.""" + creds = {"max_leverage": 1} + result = enforce_leverage(None, creds=creds, exchange="alpaca") + assert result == 1 diff --git a/tests/unit/test_exchanges_builder.py b/tests/unit/test_exchanges_builder.py index bc8f1c1..ec8b9bd 100644 --- a/tests/unit/test_exchanges_builder.py +++ b/tests/unit/test_exchanges_builder.py @@ -71,6 +71,37 @@ async def test_build_client_hyperliquid_returns_correct_env(monkeypatch): assert "test" not in c_live.base_url.lower() +@pytest.mark.asyncio +async def test_build_client_alpaca_returns_correct_env(monkeypatch): + from tests.unit.test_settings import _minimal_env + + for k, v in _minimal_env().items(): + monkeypatch.setenv(k, v) + + # Stub alpaca SDK clients per evitare connessioni reali in __init__ + from cerbero_mcp.exchanges.alpaca import client as alpaca_client + + class _FakeSdk: + def __init__(self, **kwargs): + self.kwargs = kwargs + + monkeypatch.setattr(alpaca_client, "TradingClient", _FakeSdk) + monkeypatch.setattr(alpaca_client, "StockHistoricalDataClient", _FakeSdk) + monkeypatch.setattr(alpaca_client, "CryptoHistoricalDataClient", _FakeSdk) + monkeypatch.setattr(alpaca_client, "OptionHistoricalDataClient", _FakeSdk) + + from cerbero_mcp.settings import Settings + from cerbero_mcp.exchanges import build_client + + s = Settings() + c_test = await build_client(s, "alpaca", "testnet") + c_live = await build_client(s, "alpaca", "mainnet") + + assert c_test is not c_live + assert c_test.paper is True + assert c_live.paper is False + + @pytest.mark.asyncio async def test_build_client_unknown_exchange_raises(monkeypatch): from tests.unit.test_settings import _minimal_env