feat: import 6 MCP services + common workspace
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "option-mcp-common"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115",
|
||||
"uvicorn[standard]>=0.30",
|
||||
"mcp>=1.0",
|
||||
"httpx>=0.27",
|
||||
"pydantic>=2.6",
|
||||
"pydantic-settings>=2.3",
|
||||
"python-json-logger>=2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest>=8", "pytest-asyncio>=0.23", "pytest-httpx>=0.30", "ruff>=0.5"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/option_mcp_common"]
|
||||
@@ -0,0 +1,19 @@
|
||||
from option_mcp_common.models import (
|
||||
Event,
|
||||
EventPriority,
|
||||
EventType,
|
||||
L1State,
|
||||
L2Entry,
|
||||
L3Entry,
|
||||
UserInstruction,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"L1State",
|
||||
"L2Entry",
|
||||
"L3Entry",
|
||||
"Event",
|
||||
"EventPriority",
|
||||
"EventType",
|
||||
"UserInstruction",
|
||||
]
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
@dataclass
|
||||
class Principal:
|
||||
name: str
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenStore:
|
||||
tokens: dict[str, Principal]
|
||||
|
||||
def get(self, token: str) -> Principal | None:
|
||||
return self.tokens.get(token)
|
||||
|
||||
|
||||
def require_principal(request: Request) -> Principal:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
||||
token = auth[len("Bearer "):].strip()
|
||||
store: TokenStore = request.app.state.token_store
|
||||
principal = store.get(token)
|
||||
if principal is None:
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid token")
|
||||
return principal
|
||||
|
||||
|
||||
def acl_requires(*, core: bool = False, observer: bool = False) -> Callable:
|
||||
"""Decorator: require at least one matching capability."""
|
||||
allowed: set[str] = set()
|
||||
if core:
|
||||
allowed.add("core")
|
||||
if observer:
|
||||
allowed.add("observer")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
principal = kwargs.get("principal")
|
||||
if principal is None:
|
||||
for a in args:
|
||||
if isinstance(a, Principal):
|
||||
principal = a
|
||||
break
|
||||
if principal is None or not (principal.capabilities & allowed):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
f"capability required: {allowed}",
|
||||
)
|
||||
return await func(*args, **kwargs) if _is_coro(func) else func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
principal = kwargs.get("principal")
|
||||
if principal is None:
|
||||
for a in args:
|
||||
if isinstance(a, Principal):
|
||||
principal = a
|
||||
break
|
||||
if principal is None or not (principal.capabilities & allowed):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
f"capability required: {allowed}",
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if _is_coro(func) else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_coro(func: Callable) -> bool:
|
||||
import asyncio
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def load_token_store_from_files(
|
||||
core_token_file: str | None,
|
||||
observer_token_file: str | None,
|
||||
) -> TokenStore:
|
||||
tokens: dict[str, Principal] = {}
|
||||
if core_token_file:
|
||||
with open(core_token_file) as f:
|
||||
tokens[f.read().strip()] = Principal(name="core", capabilities={"core"})
|
||||
if observer_token_file:
|
||||
with open(observer_token_file) as f:
|
||||
tokens[f.read().strip()] = Principal(
|
||||
name="observer", capabilities={"observer"}
|
||||
)
|
||||
return TokenStore(tokens=tokens)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""CER-P5-010: env validation policy — fail-fast per mandatory, soft per optional.
|
||||
|
||||
Usage al boot di ogni mcp `__main__.py`:
|
||||
|
||||
from option_mcp_common.env_validation import require_env, optional_env, summarize
|
||||
|
||||
creds_file = require_env("CREDENTIALS_FILE", "deribit credentials JSON path")
|
||||
host = optional_env("HOST", default="0.0.0.0")
|
||||
summarize(["CREDENTIALS_FILE", "HOST", "PORT"])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MissingEnvError(RuntimeError):
|
||||
"""Mandatory env var absent or empty."""
|
||||
|
||||
|
||||
def require_env(name: str, description: str = "") -> str:
|
||||
"""Fail-fast: raise MissingEnvError se name non presente o vuoto.
|
||||
|
||||
Uscita dal processo con codice 2 se chiamato dal main(). Comporta
|
||||
logging chiaro del missing var prima dell'exit.
|
||||
"""
|
||||
val = (os.environ.get(name) or "").strip()
|
||||
if not val:
|
||||
msg = f"missing mandatory env var: {name}"
|
||||
if description:
|
||||
msg += f" ({description})"
|
||||
logger.error(msg)
|
||||
raise MissingEnvError(msg)
|
||||
return val
|
||||
|
||||
|
||||
def optional_env(name: str, *, default: str = "") -> str:
|
||||
"""Soft: ritorna env o default. Log INFO se default usato."""
|
||||
val = (os.environ.get(name) or "").strip()
|
||||
if not val:
|
||||
if default:
|
||||
logger.info("env %s not set, using default=%r", name, default)
|
||||
return default
|
||||
return val
|
||||
|
||||
|
||||
def summarize(names: list[str]) -> None:
|
||||
"""Log INFO di tutti gli env rilevanti con presenza (mask se SECRET/KEY/TOKEN)."""
|
||||
sensitive_tokens = ("SECRET", "KEY", "TOKEN", "PASSWORD", "CREDENTIAL", "WALLET")
|
||||
for n in names:
|
||||
val = os.environ.get(n)
|
||||
if val is None:
|
||||
logger.info("env[%s]: <unset>", n)
|
||||
continue
|
||||
if any(t in n.upper() for t in sensitive_tokens):
|
||||
logger.info("env[%s]: <set, %d chars>", n, len(val))
|
||||
else:
|
||||
logger.info("env[%s]: %s", n, val)
|
||||
|
||||
|
||||
def fail_fast_if_missing(names: list[str]) -> None:
|
||||
"""Verifica lista di nomi mandatory al boot. Exit 2 se uno solo manca.
|
||||
|
||||
Uso preferito: early call in main() per bloccare boot se config incompleta.
|
||||
"""
|
||||
missing: list[str] = []
|
||||
for n in names:
|
||||
if not (os.environ.get(n) or "").strip():
|
||||
missing.append(n)
|
||||
if missing:
|
||||
logger.error("boot aborted: missing mandatory env vars: %s", missing)
|
||||
print(
|
||||
f"FATAL: missing mandatory env vars: {missing}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def sma(values: list[float], period: int) -> float | None:
|
||||
if len(values) < period:
|
||||
return None
|
||||
return sum(values[-period:]) / period
|
||||
|
||||
|
||||
def rsi(closes: list[float], period: int = 14) -> float | None:
|
||||
if len(closes) < period + 1:
|
||||
return None
|
||||
gains: list[float] = []
|
||||
losses: list[float] = []
|
||||
for i in range(1, len(closes)):
|
||||
delta = closes[i] - closes[i - 1]
|
||||
gains.append(max(delta, 0.0))
|
||||
losses.append(-min(delta, 0.0))
|
||||
avg_gain = sum(gains[:period]) / period
|
||||
avg_loss = sum(losses[:period]) / period
|
||||
for i in range(period, len(gains)):
|
||||
avg_gain = (avg_gain * (period - 1) + gains[i]) / period
|
||||
avg_loss = (avg_loss * (period - 1) + losses[i]) / period
|
||||
if avg_loss == 0:
|
||||
return 100.0
|
||||
rs = avg_gain / avg_loss
|
||||
return 100.0 - (100.0 / (1.0 + rs))
|
||||
|
||||
|
||||
def _ema_series(values: list[float], period: int) -> list[float]:
|
||||
if len(values) < period:
|
||||
return []
|
||||
k = 2.0 / (period + 1)
|
||||
seed = sum(values[:period]) / period
|
||||
out = [seed]
|
||||
for v in values[period:]:
|
||||
out.append(out[-1] + k * (v - out[-1]))
|
||||
return out
|
||||
|
||||
|
||||
def macd(
|
||||
closes: list[float],
|
||||
fast: int = 12,
|
||||
slow: int = 26,
|
||||
signal: int = 9,
|
||||
) -> dict[str, float | None]:
|
||||
nothing: dict[str, float | None] = {"macd": None, "signal": None, "hist": None}
|
||||
if len(closes) < slow + signal:
|
||||
return nothing
|
||||
ema_fast = _ema_series(closes, fast)
|
||||
ema_slow = _ema_series(closes, slow)
|
||||
offset = slow - fast
|
||||
aligned_fast = ema_fast[offset:]
|
||||
macd_line = [f - s for f, s in zip(aligned_fast, ema_slow, strict=False)]
|
||||
if len(macd_line) < signal:
|
||||
return nothing
|
||||
signal_line = _ema_series(macd_line, signal)
|
||||
if not signal_line:
|
||||
return nothing
|
||||
last_macd = macd_line[-1]
|
||||
last_sig = signal_line[-1]
|
||||
return {
|
||||
"macd": last_macd,
|
||||
"signal": last_sig,
|
||||
"hist": last_macd - last_sig,
|
||||
}
|
||||
|
||||
|
||||
def atr(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> float | None:
|
||||
if len(closes) < period + 1:
|
||||
return None
|
||||
trs: list[float] = []
|
||||
for i in range(1, len(closes)):
|
||||
tr = max(
|
||||
highs[i] - lows[i],
|
||||
abs(highs[i] - closes[i - 1]),
|
||||
abs(lows[i] - closes[i - 1]),
|
||||
)
|
||||
trs.append(tr)
|
||||
if len(trs) < period:
|
||||
return None
|
||||
avg = sum(trs[:period]) / period
|
||||
for i in range(period, len(trs)):
|
||||
avg = (avg * (period - 1) + trs[i]) / period
|
||||
return avg
|
||||
|
||||
|
||||
def adx(
|
||||
highs: list[float],
|
||||
lows: list[float],
|
||||
closes: list[float],
|
||||
period: int = 14,
|
||||
) -> dict[str, float | None]:
|
||||
nothing: dict[str, float | None] = {"adx": None, "+di": None, "-di": None}
|
||||
if len(closes) < 2 * period + 1:
|
||||
return nothing
|
||||
trs: list[float] = []
|
||||
plus_dms: list[float] = []
|
||||
minus_dms: list[float] = []
|
||||
for i in range(1, len(closes)):
|
||||
tr = max(
|
||||
highs[i] - lows[i],
|
||||
abs(highs[i] - closes[i - 1]),
|
||||
abs(lows[i] - closes[i - 1]),
|
||||
)
|
||||
up = highs[i] - highs[i - 1]
|
||||
dn = lows[i - 1] - lows[i]
|
||||
plus_dm = up if (up > dn and up > 0) else 0.0
|
||||
minus_dm = dn if (dn > up and dn > 0) else 0.0
|
||||
trs.append(tr)
|
||||
plus_dms.append(plus_dm)
|
||||
minus_dms.append(minus_dm)
|
||||
|
||||
atr_s = sum(trs[:period])
|
||||
pdm_s = sum(plus_dms[:period])
|
||||
mdm_s = sum(minus_dms[:period])
|
||||
dxs: list[float] = []
|
||||
pdi = mdi = 0.0
|
||||
for i in range(period, len(trs)):
|
||||
atr_s = atr_s - atr_s / period + trs[i]
|
||||
pdm_s = pdm_s - pdm_s / period + plus_dms[i]
|
||||
mdm_s = mdm_s - mdm_s / period + minus_dms[i]
|
||||
pdi = 100.0 * pdm_s / atr_s if atr_s else 0.0
|
||||
mdi = 100.0 * mdm_s / atr_s if atr_s else 0.0
|
||||
s = pdi + mdi
|
||||
dx = 100.0 * abs(pdi - mdi) / s if s else 0.0
|
||||
dxs.append(dx)
|
||||
|
||||
if len(dxs) < period:
|
||||
return nothing
|
||||
adx_val = sum(dxs[:period]) / period
|
||||
for i in range(period, len(dxs)):
|
||||
adx_val = (adx_val * (period - 1) + dxs[i]) / period
|
||||
return {"adx": adx_val, "+di": pdi, "-di": mdi}
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
# pythonjsonlogger rinominato in .json; keep fallback per compat
|
||||
try:
|
||||
from pythonjsonlogger.json import JsonFormatter as _JsonFormatter # noqa: N813
|
||||
except ImportError:
|
||||
from pythonjsonlogger.jsonlogger import JsonFormatter as _JsonFormatter # noqa: N813
|
||||
|
||||
SECRET_PATTERNS = [
|
||||
(re.compile(r"Bearer\s+[\w\-\._]+", re.IGNORECASE), "Bearer ***"),
|
||||
(re.compile(r'("api_key"\s*:\s*")[^"]+(")'), r'\1***\2'),
|
||||
(re.compile(r'("password"\s*:\s*")[^"]+(")'), r'\1***\2'),
|
||||
(re.compile(r'("private_key"\s*:\s*")[^"]+(")'), r'\1***\2'),
|
||||
(re.compile(r'("client_secret"\s*:\s*")[^"]+(")'), r'\1***\2'),
|
||||
(re.compile(r"sk-[\w]{20,}"), "sk-***"),
|
||||
]
|
||||
|
||||
|
||||
class SecretsFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
msg = record.getMessage()
|
||||
for pattern, replacement in SECRET_PATTERNS:
|
||||
msg = pattern.sub(replacement, msg)
|
||||
record.msg = msg
|
||||
record.args = () # already formatted into msg
|
||||
return True
|
||||
|
||||
|
||||
def get_json_logger(name: str, level: int = logging.INFO) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
if logger.handlers:
|
||||
return logger # already configured
|
||||
logger.setLevel(level)
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
formatter = _JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
handler.addFilter(SecretsFilter())
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
def configure_root_logging(
|
||||
*,
|
||||
level: str | int | None = None,
|
||||
format_type: str | None = None,
|
||||
) -> None:
|
||||
"""CER-P5-009: configura il root logger con JSON o text formatter.
|
||||
|
||||
Env overrides:
|
||||
- LOG_LEVEL (default INFO)
|
||||
- LOG_FORMAT=json|text (default json — production-ready structured log)
|
||||
|
||||
Applica SecretsFilter su entrambi i format.
|
||||
"""
|
||||
lvl_raw = level if level is not None else os.environ.get("LOG_LEVEL", "INFO")
|
||||
lvl = logging.getLevelName(lvl_raw.upper()) if isinstance(lvl_raw, str) else lvl_raw
|
||||
fmt = (format_type or os.environ.get("LOG_FORMAT") or "json").lower()
|
||||
|
||||
root = logging.getLogger()
|
||||
# Rimuovi handler esistenti (basicConfig li avrebbe lasciati duplicati)
|
||||
for h in list(root.handlers):
|
||||
root.removeHandler(h)
|
||||
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
if fmt == "json":
|
||||
handler.setFormatter(
|
||||
_JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||
)
|
||||
else:
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
)
|
||||
handler.addFilter(SecretsFilter())
|
||||
root.addHandler(handler)
|
||||
root.setLevel(lvl)
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Bridge MCP → endpoint REST esistenti.
|
||||
|
||||
Implementa manualmente JSON-RPC 2.0 MCP su `POST /mcp` (no SSE, risposta
|
||||
diretta in body JSON). Supporta:
|
||||
- initialize
|
||||
- notifications/initialized
|
||||
- tools/list
|
||||
- tools/call
|
||||
|
||||
Claude Code config esempio:
|
||||
|
||||
{
|
||||
"mcpServers": {
|
||||
"cerbero-memory": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8080/mcp-memory/mcp",
|
||||
"headers": {"Authorization": "Bearer <observer-token>"}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from option_mcp_common.auth import TokenStore
|
||||
|
||||
MCP_PROTOCOL_VERSION = "2024-11-05"
|
||||
|
||||
|
||||
def _derive_input_schemas(app: FastAPI, tool_names: list[str]) -> dict[str, dict]:
|
||||
"""Estrae JSON schema del body Pydantic per ogni route POST /tools/<name>.
|
||||
|
||||
Risolve annotazioni lazy (PEP 563) via `typing.get_type_hints`.
|
||||
Ritorna mapping {tool_name: json_schema}. Route senza body Pydantic o non
|
||||
risolvibili vengono saltate: il chiamante userà un fallback.
|
||||
"""
|
||||
import typing
|
||||
from pydantic import BaseModel
|
||||
|
||||
names_set = set(tool_names)
|
||||
out: dict[str, dict] = {}
|
||||
for route in app.routes:
|
||||
path = getattr(route, "path", "")
|
||||
if not path.startswith("/tools/"):
|
||||
continue
|
||||
name = path[len("/tools/"):]
|
||||
if name not in names_set:
|
||||
continue
|
||||
endpoint = getattr(route, "endpoint", None)
|
||||
if endpoint is None:
|
||||
continue
|
||||
try:
|
||||
hints = typing.get_type_hints(endpoint)
|
||||
except Exception:
|
||||
continue
|
||||
for pname, ann in hints.items():
|
||||
if pname == "return":
|
||||
continue
|
||||
if isinstance(ann, type) and issubclass(ann, BaseModel):
|
||||
try:
|
||||
out[name] = ann.model_json_schema()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _make_proxy_handler(internal_base_url: str, tool_name: str, token: str):
|
||||
async def handler(args: dict | None) -> Any:
|
||||
async with httpx.AsyncClient(timeout=30.0) as c:
|
||||
r = await c.post(
|
||||
f"{internal_base_url}/tools/{tool_name}",
|
||||
headers={"Authorization": f"Bearer {token}"} if token else {},
|
||||
json=args or {},
|
||||
)
|
||||
if r.status_code >= 400:
|
||||
raise RuntimeError(
|
||||
f"tool {tool_name} failed: HTTP {r.status_code} — {r.text[:500]}"
|
||||
)
|
||||
try:
|
||||
return r.json()
|
||||
except Exception:
|
||||
return {"raw": r.text}
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def mount_mcp_endpoint(
|
||||
app: FastAPI,
|
||||
*,
|
||||
name: str,
|
||||
version: str,
|
||||
token_store: TokenStore,
|
||||
internal_base_url: str,
|
||||
tools: list[dict],
|
||||
) -> None:
|
||||
"""Registra un endpoint MCP JSON-RPC 2.0 su POST /mcp.
|
||||
|
||||
Ogni tool è proxato verso POST {internal_base_url}/tools/<name> con il
|
||||
Bearer token del client MCP (preservando le ACL REST esistenti).
|
||||
|
||||
Args:
|
||||
app: istanza FastAPI del service
|
||||
name: nome server MCP
|
||||
version: versione del service
|
||||
token_store: lo stesso usato dai tool REST
|
||||
internal_base_url: URL base interno (es. "http://localhost:9015")
|
||||
tools: lista di {"name": str, "description": str, "input_schema"?: dict}
|
||||
"""
|
||||
tools_by_name = {t["name"]: t for t in tools}
|
||||
|
||||
# Auto-derive input schemas from FastAPI routes (Pydantic body models).
|
||||
# Permette al LLM di conoscere i nomi dei parametri obbligatori invece di
|
||||
# indovinarli. Se il tool ha `input_schema` esplicito, vince sull'auto-derive.
|
||||
derived_schemas = _derive_input_schemas(app, [t["name"] for t in tools])
|
||||
|
||||
def _tool_defs() -> list[dict]:
|
||||
defs = []
|
||||
for t in tools:
|
||||
schema = t.get("input_schema") or derived_schemas.get(t["name"]) or {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
}
|
||||
defs.append({
|
||||
"name": t["name"],
|
||||
"description": t.get("description", t["name"]),
|
||||
"inputSchema": schema,
|
||||
})
|
||||
return defs
|
||||
|
||||
async def _handle_rpc(body: dict, token: str) -> dict | None:
|
||||
rpc_id = body.get("id")
|
||||
method = body.get("method")
|
||||
params = body.get("params") or {}
|
||||
|
||||
# Notification (no id) → no response
|
||||
if method == "notifications/initialized":
|
||||
return None
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"result": {
|
||||
"protocolVersion": MCP_PROTOCOL_VERSION,
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": name, "version": version},
|
||||
},
|
||||
}
|
||||
|
||||
if method == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"result": {"tools": _tool_defs()},
|
||||
}
|
||||
|
||||
if method == "tools/call":
|
||||
tool_name = params.get("name", "")
|
||||
args = params.get("arguments") or {}
|
||||
if tool_name not in tools_by_name:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"error": {"code": -32601, "message": f"tool non trovato: {tool_name}"},
|
||||
}
|
||||
handler = _make_proxy_handler(internal_base_url, tool_name, token)
|
||||
try:
|
||||
result = await handler(args)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": _to_text(result),
|
||||
}
|
||||
],
|
||||
"isError": False,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": str(e)}],
|
||||
"isError": True,
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": rpc_id,
|
||||
"error": {"code": -32601, "message": f"metodo non supportato: {method}"},
|
||||
}
|
||||
|
||||
@app.post("/mcp")
|
||||
async def mcp_entry(request: Request):
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
return JSONResponse({"error": "missing bearer token"}, status_code=401)
|
||||
token = auth[len("Bearer "):].strip()
|
||||
principal = token_store.get(token)
|
||||
if principal is None:
|
||||
return JSONResponse({"error": "invalid token"}, status_code=403)
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Batch support
|
||||
if isinstance(body, list):
|
||||
results = []
|
||||
for item in body:
|
||||
resp = await _handle_rpc(item, token)
|
||||
if resp is not None:
|
||||
results.append(resp)
|
||||
return JSONResponse(results)
|
||||
|
||||
resp = await _handle_rpc(body, token)
|
||||
if resp is None:
|
||||
# Notification (no id) → 204 no content
|
||||
return JSONResponse(None, status_code=204)
|
||||
return JSONResponse(resp)
|
||||
|
||||
|
||||
def _to_text(value: Any) -> str:
|
||||
import json
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
return str(value)
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from functools import total_ordering
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@total_ordering
|
||||
class EventPriority(StrEnum):
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
def _rank(self) -> int:
|
||||
return ["low", "normal", "high", "critical"].index(self.value)
|
||||
|
||||
def __lt__(self, other: EventPriority) -> bool:
|
||||
return self._rank() < other._rank()
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
ALERT = "alert"
|
||||
USER_INSTRUCTION = "user_instruction"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class L1State(BaseModel):
|
||||
"""Singleton row with current operational state."""
|
||||
|
||||
updated_at: str
|
||||
equity_total: float | None = None
|
||||
equity_by_exchange: dict[str, float] = Field(default_factory=dict)
|
||||
bias: str | None = None
|
||||
pnl_day: float | None = None
|
||||
pnl_total: float | None = None
|
||||
capital: float | None = None
|
||||
open_positions_count: int = 0
|
||||
greeks_aggregate: dict[str, float] = Field(default_factory=dict)
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class L2Entry(BaseModel):
|
||||
"""Reasoning entry — schema matches system_prompt v2.
|
||||
|
||||
`authored_by_model`: identifica l'LLM che ha generato la entry (es.
|
||||
"google/gemini-3-flash-preview" per core, "anthropic/claude-haiku-4-5"
|
||||
per worker). None se scritto da operatore umano via UI.
|
||||
"""
|
||||
|
||||
timestamp: str
|
||||
setup: str
|
||||
tesi: str | None = None
|
||||
tesi_check: str | None = None
|
||||
invalidation: str | None = None
|
||||
esito: str
|
||||
scostamento: str | None = None
|
||||
scostamento_sigma: float | None = None
|
||||
lezione: str | None = None
|
||||
sizing_note: str | None = None
|
||||
run_id: str | None = None
|
||||
user_instruction_id: int | None = None
|
||||
authored_by_model: str | None = None
|
||||
|
||||
|
||||
class L3Entry(BaseModel):
|
||||
"""Compacted pattern from L2 batch."""
|
||||
|
||||
created_at: str
|
||||
category: str # "pattern_errore" | "pattern_vincente" | "correlazione"
|
||||
summary: str
|
||||
source_l2_ids: list[int]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
id: int | None = None
|
||||
created_at: str
|
||||
expires_at: str
|
||||
type: EventType
|
||||
source: str
|
||||
priority: EventPriority
|
||||
payload: dict[str, Any]
|
||||
acked_at: str | None = None
|
||||
ack_outcome: str | None = None
|
||||
ack_notes: str | None = None
|
||||
|
||||
|
||||
class UserInstruction(BaseModel):
|
||||
id: int | None = None
|
||||
created_at: str
|
||||
text: str
|
||||
priority: EventPriority
|
||||
require_ack: bool = True
|
||||
source: str = "observer"
|
||||
acked_at: str | None = None
|
||||
ack_outcome: str | None = None
|
||||
@@ -0,0 +1,92 @@
|
||||
"""CER-016 hard guard server-side su place_order.
|
||||
|
||||
Caps configurabili via env (default safety-first, mirati a ~200 EUR single,
|
||||
1000 EUR aggregato, 3x max leverage).
|
||||
|
||||
Thresholds sono numerici semplici — l'operatore stabilisce unità (EUR/USD)
|
||||
via env; il server compara su un unico campo `notional` in valore monetario.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def max_notional() -> float:
|
||||
return _env_float("CERBERO_MAX_NOTIONAL", 200.0)
|
||||
|
||||
|
||||
def max_aggregate() -> float:
|
||||
return _env_float("CERBERO_MAX_AGGREGATE", 1000.0)
|
||||
|
||||
|
||||
def max_leverage() -> int:
|
||||
return _env_int("CERBERO_MAX_LEVERAGE", 3)
|
||||
|
||||
|
||||
def _hard_reject(reason: str) -> None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "HARD_PROHIBITION",
|
||||
"message": reason,
|
||||
"caps": {
|
||||
"max_notional": max_notional(),
|
||||
"max_aggregate": max_aggregate(),
|
||||
"max_leverage": max_leverage(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def enforce_leverage(leverage: int | float | None) -> int:
|
||||
"""Ritorna leverage applicabile. Default 3x se None. Reject se > cap."""
|
||||
cap = max_leverage()
|
||||
if leverage is None:
|
||||
return cap
|
||||
lev = int(leverage)
|
||||
if lev < 1:
|
||||
_hard_reject(f"leverage must be >= 1 (got {lev})")
|
||||
if lev > cap:
|
||||
_hard_reject(f"leverage {lev}x exceeds hard cap {cap}x")
|
||||
return lev
|
||||
|
||||
|
||||
def enforce_single_notional(notional: float, *, exchange: str, instrument: str) -> None:
|
||||
cap = max_notional()
|
||||
if notional > cap:
|
||||
_hard_reject(
|
||||
f"{exchange}.{instrument} notional {notional:.2f} exceeds single trade cap {cap:.2f}"
|
||||
)
|
||||
|
||||
|
||||
def enforce_aggregate(current_total: float, new_notional: float) -> None:
|
||||
cap = max_aggregate()
|
||||
total = current_total + new_notional
|
||||
if total > cap:
|
||||
_hard_reject(
|
||||
f"aggregate notional {total:.2f} (current {current_total:.2f} + new "
|
||||
f"{new_notional:.2f}) exceeds cap {cap:.2f}"
|
||||
)
|
||||
@@ -0,0 +1,220 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from option_mcp_common.auth import TokenStore
|
||||
|
||||
Lifespan = Callable[[FastAPI], AbstractAsyncContextManager[None]]
|
||||
|
||||
|
||||
def _error_envelope(
|
||||
*,
|
||||
type_: str,
|
||||
code: str,
|
||||
message: str,
|
||||
retryable: bool,
|
||||
suggested_fix: str | None = None,
|
||||
details: dict | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> dict:
|
||||
env = {
|
||||
"error": {
|
||||
"type": type_,
|
||||
"code": code,
|
||||
"message": message,
|
||||
"retryable": retryable,
|
||||
},
|
||||
"request_id": request_id or uuid.uuid4().hex,
|
||||
"data_timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
if suggested_fix:
|
||||
env["error"]["suggested_fix"] = suggested_fix
|
||||
if details:
|
||||
env["error"]["details"] = details
|
||||
return env
|
||||
|
||||
|
||||
class _TimestampInjectorMiddleware(BaseHTTPMiddleware):
|
||||
"""CER-P5-001: inietta data_timestamp nei response tool.
|
||||
|
||||
- Dict response: body gains `data_timestamp` se mancante.
|
||||
- List of dicts: ogni item gains `data_timestamp` se mancante.
|
||||
- Header `X-Data-Timestamp` sempre presente (universale per list primitive).
|
||||
Skips /health (già popolato) e /mcp (JSON-RPC bridge) e non-JSON responses.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
path = request.url.path
|
||||
if not path.startswith("/tools/"):
|
||||
return response
|
||||
ctype = response.headers.get("content-type", "")
|
||||
if "application/json" not in ctype:
|
||||
return response
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
ts = datetime.now(UTC).isoformat()
|
||||
try:
|
||||
data = json.loads(body) if body else None
|
||||
except Exception:
|
||||
headers = dict(response.headers)
|
||||
headers["X-Data-Timestamp"] = ts
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
modified = False
|
||||
if isinstance(data, dict) and "data_timestamp" not in data:
|
||||
data["data_timestamp"] = ts
|
||||
modified = True
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "data_timestamp" not in item:
|
||||
item["data_timestamp"] = ts
|
||||
modified = True
|
||||
|
||||
headers = dict(response.headers)
|
||||
headers["X-Data-Timestamp"] = ts
|
||||
if modified:
|
||||
new_body = json.dumps(data, default=str).encode()
|
||||
headers.pop("content-length", None)
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
return Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
|
||||
def build_app(
|
||||
*,
|
||||
name: str,
|
||||
version: str,
|
||||
token_store: TokenStore,
|
||||
lifespan: Lifespan | None = None,
|
||||
) -> FastAPI:
|
||||
root_path = os.getenv("ROOT_PATH", "")
|
||||
app = FastAPI(title=name, version=version, root_path=root_path, lifespan=lifespan)
|
||||
app.state.token_store = token_store
|
||||
app.state.boot_at = time.time()
|
||||
|
||||
app.add_middleware(_TimestampInjectorMiddleware)
|
||||
|
||||
@app.middleware("http")
|
||||
async def _latency_header(request: Request, call_next):
|
||||
t0 = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
dur_ms = (time.perf_counter() - t0) * 1000
|
||||
response.headers["X-Duration-Ms"] = f"{dur_ms:.2f}"
|
||||
return response
|
||||
|
||||
# CER-P5-002 error envelope: exception handlers globali
|
||||
@app.exception_handler(HTTPException)
|
||||
async def _http_exc(request: Request, exc: HTTPException):
|
||||
retryable = exc.status_code in (408, 429, 502, 503, 504)
|
||||
code_map = {
|
||||
400: "BAD_REQUEST", 401: "UNAUTHORIZED", 403: "FORBIDDEN",
|
||||
404: "NOT_FOUND", 408: "TIMEOUT", 409: "CONFLICT",
|
||||
422: "VALIDATION_ERROR", 429: "RATE_LIMIT",
|
||||
500: "INTERNAL_ERROR", 502: "UPSTREAM_ERROR",
|
||||
503: "UNAVAILABLE", 504: "GATEWAY_TIMEOUT",
|
||||
}
|
||||
code = code_map.get(exc.status_code, f"HTTP_{exc.status_code}")
|
||||
message = "HTTP error"
|
||||
details: dict | None = None
|
||||
detail = exc.detail
|
||||
# Preserve rail-style detail {"error": "..", "message": ".."} as code
|
||||
if isinstance(detail, dict):
|
||||
if isinstance(detail.get("error"), str):
|
||||
code = detail["error"].upper()
|
||||
message = str(detail.get("message") or detail.get("error") or message)
|
||||
details = detail
|
||||
elif isinstance(detail, str):
|
||||
message = detail
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=_error_envelope(
|
||||
type_="http_error",
|
||||
code=code,
|
||||
message=message,
|
||||
retryable=retryable,
|
||||
details=details,
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def _validation_exc(request: Request, exc: RequestValidationError):
|
||||
errs = exc.errors()
|
||||
first_loc = ".".join(str(x) for x in errs[0]["loc"]) if errs else "body"
|
||||
suggestion = (
|
||||
f"check field '{first_loc}': "
|
||||
+ (errs[0]["msg"] if errs else "invalid input")
|
||||
)
|
||||
# Sanitize ctx values: pydantic v2 può mettere ValueError in ctx['error'],
|
||||
# non serializzabile JSON. Riduci a stringhe.
|
||||
safe_errs: list[dict] = []
|
||||
for e in errs[:5]:
|
||||
ne: dict = {}
|
||||
for k, v in e.items():
|
||||
if k == "ctx" and isinstance(v, dict):
|
||||
ne[k] = {ck: str(cv) for ck, cv in v.items()}
|
||||
else:
|
||||
ne[k] = v
|
||||
safe_errs.append(ne)
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=_error_envelope(
|
||||
type_="validation_error",
|
||||
code="INVALID_INPUT",
|
||||
message=f"request body validation failed on {first_loc}",
|
||||
retryable=False,
|
||||
suggested_fix=suggestion,
|
||||
details={"errors": safe_errs},
|
||||
),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def _unhandled(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=_error_envelope(
|
||||
type_="internal_error",
|
||||
code="UNHANDLED_EXCEPTION",
|
||||
message=f"{type(exc).__name__}: {str(exc)[:300]}",
|
||||
retryable=True,
|
||||
),
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"name": name,
|
||||
"version": version,
|
||||
"uptime_seconds": int(time.time() - app.state.boot_at),
|
||||
"data_timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
return app
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class Database:
|
||||
path: Path
|
||||
conn: sqlite3.Connection | None = None
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(
|
||||
str(self.path),
|
||||
isolation_level=None,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
self.conn.execute("PRAGMA synchronous=NORMAL")
|
||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
||||
return self.conn
|
||||
|
||||
def close(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection, migrations: dict[int, str]) -> None:
|
||||
"""Idempotent migrations. `migrations` keys are monotonic version numbers."""
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _schema_version (version INTEGER PRIMARY KEY)"
|
||||
)
|
||||
cur = conn.execute("SELECT COALESCE(MAX(version), 0) FROM _schema_version")
|
||||
current = cur.fetchone()[0]
|
||||
for version in sorted(migrations):
|
||||
if version <= current:
|
||||
continue
|
||||
conn.executescript(migrations[version])
|
||||
conn.execute("INSERT INTO _schema_version (version) VALUES (?)", (version,))
|
||||
@@ -0,0 +1,84 @@
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from option_mcp_common.auth import (
|
||||
Principal,
|
||||
TokenStore,
|
||||
acl_requires,
|
||||
require_principal,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_store():
|
||||
return TokenStore(tokens={
|
||||
"token-core-123": Principal(name="core", capabilities={"core"}),
|
||||
"token-obs-456": Principal(name="observer", capabilities={"observer"}),
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(token_store):
|
||||
app = FastAPI()
|
||||
app.state.token_store = token_store
|
||||
|
||||
@app.get("/public")
|
||||
def public():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/private")
|
||||
def private(principal: Principal = Depends(require_principal)):
|
||||
return {"name": principal.name}
|
||||
|
||||
@app.post("/core-only")
|
||||
@acl_requires(core=True, observer=False)
|
||||
def core_only(principal: Principal = Depends(require_principal)):
|
||||
return {"who": principal.name}
|
||||
|
||||
@app.post("/observer-only")
|
||||
@acl_requires(core=False, observer=True)
|
||||
def observer_only(principal: Principal = Depends(require_principal)):
|
||||
return {"who": principal.name}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_public_endpoint_no_auth(app):
|
||||
client = TestClient(app)
|
||||
assert client.get("/public").status_code == 200
|
||||
|
||||
|
||||
def test_private_without_header_401(app):
|
||||
client = TestClient(app)
|
||||
assert client.get("/private").status_code == 401
|
||||
|
||||
|
||||
def test_private_bad_token_403(app):
|
||||
client = TestClient(app)
|
||||
r = client.get("/private", headers={"Authorization": "Bearer nope"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_private_good_token_200(app):
|
||||
client = TestClient(app)
|
||||
r = client.get("/private", headers={"Authorization": "Bearer token-core-123"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"name": "core"}
|
||||
|
||||
|
||||
def test_acl_core_token_on_core_only_endpoint(app):
|
||||
client = TestClient(app)
|
||||
r = client.post("/core-only", headers={"Authorization": "Bearer token-core-123"})
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
def test_acl_observer_on_core_only_rejected(app):
|
||||
client = TestClient(app)
|
||||
r = client.post("/core-only", headers={"Authorization": "Bearer token-obs-456"})
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_acl_observer_on_observer_only_ok(app):
|
||||
client = TestClient(app)
|
||||
r = client.post("/observer-only", headers={"Authorization": "Bearer token-obs-456"})
|
||||
assert r.status_code == 200
|
||||
@@ -0,0 +1,71 @@
|
||||
"""CER-P5-010 env validation tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from option_mcp_common.env_validation import (
|
||||
MissingEnvError,
|
||||
fail_fast_if_missing,
|
||||
optional_env,
|
||||
require_env,
|
||||
summarize,
|
||||
)
|
||||
|
||||
|
||||
def test_require_env_present(monkeypatch):
|
||||
monkeypatch.setenv("FOO_KEY", "value1")
|
||||
assert require_env("FOO_KEY") == "value1"
|
||||
|
||||
|
||||
def test_require_env_missing_raises(monkeypatch):
|
||||
monkeypatch.delenv("MISSING_REQ", raising=False)
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("MISSING_REQ", "critical path")
|
||||
|
||||
|
||||
def test_require_env_empty_raises(monkeypatch):
|
||||
monkeypatch.setenv("EMPTY_REQ", "")
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("EMPTY_REQ")
|
||||
|
||||
|
||||
def test_require_env_whitespace_only_raises(monkeypatch):
|
||||
monkeypatch.setenv("WS_REQ", " ")
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("WS_REQ")
|
||||
|
||||
|
||||
def test_optional_env_default(monkeypatch):
|
||||
monkeypatch.delenv("OPT_A", raising=False)
|
||||
assert optional_env("OPT_A", default="fallback") == "fallback"
|
||||
|
||||
|
||||
def test_optional_env_set(monkeypatch):
|
||||
monkeypatch.setenv("OPT_B", "xx")
|
||||
assert optional_env("OPT_B", default="fallback") == "xx"
|
||||
|
||||
|
||||
def test_fail_fast_all_present(monkeypatch):
|
||||
monkeypatch.setenv("AA", "1")
|
||||
monkeypatch.setenv("BB", "2")
|
||||
fail_fast_if_missing(["AA", "BB"]) # no exit
|
||||
|
||||
|
||||
def test_fail_fast_missing_exits(monkeypatch):
|
||||
monkeypatch.setenv("HAVE_IT", "1")
|
||||
monkeypatch.delenv("MISSING_X", raising=False)
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
fail_fast_if_missing(["HAVE_IT", "MISSING_X"])
|
||||
assert exc.value.code == 2
|
||||
|
||||
|
||||
def test_summarize_does_not_leak_secrets(monkeypatch, caplog):
|
||||
import logging
|
||||
monkeypatch.setenv("API_KEY_FOO", "super-secret-token-123456")
|
||||
monkeypatch.setenv("PORT", "9000")
|
||||
with caplog.at_level(logging.INFO, logger="option_mcp_common.env_validation"):
|
||||
summarize(["API_KEY_FOO", "PORT", "NOT_SET_XYZ"])
|
||||
log_text = "\n".join(caplog.messages)
|
||||
assert "super-secret-token-123456" not in log_text
|
||||
assert "9000" in log_text
|
||||
assert "<unset>" in log_text
|
||||
@@ -0,0 +1,80 @@
|
||||
|
||||
from option_mcp_common.indicators import adx, atr, macd, rsi, sma
|
||||
|
||||
|
||||
def test_rsi_simple():
|
||||
closes = [44, 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84,
|
||||
46.08, 45.89, 46.03, 45.61, 46.28]
|
||||
r = rsi(closes, period=14)
|
||||
assert r is not None
|
||||
# Known textbook RSI value ballpark
|
||||
assert 65.0 < r < 75.0
|
||||
|
||||
|
||||
def test_rsi_insufficient_data():
|
||||
assert rsi([1, 2, 3], period=14) is None
|
||||
|
||||
|
||||
def test_sma_simple():
|
||||
assert sma([1, 2, 3, 4, 5], period=5) == 3.0
|
||||
assert sma([1, 2, 3], period=5) is None
|
||||
|
||||
|
||||
def test_atr_simple():
|
||||
# highs, lows, closes
|
||||
highs = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
|
||||
lows = [ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
|
||||
closes = [9.5,10.5,11.5,12.5,13.5,14.5,15.5,16.5,17.5,18.5,19.5,20.5,21.5,22.5,23.5]
|
||||
a = atr(highs, lows, closes, period=14)
|
||||
assert a is not None
|
||||
assert 0.9 < a <= 1.5
|
||||
|
||||
|
||||
def test_macd_trend_up():
|
||||
# monotonic uptrend → MACD > 0, histogram > 0
|
||||
closes = [float(i) for i in range(1, 60)]
|
||||
m = macd(closes)
|
||||
assert m["macd"] is not None
|
||||
assert m["signal"] is not None
|
||||
assert m["hist"] is not None
|
||||
assert m["macd"] > 0
|
||||
assert m["hist"] >= 0
|
||||
|
||||
|
||||
def test_macd_insufficient_data():
|
||||
m = macd([1.0, 2.0, 3.0])
|
||||
assert m == {"macd": None, "signal": None, "hist": None}
|
||||
|
||||
|
||||
def test_macd_trend_down():
|
||||
closes = [float(i) for i in range(60, 1, -1)]
|
||||
m = macd(closes)
|
||||
assert m["macd"] < 0
|
||||
assert m["hist"] <= 0
|
||||
|
||||
|
||||
def test_adx_insufficient_data():
|
||||
a = adx([1.0] * 10, [0.5] * 10, [0.7] * 10, period=14)
|
||||
assert a == {"adx": None, "+di": None, "-di": None}
|
||||
|
||||
|
||||
def test_adx_strong_uptrend():
|
||||
highs = [float(i) + 1.0 for i in range(1, 40)]
|
||||
lows = [float(i) for i in range(1, 40)]
|
||||
closes = [float(i) + 0.5 for i in range(1, 40)]
|
||||
a = adx(highs, lows, closes, period=14)
|
||||
assert a["adx"] is not None
|
||||
assert a["+di"] is not None and a["-di"] is not None
|
||||
# strong uptrend → +DI >> -DI, ADX high
|
||||
assert a["+di"] > a["-di"]
|
||||
assert a["adx"] > 50.0
|
||||
|
||||
|
||||
def test_adx_flat_market():
|
||||
highs = [10.0] * 40
|
||||
lows = [9.0] * 40
|
||||
closes = [9.5] * 40
|
||||
a = adx(highs, lows, closes, period=14)
|
||||
# no directional movement → ADX near 0
|
||||
assert a["adx"] is not None
|
||||
assert a["adx"] < 5.0
|
||||
@@ -0,0 +1,77 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from option_mcp_common.logging import (
|
||||
SecretsFilter,
|
||||
configure_root_logging,
|
||||
get_json_logger,
|
||||
)
|
||||
|
||||
|
||||
def test_secrets_filter_masks_bearer():
|
||||
f = SecretsFilter()
|
||||
rec = logging.LogRecord(
|
||||
name="t", level=logging.INFO, pathname="", lineno=0,
|
||||
msg="Got Bearer abcdef123456 from client",
|
||||
args=(), exc_info=None,
|
||||
)
|
||||
f.filter(rec)
|
||||
assert "abcdef" not in rec.msg
|
||||
assert "***" in rec.msg
|
||||
|
||||
|
||||
def test_secrets_filter_masks_api_key_json():
|
||||
f = SecretsFilter()
|
||||
rec = logging.LogRecord(
|
||||
name="t", level=logging.INFO, pathname="", lineno=0,
|
||||
msg='{"api_key": "sk-live-abc123xyz"}',
|
||||
args=(), exc_info=None,
|
||||
)
|
||||
f.filter(rec)
|
||||
assert "sk-live-abc123xyz" not in rec.msg
|
||||
|
||||
|
||||
def test_json_logger_outputs_json(capsys):
|
||||
logger = get_json_logger("test")
|
||||
logger.info("hello", extra={"user_id": 42})
|
||||
captured = capsys.readouterr()
|
||||
# output is on stderr by default for json logger
|
||||
line = (captured.err or captured.out).strip().splitlines()[-1]
|
||||
data = json.loads(line)
|
||||
assert data["message"] == "hello"
|
||||
assert data["user_id"] == 42
|
||||
|
||||
|
||||
def test_configure_root_json_format(monkeypatch, capsys):
|
||||
monkeypatch.setenv("LOG_FORMAT", "json")
|
||||
monkeypatch.setenv("LOG_LEVEL", "INFO")
|
||||
configure_root_logging()
|
||||
logging.info("root json test")
|
||||
line = capsys.readouterr().err.strip().splitlines()[-1]
|
||||
data = json.loads(line)
|
||||
assert data["message"] == "root json test"
|
||||
assert data["levelname"] == "INFO"
|
||||
|
||||
|
||||
def test_configure_root_text_format(monkeypatch, capsys):
|
||||
monkeypatch.setenv("LOG_FORMAT", "text")
|
||||
configure_root_logging()
|
||||
logging.info("root text test")
|
||||
line = capsys.readouterr().err.strip().splitlines()[-1]
|
||||
# text format non è JSON parseable
|
||||
try:
|
||||
json.loads(line)
|
||||
raise AssertionError("expected text format, got JSON")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
assert "root text test" in line
|
||||
|
||||
|
||||
def test_configure_root_applies_secrets_filter(monkeypatch, capsys):
|
||||
monkeypatch.setenv("LOG_FORMAT", "json")
|
||||
configure_root_logging()
|
||||
logging.info("calling with Bearer sk-live-leak123456 token")
|
||||
line = capsys.readouterr().err.strip().splitlines()[-1]
|
||||
data = json.loads(line)
|
||||
assert "sk-live-leak123456" not in data["message"]
|
||||
assert "***" in data["message"]
|
||||
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from option_mcp_common.auth import Principal, TokenStore, require_principal
|
||||
from option_mcp_common.mcp_bridge import _derive_input_schemas, mount_mcp_endpoint
|
||||
from option_mcp_common.server import build_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EchoBody(BaseModel):
|
||||
msg: str
|
||||
n: int = 1
|
||||
|
||||
|
||||
def _make_app() -> tuple[FastAPI, TokenStore]:
|
||||
store = TokenStore(tokens={"t": Principal("obs", {"observer"})})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
|
||||
@app.post("/tools/echo")
|
||||
def echo(body: EchoBody, principal: Principal = Depends(require_principal)):
|
||||
return {"echo": body.msg, "n": body.n}
|
||||
|
||||
@app.post("/tools/ping")
|
||||
def ping(principal: Principal = Depends(require_principal)):
|
||||
return {"pong": True}
|
||||
|
||||
return app, store
|
||||
|
||||
|
||||
def test_derive_input_schemas_resolves_lazy_annotations():
|
||||
app, _ = _make_app()
|
||||
schemas = _derive_input_schemas(app, ["echo", "ping"])
|
||||
assert "echo" in schemas
|
||||
echo_schema = schemas["echo"]
|
||||
assert echo_schema["type"] == "object"
|
||||
assert "msg" in echo_schema["properties"]
|
||||
assert "n" in echo_schema["properties"]
|
||||
assert "msg" in echo_schema["required"]
|
||||
# ping has no Pydantic body → not in map (fallback applied by caller)
|
||||
assert "ping" not in schemas
|
||||
|
||||
|
||||
def test_mount_mcp_endpoint_exposes_derived_schemas():
|
||||
app, store = _make_app()
|
||||
mount_mcp_endpoint(
|
||||
app,
|
||||
name="test",
|
||||
version="1.0",
|
||||
token_store=store,
|
||||
internal_base_url="http://localhost:0",
|
||||
tools=[
|
||||
{"name": "echo", "description": "Echo a message."},
|
||||
{"name": "ping", "description": "Ping."},
|
||||
],
|
||||
)
|
||||
c = TestClient(app)
|
||||
r = c.post(
|
||||
"/mcp",
|
||||
headers={"Authorization": "Bearer t"},
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
tools = r.json()["result"]["tools"]
|
||||
by_name = {t["name"]: t for t in tools}
|
||||
assert set(by_name["echo"]["inputSchema"]["required"]) == {"msg"}
|
||||
# ping fallback su schema generico
|
||||
assert by_name["ping"]["inputSchema"] == {
|
||||
"type": "object",
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def test_mount_mcp_endpoint_requires_auth():
|
||||
app, store = _make_app()
|
||||
mount_mcp_endpoint(
|
||||
app,
|
||||
name="test",
|
||||
version="1.0",
|
||||
token_store=store,
|
||||
internal_base_url="http://localhost:0",
|
||||
tools=[{"name": "echo"}],
|
||||
)
|
||||
c = TestClient(app)
|
||||
r = c.post("/mcp", json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"})
|
||||
assert r.status_code == 401
|
||||
r = c.post(
|
||||
"/mcp",
|
||||
headers={"Authorization": "Bearer WRONG"},
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
|
||||
def test_explicit_input_schema_overrides_derived():
|
||||
app, store = _make_app()
|
||||
custom = {"type": "object", "properties": {"custom": {"type": "string"}}, "required": ["custom"]}
|
||||
mount_mcp_endpoint(
|
||||
app,
|
||||
name="test",
|
||||
version="1.0",
|
||||
token_store=store,
|
||||
internal_base_url="http://localhost:0",
|
||||
tools=[{"name": "echo", "input_schema": custom}],
|
||||
)
|
||||
c = TestClient(app)
|
||||
r = c.post(
|
||||
"/mcp",
|
||||
headers={"Authorization": "Bearer t"},
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"},
|
||||
)
|
||||
assert r.json()["result"]["tools"][0]["inputSchema"] == custom
|
||||
@@ -0,0 +1,40 @@
|
||||
from option_mcp_common.models import EventPriority, EventType, L2Entry
|
||||
|
||||
|
||||
def test_l2_entry_minimal():
|
||||
entry = L2Entry(
|
||||
timestamp="2026-04-17T10:30:00Z",
|
||||
setup="bull put spread ETH 1800/1750 14d",
|
||||
tesi="IV alta post-CPI, attesa mean-reversion",
|
||||
esito="aperto"
|
||||
)
|
||||
assert entry.scostamento_sigma is None
|
||||
assert entry.tesi_check is None
|
||||
|
||||
|
||||
def test_l2_entry_full():
|
||||
entry = L2Entry(
|
||||
timestamp="2026-04-17T10:30:00Z",
|
||||
setup="bull put spread ETH 1800/1750 14d",
|
||||
tesi="IV alta post-CPI",
|
||||
tesi_check="ETH sopra 1820 per 24h con IV in calo",
|
||||
invalidation="rottura 1800 con volume > 2x media",
|
||||
esito="chiuso +12 USDC",
|
||||
scostamento="nessuno",
|
||||
scostamento_sigma=0.5,
|
||||
lezione="supporto ha tenuto, timing ok",
|
||||
sizing_note="size 80 USDC (ATR 1.3x media)",
|
||||
)
|
||||
assert entry.scostamento_sigma == 0.5
|
||||
dump = entry.model_dump()
|
||||
assert dump["lezione"] == "supporto ha tenuto, timing ok"
|
||||
|
||||
|
||||
def test_event_priority_enum():
|
||||
assert EventPriority.CRITICAL.value == "critical"
|
||||
assert EventPriority.LOW < EventPriority.CRITICAL # ordering
|
||||
|
||||
|
||||
def test_event_type_enum():
|
||||
assert EventType.ALERT.value == "alert"
|
||||
assert EventType.USER_INSTRUCTION.value == "user_instruction"
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Tests for CER-016 risk guard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from option_mcp_common.risk_guard import (
|
||||
enforce_aggregate,
|
||||
enforce_leverage,
|
||||
enforce_single_notional,
|
||||
max_aggregate,
|
||||
max_leverage,
|
||||
max_notional,
|
||||
)
|
||||
|
||||
|
||||
def test_defaults(monkeypatch):
|
||||
for k in ("CERBERO_MAX_NOTIONAL", "CERBERO_MAX_AGGREGATE", "CERBERO_MAX_LEVERAGE"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
assert max_notional() == 200.0
|
||||
assert max_aggregate() == 1000.0
|
||||
assert max_leverage() == 3
|
||||
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
monkeypatch.setenv("CERBERO_MAX_NOTIONAL", "50")
|
||||
monkeypatch.setenv("CERBERO_MAX_AGGREGATE", "150")
|
||||
monkeypatch.setenv("CERBERO_MAX_LEVERAGE", "2")
|
||||
assert max_notional() == 50.0
|
||||
assert max_aggregate() == 150.0
|
||||
assert max_leverage() == 2
|
||||
|
||||
|
||||
def test_leverage_default_when_none(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
assert enforce_leverage(None) == 3
|
||||
|
||||
|
||||
def test_leverage_accepts_within_cap(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
assert enforce_leverage(2) == 2
|
||||
assert enforce_leverage(3) == 3
|
||||
|
||||
|
||||
def test_leverage_rejects_above_cap(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_leverage(50)
|
||||
assert exc.value.status_code == 403
|
||||
assert exc.value.detail["error"] == "HARD_PROHIBITION"
|
||||
|
||||
|
||||
def test_leverage_rejects_below_one(monkeypatch):
|
||||
with pytest.raises(HTTPException):
|
||||
enforce_leverage(0)
|
||||
|
||||
|
||||
def test_single_notional_ok(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
||||
enforce_single_notional(100.0, exchange="deribit", instrument="BTC-PERPETUAL")
|
||||
|
||||
|
||||
def test_single_notional_rejects(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_single_notional(335.0, exchange="hyperliquid", instrument="ETH")
|
||||
assert exc.value.status_code == 403
|
||||
assert "335" in exc.value.detail["message"]
|
||||
|
||||
|
||||
def test_aggregate_ok(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
||||
enforce_aggregate(current_total=500.0, new_notional=200.0)
|
||||
|
||||
|
||||
def test_aggregate_rejects(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_aggregate(current_total=900.0, new_notional=200.0)
|
||||
assert exc.value.status_code == 403
|
||||
@@ -0,0 +1,90 @@
|
||||
from fastapi.testclient import TestClient
|
||||
from option_mcp_common.auth import Principal, TokenStore
|
||||
from option_mcp_common.server import build_app
|
||||
|
||||
|
||||
def test_build_app_health():
|
||||
store = TokenStore(tokens={})
|
||||
app = build_app(name="test-mcp", version="0.0.1", token_store=store)
|
||||
client = TestClient(app)
|
||||
r = client.get("/health")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["status"] == "healthy"
|
||||
assert body["name"] == "test-mcp"
|
||||
assert body["version"] == "0.0.1"
|
||||
assert "uptime_seconds" in body
|
||||
assert "data_timestamp" in body
|
||||
assert r.headers.get("X-Duration-Ms") is not None
|
||||
|
||||
|
||||
def test_build_app_adds_token_store():
|
||||
store = TokenStore(tokens={"t1": Principal("x", {"core"})})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
assert app.state.token_store is store
|
||||
|
||||
|
||||
def test_timestamp_injector_dict_response():
|
||||
"""CER-P5-001: dict response gets data_timestamp + X-Data-Timestamp header."""
|
||||
store = TokenStore(tokens={})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
|
||||
@app.post("/tools/foo")
|
||||
def foo():
|
||||
return {"ok": True}
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.post("/tools/foo")
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["ok"] is True
|
||||
assert "data_timestamp" in body
|
||||
assert r.headers.get("X-Data-Timestamp") is not None
|
||||
|
||||
|
||||
def test_timestamp_injector_list_of_dicts():
|
||||
"""CER-P5-001: list of dicts → each item gets data_timestamp."""
|
||||
store = TokenStore(tokens={})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
|
||||
@app.post("/tools/list_items")
|
||||
def list_items():
|
||||
return [{"x": 1}, {"x": 2}]
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.post("/tools/list_items")
|
||||
body = r.json()
|
||||
assert isinstance(body, list)
|
||||
assert len(body) == 2
|
||||
for item in body:
|
||||
assert "data_timestamp" in item
|
||||
assert r.headers.get("X-Data-Timestamp") is not None
|
||||
|
||||
|
||||
def test_timestamp_injector_preserves_existing():
|
||||
"""CER-P5-001: se già presente, non override."""
|
||||
store = TokenStore(tokens={})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
|
||||
@app.post("/tools/already")
|
||||
def already():
|
||||
return {"data_timestamp": "2020-01-01T00:00:00Z", "x": 1}
|
||||
|
||||
client = TestClient(app)
|
||||
body = client.post("/tools/already").json()
|
||||
assert body["data_timestamp"] == "2020-01-01T00:00:00Z"
|
||||
|
||||
|
||||
def test_timestamp_injector_empty_list_gets_header_only():
|
||||
"""CER-P5-001: list vuota — no body modification, ma header presente."""
|
||||
store = TokenStore(tokens={})
|
||||
app = build_app(name="t", version="v", token_store=store)
|
||||
|
||||
@app.post("/tools/empty_list")
|
||||
def empty_list():
|
||||
return []
|
||||
|
||||
client = TestClient(app)
|
||||
r = client.post("/tools/empty_list")
|
||||
assert r.json() == []
|
||||
assert r.headers.get("X-Data-Timestamp") is not None
|
||||
@@ -0,0 +1,48 @@
|
||||
from pathlib import Path
|
||||
|
||||
from option_mcp_common.storage import Database, run_migrations
|
||||
|
||||
|
||||
def test_database_creates_wal(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
# WAL mode attivo
|
||||
mode = db.conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode.lower() == "wal"
|
||||
db.close()
|
||||
|
||||
|
||||
def test_database_migrations_run_once(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
migrations = {
|
||||
1: "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);",
|
||||
2: "ALTER TABLE foo ADD COLUMN value INTEGER DEFAULT 0;",
|
||||
}
|
||||
run_migrations(db.conn, migrations)
|
||||
# Second run: should be no-op
|
||||
run_migrations(db.conn, migrations)
|
||||
cols = [r[1] for r in db.conn.execute("PRAGMA table_info(foo)").fetchall()]
|
||||
assert "name" in cols
|
||||
assert "value" in cols
|
||||
version = db.conn.execute("SELECT MAX(version) FROM _schema_version").fetchone()[0]
|
||||
assert version == 2
|
||||
db.close()
|
||||
|
||||
|
||||
def test_database_partial_migration(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
migrations_v1 = {1: "CREATE TABLE foo (id INTEGER);"}
|
||||
run_migrations(db.conn, migrations_v1)
|
||||
migrations_v2 = {**migrations_v1, 2: "CREATE TABLE bar (id INTEGER);"}
|
||||
run_migrations(db.conn, migrations_v2)
|
||||
tables = {r[0] for r in db.conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()}
|
||||
assert "foo" in tables
|
||||
assert "bar" in tables
|
||||
db.close()
|
||||
Reference in New Issue
Block a user