Phase 2: persistence + safety controls

Aggiunge la persistenza SQLite, l'audit log a hash chain, il kill
switch coordinato e i CLI di gestione documentati in
docs/05-data-model.md e docs/07-risk-controls.md. 197 test pass,
1 skipped (sqlite3 CLI mancante), copertura totale 97%.

State (`state/`):
- 0001_init.sql con positions, instructions, decisions, dvol_history,
  manual_actions, system_state.
- db.py: connect con WAL + foreign_keys + transaction ctx, runner
  forward-only basato su PRAGMA user_version.
- models.py: record Pydantic, Decimal preservato come TEXT.
- repository.py: CRUD typed con singola connessione passata, cache
  aware, posizioni concorrenti.

Safety (`safety/`):
- audit_log.py: AuditLog append-only con SHA-256 chain e fsync,
  verify_chain riconosce ogni manomissione (payload, prev_hash,
  hash, JSON, separatori).
- kill_switch.py: arm/disarm transazionali, idempotenti, accoppiati
  all'audit chain.

Config (`config/loader.py` + `strategy.yaml`):
- Loader YAML con deep-merge di strategy.local.yaml.
- Verifica config_hash SHA-256 (riga config_hash esclusa).
- File golden strategy.yaml + esempio override.

Scripts:
- dead_man.sh: watchdog shell indipendente da Python.
- backup.py: VACUUM INTO orario con retention 30 giorni.

