Files
Cerbero-mcp/services/common/src/mcp_common/mcp_bridge.py
T

240 lines
7.6 KiB
Python

"""Bridge MCP → endpoint REST esistenti.
Implementa manualmente JSON-RPC 2.0 MCP su `POST /mcp` (no SSE, risposta
diretta in body JSON). Supporta:
- initialize
- notifications/initialized
- tools/list
- tools/call
Claude Code config esempio:
{
"mcpServers": {
"cerbero-memory": {
"type": "http",
"url": "http://localhost:8080/mcp-memory/mcp",
"headers": {"Authorization": "Bearer <observer-token>"}
}
}
}
"""
from __future__ import annotations
from typing import Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from mcp_common.auth import TokenStore
MCP_PROTOCOL_VERSION = "2024-11-05"
def _derive_input_schemas(app: FastAPI, tool_names: list[str]) -> dict[str, dict]:
"""Estrae JSON schema del body Pydantic per ogni route POST /tools/<name>.
Risolve annotazioni lazy (PEP 563) via `typing.get_type_hints`.
Ritorna mapping {tool_name: json_schema}. Route senza body Pydantic o non
risolvibili vengono saltate: il chiamante userà un fallback.
"""
import typing
from pydantic import BaseModel
names_set = set(tool_names)
out: dict[str, dict] = {}
for route in app.routes:
path = getattr(route, "path", "")
if not path.startswith("/tools/"):
continue
name = path[len("/tools/"):]
if name not in names_set:
continue
endpoint = getattr(route, "endpoint", None)
if endpoint is None:
continue
try:
hints = typing.get_type_hints(endpoint)
except Exception:
continue
for pname, ann in hints.items():
if pname == "return":
continue
if isinstance(ann, type) and issubclass(ann, BaseModel):
try:
out[name] = ann.model_json_schema()
except Exception:
pass
break
return out
def _make_proxy_handler(internal_base_url: str, tool_name: str, token: str):
async def handler(args: dict | None) -> Any:
async with httpx.AsyncClient(timeout=30.0) as c:
r = await c.post(
f"{internal_base_url}/tools/{tool_name}",
headers={"Authorization": f"Bearer {token}"} if token else {},
json=args or {},
)
if r.status_code >= 400:
raise RuntimeError(
f"tool {tool_name} failed: HTTP {r.status_code}{r.text[:500]}"
)
try:
return r.json()
except Exception:
return {"raw": r.text}
return handler
def mount_mcp_endpoint(
app: FastAPI,
*,
name: str,
version: str,
token_store: TokenStore,
internal_base_url: str,
tools: list[dict],
) -> None:
"""Registra un endpoint MCP JSON-RPC 2.0 su POST /mcp.
Ogni tool è proxato verso POST {internal_base_url}/tools/<name> con il
Bearer token del client MCP (preservando le ACL REST esistenti).
Args:
app: istanza FastAPI del service
name: nome server MCP
version: versione del service
token_store: lo stesso usato dai tool REST
internal_base_url: URL base interno (es. "http://localhost:9015")
tools: lista di {"name": str, "description": str, "input_schema"?: dict}
"""
tools_by_name = {t["name"]: t for t in tools}
# Auto-derive input schemas from FastAPI routes (Pydantic body models).
# Permette al LLM di conoscere i nomi dei parametri obbligatori invece di
# indovinarli. Se il tool ha `input_schema` esplicito, vince sull'auto-derive.
derived_schemas = _derive_input_schemas(app, [t["name"] for t in tools])
def _tool_defs() -> list[dict]:
defs = []
for t in tools:
schema = t.get("input_schema") or derived_schemas.get(t["name"]) or {
"type": "object",
"additionalProperties": True,
}
defs.append({
"name": t["name"],
"description": t.get("description", t["name"]),
"inputSchema": schema,
})
return defs
async def _handle_rpc(body: dict, token: str) -> dict | None:
rpc_id = body.get("id")
method = body.get("method")
params = body.get("params") or {}
# Notification (no id) → no response
if method == "notifications/initialized":
return None
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": name, "version": version},
},
}
if method == "tools/list":
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {"tools": _tool_defs()},
}
if method == "tools/call":
tool_name = params.get("name", "")
args = params.get("arguments") or {}
if tool_name not in tools_by_name:
return {
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32601, "message": f"tool non trovato: {tool_name}"},
}
handler = _make_proxy_handler(internal_base_url, tool_name, token)
try:
result = await handler(args)
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"content": [
{
"type": "text",
"text": _to_text(result),
}
],
"isError": False,
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"content": [{"type": "text", "text": str(e)}],
"isError": True,
},
}
return {
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32601, "message": f"metodo non supportato: {method}"},
}
@app.post("/mcp")
async def mcp_entry(request: Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse({"error": "missing bearer token"}, status_code=401)
token = auth[len("Bearer "):].strip()
principal = token_store.get(token)
if principal is None:
return JSONResponse({"error": "invalid token"}, status_code=403)
body = await request.json()
# Batch support
if isinstance(body, list):
results = []
for item in body:
resp = await _handle_rpc(item, token)
if resp is not None:
results.append(resp)
return JSONResponse(results)
resp = await _handle_rpc(body, token)
if resp is None:
# Notification (no id) → 204 no content
return JSONResponse(None, status_code=204)
return JSONResponse(resp)
def _to_text(value: Any) -> str:
import json
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False, indent=2)
except Exception:
return str(value)