Phase 2: persistence + safety controls
Aggiunge la persistenza SQLite, l'audit log a hash chain, il kill switch coordinato e i CLI di gestione documentati in docs/05-data-model.md e docs/07-risk-controls.md. 197 test pass, 1 skipped (sqlite3 CLI mancante), copertura totale 97%. State (`state/`): - 0001_init.sql con positions, instructions, decisions, dvol_history, manual_actions, system_state. - db.py: connect con WAL + foreign_keys + transaction ctx, runner forward-only basato su PRAGMA user_version. - models.py: record Pydantic, Decimal preservato come TEXT. - repository.py: CRUD typed con singola connessione passata, cache aware, posizioni concorrenti. Safety (`safety/`): - audit_log.py: AuditLog append-only con SHA-256 chain e fsync, verify_chain riconosce ogni manomissione (payload, prev_hash, hash, JSON, separatori). - kill_switch.py: arm/disarm transazionali, idempotenti, accoppiati all'audit chain. Config (`config/loader.py` + `strategy.yaml`): - Loader YAML con deep-merge di strategy.local.yaml. - Verifica config_hash SHA-256 (riga config_hash esclusa). - File golden strategy.yaml + esempio override. Scripts: - dead_man.sh: watchdog shell indipendente da Python. - backup.py: VACUUM INTO orario con retention 30 giorni. CLI: - audit verify (exit 2 su tampering). - kill-switch arm/disarm/status su SQLite reale. - state inspect con tabella posizioni aperte. - config hash, config validate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
"""Smoke tests for the dead_man.sh shell watchdog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import sqlite3
|
||||
import subprocess
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
SCRIPT = REPO_ROOT / "scripts" / "dead_man.sh"
|
||||
|
||||
_SQLITE3_BIN = shutil.which("sqlite3")
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.integration,
|
||||
]
|
||||
|
||||
|
||||
def _require_sqlite3() -> None:
|
||||
if _SQLITE3_BIN is None:
|
||||
pytest.skip("sqlite3 CLI not installed on this host")
|
||||
|
||||
|
||||
def _setup_project(tmp_path: Path) -> Path:
|
||||
project = tmp_path / "project"
|
||||
(project / "data" / "log").mkdir(parents=True)
|
||||
(project / "scripts").mkdir(parents=True)
|
||||
shutil.copy(SCRIPT, project / "scripts" / "dead_man.sh")
|
||||
return project
|
||||
|
||||
|
||||
def _write_health(project: Path, ts: datetime) -> None:
|
||||
log_file = project / "data" / "log" / f"cerbero-bite-{ts:%Y-%m-%d}.jsonl"
|
||||
log_file.write_text(
|
||||
f'{{"ts": "{ts.astimezone(UTC).isoformat()}", "event": "HEALTH_OK"}}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _run(project: Path, threshold: int = 900) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run(
|
||||
["bash", str(project / "scripts" / "dead_man.sh")],
|
||||
env={
|
||||
"PATH": "/usr/bin:/bin",
|
||||
"PROJECT_ROOT": str(project),
|
||||
"DEAD_MAN_THRESHOLD_SECONDS": str(threshold),
|
||||
},
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def test_dead_man_exits_zero_when_recent_health_ok(tmp_path: Path) -> None:
|
||||
project = _setup_project(tmp_path)
|
||||
_write_health(project, datetime.now(UTC) - timedelta(seconds=60))
|
||||
result = _run(project, threshold=900)
|
||||
assert result.returncode == 0, result.stderr
|
||||
|
||||
|
||||
def test_dead_man_arms_kill_switch_when_silent(tmp_path: Path) -> None:
|
||||
_require_sqlite3()
|
||||
project = _setup_project(tmp_path)
|
||||
_write_health(project, datetime.now(UTC) - timedelta(seconds=2000))
|
||||
|
||||
# Pre-create the SQLite system_state singleton; otherwise the script
|
||||
# has nothing to update.
|
||||
db = project / "data" / "state.sqlite"
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.execute(
|
||||
"CREATE TABLE system_state (id INTEGER PRIMARY KEY CHECK(id=1), "
|
||||
"kill_switch INTEGER NOT NULL DEFAULT 0, kill_reason TEXT, "
|
||||
"kill_at TEXT, last_health_check TEXT NOT NULL, "
|
||||
"last_kelly_calib TEXT, config_version TEXT NOT NULL, "
|
||||
"started_at TEXT NOT NULL)"
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
|
||||
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
result = _run(project, threshold=900)
|
||||
assert result.returncode == 1
|
||||
|
||||
# Kill switch armed.
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
kill = conn.execute("SELECT kill_switch FROM system_state").fetchone()[0]
|
||||
finally:
|
||||
conn.close()
|
||||
assert kill == 1
|
||||
|
||||
# Alert file exists.
|
||||
alert = project / "data" / "log" / "dead-man-alert.txt"
|
||||
assert alert.exists()
|
||||
assert "dead_man" in alert.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_dead_man_handles_missing_log_file(tmp_path: Path) -> None:
|
||||
project = _setup_project(tmp_path)
|
||||
# No log file at all.
|
||||
result = _run(project, threshold=900)
|
||||
assert result.returncode == 1
|
||||
alert = project / "data" / "log" / "dead-man-alert.txt"
|
||||
assert alert.exists()
|
||||
@@ -0,0 +1,273 @@
|
||||
"""Audit chain writer + verifier tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cerbero_bite.safety.audit_log import (
|
||||
GENESIS_HASH,
|
||||
AuditChainError,
|
||||
AuditLog,
|
||||
iter_entries,
|
||||
verify_chain,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_file_verifies_with_zero_entries(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
assert verify_chain(path) == 0
|
||||
|
||||
|
||||
def test_first_entry_uses_genesis_prev_hash(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
entry = log.append(
|
||||
event="ENGINE_START",
|
||||
payload={"version": "1.0.0"},
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
assert entry.prev_hash == GENESIS_HASH
|
||||
assert entry.hash != GENESIS_HASH
|
||||
|
||||
|
||||
def test_chain_links_subsequent_entries(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
e1 = log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
e2 = log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
|
||||
e3 = log.append(event="C", payload={"i": 3}, now=datetime(2026, 4, 27, 14, 2, tzinfo=UTC))
|
||||
assert e2.prev_hash == e1.hash
|
||||
assert e3.prev_hash == e2.hash
|
||||
assert verify_chain(path) == 3
|
||||
|
||||
|
||||
def test_iter_entries_yields_in_order(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
|
||||
events = [e.event for e in iter_entries(path)]
|
||||
assert events == ["A", "B"]
|
||||
|
||||
|
||||
def test_log_resumes_chain_after_reopen(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
first = AuditLog(path)
|
||||
e1 = first.append(
|
||||
event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
)
|
||||
|
||||
second = AuditLog(path)
|
||||
assert second.last_hash == e1.hash
|
||||
e2 = second.append(
|
||||
event="B", payload={"k": "v"}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)
|
||||
)
|
||||
assert e2.prev_hash == e1.hash
|
||||
assert verify_chain(path) == 2
|
||||
|
||||
|
||||
def test_payload_with_pipe_character_round_trips(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(
|
||||
event="NOTE",
|
||||
payload={"text": "first|second|third"},
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
entries = list(iter_entries(path))
|
||||
assert entries[0].payload == {"text": "first|second|third"}
|
||||
assert verify_chain(path) == 1
|
||||
|
||||
|
||||
def test_tampered_payload_breaks_chain(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
|
||||
|
||||
# Mutate the first line's payload by hand.
|
||||
text = path.read_text(encoding="utf-8").splitlines()
|
||||
text[0] = text[0].replace('"i":1', '"i":99')
|
||||
path.write_text("\n".join(text) + "\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(AuditChainError, match="hash mismatch"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_verify_chain_skips_blank_lines(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
path.write_text("\n" + raw + "\n \n", encoding="utf-8")
|
||||
# The chain still verifies despite the surrounding whitespace lines.
|
||||
assert verify_chain(path) == 1
|
||||
|
||||
|
||||
def test_prev_hash_mismatch_between_entries_is_caught(tmp_path: Path) -> None:
|
||||
"""Second line's prev_hash points to a different chain — verify_chain rejects."""
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
e1 = log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
|
||||
# Build a synthetic second line whose prev_hash != e1.hash but whose
|
||||
# own hash is correctly computed from that bogus prev_hash.
|
||||
fake_prev = "0" * 32 + "f" * 32
|
||||
ts2 = "2026-04-27T14:01:00+00:00"
|
||||
payload_json = "{}"
|
||||
raw = f"{ts2}|B|{payload_json}|{fake_prev}"
|
||||
fake_hash = hashlib.sha256(raw.encode()).hexdigest()
|
||||
line = f"{ts2}|B|{payload_json}|prev_hash={fake_prev}|hash={fake_hash}\n"
|
||||
with path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(line)
|
||||
|
||||
assert e1.hash != fake_prev # sanity
|
||||
with pytest.raises(AuditChainError, match="prev_hash mismatch"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_tampered_prev_hash_breaks_chain(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
log.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
|
||||
|
||||
# Inject an unrelated prev_hash on the second line.
|
||||
lines = path.read_text(encoding="utf-8").splitlines()
|
||||
lines[1] = lines[1].replace("prev_hash=", "prev_hash=" + "f" * 64 + "X")
|
||||
# Truncate to recover length: replace prev_hash field with all-ff.
|
||||
lines[1] = lines[1].replace("X", "")
|
||||
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(AuditChainError):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_malformed_line_raises_chain_error(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text("not-a-valid-line\n", encoding="utf-8")
|
||||
with pytest.raises(AuditChainError):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_missing_hash_field(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text(
|
||||
"2026-04-27T14:00:00+00:00|EVT|{}|prev_hash=" + "0" * 64 + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
with pytest.raises(AuditChainError, match="hash="):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_missing_prev_hash_field(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text(
|
||||
"2026-04-27T14:00:00+00:00|EVT|{}|hash=" + "f" * 64 + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
with pytest.raises(AuditChainError, match="prev_hash"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_line_with_no_separators(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text("just-a-blob|hash=" + "f" * 64 + "\n", encoding="utf-8")
|
||||
with pytest.raises(AuditChainError, match="prev_hash"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_malformed_leading_section(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
# Two `|` only: rsplit succeeds twice, leading parts has 1 element ≠ 3.
|
||||
path.write_text(
|
||||
"tooshort|prev_hash=" + "0" * 64 + "|hash=" + "f" * 64 + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
with pytest.raises(AuditChainError, match="leading section"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_payload_not_a_json_object(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text(
|
||||
"2026-04-27T14:00:00+00:00|EVT|[1,2]|prev_hash="
|
||||
+ "0" * 64
|
||||
+ "|hash="
|
||||
+ "f" * 64
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
with pytest.raises(AuditChainError, match="JSON object"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_parser_rejects_payload_with_invalid_json(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
path.write_text(
|
||||
"2026-04-27T14:00:00+00:00|EVT|{not-json}|prev_hash="
|
||||
+ "0" * 64
|
||||
+ "|hash="
|
||||
+ "f" * 64
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
with pytest.raises(AuditChainError, match="JSON"):
|
||||
verify_chain(path)
|
||||
|
||||
|
||||
def test_iter_entries_returns_empty_when_file_missing(tmp_path: Path) -> None:
|
||||
path = tmp_path / "missing.log"
|
||||
assert list(iter_entries(path)) == []
|
||||
|
||||
|
||||
def test_iter_entries_skips_blank_lines(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
path.write_text(raw + "\n\n", encoding="utf-8")
|
||||
entries = list(iter_entries(path))
|
||||
assert len(entries) == 1
|
||||
|
||||
|
||||
def test_log_resumes_chain_with_large_file(tmp_path: Path) -> None:
|
||||
"""Tail-seek reads past the 4096-byte chunk boundary."""
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
base = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
# Each line ~150 chars; 50 lines is comfortably > 4096 bytes.
|
||||
for i in range(50):
|
||||
log.append(
|
||||
event=f"E{i}",
|
||||
payload={"i": i, "filler": "x" * 80},
|
||||
now=base + timedelta(seconds=i),
|
||||
)
|
||||
|
||||
last_hash = log.last_hash
|
||||
reopened = AuditLog(path)
|
||||
assert reopened.last_hash == last_hash
|
||||
assert verify_chain(path) == 50
|
||||
|
||||
|
||||
def test_payload_serialisation_is_canonical(tmp_path: Path) -> None:
|
||||
path = tmp_path / "audit.log"
|
||||
log = AuditLog(path)
|
||||
# Different key order must produce identical hashes.
|
||||
e1 = log.append(
|
||||
event="A",
|
||||
payload={"b": 1, "a": 2},
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
other = tmp_path / "audit_other.log"
|
||||
log2 = AuditLog(other)
|
||||
e2 = log2.append(
|
||||
event="A",
|
||||
payload={"a": 2, "b": 1},
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
assert e1.hash == e2.hash
|
||||
@@ -0,0 +1,126 @@
|
||||
"""Tests for scripts.backup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cerbero_bite.state import connect, run_migrations
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _load_backup_module() -> object:
|
||||
"""Load scripts/backup.py as a module without polluting sys.path."""
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"_cerbero_bite_backup", REPO_ROOT / "scripts" / "backup.py"
|
||||
)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backup_mod() -> object:
|
||||
return _load_backup_module()
|
||||
|
||||
|
||||
def test_backup_database_creates_snapshot(tmp_path: Path, backup_mod: object) -> None:
|
||||
db = tmp_path / "state.sqlite"
|
||||
conn = connect(db)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
conn.execute(
|
||||
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
|
||||
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
backup_dir = tmp_path / "backups"
|
||||
snapshot = backup_mod.backup_database( # type: ignore[attr-defined]
|
||||
db_path=db,
|
||||
backup_dir=backup_dir,
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
assert snapshot.exists()
|
||||
assert snapshot.name == "state-20260427-14.sqlite"
|
||||
|
||||
# Snapshot is itself a valid SQLite db with the same row.
|
||||
snap = connect(snapshot)
|
||||
try:
|
||||
rows = snap.execute(
|
||||
"SELECT config_version FROM system_state WHERE id = 1"
|
||||
).fetchone()
|
||||
finally:
|
||||
snap.close()
|
||||
assert rows[0] == "1.0.0"
|
||||
|
||||
|
||||
def test_backup_database_replaces_existing_hour_snapshot(
|
||||
tmp_path: Path, backup_mod: object
|
||||
) -> None:
|
||||
db = tmp_path / "state.sqlite"
|
||||
conn = connect(db)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
when = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
first = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
|
||||
second = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
|
||||
assert first == second
|
||||
files = list((tmp_path / "b").iterdir())
|
||||
assert len(files) == 1
|
||||
|
||||
|
||||
def test_prune_backups_removes_old_files(tmp_path: Path, backup_mod: object) -> None:
|
||||
backup_dir = tmp_path / "b"
|
||||
backup_dir.mkdir()
|
||||
fresh = backup_dir / "state-20260420-10.sqlite"
|
||||
stale = backup_dir / "state-20260101-12.sqlite"
|
||||
other = backup_dir / "unrelated.txt"
|
||||
fresh.touch()
|
||||
stale.touch()
|
||||
other.touch()
|
||||
|
||||
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
|
||||
backup_dir,
|
||||
retention_days=30,
|
||||
now=datetime(2026, 4, 27, tzinfo=UTC),
|
||||
)
|
||||
assert deleted == [stale]
|
||||
assert fresh.exists()
|
||||
assert other.exists()
|
||||
|
||||
|
||||
def test_prune_backups_ignores_unparseable_filenames(
|
||||
tmp_path: Path, backup_mod: object
|
||||
) -> None:
|
||||
backup_dir = tmp_path / "b"
|
||||
backup_dir.mkdir()
|
||||
(backup_dir / "state-bogus-XX.sqlite").touch()
|
||||
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
|
||||
backup_dir,
|
||||
retention_days=0,
|
||||
now=datetime(2026, 4, 27, tzinfo=UTC) + timedelta(days=10000),
|
||||
)
|
||||
assert deleted == []
|
||||
|
||||
|
||||
def test_list_backups_returns_sorted(tmp_path: Path, backup_mod: object) -> None:
|
||||
backup_dir = tmp_path / "b"
|
||||
backup_dir.mkdir()
|
||||
a = backup_dir / "state-20260103-08.sqlite"
|
||||
b = backup_dir / "state-20260101-08.sqlite"
|
||||
a.touch()
|
||||
b.touch()
|
||||
listed = list(backup_mod.list_backups(backup_dir)) # type: ignore[attr-defined]
|
||||
assert listed == [b, a]
|
||||
@@ -0,0 +1,160 @@
|
||||
"""End-to-end CLI tests for the Phase 2 commands.
|
||||
|
||||
The commands hit real on-disk paths (tmp_path) so the assertions run
|
||||
the production code paths verbatim.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from cerbero_bite.cli import main as cli_main
|
||||
from cerbero_bite.safety import AuditLog
|
||||
from cerbero_bite.state import Repository, connect, run_migrations, transaction
|
||||
|
||||
|
||||
def _seed_state(db_path: Path) -> None:
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
with transaction(conn):
|
||||
Repository().init_system_state(
|
||||
conn,
|
||||
config_version="1.0.0",
|
||||
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_audit_verify_reports_ok_on_clean_chain(tmp_path: Path) -> None:
|
||||
audit = AuditLog(tmp_path / "audit.log")
|
||||
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
audit.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli_main, ["audit", "verify", "--file", str(tmp_path / "audit.log")]
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "ok" in result.output
|
||||
assert "2" in result.output
|
||||
|
||||
|
||||
def test_audit_verify_handles_empty_file(tmp_path: Path) -> None:
|
||||
target = tmp_path / "audit.log"
|
||||
target.write_text("", encoding="utf-8")
|
||||
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
|
||||
assert result.exit_code == 0
|
||||
assert "empty" in result.output
|
||||
|
||||
|
||||
def test_audit_verify_exits_nonzero_on_tampering(tmp_path: Path) -> None:
|
||||
target = tmp_path / "audit.log"
|
||||
audit = AuditLog(target)
|
||||
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
|
||||
target.write_text(
|
||||
target.read_text(encoding="utf-8").replace('"event":"A"', '"event":"X"'),
|
||||
encoding="utf-8",
|
||||
)
|
||||
# NB: we mutated the JSON payload, but the actual line still has event=A.
|
||||
# Force tampering by editing the literal "A" in the line text.
|
||||
raw = target.read_text(encoding="utf-8")
|
||||
target.write_text(raw.replace("|A|", "|X|", 1), encoding="utf-8")
|
||||
|
||||
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
|
||||
assert result.exit_code == 2
|
||||
assert "TAMPERED" in result.output
|
||||
|
||||
|
||||
def test_kill_switch_status_prints_disarmed(tmp_path: Path) -> None:
|
||||
db = tmp_path / "state.sqlite"
|
||||
_seed_state(db)
|
||||
result = CliRunner().invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
|
||||
assert result.exit_code == 0
|
||||
assert "disarmed" in result.output
|
||||
|
||||
|
||||
def test_kill_switch_arm_then_status_shows_armed(tmp_path: Path) -> None:
|
||||
db = tmp_path / "state.sqlite"
|
||||
audit = tmp_path / "audit.log"
|
||||
runner = CliRunner()
|
||||
arm = runner.invoke(
|
||||
cli_main,
|
||||
[
|
||||
"kill-switch",
|
||||
"arm",
|
||||
"--reason",
|
||||
"manual smoke",
|
||||
"--db",
|
||||
str(db),
|
||||
"--audit",
|
||||
str(audit),
|
||||
],
|
||||
)
|
||||
assert arm.exit_code == 0, arm.output
|
||||
status = runner.invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
|
||||
assert status.exit_code == 0
|
||||
assert "ARMED" in status.output
|
||||
|
||||
|
||||
def test_kill_switch_status_handles_missing_db(tmp_path: Path) -> None:
|
||||
result = CliRunner().invoke(
|
||||
cli_main, ["kill-switch", "status", "--db", str(tmp_path / "absent.sqlite")]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "not found" in result.output
|
||||
|
||||
|
||||
def test_state_inspect_shows_no_open_positions(tmp_path: Path) -> None:
|
||||
db = tmp_path / "state.sqlite"
|
||||
_seed_state(db)
|
||||
result = CliRunner().invoke(cli_main, ["state", "inspect", "--db", str(db)])
|
||||
assert result.exit_code == 0
|
||||
assert "no open positions" in result.output
|
||||
|
||||
|
||||
def test_state_inspect_handles_missing_db(tmp_path: Path) -> None:
|
||||
result = CliRunner().invoke(
|
||||
cli_main, ["state", "inspect", "--db", str(tmp_path / "absent.sqlite")]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "not found" in result.output
|
||||
|
||||
|
||||
def test_config_hash_matches_loader(tmp_path: Path) -> None:
|
||||
target = tmp_path / "strategy.yaml"
|
||||
target.write_text(
|
||||
'config_version: "1.0.0"\nconfig_hash: "0000"\nasset:\n symbol: ETH\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = CliRunner().invoke(cli_main, ["config", "hash", "--file", str(target)])
|
||||
assert result.exit_code == 0
|
||||
assert len(result.output.strip()) == 64 # sha256 hex
|
||||
|
||||
|
||||
def test_config_validate_repo_strategy_yaml() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
yaml_path = repo_root / "strategy.yaml"
|
||||
result = CliRunner().invoke(
|
||||
cli_main, ["config", "validate", "--file", str(yaml_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "ok" in result.output
|
||||
|
||||
|
||||
def test_config_validate_with_no_enforce_hash_skips_check(tmp_path: Path) -> None:
|
||||
target = tmp_path / "strategy.yaml"
|
||||
target.write_text(
|
||||
'config_version: "1.0.0"\nconfig_hash: "wrong"\n'
|
||||
'last_review: "2026-04-26"\nlast_reviewer: "test"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = CliRunner().invoke(
|
||||
cli_main,
|
||||
["config", "validate", "--file", str(target), "--no-enforce-hash"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "ok" in result.output
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Tests for the YAML loader and hash verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from cerbero_bite.config.loader import (
|
||||
ConfigHashError,
|
||||
compute_config_hash,
|
||||
load_strategy,
|
||||
)
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _golden_yaml_skeleton(**overrides: object) -> dict[str, object]:
|
||||
base = {
|
||||
"config_version": "1.0.0-test",
|
||||
"config_hash": "0" * 64,
|
||||
"last_review": "2026-04-26",
|
||||
"last_reviewer": "test",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _write_with_correct_hash(path: Path, doc: dict[str, object]) -> None:
|
||||
"""Write a YAML doc and patch ``config_hash`` to match the file body."""
|
||||
text = yaml.safe_dump(doc, sort_keys=False)
|
||||
path.write_text(text, encoding="utf-8")
|
||||
new_hash = compute_config_hash(path.read_text(encoding="utf-8"))
|
||||
doc = {**doc, "config_hash": new_hash}
|
||||
text = yaml.safe_dump(doc, sort_keys=False)
|
||||
path.write_text(text, encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_config_hash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_hash_is_independent_of_recorded_hash_value(tmp_path: Path) -> None:
|
||||
a = tmp_path / "a.yaml"
|
||||
b = tmp_path / "b.yaml"
|
||||
a.write_text(
|
||||
'config_version: "1.0.0"\nconfig_hash: "aaa"\nfoo: 1\n', encoding="utf-8"
|
||||
)
|
||||
b.write_text(
|
||||
'config_version: "1.0.0"\nconfig_hash: "bbb"\nfoo: 1\n', encoding="utf-8"
|
||||
)
|
||||
assert compute_config_hash(a.read_text()) == compute_config_hash(b.read_text())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_strategy — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_repo_strategy_yaml(tmp_path: Path) -> None:
|
||||
"""The committed strategy.yaml validates with the recorded hash."""
|
||||
result = load_strategy(REPO_ROOT / "strategy.yaml")
|
||||
assert result.config.config_version == "1.0.0"
|
||||
assert result.config.sizing.kelly_fraction == Decimal("0.13")
|
||||
assert result.computed_hash == result.config.config_hash
|
||||
|
||||
|
||||
def test_load_with_local_override_merges(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
_write_with_correct_hash(main, _golden_yaml_skeleton())
|
||||
override = tmp_path / "strategy.local.yaml"
|
||||
override.write_text(
|
||||
yaml.safe_dump(
|
||||
{"sizing": {"max_concurrent_positions": 0}},
|
||||
sort_keys=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
loaded = load_strategy(main)
|
||||
assert loaded.config.sizing.max_concurrent_positions == 0
|
||||
assert override in loaded.sources
|
||||
|
||||
|
||||
def test_local_override_does_not_invalidate_main_hash(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
_write_with_correct_hash(main, _golden_yaml_skeleton())
|
||||
(tmp_path / "strategy.local.yaml").write_text(
|
||||
"sizing:\n kelly_fraction: '0.10'\n", encoding="utf-8"
|
||||
)
|
||||
loaded = load_strategy(main)
|
||||
assert loaded.config.sizing.kelly_fraction == Decimal("0.10")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_strategy — error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_with_mismatched_hash_raises(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
main.write_text(
|
||||
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
|
||||
)
|
||||
with pytest.raises(ConfigHashError, match="config_hash mismatch"):
|
||||
load_strategy(main)
|
||||
|
||||
|
||||
def test_load_with_enforce_hash_false_skips_check(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
main.write_text(
|
||||
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
|
||||
)
|
||||
loaded = load_strategy(main, enforce_hash=False)
|
||||
assert loaded.config.config_hash == "0" * 64
|
||||
|
||||
|
||||
def test_load_rejects_top_level_non_mapping(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
main.write_text("- a\n- b\n", encoding="utf-8")
|
||||
with pytest.raises(ValueError, match="top-level mapping"):
|
||||
load_strategy(main, enforce_hash=False)
|
||||
|
||||
|
||||
def test_local_override_path_pointing_to_missing_file_is_ignored(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
_write_with_correct_hash(main, _golden_yaml_skeleton())
|
||||
loaded = load_strategy(
|
||||
main, local_override_path=tmp_path / "nonexistent.yaml"
|
||||
)
|
||||
assert main in loaded.sources
|
||||
assert len(loaded.sources) == 1
|
||||
|
||||
|
||||
def test_empty_local_override_file_is_no_op(tmp_path: Path) -> None:
|
||||
main = tmp_path / "strategy.yaml"
|
||||
_write_with_correct_hash(main, _golden_yaml_skeleton())
|
||||
(tmp_path / "strategy.local.yaml").write_text("", encoding="utf-8")
|
||||
loaded = load_strategy(main)
|
||||
assert loaded.config.sizing.kelly_fraction == Decimal("0.13")
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Kill switch behaviour: SQLite + audit log stay in lock-step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cerbero_bite.safety import AuditLog, verify_chain
|
||||
from cerbero_bite.safety.kill_switch import KillSwitch, KillSwitchError
|
||||
from cerbero_bite.state import Repository, connect, run_migrations, transaction
|
||||
|
||||
|
||||
def _make_kill_switch(tmp_path: Path) -> tuple[KillSwitch, AuditLog, Path, Repository]:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
audit_path = tmp_path / "audit.log"
|
||||
conn = connect(db_path)
|
||||
run_migrations(conn)
|
||||
repo = Repository()
|
||||
with transaction(conn):
|
||||
repo.init_system_state(
|
||||
conn, config_version="1.0.0", now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
)
|
||||
conn.close()
|
||||
|
||||
audit = AuditLog(audit_path)
|
||||
times = iter(
|
||||
datetime(2026, 4, 27, 14, m, tzinfo=UTC) for m in (10, 20, 30, 40, 50)
|
||||
)
|
||||
ks = KillSwitch(
|
||||
connection_factory=lambda: connect(db_path),
|
||||
repository=repo,
|
||||
audit_log=audit,
|
||||
clock=lambda: next(times),
|
||||
)
|
||||
return ks, audit, audit_path, repo
|
||||
|
||||
|
||||
def test_arm_persists_state_and_appends_audit(tmp_path: Path) -> None:
|
||||
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
|
||||
assert ks.is_armed() is False
|
||||
|
||||
ks.arm(reason="manual test", source="manual")
|
||||
|
||||
assert ks.is_armed() is True
|
||||
conn = connect(tmp_path / "state.sqlite")
|
||||
try:
|
||||
state = repo.get_system_state(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
assert state is not None
|
||||
assert state.kill_switch == 1
|
||||
assert state.kill_reason == "manual test"
|
||||
assert state.kill_at is not None
|
||||
assert verify_chain(audit_path) == 1
|
||||
|
||||
|
||||
def test_arm_is_idempotent_on_second_call(tmp_path: Path) -> None:
|
||||
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
|
||||
ks.arm(reason="first", source="manual")
|
||||
ks.arm(reason="second", source="manual") # no-op
|
||||
# only one audit line because the second call short-circuits
|
||||
assert verify_chain(audit_path) == 1
|
||||
|
||||
|
||||
def test_disarm_resets_kill_switch(tmp_path: Path) -> None:
|
||||
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
|
||||
ks.arm(reason="test", source="manual")
|
||||
ks.disarm(reason="cleared", source="manual")
|
||||
assert ks.is_armed() is False
|
||||
conn = connect(tmp_path / "state.sqlite")
|
||||
try:
|
||||
state = repo.get_system_state(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
assert state is not None
|
||||
assert state.kill_at is None
|
||||
# arm + disarm = 2 audit lines
|
||||
assert verify_chain(audit_path) == 2
|
||||
|
||||
|
||||
def test_disarm_when_not_armed_is_noop(tmp_path: Path) -> None:
|
||||
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
|
||||
ks.disarm(reason="nothing to do", source="manual")
|
||||
assert verify_chain(audit_path) == 0
|
||||
|
||||
|
||||
def test_arm_requires_reason(tmp_path: Path) -> None:
|
||||
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
|
||||
with pytest.raises(KillSwitchError, match="reason is required"):
|
||||
ks.arm(reason="", source="manual")
|
||||
|
||||
|
||||
def test_arm_without_initialised_state_raises(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
audit_path = tmp_path / "audit.log"
|
||||
conn = connect(db_path)
|
||||
run_migrations(conn)
|
||||
conn.close()
|
||||
ks = KillSwitch(
|
||||
connection_factory=lambda: connect(db_path),
|
||||
repository=Repository(),
|
||||
audit_log=AuditLog(audit_path),
|
||||
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
|
||||
ks.arm(reason="x", source="manual")
|
||||
|
||||
|
||||
def test_audit_chain_records_event_kind(tmp_path: Path) -> None:
|
||||
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
|
||||
ks.arm(reason="x", source="mcp_timeout")
|
||||
ks.disarm(reason="y", source="manual")
|
||||
text = audit_path.read_text(encoding="utf-8")
|
||||
assert "KILL_SWITCH_ARMED" in text
|
||||
assert "KILL_SWITCH_DISARMED" in text
|
||||
|
||||
|
||||
def test_is_armed_returns_false_when_singleton_missing(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
audit_path = tmp_path / "audit.log"
|
||||
conn = connect(db_path)
|
||||
run_migrations(conn)
|
||||
conn.close()
|
||||
ks = KillSwitch(
|
||||
connection_factory=lambda: connect(db_path),
|
||||
repository=Repository(),
|
||||
audit_log=AuditLog(audit_path),
|
||||
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
assert ks.is_armed() is False
|
||||
|
||||
|
||||
def test_disarm_requires_reason(tmp_path: Path) -> None:
|
||||
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
|
||||
with pytest.raises(KillSwitchError, match="reason is required"):
|
||||
ks.disarm(reason="", source="manual")
|
||||
|
||||
|
||||
def test_disarm_without_initialised_state_raises(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
audit_path = tmp_path / "audit.log"
|
||||
conn = connect(db_path)
|
||||
run_migrations(conn)
|
||||
conn.close()
|
||||
ks = KillSwitch(
|
||||
connection_factory=lambda: connect(db_path),
|
||||
repository=Repository(),
|
||||
audit_log=AuditLog(audit_path),
|
||||
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
|
||||
ks.disarm(reason="x", source="manual")
|
||||
|
||||
|
||||
def test_clock_is_advanced_for_each_call(tmp_path: Path) -> None:
|
||||
ks, _audit, _audit_path, repo = _make_kill_switch(tmp_path)
|
||||
ks.arm(reason="x", source="manual")
|
||||
ks.disarm(reason="y", source="manual")
|
||||
conn = connect(tmp_path / "state.sqlite")
|
||||
try:
|
||||
state = repo.get_system_state(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
assert state is not None
|
||||
# last_health_check should reflect the disarm time (14:20 from the fake clock).
|
||||
assert state.last_health_check >= datetime(2026, 4, 27, 14, 15, tzinfo=UTC) - timedelta(
|
||||
seconds=1
|
||||
)
|
||||
@@ -38,14 +38,29 @@ def test_cli_status_runs(tmp_data_dir: Path) -> None:
|
||||
assert "phase: 0" in result.output
|
||||
|
||||
|
||||
def test_cli_kill_switch_arm_placeholder(tmp_data_dir: Path) -> None:
|
||||
def test_cli_kill_switch_arm_persists_state(tmp_data_dir: Path) -> None:
|
||||
runner = CliRunner()
|
||||
db_path = tmp_data_dir / "state.sqlite"
|
||||
audit_path = tmp_data_dir / "audit.log"
|
||||
result = runner.invoke(
|
||||
cli_main,
|
||||
["--log-dir", str(tmp_data_dir / "log"), "kill-switch", "arm", "--reason", "test"],
|
||||
[
|
||||
"--log-dir",
|
||||
str(tmp_data_dir / "log"),
|
||||
"kill-switch",
|
||||
"arm",
|
||||
"--reason",
|
||||
"smoke",
|
||||
"--db",
|
||||
str(db_path),
|
||||
"--audit",
|
||||
str(audit_path),
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "phase 0 placeholder" in result.output
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "ARMED" in result.output
|
||||
assert db_path.exists()
|
||||
assert audit_path.exists()
|
||||
|
||||
|
||||
def test_cli_version_flag() -> None:
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Migration runner + SQLite pragma sanity tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from cerbero_bite.state.db import (
|
||||
connect,
|
||||
current_version,
|
||||
list_migrations,
|
||||
run_migrations,
|
||||
transaction,
|
||||
)
|
||||
|
||||
|
||||
def test_list_migrations_is_ordered_and_starts_with_0001() -> None:
|
||||
migs = list_migrations()
|
||||
assert migs, "expected at least one migration file"
|
||||
versions = [m[0] for m in migs]
|
||||
assert versions == sorted(versions)
|
||||
assert versions[0] == 1
|
||||
|
||||
|
||||
def test_run_migrations_creates_full_schema(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
new_version = run_migrations(conn)
|
||||
assert new_version == max(m[0] for m in list_migrations())
|
||||
# Sanity: every documented table is present.
|
||||
existing = {
|
||||
row["name"]
|
||||
for row in conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
}
|
||||
assert {
|
||||
"positions",
|
||||
"instructions",
|
||||
"decisions",
|
||||
"dvol_history",
|
||||
"manual_actions",
|
||||
"system_state",
|
||||
} <= existing
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_run_migrations_is_idempotent(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
first = current_version(conn)
|
||||
run_migrations(conn) # second call must be a no-op
|
||||
assert current_version(conn) == first
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_pragmas_applied(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
assert conn.execute("PRAGMA foreign_keys").fetchone()[0] == 1
|
||||
# WAL is sticky on the file: confirm via the journal_mode pragma.
|
||||
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
||||
assert mode.lower() == "wal"
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_transaction_rolls_back_on_exception(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
with pytest.raises(RuntimeError, match="boom"), transaction(conn):
|
||||
conn.execute(
|
||||
"INSERT INTO system_state(id, last_health_check, "
|
||||
"config_version, started_at) VALUES (1, '2026-04-27', "
|
||||
"'1.0.0', '2026-04-27')"
|
||||
)
|
||||
raise RuntimeError("boom")
|
||||
# Row was rolled back.
|
||||
rows = conn.execute("SELECT COUNT(*) FROM system_state").fetchone()[0]
|
||||
assert rows == 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_foreign_key_enforcement(tmp_path: Path) -> None:
|
||||
db_path = tmp_path / "state.sqlite"
|
||||
conn = connect(db_path)
|
||||
try:
|
||||
run_migrations(conn)
|
||||
# Inserting an instruction without a parent position must fail.
|
||||
with pytest.raises(sqlite3.IntegrityError):
|
||||
conn.execute(
|
||||
"INSERT INTO instructions(instruction_id, proposal_id, kind, "
|
||||
"payload_json, sent_at) VALUES (?, ?, ?, ?, ?)",
|
||||
("i1", "missing", "open_combo", "{}", "2026-04-27"),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -0,0 +1,450 @@
|
||||
"""CRUD tests for :mod:`cerbero_bite.state.repository`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from cerbero_bite.state import (
|
||||
DecisionRecord,
|
||||
DvolSnapshot,
|
||||
InstructionRecord,
|
||||
ManualAction,
|
||||
PositionRecord,
|
||||
Repository,
|
||||
connect,
|
||||
run_migrations,
|
||||
transaction,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(tmp_path: Path) -> Iterator[sqlite3.Connection]:
|
||||
db = tmp_path / "state.sqlite"
|
||||
c = connect(db)
|
||||
run_migrations(c)
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo() -> Repository:
|
||||
return Repository()
|
||||
|
||||
|
||||
def _make_position(**overrides: object) -> PositionRecord:
|
||||
base: dict[str, object] = {
|
||||
"proposal_id": uuid4(),
|
||||
"spread_type": "bull_put",
|
||||
"expiry": datetime(2026, 5, 15, 8, 0, tzinfo=UTC),
|
||||
"short_strike": Decimal("2475"),
|
||||
"long_strike": Decimal("2350"),
|
||||
"short_instrument": "ETH-15MAY26-2475-P",
|
||||
"long_instrument": "ETH-15MAY26-2350-P",
|
||||
"n_contracts": 2,
|
||||
"spread_width_usd": Decimal("125"),
|
||||
"spread_width_pct": Decimal("0.0417"),
|
||||
"credit_eth": Decimal("0.030"),
|
||||
"credit_usd": Decimal("90"),
|
||||
"max_loss_usd": Decimal("160"),
|
||||
"spot_at_entry": Decimal("3000"),
|
||||
"dvol_at_entry": Decimal("50"),
|
||||
"delta_at_entry": Decimal("-0.12"),
|
||||
"eth_price_at_entry": Decimal("3000"),
|
||||
"proposed_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
"status": "proposed",
|
||||
"created_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
"updated_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
}
|
||||
base.update(overrides)
|
||||
return PositionRecord(**base) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# positions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_and_get_position_roundtrip(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
record = _make_position()
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, record)
|
||||
fetched = repo.get_position(conn, record.proposal_id)
|
||||
assert fetched is not None
|
||||
assert fetched.proposal_id == record.proposal_id
|
||||
assert fetched.short_strike == Decimal("2475")
|
||||
# Decimal precision preserved (no float coercion).
|
||||
assert fetched.spread_width_pct == Decimal("0.0417")
|
||||
|
||||
|
||||
def test_get_unknown_position_returns_none(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
assert repo.get_position(conn, uuid4()) is None
|
||||
|
||||
|
||||
def test_list_open_positions_filters_by_status(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
open_pos = _make_position(status="open")
|
||||
closed_pos = _make_position(status="closed")
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, open_pos)
|
||||
repo.create_position(conn, closed_pos)
|
||||
open_only = repo.list_open_positions(conn)
|
||||
assert {p.proposal_id for p in open_only} == {open_pos.proposal_id}
|
||||
|
||||
|
||||
def test_count_concurrent_positions_excludes_closed(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, _make_position(status="open"))
|
||||
repo.create_position(conn, _make_position(status="awaiting_fill"))
|
||||
repo.create_position(conn, _make_position(status="closed"))
|
||||
repo.create_position(conn, _make_position(status="cancelled"))
|
||||
assert repo.count_concurrent_positions(conn) == 2
|
||||
|
||||
|
||||
def test_update_position_status_with_only_status_skips_optional_fields(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
"""Hit the False side of every optional-field branch in ``update_position_status``."""
|
||||
record = _make_position(status="proposed")
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, record)
|
||||
repo.update_position_status(
|
||||
conn,
|
||||
record.proposal_id,
|
||||
status="awaiting_fill",
|
||||
now=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
|
||||
)
|
||||
fetched = repo.get_position(conn, record.proposal_id)
|
||||
assert fetched is not None
|
||||
assert fetched.status == "awaiting_fill"
|
||||
assert fetched.opened_at is None
|
||||
assert fetched.closed_at is None
|
||||
assert fetched.close_reason is None
|
||||
|
||||
|
||||
def test_update_position_status_persists_open_then_close(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
record = _make_position(status="awaiting_fill")
|
||||
opened_at = datetime(2026, 4, 27, 14, 10, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, record)
|
||||
repo.update_position_status(
|
||||
conn,
|
||||
record.proposal_id,
|
||||
status="open",
|
||||
opened_at=opened_at,
|
||||
now=datetime(2026, 4, 27, 14, 11, tzinfo=UTC),
|
||||
)
|
||||
after_open = repo.get_position(conn, record.proposal_id)
|
||||
assert after_open is not None
|
||||
assert after_open.opened_at == opened_at
|
||||
|
||||
closed_at = datetime(2026, 5, 12, 14, 0, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.update_position_status(
|
||||
conn,
|
||||
record.proposal_id,
|
||||
status="closed",
|
||||
closed_at=closed_at,
|
||||
close_reason="CLOSE_PROFIT",
|
||||
debit_paid_eth=Decimal("0.012"),
|
||||
pnl_eth=Decimal("0.018"),
|
||||
pnl_usd=Decimal("54"),
|
||||
now=closed_at,
|
||||
)
|
||||
|
||||
fetched = repo.get_position(conn, record.proposal_id)
|
||||
assert fetched is not None
|
||||
assert fetched.status == "closed"
|
||||
assert fetched.close_reason == "CLOSE_PROFIT"
|
||||
assert fetched.debit_paid_eth == Decimal("0.012")
|
||||
assert fetched.pnl_eth == Decimal("0.018")
|
||||
|
||||
|
||||
def test_naive_datetime_is_normalised_to_utc(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
naive_now = datetime(2026, 4, 27, 14, 0)
|
||||
record = _make_position(proposed_at=naive_now, created_at=naive_now, updated_at=naive_now)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, record)
|
||||
fetched = repo.get_position(conn, record.proposal_id)
|
||||
assert fetched is not None
|
||||
assert fetched.proposed_at.tzinfo is not None
|
||||
assert fetched.proposed_at.utcoffset() == timedelta(0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_instruction_lifecycle_ack_and_fill(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
pos = _make_position(status="awaiting_fill")
|
||||
instr = InstructionRecord(
|
||||
instruction_id=uuid4(),
|
||||
proposal_id=pos.proposal_id,
|
||||
kind="open_combo",
|
||||
payload_json='{"action":"open"}',
|
||||
sent_at=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
|
||||
)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, pos)
|
||||
repo.create_instruction(conn, instr)
|
||||
repo.update_instruction(
|
||||
conn,
|
||||
instr.instruction_id,
|
||||
acknowledged_at=datetime(2026, 4, 27, 14, 6, tzinfo=UTC),
|
||||
)
|
||||
repo.update_instruction(
|
||||
conn,
|
||||
instr.instruction_id,
|
||||
filled_at=datetime(2026, 4, 27, 14, 8, tzinfo=UTC),
|
||||
actual_fill_eth=Decimal("0.0298"),
|
||||
actual_fees_eth=Decimal("0.0001"),
|
||||
)
|
||||
|
||||
fetched = repo.list_instructions(conn, pos.proposal_id)
|
||||
assert len(fetched) == 1
|
||||
assert fetched[0].acknowledged_at is not None
|
||||
assert fetched[0].filled_at is not None
|
||||
assert fetched[0].actual_fill_eth == Decimal("0.0298")
|
||||
|
||||
|
||||
def test_update_instruction_no_op_when_no_fields(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
pos = _make_position()
|
||||
instr = InstructionRecord(
|
||||
instruction_id=uuid4(),
|
||||
proposal_id=pos.proposal_id,
|
||||
kind="open_combo",
|
||||
payload_json="{}",
|
||||
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, pos)
|
||||
repo.create_instruction(conn, instr)
|
||||
repo.update_instruction(conn, instr.instruction_id) # no fields
|
||||
|
||||
fetched = repo.list_instructions(conn, pos.proposal_id)
|
||||
assert fetched[0].acknowledged_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decisions / dvol / manual_actions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_and_list_decisions(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
decision = DecisionRecord(
|
||||
decision_type="entry_check",
|
||||
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
inputs_json='{"capital":1500}',
|
||||
outputs_json='{"accepted":true}',
|
||||
action_taken="propose_open",
|
||||
)
|
||||
with transaction(conn):
|
||||
decision_id = repo.record_decision(conn, decision)
|
||||
assert decision_id > 0
|
||||
decisions = repo.list_decisions(conn)
|
||||
assert len(decisions) == 1
|
||||
assert decisions[0].id == decision_id
|
||||
assert decisions[0].action_taken == "propose_open"
|
||||
|
||||
|
||||
def test_record_dvol_snapshot_replaces_on_duplicate_timestamp(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
ts = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.record_dvol_snapshot(
|
||||
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("50"), eth_spot=Decimal("3000"))
|
||||
)
|
||||
repo.record_dvol_snapshot(
|
||||
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("55"), eth_spot=Decimal("3050"))
|
||||
)
|
||||
rows = conn.execute("SELECT COUNT(*) FROM dvol_history").fetchone()
|
||||
assert rows[0] == 1
|
||||
|
||||
|
||||
def test_manual_action_enqueue_consume_cycle(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
pos = _make_position()
|
||||
action = ManualAction(
|
||||
kind="approve_proposal",
|
||||
proposal_id=pos.proposal_id,
|
||||
payload_json='{"reason":"go"}',
|
||||
created_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, pos)
|
||||
action_id = repo.enqueue_manual_action(conn, action)
|
||||
next_action = repo.next_unconsumed_action(conn)
|
||||
assert next_action is not None
|
||||
assert next_action.id == action_id
|
||||
with transaction(conn):
|
||||
repo.mark_action_consumed(
|
||||
conn,
|
||||
action_id,
|
||||
consumed_by="orchestrator",
|
||||
result="ok",
|
||||
now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC),
|
||||
)
|
||||
assert repo.next_unconsumed_action(conn) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# system_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_init_system_state_is_idempotent(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.init_system_state(conn, config_version="1.0.0", now=now)
|
||||
repo.init_system_state(conn, config_version="1.0.0", now=now)
|
||||
state = repo.get_system_state(conn)
|
||||
assert state is not None
|
||||
assert state.kill_switch == 0
|
||||
assert state.config_version == "1.0.0"
|
||||
|
||||
|
||||
def test_kill_switch_arm_and_disarm(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.init_system_state(conn, config_version="1.0.0", now=now)
|
||||
repo.set_kill_switch(conn, armed=True, reason="manual", now=now)
|
||||
state = repo.get_system_state(conn)
|
||||
assert state is not None
|
||||
assert state.kill_switch == 1
|
||||
assert state.kill_reason == "manual"
|
||||
assert state.kill_at is not None
|
||||
|
||||
later = now + timedelta(minutes=5)
|
||||
with transaction(conn):
|
||||
repo.set_kill_switch(conn, armed=False, reason=None, now=later)
|
||||
state = repo.get_system_state(conn)
|
||||
assert state is not None
|
||||
assert state.kill_switch == 0
|
||||
assert state.kill_at is None
|
||||
assert state.last_health_check == later
|
||||
|
||||
|
||||
def test_get_system_state_returns_none_when_uninitialised(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
assert repo.get_system_state(conn) is None
|
||||
|
||||
|
||||
def test_list_positions_filters_by_status(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
open_pos = _make_position(status="open")
|
||||
closed_pos = _make_position(status="closed")
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, open_pos)
|
||||
repo.create_position(conn, closed_pos)
|
||||
closed_only = repo.list_positions(conn, status="closed")
|
||||
assert len(closed_only) == 1
|
||||
assert closed_only[0].proposal_id == closed_pos.proposal_id
|
||||
|
||||
|
||||
def test_list_positions_without_filter_returns_all(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, _make_position(status="open"))
|
||||
repo.create_position(conn, _make_position(status="closed"))
|
||||
assert len(repo.list_positions(conn)) == 2
|
||||
|
||||
|
||||
def test_list_decisions_filters_by_proposal(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
pos = _make_position()
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, pos)
|
||||
repo.record_decision(
|
||||
conn,
|
||||
DecisionRecord(
|
||||
decision_type="entry_check",
|
||||
proposal_id=pos.proposal_id,
|
||||
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
inputs_json="{}",
|
||||
outputs_json="{}",
|
||||
action_taken="propose_open",
|
||||
),
|
||||
)
|
||||
repo.record_decision(
|
||||
conn,
|
||||
DecisionRecord(
|
||||
decision_type="exit_check",
|
||||
timestamp=datetime(2026, 4, 27, 15, 0, tzinfo=UTC),
|
||||
inputs_json="{}",
|
||||
outputs_json="{}",
|
||||
action_taken="HOLD",
|
||||
),
|
||||
)
|
||||
only_for_proposal = repo.list_decisions(conn, proposal_id=pos.proposal_id)
|
||||
assert len(only_for_proposal) == 1
|
||||
assert only_for_proposal[0].action_taken == "propose_open"
|
||||
|
||||
|
||||
def test_update_instruction_sets_cancelled(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
pos = _make_position()
|
||||
instr = InstructionRecord(
|
||||
instruction_id=uuid4(),
|
||||
proposal_id=pos.proposal_id,
|
||||
kind="open_combo",
|
||||
payload_json="{}",
|
||||
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
|
||||
)
|
||||
with transaction(conn):
|
||||
repo.create_position(conn, pos)
|
||||
repo.create_instruction(conn, instr)
|
||||
repo.update_instruction(
|
||||
conn,
|
||||
instr.instruction_id,
|
||||
cancelled_at=datetime(2026, 4, 27, 14, 30, tzinfo=UTC),
|
||||
)
|
||||
fetched = repo.list_instructions(conn, pos.proposal_id)
|
||||
assert fetched[0].cancelled_at is not None
|
||||
|
||||
|
||||
def test_touch_health_check_updates_timestamp(
|
||||
conn: sqlite3.Connection, repo: Repository
|
||||
) -> None:
|
||||
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
|
||||
with transaction(conn):
|
||||
repo.init_system_state(conn, config_version="1.0.0", now=now)
|
||||
later = now + timedelta(minutes=10)
|
||||
repo.touch_health_check(conn, now=later)
|
||||
state = repo.get_system_state(conn)
|
||||
assert state is not None
|
||||
assert state.last_health_check == later
|
||||
Reference in New Issue
Block a user