CLI:
- audit verify (exit 2 su tampering).
- kill-switch arm/disarm/status su SQLite reale.
- state inspect con tabella posizioni aperte.
- config hash, config validate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 13:35:35 +02:00
parent fbb7753cc6
commit 263470786d
25 changed files with 3669 additions and 14 deletions
+111
View File
@@ -0,0 +1,111 @@
"""Smoke tests for the dead_man.sh shell watchdog."""
from __future__ import annotations
import shutil
import sqlite3
import subprocess
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
SCRIPT = REPO_ROOT / "scripts" / "dead_man.sh"
_SQLITE3_BIN = shutil.which("sqlite3")
pytestmark = [
pytest.mark.integration,
]
def _require_sqlite3() -> None:
if _SQLITE3_BIN is None:
pytest.skip("sqlite3 CLI not installed on this host")
def _setup_project(tmp_path: Path) -> Path:
project = tmp_path / "project"
(project / "data" / "log").mkdir(parents=True)
(project / "scripts").mkdir(parents=True)
shutil.copy(SCRIPT, project / "scripts" / "dead_man.sh")
return project
def _write_health(project: Path, ts: datetime) -> None:
log_file = project / "data" / "log" / f"cerbero-bite-{ts:%Y-%m-%d}.jsonl"
log_file.write_text(
f'{{"ts": "{ts.astimezone(UTC).isoformat()}", "event": "HEALTH_OK"}}\n',
encoding="utf-8",
)
def _run(project: Path, threshold: int = 900) -> subprocess.CompletedProcess[str]:
return subprocess.run(
["bash", str(project / "scripts" / "dead_man.sh")],
env={
"PATH": "/usr/bin:/bin",
"PROJECT_ROOT": str(project),
"DEAD_MAN_THRESHOLD_SECONDS": str(threshold),
},
capture_output=True,
text=True,
check=False,
)
def test_dead_man_exits_zero_when_recent_health_ok(tmp_path: Path) -> None:
project = _setup_project(tmp_path)
_write_health(project, datetime.now(UTC) - timedelta(seconds=60))
result = _run(project, threshold=900)
assert result.returncode == 0, result.stderr
def test_dead_man_arms_kill_switch_when_silent(tmp_path: Path) -> None:
_require_sqlite3()
project = _setup_project(tmp_path)
_write_health(project, datetime.now(UTC) - timedelta(seconds=2000))
# Pre-create the SQLite system_state singleton; otherwise the script
# has nothing to update.
db = project / "data" / "state.sqlite"
conn = sqlite3.connect(str(db))
conn.execute(
"CREATE TABLE system_state (id INTEGER PRIMARY KEY CHECK(id=1), "
"kill_switch INTEGER NOT NULL DEFAULT 0, kill_reason TEXT, "
"kill_at TEXT, last_health_check TEXT NOT NULL, "
"last_kelly_calib TEXT, config_version TEXT NOT NULL, "
"started_at TEXT NOT NULL)"
)
conn.execute(
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
)
conn.commit()
conn.close()
result = _run(project, threshold=900)
assert result.returncode == 1
# Kill switch armed.
conn = sqlite3.connect(str(db))
try:
kill = conn.execute("SELECT kill_switch FROM system_state").fetchone()[0]
finally:
conn.close()
assert kill == 1
# Alert file exists.
alert = project / "data" / "log" / "dead-man-alert.txt"
assert alert.exists()
assert "dead_man" in alert.read_text(encoding="utf-8")
def test_dead_man_handles_missing_log_file(tmp_path: Path) -> None:
project = _setup_project(tmp_path)
# No log file at all.
result = _run(project, threshold=900)
assert result.returncode == 1
alert = project / "data" / "log" / "dead-man-alert.txt"
assert alert.exists()
+273
View File
@@ -0,0 +1,273 @@
"""Audit chain writer + verifier tests."""
from __future__ import annotations
import hashlib
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.safety.audit_log import (
GENESIS_HASH,
AuditChainError,
AuditLog,
iter_entries,
verify_chain,
)
def test_empty_file_verifies_with_zero_entries(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
assert verify_chain(path) == 0
def test_first_entry_uses_genesis_prev_hash(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
entry = log.append(
event="ENGINE_START",
payload={"version": "1.0.0"},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert entry.prev_hash == GENESIS_HASH
assert entry.hash != GENESIS_HASH
def test_chain_links_subsequent_entries(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
e1 = log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
e2 = log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
e3 = log.append(event="C", payload={"i": 3}, now=datetime(2026, 4, 27, 14, 2, tzinfo=UTC))
assert e2.prev_hash == e1.hash
assert e3.prev_hash == e2.hash
assert verify_chain(path) == 3
def test_iter_entries_yields_in_order(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
events = [e.event for e in iter_entries(path)]
assert events == ["A", "B"]
def test_log_resumes_chain_after_reopen(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
first = AuditLog(path)
e1 = first.append(
event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
)
second = AuditLog(path)
assert second.last_hash == e1.hash
e2 = second.append(
event="B", payload={"k": "v"}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)
)
assert e2.prev_hash == e1.hash
assert verify_chain(path) == 2
def test_payload_with_pipe_character_round_trips(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(
event="NOTE",
payload={"text": "first|second|third"},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
entries = list(iter_entries(path))
assert entries[0].payload == {"text": "first|second|third"}
assert verify_chain(path) == 1
def test_tampered_payload_breaks_chain(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
# Mutate the first line's payload by hand.
text = path.read_text(encoding="utf-8").splitlines()
text[0] = text[0].replace('"i":1', '"i":99')
path.write_text("\n".join(text) + "\n", encoding="utf-8")
with pytest.raises(AuditChainError, match="hash mismatch"):
verify_chain(path)
def test_verify_chain_skips_blank_lines(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
raw = path.read_text(encoding="utf-8")
path.write_text("\n" + raw + "\n \n", encoding="utf-8")
# The chain still verifies despite the surrounding whitespace lines.
assert verify_chain(path) == 1
def test_prev_hash_mismatch_between_entries_is_caught(tmp_path: Path) -> None:
"""Second line's prev_hash points to a different chain — verify_chain rejects."""
path = tmp_path / "audit.log"
log = AuditLog(path)
e1 = log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
# Build a synthetic second line whose prev_hash != e1.hash but whose
# own hash is correctly computed from that bogus prev_hash.
fake_prev = "0" * 32 + "f" * 32
ts2 = "2026-04-27T14:01:00+00:00"
payload_json = "{}"
raw = f"{ts2}|B|{payload_json}|{fake_prev}"
fake_hash = hashlib.sha256(raw.encode()).hexdigest()
line = f"{ts2}|B|{payload_json}|prev_hash={fake_prev}|hash={fake_hash}\n"
with path.open("a", encoding="utf-8") as fh:
fh.write(line)
assert e1.hash != fake_prev # sanity
with pytest.raises(AuditChainError, match="prev_hash mismatch"):
verify_chain(path)
def test_tampered_prev_hash_breaks_chain(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
# Inject an unrelated prev_hash on the second line.
lines = path.read_text(encoding="utf-8").splitlines()
lines[1] = lines[1].replace("prev_hash=", "prev_hash=" + "f" * 64 + "X")
# Truncate to recover length: replace prev_hash field with all-ff.
lines[1] = lines[1].replace("X", "")
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
with pytest.raises(AuditChainError):
verify_chain(path)
def test_malformed_line_raises_chain_error(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text("not-a-valid-line\n", encoding="utf-8")
with pytest.raises(AuditChainError):
verify_chain(path)
def test_parser_rejects_missing_hash_field(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{}|prev_hash=" + "0" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="hash="):
verify_chain(path)
def test_parser_rejects_missing_prev_hash_field(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{}|hash=" + "f" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="prev_hash"):
verify_chain(path)
def test_parser_rejects_line_with_no_separators(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text("just-a-blob|hash=" + "f" * 64 + "\n", encoding="utf-8")
with pytest.raises(AuditChainError, match="prev_hash"):
verify_chain(path)
def test_parser_rejects_malformed_leading_section(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
# Two `|` only: rsplit succeeds twice, leading parts has 1 element ≠ 3.
path.write_text(
"tooshort|prev_hash=" + "0" * 64 + "|hash=" + "f" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="leading section"):
verify_chain(path)
def test_parser_rejects_payload_not_a_json_object(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|[1,2]|prev_hash="
+ "0" * 64
+ "|hash="
+ "f" * 64
+ "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="JSON object"):
verify_chain(path)
def test_parser_rejects_payload_with_invalid_json(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{not-json}|prev_hash="
+ "0" * 64
+ "|hash="
+ "f" * 64
+ "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="JSON"):
verify_chain(path)
def test_iter_entries_returns_empty_when_file_missing(tmp_path: Path) -> None:
path = tmp_path / "missing.log"
assert list(iter_entries(path)) == []
def test_iter_entries_skips_blank_lines(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
raw = path.read_text(encoding="utf-8")
path.write_text(raw + "\n\n", encoding="utf-8")
entries = list(iter_entries(path))
assert len(entries) == 1
def test_log_resumes_chain_with_large_file(tmp_path: Path) -> None:
"""Tail-seek reads past the 4096-byte chunk boundary."""
path = tmp_path / "audit.log"
log = AuditLog(path)
base = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
# Each line ~150 chars; 50 lines is comfortably > 4096 bytes.
for i in range(50):
log.append(
event=f"E{i}",
payload={"i": i, "filler": "x" * 80},
now=base + timedelta(seconds=i),
)
last_hash = log.last_hash
reopened = AuditLog(path)
assert reopened.last_hash == last_hash
assert verify_chain(path) == 50
def test_payload_serialisation_is_canonical(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
# Different key order must produce identical hashes.
e1 = log.append(
event="A",
payload={"b": 1, "a": 2},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
other = tmp_path / "audit_other.log"
log2 = AuditLog(other)
e2 = log2.append(
event="A",
payload={"a": 2, "b": 1},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert e1.hash == e2.hash
+126
View File
@@ -0,0 +1,126 @@
"""Tests for scripts.backup."""
from __future__ import annotations
import importlib.util
import sys
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.state import connect, run_migrations
REPO_ROOT = Path(__file__).resolve().parents[2]
def _load_backup_module() -> object:
"""Load scripts/backup.py as a module without polluting sys.path."""
spec = importlib.util.spec_from_file_location(
"_cerbero_bite_backup", REPO_ROOT / "scripts" / "backup.py"
)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
@pytest.fixture
def backup_mod() -> object:
return _load_backup_module()
def test_backup_database_creates_snapshot(tmp_path: Path, backup_mod: object) -> None:
db = tmp_path / "state.sqlite"
conn = connect(db)
try:
run_migrations(conn)
conn.execute(
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
)
finally:
conn.close()
backup_dir = tmp_path / "backups"
snapshot = backup_mod.backup_database( # type: ignore[attr-defined]
db_path=db,
backup_dir=backup_dir,
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert snapshot.exists()
assert snapshot.name == "state-20260427-14.sqlite"
# Snapshot is itself a valid SQLite db with the same row.
snap = connect(snapshot)
try:
rows = snap.execute(
"SELECT config_version FROM system_state WHERE id = 1"
).fetchone()
finally:
snap.close()
assert rows[0] == "1.0.0"
def test_backup_database_replaces_existing_hour_snapshot(
tmp_path: Path, backup_mod: object
) -> None:
db = tmp_path / "state.sqlite"
conn = connect(db)
try:
run_migrations(conn)
finally:
conn.close()
when = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
first = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
second = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
assert first == second
files = list((tmp_path / "b").iterdir())
assert len(files) == 1
def test_prune_backups_removes_old_files(tmp_path: Path, backup_mod: object) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
fresh = backup_dir / "state-20260420-10.sqlite"
stale = backup_dir / "state-20260101-12.sqlite"
other = backup_dir / "unrelated.txt"
fresh.touch()
stale.touch()
other.touch()
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
backup_dir,
retention_days=30,
now=datetime(2026, 4, 27, tzinfo=UTC),
)
assert deleted == [stale]
assert fresh.exists()
assert other.exists()
def test_prune_backups_ignores_unparseable_filenames(
tmp_path: Path, backup_mod: object
) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
(backup_dir / "state-bogus-XX.sqlite").touch()
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
backup_dir,
retention_days=0,
now=datetime(2026, 4, 27, tzinfo=UTC) + timedelta(days=10000),
)
assert deleted == []
def test_list_backups_returns_sorted(tmp_path: Path, backup_mod: object) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
a = backup_dir / "state-20260103-08.sqlite"
b = backup_dir / "state-20260101-08.sqlite"
a.touch()
b.touch()
listed = list(backup_mod.list_backups(backup_dir)) # type: ignore[attr-defined]
assert listed == [b, a]
+160
View File
@@ -0,0 +1,160 @@
"""End-to-end CLI tests for the Phase 2 commands.
The commands hit real on-disk paths (tmp_path) so the assertions run
the production code paths verbatim.
"""
from __future__ import annotations
from datetime import UTC, datetime
from pathlib import Path
from click.testing import CliRunner
from cerbero_bite.cli import main as cli_main
from cerbero_bite.safety import AuditLog
from cerbero_bite.state import Repository, connect, run_migrations, transaction
def _seed_state(db_path: Path) -> None:
conn = connect(db_path)
try:
run_migrations(conn)
with transaction(conn):
Repository().init_system_state(
conn,
config_version="1.0.0",
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
finally:
conn.close()
def test_audit_verify_reports_ok_on_clean_chain(tmp_path: Path) -> None:
audit = AuditLog(tmp_path / "audit.log")
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
audit.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
result = CliRunner().invoke(
cli_main, ["audit", "verify", "--file", str(tmp_path / "audit.log")]
)
assert result.exit_code == 0, result.output
assert "ok" in result.output
assert "2" in result.output
def test_audit_verify_handles_empty_file(tmp_path: Path) -> None:
target = tmp_path / "audit.log"
target.write_text("", encoding="utf-8")
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
assert result.exit_code == 0
assert "empty" in result.output
def test_audit_verify_exits_nonzero_on_tampering(tmp_path: Path) -> None:
target = tmp_path / "audit.log"
audit = AuditLog(target)
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
target.write_text(
target.read_text(encoding="utf-8").replace('"event":"A"', '"event":"X"'),
encoding="utf-8",
)
# NB: we mutated the JSON payload, but the actual line still has event=A.
# Force tampering by editing the literal "A" in the line text.
raw = target.read_text(encoding="utf-8")
target.write_text(raw.replace("|A|", "|X|", 1), encoding="utf-8")
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
assert result.exit_code == 2
assert "TAMPERED" in result.output
def test_kill_switch_status_prints_disarmed(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
_seed_state(db)
result = CliRunner().invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
assert result.exit_code == 0
assert "disarmed" in result.output
def test_kill_switch_arm_then_status_shows_armed(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
audit = tmp_path / "audit.log"
runner = CliRunner()
arm = runner.invoke(
cli_main,
[
"kill-switch",
"arm",
"--reason",
"manual smoke",
"--db",
str(db),
"--audit",
str(audit),
],
)
assert arm.exit_code == 0, arm.output
status = runner.invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
assert status.exit_code == 0
assert "ARMED" in status.output
def test_kill_switch_status_handles_missing_db(tmp_path: Path) -> None:
result = CliRunner().invoke(
cli_main, ["kill-switch", "status", "--db", str(tmp_path / "absent.sqlite")]
)
assert result.exit_code == 0
assert "not found" in result.output
def test_state_inspect_shows_no_open_positions(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
_seed_state(db)
result = CliRunner().invoke(cli_main, ["state", "inspect", "--db", str(db)])
assert result.exit_code == 0
assert "no open positions" in result.output
def test_state_inspect_handles_missing_db(tmp_path: Path) -> None:
result = CliRunner().invoke(
cli_main, ["state", "inspect", "--db", str(tmp_path / "absent.sqlite")]
)
assert result.exit_code == 0
assert "not found" in result.output
def test_config_hash_matches_loader(tmp_path: Path) -> None:
target = tmp_path / "strategy.yaml"
target.write_text(
'config_version: "1.0.0"\nconfig_hash: "0000"\nasset:\n symbol: ETH\n',
encoding="utf-8",
)
result = CliRunner().invoke(cli_main, ["config", "hash", "--file", str(target)])
assert result.exit_code == 0
assert len(result.output.strip()) == 64 # sha256 hex
def test_config_validate_repo_strategy_yaml() -> None:
repo_root = Path(__file__).resolve().parents[2]
yaml_path = repo_root / "strategy.yaml"
result = CliRunner().invoke(
cli_main, ["config", "validate", "--file", str(yaml_path)]
)
assert result.exit_code == 0
assert "ok" in result.output
def test_config_validate_with_no_enforce_hash_skips_check(tmp_path: Path) -> None:
target = tmp_path / "strategy.yaml"
target.write_text(
'config_version: "1.0.0"\nconfig_hash: "wrong"\n'
'last_review: "2026-04-26"\nlast_reviewer: "test"\n',
encoding="utf-8",
)
result = CliRunner().invoke(
cli_main,
["config", "validate", "--file", str(target), "--no-enforce-hash"],
)
assert result.exit_code == 0
assert "ok" in result.output
+149
View File
@@ -0,0 +1,149 @@
"""Tests for the YAML loader and hash verification."""
from __future__ import annotations
from decimal import Decimal
from pathlib import Path
import pytest
import yaml
from cerbero_bite.config.loader import (
ConfigHashError,
compute_config_hash,
load_strategy,
)
REPO_ROOT = Path(__file__).resolve().parents[2]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _golden_yaml_skeleton(**overrides: object) -> dict[str, object]:
base = {
"config_version": "1.0.0-test",
"config_hash": "0" * 64,
"last_review": "2026-04-26",
"last_reviewer": "test",
}
base.update(overrides)
return base
def _write_with_correct_hash(path: Path, doc: dict[str, object]) -> None:
"""Write a YAML doc and patch ``config_hash`` to match the file body."""
text = yaml.safe_dump(doc, sort_keys=False)
path.write_text(text, encoding="utf-8")
new_hash = compute_config_hash(path.read_text(encoding="utf-8"))
doc = {**doc, "config_hash": new_hash}
text = yaml.safe_dump(doc, sort_keys=False)
path.write_text(text, encoding="utf-8")
# ---------------------------------------------------------------------------
# compute_config_hash
# ---------------------------------------------------------------------------
def test_compute_hash_is_independent_of_recorded_hash_value(tmp_path: Path) -> None:
a = tmp_path / "a.yaml"
b = tmp_path / "b.yaml"
a.write_text(
'config_version: "1.0.0"\nconfig_hash: "aaa"\nfoo: 1\n', encoding="utf-8"
)
b.write_text(
'config_version: "1.0.0"\nconfig_hash: "bbb"\nfoo: 1\n', encoding="utf-8"
)
assert compute_config_hash(a.read_text()) == compute_config_hash(b.read_text())
# ---------------------------------------------------------------------------
# load_strategy — happy path
# ---------------------------------------------------------------------------
def test_load_repo_strategy_yaml(tmp_path: Path) -> None:
"""The committed strategy.yaml validates with the recorded hash."""
result = load_strategy(REPO_ROOT / "strategy.yaml")
assert result.config.config_version == "1.0.0"
assert result.config.sizing.kelly_fraction == Decimal("0.13")
assert result.computed_hash == result.config.config_hash
def test_load_with_local_override_merges(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
override = tmp_path / "strategy.local.yaml"
override.write_text(
yaml.safe_dump(
{"sizing": {"max_concurrent_positions": 0}},
sort_keys=False,
),
encoding="utf-8",
)
loaded = load_strategy(main)
assert loaded.config.sizing.max_concurrent_positions == 0
assert override in loaded.sources
def test_local_override_does_not_invalidate_main_hash(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
(tmp_path / "strategy.local.yaml").write_text(
"sizing:\n kelly_fraction: '0.10'\n", encoding="utf-8"
)
loaded = load_strategy(main)
assert loaded.config.sizing.kelly_fraction == Decimal("0.10")
# ---------------------------------------------------------------------------
# load_strategy — error paths
# ---------------------------------------------------------------------------
def test_load_with_mismatched_hash_raises(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text(
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
)
with pytest.raises(ConfigHashError, match="config_hash mismatch"):
load_strategy(main)
def test_load_with_enforce_hash_false_skips_check(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text(
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
)
loaded = load_strategy(main, enforce_hash=False)
assert loaded.config.config_hash == "0" * 64
def test_load_rejects_top_level_non_mapping(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text("- a\n- b\n", encoding="utf-8")
with pytest.raises(ValueError, match="top-level mapping"):
load_strategy(main, enforce_hash=False)
def test_local_override_path_pointing_to_missing_file_is_ignored(
tmp_path: Path,
) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
loaded = load_strategy(
main, local_override_path=tmp_path / "nonexistent.yaml"
)
assert main in loaded.sources
assert len(loaded.sources) == 1
def test_empty_local_override_file_is_no_op(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
(tmp_path / "strategy.local.yaml").write_text("", encoding="utf-8")
loaded = load_strategy(main)
assert loaded.config.sizing.kelly_fraction == Decimal("0.13")
+170
View File
@@ -0,0 +1,170 @@
"""Kill switch behaviour: SQLite + audit log stay in lock-step."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.safety import AuditLog, verify_chain
from cerbero_bite.safety.kill_switch import KillSwitch, KillSwitchError
from cerbero_bite.state import Repository, connect, run_migrations, transaction
def _make_kill_switch(tmp_path: Path) -> tuple[KillSwitch, AuditLog, Path, Repository]:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
repo = Repository()
with transaction(conn):
repo.init_system_state(
conn, config_version="1.0.0", now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
)
conn.close()
audit = AuditLog(audit_path)
times = iter(
datetime(2026, 4, 27, 14, m, tzinfo=UTC) for m in (10, 20, 30, 40, 50)
)
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=repo,
audit_log=audit,
clock=lambda: next(times),
)
return ks, audit, audit_path, repo
def test_arm_persists_state_and_appends_audit(tmp_path: Path) -> None:
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
assert ks.is_armed() is False
ks.arm(reason="manual test", source="manual")
assert ks.is_armed() is True
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
assert state.kill_switch == 1
assert state.kill_reason == "manual test"
assert state.kill_at is not None
assert verify_chain(audit_path) == 1
def test_arm_is_idempotent_on_second_call(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.arm(reason="first", source="manual")
ks.arm(reason="second", source="manual") # no-op
# only one audit line because the second call short-circuits
assert verify_chain(audit_path) == 1
def test_disarm_resets_kill_switch(tmp_path: Path) -> None:
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
ks.arm(reason="test", source="manual")
ks.disarm(reason="cleared", source="manual")
assert ks.is_armed() is False
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
assert state.kill_at is None
# arm + disarm = 2 audit lines
assert verify_chain(audit_path) == 2
def test_disarm_when_not_armed_is_noop(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.disarm(reason="nothing to do", source="manual")
assert verify_chain(audit_path) == 0
def test_arm_requires_reason(tmp_path: Path) -> None:
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
with pytest.raises(KillSwitchError, match="reason is required"):
ks.arm(reason="", source="manual")
def test_arm_without_initialised_state_raises(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
ks.arm(reason="x", source="manual")
def test_audit_chain_records_event_kind(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.arm(reason="x", source="mcp_timeout")
ks.disarm(reason="y", source="manual")
text = audit_path.read_text(encoding="utf-8")
assert "KILL_SWITCH_ARMED" in text
assert "KILL_SWITCH_DISARMED" in text
def test_is_armed_returns_false_when_singleton_missing(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert ks.is_armed() is False
def test_disarm_requires_reason(tmp_path: Path) -> None:
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
with pytest.raises(KillSwitchError, match="reason is required"):
ks.disarm(reason="", source="manual")
def test_disarm_without_initialised_state_raises(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
ks.disarm(reason="x", source="manual")
def test_clock_is_advanced_for_each_call(tmp_path: Path) -> None:
ks, _audit, _audit_path, repo = _make_kill_switch(tmp_path)
ks.arm(reason="x", source="manual")
ks.disarm(reason="y", source="manual")
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
# last_health_check should reflect the disarm time (14:20 from the fake clock).
assert state.last_health_check >= datetime(2026, 4, 27, 14, 15, tzinfo=UTC) - timedelta(
seconds=1
)
+19 -4
View File
@@ -38,14 +38,29 @@ def test_cli_status_runs(tmp_data_dir: Path) -> None:
assert "phase: 0" in result.output
def test_cli_kill_switch_arm_placeholder(tmp_data_dir: Path) -> None:
def test_cli_kill_switch_arm_persists_state(tmp_data_dir: Path) -> None:
runner = CliRunner()
db_path = tmp_data_dir / "state.sqlite"
audit_path = tmp_data_dir / "audit.log"
result = runner.invoke(
cli_main,
["--log-dir", str(tmp_data_dir / "log"), "kill-switch", "arm", "--reason", "test"],
[
"--log-dir",
str(tmp_data_dir / "log"),
"kill-switch",
"arm",
"--reason",
"smoke",
"--db",
str(db_path),
"--audit",
str(audit_path),
],
)
assert result.exit_code == 0
assert "phase 0 placeholder" in result.output
assert result.exit_code == 0, result.output
assert "ARMED" in result.output
assert db_path.exists()
assert audit_path.exists()
def test_cli_version_flag() -> None:
+109
View File
@@ -0,0 +1,109 @@
"""Migration runner + SQLite pragma sanity tests."""
from __future__ import annotations
import sqlite3
from pathlib import Path
import pytest
from cerbero_bite.state.db import (
connect,
current_version,
list_migrations,
run_migrations,
transaction,
)
def test_list_migrations_is_ordered_and_starts_with_0001() -> None:
migs = list_migrations()
assert migs, "expected at least one migration file"
versions = [m[0] for m in migs]
assert versions == sorted(versions)
assert versions[0] == 1
def test_run_migrations_creates_full_schema(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
new_version = run_migrations(conn)
assert new_version == max(m[0] for m in list_migrations())
# Sanity: every documented table is present.
existing = {
row["name"]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
}
assert {
"positions",
"instructions",
"decisions",
"dvol_history",
"manual_actions",
"system_state",
} <= existing
finally:
conn.close()
def test_run_migrations_is_idempotent(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
first = current_version(conn)
run_migrations(conn) # second call must be a no-op
assert current_version(conn) == first
finally:
conn.close()
def test_pragmas_applied(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
assert conn.execute("PRAGMA foreign_keys").fetchone()[0] == 1
# WAL is sticky on the file: confirm via the journal_mode pragma.
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode.lower() == "wal"
finally:
conn.close()
def test_transaction_rolls_back_on_exception(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
with pytest.raises(RuntimeError, match="boom"), transaction(conn):
conn.execute(
"INSERT INTO system_state(id, last_health_check, "
"config_version, started_at) VALUES (1, '2026-04-27', "
"'1.0.0', '2026-04-27')"
)
raise RuntimeError("boom")
# Row was rolled back.
rows = conn.execute("SELECT COUNT(*) FROM system_state").fetchone()[0]
assert rows == 0
finally:
conn.close()
def test_foreign_key_enforcement(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
# Inserting an instruction without a parent position must fail.
with pytest.raises(sqlite3.IntegrityError):
conn.execute(
"INSERT INTO instructions(instruction_id, proposal_id, kind, "
"payload_json, sent_at) VALUES (?, ?, ?, ?, ?)",
("i1", "missing", "open_combo", "{}", "2026-04-27"),
)
finally:
conn.close()
+450
View File
@@ -0,0 +1,450 @@
"""CRUD tests for :mod:`cerbero_bite.state.repository`."""
from __future__ import annotations
import sqlite3
from collections.abc import Iterator
from datetime import UTC, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from uuid import uuid4
import pytest
from cerbero_bite.state import (
DecisionRecord,
DvolSnapshot,
InstructionRecord,
ManualAction,
PositionRecord,
Repository,
connect,
run_migrations,
transaction,
)
@pytest.fixture
def conn(tmp_path: Path) -> Iterator[sqlite3.Connection]:
db = tmp_path / "state.sqlite"
c = connect(db)
run_migrations(c)
yield c
c.close()
@pytest.fixture
def repo() -> Repository:
return Repository()
def _make_position(**overrides: object) -> PositionRecord:
base: dict[str, object] = {
"proposal_id": uuid4(),
"spread_type": "bull_put",
"expiry": datetime(2026, 5, 15, 8, 0, tzinfo=UTC),
"short_strike": Decimal("2475"),
"long_strike": Decimal("2350"),
"short_instrument": "ETH-15MAY26-2475-P",
"long_instrument": "ETH-15MAY26-2350-P",
"n_contracts": 2,
"spread_width_usd": Decimal("125"),
"spread_width_pct": Decimal("0.0417"),
"credit_eth": Decimal("0.030"),
"credit_usd": Decimal("90"),
"max_loss_usd": Decimal("160"),
"spot_at_entry": Decimal("3000"),
"dvol_at_entry": Decimal("50"),
"delta_at_entry": Decimal("-0.12"),
"eth_price_at_entry": Decimal("3000"),
"proposed_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
"status": "proposed",
"created_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
"updated_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
}
base.update(overrides)
return PositionRecord(**base) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# positions
# ---------------------------------------------------------------------------
def test_create_and_get_position_roundtrip(
conn: sqlite3.Connection, repo: Repository
) -> None:
record = _make_position()
with transaction(conn):
repo.create_position(conn, record)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.proposal_id == record.proposal_id
assert fetched.short_strike == Decimal("2475")
# Decimal precision preserved (no float coercion).
assert fetched.spread_width_pct == Decimal("0.0417")
def test_get_unknown_position_returns_none(
conn: sqlite3.Connection, repo: Repository
) -> None:
assert repo.get_position(conn, uuid4()) is None
def test_list_open_positions_filters_by_status(
conn: sqlite3.Connection, repo: Repository
) -> None:
open_pos = _make_position(status="open")
closed_pos = _make_position(status="closed")
with transaction(conn):
repo.create_position(conn, open_pos)
repo.create_position(conn, closed_pos)
open_only = repo.list_open_positions(conn)
assert {p.proposal_id for p in open_only} == {open_pos.proposal_id}
def test_count_concurrent_positions_excludes_closed(
conn: sqlite3.Connection, repo: Repository
) -> None:
with transaction(conn):
repo.create_position(conn, _make_position(status="open"))
repo.create_position(conn, _make_position(status="awaiting_fill"))
repo.create_position(conn, _make_position(status="closed"))
repo.create_position(conn, _make_position(status="cancelled"))
assert repo.count_concurrent_positions(conn) == 2
def test_update_position_status_with_only_status_skips_optional_fields(
conn: sqlite3.Connection, repo: Repository
) -> None:
"""Hit the False side of every optional-field branch in ``update_position_status``."""
record = _make_position(status="proposed")
with transaction(conn):
repo.create_position(conn, record)
repo.update_position_status(
conn,
record.proposal_id,
status="awaiting_fill",
now=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.status == "awaiting_fill"
assert fetched.opened_at is None
assert fetched.closed_at is None
assert fetched.close_reason is None
def test_update_position_status_persists_open_then_close(
conn: sqlite3.Connection, repo: Repository
) -> None:
record = _make_position(status="awaiting_fill")
opened_at = datetime(2026, 4, 27, 14, 10, tzinfo=UTC)
with transaction(conn):
repo.create_position(conn, record)
repo.update_position_status(
conn,
record.proposal_id,
status="open",
opened_at=opened_at,
now=datetime(2026, 4, 27, 14, 11, tzinfo=UTC),
)
after_open = repo.get_position(conn, record.proposal_id)
assert after_open is not None
assert after_open.opened_at == opened_at
closed_at = datetime(2026, 5, 12, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.update_position_status(
conn,
record.proposal_id,
status="closed",
closed_at=closed_at,
close_reason="CLOSE_PROFIT",
debit_paid_eth=Decimal("0.012"),
pnl_eth=Decimal("0.018"),
pnl_usd=Decimal("54"),
now=closed_at,
)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.status == "closed"
assert fetched.close_reason == "CLOSE_PROFIT"
assert fetched.debit_paid_eth == Decimal("0.012")
assert fetched.pnl_eth == Decimal("0.018")
def test_naive_datetime_is_normalised_to_utc(
conn: sqlite3.Connection, repo: Repository
) -> None:
naive_now = datetime(2026, 4, 27, 14, 0)
record = _make_position(proposed_at=naive_now, created_at=naive_now, updated_at=naive_now)
with transaction(conn):
repo.create_position(conn, record)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.proposed_at.tzinfo is not None
assert fetched.proposed_at.utcoffset() == timedelta(0)
# ---------------------------------------------------------------------------
# instructions
# ---------------------------------------------------------------------------
def test_instruction_lifecycle_ack_and_fill(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position(status="awaiting_fill")
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json='{"action":"open"}',
sent_at=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(
conn,
instr.instruction_id,
acknowledged_at=datetime(2026, 4, 27, 14, 6, tzinfo=UTC),
)
repo.update_instruction(
conn,
instr.instruction_id,
filled_at=datetime(2026, 4, 27, 14, 8, tzinfo=UTC),
actual_fill_eth=Decimal("0.0298"),
actual_fees_eth=Decimal("0.0001"),
)
fetched = repo.list_instructions(conn, pos.proposal_id)
assert len(fetched) == 1
assert fetched[0].acknowledged_at is not None
assert fetched[0].filled_at is not None
assert fetched[0].actual_fill_eth == Decimal("0.0298")
def test_update_instruction_no_op_when_no_fields(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json="{}",
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(conn, instr.instruction_id) # no fields
fetched = repo.list_instructions(conn, pos.proposal_id)
assert fetched[0].acknowledged_at is None
# ---------------------------------------------------------------------------
# decisions / dvol / manual_actions
# ---------------------------------------------------------------------------
def test_record_and_list_decisions(
conn: sqlite3.Connection, repo: Repository
) -> None:
decision = DecisionRecord(
decision_type="entry_check",
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
inputs_json='{"capital":1500}',
outputs_json='{"accepted":true}',
action_taken="propose_open",
)
with transaction(conn):
decision_id = repo.record_decision(conn, decision)
assert decision_id > 0
decisions = repo.list_decisions(conn)
assert len(decisions) == 1
assert decisions[0].id == decision_id
assert decisions[0].action_taken == "propose_open"
def test_record_dvol_snapshot_replaces_on_duplicate_timestamp(
conn: sqlite3.Connection, repo: Repository
) -> None:
ts = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.record_dvol_snapshot(
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("50"), eth_spot=Decimal("3000"))
)
repo.record_dvol_snapshot(
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("55"), eth_spot=Decimal("3050"))
)
rows = conn.execute("SELECT COUNT(*) FROM dvol_history").fetchone()
assert rows[0] == 1
def test_manual_action_enqueue_consume_cycle(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
action = ManualAction(
kind="approve_proposal",
proposal_id=pos.proposal_id,
payload_json='{"reason":"go"}',
created_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
action_id = repo.enqueue_manual_action(conn, action)
next_action = repo.next_unconsumed_action(conn)
assert next_action is not None
assert next_action.id == action_id
with transaction(conn):
repo.mark_action_consumed(
conn,
action_id,
consumed_by="orchestrator",
result="ok",
now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC),
)
assert repo.next_unconsumed_action(conn) is None
# ---------------------------------------------------------------------------
# system_state
# ---------------------------------------------------------------------------
def test_init_system_state_is_idempotent(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
repo.init_system_state(conn, config_version="1.0.0", now=now)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 0
assert state.config_version == "1.0.0"
def test_kill_switch_arm_and_disarm(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
repo.set_kill_switch(conn, armed=True, reason="manual", now=now)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 1
assert state.kill_reason == "manual"
assert state.kill_at is not None
later = now + timedelta(minutes=5)
with transaction(conn):
repo.set_kill_switch(conn, armed=False, reason=None, now=later)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 0
assert state.kill_at is None
assert state.last_health_check == later
def test_get_system_state_returns_none_when_uninitialised(
conn: sqlite3.Connection, repo: Repository
) -> None:
assert repo.get_system_state(conn) is None
def test_list_positions_filters_by_status(
conn: sqlite3.Connection, repo: Repository
) -> None:
open_pos = _make_position(status="open")
closed_pos = _make_position(status="closed")
with transaction(conn):
repo.create_position(conn, open_pos)
repo.create_position(conn, closed_pos)
closed_only = repo.list_positions(conn, status="closed")
assert len(closed_only) == 1
assert closed_only[0].proposal_id == closed_pos.proposal_id
def test_list_positions_without_filter_returns_all(
conn: sqlite3.Connection, repo: Repository
) -> None:
with transaction(conn):
repo.create_position(conn, _make_position(status="open"))
repo.create_position(conn, _make_position(status="closed"))
assert len(repo.list_positions(conn)) == 2
def test_list_decisions_filters_by_proposal(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
with transaction(conn):
repo.create_position(conn, pos)
repo.record_decision(
conn,
DecisionRecord(
decision_type="entry_check",
proposal_id=pos.proposal_id,
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
inputs_json="{}",
outputs_json="{}",
action_taken="propose_open",
),
)
repo.record_decision(
conn,
DecisionRecord(
decision_type="exit_check",
timestamp=datetime(2026, 4, 27, 15, 0, tzinfo=UTC),
inputs_json="{}",
outputs_json="{}",
action_taken="HOLD",
),
)
only_for_proposal = repo.list_decisions(conn, proposal_id=pos.proposal_id)
assert len(only_for_proposal) == 1
assert only_for_proposal[0].action_taken == "propose_open"
def test_update_instruction_sets_cancelled(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json="{}",
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(
conn,
instr.instruction_id,
cancelled_at=datetime(2026, 4, 27, 14, 30, tzinfo=UTC),
)
fetched = repo.list_instructions(conn, pos.proposal_id)
assert fetched[0].cancelled_at is not None
def test_touch_health_check_updates_timestamp(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
later = now + timedelta(minutes=10)
repo.touch_health_check(conn, now=later)
state = repo.get_system_state(conn)
assert state is not None
assert state.last_health_check == later