refactor(mcp_common): remove risk_guard, models, env_validation, storage
This commit is contained in:
@@ -1,19 +1 @@
|
||||
from mcp_common.models import (
|
||||
Event,
|
||||
EventPriority,
|
||||
EventType,
|
||||
L1State,
|
||||
L2Entry,
|
||||
L3Entry,
|
||||
UserInstruction,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"L1State",
|
||||
"L2Entry",
|
||||
"L3Entry",
|
||||
"Event",
|
||||
"EventPriority",
|
||||
"EventType",
|
||||
"UserInstruction",
|
||||
]
|
||||
__all__ = []
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""CER-P5-010: env validation policy — fail-fast per mandatory, soft per optional.
|
||||
|
||||
Usage al boot di ogni mcp `__main__.py`:
|
||||
|
||||
from mcp_common.env_validation import require_env, optional_env, summarize
|
||||
|
||||
creds_file = require_env("CREDENTIALS_FILE", "deribit credentials JSON path")
|
||||
host = optional_env("HOST", default="0.0.0.0")
|
||||
summarize(["CREDENTIALS_FILE", "HOST", "PORT"])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MissingEnvError(RuntimeError):
|
||||
"""Mandatory env var absent or empty."""
|
||||
|
||||
|
||||
def require_env(name: str, description: str = "") -> str:
|
||||
"""Fail-fast: raise MissingEnvError se name non presente o vuoto.
|
||||
|
||||
Uscita dal processo con codice 2 se chiamato dal main(). Comporta
|
||||
logging chiaro del missing var prima dell'exit.
|
||||
"""
|
||||
val = (os.environ.get(name) or "").strip()
|
||||
if not val:
|
||||
msg = f"missing mandatory env var: {name}"
|
||||
if description:
|
||||
msg += f" ({description})"
|
||||
logger.error(msg)
|
||||
raise MissingEnvError(msg)
|
||||
return val
|
||||
|
||||
|
||||
def optional_env(name: str, *, default: str = "") -> str:
|
||||
"""Soft: ritorna env o default. Log INFO se default usato."""
|
||||
val = (os.environ.get(name) or "").strip()
|
||||
if not val:
|
||||
if default:
|
||||
logger.info("env %s not set, using default=%r", name, default)
|
||||
return default
|
||||
return val
|
||||
|
||||
|
||||
def summarize(names: list[str]) -> None:
|
||||
"""Log INFO di tutti gli env rilevanti con presenza (mask se SECRET/KEY/TOKEN)."""
|
||||
sensitive_tokens = ("SECRET", "KEY", "TOKEN", "PASSWORD", "CREDENTIAL", "WALLET")
|
||||
for n in names:
|
||||
val = os.environ.get(n)
|
||||
if val is None:
|
||||
logger.info("env[%s]: <unset>", n)
|
||||
continue
|
||||
if any(t in n.upper() for t in sensitive_tokens):
|
||||
logger.info("env[%s]: <set, %d chars>", n, len(val))
|
||||
else:
|
||||
logger.info("env[%s]: %s", n, val)
|
||||
|
||||
|
||||
def fail_fast_if_missing(names: list[str]) -> None:
|
||||
"""Verifica lista di nomi mandatory al boot. Exit 2 se uno solo manca.
|
||||
|
||||
Uso preferito: early call in main() per bloccare boot se config incompleta.
|
||||
"""
|
||||
missing: list[str] = []
|
||||
for n in names:
|
||||
if not (os.environ.get(n) or "").strip():
|
||||
missing.append(n)
|
||||
if missing:
|
||||
logger.error("boot aborted: missing mandatory env vars: %s", missing)
|
||||
print(
|
||||
f"FATAL: missing mandatory env vars: {missing}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
@@ -1,98 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from functools import total_ordering
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@total_ordering
|
||||
class EventPriority(StrEnum):
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
def _rank(self) -> int:
|
||||
return ["low", "normal", "high", "critical"].index(self.value)
|
||||
|
||||
def __lt__(self, other: EventPriority) -> bool:
|
||||
return self._rank() < other._rank()
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
ALERT = "alert"
|
||||
USER_INSTRUCTION = "user_instruction"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class L1State(BaseModel):
|
||||
"""Singleton row with current operational state."""
|
||||
|
||||
updated_at: str
|
||||
equity_total: float | None = None
|
||||
equity_by_exchange: dict[str, float] = Field(default_factory=dict)
|
||||
bias: str | None = None
|
||||
pnl_day: float | None = None
|
||||
pnl_total: float | None = None
|
||||
capital: float | None = None
|
||||
open_positions_count: int = 0
|
||||
greeks_aggregate: dict[str, float] = Field(default_factory=dict)
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class L2Entry(BaseModel):
|
||||
"""Reasoning entry — schema matches system_prompt v2.
|
||||
|
||||
`authored_by_model`: identifica l'LLM che ha generato la entry (es.
|
||||
"google/gemini-3-flash-preview" per core, "anthropic/claude-haiku-4-5"
|
||||
per worker). None se scritto da operatore umano via UI.
|
||||
"""
|
||||
|
||||
timestamp: str
|
||||
setup: str
|
||||
tesi: str | None = None
|
||||
tesi_check: str | None = None
|
||||
invalidation: str | None = None
|
||||
esito: str
|
||||
scostamento: str | None = None
|
||||
scostamento_sigma: float | None = None
|
||||
lezione: str | None = None
|
||||
sizing_note: str | None = None
|
||||
run_id: str | None = None
|
||||
user_instruction_id: int | None = None
|
||||
authored_by_model: str | None = None
|
||||
|
||||
|
||||
class L3Entry(BaseModel):
|
||||
"""Compacted pattern from L2 batch."""
|
||||
|
||||
created_at: str
|
||||
category: str # "pattern_errore" | "pattern_vincente" | "correlazione"
|
||||
summary: str
|
||||
source_l2_ids: list[int]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
id: int | None = None
|
||||
created_at: str
|
||||
expires_at: str
|
||||
type: EventType
|
||||
source: str
|
||||
priority: EventPriority
|
||||
payload: dict[str, Any]
|
||||
acked_at: str | None = None
|
||||
ack_outcome: str | None = None
|
||||
ack_notes: str | None = None
|
||||
|
||||
|
||||
class UserInstruction(BaseModel):
|
||||
id: int | None = None
|
||||
created_at: str
|
||||
text: str
|
||||
priority: EventPriority
|
||||
require_ack: bool = True
|
||||
source: str = "observer"
|
||||
acked_at: str | None = None
|
||||
ack_outcome: str | None = None
|
||||
@@ -1,92 +0,0 @@
|
||||
"""CER-016 hard guard server-side su place_order.
|
||||
|
||||
Caps configurabili via env (default safety-first, mirati a ~200 EUR single,
|
||||
1000 EUR aggregato, 3x max leverage).
|
||||
|
||||
Thresholds sono numerici semplici — l'operatore stabilisce unità (EUR/USD)
|
||||
via env; il server compara su un unico campo `notional` in valore monetario.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return float(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
raw = os.environ.get(name)
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return int(raw)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def max_notional() -> float:
|
||||
return _env_float("CERBERO_MAX_NOTIONAL", 200.0)
|
||||
|
||||
|
||||
def max_aggregate() -> float:
|
||||
return _env_float("CERBERO_MAX_AGGREGATE", 1000.0)
|
||||
|
||||
|
||||
def max_leverage() -> int:
|
||||
return _env_int("CERBERO_MAX_LEVERAGE", 3)
|
||||
|
||||
|
||||
def _hard_reject(reason: str) -> None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "HARD_PROHIBITION",
|
||||
"message": reason,
|
||||
"caps": {
|
||||
"max_notional": max_notional(),
|
||||
"max_aggregate": max_aggregate(),
|
||||
"max_leverage": max_leverage(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def enforce_leverage(leverage: int | float | None) -> int:
|
||||
"""Ritorna leverage applicabile. Default 3x se None. Reject se > cap."""
|
||||
cap = max_leverage()
|
||||
if leverage is None:
|
||||
return cap
|
||||
lev = int(leverage)
|
||||
if lev < 1:
|
||||
_hard_reject(f"leverage must be >= 1 (got {lev})")
|
||||
if lev > cap:
|
||||
_hard_reject(f"leverage {lev}x exceeds hard cap {cap}x")
|
||||
return lev
|
||||
|
||||
|
||||
def enforce_single_notional(notional: float, *, exchange: str, instrument: str) -> None:
|
||||
cap = max_notional()
|
||||
if notional > cap:
|
||||
_hard_reject(
|
||||
f"{exchange}.{instrument} notional {notional:.2f} exceeds single trade cap {cap:.2f}"
|
||||
)
|
||||
|
||||
|
||||
def enforce_aggregate(current_total: float, new_notional: float) -> None:
|
||||
cap = max_aggregate()
|
||||
total = current_total + new_notional
|
||||
if total > cap:
|
||||
_hard_reject(
|
||||
f"aggregate notional {total:.2f} (current {current_total:.2f} + new "
|
||||
f"{new_notional:.2f}) exceeds cap {cap:.2f}"
|
||||
)
|
||||
@@ -1,43 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class Database:
|
||||
path: Path
|
||||
conn: sqlite3.Connection | None = None
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(
|
||||
str(self.path),
|
||||
isolation_level=None,
|
||||
check_same_thread=False,
|
||||
)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
self.conn.execute("PRAGMA synchronous=NORMAL")
|
||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
||||
return self.conn
|
||||
|
||||
def close(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
|
||||
def run_migrations(conn: sqlite3.Connection, migrations: dict[int, str]) -> None:
|
||||
"""Idempotent migrations. `migrations` keys are monotonic version numbers."""
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS _schema_version (version INTEGER PRIMARY KEY)"
|
||||
)
|
||||
cur = conn.execute("SELECT COALESCE(MAX(version), 0) FROM _schema_version")
|
||||
current = cur.fetchone()[0]
|
||||
for version in sorted(migrations):
|
||||
if version <= current:
|
||||
continue
|
||||
conn.executescript(migrations[version])
|
||||
conn.execute("INSERT INTO _schema_version (version) VALUES (?)", (version,))
|
||||
@@ -1,71 +0,0 @@
|
||||
"""CER-P5-010 env validation tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from mcp_common.env_validation import (
|
||||
MissingEnvError,
|
||||
fail_fast_if_missing,
|
||||
optional_env,
|
||||
require_env,
|
||||
summarize,
|
||||
)
|
||||
|
||||
|
||||
def test_require_env_present(monkeypatch):
|
||||
monkeypatch.setenv("FOO_KEY", "value1")
|
||||
assert require_env("FOO_KEY") == "value1"
|
||||
|
||||
|
||||
def test_require_env_missing_raises(monkeypatch):
|
||||
monkeypatch.delenv("MISSING_REQ", raising=False)
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("MISSING_REQ", "critical path")
|
||||
|
||||
|
||||
def test_require_env_empty_raises(monkeypatch):
|
||||
monkeypatch.setenv("EMPTY_REQ", "")
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("EMPTY_REQ")
|
||||
|
||||
|
||||
def test_require_env_whitespace_only_raises(monkeypatch):
|
||||
monkeypatch.setenv("WS_REQ", " ")
|
||||
with pytest.raises(MissingEnvError):
|
||||
require_env("WS_REQ")
|
||||
|
||||
|
||||
def test_optional_env_default(monkeypatch):
|
||||
monkeypatch.delenv("OPT_A", raising=False)
|
||||
assert optional_env("OPT_A", default="fallback") == "fallback"
|
||||
|
||||
|
||||
def test_optional_env_set(monkeypatch):
|
||||
monkeypatch.setenv("OPT_B", "xx")
|
||||
assert optional_env("OPT_B", default="fallback") == "xx"
|
||||
|
||||
|
||||
def test_fail_fast_all_present(monkeypatch):
|
||||
monkeypatch.setenv("AA", "1")
|
||||
monkeypatch.setenv("BB", "2")
|
||||
fail_fast_if_missing(["AA", "BB"]) # no exit
|
||||
|
||||
|
||||
def test_fail_fast_missing_exits(monkeypatch):
|
||||
monkeypatch.setenv("HAVE_IT", "1")
|
||||
monkeypatch.delenv("MISSING_X", raising=False)
|
||||
with pytest.raises(SystemExit) as exc:
|
||||
fail_fast_if_missing(["HAVE_IT", "MISSING_X"])
|
||||
assert exc.value.code == 2
|
||||
|
||||
|
||||
def test_summarize_does_not_leak_secrets(monkeypatch, caplog):
|
||||
import logging
|
||||
monkeypatch.setenv("API_KEY_FOO", "super-secret-token-123456")
|
||||
monkeypatch.setenv("PORT", "9000")
|
||||
with caplog.at_level(logging.INFO, logger="mcp_common.env_validation"):
|
||||
summarize(["API_KEY_FOO", "PORT", "NOT_SET_XYZ"])
|
||||
log_text = "\n".join(caplog.messages)
|
||||
assert "super-secret-token-123456" not in log_text
|
||||
assert "9000" in log_text
|
||||
assert "<unset>" in log_text
|
||||
@@ -1,40 +0,0 @@
|
||||
from mcp_common.models import EventPriority, EventType, L2Entry
|
||||
|
||||
|
||||
def test_l2_entry_minimal():
|
||||
entry = L2Entry(
|
||||
timestamp="2026-04-17T10:30:00Z",
|
||||
setup="bull put spread ETH 1800/1750 14d",
|
||||
tesi="IV alta post-CPI, attesa mean-reversion",
|
||||
esito="aperto"
|
||||
)
|
||||
assert entry.scostamento_sigma is None
|
||||
assert entry.tesi_check is None
|
||||
|
||||
|
||||
def test_l2_entry_full():
|
||||
entry = L2Entry(
|
||||
timestamp="2026-04-17T10:30:00Z",
|
||||
setup="bull put spread ETH 1800/1750 14d",
|
||||
tesi="IV alta post-CPI",
|
||||
tesi_check="ETH sopra 1820 per 24h con IV in calo",
|
||||
invalidation="rottura 1800 con volume > 2x media",
|
||||
esito="chiuso +12 USDC",
|
||||
scostamento="nessuno",
|
||||
scostamento_sigma=0.5,
|
||||
lezione="supporto ha tenuto, timing ok",
|
||||
sizing_note="size 80 USDC (ATR 1.3x media)",
|
||||
)
|
||||
assert entry.scostamento_sigma == 0.5
|
||||
dump = entry.model_dump()
|
||||
assert dump["lezione"] == "supporto ha tenuto, timing ok"
|
||||
|
||||
|
||||
def test_event_priority_enum():
|
||||
assert EventPriority.CRITICAL.value == "critical"
|
||||
assert EventPriority.LOW < EventPriority.CRITICAL # ordering
|
||||
|
||||
|
||||
def test_event_type_enum():
|
||||
assert EventType.ALERT.value == "alert"
|
||||
assert EventType.USER_INSTRUCTION.value == "user_instruction"
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Tests for CER-016 risk guard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from mcp_common.risk_guard import (
|
||||
enforce_aggregate,
|
||||
enforce_leverage,
|
||||
enforce_single_notional,
|
||||
max_aggregate,
|
||||
max_leverage,
|
||||
max_notional,
|
||||
)
|
||||
|
||||
|
||||
def test_defaults(monkeypatch):
|
||||
for k in ("CERBERO_MAX_NOTIONAL", "CERBERO_MAX_AGGREGATE", "CERBERO_MAX_LEVERAGE"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
assert max_notional() == 200.0
|
||||
assert max_aggregate() == 1000.0
|
||||
assert max_leverage() == 3
|
||||
|
||||
|
||||
def test_env_override(monkeypatch):
|
||||
monkeypatch.setenv("CERBERO_MAX_NOTIONAL", "50")
|
||||
monkeypatch.setenv("CERBERO_MAX_AGGREGATE", "150")
|
||||
monkeypatch.setenv("CERBERO_MAX_LEVERAGE", "2")
|
||||
assert max_notional() == 50.0
|
||||
assert max_aggregate() == 150.0
|
||||
assert max_leverage() == 2
|
||||
|
||||
|
||||
def test_leverage_default_when_none(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
assert enforce_leverage(None) == 3
|
||||
|
||||
|
||||
def test_leverage_accepts_within_cap(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
assert enforce_leverage(2) == 2
|
||||
assert enforce_leverage(3) == 3
|
||||
|
||||
|
||||
def test_leverage_rejects_above_cap(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_leverage(50)
|
||||
assert exc.value.status_code == 403
|
||||
assert exc.value.detail["error"] == "HARD_PROHIBITION"
|
||||
|
||||
|
||||
def test_leverage_rejects_below_one(monkeypatch):
|
||||
with pytest.raises(HTTPException):
|
||||
enforce_leverage(0)
|
||||
|
||||
|
||||
def test_single_notional_ok(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
||||
enforce_single_notional(100.0, exchange="deribit", instrument="BTC-PERPETUAL")
|
||||
|
||||
|
||||
def test_single_notional_rejects(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_single_notional(335.0, exchange="hyperliquid", instrument="ETH")
|
||||
assert exc.value.status_code == 403
|
||||
assert "335" in exc.value.detail["message"]
|
||||
|
||||
|
||||
def test_aggregate_ok(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
||||
enforce_aggregate(current_total=500.0, new_notional=200.0)
|
||||
|
||||
|
||||
def test_aggregate_rejects(monkeypatch):
|
||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_aggregate(current_total=900.0, new_notional=200.0)
|
||||
assert exc.value.status_code == 403
|
||||
@@ -1,48 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from mcp_common.storage import Database, run_migrations
|
||||
|
||||
|
||||
def test_database_creates_wal(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
# WAL mode attivo
|
||||
mode = db.conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode.lower() == "wal"
|
||||
db.close()
|
||||
|
||||
|
||||
def test_database_migrations_run_once(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
migrations = {
|
||||
1: "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);",
|
||||
2: "ALTER TABLE foo ADD COLUMN value INTEGER DEFAULT 0;",
|
||||
}
|
||||
run_migrations(db.conn, migrations)
|
||||
# Second run: should be no-op
|
||||
run_migrations(db.conn, migrations)
|
||||
cols = [r[1] for r in db.conn.execute("PRAGMA table_info(foo)").fetchall()]
|
||||
assert "name" in cols
|
||||
assert "value" in cols
|
||||
version = db.conn.execute("SELECT MAX(version) FROM _schema_version").fetchone()[0]
|
||||
assert version == 2
|
||||
db.close()
|
||||
|
||||
|
||||
def test_database_partial_migration(tmp_path: Path):
|
||||
db_path = tmp_path / "test.db"
|
||||
db = Database(db_path)
|
||||
db.connect()
|
||||
migrations_v1 = {1: "CREATE TABLE foo (id INTEGER);"}
|
||||
run_migrations(db.conn, migrations_v1)
|
||||
migrations_v2 = {**migrations_v1, 2: "CREATE TABLE bar (id INTEGER);"}
|
||||
run_migrations(db.conn, migrations_v2)
|
||||
tables = {r[0] for r in db.conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()}
|
||||
assert "foo" in tables
|
||||
assert "bar" in tables
|
||||
db.close()
|
||||
Reference in New Issue
Block a user