feat: import 6 MCP services + common workspace

This commit is contained in:
AdrianoDev
2026-04-27 17:34:14 +02:00
parent 9676f22a8e
commit 6fc3d1d94f
67 changed files with 10693 additions and 0 deletions
+29
View File
@@ -0,0 +1,29 @@
[project]
name = "mcp-alpaca"
version = "0.1.0"
requires-python = ">=3.11"
dependencies = [
"option-mcp-common",
"fastapi>=0.115",
"uvicorn[standard]>=0.30",
"httpx>=0.27",
"pydantic>=2.6",
"alpaca-py>=0.32",
"pytz>=2024.1",
]
[project.optional-dependencies]
dev = ["pytest>=8", "pytest-asyncio>=0.23"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/mcp_alpaca"]
[tool.uv.sources]
option-mcp-common = { workspace = true }
[project.scripts]
mcp-alpaca = "mcp_alpaca.__main__:main"
@@ -0,0 +1,45 @@
from __future__ import annotations
import json
import os
import uvicorn
from option_mcp_common.auth import load_token_store_from_files
from option_mcp_common.logging import configure_root_logging
from mcp_alpaca.client import AlpacaClient
from mcp_alpaca.server import create_app
configure_root_logging()
def main():
creds_file = os.environ["ALPACA_CREDENTIALS_FILE"]
with open(creds_file) as f:
creds = json.load(f)
paper_env = os.environ.get("ALPACA_PAPER", "true").lower()
paper = paper_env not in ("0", "false", "no")
client = AlpacaClient(
api_key=creds["api_key_id"],
secret_key=creds["secret_key"],
paper=paper,
)
token_store = load_token_store_from_files(
core_token_file=os.environ.get("CORE_TOKEN_FILE"),
observer_token_file=os.environ.get("OBSERVER_TOKEN_FILE"),
)
app = create_app(client=client, token_store=token_store)
uvicorn.run(
app,
log_config=None,
host=os.environ.get("HOST", "0.0.0.0"),
port=int(os.environ.get("PORT", "9020")),
)
if __name__ == "__main__":
main()
@@ -0,0 +1,388 @@
from __future__ import annotations
import asyncio
import datetime as _dt
from typing import Any
from alpaca.data.historical import (
CryptoHistoricalDataClient,
OptionHistoricalDataClient,
StockHistoricalDataClient,
)
from alpaca.data.requests import (
CryptoBarsRequest,
CryptoLatestQuoteRequest,
CryptoLatestTradeRequest,
OptionBarsRequest,
OptionChainRequest,
OptionLatestQuoteRequest,
StockBarsRequest,
StockLatestQuoteRequest,
StockLatestTradeRequest,
StockSnapshotRequest,
)
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
from alpaca.trading.client import TradingClient
from alpaca.trading.enums import (
AssetClass,
OrderSide,
OrderStatus,
OrderType,
QueryOrderStatus,
TimeInForce,
)
from alpaca.trading.requests import (
ClosePositionRequest,
GetAssetsRequest,
GetOrdersRequest,
LimitOrderRequest,
MarketOrderRequest,
ReplaceOrderRequest,
StopOrderRequest,
)
_TF_MAP = {
"1min": TimeFrame(1, TimeFrameUnit.Minute),
"5min": TimeFrame(5, TimeFrameUnit.Minute),
"15min": TimeFrame(15, TimeFrameUnit.Minute),
"30min": TimeFrame(30, TimeFrameUnit.Minute),
"1h": TimeFrame(1, TimeFrameUnit.Hour),
"1d": TimeFrame(1, TimeFrameUnit.Day),
"1w": TimeFrame(1, TimeFrameUnit.Week),
}
_ASSET_CLASSES = {"stocks", "crypto", "options"}
def _tf(interval: str) -> TimeFrame:
if interval in _TF_MAP:
return _TF_MAP[interval]
raise ValueError(f"unsupported timeframe: {interval}")
def _asset_class_enum(ac: str) -> AssetClass:
ac = ac.lower()
if ac == "stocks":
return AssetClass.US_EQUITY
if ac == "crypto":
return AssetClass.CRYPTO
if ac == "options":
return AssetClass.US_OPTION
raise ValueError(f"invalid asset_class: {ac}")
def _serialize(obj: Any) -> Any:
"""Recursively convert pydantic/datetime objects → json-safe."""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (_dt.datetime, _dt.date)):
return obj.isoformat()
if isinstance(obj, dict):
return {k: _serialize(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_serialize(v) for v in obj]
if hasattr(obj, "model_dump"):
return _serialize(obj.model_dump())
if hasattr(obj, "__dict__"):
return _serialize(vars(obj))
return str(obj)
class AlpacaClient:
def __init__(
self,
api_key: str,
secret_key: str,
paper: bool = True,
trading: Any | None = None,
stock_data: Any | None = None,
crypto_data: Any | None = None,
option_data: Any | None = None,
) -> None:
self.api_key = api_key
self.secret_key = secret_key
self.paper = paper
self._trading = trading or TradingClient(
api_key=api_key, secret_key=secret_key, paper=paper
)
self._stock = stock_data or StockHistoricalDataClient(
api_key=api_key, secret_key=secret_key
)
self._crypto = crypto_data or CryptoHistoricalDataClient(
api_key=api_key, secret_key=secret_key
)
self._option = option_data or OptionHistoricalDataClient(
api_key=api_key, secret_key=secret_key
)
async def _run(self, fn, /, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)
# ── Account / positions ──────────────────────────────────────
async def get_account(self) -> dict:
acc = await self._run(self._trading.get_account)
return _serialize(acc)
async def get_positions(self) -> list[dict]:
pos = await self._run(self._trading.get_all_positions)
return [_serialize(p) for p in pos]
async def get_activities(self, limit: int = 50) -> list[dict]:
acts = await self._run(self._trading.get_account_activities)
data = [_serialize(a) for a in acts]
return data[:limit]
# ── Assets ──────────────────────────────────────────────────
async def get_assets(
self, asset_class: str = "stocks", status: str = "active"
) -> list[dict]:
req = GetAssetsRequest(
asset_class=_asset_class_enum(asset_class),
status=status,
)
assets = await self._run(self._trading.get_all_assets, req)
return [_serialize(a) for a in assets[:500]]
# ── Market data ─────────────────────────────────────────────
async def get_ticker(self, symbol: str, asset_class: str = "stocks") -> dict:
ac = asset_class.lower()
if ac == "stocks":
req = StockLatestTradeRequest(symbol_or_symbols=symbol)
data = await self._run(self._stock.get_stock_latest_trade, req)
trade = data.get(symbol)
q_req = StockLatestQuoteRequest(symbol_or_symbols=symbol)
qdata = await self._run(self._stock.get_stock_latest_quote, q_req)
quote = qdata.get(symbol)
return {
"symbol": symbol,
"asset_class": "stocks",
"last_price": getattr(trade, "price", None),
"bid": getattr(quote, "bid_price", None),
"ask": getattr(quote, "ask_price", None),
"bid_size": getattr(quote, "bid_size", None),
"ask_size": getattr(quote, "ask_size", None),
"timestamp": _serialize(getattr(trade, "timestamp", None)),
}
if ac == "crypto":
req = CryptoLatestTradeRequest(symbol_or_symbols=symbol)
data = await self._run(self._crypto.get_crypto_latest_trade, req)
trade = data.get(symbol)
q_req = CryptoLatestQuoteRequest(symbol_or_symbols=symbol)
qdata = await self._run(self._crypto.get_crypto_latest_quote, q_req)
quote = qdata.get(symbol)
return {
"symbol": symbol,
"asset_class": "crypto",
"last_price": getattr(trade, "price", None),
"bid": getattr(quote, "bid_price", None),
"ask": getattr(quote, "ask_price", None),
"timestamp": _serialize(getattr(trade, "timestamp", None)),
}
if ac == "options":
req = OptionLatestQuoteRequest(symbol_or_symbols=symbol)
data = await self._run(self._option.get_option_latest_quote, req)
quote = data.get(symbol)
return {
"symbol": symbol,
"asset_class": "options",
"bid": getattr(quote, "bid_price", None),
"ask": getattr(quote, "ask_price", None),
"timestamp": _serialize(getattr(quote, "timestamp", None)),
}
raise ValueError(f"invalid asset_class: {asset_class}")
async def get_bars(
self,
symbol: str,
asset_class: str = "stocks",
interval: str = "1d",
start: str | None = None,
end: str | None = None,
limit: int = 1000,
) -> dict:
tf = _tf(interval)
start_dt = _dt.datetime.fromisoformat(start) if start else (
_dt.datetime.now(_dt.UTC) - _dt.timedelta(days=30)
)
end_dt = _dt.datetime.fromisoformat(end) if end else _dt.datetime.now(_dt.UTC)
ac = asset_class.lower()
if ac == "stocks":
req = StockBarsRequest(
symbol_or_symbols=symbol, timeframe=tf,
start=start_dt, end=end_dt, limit=limit,
)
data = await self._run(self._stock.get_stock_bars, req)
elif ac == "crypto":
req = CryptoBarsRequest(
symbol_or_symbols=symbol, timeframe=tf,
start=start_dt, end=end_dt, limit=limit,
)
data = await self._run(self._crypto.get_crypto_bars, req)
elif ac == "options":
req = OptionBarsRequest(
symbol_or_symbols=symbol, timeframe=tf,
start=start_dt, end=end_dt, limit=limit,
)
data = await self._run(self._option.get_option_bars, req)
else:
raise ValueError(f"invalid asset_class: {asset_class}")
bars_dict = getattr(data, "data", {}) or {}
rows = bars_dict.get(symbol, []) or []
bars = [
{
"timestamp": _serialize(getattr(b, "timestamp", None)),
"open": getattr(b, "open", None),
"high": getattr(b, "high", None),
"low": getattr(b, "low", None),
"close": getattr(b, "close", None),
"volume": getattr(b, "volume", None),
}
for b in rows
]
return {"symbol": symbol, "asset_class": ac, "interval": interval, "bars": bars}
async def get_snapshot(self, symbol: str) -> dict:
req = StockSnapshotRequest(symbol_or_symbols=symbol)
data = await self._run(self._stock.get_stock_snapshot, req)
return _serialize(data.get(symbol))
async def get_option_chain(
self,
underlying: str,
expiry: str | None = None,
) -> dict:
kwargs: dict[str, Any] = {"underlying_symbol": underlying}
if expiry:
kwargs["expiration_date"] = _dt.date.fromisoformat(expiry)
req = OptionChainRequest(**kwargs)
data = await self._run(self._option.get_option_chain, req)
return {
"underlying": underlying,
"expiry": expiry,
"contracts": _serialize(data),
}
# ── Orders ──────────────────────────────────────────────────
async def get_open_orders(self, limit: int = 50) -> list[dict]:
req = GetOrdersRequest(status=QueryOrderStatus.OPEN, limit=limit)
orders = await self._run(self._trading.get_orders, filter=req)
return [_serialize(o) for o in orders]
async def place_order(
self,
symbol: str,
side: str,
qty: float | None = None,
notional: float | None = None,
order_type: str = "market",
limit_price: float | None = None,
stop_price: float | None = None,
tif: str = "day",
asset_class: str = "stocks",
) -> dict:
side_enum = OrderSide.BUY if side.lower() == "buy" else OrderSide.SELL
tif_enum = TimeInForce(tif.lower())
ot = order_type.lower()
common = {
"symbol": symbol,
"side": side_enum,
"time_in_force": tif_enum,
}
if qty is not None:
common["qty"] = qty
if notional is not None:
common["notional"] = notional
if ot == "market":
req = MarketOrderRequest(**common)
elif ot == "limit":
if limit_price is None:
raise ValueError("limit_price required for limit order")
req = LimitOrderRequest(**common, limit_price=limit_price)
elif ot == "stop":
if stop_price is None:
raise ValueError("stop_price required for stop order")
req = StopOrderRequest(**common, stop_price=stop_price)
else:
raise ValueError(f"unsupported order_type: {order_type}")
order = await self._run(self._trading.submit_order, req)
return _serialize(order)
async def amend_order(
self,
order_id: str,
qty: float | None = None,
limit_price: float | None = None,
stop_price: float | None = None,
tif: str | None = None,
) -> dict:
kwargs: dict[str, Any] = {}
if qty is not None:
kwargs["qty"] = qty
if limit_price is not None:
kwargs["limit_price"] = limit_price
if stop_price is not None:
kwargs["stop_price"] = stop_price
if tif is not None:
kwargs["time_in_force"] = TimeInForce(tif.lower())
req = ReplaceOrderRequest(**kwargs)
order = await self._run(self._trading.replace_order_by_id, order_id, req)
return _serialize(order)
async def cancel_order(self, order_id: str) -> dict:
await self._run(self._trading.cancel_order_by_id, order_id)
return {"order_id": order_id, "canceled": True}
async def cancel_all_orders(self) -> list[dict]:
resp = await self._run(self._trading.cancel_orders)
return [_serialize(r) for r in resp]
# ── Position close ──────────────────────────────────────────
async def close_position(
self, symbol: str, qty: float | None = None, percentage: float | None = None
) -> dict:
req = None
if qty is not None or percentage is not None:
kwargs: dict[str, Any] = {}
if qty is not None:
kwargs["qty"] = str(qty)
if percentage is not None:
kwargs["percentage"] = str(percentage)
req = ClosePositionRequest(**kwargs)
order = await self._run(
self._trading.close_position, symbol, close_options=req
)
return _serialize(order)
async def close_all_positions(self, cancel_orders: bool = True) -> list[dict]:
resp = await self._run(
self._trading.close_all_positions, cancel_orders=cancel_orders
)
return [_serialize(r) for r in resp]
# ── Clock / calendar ────────────────────────────────────────
async def get_clock(self) -> dict:
clock = await self._run(self._trading.get_clock)
return _serialize(clock)
async def get_calendar(
self, start: str | None = None, end: str | None = None
) -> list[dict]:
from alpaca.trading.requests import GetCalendarRequest
kwargs: dict[str, Any] = {}
if start:
kwargs["start"] = _dt.date.fromisoformat(start)
if end:
kwargs["end"] = _dt.date.fromisoformat(end)
req = GetCalendarRequest(**kwargs) if kwargs else None
cal = await self._run(
self._trading.get_calendar, filters=req
) if req else await self._run(self._trading.get_calendar)
return [_serialize(c) for c in cal]
@@ -0,0 +1,250 @@
from __future__ import annotations
import os
from fastapi import Depends, HTTPException
from option_mcp_common.auth import Principal, TokenStore, require_principal
from option_mcp_common.mcp_bridge import mount_mcp_endpoint
from option_mcp_common.server import build_app
from pydantic import BaseModel
from mcp_alpaca.client import AlpacaClient
# --- Body models: reads ---
class AccountReq(BaseModel):
pass
class PositionsReq(BaseModel):
pass
class ActivitiesReq(BaseModel):
limit: int = 50
class AssetsReq(BaseModel):
asset_class: str = "stocks"
status: str = "active"
class TickerReq(BaseModel):
symbol: str
asset_class: str = "stocks"
class BarsReq(BaseModel):
symbol: str
asset_class: str = "stocks"
interval: str = "1d"
start: str | None = None
end: str | None = None
limit: int = 1000
class SnapshotReq(BaseModel):
symbol: str
class OptionChainReq(BaseModel):
underlying: str
expiry: str | None = None
class OpenOrdersReq(BaseModel):
limit: int = 50
class ClockReq(BaseModel):
pass
class CalendarReq(BaseModel):
start: str | None = None
end: str | None = None
# --- Body models: writes ---
class PlaceOrderReq(BaseModel):
symbol: str
side: str
qty: float | None = None
notional: float | None = None
order_type: str = "market"
limit_price: float | None = None
stop_price: float | None = None
tif: str = "day"
asset_class: str = "stocks"
class AmendOrderReq(BaseModel):
order_id: str
qty: float | None = None
limit_price: float | None = None
stop_price: float | None = None
tif: str | None = None
class CancelOrderReq(BaseModel):
order_id: str
class CancelAllReq(BaseModel):
pass
class ClosePositionReq(BaseModel):
symbol: str
qty: float | None = None
percentage: float | None = None
class CloseAllPositionsReq(BaseModel):
cancel_orders: bool = True
# --- ACL helper ---
def _check(principal: Principal, *, core: bool = False, observer: bool = False) -> None:
allowed: set[str] = set()
if core:
allowed.add("core")
if observer:
allowed.add("observer")
if not (principal.capabilities & allowed):
raise HTTPException(status_code=403, detail="forbidden")
def create_app(*, client: AlpacaClient, token_store: TokenStore):
app = build_app(name="mcp-alpaca", version="0.1.0", token_store=token_store)
# ── Reads ──────────────────────────────────────────────
@app.post("/tools/get_account", tags=["reads"])
async def t_get_account(body: AccountReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_account()
@app.post("/tools/get_positions", tags=["reads"])
async def t_get_positions(body: PositionsReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return {"positions": await client.get_positions()}
@app.post("/tools/get_activities", tags=["reads"])
async def t_get_activities(body: ActivitiesReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return {"activities": await client.get_activities(body.limit)}
@app.post("/tools/get_assets", tags=["reads"])
async def t_get_assets(body: AssetsReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return {"assets": await client.get_assets(body.asset_class, body.status)}
@app.post("/tools/get_ticker", tags=["reads"])
async def t_get_ticker(body: TickerReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_ticker(body.symbol, body.asset_class)
@app.post("/tools/get_bars", tags=["reads"])
async def t_get_bars(body: BarsReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_bars(
body.symbol, body.asset_class, body.interval, body.start, body.end, body.limit,
)
@app.post("/tools/get_snapshot", tags=["reads"])
async def t_get_snapshot(body: SnapshotReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_snapshot(body.symbol)
@app.post("/tools/get_option_chain", tags=["reads"])
async def t_get_option_chain(body: OptionChainReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_option_chain(body.underlying, body.expiry)
@app.post("/tools/get_open_orders", tags=["reads"])
async def t_get_open_orders(body: OpenOrdersReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return {"orders": await client.get_open_orders(body.limit)}
@app.post("/tools/get_clock", tags=["reads"])
async def t_get_clock(body: ClockReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return await client.get_clock()
@app.post("/tools/get_calendar", tags=["reads"])
async def t_get_calendar(body: CalendarReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True, observer=True)
return {"calendar": await client.get_calendar(body.start, body.end)}
# ── Writes ─────────────────────────────────────────────
@app.post("/tools/place_order", tags=["writes"])
async def t_place_order(body: PlaceOrderReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return await client.place_order(
body.symbol, body.side, body.qty, body.notional,
body.order_type, body.limit_price, body.stop_price, body.tif, body.asset_class,
)
@app.post("/tools/amend_order", tags=["writes"])
async def t_amend_order(body: AmendOrderReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return await client.amend_order(
body.order_id, body.qty, body.limit_price, body.stop_price, body.tif,
)
@app.post("/tools/cancel_order", tags=["writes"])
async def t_cancel_order(body: CancelOrderReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return await client.cancel_order(body.order_id)
@app.post("/tools/cancel_all_orders", tags=["writes"])
async def t_cancel_all(body: CancelAllReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return {"canceled": await client.cancel_all_orders()}
@app.post("/tools/close_position", tags=["writes"])
async def t_close(body: ClosePositionReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return await client.close_position(body.symbol, body.qty, body.percentage)
@app.post("/tools/close_all_positions", tags=["writes"])
async def t_close_all(body: CloseAllPositionsReq, principal: Principal = Depends(require_principal)):
_check(principal, core=True)
return {"closed": await client.close_all_positions(body.cancel_orders)}
# ── MCP mount ──────────────────────────────────────────
port = int(os.environ.get("PORT", "9020"))
mount_mcp_endpoint(
app,
name="cerbero-alpaca",
version="0.1.0",
token_store=token_store,
internal_base_url=f"http://localhost:{port}",
tools=[
{"name": "get_account", "description": "Alpaca account summary (equity, cash, buying_power)."},
{"name": "get_positions", "description": "Posizioni aperte (stocks/crypto/options)."},
{"name": "get_activities", "description": "Activity log (fills, dividends, transfers)."},
{"name": "get_assets", "description": "Universo asset per asset_class."},
{"name": "get_ticker", "description": "Last trade + quote per simbolo (stocks/crypto/options)."},
{"name": "get_bars", "description": "OHLCV candles (stocks/crypto/options)."},
{"name": "get_snapshot", "description": "Snapshot completo stock (last trade+quote+bar)."},
{"name": "get_option_chain", "description": "Option chain per underlying."},
{"name": "get_open_orders", "description": "Ordini pending."},
{"name": "get_clock", "description": "Market clock (open/close, next_open)."},
{"name": "get_calendar", "description": "Calendar sessioni trading."},
{"name": "place_order", "description": "Invia ordine (CORE only)."},
{"name": "amend_order", "description": "Replace ordine esistente."},
{"name": "cancel_order", "description": "Cancella ordine."},
{"name": "cancel_all_orders", "description": "Cancella tutti ordini aperti."},
{"name": "close_position", "description": "Chiude posizione (tutta o parziale)."},
{"name": "close_all_positions", "description": "Liquida tutto il portafoglio."},
],
)
return app
+40
View File
@@ -0,0 +1,40 @@
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from mcp_alpaca.client import AlpacaClient
@pytest.fixture
def mock_trading():
return MagicMock(name="alpaca_TradingClient")
@pytest.fixture
def mock_stock():
return MagicMock(name="alpaca_StockHistoricalDataClient")
@pytest.fixture
def mock_crypto():
return MagicMock(name="alpaca_CryptoHistoricalDataClient")
@pytest.fixture
def mock_option():
return MagicMock(name="alpaca_OptionHistoricalDataClient")
@pytest.fixture
def client(mock_trading, mock_stock, mock_crypto, mock_option):
return AlpacaClient(
api_key="test_key",
secret_key="test_secret",
paper=True,
trading=mock_trading,
stock_data=mock_stock,
crypto_data=mock_crypto,
option_data=mock_option,
)
+80
View File
@@ -0,0 +1,80 @@
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
@pytest.mark.asyncio
async def test_init_paper_mode(client, mock_trading):
assert client.paper is True
assert client._trading is mock_trading
@pytest.mark.asyncio
async def test_get_account_calls_trading(client, mock_trading):
mock_trading.get_account.return_value = MagicMock(
model_dump=lambda: {"equity": 100000, "cash": 50000}
)
result = await client.get_account()
mock_trading.get_account.assert_called_once()
assert result["equity"] == 100000
@pytest.mark.asyncio
async def test_get_positions_returns_list(client, mock_trading):
pos_mock = MagicMock(model_dump=lambda: {"symbol": "AAPL", "qty": 10})
mock_trading.get_all_positions.return_value = [pos_mock]
result = await client.get_positions()
assert len(result) == 1
assert result[0]["symbol"] == "AAPL"
@pytest.mark.asyncio
async def test_place_market_order_stocks(client, mock_trading):
order_mock = MagicMock(model_dump=lambda: {"id": "o123", "symbol": "AAPL"})
mock_trading.submit_order.return_value = order_mock
result = await client.place_order(
symbol="AAPL", side="buy", qty=1, order_type="market", asset_class="stocks",
)
assert result["id"] == "o123"
assert mock_trading.submit_order.called
@pytest.mark.asyncio
async def test_place_limit_order_requires_price(client):
with pytest.raises(ValueError, match="limit_price"):
await client.place_order(
symbol="AAPL", side="buy", qty=1, order_type="limit",
)
@pytest.mark.asyncio
async def test_cancel_order(client, mock_trading):
mock_trading.cancel_order_by_id.return_value = None
result = await client.cancel_order("o1")
mock_trading.cancel_order_by_id.assert_called_once_with("o1")
assert result == {"order_id": "o1", "canceled": True}
@pytest.mark.asyncio
async def test_close_position_no_options(client, mock_trading):
order_mock = MagicMock(model_dump=lambda: {"id": "close-1"})
mock_trading.close_position.return_value = order_mock
result = await client.close_position("AAPL")
assert mock_trading.close_position.called
assert result["id"] == "close-1"
@pytest.mark.asyncio
async def test_get_clock(client, mock_trading):
clock_mock = MagicMock(model_dump=lambda: {"is_open": True, "next_close": "2026-04-21T20:00:00Z"})
mock_trading.get_clock.return_value = clock_mock
result = await client.get_clock()
assert result["is_open"] is True
@pytest.mark.asyncio
async def test_invalid_asset_class(client):
with pytest.raises(ValueError, match="invalid asset_class"):
await client.get_ticker("AAPL", asset_class="forex")
@@ -0,0 +1,111 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from option_mcp_common.auth import Principal, TokenStore
from mcp_alpaca.server import create_app
@pytest.fixture
def token_store():
return TokenStore(
tokens={
"core-tok": Principal("core", {"core"}),
"obs-tok": Principal("observer", {"observer"}),
}
)
@pytest.fixture
def mock_client():
c = MagicMock()
c.get_account = AsyncMock(return_value={"equity": 100000})
c.get_positions = AsyncMock(return_value=[])
c.get_activities = AsyncMock(return_value=[])
c.get_assets = AsyncMock(return_value=[])
c.get_ticker = AsyncMock(return_value={"symbol": "AAPL"})
c.get_bars = AsyncMock(return_value={"bars": []})
c.get_snapshot = AsyncMock(return_value={})
c.get_option_chain = AsyncMock(return_value={"contracts": []})
c.get_open_orders = AsyncMock(return_value=[])
c.get_clock = AsyncMock(return_value={"is_open": True})
c.get_calendar = AsyncMock(return_value=[])
c.place_order = AsyncMock(return_value={"id": "o1"})
c.amend_order = AsyncMock(return_value={"id": "o1"})
c.cancel_order = AsyncMock(return_value={"canceled": True})
c.cancel_all_orders = AsyncMock(return_value=[])
c.close_position = AsyncMock(return_value={"id": "close1"})
c.close_all_positions = AsyncMock(return_value=[])
return c
@pytest.fixture
def http(mock_client, token_store):
app = create_app(client=mock_client, token_store=token_store)
return TestClient(app)
CORE = {"Authorization": "Bearer core-tok"}
OBS = {"Authorization": "Bearer obs-tok"}
READ_ENDPOINTS = [
("/tools/get_account", {}),
("/tools/get_positions", {}),
("/tools/get_activities", {}),
("/tools/get_assets", {}),
("/tools/get_ticker", {"symbol": "AAPL"}),
("/tools/get_bars", {"symbol": "AAPL"}),
("/tools/get_snapshot", {"symbol": "AAPL"}),
("/tools/get_option_chain", {"underlying": "AAPL"}),
("/tools/get_open_orders", {}),
("/tools/get_clock", {}),
("/tools/get_calendar", {}),
]
WRITE_ENDPOINTS = [
("/tools/place_order", {"symbol": "AAPL", "side": "buy", "qty": 1}),
("/tools/amend_order", {"order_id": "o1", "qty": 2}),
("/tools/cancel_order", {"order_id": "o1"}),
("/tools/cancel_all_orders", {}),
("/tools/close_position", {"symbol": "AAPL"}),
("/tools/close_all_positions", {}),
]
@pytest.mark.parametrize("path,payload", READ_ENDPOINTS)
def test_read_core_ok(http, path, payload):
r = http.post(path, json=payload, headers=CORE)
assert r.status_code == 200, (path, r.text)
@pytest.mark.parametrize("path,payload", READ_ENDPOINTS)
def test_read_observer_ok(http, path, payload):
r = http.post(path, json=payload, headers=OBS)
assert r.status_code == 200, (path, r.text)
@pytest.mark.parametrize("path,payload", READ_ENDPOINTS)
def test_read_no_auth_401(http, path, payload):
r = http.post(path, json=payload)
assert r.status_code == 401, (path, r.text)
@pytest.mark.parametrize("path,payload", WRITE_ENDPOINTS)
def test_write_core_ok(http, path, payload):
r = http.post(path, json=payload, headers=CORE)
assert r.status_code == 200, (path, r.text)
@pytest.mark.parametrize("path,payload", WRITE_ENDPOINTS)
def test_write_observer_403(http, path, payload):
r = http.post(path, json=payload, headers=OBS)
assert r.status_code == 403, (path, r.text)
@pytest.mark.parametrize("path,payload", WRITE_ENDPOINTS)
def test_write_no_auth_401(http, path, payload):
r = http.post(path, json=payload)
assert r.status_code == 401, (path, r.text)