diff --git a/src/cerbero_mcp/exchanges/alpaca/client.py b/src/cerbero_mcp/exchanges/alpaca/client.py index 5af0988..0d448cb 100644 --- a/src/cerbero_mcp/exchanges/alpaca/client.py +++ b/src/cerbero_mcp/exchanges/alpaca/client.py @@ -1,204 +1,248 @@ +"""Alpaca client su httpx puro (V2.0.0). + +Riscrittura full-REST del client `alpaca-py` originale: 4 endpoint base +(trading, stock data, crypto data, options data), auth via header +APCA-API-KEY-ID / APCA-API-SECRET-KEY, parità completa con la versione V1 +(stesse firme, stessa shape dei dict ritornati). + +- `base_url` parametro override applica SOLO al trading endpoint + (coerente con `url_override` di alpaca-py.TradingClient). Gli endpoint + data restano hardcoded su `https://data.alpaca.markets`. +- I metodi ritornano `dict` / `list[dict]` direttamente dal JSON REST + (al posto dei modelli pydantic alpaca-py serializzati). Le chiavi sono + quelle restituite dall'API Alpaca; equivalgono al `model_dump()` dei + modelli SDK precedenti. +""" from __future__ import annotations -import asyncio import datetime as _dt from typing import Any -from alpaca.data.historical import ( - CryptoHistoricalDataClient, - OptionHistoricalDataClient, - StockHistoricalDataClient, -) -from alpaca.data.requests import ( - CryptoBarsRequest, - CryptoLatestQuoteRequest, - CryptoLatestTradeRequest, - OptionBarsRequest, - OptionChainRequest, - OptionLatestQuoteRequest, - StockBarsRequest, - StockLatestQuoteRequest, - StockLatestTradeRequest, - StockSnapshotRequest, -) -from alpaca.data.timeframe import TimeFrame, TimeFrameUnit -from alpaca.trading.client import TradingClient -from alpaca.trading.enums import ( - AssetClass, - OrderSide, - QueryOrderStatus, - TimeInForce, -) -from alpaca.trading.requests import ( - ClosePositionRequest, - GetAssetsRequest, - GetOrdersRequest, - LimitOrderRequest, - MarketOrderRequest, - ReplaceOrderRequest, - StopOrderRequest, -) +import httpx +from cerbero_mcp.common.http import async_client + +# ── Endpoint base ──────────────────────────────────────────────── +_TRADING_LIVE = "https://api.alpaca.markets" +_TRADING_PAPER = "https://paper-api.alpaca.markets" +_DATA = "https://data.alpaca.markets" + +# ── Mappa timeframe → query param Alpaca ───────────────────────── +# Alpaca v2 bars: timeframe = "1Min" / "5Min" / "15Min" / "30Min" / "1Hour" / "1Day" / "1Week" _TF_MAP = { - "1min": TimeFrame(1, TimeFrameUnit.Minute), - "5min": TimeFrame(5, TimeFrameUnit.Minute), - "15min": TimeFrame(15, TimeFrameUnit.Minute), - "30min": TimeFrame(30, TimeFrameUnit.Minute), - "1h": TimeFrame(1, TimeFrameUnit.Hour), - "1d": TimeFrame(1, TimeFrameUnit.Day), - "1w": TimeFrame(1, TimeFrameUnit.Week), + "1min": "1Min", + "5min": "5Min", + "15min": "15Min", + "30min": "30Min", + "1h": "1Hour", + "1d": "1Day", + "1w": "1Week", } -_ASSET_CLASSES = {"stocks", "crypto", "options"} +_ASSET_CLASS_MAP = { + "stocks": "us_equity", + "crypto": "crypto", + "options": "us_option", +} -def _tf(interval: str) -> TimeFrame: +def _tf(interval: str) -> str: if interval in _TF_MAP: return _TF_MAP[interval] raise ValueError(f"unsupported timeframe: {interval}") -def _asset_class_enum(ac: str) -> AssetClass: +def _asset_class_param(ac: str) -> str: ac = ac.lower() - if ac == "stocks": - return AssetClass.US_EQUITY - if ac == "crypto": - return AssetClass.CRYPTO - if ac == "options": - return AssetClass.US_OPTION + if ac in _ASSET_CLASS_MAP: + return _ASSET_CLASS_MAP[ac] raise ValueError(f"invalid asset_class: {ac}") -def _serialize(obj: Any) -> Any: - """Recursively convert pydantic/datetime objects → json-safe.""" - if obj is None or isinstance(obj, str | int | float | bool): - return obj - if isinstance(obj, _dt.datetime | _dt.date): - return obj.isoformat() - if isinstance(obj, dict): - return {k: _serialize(v) for k, v in obj.items()} - if isinstance(obj, list | tuple): - return [_serialize(v) for v in obj] - if hasattr(obj, "model_dump"): - return _serialize(obj.model_dump()) - if hasattr(obj, "__dict__"): - return _serialize(vars(obj)) - return str(obj) +def _iso(value: _dt.datetime | _dt.date | None) -> str | None: + if value is None: + return None + return value.isoformat() class AlpacaClient: + """Client httpx-based per Alpaca REST API v2. + + Auth via header `APCA-API-KEY-ID` / `APCA-API-SECRET-KEY`. + """ + def __init__( self, api_key: str, secret_key: str, paper: bool = True, base_url: str | None = None, - trading: Any | None = None, - stock_data: Any | None = None, - crypto_data: Any | None = None, - option_data: Any | None = None, + http: httpx.AsyncClient | None = None, ) -> None: self.api_key = api_key self.secret_key = secret_key self.paper = paper + # `base_url` mantenuto come attributo pubblico (test/build_client lo + # leggono). Override del solo endpoint trading; data endpoints sono + # sempre `data.alpaca.markets` (Alpaca non offre paper data feed). self.base_url = base_url - # alpaca-py TradingClient accetta `url_override` per override URL trading. - # Data clients (Stock/Crypto/Option) non supportano url_override sul costruttore; - # usano endpoint dati separati (data.alpaca.markets) — `base_url` è ignorato per essi. - if trading is None: - trading_kwargs: dict[str, Any] = { - "api_key": api_key, "secret_key": secret_key, "paper": paper, - } - if base_url: - trading_kwargs["url_override"] = base_url - trading = TradingClient(**trading_kwargs) - self._trading = trading - self._stock = stock_data or StockHistoricalDataClient( - api_key=api_key, secret_key=secret_key - ) - self._crypto = crypto_data or CryptoHistoricalDataClient( - api_key=api_key, secret_key=secret_key - ) - self._option = option_data or OptionHistoricalDataClient( - api_key=api_key, secret_key=secret_key - ) + if base_url: + self._trading_base = base_url + else: + self._trading_base = _TRADING_PAPER if paper else _TRADING_LIVE + self._data_base = _DATA + # Single long-lived AsyncClient → reuse connection pool. + self._http = http or async_client(timeout=30.0) - async def _run(self, fn, /, *args, **kwargs): - return await asyncio.to_thread(fn, *args, **kwargs) + async def aclose(self) -> None: + """Chiudi connessioni HTTP. Idempotente.""" + if not self._http.is_closed: + await self._http.aclose() + + # ── Helpers ────────────────────────────────────────────────── + + @property + def _headers(self) -> dict[str, str]: + return { + "APCA-API-KEY-ID": self.api_key, + "APCA-API-SECRET-KEY": self.secret_key, + "Accept": "application/json", + } + + async def _request( + self, + method: str, + base: str, + path: str, + *, + params: dict[str, Any] | None = None, + json_body: dict[str, Any] | None = None, + ) -> Any: + """Esegue una richiesta HTTP autenticata e ritorna il JSON parsato. + + Per response body vuoto (es. DELETE 204) ritorna `{}`. + Solleva `httpx.HTTPStatusError` su 4xx/5xx tramite raise_for_status. + """ + url = f"{base}{path}" + # httpx scarta i query params con valore None automaticamente solo se + # passati come list of tuples; con dict dobbiamo filtrare a monte. + clean_params: dict[str, Any] | None = None + if params is not None: + clean_params = {k: v for k, v in params.items() if v is not None} + if not clean_params: + clean_params = None + resp = await self._http.request( + method, + url, + params=clean_params, + json=json_body, + headers=self._headers, + ) + resp.raise_for_status() + if not resp.content: + return {} + return resp.json() # ── Account / positions ────────────────────────────────────── async def get_account(self) -> dict: - acc = await self._run(self._trading.get_account) - return _serialize(acc) # type: ignore[no-any-return] + data = await self._request("GET", self._trading_base, "/v2/account") + return dict(data) if data else {} async def get_positions(self) -> list[dict]: - pos = await self._run(self._trading.get_all_positions) - return [_serialize(p) for p in pos] + data = await self._request("GET", self._trading_base, "/v2/positions") + return list(data) if data else [] async def get_activities(self, limit: int = 50) -> list[dict]: - acts = await self._run(self._trading.get_account_activities) # type: ignore[union-attr] - data = [_serialize(a) for a in acts] - return data[:limit] + data = await self._request( + "GET", + self._trading_base, + "/v2/account/activities", + params={"page_size": limit}, + ) + items = list(data) if data else [] + return items[:limit] # ── Assets ────────────────────────────────────────────────── async def get_assets( self, asset_class: str = "stocks", status: str = "active" ) -> list[dict]: - req = GetAssetsRequest( - asset_class=_asset_class_enum(asset_class), - status=status, # type: ignore[arg-type] + data = await self._request( + "GET", + self._trading_base, + "/v2/assets", + params={ + "status": status, + "asset_class": _asset_class_param(asset_class), + }, ) - assets = await self._run(self._trading.get_all_assets, req) - return [_serialize(a) for a in assets[:500]] + items = list(data) if data else [] + return items[:500] # ── Market data ───────────────────────────────────────────── async def get_ticker(self, symbol: str, asset_class: str = "stocks") -> dict: ac = asset_class.lower() if ac == "stocks": - req = StockLatestTradeRequest(symbol_or_symbols=symbol) - data = await self._run(self._stock.get_stock_latest_trade, req) - trade = data.get(symbol) - q_req = StockLatestQuoteRequest(symbol_or_symbols=symbol) - qdata = await self._run(self._stock.get_stock_latest_quote, q_req) - quote = qdata.get(symbol) + trade_resp = await self._request( + "GET", + self._data_base, + f"/v2/stocks/{symbol}/trades/latest", + ) + quote_resp = await self._request( + "GET", + self._data_base, + f"/v2/stocks/{symbol}/quotes/latest", + ) + trade = (trade_resp or {}).get("trade") or {} + quote = (quote_resp or {}).get("quote") or {} return { "symbol": symbol, "asset_class": "stocks", - "last_price": getattr(trade, "price", None), - "bid": getattr(quote, "bid_price", None), - "ask": getattr(quote, "ask_price", None), - "bid_size": getattr(quote, "bid_size", None), - "ask_size": getattr(quote, "ask_size", None), - "timestamp": _serialize(getattr(trade, "timestamp", None)), + "last_price": trade.get("p"), + "bid": quote.get("bp"), + "ask": quote.get("ap"), + "bid_size": quote.get("bs"), + "ask_size": quote.get("as"), + "timestamp": trade.get("t"), } if ac == "crypto": - req = CryptoLatestTradeRequest(symbol_or_symbols=symbol) # type: ignore[assignment] - data = await self._run(self._crypto.get_crypto_latest_trade, req) - trade = data.get(symbol) - q_req = CryptoLatestQuoteRequest(symbol_or_symbols=symbol) # type: ignore[assignment] - qdata = await self._run(self._crypto.get_crypto_latest_quote, q_req) - quote = qdata.get(symbol) + trade_resp = await self._request( + "GET", + self._data_base, + "/v1beta3/crypto/us/latest/trades", + params={"symbols": symbol}, + ) + quote_resp = await self._request( + "GET", + self._data_base, + "/v1beta3/crypto/us/latest/quotes", + params={"symbols": symbol}, + ) + trade = ((trade_resp or {}).get("trades") or {}).get(symbol) or {} + quote = ((quote_resp or {}).get("quotes") or {}).get(symbol) or {} return { "symbol": symbol, "asset_class": "crypto", - "last_price": getattr(trade, "price", None), - "bid": getattr(quote, "bid_price", None), - "ask": getattr(quote, "ask_price", None), - "timestamp": _serialize(getattr(trade, "timestamp", None)), + "last_price": trade.get("p"), + "bid": quote.get("bp"), + "ask": quote.get("ap"), + "timestamp": trade.get("t"), } if ac == "options": - req = OptionLatestQuoteRequest(symbol_or_symbols=symbol) # type: ignore[assignment] - data = await self._run(self._option.get_option_latest_quote, req) - quote = data.get(symbol) + quote_resp = await self._request( + "GET", + self._data_base, + f"/v1beta1/options/{symbol}/quotes/latest", + ) + quote = (quote_resp or {}).get("quote") or {} return { "symbol": symbol, "asset_class": "options", - "bid": getattr(quote, "bid_price", None), - "ask": getattr(quote, "ask_price", None), - "timestamp": _serialize(getattr(quote, "timestamp", None)), + "bid": quote.get("bp"), + "ask": quote.get("ap"), + "timestamp": quote.get("t"), } raise ValueError(f"invalid asset_class: {asset_class}") @@ -212,73 +256,117 @@ class AlpacaClient: limit: int = 1000, ) -> dict: tf = _tf(interval) - start_dt = _dt.datetime.fromisoformat(start) if start else ( - _dt.datetime.now(_dt.UTC) - _dt.timedelta(days=30) + start_dt = ( + _dt.datetime.fromisoformat(start) + if start + else (_dt.datetime.now(_dt.UTC) - _dt.timedelta(days=30)) ) end_dt = _dt.datetime.fromisoformat(end) if end else _dt.datetime.now(_dt.UTC) ac = asset_class.lower() + + params: dict[str, Any] = { + "symbols": symbol, + "timeframe": tf, + "start": _iso(start_dt), + "end": _iso(end_dt), + "limit": limit, + } + if ac == "stocks": - req = StockBarsRequest( - symbol_or_symbols=symbol, timeframe=tf, - start=start_dt, end=end_dt, limit=limit, + # IEX feed di default — coerente con default alpaca-py free tier. + params["feed"] = "iex" + data = await self._request( + "GET", self._data_base, "/v2/stocks/bars", params=params ) - data = await self._run(self._stock.get_stock_bars, req) elif ac == "crypto": - req = CryptoBarsRequest( # type: ignore[assignment] - symbol_or_symbols=symbol, timeframe=tf, - start=start_dt, end=end_dt, limit=limit, + data = await self._request( + "GET", + self._data_base, + "/v1beta3/crypto/us/bars", + params=params, ) - data = await self._run(self._crypto.get_crypto_bars, req) elif ac == "options": - req = OptionBarsRequest( # type: ignore[assignment] - symbol_or_symbols=symbol, timeframe=tf, - start=start_dt, end=end_dt, limit=limit, + data = await self._request( + "GET", + self._data_base, + "/v1beta1/options/bars", + params=params, ) - data = await self._run(self._option.get_option_bars, req) else: raise ValueError(f"invalid asset_class: {asset_class}") - bars_dict = getattr(data, "data", {}) or {} - rows = bars_dict.get(symbol, []) or [] + + bars_dict = (data or {}).get("bars") or {} + rows = bars_dict.get(symbol) or [] bars = [ { - "timestamp": _serialize(getattr(b, "timestamp", None)), - "open": getattr(b, "open", None), - "high": getattr(b, "high", None), - "low": getattr(b, "low", None), - "close": getattr(b, "close", None), - "volume": getattr(b, "volume", None), + "timestamp": b.get("t"), + "open": b.get("o"), + "high": b.get("h"), + "low": b.get("l"), + "close": b.get("c"), + "volume": b.get("v"), } for b in rows ] - return {"symbol": symbol, "asset_class": ac, "interval": interval, "bars": bars} + return { + "symbol": symbol, + "asset_class": ac, + "interval": interval, + "bars": bars, + } async def get_snapshot(self, symbol: str) -> dict: - req = StockSnapshotRequest(symbol_or_symbols=symbol) - data = await self._run(self._stock.get_stock_snapshot, req) - return _serialize(data.get(symbol)) # type: ignore[no-any-return] + data = await self._request( + "GET", + self._data_base, + "/v2/stocks/snapshots", + params={"symbols": symbol}, + ) + # API ritorna {"AAPL": {snapshot}} o {"snapshots": {...}} — gestiamo + # entrambi i formati; v2/stocks/snapshots ritorna dict top-level + # symbol→snapshot. + if data is None: + return {} + if symbol in data: + return data[symbol] or {} + snaps = data.get("snapshots") or {} + return snaps.get(symbol) or {} async def get_option_chain( self, underlying: str, expiry: str | None = None, ) -> dict: - kwargs: dict[str, Any] = {"underlying_symbol": underlying} + params: dict[str, Any] = {} if expiry: - kwargs["expiration_date"] = _dt.date.fromisoformat(expiry) - req = OptionChainRequest(**kwargs) - data = await self._run(self._option.get_option_chain, req) + # Validazione date (solleva ValueError su input invalido, + # parità con V1 che usava _dt.date.fromisoformat). + _dt.date.fromisoformat(expiry) + params["expiration_date_gte"] = expiry + params["expiration_date_lte"] = expiry + data = await self._request( + "GET", + self._data_base, + f"/v1beta1/options/snapshots/{underlying}", + params=params or None, + ) + contracts = (data or {}).get("snapshots") if data else None return { "underlying": underlying, "expiry": expiry, - "contracts": _serialize(data), + "contracts": contracts if contracts is not None else (data or {}), } # ── Orders ────────────────────────────────────────────────── async def get_open_orders(self, limit: int = 50) -> list[dict]: - req = GetOrdersRequest(status=QueryOrderStatus.OPEN, limit=limit) - orders = await self._run(self._trading.get_orders, filter=req) - return [_serialize(o) for o in orders] + data = await self._request( + "GET", + self._trading_base, + "/v2/orders", + params={"status": "open", "limit": limit}, + ) + return list(data) if data else [] async def place_order( self, @@ -292,32 +380,39 @@ class AlpacaClient: tif: str = "day", asset_class: str = "stocks", ) -> dict: - side_enum = OrderSide.BUY if side.lower() == "buy" else OrderSide.SELL - tif_enum = TimeInForce(tif.lower()) ot = order_type.lower() - common = { + body: dict[str, Any] = { "symbol": symbol, - "side": side_enum, - "time_in_force": tif_enum, + "side": side.lower(), + "type": ot, + "time_in_force": tif.lower(), } if qty is not None: - common["qty"] = qty # type: ignore[assignment] + body["qty"] = str(qty) if notional is not None: - common["notional"] = notional # type: ignore[assignment] + body["notional"] = str(notional) if ot == "market": - req = MarketOrderRequest(**common) + pass elif ot == "limit": if limit_price is None: raise ValueError("limit_price required for limit order") - req = LimitOrderRequest(**common, limit_price=limit_price) # type: ignore[assignment] + body["limit_price"] = str(limit_price) elif ot == "stop": if stop_price is None: raise ValueError("stop_price required for stop order") - req = StopOrderRequest(**common, stop_price=stop_price) # type: ignore[assignment] + body["stop_price"] = str(stop_price) else: raise ValueError(f"unsupported order_type: {order_type}") - order = await self._run(self._trading.submit_order, req) - return _serialize(order) # type: ignore[no-any-return] + # `asset_class` non è un parametro REST; mantenuto in firma per parità + # con V1 (era usato solo da SDK per scegliere il request model). + _ = asset_class + data = await self._request( + "POST", + self._trading_base, + "/v2/orders", + json_body=body, + ) + return dict(data) if data else {} async def amend_order( self, @@ -327,69 +422,85 @@ class AlpacaClient: stop_price: float | None = None, tif: str | None = None, ) -> dict: - kwargs: dict[str, Any] = {} + body: dict[str, Any] = {} if qty is not None: - kwargs["qty"] = qty + body["qty"] = str(qty) if limit_price is not None: - kwargs["limit_price"] = limit_price + body["limit_price"] = str(limit_price) if stop_price is not None: - kwargs["stop_price"] = stop_price + body["stop_price"] = str(stop_price) if tif is not None: - kwargs["time_in_force"] = TimeInForce(tif.lower()) - req = ReplaceOrderRequest(**kwargs) - order = await self._run(self._trading.replace_order_by_id, order_id, req) - return _serialize(order) # type: ignore[no-any-return] + body["time_in_force"] = tif.lower() + data = await self._request( + "PATCH", + self._trading_base, + f"/v2/orders/{order_id}", + json_body=body, + ) + return dict(data) if data else {} async def cancel_order(self, order_id: str) -> dict: - await self._run(self._trading.cancel_order_by_id, order_id) + # DELETE /v2/orders/{id} → 204 No Content su success. + await self._request( + "DELETE", self._trading_base, f"/v2/orders/{order_id}" + ) return {"order_id": order_id, "canceled": True} async def cancel_all_orders(self) -> list[dict]: - resp = await self._run(self._trading.cancel_orders) - return [_serialize(r) for r in resp] + # DELETE /v2/orders → 207 Multi-Status con array di {id, status} + data = await self._request( + "DELETE", self._trading_base, "/v2/orders" + ) + return list(data) if data else [] # ── Position close ────────────────────────────────────────── async def close_position( self, symbol: str, qty: float | None = None, percentage: float | None = None ) -> dict: - req = None - if qty is not None or percentage is not None: - kwargs: dict[str, Any] = {} - if qty is not None: - kwargs["qty"] = str(qty) - if percentage is not None: - kwargs["percentage"] = str(percentage) - req = ClosePositionRequest(**kwargs) - order = await self._run( - self._trading.close_position, symbol, close_options=req + # DELETE /v2/positions/{symbol}?qty=... oppure ?percentage=... + params: dict[str, Any] = {} + if qty is not None: + params["qty"] = str(qty) + if percentage is not None: + params["percentage"] = str(percentage) + data = await self._request( + "DELETE", + self._trading_base, + f"/v2/positions/{symbol}", + params=params or None, ) - return _serialize(order) # type: ignore[no-any-return] + return dict(data) if data else {} async def close_all_positions(self, cancel_orders: bool = True) -> list[dict]: - resp = await self._run( - self._trading.close_all_positions, cancel_orders=cancel_orders + data = await self._request( + "DELETE", + self._trading_base, + "/v2/positions", + params={"cancel_orders": "true" if cancel_orders else "false"}, ) - return [_serialize(r) for r in resp] + return list(data) if data else [] # ── Clock / calendar ──────────────────────────────────────── async def get_clock(self) -> dict: - clock = await self._run(self._trading.get_clock) - return _serialize(clock) # type: ignore[no-any-return] + data = await self._request("GET", self._trading_base, "/v2/clock") + return dict(data) if data else {} async def get_calendar( self, start: str | None = None, end: str | None = None ) -> list[dict]: - from alpaca.trading.requests import GetCalendarRequest - - kwargs: dict[str, Any] = {} + params: dict[str, Any] = {} if start: - kwargs["start"] = _dt.date.fromisoformat(start) + _dt.date.fromisoformat(start) # validazione, parità V1 + params["start"] = start if end: - kwargs["end"] = _dt.date.fromisoformat(end) - req = GetCalendarRequest(**kwargs) if kwargs else None - cal = await self._run( - self._trading.get_calendar, filters=req - ) if req else await self._run(self._trading.get_calendar) - return [_serialize(c) for c in cal] + _dt.date.fromisoformat(end) + params["end"] = end + data = await self._request( + "GET", + self._trading_base, + "/v2/calendar", + params=params or None, + ) + return list(data) if data else [] diff --git a/tests/unit/exchanges/alpaca/conftest.py b/tests/unit/exchanges/alpaca/conftest.py index ac65829..99e343f 100644 --- a/tests/unit/exchanges/alpaca/conftest.py +++ b/tests/unit/exchanges/alpaca/conftest.py @@ -1,39 +1,23 @@ from __future__ import annotations -from unittest.mock import MagicMock - import pytest from cerbero_mcp.exchanges.alpaca.client import AlpacaClient @pytest.fixture -def mock_trading(): - return MagicMock(name="alpaca_TradingClient") +async def client(): + """AlpacaClient paper su httpx mock (gestito da pytest-httpx).""" + c = AlpacaClient(api_key="test_key", secret_key="test_secret", paper=True) + try: + yield c + finally: + await c.aclose() @pytest.fixture -def mock_stock(): - return MagicMock(name="alpaca_StockHistoricalDataClient") - - -@pytest.fixture -def mock_crypto(): - return MagicMock(name="alpaca_CryptoHistoricalDataClient") - - -@pytest.fixture -def mock_option(): - return MagicMock(name="alpaca_OptionHistoricalDataClient") - - -@pytest.fixture -def client(mock_trading, mock_stock, mock_crypto, mock_option): - return AlpacaClient( - api_key="test_key", - secret_key="test_secret", - paper=True, - trading=mock_trading, - stock_data=mock_stock, - crypto_data=mock_crypto, - option_data=mock_option, - ) +async def client_live(): + c = AlpacaClient(api_key="test_key", secret_key="test_secret", paper=False) + try: + yield c + finally: + await c.aclose() diff --git a/tests/unit/exchanges/alpaca/test_client.py b/tests/unit/exchanges/alpaca/test_client.py index 24172ec..f9aad8a 100644 --- a/tests/unit/exchanges/alpaca/test_client.py +++ b/tests/unit/exchanges/alpaca/test_client.py @@ -1,80 +1,417 @@ +"""Test AlpacaClient httpx-based (V2.0.0). + +Mockano gli endpoint REST tramite pytest-httpx. Coprono account/positions, +ordini (place/cancel/limit-error), close position, clock, asset class +invalida → ValueError. +""" from __future__ import annotations -from unittest.mock import MagicMock +import re import pytest +from cerbero_mcp.exchanges.alpaca.client import AlpacaClient +from pytest_httpx import HTTPXMock + +PAPER = "https://paper-api.alpaca.markets" +DATA = "https://data.alpaca.markets" @pytest.mark.asyncio -async def test_init_paper_mode(client, mock_trading): +async def test_init_paper_mode(client: AlpacaClient): assert client.paper is True - assert client._trading is mock_trading + assert client.base_url is None + assert client._trading_base == PAPER @pytest.mark.asyncio -async def test_get_account_calls_trading(client, mock_trading): - mock_trading.get_account.return_value = MagicMock( - model_dump=lambda: {"equity": 100000, "cash": 50000} +async def test_init_live_mode(client_live: AlpacaClient): + assert client_live.paper is False + assert client_live._trading_base == "https://api.alpaca.markets" + + +@pytest.mark.asyncio +async def test_init_base_url_override(): + c = AlpacaClient( + api_key="k", + secret_key="s", + paper=True, + base_url="https://alpaca-custom.example.com", + ) + try: + assert c.base_url == "https://alpaca-custom.example.com" + assert c._trading_base == "https://alpaca-custom.example.com" + # Data endpoint NON viene overridato + assert c._data_base == DATA + finally: + await c.aclose() + + +@pytest.mark.asyncio +async def test_get_account(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=f"{PAPER}/v2/account", + json={"id": "abc", "equity": "100000.00", "buying_power": "200000.00"}, ) result = await client.get_account() - mock_trading.get_account.assert_called_once() - assert result["equity"] == 100000 + assert result["id"] == "abc" + assert result["equity"] == "100000.00" @pytest.mark.asyncio -async def test_get_positions_returns_list(client, mock_trading): - pos_mock = MagicMock(model_dump=lambda: {"symbol": "AAPL", "qty": 10}) - mock_trading.get_all_positions.return_value = [pos_mock] +async def test_get_account_sends_auth_headers( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response(url=f"{PAPER}/v2/account", json={"id": "x"}) + await client.get_account() + req = httpx_mock.get_requests()[0] + assert req.headers["APCA-API-KEY-ID"] == "test_key" + assert req.headers["APCA-API-SECRET-KEY"] == "test_secret" + assert req.headers["Accept"] == "application/json" + + +@pytest.mark.asyncio +async def test_get_positions_returns_list( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + url=f"{PAPER}/v2/positions", + json=[{"symbol": "AAPL", "qty": "10", "side": "long"}], + ) result = await client.get_positions() assert len(result) == 1 assert result[0]["symbol"] == "AAPL" @pytest.mark.asyncio -async def test_place_market_order_stocks(client, mock_trading): - order_mock = MagicMock(model_dump=lambda: {"id": "o123", "symbol": "AAPL"}) - mock_trading.submit_order.return_value = order_mock +async def test_place_market_order_stocks( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + method="POST", + url=f"{PAPER}/v2/orders", + json={"id": "o123", "symbol": "AAPL", "status": "accepted"}, + ) result = await client.place_order( - symbol="AAPL", side="buy", qty=1, order_type="market", asset_class="stocks", + symbol="AAPL", side="buy", qty=1, order_type="market", asset_class="stocks" ) assert result["id"] == "o123" - assert mock_trading.submit_order.called + # body POST corretto + req = httpx_mock.get_requests()[0] + import json as _j + + body = _j.loads(req.content) + assert body["symbol"] == "AAPL" + assert body["side"] == "buy" + assert body["type"] == "market" + assert body["qty"] == "1" + assert body["time_in_force"] == "day" @pytest.mark.asyncio -async def test_place_limit_order_requires_price(client): +async def test_place_limit_order_requires_price(client: AlpacaClient): with pytest.raises(ValueError, match="limit_price"): await client.place_order( - symbol="AAPL", side="buy", qty=1, order_type="limit", + symbol="AAPL", side="buy", qty=1, order_type="limit" ) @pytest.mark.asyncio -async def test_cancel_order(client, mock_trading): - mock_trading.cancel_order_by_id.return_value = None +async def test_place_stop_order_requires_price(client: AlpacaClient): + with pytest.raises(ValueError, match="stop_price"): + await client.place_order( + symbol="AAPL", side="buy", qty=1, order_type="stop" + ) + + +@pytest.mark.asyncio +async def test_place_unsupported_order_type(client: AlpacaClient): + with pytest.raises(ValueError, match="unsupported order_type"): + await client.place_order( + symbol="AAPL", side="buy", qty=1, order_type="trailing_stop" + ) + + +@pytest.mark.asyncio +async def test_cancel_order(httpx_mock: HTTPXMock, client: AlpacaClient): + # 204 No Content su success + httpx_mock.add_response( + method="DELETE", + url=f"{PAPER}/v2/orders/o1", + status_code=204, + ) result = await client.cancel_order("o1") - mock_trading.cancel_order_by_id.assert_called_once_with("o1") assert result == {"order_id": "o1", "canceled": True} @pytest.mark.asyncio -async def test_close_position_no_options(client, mock_trading): - order_mock = MagicMock(model_dump=lambda: {"id": "close-1"}) - mock_trading.close_position.return_value = order_mock +async def test_cancel_all_orders(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + method="DELETE", + url=f"{PAPER}/v2/orders", + json=[ + {"id": "a", "status": 200}, + {"id": "b", "status": 200}, + ], + ) + result = await client.cancel_all_orders() + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_close_position_no_options( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + method="DELETE", + url=f"{PAPER}/v2/positions/AAPL", + json={"id": "close-1", "symbol": "AAPL"}, + ) result = await client.close_position("AAPL") - assert mock_trading.close_position.called assert result["id"] == "close-1" @pytest.mark.asyncio -async def test_get_clock(client, mock_trading): - clock_mock = MagicMock(model_dump=lambda: {"is_open": True, "next_close": "2026-04-21T20:00:00Z"}) - mock_trading.get_clock.return_value = clock_mock +async def test_close_position_with_qty( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + method="DELETE", + url=f"{PAPER}/v2/positions/AAPL?qty=5.0", + json={"id": "close-2"}, + ) + result = await client.close_position("AAPL", qty=5.0) + assert result["id"] == "close-2" + + +@pytest.mark.asyncio +async def test_close_all_positions( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + method="DELETE", + url=f"{PAPER}/v2/positions?cancel_orders=true", + json=[{"symbol": "AAPL", "status": 200}], + ) + result = await client.close_all_positions(cancel_orders=True) + assert len(result) == 1 + assert result[0]["symbol"] == "AAPL" + + +@pytest.mark.asyncio +async def test_get_clock(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=f"{PAPER}/v2/clock", + json={"is_open": True, "next_close": "2026-04-21T20:00:00Z"}, + ) result = await client.get_clock() assert result["is_open"] is True @pytest.mark.asyncio -async def test_invalid_asset_class(client): +async def test_invalid_asset_class(client: AlpacaClient): with pytest.raises(ValueError, match="invalid asset_class"): await client.get_ticker("AAPL", asset_class="forex") + + +@pytest.mark.asyncio +async def test_get_ticker_stocks(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=f"{DATA}/v2/stocks/AAPL/trades/latest", + json={ + "symbol": "AAPL", + "trade": {"p": 175.50, "s": 100, "t": "2026-04-18T15:30:00Z"}, + }, + ) + httpx_mock.add_response( + url=f"{DATA}/v2/stocks/AAPL/quotes/latest", + json={ + "symbol": "AAPL", + "quote": { + "bp": 175.40, + "ap": 175.55, + "bs": 50, + "as": 25, + "t": "2026-04-18T15:30:00Z", + }, + }, + ) + result = await client.get_ticker("AAPL", asset_class="stocks") + assert result["asset_class"] == "stocks" + assert result["last_price"] == 175.50 + assert result["bid"] == 175.40 + assert result["ask"] == 175.55 + + +@pytest.mark.asyncio +async def test_get_bars_stocks(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{DATA}/v2/stocks/bars\?.*"), + json={ + "bars": { + "AAPL": [ + { + "t": "2026-04-17T00:00:00Z", + "o": 170.0, + "h": 176.0, + "l": 169.5, + "c": 175.0, + "v": 1000000, + } + ] + } + }, + ) + result = await client.get_bars( + symbol="AAPL", + asset_class="stocks", + interval="1d", + start="2026-04-17T00:00:00", + end="2026-04-18T00:00:00", + limit=10, + ) + assert result["symbol"] == "AAPL" + assert result["interval"] == "1d" + assert len(result["bars"]) == 1 + assert result["bars"][0]["close"] == 175.0 + + +@pytest.mark.asyncio +async def test_get_bars_unsupported_timeframe(client: AlpacaClient): + with pytest.raises(ValueError, match="unsupported timeframe"): + await client.get_bars( + symbol="AAPL", + asset_class="stocks", + interval="3min", + ) + + +@pytest.mark.asyncio +async def test_get_bars_invalid_asset_class(client: AlpacaClient): + with pytest.raises(ValueError, match="invalid asset_class"): + await client.get_bars(symbol="AAPL", asset_class="forex") + + +@pytest.mark.asyncio +async def test_get_assets(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{PAPER}/v2/assets\?.*"), + json=[ + {"symbol": "AAPL", "tradable": True, "class": "us_equity"}, + {"symbol": "GOOG", "tradable": True, "class": "us_equity"}, + ], + ) + result = await client.get_assets(asset_class="stocks", status="active") + assert len(result) == 2 + assert result[0]["symbol"] == "AAPL" + + +@pytest.mark.asyncio +async def test_get_assets_invalid_class(client: AlpacaClient): + with pytest.raises(ValueError, match="invalid asset_class"): + await client.get_assets(asset_class="forex") + + +@pytest.mark.asyncio +async def test_get_open_orders(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{PAPER}/v2/orders\?.*"), + json=[{"id": "o1", "status": "open", "symbol": "AAPL"}], + ) + result = await client.get_open_orders(limit=10) + assert len(result) == 1 + assert result[0]["id"] == "o1" + + +@pytest.mark.asyncio +async def test_amend_order(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + method="PATCH", + url=f"{PAPER}/v2/orders/o1", + json={"id": "o1", "qty": "5", "limit_price": "180.0"}, + ) + result = await client.amend_order( + "o1", qty=5, limit_price=180.0, tif="gtc" + ) + assert result["id"] == "o1" + req = httpx_mock.get_requests()[0] + import json as _j + + body = _j.loads(req.content) + assert body["qty"] == "5" + assert body["limit_price"] == "180.0" + assert body["time_in_force"] == "gtc" + + +@pytest.mark.asyncio +async def test_get_calendar(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{PAPER}/v2/calendar.*"), + json=[{"date": "2026-04-20", "open": "09:30", "close": "16:00"}], + ) + result = await client.get_calendar(start="2026-04-20", end="2026-04-20") + assert len(result) == 1 + assert result[0]["date"] == "2026-04-20" + + +@pytest.mark.asyncio +async def test_get_calendar_no_filters( + httpx_mock: HTTPXMock, client: AlpacaClient +): + httpx_mock.add_response( + url=f"{PAPER}/v2/calendar", + json=[{"date": "2026-04-20"}], + ) + result = await client.get_calendar() + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_get_snapshot(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{DATA}/v2/stocks/snapshots\?.*"), + json={ + "AAPL": { + "latestTrade": {"p": 175.0}, + "latestQuote": {"bp": 174.9, "ap": 175.1}, + } + }, + ) + result = await client.get_snapshot("AAPL") + assert result["latestTrade"]["p"] == 175.0 + + +@pytest.mark.asyncio +async def test_get_option_chain(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{DATA}/v1beta1/options/snapshots/AAPL.*"), + json={ + "snapshots": { + "AAPL250620C00200000": { + "latestQuote": {"bp": 1.20, "ap": 1.30} + } + } + }, + ) + result = await client.get_option_chain("AAPL", expiry="2026-06-20") + assert result["underlying"] == "AAPL" + assert result["expiry"] == "2026-06-20" + assert "AAPL250620C00200000" in result["contracts"] + + +@pytest.mark.asyncio +async def test_get_activities(httpx_mock: HTTPXMock, client: AlpacaClient): + httpx_mock.add_response( + url=re.compile(rf"^{PAPER}/v2/account/activities.*"), + json=[ + {"id": "1", "activity_type": "FILL"}, + {"id": "2", "activity_type": "TRANS"}, + ], + ) + result = await client.get_activities(limit=10) + assert len(result) == 2 + + +@pytest.mark.asyncio +async def test_aclose_idempotent(client: AlpacaClient): + await client.aclose() + await client.aclose() # nessun raise