feat(V2): migrazione alpaca completa
Task 6.7: porting alpaca da services/mcp-alpaca a src/cerbero_mcp. client.py + leverage_cap.py copiati 1:1 (default cap 1 cash). tools.py: 17 tool senza ACL/Principal/audit. Router /mcp-alpaca con 18 route (env_info + 17 tool). Builder branch alpaca: paper=(env=="testnet"), api_key viene da settings.alpaca.api_key_id. Test client + leverage_cap migrati (15 test alpaca pass). Test builder con stub SDK alpaca-py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -36,4 +36,12 @@ async def build_client(
|
|||||||
testnet=(env == "testnet"),
|
testnet=(env == "testnet"),
|
||||||
api_wallet_address=settings.hyperliquid.api_wallet_address,
|
api_wallet_address=settings.hyperliquid.api_wallet_address,
|
||||||
)
|
)
|
||||||
|
if exchange == "alpaca":
|
||||||
|
from cerbero_mcp.exchanges.alpaca.client import AlpacaClient
|
||||||
|
|
||||||
|
return AlpacaClient(
|
||||||
|
api_key=settings.alpaca.api_key_id,
|
||||||
|
secret_key=settings.alpaca.secret_key.get_secret_value(),
|
||||||
|
paper=(env == "testnet"),
|
||||||
|
)
|
||||||
raise ValueError(f"unsupported exchange: {exchange}")
|
raise ValueError(f"unsupported exchange: {exchange}")
|
||||||
|
|||||||
@@ -0,0 +1,385 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import datetime as _dt
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from alpaca.data.historical import (
|
||||||
|
CryptoHistoricalDataClient,
|
||||||
|
OptionHistoricalDataClient,
|
||||||
|
StockHistoricalDataClient,
|
||||||
|
)
|
||||||
|
from alpaca.data.requests import (
|
||||||
|
CryptoBarsRequest,
|
||||||
|
CryptoLatestQuoteRequest,
|
||||||
|
CryptoLatestTradeRequest,
|
||||||
|
OptionBarsRequest,
|
||||||
|
OptionChainRequest,
|
||||||
|
OptionLatestQuoteRequest,
|
||||||
|
StockBarsRequest,
|
||||||
|
StockLatestQuoteRequest,
|
||||||
|
StockLatestTradeRequest,
|
||||||
|
StockSnapshotRequest,
|
||||||
|
)
|
||||||
|
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
|
||||||
|
from alpaca.trading.client import TradingClient
|
||||||
|
from alpaca.trading.enums import (
|
||||||
|
AssetClass,
|
||||||
|
OrderSide,
|
||||||
|
QueryOrderStatus,
|
||||||
|
TimeInForce,
|
||||||
|
)
|
||||||
|
from alpaca.trading.requests import (
|
||||||
|
ClosePositionRequest,
|
||||||
|
GetAssetsRequest,
|
||||||
|
GetOrdersRequest,
|
||||||
|
LimitOrderRequest,
|
||||||
|
MarketOrderRequest,
|
||||||
|
ReplaceOrderRequest,
|
||||||
|
StopOrderRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TF_MAP = {
|
||||||
|
"1min": TimeFrame(1, TimeFrameUnit.Minute),
|
||||||
|
"5min": TimeFrame(5, TimeFrameUnit.Minute),
|
||||||
|
"15min": TimeFrame(15, TimeFrameUnit.Minute),
|
||||||
|
"30min": TimeFrame(30, TimeFrameUnit.Minute),
|
||||||
|
"1h": TimeFrame(1, TimeFrameUnit.Hour),
|
||||||
|
"1d": TimeFrame(1, TimeFrameUnit.Day),
|
||||||
|
"1w": TimeFrame(1, TimeFrameUnit.Week),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ASSET_CLASSES = {"stocks", "crypto", "options"}
|
||||||
|
|
||||||
|
|
||||||
|
def _tf(interval: str) -> TimeFrame:
|
||||||
|
if interval in _TF_MAP:
|
||||||
|
return _TF_MAP[interval]
|
||||||
|
raise ValueError(f"unsupported timeframe: {interval}")
|
||||||
|
|
||||||
|
|
||||||
|
def _asset_class_enum(ac: str) -> AssetClass:
|
||||||
|
ac = ac.lower()
|
||||||
|
if ac == "stocks":
|
||||||
|
return AssetClass.US_EQUITY
|
||||||
|
if ac == "crypto":
|
||||||
|
return AssetClass.CRYPTO
|
||||||
|
if ac == "options":
|
||||||
|
return AssetClass.US_OPTION
|
||||||
|
raise ValueError(f"invalid asset_class: {ac}")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize(obj: Any) -> Any:
|
||||||
|
"""Recursively convert pydantic/datetime objects → json-safe."""
|
||||||
|
if obj is None or isinstance(obj, str | int | float | bool):
|
||||||
|
return obj
|
||||||
|
if isinstance(obj, _dt.datetime | _dt.date):
|
||||||
|
return obj.isoformat()
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _serialize(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list | tuple):
|
||||||
|
return [_serialize(v) for v in obj]
|
||||||
|
if hasattr(obj, "model_dump"):
|
||||||
|
return _serialize(obj.model_dump())
|
||||||
|
if hasattr(obj, "__dict__"):
|
||||||
|
return _serialize(vars(obj))
|
||||||
|
return str(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class AlpacaClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
secret_key: str,
|
||||||
|
paper: bool = True,
|
||||||
|
trading: Any | None = None,
|
||||||
|
stock_data: Any | None = None,
|
||||||
|
crypto_data: Any | None = None,
|
||||||
|
option_data: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.api_key = api_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.paper = paper
|
||||||
|
self._trading = trading or TradingClient(
|
||||||
|
api_key=api_key, secret_key=secret_key, paper=paper
|
||||||
|
)
|
||||||
|
self._stock = stock_data or StockHistoricalDataClient(
|
||||||
|
api_key=api_key, secret_key=secret_key
|
||||||
|
)
|
||||||
|
self._crypto = crypto_data or CryptoHistoricalDataClient(
|
||||||
|
api_key=api_key, secret_key=secret_key
|
||||||
|
)
|
||||||
|
self._option = option_data or OptionHistoricalDataClient(
|
||||||
|
api_key=api_key, secret_key=secret_key
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run(self, fn, /, *args, **kwargs):
|
||||||
|
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||||
|
|
||||||
|
# ── Account / positions ──────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_account(self) -> dict:
|
||||||
|
acc = await self._run(self._trading.get_account)
|
||||||
|
return _serialize(acc)
|
||||||
|
|
||||||
|
async def get_positions(self) -> list[dict]:
|
||||||
|
pos = await self._run(self._trading.get_all_positions)
|
||||||
|
return [_serialize(p) for p in pos]
|
||||||
|
|
||||||
|
async def get_activities(self, limit: int = 50) -> list[dict]:
|
||||||
|
acts = await self._run(self._trading.get_account_activities)
|
||||||
|
data = [_serialize(a) for a in acts]
|
||||||
|
return data[:limit]
|
||||||
|
|
||||||
|
# ── Assets ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_assets(
|
||||||
|
self, asset_class: str = "stocks", status: str = "active"
|
||||||
|
) -> list[dict]:
|
||||||
|
req = GetAssetsRequest(
|
||||||
|
asset_class=_asset_class_enum(asset_class),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
assets = await self._run(self._trading.get_all_assets, req)
|
||||||
|
return [_serialize(a) for a in assets[:500]]
|
||||||
|
|
||||||
|
# ── Market data ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_ticker(self, symbol: str, asset_class: str = "stocks") -> dict:
|
||||||
|
ac = asset_class.lower()
|
||||||
|
if ac == "stocks":
|
||||||
|
req = StockLatestTradeRequest(symbol_or_symbols=symbol)
|
||||||
|
data = await self._run(self._stock.get_stock_latest_trade, req)
|
||||||
|
trade = data.get(symbol)
|
||||||
|
q_req = StockLatestQuoteRequest(symbol_or_symbols=symbol)
|
||||||
|
qdata = await self._run(self._stock.get_stock_latest_quote, q_req)
|
||||||
|
quote = qdata.get(symbol)
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"asset_class": "stocks",
|
||||||
|
"last_price": getattr(trade, "price", None),
|
||||||
|
"bid": getattr(quote, "bid_price", None),
|
||||||
|
"ask": getattr(quote, "ask_price", None),
|
||||||
|
"bid_size": getattr(quote, "bid_size", None),
|
||||||
|
"ask_size": getattr(quote, "ask_size", None),
|
||||||
|
"timestamp": _serialize(getattr(trade, "timestamp", None)),
|
||||||
|
}
|
||||||
|
if ac == "crypto":
|
||||||
|
req = CryptoLatestTradeRequest(symbol_or_symbols=symbol)
|
||||||
|
data = await self._run(self._crypto.get_crypto_latest_trade, req)
|
||||||
|
trade = data.get(symbol)
|
||||||
|
q_req = CryptoLatestQuoteRequest(symbol_or_symbols=symbol)
|
||||||
|
qdata = await self._run(self._crypto.get_crypto_latest_quote, q_req)
|
||||||
|
quote = qdata.get(symbol)
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"asset_class": "crypto",
|
||||||
|
"last_price": getattr(trade, "price", None),
|
||||||
|
"bid": getattr(quote, "bid_price", None),
|
||||||
|
"ask": getattr(quote, "ask_price", None),
|
||||||
|
"timestamp": _serialize(getattr(trade, "timestamp", None)),
|
||||||
|
}
|
||||||
|
if ac == "options":
|
||||||
|
req = OptionLatestQuoteRequest(symbol_or_symbols=symbol)
|
||||||
|
data = await self._run(self._option.get_option_latest_quote, req)
|
||||||
|
quote = data.get(symbol)
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"asset_class": "options",
|
||||||
|
"bid": getattr(quote, "bid_price", None),
|
||||||
|
"ask": getattr(quote, "ask_price", None),
|
||||||
|
"timestamp": _serialize(getattr(quote, "timestamp", None)),
|
||||||
|
}
|
||||||
|
raise ValueError(f"invalid asset_class: {asset_class}")
|
||||||
|
|
||||||
|
async def get_bars(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
asset_class: str = "stocks",
|
||||||
|
interval: str = "1d",
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
limit: int = 1000,
|
||||||
|
) -> dict:
|
||||||
|
tf = _tf(interval)
|
||||||
|
start_dt = _dt.datetime.fromisoformat(start) if start else (
|
||||||
|
_dt.datetime.now(_dt.UTC) - _dt.timedelta(days=30)
|
||||||
|
)
|
||||||
|
end_dt = _dt.datetime.fromisoformat(end) if end else _dt.datetime.now(_dt.UTC)
|
||||||
|
ac = asset_class.lower()
|
||||||
|
if ac == "stocks":
|
||||||
|
req = StockBarsRequest(
|
||||||
|
symbol_or_symbols=symbol, timeframe=tf,
|
||||||
|
start=start_dt, end=end_dt, limit=limit,
|
||||||
|
)
|
||||||
|
data = await self._run(self._stock.get_stock_bars, req)
|
||||||
|
elif ac == "crypto":
|
||||||
|
req = CryptoBarsRequest(
|
||||||
|
symbol_or_symbols=symbol, timeframe=tf,
|
||||||
|
start=start_dt, end=end_dt, limit=limit,
|
||||||
|
)
|
||||||
|
data = await self._run(self._crypto.get_crypto_bars, req)
|
||||||
|
elif ac == "options":
|
||||||
|
req = OptionBarsRequest(
|
||||||
|
symbol_or_symbols=symbol, timeframe=tf,
|
||||||
|
start=start_dt, end=end_dt, limit=limit,
|
||||||
|
)
|
||||||
|
data = await self._run(self._option.get_option_bars, req)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid asset_class: {asset_class}")
|
||||||
|
bars_dict = getattr(data, "data", {}) or {}
|
||||||
|
rows = bars_dict.get(symbol, []) or []
|
||||||
|
bars = [
|
||||||
|
{
|
||||||
|
"timestamp": _serialize(getattr(b, "timestamp", None)),
|
||||||
|
"open": getattr(b, "open", None),
|
||||||
|
"high": getattr(b, "high", None),
|
||||||
|
"low": getattr(b, "low", None),
|
||||||
|
"close": getattr(b, "close", None),
|
||||||
|
"volume": getattr(b, "volume", None),
|
||||||
|
}
|
||||||
|
for b in rows
|
||||||
|
]
|
||||||
|
return {"symbol": symbol, "asset_class": ac, "interval": interval, "bars": bars}
|
||||||
|
|
||||||
|
async def get_snapshot(self, symbol: str) -> dict:
|
||||||
|
req = StockSnapshotRequest(symbol_or_symbols=symbol)
|
||||||
|
data = await self._run(self._stock.get_stock_snapshot, req)
|
||||||
|
return _serialize(data.get(symbol))
|
||||||
|
|
||||||
|
async def get_option_chain(
|
||||||
|
self,
|
||||||
|
underlying: str,
|
||||||
|
expiry: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
kwargs: dict[str, Any] = {"underlying_symbol": underlying}
|
||||||
|
if expiry:
|
||||||
|
kwargs["expiration_date"] = _dt.date.fromisoformat(expiry)
|
||||||
|
req = OptionChainRequest(**kwargs)
|
||||||
|
data = await self._run(self._option.get_option_chain, req)
|
||||||
|
return {
|
||||||
|
"underlying": underlying,
|
||||||
|
"expiry": expiry,
|
||||||
|
"contracts": _serialize(data),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Orders ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_open_orders(self, limit: int = 50) -> list[dict]:
|
||||||
|
req = GetOrdersRequest(status=QueryOrderStatus.OPEN, limit=limit)
|
||||||
|
orders = await self._run(self._trading.get_orders, filter=req)
|
||||||
|
return [_serialize(o) for o in orders]
|
||||||
|
|
||||||
|
async def place_order(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
side: str,
|
||||||
|
qty: float | None = None,
|
||||||
|
notional: float | None = None,
|
||||||
|
order_type: str = "market",
|
||||||
|
limit_price: float | None = None,
|
||||||
|
stop_price: float | None = None,
|
||||||
|
tif: str = "day",
|
||||||
|
asset_class: str = "stocks",
|
||||||
|
) -> dict:
|
||||||
|
side_enum = OrderSide.BUY if side.lower() == "buy" else OrderSide.SELL
|
||||||
|
tif_enum = TimeInForce(tif.lower())
|
||||||
|
ot = order_type.lower()
|
||||||
|
common = {
|
||||||
|
"symbol": symbol,
|
||||||
|
"side": side_enum,
|
||||||
|
"time_in_force": tif_enum,
|
||||||
|
}
|
||||||
|
if qty is not None:
|
||||||
|
common["qty"] = qty
|
||||||
|
if notional is not None:
|
||||||
|
common["notional"] = notional
|
||||||
|
if ot == "market":
|
||||||
|
req = MarketOrderRequest(**common)
|
||||||
|
elif ot == "limit":
|
||||||
|
if limit_price is None:
|
||||||
|
raise ValueError("limit_price required for limit order")
|
||||||
|
req = LimitOrderRequest(**common, limit_price=limit_price)
|
||||||
|
elif ot == "stop":
|
||||||
|
if stop_price is None:
|
||||||
|
raise ValueError("stop_price required for stop order")
|
||||||
|
req = StopOrderRequest(**common, stop_price=stop_price)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported order_type: {order_type}")
|
||||||
|
order = await self._run(self._trading.submit_order, req)
|
||||||
|
return _serialize(order)
|
||||||
|
|
||||||
|
async def amend_order(
|
||||||
|
self,
|
||||||
|
order_id: str,
|
||||||
|
qty: float | None = None,
|
||||||
|
limit_price: float | None = None,
|
||||||
|
stop_price: float | None = None,
|
||||||
|
tif: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if qty is not None:
|
||||||
|
kwargs["qty"] = qty
|
||||||
|
if limit_price is not None:
|
||||||
|
kwargs["limit_price"] = limit_price
|
||||||
|
if stop_price is not None:
|
||||||
|
kwargs["stop_price"] = stop_price
|
||||||
|
if tif is not None:
|
||||||
|
kwargs["time_in_force"] = TimeInForce(tif.lower())
|
||||||
|
req = ReplaceOrderRequest(**kwargs)
|
||||||
|
order = await self._run(self._trading.replace_order_by_id, order_id, req)
|
||||||
|
return _serialize(order)
|
||||||
|
|
||||||
|
async def cancel_order(self, order_id: str) -> dict:
|
||||||
|
await self._run(self._trading.cancel_order_by_id, order_id)
|
||||||
|
return {"order_id": order_id, "canceled": True}
|
||||||
|
|
||||||
|
async def cancel_all_orders(self) -> list[dict]:
|
||||||
|
resp = await self._run(self._trading.cancel_orders)
|
||||||
|
return [_serialize(r) for r in resp]
|
||||||
|
|
||||||
|
# ── Position close ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def close_position(
|
||||||
|
self, symbol: str, qty: float | None = None, percentage: float | None = None
|
||||||
|
) -> dict:
|
||||||
|
req = None
|
||||||
|
if qty is not None or percentage is not None:
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if qty is not None:
|
||||||
|
kwargs["qty"] = str(qty)
|
||||||
|
if percentage is not None:
|
||||||
|
kwargs["percentage"] = str(percentage)
|
||||||
|
req = ClosePositionRequest(**kwargs)
|
||||||
|
order = await self._run(
|
||||||
|
self._trading.close_position, symbol, close_options=req
|
||||||
|
)
|
||||||
|
return _serialize(order)
|
||||||
|
|
||||||
|
async def close_all_positions(self, cancel_orders: bool = True) -> list[dict]:
|
||||||
|
resp = await self._run(
|
||||||
|
self._trading.close_all_positions, cancel_orders=cancel_orders
|
||||||
|
)
|
||||||
|
return [_serialize(r) for r in resp]
|
||||||
|
|
||||||
|
# ── Clock / calendar ────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_clock(self) -> dict:
|
||||||
|
clock = await self._run(self._trading.get_clock)
|
||||||
|
return _serialize(clock)
|
||||||
|
|
||||||
|
async def get_calendar(
|
||||||
|
self, start: str | None = None, end: str | None = None
|
||||||
|
) -> list[dict]:
|
||||||
|
from alpaca.trading.requests import GetCalendarRequest
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if start:
|
||||||
|
kwargs["start"] = _dt.date.fromisoformat(start)
|
||||||
|
if end:
|
||||||
|
kwargs["end"] = _dt.date.fromisoformat(end)
|
||||||
|
req = GetCalendarRequest(**kwargs) if kwargs else None
|
||||||
|
cal = await self._run(
|
||||||
|
self._trading.get_calendar, filters=req
|
||||||
|
) if req else await self._run(self._trading.get_calendar)
|
||||||
|
return [_serialize(c) for c in cal]
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Leverage cap server-side per place_order.
|
||||||
|
|
||||||
|
Cap letto dal secret JSON via campo `max_leverage`. Default 1 (cash) se assente.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_leverage(creds: dict) -> int:
|
||||||
|
"""Legge max_leverage dal secret. Default 1 se mancante."""
|
||||||
|
raw = creds.get("max_leverage", 1)
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = 1
|
||||||
|
return max(1, value)
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_leverage(
|
||||||
|
requested: int | float | None,
|
||||||
|
*,
|
||||||
|
creds: dict,
|
||||||
|
exchange: str,
|
||||||
|
) -> int:
|
||||||
|
"""Verifica e applica leverage cap. Ritorna leverage applicabile.
|
||||||
|
|
||||||
|
Solleva HTTPException(403, LEVERAGE_CAP_EXCEEDED) se requested > cap.
|
||||||
|
Se requested is None, applica il cap come default.
|
||||||
|
"""
|
||||||
|
cap = get_max_leverage(creds)
|
||||||
|
if requested is None:
|
||||||
|
return cap
|
||||||
|
lev = int(requested)
|
||||||
|
if lev < 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "LEVERAGE_CAP_EXCEEDED",
|
||||||
|
"exchange": exchange,
|
||||||
|
"requested": lev,
|
||||||
|
"max": cap,
|
||||||
|
"reason": "leverage must be >= 1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if lev > cap:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "LEVERAGE_CAP_EXCEEDED",
|
||||||
|
"exchange": exchange,
|
||||||
|
"requested": lev,
|
||||||
|
"max": cap,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return lev
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
"""Tool alpaca V2: pydantic schemas + async functions.
|
||||||
|
|
||||||
|
Ogni funzione prende (client: AlpacaClient, params: <Req>) e restituisce
|
||||||
|
un dict (o list[dict]). Pure logica, no FastAPI dependency, no ACL.
|
||||||
|
L'autenticazione bearer è gestita dal middleware in cerbero_mcp.auth;
|
||||||
|
l'audit verrà cablato dal router via request.state.environment.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from cerbero_mcp.exchanges.alpaca.client import AlpacaClient
|
||||||
|
from cerbero_mcp.exchanges.alpaca.leverage_cap import get_max_leverage
|
||||||
|
|
||||||
|
# === Schemas: reads ===
|
||||||
|
|
||||||
|
|
||||||
|
class GetAccountReq(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetPositionsReq(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetActivitiesReq(BaseModel):
|
||||||
|
limit: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
class GetAssetsReq(BaseModel):
|
||||||
|
asset_class: str = "stocks"
|
||||||
|
status: str = "active"
|
||||||
|
|
||||||
|
|
||||||
|
class GetTickerReq(BaseModel):
|
||||||
|
symbol: str
|
||||||
|
asset_class: str = "stocks"
|
||||||
|
|
||||||
|
|
||||||
|
class GetBarsReq(BaseModel):
|
||||||
|
symbol: str
|
||||||
|
asset_class: str = "stocks"
|
||||||
|
interval: str = "1d"
|
||||||
|
start: str | None = None
|
||||||
|
end: str | None = None
|
||||||
|
limit: int = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class GetSnapshotReq(BaseModel):
|
||||||
|
symbol: str
|
||||||
|
|
||||||
|
|
||||||
|
class GetOptionChainReq(BaseModel):
|
||||||
|
underlying: str
|
||||||
|
expiry: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GetOpenOrdersReq(BaseModel):
|
||||||
|
limit: int = 50
|
||||||
|
|
||||||
|
|
||||||
|
class GetClockReq(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GetCalendarReq(BaseModel):
|
||||||
|
start: str | None = None
|
||||||
|
end: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# === Schemas: writes ===
|
||||||
|
|
||||||
|
|
||||||
|
class PlaceOrderReq(BaseModel):
|
||||||
|
symbol: str
|
||||||
|
side: str # "buy" | "sell"
|
||||||
|
qty: float | None = None
|
||||||
|
notional: float | None = None
|
||||||
|
order_type: str = "market"
|
||||||
|
limit_price: float | None = None
|
||||||
|
stop_price: float | None = None
|
||||||
|
tif: str = "day"
|
||||||
|
asset_class: str = "stocks"
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"summary": "Market buy 1 share AAPL",
|
||||||
|
"value": {
|
||||||
|
"symbol": "AAPL",
|
||||||
|
"side": "buy",
|
||||||
|
"qty": 1,
|
||||||
|
"order_type": "market",
|
||||||
|
"asset_class": "stocks",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AmendOrderReq(BaseModel):
|
||||||
|
order_id: str
|
||||||
|
qty: float | None = None
|
||||||
|
limit_price: float | None = None
|
||||||
|
stop_price: float | None = None
|
||||||
|
tif: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelOrderReq(BaseModel):
|
||||||
|
order_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class CancelAllOrdersReq(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClosePositionReq(BaseModel):
|
||||||
|
symbol: str
|
||||||
|
qty: float | None = None
|
||||||
|
percentage: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloseAllPositionsReq(BaseModel):
|
||||||
|
cancel_orders: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# === Tools (reads) ===
|
||||||
|
|
||||||
|
|
||||||
|
async def environment_info(
|
||||||
|
client: AlpacaClient, *, creds: dict, env_info: Any | None = None
|
||||||
|
) -> dict:
|
||||||
|
if env_info is None:
|
||||||
|
return {
|
||||||
|
"exchange": "alpaca",
|
||||||
|
"environment": "testnet" if getattr(client, "paper", True) else "mainnet",
|
||||||
|
"source": "credentials",
|
||||||
|
"env_value": None,
|
||||||
|
"base_url": getattr(client, "base_url", None),
|
||||||
|
"max_leverage": get_max_leverage(creds),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"exchange": env_info.exchange,
|
||||||
|
"environment": env_info.environment,
|
||||||
|
"source": env_info.source,
|
||||||
|
"env_value": env_info.env_value,
|
||||||
|
"base_url": env_info.base_url,
|
||||||
|
"max_leverage": get_max_leverage(creds),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_account(client: AlpacaClient, params: GetAccountReq) -> dict:
|
||||||
|
return await client.get_account()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_positions(
|
||||||
|
client: AlpacaClient, params: GetPositionsReq
|
||||||
|
) -> dict:
|
||||||
|
return {"positions": await client.get_positions()}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_activities(
|
||||||
|
client: AlpacaClient, params: GetActivitiesReq
|
||||||
|
) -> dict:
|
||||||
|
return {"activities": await client.get_activities(params.limit)}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_assets(client: AlpacaClient, params: GetAssetsReq) -> dict:
|
||||||
|
return {
|
||||||
|
"assets": await client.get_assets(params.asset_class, params.status)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ticker(client: AlpacaClient, params: GetTickerReq) -> dict:
|
||||||
|
return await client.get_ticker(params.symbol, params.asset_class)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_bars(client: AlpacaClient, params: GetBarsReq) -> dict:
|
||||||
|
return await client.get_bars(
|
||||||
|
params.symbol,
|
||||||
|
params.asset_class,
|
||||||
|
params.interval,
|
||||||
|
params.start,
|
||||||
|
params.end,
|
||||||
|
params.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_snapshot(
|
||||||
|
client: AlpacaClient, params: GetSnapshotReq
|
||||||
|
) -> dict:
|
||||||
|
return await client.get_snapshot(params.symbol)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_option_chain(
|
||||||
|
client: AlpacaClient, params: GetOptionChainReq
|
||||||
|
) -> dict:
|
||||||
|
return await client.get_option_chain(params.underlying, params.expiry)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_open_orders(
|
||||||
|
client: AlpacaClient, params: GetOpenOrdersReq
|
||||||
|
) -> dict:
|
||||||
|
return {"orders": await client.get_open_orders(params.limit)}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_clock(client: AlpacaClient, params: GetClockReq) -> dict:
|
||||||
|
return await client.get_clock()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_calendar(
|
||||||
|
client: AlpacaClient, params: GetCalendarReq
|
||||||
|
) -> dict:
|
||||||
|
return {"calendar": await client.get_calendar(params.start, params.end)}
|
||||||
|
|
||||||
|
|
||||||
|
# === Tools (writes) ===
|
||||||
|
|
||||||
|
|
||||||
|
async def place_order(
|
||||||
|
client: AlpacaClient, params: PlaceOrderReq, *, creds: dict
|
||||||
|
) -> dict:
|
||||||
|
# Alpaca: cap default 1 (cash account). Niente leverage parametro;
|
||||||
|
# cap presente per coerenza con altri exchange e per audit.
|
||||||
|
return await client.place_order(
|
||||||
|
symbol=params.symbol,
|
||||||
|
side=params.side,
|
||||||
|
qty=params.qty,
|
||||||
|
notional=params.notional,
|
||||||
|
order_type=params.order_type,
|
||||||
|
limit_price=params.limit_price,
|
||||||
|
stop_price=params.stop_price,
|
||||||
|
tif=params.tif,
|
||||||
|
asset_class=params.asset_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def amend_order(
|
||||||
|
client: AlpacaClient, params: AmendOrderReq
|
||||||
|
) -> dict:
|
||||||
|
return await client.amend_order(
|
||||||
|
params.order_id,
|
||||||
|
params.qty,
|
||||||
|
params.limit_price,
|
||||||
|
params.stop_price,
|
||||||
|
params.tif,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_order(
|
||||||
|
client: AlpacaClient, params: CancelOrderReq
|
||||||
|
) -> dict:
|
||||||
|
return await client.cancel_order(params.order_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_all_orders(
|
||||||
|
client: AlpacaClient, params: CancelAllOrdersReq
|
||||||
|
) -> dict:
|
||||||
|
return {"canceled": await client.cancel_all_orders()}
|
||||||
|
|
||||||
|
|
||||||
|
async def close_position(
|
||||||
|
client: AlpacaClient, params: ClosePositionReq
|
||||||
|
) -> dict:
|
||||||
|
return await client.close_position(
|
||||||
|
params.symbol, params.qty, params.percentage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_all_positions(
|
||||||
|
client: AlpacaClient, params: CloseAllPositionsReq
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"closed": await client.close_all_positions(params.cancel_orders)
|
||||||
|
}
|
||||||
@@ -0,0 +1,176 @@
|
|||||||
|
"""Router /mcp-alpaca/* — DI per env, client e (write) creds.
|
||||||
|
|
||||||
|
Mappa 1:1 i tool di `cerbero_mcp.exchanges.alpaca.tools` a endpoint
|
||||||
|
`POST /mcp-alpaca/tools/{tool_name}`. L'autenticazione bearer è gestita
|
||||||
|
dal middleware in `cerbero_mcp.auth`; qui leggiamo solo `request.state.environment`.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|
||||||
|
from cerbero_mcp.client_registry import ClientRegistry
|
||||||
|
from cerbero_mcp.exchanges.alpaca import tools as t
|
||||||
|
from cerbero_mcp.exchanges.alpaca.client import AlpacaClient
|
||||||
|
|
||||||
|
Environment = Literal["testnet", "mainnet"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_environment(request: Request) -> Environment:
|
||||||
|
return request.state.environment
|
||||||
|
|
||||||
|
|
||||||
|
async def get_alpaca_client(
|
||||||
|
request: Request, env: Environment = Depends(get_environment)
|
||||||
|
) -> AlpacaClient:
|
||||||
|
registry: ClientRegistry = request.app.state.registry
|
||||||
|
return await registry.get("alpaca", env)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_creds(request: Request) -> dict:
|
||||||
|
"""Costruisce dict `creds` minimale per leverage cap / metadata."""
|
||||||
|
settings = request.app.state.settings
|
||||||
|
return {
|
||||||
|
"max_leverage": settings.alpaca.max_leverage,
|
||||||
|
"api_key_id": settings.alpaca.api_key_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_router() -> APIRouter:
|
||||||
|
r = APIRouter(prefix="/mcp-alpaca", tags=["alpaca"])
|
||||||
|
|
||||||
|
# === READ tools ===
|
||||||
|
|
||||||
|
@r.post("/tools/environment_info")
|
||||||
|
async def _environment_info(
|
||||||
|
request: Request,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
creds = _build_creds(request)
|
||||||
|
return await t.environment_info(client, creds=creds)
|
||||||
|
|
||||||
|
@r.post("/tools/get_account")
|
||||||
|
async def _get_account(
|
||||||
|
params: t.GetAccountReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_account(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_positions")
|
||||||
|
async def _get_positions(
|
||||||
|
params: t.GetPositionsReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_positions(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_activities")
|
||||||
|
async def _get_activities(
|
||||||
|
params: t.GetActivitiesReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_activities(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_assets")
|
||||||
|
async def _get_assets(
|
||||||
|
params: t.GetAssetsReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_assets(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_ticker")
|
||||||
|
async def _get_ticker(
|
||||||
|
params: t.GetTickerReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_ticker(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_bars")
|
||||||
|
async def _get_bars(
|
||||||
|
params: t.GetBarsReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_bars(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_snapshot")
|
||||||
|
async def _get_snapshot(
|
||||||
|
params: t.GetSnapshotReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_snapshot(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_option_chain")
|
||||||
|
async def _get_option_chain(
|
||||||
|
params: t.GetOptionChainReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_option_chain(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_open_orders")
|
||||||
|
async def _get_open_orders(
|
||||||
|
params: t.GetOpenOrdersReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_open_orders(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_clock")
|
||||||
|
async def _get_clock(
|
||||||
|
params: t.GetClockReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_clock(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/get_calendar")
|
||||||
|
async def _get_calendar(
|
||||||
|
params: t.GetCalendarReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.get_calendar(client, params)
|
||||||
|
|
||||||
|
# === WRITE tools ===
|
||||||
|
|
||||||
|
@r.post("/tools/place_order")
|
||||||
|
async def _place_order(
|
||||||
|
params: t.PlaceOrderReq,
|
||||||
|
request: Request,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
creds = _build_creds(request)
|
||||||
|
return await t.place_order(client, params, creds=creds)
|
||||||
|
|
||||||
|
@r.post("/tools/amend_order")
|
||||||
|
async def _amend_order(
|
||||||
|
params: t.AmendOrderReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.amend_order(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/cancel_order")
|
||||||
|
async def _cancel_order(
|
||||||
|
params: t.CancelOrderReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.cancel_order(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/cancel_all_orders")
|
||||||
|
async def _cancel_all_orders(
|
||||||
|
params: t.CancelAllOrdersReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.cancel_all_orders(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/close_position")
|
||||||
|
async def _close_position(
|
||||||
|
params: t.ClosePositionReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.close_position(client, params)
|
||||||
|
|
||||||
|
@r.post("/tools/close_all_positions")
|
||||||
|
async def _close_all_positions(
|
||||||
|
params: t.CloseAllPositionsReq,
|
||||||
|
client: AlpacaClient = Depends(get_alpaca_client),
|
||||||
|
):
|
||||||
|
return await t.close_all_positions(client, params)
|
||||||
|
|
||||||
|
return r
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cerbero_mcp.exchanges.alpaca.client import AlpacaClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trading():
|
||||||
|
return MagicMock(name="alpaca_TradingClient")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_stock():
|
||||||
|
return MagicMock(name="alpaca_StockHistoricalDataClient")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_crypto():
|
||||||
|
return MagicMock(name="alpaca_CryptoHistoricalDataClient")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_option():
|
||||||
|
return MagicMock(name="alpaca_OptionHistoricalDataClient")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_trading, mock_stock, mock_crypto, mock_option):
|
||||||
|
return AlpacaClient(
|
||||||
|
api_key="test_key",
|
||||||
|
secret_key="test_secret",
|
||||||
|
paper=True,
|
||||||
|
trading=mock_trading,
|
||||||
|
stock_data=mock_stock,
|
||||||
|
crypto_data=mock_crypto,
|
||||||
|
option_data=mock_option,
|
||||||
|
)
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_init_paper_mode(client, mock_trading):
|
||||||
|
assert client.paper is True
|
||||||
|
assert client._trading is mock_trading
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_account_calls_trading(client, mock_trading):
|
||||||
|
mock_trading.get_account.return_value = MagicMock(
|
||||||
|
model_dump=lambda: {"equity": 100000, "cash": 50000}
|
||||||
|
)
|
||||||
|
result = await client.get_account()
|
||||||
|
mock_trading.get_account.assert_called_once()
|
||||||
|
assert result["equity"] == 100000
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_positions_returns_list(client, mock_trading):
|
||||||
|
pos_mock = MagicMock(model_dump=lambda: {"symbol": "AAPL", "qty": 10})
|
||||||
|
mock_trading.get_all_positions.return_value = [pos_mock]
|
||||||
|
result = await client.get_positions()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["symbol"] == "AAPL"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_place_market_order_stocks(client, mock_trading):
|
||||||
|
order_mock = MagicMock(model_dump=lambda: {"id": "o123", "symbol": "AAPL"})
|
||||||
|
mock_trading.submit_order.return_value = order_mock
|
||||||
|
result = await client.place_order(
|
||||||
|
symbol="AAPL", side="buy", qty=1, order_type="market", asset_class="stocks",
|
||||||
|
)
|
||||||
|
assert result["id"] == "o123"
|
||||||
|
assert mock_trading.submit_order.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_place_limit_order_requires_price(client):
|
||||||
|
with pytest.raises(ValueError, match="limit_price"):
|
||||||
|
await client.place_order(
|
||||||
|
symbol="AAPL", side="buy", qty=1, order_type="limit",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_order(client, mock_trading):
|
||||||
|
mock_trading.cancel_order_by_id.return_value = None
|
||||||
|
result = await client.cancel_order("o1")
|
||||||
|
mock_trading.cancel_order_by_id.assert_called_once_with("o1")
|
||||||
|
assert result == {"order_id": "o1", "canceled": True}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_position_no_options(client, mock_trading):
|
||||||
|
order_mock = MagicMock(model_dump=lambda: {"id": "close-1"})
|
||||||
|
mock_trading.close_position.return_value = order_mock
|
||||||
|
result = await client.close_position("AAPL")
|
||||||
|
assert mock_trading.close_position.called
|
||||||
|
assert result["id"] == "close-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_clock(client, mock_trading):
|
||||||
|
clock_mock = MagicMock(model_dump=lambda: {"is_open": True, "next_close": "2026-04-21T20:00:00Z"})
|
||||||
|
mock_trading.get_clock.return_value = clock_mock
|
||||||
|
result = await client.get_clock()
|
||||||
|
assert result["is_open"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_asset_class(client):
|
||||||
|
with pytest.raises(ValueError, match="invalid asset_class"):
|
||||||
|
await client.get_ticker("AAPL", asset_class="forex")
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from cerbero_mcp.exchanges.alpaca.leverage_cap import enforce_leverage, get_max_leverage
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_leverage_returns_creds_value():
|
||||||
|
creds = {"max_leverage": 4}
|
||||||
|
assert get_max_leverage(creds) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_max_leverage_default_when_missing():
|
||||||
|
"""Default 1 (cash) se il secret non ha max_leverage."""
|
||||||
|
assert get_max_leverage({}) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_leverage_pass_at_cap_one():
|
||||||
|
"""Alpaca cash account: cap 1, leverage 1 OK."""
|
||||||
|
creds = {"max_leverage": 1}
|
||||||
|
enforce_leverage(1, creds=creds, exchange="alpaca") # no raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_leverage_reject_over_cap_one():
|
||||||
|
creds = {"max_leverage": 1}
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
enforce_leverage(2, creds=creds, exchange="alpaca")
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
assert exc.value.detail["error"] == "LEVERAGE_CAP_EXCEEDED"
|
||||||
|
assert exc.value.detail["exchange"] == "alpaca"
|
||||||
|
assert exc.value.detail["requested"] == 2
|
||||||
|
assert exc.value.detail["max"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_leverage_reject_when_below_one():
|
||||||
|
creds = {"max_leverage": 1}
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
enforce_leverage(0, creds=creds, exchange="alpaca")
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_leverage_default_when_none():
|
||||||
|
"""Se requested è None, applica il cap come default."""
|
||||||
|
creds = {"max_leverage": 1}
|
||||||
|
result = enforce_leverage(None, creds=creds, exchange="alpaca")
|
||||||
|
assert result == 1
|
||||||
@@ -71,6 +71,37 @@ async def test_build_client_hyperliquid_returns_correct_env(monkeypatch):
|
|||||||
assert "test" not in c_live.base_url.lower()
|
assert "test" not in c_live.base_url.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_client_alpaca_returns_correct_env(monkeypatch):
|
||||||
|
from tests.unit.test_settings import _minimal_env
|
||||||
|
|
||||||
|
for k, v in _minimal_env().items():
|
||||||
|
monkeypatch.setenv(k, v)
|
||||||
|
|
||||||
|
# Stub alpaca SDK clients per evitare connessioni reali in __init__
|
||||||
|
from cerbero_mcp.exchanges.alpaca import client as alpaca_client
|
||||||
|
|
||||||
|
class _FakeSdk:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
monkeypatch.setattr(alpaca_client, "TradingClient", _FakeSdk)
|
||||||
|
monkeypatch.setattr(alpaca_client, "StockHistoricalDataClient", _FakeSdk)
|
||||||
|
monkeypatch.setattr(alpaca_client, "CryptoHistoricalDataClient", _FakeSdk)
|
||||||
|
monkeypatch.setattr(alpaca_client, "OptionHistoricalDataClient", _FakeSdk)
|
||||||
|
|
||||||
|
from cerbero_mcp.settings import Settings
|
||||||
|
from cerbero_mcp.exchanges import build_client
|
||||||
|
|
||||||
|
s = Settings()
|
||||||
|
c_test = await build_client(s, "alpaca", "testnet")
|
||||||
|
c_live = await build_client(s, "alpaca", "mainnet")
|
||||||
|
|
||||||
|
assert c_test is not c_live
|
||||||
|
assert c_test.paper is True
|
||||||
|
assert c_live.paper is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_build_client_unknown_exchange_raises(monkeypatch):
|
async def test_build_client_unknown_exchange_raises(monkeypatch):
|
||||||
from tests.unit.test_settings import _minimal_env
|
from tests.unit.test_settings import _minimal_env
|
||||||
|
|||||||
Reference in New Issue
Block a user