refactor(mcp_common): remove risk_guard, models, env_validation, storage

This commit is contained in:
AdrianoDev
2026-04-27 17:38:44 +02:00
parent e888fc373d
commit 888a3cde84
9 changed files with 1 additions and 571 deletions
+1 -19
View File
@@ -1,19 +1 @@
from mcp_common.models import (
Event,
EventPriority,
EventType,
L1State,
L2Entry,
L3Entry,
UserInstruction,
)
__all__ = [
"L1State",
"L2Entry",
"L3Entry",
"Event",
"EventPriority",
"EventType",
"UserInstruction",
]
__all__ = []
@@ -1,80 +0,0 @@
"""CER-P5-010: env validation policy — fail-fast per mandatory, soft per optional.
Usage al boot di ogni mcp `__main__.py`:
from mcp_common.env_validation import require_env, optional_env, summarize
creds_file = require_env("CREDENTIALS_FILE", "deribit credentials JSON path")
host = optional_env("HOST", default="0.0.0.0")
summarize(["CREDENTIALS_FILE", "HOST", "PORT"])
"""
from __future__ import annotations
import logging
import os
import sys
logger = logging.getLogger(__name__)
class MissingEnvError(RuntimeError):
"""Mandatory env var absent or empty."""
def require_env(name: str, description: str = "") -> str:
"""Fail-fast: raise MissingEnvError se name non presente o vuoto.
Uscita dal processo con codice 2 se chiamato dal main(). Comporta
logging chiaro del missing var prima dell'exit.
"""
val = (os.environ.get(name) or "").strip()
if not val:
msg = f"missing mandatory env var: {name}"
if description:
msg += f" ({description})"
logger.error(msg)
raise MissingEnvError(msg)
return val
def optional_env(name: str, *, default: str = "") -> str:
"""Soft: ritorna env o default. Log INFO se default usato."""
val = (os.environ.get(name) or "").strip()
if not val:
if default:
logger.info("env %s not set, using default=%r", name, default)
return default
return val
def summarize(names: list[str]) -> None:
"""Log INFO di tutti gli env rilevanti con presenza (mask se SECRET/KEY/TOKEN)."""
sensitive_tokens = ("SECRET", "KEY", "TOKEN", "PASSWORD", "CREDENTIAL", "WALLET")
for n in names:
val = os.environ.get(n)
if val is None:
logger.info("env[%s]: <unset>", n)
continue
if any(t in n.upper() for t in sensitive_tokens):
logger.info("env[%s]: <set, %d chars>", n, len(val))
else:
logger.info("env[%s]: %s", n, val)
def fail_fast_if_missing(names: list[str]) -> None:
"""Verifica lista di nomi mandatory al boot. Exit 2 se uno solo manca.
Uso preferito: early call in main() per bloccare boot se config incompleta.
"""
missing: list[str] = []
for n in names:
if not (os.environ.get(n) or "").strip():
missing.append(n)
if missing:
logger.error("boot aborted: missing mandatory env vars: %s", missing)
print(
f"FATAL: missing mandatory env vars: {missing}",
file=sys.stderr,
)
sys.exit(2)
-98
View File
@@ -1,98 +0,0 @@
from __future__ import annotations
from enum import StrEnum
from functools import total_ordering
from typing import Any
from pydantic import BaseModel, Field
@total_ordering
class EventPriority(StrEnum):
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
def _rank(self) -> int:
return ["low", "normal", "high", "critical"].index(self.value)
def __lt__(self, other: EventPriority) -> bool:
return self._rank() < other._rank()
class EventType(StrEnum):
ALERT = "alert"
USER_INSTRUCTION = "user_instruction"
SYSTEM = "system"
class L1State(BaseModel):
"""Singleton row with current operational state."""
updated_at: str
equity_total: float | None = None
equity_by_exchange: dict[str, float] = Field(default_factory=dict)
bias: str | None = None
pnl_day: float | None = None
pnl_total: float | None = None
capital: float | None = None
open_positions_count: int = 0
greeks_aggregate: dict[str, float] = Field(default_factory=dict)
notes: str | None = None
class L2Entry(BaseModel):
"""Reasoning entry — schema matches system_prompt v2.
`authored_by_model`: identifica l'LLM che ha generato la entry (es.
"google/gemini-3-flash-preview" per core, "anthropic/claude-haiku-4-5"
per worker). None se scritto da operatore umano via UI.
"""
timestamp: str
setup: str
tesi: str | None = None
tesi_check: str | None = None
invalidation: str | None = None
esito: str
scostamento: str | None = None
scostamento_sigma: float | None = None
lezione: str | None = None
sizing_note: str | None = None
run_id: str | None = None
user_instruction_id: int | None = None
authored_by_model: str | None = None
class L3Entry(BaseModel):
"""Compacted pattern from L2 batch."""
created_at: str
category: str # "pattern_errore" | "pattern_vincente" | "correlazione"
summary: str
source_l2_ids: list[int]
class Event(BaseModel):
id: int | None = None
created_at: str
expires_at: str
type: EventType
source: str
priority: EventPriority
payload: dict[str, Any]
acked_at: str | None = None
ack_outcome: str | None = None
ack_notes: str | None = None
class UserInstruction(BaseModel):
id: int | None = None
created_at: str
text: str
priority: EventPriority
require_ack: bool = True
source: str = "observer"
acked_at: str | None = None
ack_outcome: str | None = None
@@ -1,92 +0,0 @@
"""CER-016 hard guard server-side su place_order.
Caps configurabili via env (default safety-first, mirati a ~200 EUR single,
1000 EUR aggregato, 3x max leverage).
Thresholds sono numerici semplici — l'operatore stabilisce unità (EUR/USD)
via env; il server compara su un unico campo `notional` in valore monetario.
"""
from __future__ import annotations
import os
from fastapi import HTTPException
def _env_float(name: str, default: float) -> float:
raw = os.environ.get(name)
if not raw:
return default
try:
return float(raw)
except (TypeError, ValueError):
return default
def _env_int(name: str, default: int) -> int:
raw = os.environ.get(name)
if not raw:
return default
try:
return int(raw)
except (TypeError, ValueError):
return default
def max_notional() -> float:
return _env_float("CERBERO_MAX_NOTIONAL", 200.0)
def max_aggregate() -> float:
return _env_float("CERBERO_MAX_AGGREGATE", 1000.0)
def max_leverage() -> int:
return _env_int("CERBERO_MAX_LEVERAGE", 3)
def _hard_reject(reason: str) -> None:
raise HTTPException(
status_code=403,
detail={
"error": "HARD_PROHIBITION",
"message": reason,
"caps": {
"max_notional": max_notional(),
"max_aggregate": max_aggregate(),
"max_leverage": max_leverage(),
},
},
)
def enforce_leverage(leverage: int | float | None) -> int:
"""Ritorna leverage applicabile. Default 3x se None. Reject se > cap."""
cap = max_leverage()
if leverage is None:
return cap
lev = int(leverage)
if lev < 1:
_hard_reject(f"leverage must be >= 1 (got {lev})")
if lev > cap:
_hard_reject(f"leverage {lev}x exceeds hard cap {cap}x")
return lev
def enforce_single_notional(notional: float, *, exchange: str, instrument: str) -> None:
cap = max_notional()
if notional > cap:
_hard_reject(
f"{exchange}.{instrument} notional {notional:.2f} exceeds single trade cap {cap:.2f}"
)
def enforce_aggregate(current_total: float, new_notional: float) -> None:
cap = max_aggregate()
total = current_total + new_notional
if total > cap:
_hard_reject(
f"aggregate notional {total:.2f} (current {current_total:.2f} + new "
f"{new_notional:.2f}) exceeds cap {cap:.2f}"
)
-43
View File
@@ -1,43 +0,0 @@
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
@dataclass
class Database:
path: Path
conn: sqlite3.Connection | None = None
def connect(self) -> sqlite3.Connection:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(
str(self.path),
isolation_level=None,
check_same_thread=False,
)
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA journal_mode=WAL")
self.conn.execute("PRAGMA synchronous=NORMAL")
self.conn.execute("PRAGMA foreign_keys=ON")
return self.conn
def close(self) -> None:
if self.conn is not None:
self.conn.close()
self.conn = None
def run_migrations(conn: sqlite3.Connection, migrations: dict[int, str]) -> None:
"""Idempotent migrations. `migrations` keys are monotonic version numbers."""
conn.execute(
"CREATE TABLE IF NOT EXISTS _schema_version (version INTEGER PRIMARY KEY)"
)
cur = conn.execute("SELECT COALESCE(MAX(version), 0) FROM _schema_version")
current = cur.fetchone()[0]
for version in sorted(migrations):
if version <= current:
continue
conn.executescript(migrations[version])
conn.execute("INSERT INTO _schema_version (version) VALUES (?)", (version,))
@@ -1,71 +0,0 @@
"""CER-P5-010 env validation tests."""
from __future__ import annotations
import pytest
from mcp_common.env_validation import (
MissingEnvError,
fail_fast_if_missing,
optional_env,
require_env,
summarize,
)
def test_require_env_present(monkeypatch):
monkeypatch.setenv("FOO_KEY", "value1")
assert require_env("FOO_KEY") == "value1"
def test_require_env_missing_raises(monkeypatch):
monkeypatch.delenv("MISSING_REQ", raising=False)
with pytest.raises(MissingEnvError):
require_env("MISSING_REQ", "critical path")
def test_require_env_empty_raises(monkeypatch):
monkeypatch.setenv("EMPTY_REQ", "")
with pytest.raises(MissingEnvError):
require_env("EMPTY_REQ")
def test_require_env_whitespace_only_raises(monkeypatch):
monkeypatch.setenv("WS_REQ", " ")
with pytest.raises(MissingEnvError):
require_env("WS_REQ")
def test_optional_env_default(monkeypatch):
monkeypatch.delenv("OPT_A", raising=False)
assert optional_env("OPT_A", default="fallback") == "fallback"
def test_optional_env_set(monkeypatch):
monkeypatch.setenv("OPT_B", "xx")
assert optional_env("OPT_B", default="fallback") == "xx"
def test_fail_fast_all_present(monkeypatch):
monkeypatch.setenv("AA", "1")
monkeypatch.setenv("BB", "2")
fail_fast_if_missing(["AA", "BB"]) # no exit
def test_fail_fast_missing_exits(monkeypatch):
monkeypatch.setenv("HAVE_IT", "1")
monkeypatch.delenv("MISSING_X", raising=False)
with pytest.raises(SystemExit) as exc:
fail_fast_if_missing(["HAVE_IT", "MISSING_X"])
assert exc.value.code == 2
def test_summarize_does_not_leak_secrets(monkeypatch, caplog):
import logging
monkeypatch.setenv("API_KEY_FOO", "super-secret-token-123456")
monkeypatch.setenv("PORT", "9000")
with caplog.at_level(logging.INFO, logger="mcp_common.env_validation"):
summarize(["API_KEY_FOO", "PORT", "NOT_SET_XYZ"])
log_text = "\n".join(caplog.messages)
assert "super-secret-token-123456" not in log_text
assert "9000" in log_text
assert "<unset>" in log_text
-40
View File
@@ -1,40 +0,0 @@
from mcp_common.models import EventPriority, EventType, L2Entry
def test_l2_entry_minimal():
entry = L2Entry(
timestamp="2026-04-17T10:30:00Z",
setup="bull put spread ETH 1800/1750 14d",
tesi="IV alta post-CPI, attesa mean-reversion",
esito="aperto"
)
assert entry.scostamento_sigma is None
assert entry.tesi_check is None
def test_l2_entry_full():
entry = L2Entry(
timestamp="2026-04-17T10:30:00Z",
setup="bull put spread ETH 1800/1750 14d",
tesi="IV alta post-CPI",
tesi_check="ETH sopra 1820 per 24h con IV in calo",
invalidation="rottura 1800 con volume > 2x media",
esito="chiuso +12 USDC",
scostamento="nessuno",
scostamento_sigma=0.5,
lezione="supporto ha tenuto, timing ok",
sizing_note="size 80 USDC (ATR 1.3x media)",
)
assert entry.scostamento_sigma == 0.5
dump = entry.model_dump()
assert dump["lezione"] == "supporto ha tenuto, timing ok"
def test_event_priority_enum():
assert EventPriority.CRITICAL.value == "critical"
assert EventPriority.LOW < EventPriority.CRITICAL # ordering
def test_event_type_enum():
assert EventType.ALERT.value == "alert"
assert EventType.USER_INSTRUCTION.value == "user_instruction"
-80
View File
@@ -1,80 +0,0 @@
"""Tests for CER-016 risk guard."""
from __future__ import annotations
import pytest
from fastapi import HTTPException
from mcp_common.risk_guard import (
enforce_aggregate,
enforce_leverage,
enforce_single_notional,
max_aggregate,
max_leverage,
max_notional,
)
def test_defaults(monkeypatch):
for k in ("CERBERO_MAX_NOTIONAL", "CERBERO_MAX_AGGREGATE", "CERBERO_MAX_LEVERAGE"):
monkeypatch.delenv(k, raising=False)
assert max_notional() == 200.0
assert max_aggregate() == 1000.0
assert max_leverage() == 3
def test_env_override(monkeypatch):
monkeypatch.setenv("CERBERO_MAX_NOTIONAL", "50")
monkeypatch.setenv("CERBERO_MAX_AGGREGATE", "150")
monkeypatch.setenv("CERBERO_MAX_LEVERAGE", "2")
assert max_notional() == 50.0
assert max_aggregate() == 150.0
assert max_leverage() == 2
def test_leverage_default_when_none(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
assert enforce_leverage(None) == 3
def test_leverage_accepts_within_cap(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
assert enforce_leverage(2) == 2
assert enforce_leverage(3) == 3
def test_leverage_rejects_above_cap(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_leverage(50)
assert exc.value.status_code == 403
assert exc.value.detail["error"] == "HARD_PROHIBITION"
def test_leverage_rejects_below_one(monkeypatch):
with pytest.raises(HTTPException):
enforce_leverage(0)
def test_single_notional_ok(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
enforce_single_notional(100.0, exchange="deribit", instrument="BTC-PERPETUAL")
def test_single_notional_rejects(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_single_notional(335.0, exchange="hyperliquid", instrument="ETH")
assert exc.value.status_code == 403
assert "335" in exc.value.detail["message"]
def test_aggregate_ok(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
enforce_aggregate(current_total=500.0, new_notional=200.0)
def test_aggregate_rejects(monkeypatch):
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
with pytest.raises(HTTPException) as exc:
enforce_aggregate(current_total=900.0, new_notional=200.0)
assert exc.value.status_code == 403
-48
View File
@@ -1,48 +0,0 @@
from pathlib import Path
from mcp_common.storage import Database, run_migrations
def test_database_creates_wal(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
# WAL mode attivo
mode = db.conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode.lower() == "wal"
db.close()
def test_database_migrations_run_once(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
migrations = {
1: "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);",
2: "ALTER TABLE foo ADD COLUMN value INTEGER DEFAULT 0;",
}
run_migrations(db.conn, migrations)
# Second run: should be no-op
run_migrations(db.conn, migrations)
cols = [r[1] for r in db.conn.execute("PRAGMA table_info(foo)").fetchall()]
assert "name" in cols
assert "value" in cols
version = db.conn.execute("SELECT MAX(version) FROM _schema_version").fetchone()[0]
assert version == 2
db.close()
def test_database_partial_migration(tmp_path: Path):
db_path = tmp_path / "test.db"
db = Database(db_path)
db.connect()
migrations_v1 = {1: "CREATE TABLE foo (id INTEGER);"}
run_migrations(db.conn, migrations_v1)
migrations_v2 = {**migrations_v1, 2: "CREATE TABLE bar (id INTEGER);"}
run_migrations(db.conn, migrations_v2)
tables = {r[0] for r in db.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()}
assert "foo" in tables
assert "bar" in tables
db.close()