refactor(mcp_common): remove risk_guard, models, env_validation, storage
This commit is contained in:
@@ -1,19 +1 @@
|
|||||||
from mcp_common.models import (
|
__all__ = []
|
||||||
Event,
|
|
||||||
EventPriority,
|
|
||||||
EventType,
|
|
||||||
L1State,
|
|
||||||
L2Entry,
|
|
||||||
L3Entry,
|
|
||||||
UserInstruction,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"L1State",
|
|
||||||
"L2Entry",
|
|
||||||
"L3Entry",
|
|
||||||
"Event",
|
|
||||||
"EventPriority",
|
|
||||||
"EventType",
|
|
||||||
"UserInstruction",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
"""CER-P5-010: env validation policy — fail-fast per mandatory, soft per optional.
|
|
||||||
|
|
||||||
Usage al boot di ogni mcp `__main__.py`:
|
|
||||||
|
|
||||||
from mcp_common.env_validation import require_env, optional_env, summarize
|
|
||||||
|
|
||||||
creds_file = require_env("CREDENTIALS_FILE", "deribit credentials JSON path")
|
|
||||||
host = optional_env("HOST", default="0.0.0.0")
|
|
||||||
summarize(["CREDENTIALS_FILE", "HOST", "PORT"])
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MissingEnvError(RuntimeError):
|
|
||||||
"""Mandatory env var absent or empty."""
|
|
||||||
|
|
||||||
|
|
||||||
def require_env(name: str, description: str = "") -> str:
|
|
||||||
"""Fail-fast: raise MissingEnvError se name non presente o vuoto.
|
|
||||||
|
|
||||||
Uscita dal processo con codice 2 se chiamato dal main(). Comporta
|
|
||||||
logging chiaro del missing var prima dell'exit.
|
|
||||||
"""
|
|
||||||
val = (os.environ.get(name) or "").strip()
|
|
||||||
if not val:
|
|
||||||
msg = f"missing mandatory env var: {name}"
|
|
||||||
if description:
|
|
||||||
msg += f" ({description})"
|
|
||||||
logger.error(msg)
|
|
||||||
raise MissingEnvError(msg)
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def optional_env(name: str, *, default: str = "") -> str:
|
|
||||||
"""Soft: ritorna env o default. Log INFO se default usato."""
|
|
||||||
val = (os.environ.get(name) or "").strip()
|
|
||||||
if not val:
|
|
||||||
if default:
|
|
||||||
logger.info("env %s not set, using default=%r", name, default)
|
|
||||||
return default
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def summarize(names: list[str]) -> None:
|
|
||||||
"""Log INFO di tutti gli env rilevanti con presenza (mask se SECRET/KEY/TOKEN)."""
|
|
||||||
sensitive_tokens = ("SECRET", "KEY", "TOKEN", "PASSWORD", "CREDENTIAL", "WALLET")
|
|
||||||
for n in names:
|
|
||||||
val = os.environ.get(n)
|
|
||||||
if val is None:
|
|
||||||
logger.info("env[%s]: <unset>", n)
|
|
||||||
continue
|
|
||||||
if any(t in n.upper() for t in sensitive_tokens):
|
|
||||||
logger.info("env[%s]: <set, %d chars>", n, len(val))
|
|
||||||
else:
|
|
||||||
logger.info("env[%s]: %s", n, val)
|
|
||||||
|
|
||||||
|
|
||||||
def fail_fast_if_missing(names: list[str]) -> None:
|
|
||||||
"""Verifica lista di nomi mandatory al boot. Exit 2 se uno solo manca.
|
|
||||||
|
|
||||||
Uso preferito: early call in main() per bloccare boot se config incompleta.
|
|
||||||
"""
|
|
||||||
missing: list[str] = []
|
|
||||||
for n in names:
|
|
||||||
if not (os.environ.get(n) or "").strip():
|
|
||||||
missing.append(n)
|
|
||||||
if missing:
|
|
||||||
logger.error("boot aborted: missing mandatory env vars: %s", missing)
|
|
||||||
print(
|
|
||||||
f"FATAL: missing mandatory env vars: {missing}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
sys.exit(2)
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
from functools import total_ordering
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
|
||||||
class EventPriority(StrEnum):
|
|
||||||
LOW = "low"
|
|
||||||
NORMAL = "normal"
|
|
||||||
HIGH = "high"
|
|
||||||
CRITICAL = "critical"
|
|
||||||
|
|
||||||
def _rank(self) -> int:
|
|
||||||
return ["low", "normal", "high", "critical"].index(self.value)
|
|
||||||
|
|
||||||
def __lt__(self, other: EventPriority) -> bool:
|
|
||||||
return self._rank() < other._rank()
|
|
||||||
|
|
||||||
|
|
||||||
class EventType(StrEnum):
|
|
||||||
ALERT = "alert"
|
|
||||||
USER_INSTRUCTION = "user_instruction"
|
|
||||||
SYSTEM = "system"
|
|
||||||
|
|
||||||
|
|
||||||
class L1State(BaseModel):
|
|
||||||
"""Singleton row with current operational state."""
|
|
||||||
|
|
||||||
updated_at: str
|
|
||||||
equity_total: float | None = None
|
|
||||||
equity_by_exchange: dict[str, float] = Field(default_factory=dict)
|
|
||||||
bias: str | None = None
|
|
||||||
pnl_day: float | None = None
|
|
||||||
pnl_total: float | None = None
|
|
||||||
capital: float | None = None
|
|
||||||
open_positions_count: int = 0
|
|
||||||
greeks_aggregate: dict[str, float] = Field(default_factory=dict)
|
|
||||||
notes: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class L2Entry(BaseModel):
|
|
||||||
"""Reasoning entry — schema matches system_prompt v2.
|
|
||||||
|
|
||||||
`authored_by_model`: identifica l'LLM che ha generato la entry (es.
|
|
||||||
"google/gemini-3-flash-preview" per core, "anthropic/claude-haiku-4-5"
|
|
||||||
per worker). None se scritto da operatore umano via UI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
timestamp: str
|
|
||||||
setup: str
|
|
||||||
tesi: str | None = None
|
|
||||||
tesi_check: str | None = None
|
|
||||||
invalidation: str | None = None
|
|
||||||
esito: str
|
|
||||||
scostamento: str | None = None
|
|
||||||
scostamento_sigma: float | None = None
|
|
||||||
lezione: str | None = None
|
|
||||||
sizing_note: str | None = None
|
|
||||||
run_id: str | None = None
|
|
||||||
user_instruction_id: int | None = None
|
|
||||||
authored_by_model: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class L3Entry(BaseModel):
|
|
||||||
"""Compacted pattern from L2 batch."""
|
|
||||||
|
|
||||||
created_at: str
|
|
||||||
category: str # "pattern_errore" | "pattern_vincente" | "correlazione"
|
|
||||||
summary: str
|
|
||||||
source_l2_ids: list[int]
|
|
||||||
|
|
||||||
|
|
||||||
class Event(BaseModel):
|
|
||||||
id: int | None = None
|
|
||||||
created_at: str
|
|
||||||
expires_at: str
|
|
||||||
type: EventType
|
|
||||||
source: str
|
|
||||||
priority: EventPriority
|
|
||||||
payload: dict[str, Any]
|
|
||||||
acked_at: str | None = None
|
|
||||||
ack_outcome: str | None = None
|
|
||||||
ack_notes: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserInstruction(BaseModel):
|
|
||||||
id: int | None = None
|
|
||||||
created_at: str
|
|
||||||
text: str
|
|
||||||
priority: EventPriority
|
|
||||||
require_ack: bool = True
|
|
||||||
source: str = "observer"
|
|
||||||
acked_at: str | None = None
|
|
||||||
ack_outcome: str | None = None
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""CER-016 hard guard server-side su place_order.
|
|
||||||
|
|
||||||
Caps configurabili via env (default safety-first, mirati a ~200 EUR single,
|
|
||||||
1000 EUR aggregato, 3x max leverage).
|
|
||||||
|
|
||||||
Thresholds sono numerici semplici — l'operatore stabilisce unità (EUR/USD)
|
|
||||||
via env; il server compara su un unico campo `notional` in valore monetario.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
|
|
||||||
def _env_float(name: str, default: float) -> float:
|
|
||||||
raw = os.environ.get(name)
|
|
||||||
if not raw:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return float(raw)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _env_int(name: str, default: int) -> int:
|
|
||||||
raw = os.environ.get(name)
|
|
||||||
if not raw:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return int(raw)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def max_notional() -> float:
|
|
||||||
return _env_float("CERBERO_MAX_NOTIONAL", 200.0)
|
|
||||||
|
|
||||||
|
|
||||||
def max_aggregate() -> float:
|
|
||||||
return _env_float("CERBERO_MAX_AGGREGATE", 1000.0)
|
|
||||||
|
|
||||||
|
|
||||||
def max_leverage() -> int:
|
|
||||||
return _env_int("CERBERO_MAX_LEVERAGE", 3)
|
|
||||||
|
|
||||||
|
|
||||||
def _hard_reject(reason: str) -> None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail={
|
|
||||||
"error": "HARD_PROHIBITION",
|
|
||||||
"message": reason,
|
|
||||||
"caps": {
|
|
||||||
"max_notional": max_notional(),
|
|
||||||
"max_aggregate": max_aggregate(),
|
|
||||||
"max_leverage": max_leverage(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def enforce_leverage(leverage: int | float | None) -> int:
|
|
||||||
"""Ritorna leverage applicabile. Default 3x se None. Reject se > cap."""
|
|
||||||
cap = max_leverage()
|
|
||||||
if leverage is None:
|
|
||||||
return cap
|
|
||||||
lev = int(leverage)
|
|
||||||
if lev < 1:
|
|
||||||
_hard_reject(f"leverage must be >= 1 (got {lev})")
|
|
||||||
if lev > cap:
|
|
||||||
_hard_reject(f"leverage {lev}x exceeds hard cap {cap}x")
|
|
||||||
return lev
|
|
||||||
|
|
||||||
|
|
||||||
def enforce_single_notional(notional: float, *, exchange: str, instrument: str) -> None:
|
|
||||||
cap = max_notional()
|
|
||||||
if notional > cap:
|
|
||||||
_hard_reject(
|
|
||||||
f"{exchange}.{instrument} notional {notional:.2f} exceeds single trade cap {cap:.2f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def enforce_aggregate(current_total: float, new_notional: float) -> None:
|
|
||||||
cap = max_aggregate()
|
|
||||||
total = current_total + new_notional
|
|
||||||
if total > cap:
|
|
||||||
_hard_reject(
|
|
||||||
f"aggregate notional {total:.2f} (current {current_total:.2f} + new "
|
|
||||||
f"{new_notional:.2f}) exceeds cap {cap:.2f}"
|
|
||||||
)
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Database:
|
|
||||||
path: Path
|
|
||||||
conn: sqlite3.Connection | None = None
|
|
||||||
|
|
||||||
def connect(self) -> sqlite3.Connection:
|
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.conn = sqlite3.connect(
|
|
||||||
str(self.path),
|
|
||||||
isolation_level=None,
|
|
||||||
check_same_thread=False,
|
|
||||||
)
|
|
||||||
self.conn.row_factory = sqlite3.Row
|
|
||||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
self.conn.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
|
||||||
return self.conn
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
if self.conn is not None:
|
|
||||||
self.conn.close()
|
|
||||||
self.conn = None
|
|
||||||
|
|
||||||
|
|
||||||
def run_migrations(conn: sqlite3.Connection, migrations: dict[int, str]) -> None:
|
|
||||||
"""Idempotent migrations. `migrations` keys are monotonic version numbers."""
|
|
||||||
conn.execute(
|
|
||||||
"CREATE TABLE IF NOT EXISTS _schema_version (version INTEGER PRIMARY KEY)"
|
|
||||||
)
|
|
||||||
cur = conn.execute("SELECT COALESCE(MAX(version), 0) FROM _schema_version")
|
|
||||||
current = cur.fetchone()[0]
|
|
||||||
for version in sorted(migrations):
|
|
||||||
if version <= current:
|
|
||||||
continue
|
|
||||||
conn.executescript(migrations[version])
|
|
||||||
conn.execute("INSERT INTO _schema_version (version) VALUES (?)", (version,))
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
"""CER-P5-010 env validation tests."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from mcp_common.env_validation import (
|
|
||||||
MissingEnvError,
|
|
||||||
fail_fast_if_missing,
|
|
||||||
optional_env,
|
|
||||||
require_env,
|
|
||||||
summarize,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_require_env_present(monkeypatch):
|
|
||||||
monkeypatch.setenv("FOO_KEY", "value1")
|
|
||||||
assert require_env("FOO_KEY") == "value1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_require_env_missing_raises(monkeypatch):
|
|
||||||
monkeypatch.delenv("MISSING_REQ", raising=False)
|
|
||||||
with pytest.raises(MissingEnvError):
|
|
||||||
require_env("MISSING_REQ", "critical path")
|
|
||||||
|
|
||||||
|
|
||||||
def test_require_env_empty_raises(monkeypatch):
|
|
||||||
monkeypatch.setenv("EMPTY_REQ", "")
|
|
||||||
with pytest.raises(MissingEnvError):
|
|
||||||
require_env("EMPTY_REQ")
|
|
||||||
|
|
||||||
|
|
||||||
def test_require_env_whitespace_only_raises(monkeypatch):
|
|
||||||
monkeypatch.setenv("WS_REQ", " ")
|
|
||||||
with pytest.raises(MissingEnvError):
|
|
||||||
require_env("WS_REQ")
|
|
||||||
|
|
||||||
|
|
||||||
def test_optional_env_default(monkeypatch):
|
|
||||||
monkeypatch.delenv("OPT_A", raising=False)
|
|
||||||
assert optional_env("OPT_A", default="fallback") == "fallback"
|
|
||||||
|
|
||||||
|
|
||||||
def test_optional_env_set(monkeypatch):
|
|
||||||
monkeypatch.setenv("OPT_B", "xx")
|
|
||||||
assert optional_env("OPT_B", default="fallback") == "xx"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fail_fast_all_present(monkeypatch):
|
|
||||||
monkeypatch.setenv("AA", "1")
|
|
||||||
monkeypatch.setenv("BB", "2")
|
|
||||||
fail_fast_if_missing(["AA", "BB"]) # no exit
|
|
||||||
|
|
||||||
|
|
||||||
def test_fail_fast_missing_exits(monkeypatch):
|
|
||||||
monkeypatch.setenv("HAVE_IT", "1")
|
|
||||||
monkeypatch.delenv("MISSING_X", raising=False)
|
|
||||||
with pytest.raises(SystemExit) as exc:
|
|
||||||
fail_fast_if_missing(["HAVE_IT", "MISSING_X"])
|
|
||||||
assert exc.value.code == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_summarize_does_not_leak_secrets(monkeypatch, caplog):
|
|
||||||
import logging
|
|
||||||
monkeypatch.setenv("API_KEY_FOO", "super-secret-token-123456")
|
|
||||||
monkeypatch.setenv("PORT", "9000")
|
|
||||||
with caplog.at_level(logging.INFO, logger="mcp_common.env_validation"):
|
|
||||||
summarize(["API_KEY_FOO", "PORT", "NOT_SET_XYZ"])
|
|
||||||
log_text = "\n".join(caplog.messages)
|
|
||||||
assert "super-secret-token-123456" not in log_text
|
|
||||||
assert "9000" in log_text
|
|
||||||
assert "<unset>" in log_text
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from mcp_common.models import EventPriority, EventType, L2Entry
|
|
||||||
|
|
||||||
|
|
||||||
def test_l2_entry_minimal():
|
|
||||||
entry = L2Entry(
|
|
||||||
timestamp="2026-04-17T10:30:00Z",
|
|
||||||
setup="bull put spread ETH 1800/1750 14d",
|
|
||||||
tesi="IV alta post-CPI, attesa mean-reversion",
|
|
||||||
esito="aperto"
|
|
||||||
)
|
|
||||||
assert entry.scostamento_sigma is None
|
|
||||||
assert entry.tesi_check is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_l2_entry_full():
|
|
||||||
entry = L2Entry(
|
|
||||||
timestamp="2026-04-17T10:30:00Z",
|
|
||||||
setup="bull put spread ETH 1800/1750 14d",
|
|
||||||
tesi="IV alta post-CPI",
|
|
||||||
tesi_check="ETH sopra 1820 per 24h con IV in calo",
|
|
||||||
invalidation="rottura 1800 con volume > 2x media",
|
|
||||||
esito="chiuso +12 USDC",
|
|
||||||
scostamento="nessuno",
|
|
||||||
scostamento_sigma=0.5,
|
|
||||||
lezione="supporto ha tenuto, timing ok",
|
|
||||||
sizing_note="size 80 USDC (ATR 1.3x media)",
|
|
||||||
)
|
|
||||||
assert entry.scostamento_sigma == 0.5
|
|
||||||
dump = entry.model_dump()
|
|
||||||
assert dump["lezione"] == "supporto ha tenuto, timing ok"
|
|
||||||
|
|
||||||
|
|
||||||
def test_event_priority_enum():
|
|
||||||
assert EventPriority.CRITICAL.value == "critical"
|
|
||||||
assert EventPriority.LOW < EventPriority.CRITICAL # ordering
|
|
||||||
|
|
||||||
|
|
||||||
def test_event_type_enum():
|
|
||||||
assert EventType.ALERT.value == "alert"
|
|
||||||
assert EventType.USER_INSTRUCTION.value == "user_instruction"
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
"""Tests for CER-016 risk guard."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from mcp_common.risk_guard import (
|
|
||||||
enforce_aggregate,
|
|
||||||
enforce_leverage,
|
|
||||||
enforce_single_notional,
|
|
||||||
max_aggregate,
|
|
||||||
max_leverage,
|
|
||||||
max_notional,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_defaults(monkeypatch):
|
|
||||||
for k in ("CERBERO_MAX_NOTIONAL", "CERBERO_MAX_AGGREGATE", "CERBERO_MAX_LEVERAGE"):
|
|
||||||
monkeypatch.delenv(k, raising=False)
|
|
||||||
assert max_notional() == 200.0
|
|
||||||
assert max_aggregate() == 1000.0
|
|
||||||
assert max_leverage() == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_env_override(monkeypatch):
|
|
||||||
monkeypatch.setenv("CERBERO_MAX_NOTIONAL", "50")
|
|
||||||
monkeypatch.setenv("CERBERO_MAX_AGGREGATE", "150")
|
|
||||||
monkeypatch.setenv("CERBERO_MAX_LEVERAGE", "2")
|
|
||||||
assert max_notional() == 50.0
|
|
||||||
assert max_aggregate() == 150.0
|
|
||||||
assert max_leverage() == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_leverage_default_when_none(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
|
||||||
assert enforce_leverage(None) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_leverage_accepts_within_cap(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
|
||||||
assert enforce_leverage(2) == 2
|
|
||||||
assert enforce_leverage(3) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_leverage_rejects_above_cap(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_LEVERAGE", raising=False)
|
|
||||||
with pytest.raises(HTTPException) as exc:
|
|
||||||
enforce_leverage(50)
|
|
||||||
assert exc.value.status_code == 403
|
|
||||||
assert exc.value.detail["error"] == "HARD_PROHIBITION"
|
|
||||||
|
|
||||||
|
|
||||||
def test_leverage_rejects_below_one(monkeypatch):
|
|
||||||
with pytest.raises(HTTPException):
|
|
||||||
enforce_leverage(0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_notional_ok(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
|
||||||
enforce_single_notional(100.0, exchange="deribit", instrument="BTC-PERPETUAL")
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_notional_rejects(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_NOTIONAL", raising=False)
|
|
||||||
with pytest.raises(HTTPException) as exc:
|
|
||||||
enforce_single_notional(335.0, exchange="hyperliquid", instrument="ETH")
|
|
||||||
assert exc.value.status_code == 403
|
|
||||||
assert "335" in exc.value.detail["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_ok(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
|
||||||
enforce_aggregate(current_total=500.0, new_notional=200.0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_rejects(monkeypatch):
|
|
||||||
monkeypatch.delenv("CERBERO_MAX_AGGREGATE", raising=False)
|
|
||||||
with pytest.raises(HTTPException) as exc:
|
|
||||||
enforce_aggregate(current_total=900.0, new_notional=200.0)
|
|
||||||
assert exc.value.status_code == 403
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from mcp_common.storage import Database, run_migrations
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_creates_wal(tmp_path: Path):
|
|
||||||
db_path = tmp_path / "test.db"
|
|
||||||
db = Database(db_path)
|
|
||||||
db.connect()
|
|
||||||
# WAL mode attivo
|
|
||||||
mode = db.conn.execute("PRAGMA journal_mode").fetchone()[0]
|
|
||||||
assert mode.lower() == "wal"
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_migrations_run_once(tmp_path: Path):
|
|
||||||
db_path = tmp_path / "test.db"
|
|
||||||
db = Database(db_path)
|
|
||||||
db.connect()
|
|
||||||
migrations = {
|
|
||||||
1: "CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT);",
|
|
||||||
2: "ALTER TABLE foo ADD COLUMN value INTEGER DEFAULT 0;",
|
|
||||||
}
|
|
||||||
run_migrations(db.conn, migrations)
|
|
||||||
# Second run: should be no-op
|
|
||||||
run_migrations(db.conn, migrations)
|
|
||||||
cols = [r[1] for r in db.conn.execute("PRAGMA table_info(foo)").fetchall()]
|
|
||||||
assert "name" in cols
|
|
||||||
assert "value" in cols
|
|
||||||
version = db.conn.execute("SELECT MAX(version) FROM _schema_version").fetchone()[0]
|
|
||||||
assert version == 2
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_partial_migration(tmp_path: Path):
|
|
||||||
db_path = tmp_path / "test.db"
|
|
||||||
db = Database(db_path)
|
|
||||||
db.connect()
|
|
||||||
migrations_v1 = {1: "CREATE TABLE foo (id INTEGER);"}
|
|
||||||
run_migrations(db.conn, migrations_v1)
|
|
||||||
migrations_v2 = {**migrations_v1, 2: "CREATE TABLE bar (id INTEGER);"}
|
|
||||||
run_migrations(db.conn, migrations_v2)
|
|
||||||
tables = {r[0] for r in db.conn.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
|
||||||
).fetchall()}
|
|
||||||
assert "foo" in tables
|
|
||||||
assert "bar" in tables
|
|
||||||
db.close()
|
|
||||||
Reference in New Issue
Block a user