feat: import 6 MCP services + common workspace

This commit is contained in:
AdrianoDev
2026-04-27 17:34:14 +02:00
parent 9676f22a8e
commit 6fc3d1d94f
67 changed files with 10693 additions and 0 deletions
+23
View File
@@ -0,0 +1,23 @@
[project]
name = "option-mcp-common"
version = "0.1.0"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115",
"uvicorn[standard]>=0.30",
"mcp>=1.0",
"httpx>=0.27",
"pydantic>=2.6",
"pydantic-settings>=2.3",
"python-json-logger>=2.0",
]
[project.optional-dependencies]
dev = ["pytest>=8", "pytest-asyncio>=0.23", "pytest-httpx>=0.30", "ruff>=0.5"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/option_mcp_common"]
@@ -0,0 +1,19 @@
from option_mcp_common.models import (
Event,
EventPriority,
EventType,
L1State,
L2Entry,
L3Entry,
UserInstruction,
)
__all__ = [
"L1State",
"L2Entry",
"L3Entry",
"Event",
"EventPriority",
"EventType",
"UserInstruction",
]
@@ -0,0 +1,98 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import wraps
from fastapi import HTTPException, Request, status
@dataclass
class Principal:
name: str
capabilities: set[str] = field(default_factory=set)
@dataclass
class TokenStore:
tokens: dict[str, Principal]
def get(self, token: str) -> Principal | None:
return self.tokens.get(token)
def require_principal(request: Request) -> Principal:
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
token = auth[len("Bearer "):].strip()
store: TokenStore = request.app.state.token_store
principal = store.get(token)
if principal is None:
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid token")
return principal
def acl_requires(*, core: bool = False, observer: bool = False) -> Callable:
"""Decorator: require at least one matching capability."""
allowed: set[str] = set()
if core:
allowed.add("core")
if observer:
allowed.add("observer")
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(*args, **kwargs):
principal = kwargs.get("principal")
if principal is None:
for a in args:
if isinstance(a, Principal):
principal = a
break
if principal is None or not (principal.capabilities & allowed):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"capability required: {allowed}",
)
return await func(*args, **kwargs) if _is_coro(func) else func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
principal = kwargs.get("principal")
if principal is None:
for a in args:
if isinstance(a, Principal):
principal = a
break
if principal is None or not (principal.capabilities & allowed):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"capability required: {allowed}",
)
return func(*args, **kwargs)
return async_wrapper if _is_coro(func) else sync_wrapper
return decorator
def _is_coro(func: Callable) -> bool:
import asyncio
return asyncio.iscoroutinefunction(func)
def load_token_store_from_files(
core_token_file: str | None,
observer_token_file: str | None,
) -> TokenStore:
tokens: dict[str, Principal] = {}
if core_token_file:
with open(core_token_file) as f:
tokens[f.read().strip()] = Principal(name="core", capabilities={"core"})
if observer_token_file:
with open(observer_token_file) as f:
tokens[f.read().strip()] = Principal(
name="observer", capabilities={"observer"}
)
return TokenStore(tokens=tokens)
@@ -0,0 +1,80 @@
"""CER-P5-010: env validation policy — fail-fast per mandatory, soft per optional.
Usage al boot di ogni mcp `__main__.py`:
from option_mcp_common.env_validation import require_env, optional_env, summarize
creds_file = require_env("CREDENTIALS_FILE", "deribit credentials JSON path")
host = optional_env("HOST", default="0.0.0.0")
summarize(["CREDENTIALS_FILE", "HOST", "PORT"])
"""
from __future__ import annotations
import logging
import os
import sys
logger = logging.getLogger(__name__)
class MissingEnvError(RuntimeError):
"""Mandatory env var absent or empty."""
def require_env(name: str, description: str = "") -> str:
"""Fail-fast: raise MissingEnvError se name non presente o vuoto.
Uscita dal processo con codice 2 se chiamato dal main(). Comporta
logging chiaro del missing var prima dell'exit.
"""
val = (os.environ.get(name) or "").strip()
if not val:
msg = f"missing mandatory env var: {name}"
if description:
msg += f" ({description})"
logger.error(msg)
raise MissingEnvError(msg)
return val
def optional_env(name: str, *, default: str = "") -> str:
"""Soft: ritorna env o default. Log INFO se default usato."""
val = (os.environ.get(name) or "").strip()
if not val:
if default:
logger.info("env %s not set, using default=%r", name, default)
return default
return val
def summarize(names: list[str]) -> None:
"""Log INFO di tutti gli env rilevanti con presenza (mask se SECRET/KEY/TOKEN)."""
sensitive_tokens = ("SECRET", "KEY", "TOKEN", "PASSWORD", "CREDENTIAL", "WALLET")
for n in names:
val = os.environ.get(n)
if val is None:
logger.info("env[%s]: <unset>", n)
continue
if any(t in n.upper() for t in sensitive_tokens):
logger.info("env[%s]: <set, %d chars>", n, len(val))
else:
logger.info("env[%s]: %s", n, val)
def fail_fast_if_missing(names: list[str]) -> None:
"""Verifica lista di nomi mandatory al boot. Exit 2 se uno solo manca.
Uso preferito: early call in main() per bloccare boot se config incompleta.
"""
missing: list[str] = []
for n in names:
if not (os.environ.get(n) or "").strip():
missing.append(n)
if missing:
logger.error("boot aborted: missing mandatory env vars: %s", missing)
print(
f"FATAL: missing mandatory env vars: {missing}",
file=sys.stderr,
)
sys.exit(2)
@@ -0,0 +1,139 @@
from __future__ import annotations
def sma(values: list[float], period: int) -> float | None:
if len(values) < period:
return None
return sum(values[-period:]) / period
def rsi(closes: list[float], period: int = 14) -> float | None:
if len(closes) < period + 1:
return None
gains: list[float] = []
losses: list[float] = []
for i in range(1, len(closes)):
delta = closes[i] - closes[i - 1]
gains.append(max(delta, 0.0))
losses.append(-min(delta, 0.0))
avg_gain = sum(gains[:period]) / period
avg_loss = sum(losses[:period]) / period
for i in range(period, len(gains)):
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
if avg_loss == 0:
return 100.0
rs = avg_gain / avg_loss
return 100.0 - (100.0 / (1.0 + rs))
def _ema_series(values: list[float], period: int) -> list[float]:
if len(values) < period:
return []
k = 2.0 / (period + 1)
seed = sum(values[:period]) / period
out = [seed]
for v in values[period:]:
out.append(out[-1] + k * (v - out[-1]))
return out
def macd(
closes: list[float],
fast: int = 12,
slow: int = 26,
signal: int = 9,
) -> dict[str, float | None]:
nothing: dict[str, float | None] = {"macd": None, "signal": None, "hist": None}
if len(closes) < slow + signal:
return nothing
ema_fast = _ema_series(closes, fast)
ema_slow = _ema_series(closes, slow)
offset = slow - fast
aligned_fast = ema_fast[offset:]
macd_line = [f - s for f, s in zip(aligned_fast, ema_slow, strict=False)]
if len(macd_line) < signal:
return nothing
signal_line = _ema_series(macd_line, signal)
if not signal_line:
return nothing
last_macd = macd_line[-1]
last_sig = signal_line[-1]
return {
"macd": last_macd,
"signal": last_sig,
"hist": last_macd - last_sig,
}
def atr(
highs: list[float],
lows: list[float],
closes: list[float],
period: int = 14,
) -> float | None:
if len(closes) < period + 1:
return None
trs: list[float] = []
for i in range(1, len(closes)):
tr = max(
highs[i] - lows[i],
abs(highs[i] - closes[i - 1]),
abs(lows[i] - closes[i - 1]),
)
trs.append(tr)
if len(trs) < period:
return None
avg = sum(trs[:period]) / period
for i in range(period, len(trs)):
avg = (avg * (period - 1) + trs[i]) / period
return avg
def adx(
highs: list[float],
lows: list[float],
closes: list[float],
period: int = 14,
) -> dict[str, float | None]:
nothing: dict[str, float | None] = {"adx": None, "+di": None, "-di": None}
if len(closes) < 2 * period + 1:
return nothing
trs: list[float] = []
plus_dms: list[float] = []
minus_dms: list[float] = []
for i in range(1, len(closes)):
tr = max(
highs[i] - lows[i],
abs(highs[i] - closes[i - 1]),
abs(lows[i] - closes[i - 1]),
)
up = highs[i] - highs[i - 1]
dn = lows[i - 1] - lows[i]
plus_dm = up if (up > dn and up > 0) else 0.0
minus_dm = dn if (dn > up and dn > 0) else 0.0
trs.append(tr)
plus_dms.append(plus_dm)
minus_dms.append(minus_dm)
atr_s = sum(trs[:period])
pdm_s = sum(plus_dms[:period])
mdm_s = sum(minus_dms[:period])
dxs: list[float] = []
pdi = mdi = 0.0
for i in range(period, len(trs)):
atr_s = atr_s - atr_s / period + trs[i]
pdm_s = pdm_s - pdm_s / period + plus_dms[i]
mdm_s = mdm_s - mdm_s / period + minus_dms[i]
pdi = 100.0 * pdm_s / atr_s if atr_s else 0.0
mdi = 100.0 * mdm_s / atr_s if atr_s else 0.0
s = pdi + mdi
dx = 100.0 * abs(pdi - mdi) / s if s else 0.0
dxs.append(dx)
if len(dxs) < period:
return nothing
adx_val = sum(dxs[:period]) / period
for i in range(period, len(dxs)):
adx_val = (adx_val * (period - 1) + dxs[i]) / period
return {"adx": adx_val, "+di": pdi, "-di": mdi}
@@ -0,0 +1,81 @@
from __future__ import annotations
import logging
import os
import re
import sys
# pythonjsonlogger rinominato in .json; keep fallback per compat
try:
from pythonjsonlogger.json import JsonFormatter as _JsonFormatter # noqa: N813
except ImportError:
from pythonjsonlogger.jsonlogger import JsonFormatter as _JsonFormatter # noqa: N813
SECRET_PATTERNS = [
(re.compile(r"Bearer\s+[\w\-\._]+", re.IGNORECASE), "Bearer ***"),
(re.compile(r'("api_key"\s*:\s*")[^"]+(")'), r'\1***\2'),
(re.compile(r'("password"\s*:\s*")[^"]+(")'), r'\1***\2'),
(re.compile(r'("private_key"\s*:\s*")[^"]+(")'), r'\1***\2'),
(re.compile(r'("client_secret"\s*:\s*")[^"]+(")'), r'\1***\2'),
(re.compile(r"sk-[\w]{20,}"), "sk-***"),
]
class SecretsFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
msg = record.getMessage()
for pattern, replacement in SECRET_PATTERNS:
msg = pattern.sub(replacement, msg)
record.msg = msg
record.args = () # already formatted into msg
return True
def get_json_logger(name: str, level: int = logging.INFO) -> logging.Logger:
logger = logging.getLogger(name)
if logger.handlers:
return logger # already configured
logger.setLevel(level)
handler = logging.StreamHandler(sys.stderr)
formatter = _JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s")
handler.setFormatter(formatter)
handler.addFilter(SecretsFilter())
logger.addHandler(handler)
logger.propagate = False
return logger
def configure_root_logging(
*,
level: str | int | None = None,
format_type: str | None = None,
) -> None:
"""CER-P5-009: configura il root logger con JSON o text formatter.
Env overrides:
- LOG_LEVEL (default INFO)
- LOG_FORMAT=json|text (default json — production-ready structured log)
Applica SecretsFilter su entrambi i format.
"""
lvl_raw = level if level is not None else os.environ.get("LOG_LEVEL", "INFO")
lvl = logging.getLevelName(lvl_raw.upper()) if isinstance(lvl_raw, str) else lvl_raw
fmt = (format_type or os.environ.get("LOG_FORMAT") or "json").lower()
root = logging.getLogger()
# Rimuovi handler esistenti (basicConfig li avrebbe lasciati duplicati)
for h in list(root.handlers):
root.removeHandler(h)
handler = logging.StreamHandler(sys.stderr)
if fmt == "json":
handler.setFormatter(
_JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s")
)
else:
handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
)
handler.addFilter(SecretsFilter())
root.addHandler(handler)
root.setLevel(lvl)
@@ -0,0 +1,239 @@
"""Bridge MCP → endpoint REST esistenti.
Implementa manualmente JSON-RPC 2.0 MCP su `POST /mcp` (no SSE, risposta
diretta in body JSON). Supporta:
- initialize
- notifications/initialized
- tools/list
- tools/call
Claude Code config esempio:
{
"mcpServers": {
"cerbero-memory": {
"type": "http",
"url": "http://localhost:8080/mcp-memory/mcp",
"headers": {"Authorization": "Bearer <observer-token>"}
}
}
}
"""
from __future__ import annotations
from typing import Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from option_mcp_common.auth import TokenStore
MCP_PROTOCOL_VERSION = "2024-11-05"
def _derive_input_schemas(app: FastAPI, tool_names: list[str]) -> dict[str, dict]:
"""Estrae JSON schema del body Pydantic per ogni route POST /tools/<name>.
Risolve annotazioni lazy (PEP 563) via `typing.get_type_hints`.
Ritorna mapping {tool_name: json_schema}. Route senza body Pydantic o non
risolvibili vengono saltate: il chiamante userà un fallback.
"""
import typing
from pydantic import BaseModel
names_set = set(tool_names)
out: dict[str, dict] = {}
for route in app.routes:
path = getattr(route, "path", "")
if not path.startswith("/tools/"):
continue
name = path[len("/tools/"):]
if name not in names_set:
continue
endpoint = getattr(route, "endpoint", None)
if endpoint is None:
continue
try:
hints = typing.get_type_hints(endpoint)
except Exception:
continue
for pname, ann in hints.items():
if pname == "return":
continue
if isinstance(ann, type) and issubclass(ann, BaseModel):
try:
out[name] = ann.model_json_schema()
except Exception:
pass
break
return out
def _make_proxy_handler(internal_base_url: str, tool_name: str, token: str):
async def handler(args: dict | None) -> Any:
async with httpx.AsyncClient(timeout=30.0) as c:
r = await c.post(
f"{internal_base_url}/tools/{tool_name}",
headers={"Authorization": f"Bearer {token}"} if token else {},
json=args or {},
)
if r.status_code >= 400:
raise RuntimeError(
f"tool {tool_name} failed: HTTP {r.status_code}{r.text[:500]}"
)
try:
return r.json()
except Exception:
return {"raw": r.text}
return handler
def mount_mcp_endpoint(
app: FastAPI,
*,
name: str,
version: str,
token_store: TokenStore,
internal_base_url: str,
tools: list[dict],
) -> None:
"""Registra un endpoint MCP JSON-RPC 2.0 su POST /mcp.
Ogni tool è proxato verso POST {internal_base_url}/tools/<name> con il
Bearer token del client MCP (preservando le ACL REST esistenti).
Args:
app: istanza FastAPI del service
name: nome server MCP
version: versione del service
token_store: lo stesso usato dai tool REST
internal_base_url: URL base interno (es. "http://localhost:9015")
tools: lista di {"name": str, "description": str, "input_schema"?: dict}
"""
tools_by_name = {t["name"]: t for t in tools}
# Auto-derive input schemas from FastAPI routes (Pydantic body models).
# Permette al LLM di conoscere i nomi dei parametri obbligatori invece di
# indovinarli. Se il tool ha `input_schema` esplicito, vince sull'auto-derive.
derived_schemas = _derive_input_schemas(app, [t["name"] for t in tools])
def _tool_defs() -> list[dict]:
defs = []
for t in tools:
schema = t.get("input_schema") or derived_schemas.get(t["name"]) or {
"type": "object",
"additionalProperties": True,
}
defs.append({
"name": t["name"],
"description": t.get("description", t["name"]),
"inputSchema": schema,
})
return defs
async def _handle_rpc(body: dict, token: str) -> dict | None:
rpc_id = body.get("id")
method = body.get("method")
params = body.get("params") or {}
# Notification (no id) → no response
if method == "notifications/initialized":
return None
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": name, "version": version},
},
}
if method == "tools/list":
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {"tools": _tool_defs()},
}
if method == "tools/call":
tool_name = params.get("name", "")
args = params.get("arguments") or {}
if tool_name not in tools_by_name:
return {
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32601, "message": f"tool non trovato: {tool_name}"},
}
handler = _make_proxy_handler(internal_base_url, tool_name, token)
try:
result = await handler(args)
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"content": [
{
"type": "text",
"text": _to_text(result),
}
],
"isError": False,
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": rpc_id,
"result": {
"content": [{"type": "text", "text": str(e)}],
"isError": True,
},
}
return {
"jsonrpc": "2.0",
"id": rpc_id,
"error": {"code": -32601, "message": f"metodo non supportato: {method}"},
}
@app.post("/mcp")
async def mcp_entry(request: Request):
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse({"error": "missing bearer token"}, status_code=401)
token = auth[len("Bearer "):].strip()
principal = token_store.get(token)
if principal is None:
return JSONResponse({"error": "invalid token"}, status_code=403)
body = await request.json()
# Batch support
if isinstance(body, list):
results = []
for item in body:
resp = await _handle_rpc(item, token)
if resp is not None:
results.append(resp)
return JSONResponse(results)
resp = await _handle_rpc(body, token)
if resp is None:
# Notification (no id) → 204 no content
return JSONResponse(None, status_code=204)
return JSONResponse(resp)
def _to_text(value: Any) -> str:
import json
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False, indent=2)
except Exception:
return str(value)
@@ -0,0 +1,98 @@
from __future__ import annotations
from enum import StrEnum
from functools import total_ordering
from typing import Any
from pydantic import BaseModel, Field
@total_ordering
class EventPriority(StrEnum):
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
def _rank(self) -> int:
return ["low", "normal", "high", "critical"].index(self.value)
def __lt__(self, other: EventPriority) -> bool:
return self._rank() < other._rank()
class EventType(StrEnum):
ALERT = "alert"
USER_INSTRUCTION = "user_instruction"
SYSTEM = "system"
class L1State(BaseModel):
"""Singleton row with current operational state."""
updated_at: str
equity_total: float | None = None
equity_by_exchange: dict[str, float] = Field(default_factory=dict)
bias: str | None = None
pnl_day: float | None = None
pnl_total: float | None = None
capital: float | None = None
open_positions_count: int = 0
greeks_aggregate: dict[str, float] = Field(default_factory=dict)
notes: str | None = None
class L2Entry(BaseModel):
"""Reasoning entry — schema matches system_prompt v2.
`authored_by_model`: identifica l'LLM che ha generato la entry (es.
"google/gemini-3-flash-preview" per core, "anthropic/claude-haiku-4-5"
per worker). None se scritto da operatore umano via UI.
"""
timestamp: str
setup: str
tesi: str | None = None
tesi_check: str | None = None
invalidation: str | None = None
esito: str
scostamento: str | None = None
scostamento_sigma: float | None = None
lezione: str | None = None
sizing_note: str | None = None
run_id: str | None = None
user_instruction_id: int | None = None
authored_by_model: str | None = None
class L3Entry(BaseModel):
"""Compacted pattern from L2 batch."""
created_at: str
category: str # "pattern_errore" | "pattern_vincente" | "correlazione"
summary: str
source_l2_ids: list[int]
class Event(BaseModel):
id: int | None = None
created_at: str
expires_at: str
type: EventType
source: str
priority: EventPriority
payload: dict[str, Any]
acked_at: str | None = None
ack_outcome: str | None = None
ack_notes: str | None = None
class UserInstruction(BaseModel):
id: int | None = None
created_at: str
text: str
priority: EventPriority
require_ack: bool = True
source: str = "observer"
acked_at: str | None = None
ack_outcome: str | None = None
@@ -0,0 +1,92 @@
"""CER-016 hard guard server-side su place_order.
Caps configurabili via env (default safety-first, mirati a ~200 EUR single,
1000 EUR aggregato, 3x max leverage).
Thresholds sono numerici semplici — l'operatore stabilisce unità (EUR/USD)
via env; il server compara su un unico campo `notional` in valore monetario.
"""
from __future__ import annotations
import os
from fastapi import HTTPException
def _env_float(name: str, default: float) -> float:
raw = os.environ.get(name)
if not raw:
return default
try:
return float(raw)
except (TypeError, ValueError):
return default
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name)
if not raw:
return default
try:
return int(raw)
except (TypeError, ValueError):
return default
def max_notional() -> float:
return _env_float("CERBERO_MAX_NOTIONAL", 200.0)
def max_aggregate() -> float:
return _env_float("CERBERO_MAX_AGGREGATE", 1000.0)
def max_leverage() -> int:
return _env_int("CERBERO_MAX_LEVERAGE", 3)
def _hard_reject(reason: str) -> None:
raise HTTPException(
status_code=403,
detail={
"error": "HARD_PROHIBITION",
"message": reason,
"caps": {
"max_notional": max_notional(),
"max_aggregate": max_aggregate(),
"max_leverage": max_leverage(),
},
},
)
def enforce_leverage(leverage: int | float | None) -> int:
"""Ritorna leverage applicabile. Default 3x se None. Reject se > cap."""
cap = max_leverage()
if leverage is None:
return cap
lev = int(leverage)
if lev < 1:
_hard_reject(f"leverage must be >= 1 (got {lev})")
if lev > cap:
_hard_reject(f"leverage {lev}x exceeds hard cap {cap}x")
return lev
def enforce_single_notional(notional: float, *, exchange: str, instrument: str) -> None:
cap = max_notional()
if notional > cap:
_hard_reject(
f"{exchange}.{instrument} notional {notional:.2f} exceeds single trade cap {cap:.2f}"
)
def enforce_aggregate(current_total: float, new_notional: float) -> None:
cap = max_aggregate()
total = current_total + new_notional
if total > cap:
_hard_reject(
f"aggregate notional {total:.2f} (current {current_total:.2f} + new "
f"{new_notional:.2f}) exceeds cap {cap:.2f}"
)
@@ -0,0 +1,220 @@
from __future__ import annotations
import json
import os
import time
import uuid
from datetime import UTC, datetime
from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response
from starlette.middleware.base import BaseHTTPMiddleware
from option_mcp_common.auth import TokenStore
Lifespan = Callable[[FastAPI], AbstractAsyncContextManager[None]]
def _error_envelope(
*,
type_: str,
code: str,
message: str,
retryable: bool,
suggested_fix: str | None = None,
details: dict | None = None,
request_id: str | None = None,
) -> dict:
env = {
"error": {
"type": type_,
"code": code,
"message": message,
"retryable": retryable,
},
"request_id": request_id or uuid.uuid4().hex,
"data_timestamp": datetime.now(UTC).isoformat(),
}
if suggested_fix:
env["error"]["suggested_fix"] = suggested_fix
if details:
env["error"]["details"] = details
return env
class _TimestampInjectorMiddleware(BaseHTTPMiddleware):
"""CER-P5-001: inietta data_timestamp nei response tool.
- Dict response: body gains `data_timestamp` se mancante.
- List of dicts: ogni item gains `data_timestamp` se mancante.
- Header `X-Data-Timestamp` sempre presente (universale per list primitive).
Skips /health (già popolato) e /mcp (JSON-RPC bridge) e non-JSON responses.
"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
path = request.url.path
if not path.startswith("/tools/"):
return response
ctype = response.headers.get("content-type", "")
if "application/json" not in ctype:
return response
body = b""
async for chunk in response.body_iterator:
body += chunk
ts = datetime.now(UTC).isoformat()
try:
data = json.loads(body) if body else None
except Exception:
headers = dict(response.headers)
headers["X-Data-Timestamp"] = ts
return Response(
content=body,
status_code=response.status_code,
headers=headers,
media_type=response.media_type,
)
modified = False
if isinstance(data, dict) and "data_timestamp" not in data:
data["data_timestamp"] = ts
modified = True
elif isinstance(data, list):
for item in data:
if isinstance(item, dict) and "data_timestamp" not in item:
item["data_timestamp"] = ts
modified = True
headers = dict(response.headers)
headers["X-Data-Timestamp"] = ts
if modified:
new_body = json.dumps(data, default=str).encode()
headers.pop("content-length", None)
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)
return Response(
content=body,
status_code=response.status_code,
headers=headers,
media_type=response.media_type,
)
def build_app(
*,
name: str,
version: str,
token_store: TokenStore,
lifespan: Lifespan | None = None,
) -> FastAPI:
root_path = os.getenv("ROOT_PATH", "")
app = FastAPI(title=name, version=version, root_path=root_path, lifespan=lifespan)
app.state.token_store = token_store
app.state.boot_at = time.time()
app.add_middleware(_TimestampInjectorMiddleware)
@app.middleware("http")
async def _latency_header(request: Request, call_next):
t0 = time.perf_counter()
response = await call_next(request)
dur_ms = (time.perf_counter() - t0) * 1000
response.headers["X-Duration-Ms"] = f"{dur_ms:.2f}"
return response
# CER-P5-002 error envelope: exception handlers globali
@app.exception_handler(HTTPException)
async def _http_exc(request: Request, exc: HTTPException):
retryable = exc.status_code in (408, 429, 502, 503, 504)
code_map = {
400: "BAD_REQUEST", 401: "UNAUTHORIZED", 403: "FORBIDDEN",
404: "NOT_FOUND", 408: "TIMEOUT", 409: "CONFLICT",
422: "VALIDATION_ERROR", 429: "RATE_LIMIT",
500: "INTERNAL_ERROR", 502: "UPSTREAM_ERROR",
503: "UNAVAILABLE", 504: "GATEWAY_TIMEOUT",
}
code = code_map.get(exc.status_code, f"HTTP_{exc.status_code}")
message = "HTTP error"
details: dict | None = None
detail = exc.detail
# Preserve rail-style detail {"error": "..", "message": ".."} as code
if isinstance(detail, dict):
if isinstance(detail.get("error"), str):
code = detail["error"].upper()
message = str(detail.get("message") or detail.get("error") or message)
details = detail
elif isinstance(detail, str):
message = detail
return JSONResponse(
status_code=exc.status_code,
content=_error_envelope(
type_="http_error",
code=code,
message=message,
retryable=retryable,
details=details,
),
)
@app.exception_handler(RequestValidationError)
async def _validation_exc(request: Request, exc: RequestValidationError):
errs = exc.errors()
first_loc = ".".join(str(x) for x in errs[0]["loc"]) if errs else "body"
suggestion = (
f"check field '{first_loc}': "
+ (errs[0]["msg"] if errs else "invalid input")
)
# Sanitize ctx values: pydantic v2 può mettere ValueError in ctx['error'],
# non serializzabile JSON. Riduci a stringhe.
safe_errs: list[dict] = []
for e in errs[:5]:
ne: dict = {}
for k, v in e.items():
if k == "ctx" and isinstance(v, dict):
ne[k] = {ck: str(cv) for ck, cv in v.items()}
else:
ne[k] = v
safe_errs.append(ne)
return JSONResponse(
status_code=422,
content=_error_envelope(
type_="validation_error",
code="INVALID_INPUT",
message=f"request body validation failed on {first_loc}",
retryable=False,
suggested_fix=suggestion,
details={"errors": safe_errs},
),
)
@app.exception_handler(Exception)
async def _unhandled(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content=_error_envelope(
type_="internal_error",
code="UNHANDLED_EXCEPTION",
message=f"{type(exc).__name__}: {str(exc)[:300]}",
retryable=True,
),
)
@app.get("/health")
def health():
return {
"status": "healthy",
"name": name,
"version": version,
"uptime_seconds": int(time.time() - app.state.boot_at),
"data_timestamp": datetime.now(UTC).isoformat(),
}
return app
@@ -0,0 +1,43 @@
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
@dataclass
class Database:
path: Path
conn: sqlite3.Connection | None = None
def connect(self) -> sqlite3.Connection:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(
str(self.path),
isolation_level=None,
check_same_thread=False,
)
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA journal_mode=WAL")
self.conn.execute("PRAGMA synchronous=NORMAL")
self.conn.execute("PRAGMA foreign_keys=ON")
return self.conn
def close(self) -> None:
if self.conn is not None:
self.conn.close()
self.conn = None
def run_migrations(conn: sqlite3.Connection, migrations: dict[int, str]) -> None:
"""Idempotent migrations. `migrations` keys are monotonic version numbers."""
conn.execute(
"CREATE TABLE IF NOT EXISTS _schema_version (version INTEGER PRIMARY KEY)"
)
cur = conn.execute("SELECT COALESCE(MAX(version), 0) FROM _schema_version")
current = cur.fetchone()[0]
for version in sorted(migrations):
if version <= current:
continue
conn.executescript(migrations[version])
conn.execute("INSERT INTO _schema_version (version) VALUES (?)", (version,))
+84
View File
@@ -0,0 +1,84 @@
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from option_mcp_common.auth import (
Principal,
TokenStore,
acl_requires,
require_principal,
)
@pytest.fixture
def token_store():
return TokenStore(tokens={
"token-core-123": Principal(name="core", capabilities={"core"}),
"token-obs-456": Principal(name="observer", capabilities={"observer"}),
})
@pytest.fixture
def app(token_store):
app = FastAPI()
app.state.token_store = token_store
@app.get("/public")
def public():
return {"ok": True}
@app.get("/private")
def private(principal: Principal = Depends(require_principal)):
return {"name": principal.name}
@app.post("/core-only")
@acl_requires(core=True, observer=False)
def core_only(principal: Principal = Depends(require_principal)):
return {"who": principal.name}
@app.post("/observer-only")
@acl_requires(core=False, observer=True)
def observer_only(principal: Principal = Depends(require_principal)):
return {"who": principal.name}
return app
def test_public_endpoint_no_auth(app):
client = TestClient(app)
assert client.get("/public").status_code == 200
def test_private_without_header_401(app):
client = TestClient(app)
assert client.get("/private").status_code == 401
def test_private_bad_token_403(app):
client = TestClient(app)
r = client.get("/private", headers={"Authorization": "Bearer nope"})
assert r.status_code == 403
def test_private_good_token_200(app):
client = TestClient(app)
r = client.get("/private", headers={"Authorization": "Bearer token-core-123"})
assert r.status_code == 200
assert r.json() == {"name": "core"}
def test_acl_core_token_on_core_only_endpoint(app):
client = TestClient(app)
r = client.post("/core-only", headers={"Authorization": "Bearer token-core-123"})
assert r.status_code == 200
def test_acl_observer_on_core_only_rejected(app):
client = TestClient(app)
r = client.post("/core-only", headers={"Authorization": "Bearer token-obs-456"})
assert r.status_code == 403
def test_acl_observer_on_observer_only_ok(app):
client = TestClient(app)
r = client.post("/observer-only", headers={"Authorization": "Bearer token-obs-456"})
assert r.status_code == 200
@@ -0,0 +1,71 @@
"""CER-P5-010 env validation tests."""
from __future__ import annotations
import pytest
from option_mcp_common.env_validation import (
MissingEnvError,
fail_fast_if_missing,
optional_env,
require_env,
summarize,
)
def test_require_env_present(monkeypatch):
monkeypatch.setenv("FOO_KEY", "value1")
assert require_env("FOO_KEY") == "value1"
def test_require_env_missing_raises(monkeypatch):
monkeypatch.delenv("MISSING_REQ", raising=False)
with pytest.raises(MissingEnvError):
require_env("MISSING_REQ", "critical path")
def test_require_env_empty_raises(monkeypatch):
monkeypatch.setenv("EMPTY_REQ", "")
with pytest.raises(MissingEnvError):
require_env("EMPTY_REQ")
def test_require_env_whitespace_only_raises(monkeypatch):
monkeypatch.setenv("WS_REQ", " ")
with pytest.raises(MissingEnvError):
require_env("WS_REQ")
def test_optional_env_default(monkeypatch):
monkeypatch.delenv("OPT_A", raising=False)
assert optional_env("OPT_A", default="fallback") == "fallback"
def test_optional_env_set(monkeypatch):
monkeypatch.setenv("OPT_B", "xx")
assert optional_env("OPT_B", default="fallback") == "xx"
def test_fail_fast_all_present(monkeypatch):
monkeypatch.setenv("AA", "1")
monkeypatch.setenv("BB", "2")
fail_fast_if_missing(["AA", "BB"]) # no exit
def test_fail_fast_missing_exits(monkeypatch):
monkeypatch.setenv("HAVE_IT", "1")
monkeypatch.delenv("MISSING_X", raising=False)
with pytest.raises(SystemExit) as exc:
fail_fast_if_missing(["HAVE_IT", "MISSING_X"])
assert exc.value.code == 2
def test_summarize_does_not_leak_secrets(monkeypatch, caplog):
import logging
monkeypatch.setenv("API_KEY_FOO", "super-secret-token-123456")
monkeypatch.setenv("PORT", "9000")
with caplog.at_level(logging.INFO, logger="option_mcp_common.env_validation"):
summarize(["API_KEY_FOO", "PORT", "NOT_SET_XYZ"])
log_text = "\n".join(caplog.messages)
assert "super-secret-token-123456" not in log_text
assert "9000" in log_text
assert "<unset>" in log_text
+80
View File
@@ -0,0 +1,80 @@
from option_mcp_common.indicators import adx, atr, macd, rsi, sma
def test_rsi_simple():
closes = [44, 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84,
46.08, 45.89, 46.03, 45.61, 46.28]
r = rsi(closes, period=14)
assert r is not None
# Known textbook RSI value ballpark
assert 65.0 < r < 75.0
def test_rsi_insufficient_data():
assert rsi([1, 2, 3], period=14) is None
def test_sma_simple():
assert sma([1, 2, 3, 4, 5], period=5) == 3.0
assert sma([1, 2, 3], period=5) is None
def test_atr_simple():
# highs, lows, closes
highs = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
lows = [ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
closes = [9.5,10.5,11.5,12.5,13.5,14.5,15.5,16.5,17.5,18.5,19.5,20.5,21.5,22.5,23.5]
a = atr(highs, lows, closes, period=14)
assert a is not None
assert 0.9 < a <= 1.5
def test_macd_trend_up():
# monotonic uptrend → MACD > 0, histogram > 0
closes = [float(i) for i in range(1, 60)]
m = macd(closes)
assert m["macd"] is not None
assert m["signal"] is not None
assert m["hist"] is not None
assert m["macd"] > 0
assert m["hist"] >= 0
def test_macd_insufficient_data():
m = macd([1.0, 2.0, 3.0])
assert m == {"macd": None, "signal": None, "hist": None}
def test_macd_trend_down():
closes = [float(i) for i in range(60, 1, -1)]
m = macd(closes)
assert m["macd"] < 0
assert m["hist"] <= 0
def test_adx_insufficient_data():
a = adx([1.0] * 10, [0.5] * 10, [0.7] * 10, period=14)
assert a == {"adx": None, "+di": None, "-di": None}
def test_adx_strong_uptrend():
highs = [float(i) + 1.0 for i in range(1, 40)]
lows = [float(i) for i in range(1, 40)]
closes = [float(i) + 0.5 for i in range(1, 40)]
a = adx(highs, lows, closes, period=14)
assert a["adx"] is not None
assert a["+di"] is not None and a["-di"] is not None
# strong uptrend → +DI >> -DI, ADX high
assert a["+di"] > a["-di"]
assert a["adx"] > 50.0
def test_adx_flat_market():
highs = [10.0] * 40
lows = [9.0] * 40
closes = [9.5] * 40
a = adx(highs, lows, closes, period=14)
# no directional movement → ADX near 0
assert a["adx"] is not None
assert a["adx"] < 5.0
+77
View File
@@ -0,0 +1,77 @@
import json
import logging
from option_mcp_common.logging import (
SecretsFilter,
configure_root_logging,
get_json_logger,
)
def test_secrets_filter_masks_bearer():
f = SecretsFilter()
rec = logging.LogRecord(
name="t", level=logging.INFO, pathname="", lineno=0,
msg="Got Bearer abcdef123456 from client",
args=(), exc_info=None,
)
f.filter(rec)
assert "abcdef" not in rec.msg
assert "***" in rec.msg
def test_secrets_filter_masks_api_key_json():
f = SecretsFilter()
rec = logging.LogRecord(
name="t", level=logging.INFO, pathname="", lineno=0,
msg='{"api_key": "sk-live-abc123xyz"}',
args=(), exc_info=None,
)
f.filter(rec)
assert "sk-live-abc123xyz" not in rec.msg
def test_json_logger_outputs_json(capsys):
logger = get_json_logger("test")
logger.info("hello", extra={"user_id": 42})
captured = capsys.readouterr()
# output is on stderr by default for json logger
line = (captured.err or captured.out).strip().splitlines()[-1]
data = json.loads(line)
assert data["message"] == "hello"
assert data["user_id"] == 42
def test_configure_root_json_format(monkeypatch, capsys):
monkeypatch.setenv("LOG_FORMAT", "json")
monkeypatch.setenv("LOG_LEVEL", "INFO")
configure_root_logging()
logging.info("root json test")
line = capsys.readouterr().err.strip().splitlines()[-1]
data = json.loads(line)
assert data["message"] == "root json test"
assert data["levelname"] == "INFO"
def test_configure_root_text_format(monkeypatch, capsys):
monkeypatch.setenv("LOG_FORMAT", "text")
configure_root_logging()
logging.info("root text test")
line = capsys.readouterr().err.strip().splitlines()[-1]
# text format non è JSON parseable
try:
json.loads(line)
raise AssertionError("expected text format, got JSON")
except json.JSONDecodeError:
pass
assert "root text test" in line
def test_configure_root_applies_secrets_filter(monkeypatch, capsys):
monkeypatch.setenv("LOG_FORMAT", "json")
configure_root_logging()
logging.info("calling with Bearer sk-live-leak123456 token")
line = capsys.readouterr().err.strip().splitlines()[-1]
data = json.loads(line)
assert "sk-live-leak123456" not in data["message"]
assert "***" in data["message"]
+112
View File
@@ -0,0 +1,112 @@
from __future__ import annotations
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from option_mcp_common.auth import Principal, TokenStore, require_principal
from option_mcp_common.mcp_bridge import _derive_input_schemas, mount_mcp_endpoint
from option_mcp_common.server import build_app
from pydantic import BaseModel
class EchoBody(BaseModel):
msg: str
n: int = 1
def _make_app() -> tuple[FastAPI, TokenStore]:
store = TokenStore(tokens={"t": Principal("obs", {"observer"})})
app = build_app(name="t", version="v", token_store=store)
@app.post("/tools/echo")
def echo(body: EchoBody, principal: Principal = Depends(require_principal)):
return {"echo": body.msg, "n": body.n}
@app.post("/tools/ping")
def ping(principal: Principal = Depends(require_principal)):
return {"pong": True}
return app, store
def test_derive_input_schemas_resolves_lazy_annotations():
app, _ = _make_app()
schemas = _derive_input_schemas(app, ["echo", "ping"])
assert "echo" in schemas
echo_schema = schemas["echo"]
assert echo_schema["type"] == "object"
assert "msg" in echo_schema["properties"]
assert "n" in echo_schema["properties"]
assert "msg" in echo_schema["required"]
# ping has no Pydantic body → not in map (fallback applied by caller)
assert "ping" not in schemas
def test_mount_mcp_endpoint_exposes_derived_schemas():
app, store = _make_app()
mount_mcp_endpoint(
app,
name="test",
version="1.0",
token_store=store,
internal_base_url="http://localhost:0",
tools=[
{"name": "echo", "description": "Echo a message."},
{"name": "ping", "description": "Ping."},
],
)
c = TestClient(app)
r = c.post(
"/mcp",
headers={"Authorization": "Bearer t"},
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
)
assert r.status_code == 200
tools = r.json()["result"]["tools"]
by_name = {t["name"]: t for t in tools}
assert set(by_name["echo"]["inputSchema"]["required"]) == {"msg"}
# ping fallback su schema generico
assert by_name["ping"]["inputSchema"] == {
"type": "object",
"additionalProperties": True,
}
def test_mount_mcp_endpoint_requires_auth():
app, store = _make_app()
mount_mcp_endpoint(
app,
name="test",
version="1.0",
token_store=store,
internal_base_url="http://localhost:0",
tools=[{"name": "echo"}],
)
c = TestClient(app)
r = c.post("/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"})
assert r.status_code == 401
r = c.post(
"/mcp",
headers={"Authorization": "Bearer WRONG"},
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
)
assert r.status_code == 403
def test_explicit_input_schema_overrides_derived():
app, store = _make_app()
custom = {"type": "object", "properties": {"custom": {"type": "string"}}, "required": ["custom"]}
mount_mcp_endpoint(
app,
name="test",
version="1.0",
token_store=store,
internal_base_url="http://localhost:0",
tools=[{"name": "echo", "input_schema": custom}],
)
c = TestClient(app)
r = c.post(
"/mcp",
headers={"Authorization": "Bearer t"},
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
)
assert r.json()["result"]["tools"][0]["inputSchema"] == custom
+40
View File
@@ -0,0 +1,40 @@
from option_mcp_common.models import EventPriority, EventType, L2Entry
def test_l2_entry_minimal():
entry = L2Entry(
timestamp="2026-04-17T10:30:00Z",
setup="bull put spread ETH 1800/1750 14d",
tesi="IV alta post-CPI, attesa mean-reversion",
esito="aperto"
)
assert entry.scostamento_sigma is None
assert entry.tesi_check is None
def test_l2_entry_full():
entry = L2Entry(
timestamp="2026-04-17T10:30:00Z",
setup="bull put spread ETH 1800/1750 14d",
tesi="IV alta post-CPI",
tesi_check="ETH sopra 1820 per 24h con IV in calo",
invalidation="rottura 1800 con volume > 2x media",
esito="chiuso +12 USDC",
scostamento="nessuno",
scostamento_sigma=0.5,
lezione="supporto ha tenuto, timing ok",
sizing_note="size 80 USDC (ATR 1.3x media)",
)
assert entry.scostamento_sigma == 0.5
dump = entry.model_dump()
assert dump["lezione"] == "supporto ha tenuto, timing ok"
def test_event_priority_enum():
assert EventPriority.CRITICAL.value == "critical"
assert EventPriority.LOW < EventPriority.CRITICAL # ordering
def test_event_type_enum():
assert EventType.ALERT.value == "alert"
assert EventType.USER_INSTRUCTION.value == "user_instruction"
+80
View File
@@ -0,0 +1,80 @@
"""Tests for CER-016 risk guard."""
from __future__ import annotations
import pytest
from fastapi import HTTPException
from option_mcp_common.risk_guard import (
enforce_aggregate,
enforce_leverage,
enforce_single_notional,
max_aggregate,
max_leverage,
max_notional,
)
def test_defaults(monkeypatch):
for k in ("CERBERO_MAX_NOTIONAL", "CERBERO_MAX_AGGREGATE", "CERBERO_MAX_LEVERAGE"):
monkeypatch.delenv(k, raising=False)
assert max_notional() == 200.0
assert max_aggregate() == 1000.0
assert max_leverage() == 3
def test_env_override(monkeypatch):
monkeypatch.setenv("CERBERO_MAX_NOTIONAL", "50")
monkeypatch.setenv("CERBERO_MAX_AGGREGATE", "150")
monkeypatch.setenv("CERBERO_MAX_LEVERAGE", "2")
assert max_notional() == 50.0
assert max_aggregate() == 150.0
assert max_leverage() == 2
def test_leverage_default_when_none(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
assert enforce_leverage(None) == 3
def test_leverage_accepts_within_cap(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
assert enforce_leverage(2) == 2
assert enforce_leverage(3) == 3
def test_leverage_rejects_above_cap(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_leverage(50)
assert exc.value.status_code == 403
assert exc.value.detail["error"] == "HARD_PROHIBITION"
def test_leverage_rejects_below_one(monkeypatch):
with pytest.raises(HTTPException):
enforce_leverage(0)
def test_single_notional_ok(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
enforce_single_notional(100.0, exchange="deribit", instrument="BTC-PERPETUAL")
def test_single_notional_rejects(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_single_notional(335.0, exchange="hyperliquid", instrument="ETH")
assert exc.value.status_code == 403
assert "335" in exc.value.detail["message"]
def test_aggregate_ok(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
enforce_aggregate(current_total=500.0, new_notional=200.0)
def test_aggregate_rejects(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_aggregate(current_total=900.0, new_notional=200.0)
assert exc.value.status_code == 403
+90
View File
@@ -0,0 +1,90 @@
from fastapi.testclient import TestClient
from option_mcp_common.auth import Principal, TokenStore
from option_mcp_common.server import build_app
def test_build_app_health():
store = TokenStore(tokens={})
app = build_app(name="test-mcp", version="0.0.1", token_store=store)
client = TestClient(app)
r = client.get("/health")
assert r.status_code == 200
body = r.json()
assert body["status"] == "healthy"
assert body["name"] == "test-mcp"
assert body["version"] == "0.0.1"
assert "uptime_seconds" in body
assert "data_timestamp" in body
assert r.headers.get("X-Duration-Ms") is not None
def test_build_app_adds_token_store():
store = TokenStore(tokens={"t1": Principal("x", {"core"})})
app = build_app(name="t", version="v", token_store=store)
assert app.state.token_store is store
def test_timestamp_injector_dict_response():
"""CER-P5-001: dict response gets data_timestamp + X-Data-Timestamp header."""
store = TokenStore(tokens={})
app = build_app(name="t", version="v", token_store=store)
@app.post("/tools/foo")
def foo():
return {"ok": True}
client = TestClient(app)
r = client.post("/tools/foo")
assert r.status_code == 200
body = r.json()
assert body["ok"] is True
assert "data_timestamp" in body
assert r.headers.get("X-Data-Timestamp") is not None
def test_timestamp_injector_list_of_dicts():
"""CER-P5-001: list of dicts → each item gets data_timestamp."""
store = TokenStore(tokens={})
app = build_app(name="t", version="v", token_store=store)
@app.post("/tools/list_items")
def list_items():
return [{"x": 1}, {"x": 2}]
client = TestClient(app)
r = client.post("/tools/list_items")
body = r.json()
assert isinstance(body, list)
assert len(body) == 2
for item in body:
assert "data_timestamp" in item
assert r.headers.get("X-Data-Timestamp") is not None
def test_timestamp_injector_preserves_existing():
"""CER-P5-001: se già presente, non override."""
store = TokenStore(tokens={})
app = build_app(name="t", version="v", token_store=store)
@app.post("/tools/already")
def already():
return {"data_timestamp": "2020-01-01T00:00:00Z", "x": 1}
client = TestClient(app)
body = client.post("/tools/already").json()
assert body["data_timestamp"] == "2020-01-01T00:00:00Z"
def test_timestamp_injector_empty_list_gets_header_only():
"""CER-P5-001: list vuota — no body modification, ma header presente."""
store = TokenStore(tokens={})
app = build_app(name="t", version="v", token_store=store)
@app.post("/tools/empty_list")
def empty_list():
return []
client = TestClient(app)
r = client.post("/tools/empty_list")
assert r.json() == []
assert r.headers.get("X-Data-Timestamp") is not None
+48
View File
@@ -0,0 +1,48 @@
from pathlib import Path
from option_mcp_common.storage import Database, run_migrations
def test_database_creates_wal(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
# WAL mode attivo
mode = db.conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode.lower() == "wal"
db.close()
def test_database_migrations_run_once(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
migrations = {
1: "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);",
2: "ALTER TABLE foo ADD COLUMN value INTEGER DEFAULT 0;",
}
run_migrations(db.conn, migrations)
# Second run: should be no-op
run_migrations(db.conn, migrations)
cols = [r[1] for r in db.conn.execute("PRAGMA table_info(foo)").fetchall()]
assert "name" in cols
assert "value" in cols
version = db.conn.execute("SELECT MAX(version) FROM _schema_version").fetchone()[0]
assert version == 2
db.close()
def test_database_partial_migration(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
migrations_v1 = {1: "CREATE TABLE foo (id INTEGER);"}
run_migrations(db.conn, migrations_v1)
migrations_v2 = {**migrations_v1, 2: "CREATE TABLE bar (id INTEGER);"}
run_migrations(db.conn, migrations_v2)
tables = {r[0] for r in db.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()}
assert "foo" in tables
assert "bar" in tables
db.close()