From 263470786d28a371684bb036cc2d751ef26d2afa Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Mon, 27 Apr 2026 13:35:35 +0200 Subject: [PATCH] Phase 2: persistence + safety controls Aggiunge la persistenza SQLite, l'audit log a hash chain, il kill switch coordinato e i CLI di gestione documentati in docs/05-data-model.md e docs/07-risk-controls.md. 197 test pass, 1 skipped (sqlite3 CLI mancante), copertura totale 97%. State (`state/`): - 0001_init.sql con positions, instructions, decisions, dvol_history, manual_actions, system_state. - db.py: connect con WAL + foreign_keys + transaction ctx, runner forward-only basato su PRAGMA user_version. - models.py: record Pydantic, Decimal preservato come TEXT. - repository.py: CRUD typed con singola connessione passata, cache aware, posizioni concorrenti. Safety (`safety/`): - audit_log.py: AuditLog append-only con SHA-256 chain e fsync, verify_chain riconosce ogni manomissione (payload, prev_hash, hash, JSON, separatori). - kill_switch.py: arm/disarm transazionali, idempotenti, accoppiati all'audit chain. Config (`config/loader.py` + `strategy.yaml`): - Loader YAML con deep-merge di strategy.local.yaml. - Verifica config_hash SHA-256 (riga config_hash esclusa). - File golden strategy.yaml + esempio override. Scripts: - dead_man.sh: watchdog shell indipendente da Python. - backup.py: VACUUM INTO orario con retention 30 giorni. CLI: - audit verify (exit 2 su tampering). - kill-switch arm/disarm/status su SQLite reale. - state inspect con tabella posizioni aperte. - config hash, config validate. Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/backup.py | 110 ++++ scripts/dead_man.sh | 107 ++++ src/cerbero_bite/cli.py | 257 +++++++- src/cerbero_bite/config/__init__.py | 10 + src/cerbero_bite/config/loader.py | 141 +++++ src/cerbero_bite/safety/__init__.py | 19 + src/cerbero_bite/safety/audit_log.py | 246 ++++++++ src/cerbero_bite/safety/kill_switch.py | 120 ++++ src/cerbero_bite/state/__init__.py | 27 + src/cerbero_bite/state/db.py | 102 ++++ .../state/migrations/0001_init.sql | 101 ++++ src/cerbero_bite/state/migrations/__init__.py | 1 + src/cerbero_bite/state/models.py | 154 +++++ src/cerbero_bite/state/repository.py | 553 ++++++++++++++++++ strategy.local.yaml.example | 15 + strategy.yaml | 149 +++++ tests/integration/test_dead_man_sh.py | 111 ++++ tests/unit/test_audit_log.py | 273 +++++++++ tests/unit/test_backup.py | 126 ++++ tests/unit/test_cli_safety.py | 160 +++++ tests/unit/test_config_loader.py | 149 +++++ tests/unit/test_kill_switch.py | 170 ++++++ tests/unit/test_smoke.py | 23 +- tests/unit/test_state_db.py | 109 ++++ tests/unit/test_state_repository.py | 450 ++++++++++++++ 25 files changed, 3669 insertions(+), 14 deletions(-) create mode 100644 scripts/backup.py create mode 100755 scripts/dead_man.sh create mode 100644 src/cerbero_bite/config/loader.py create mode 100644 src/cerbero_bite/safety/audit_log.py create mode 100644 src/cerbero_bite/safety/kill_switch.py create mode 100644 src/cerbero_bite/state/db.py create mode 100644 src/cerbero_bite/state/migrations/0001_init.sql create mode 100644 src/cerbero_bite/state/migrations/__init__.py create mode 100644 src/cerbero_bite/state/models.py create mode 100644 src/cerbero_bite/state/repository.py create mode 100644 strategy.local.yaml.example create mode 100644 strategy.yaml create mode 100644 tests/integration/test_dead_man_sh.py create mode 100644 tests/unit/test_audit_log.py create mode 100644 tests/unit/test_backup.py create mode 100644 tests/unit/test_cli_safety.py create mode 100644 tests/unit/test_config_loader.py create mode 100644 tests/unit/test_kill_switch.py create mode 100644 tests/unit/test_state_db.py create mode 100644 tests/unit/test_state_repository.py diff --git a/scripts/backup.py b/scripts/backup.py new file mode 100644 index 0000000..d96c0af --- /dev/null +++ b/scripts/backup.py @@ -0,0 +1,110 @@ +"""Hourly SQLite backup utility (``docs/05-data-model.md``). + +Uses ``VACUUM INTO`` so the snapshot is a self-contained, defragmented +SQLite file that can be inspected without locking the live database. + +Retention is enforced by deletion: any file matching +``state-YYYYMMDD-HH.sqlite`` older than ``retention_days`` is removed. + +Designed to be invoked from APScheduler in the orchestrator and from +the CLI for ad-hoc backups. +""" + +from __future__ import annotations + +import argparse +import re +import sqlite3 +from collections.abc import Iterable +from datetime import UTC, datetime, timedelta +from pathlib import Path + +__all__ = ["BACKUP_FILENAME_RE", "backup_database", "prune_backups"] + + +BACKUP_FILENAME_RE = re.compile(r"^state-(\d{8}-\d{2})\.sqlite$") +_DEFAULT_RETENTION_DAYS = 30 + + +def _format_backup_name(now: datetime) -> str: + return f"state-{now.astimezone(UTC):%Y%m%d-%H}.sqlite" + + +def backup_database( + *, + db_path: Path | str, + backup_dir: Path | str, + now: datetime | None = None, +) -> Path: + """Create a snapshot via ``VACUUM INTO`` and return its path.""" + src = Path(db_path) + dst_dir = Path(backup_dir) + dst_dir.mkdir(parents=True, exist_ok=True) + + timestamp = (now or datetime.now(UTC)).astimezone(UTC) + dst = dst_dir / _format_backup_name(timestamp) + if dst.exists(): + # Idempotent at hour granularity: same hour = same target file. + dst.unlink() + + conn = sqlite3.connect(str(src)) + try: + conn.execute(f"VACUUM INTO '{dst.as_posix()}'") + finally: + conn.close() + return dst + + +def _parse_backup_timestamp(name: str) -> datetime | None: + match = BACKUP_FILENAME_RE.match(name) + if match is None: + return None + try: + return datetime.strptime(match.group(1), "%Y%m%d-%H").replace(tzinfo=UTC) + except ValueError: + return None + + +def prune_backups( + backup_dir: Path | str, + *, + retention_days: int = _DEFAULT_RETENTION_DAYS, + now: datetime | None = None, +) -> list[Path]: + """Remove backups older than ``retention_days``. Returns the deleted paths.""" + cutoff = (now or datetime.now(UTC)).astimezone(UTC) - timedelta(days=retention_days) + deleted: list[Path] = [] + for entry in Path(backup_dir).iterdir(): + if not entry.is_file(): + continue + ts = _parse_backup_timestamp(entry.name) + if ts is None: + continue + if ts < cutoff: + entry.unlink() + deleted.append(entry) + return deleted + + +def list_backups(backup_dir: Path | str) -> Iterable[Path]: + return sorted( + (p for p in Path(backup_dir).iterdir() if BACKUP_FILENAME_RE.match(p.name)), + key=lambda p: p.name, + ) + + +def _cli() -> None: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument("--db", default="data/state.sqlite") + parser.add_argument("--out", default="data/backups") + parser.add_argument("--retention-days", type=int, default=_DEFAULT_RETENTION_DAYS) + args = parser.parse_args() + out = backup_database(db_path=args.db, backup_dir=args.out) + pruned = prune_backups(args.out, retention_days=args.retention_days) + print(f"backup -> {out}") + if pruned: + print(f"pruned: {', '.join(p.name for p in pruned)}") + + +if __name__ == "__main__": + _cli() diff --git a/scripts/dead_man.sh b/scripts/dead_man.sh new file mode 100755 index 0000000..fa16a1d --- /dev/null +++ b/scripts/dead_man.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash +# dead_man.sh — independent watchdog for Cerbero Bite (docs/07-risk-controls.md). +# +# Runs from cron every 5 minutes. If the engine has not written a +# HEALTH_OK event into today's JSONL log within the last +# DEAD_MAN_THRESHOLD_SECONDS (default 900 = 15 minutes), it: +# 1. Sends an alert via DEAD_MAN_ALERT_CMD (any command consuming +# a single argument: the alert text). When unset, falls back to +# writing data/log/dead-man-alert.txt so an external watcher can +# pick it up. +# 2. Arms the SQLite kill switch directly (no Python required). +# 3. Appends one line to data/audit.log (best-effort hash chain; +# verifying after recovery is the operator's job). +# +# Configuration via env vars or .env in PROJECT_ROOT: +# PROJECT_ROOT — repo root (default: parent of this file). +# DEAD_MAN_THRESHOLD_SECONDS — silence threshold (default 900). +# DEAD_MAN_ALERT_CMD — optional alert command. +# +# This script intentionally avoids Python so it survives env corruption. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="${PROJECT_ROOT:-$(cd "${SCRIPT_DIR}/.." && pwd)}" +THRESHOLD="${DEAD_MAN_THRESHOLD_SECONDS:-900}" + +LOG_DIR="${PROJECT_ROOT}/data/log" +DB_PATH="${PROJECT_ROOT}/data/state.sqlite" +AUDIT_PATH="${PROJECT_ROOT}/data/audit.log" +ALERT_FILE="${LOG_DIR}/dead-man-alert.txt" + +today_log() { + date -u +"${LOG_DIR}/cerbero-bite-%Y-%m-%d.jsonl" +} + +last_health_ts() { + local file="$1" + if [[ ! -f "$file" ]]; then + echo "" + return + fi + grep -F '"event": "HEALTH_OK"' "$file" 2>/dev/null \ + | tail -n 1 \ + | sed -E 's/.*"ts":[[:space:]]*"([^"]+)".*/\1/' \ + || true +} + +emit_alert() { + local message="$1" + if [[ -n "${DEAD_MAN_ALERT_CMD:-}" ]]; then + "${DEAD_MAN_ALERT_CMD}" "$message" || true + fi + mkdir -p "$LOG_DIR" + printf '%s | %s\n' "$(date -u +%FT%TZ)" "$message" >> "$ALERT_FILE" +} + +arm_kill_switch() { + if [[ ! -f "$DB_PATH" ]] || ! command -v sqlite3 >/dev/null 2>&1; then + return + fi + sqlite3 "$DB_PATH" <threshold"}|prev_hash=manual|hash=manual\n' "$ts" >> "$AUDIT_PATH" +} + +main() { + local log_file + log_file="$(today_log)" + local last_ts + last_ts="$(last_health_ts "$log_file")" + + if [[ -z "$last_ts" ]]; then + emit_alert "dead_man: no HEALTH_OK in $log_file" + arm_kill_switch + append_audit_line + exit 1 + fi + + local last_epoch now_epoch elapsed + last_epoch="$(date -u -d "$last_ts" +%s 2>/dev/null || echo 0)" + now_epoch="$(date -u +%s)" + elapsed=$(( now_epoch - last_epoch )) + + if (( elapsed > THRESHOLD )); then + emit_alert "dead_man: ${elapsed}s since last HEALTH_OK (threshold ${THRESHOLD}s)" + arm_kill_switch + append_audit_line + exit 1 + fi + + exit 0 +} + +main "$@" diff --git a/src/cerbero_bite/cli.py b/src/cerbero_bite/cli.py index 35c2e4e..dc8dd85 100644 --- a/src/cerbero_bite/cli.py +++ b/src/cerbero_bite/cli.py @@ -10,19 +10,32 @@ without changing the surface. from __future__ import annotations import sys +from datetime import UTC, datetime from pathlib import Path import click from rich.console import Console +from rich.table import Table from cerbero_bite import __version__ +from cerbero_bite.config.loader import compute_config_hash, load_strategy from cerbero_bite.logging import configure as configure_logging from cerbero_bite.logging import get_logger +from cerbero_bite.safety.audit_log import AuditChainError, AuditLog +from cerbero_bite.safety.audit_log import verify_chain as verify_audit_chain +from cerbero_bite.safety.kill_switch import KillSwitch +from cerbero_bite.state import Repository, run_migrations, transaction +from cerbero_bite.state import connect as connect_state console = Console() log = get_logger("cli") +_DEFAULT_DB_PATH = Path("data/state.sqlite") +_DEFAULT_AUDIT_PATH = Path("data/audit.log") +_DEFAULT_STRATEGY_PATH = Path("strategy.yaml") + + def _phase0_notice(action: str) -> None: console.print(f"[yellow]\\[phase 0 placeholder][/yellow] {action}") @@ -85,18 +98,131 @@ def kill_switch() -> None: """Manage the engine kill switch.""" +def _make_kill_switch( + db_path: Path, audit_path: Path, *, config_version: str +) -> KillSwitch: + """Wire a :class:`KillSwitch` against the on-disk paths. + + ``init_system_state`` is called eagerly so the CLI can be used on + a fresh checkout before the engine ever ran. + """ + db_path.parent.mkdir(parents=True, exist_ok=True) + audit_path.parent.mkdir(parents=True, exist_ok=True) + conn = connect_state(db_path) + try: + run_migrations(conn) + repo = Repository() + with transaction(conn): + repo.init_system_state( + conn, config_version=config_version, now=datetime.now(UTC) + ) + finally: + conn.close() + + return KillSwitch( + connection_factory=lambda: connect_state(db_path), + repository=Repository(), + audit_log=AuditLog(audit_path), + ) + + @kill_switch.command(name="arm") @click.option("--reason", required=True, help="Why you are arming the kill switch.") -def kill_switch_arm(reason: str) -> None: +@click.option( + "--source", + default="manual", + show_default=True, + help="Trigger label (manual, mcp_timeout, hash_chain, ...).", +) +@click.option( + "--db", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_DB_PATH, + show_default=True, +) +@click.option( + "--audit", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_AUDIT_PATH, + show_default=True, +) +@click.option( + "--config-version", + default="unknown", + show_default=True, + help="Recorded next to the kill event when the singleton is initialised.", +) +def kill_switch_arm( + reason: str, source: str, db: Path, audit: Path, config_version: str +) -> None: """Arm the kill switch (engine refuses new entries).""" - _phase0_notice(f"kill-switch arm placeholder (reason: {reason!r}).") + ks = _make_kill_switch(db, audit, config_version=config_version) + ks.arm(reason=reason, source=source) + console.print(f"[red]kill switch ARMED[/red] reason={reason!r} source={source}") @kill_switch.command(name="disarm") @click.option("--reason", required=True, help="Why you are disarming.") -def kill_switch_disarm(reason: str) -> None: +@click.option( + "--source", + default="manual", + show_default=True, +) +@click.option( + "--db", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_DB_PATH, + show_default=True, +) +@click.option( + "--audit", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_AUDIT_PATH, + show_default=True, +) +@click.option( + "--config-version", + default="unknown", + show_default=True, +) +def kill_switch_disarm( + reason: str, source: str, db: Path, audit: Path, config_version: str +) -> None: """Disarm the kill switch.""" - _phase0_notice(f"kill-switch disarm placeholder (reason: {reason!r}).") + ks = _make_kill_switch(db, audit, config_version=config_version) + ks.disarm(reason=reason, source=source) + console.print(f"[green]kill switch DISARMED[/green] reason={reason!r}") + + +@kill_switch.command(name="status") +@click.option( + "--db", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_DB_PATH, + show_default=True, +) +def kill_switch_status(db: Path) -> None: + """Print the current kill switch state.""" + if not db.exists(): + console.print("[yellow]state.sqlite not found — engine never ran[/yellow]") + return + conn = connect_state(db) + try: + run_migrations(conn) + state = Repository().get_system_state(conn) + finally: + conn.close() + if state is None: + console.print("[yellow]system_state singleton missing[/yellow]") + return + armed = state.kill_switch == 1 + flag = "[red]ARMED[/red]" if armed else "[green]disarmed[/green]" + console.print( + f"kill_switch: {flag}\n" + f"reason: {state.kill_reason or '-'}\n" + f"kill_at: {state.kill_at.isoformat() if state.kill_at else '-'}\n" + f"last_health_check: {state.last_health_check.isoformat()}" + ) @main.command() @@ -123,9 +249,42 @@ def config() -> None: @config.command(name="hash") -def config_hash() -> None: - """Compute and print SHA-256 of strategy.yaml.""" - _phase0_notice("config hash placeholder; will read strategy.yaml and compute SHA-256.") +@click.option( + "--file", + "yaml_path", + type=click.Path(exists=True, dir_okay=False, path_type=Path), + default=_DEFAULT_STRATEGY_PATH, + show_default=True, +) +def config_hash(yaml_path: Path) -> None: + """Compute and print the SHA-256 of *yaml_path* (config_hash field excluded).""" + text = yaml_path.read_text(encoding="utf-8") + digest = compute_config_hash(text) + console.print(digest) + + +@config.command(name="validate") +@click.option( + "--file", + "yaml_path", + type=click.Path(exists=True, dir_okay=False, path_type=Path), + default=_DEFAULT_STRATEGY_PATH, + show_default=True, +) +@click.option( + "--enforce-hash/--no-enforce-hash", + default=True, + show_default=True, + help="When enabled, the recorded config_hash must match the file body.", +) +def config_validate(yaml_path: Path, enforce_hash: bool) -> None: + """Load and validate ``strategy.yaml`` (and any local override).""" + loaded = load_strategy(yaml_path, enforce_hash=enforce_hash) + console.print( + f"[green]ok[/green] version={loaded.config.config_version} " + f"hash={loaded.computed_hash[:16]}… " + f"sources={', '.join(p.name for p in loaded.sources)}" + ) @main.group() @@ -134,9 +293,87 @@ def audit() -> None: @audit.command(name="verify") -def audit_verify() -> None: - """Verify audit chain integrity.""" - _phase0_notice("audit verify placeholder; will walk audit.log hash chain.") +@click.option( + "--file", + "audit_path", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_AUDIT_PATH, + show_default=True, +) +def audit_verify(audit_path: Path) -> None: + """Walk the hash chain in *audit_path* and report tampering.""" + try: + count = verify_audit_chain(audit_path) + except AuditChainError as exc: + console.print(f"[red]TAMPERED[/red]: {exc}") + sys.exit(2) + if count == 0: + console.print("[yellow]audit log empty[/yellow]") + else: + console.print(f"[green]ok[/green] {count} entries verified") + + +@main.group() +def state() -> None: + """State inspection utilities.""" + + +@state.command(name="inspect") +@click.option( + "--db", + type=click.Path(dir_okay=False, path_type=Path), + default=_DEFAULT_DB_PATH, + show_default=True, +) +def state_inspect(db: Path) -> None: + """Print a short snapshot of the SQLite state file.""" + if not db.exists(): + console.print("[yellow]state.sqlite not found[/yellow]") + return + conn = connect_state(db) + try: + run_migrations(conn) + repo = Repository() + sys_state = repo.get_system_state(conn) + positions = repo.list_open_positions(conn) + concurrent = repo.count_concurrent_positions(conn) + finally: + conn.close() + + if sys_state is None: + console.print("[yellow]system_state singleton missing[/yellow]") + return + + armed = "[red]ARMED[/red]" if sys_state.kill_switch == 1 else "[green]disarmed[/green]" + console.print( + f"engine state: kill_switch={armed}, " + f"open positions: {concurrent}, " + f"config_version: {sys_state.config_version}" + ) + + if not positions: + console.print("no open positions") + return + + table = Table(title="open positions") + table.add_column("proposal_id") + table.add_column("status") + table.add_column("spread") + table.add_column("short") + table.add_column("long") + table.add_column("n") + table.add_column("expiry") + for pos in positions: + table.add_row( + str(pos.proposal_id)[:8], + pos.status, + pos.spread_type, + str(pos.short_strike), + str(pos.long_strike), + str(pos.n_contracts), + pos.expiry.isoformat(), + ) + console.print(table) def _entrypoint() -> None: diff --git a/src/cerbero_bite/config/__init__.py b/src/cerbero_bite/config/__init__.py index aa1ca40..c8610b2 100644 --- a/src/cerbero_bite/config/__init__.py +++ b/src/cerbero_bite/config/__init__.py @@ -1,5 +1,11 @@ """Strategy configuration: schema, loader, validation.""" +from cerbero_bite.config.loader import ( + ConfigHashError, + LoadedConfig, + compute_config_hash, + load_strategy, +) from cerbero_bite.config.schema import ( AssetConfig, DvolAdjustmentBand, @@ -23,12 +29,14 @@ from cerbero_bite.config.schema import ( __all__ = [ "AssetConfig", + "ConfigHashError", "DvolAdjustmentBand", "EntryConfig", "ExecutionConfig", "ExitConfig", "KellyConfig", "LiquidityConfig", + "LoadedConfig", "McpConfig", "MonitoringConfig", "ShortStrikeSpec", @@ -39,5 +47,7 @@ __all__ = [ "StrategyConfig", "StructureConfig", "TelegramConfig", + "compute_config_hash", "golden_config", + "load_strategy", ] diff --git a/src/cerbero_bite/config/loader.py b/src/cerbero_bite/config/loader.py new file mode 100644 index 0000000..188b1ec --- /dev/null +++ b/src/cerbero_bite/config/loader.py @@ -0,0 +1,141 @@ +"""YAML loader for ``strategy.yaml`` with optional local override. + +* Reads ``strategy.yaml`` (golden config). +* If ``strategy.local.yaml`` exists alongside, deep-merges its keys on + top — that file is ``.gitignore``'d and used by Adriano for emergency + overrides. +* Verifies ``config_hash`` matches the SHA-256 of the YAML *minus* the + ``config_hash`` line itself. A mismatch is reported via + :class:`ConfigHashError` and the orchestrator must arm the kill switch + per ``docs/07-risk-controls.md``. + +The loader does *not* depend on the runtime: it returns a validated +:class:`StrategyConfig` plus the computed hash; nothing else. +""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml + +from cerbero_bite.config.schema import StrategyConfig + +__all__ = [ + "ConfigHashError", + "LoadedConfig", + "compute_config_hash", + "load_strategy", +] + + +_HASH_KEY = "config_hash" + + +class ConfigHashError(RuntimeError): + """Raised when the recorded ``config_hash`` does not match the file.""" + + +@dataclass(frozen=True) +class LoadedConfig: + """Result of :func:`load_strategy`.""" + + config: StrategyConfig + computed_hash: str + sources: tuple[Path, ...] + + +def _strip_hash_line(text: str) -> str: + """Return *text* with the ``config_hash:`` line replaced by an empty string. + + We deliberately keep the surrounding whitespace so that any other + line numbers stay stable; only the value of the hash is removed. + """ + out: list[str] = [] + for line in text.splitlines(keepends=True): + stripped = line.lstrip() + if stripped.startswith(f"{_HASH_KEY}:"): + # keep the key but strip the value, so identical files with + # different hashes still hash the same way + indent = line[: len(line) - len(stripped)] + out.append(f"{indent}{_HASH_KEY}:\n") + else: + out.append(line) + return "".join(out) + + +def compute_config_hash(text: str) -> str: + """SHA-256 of the YAML text after stripping the ``config_hash`` value.""" + canonical = _strip_hash_line(text).encode("utf-8") + return hashlib.sha256(canonical).hexdigest() + + +def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + """Return ``base`` with values from ``override`` recursively merged in.""" + out = dict(base) + for key, value in override.items(): + existing = out.get(key) + if isinstance(existing, dict) and isinstance(value, dict): + out[key] = _deep_merge(existing, value) + else: + out[key] = value + return out + + +def _load_yaml(path: Path) -> dict[str, Any]: + data = yaml.safe_load(path.read_text(encoding="utf-8")) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError(f"{path}: expected a top-level mapping") + return data + + +def load_strategy( + yaml_path: Path | str, + *, + local_override_path: Path | str | None = None, + enforce_hash: bool = True, +) -> LoadedConfig: + """Load and validate a strategy YAML, optionally merging a local file. + + Args: + yaml_path: path to ``strategy.yaml``. + local_override_path: when ``None`` (default), use + ``.local.yaml`` if present. Pass ``False`` / + non-existent path to disable. + enforce_hash: when ``True``, raise :class:`ConfigHashError` if + the recorded hash does not match the file. Set to ``False`` + in test fixtures or right after a manual edit. + """ + main_path = Path(yaml_path) + text = main_path.read_text(encoding="utf-8") + raw = _load_yaml(main_path) + sources: list[Path] = [main_path] + + computed_hash = compute_config_hash(text) + declared_hash = raw.get(_HASH_KEY) + if enforce_hash and declared_hash != computed_hash: + raise ConfigHashError( + f"config_hash mismatch in {main_path}: " + f"declared={declared_hash}, computed={computed_hash}" + ) + + if local_override_path is None: + local_override_path = main_path.with_name( + main_path.stem + ".local" + main_path.suffix + ) + override_path = Path(local_override_path) + if override_path.is_file(): + override = _load_yaml(override_path) + raw = _deep_merge(raw, override) + sources.append(override_path) + + return LoadedConfig( + config=StrategyConfig(**raw), + computed_hash=computed_hash, + sources=tuple(sources), + ) diff --git a/src/cerbero_bite/safety/__init__.py b/src/cerbero_bite/safety/__init__.py index e69de29..d6f1698 100644 --- a/src/cerbero_bite/safety/__init__.py +++ b/src/cerbero_bite/safety/__init__.py @@ -0,0 +1,19 @@ +"""Cross-cutting safety controls: kill switch, dead-man, audit chain.""" + +from cerbero_bite.safety.audit_log import ( + GENESIS_HASH, + AuditChainError, + AuditEntry, + AuditLog, + iter_entries, + verify_chain, +) + +__all__ = [ + "GENESIS_HASH", + "AuditChainError", + "AuditEntry", + "AuditLog", + "iter_entries", + "verify_chain", +] diff --git a/src/cerbero_bite/safety/audit_log.py b/src/cerbero_bite/safety/audit_log.py new file mode 100644 index 0000000..be2189f --- /dev/null +++ b/src/cerbero_bite/safety/audit_log.py @@ -0,0 +1,246 @@ +"""Append-only hash-chained audit log (``docs/07-risk-controls.md``). + +Every line is:: + + |||prev_hash=|hash= + +``hash`` is ``sha256("|||")`` so +that the integrity of the file can be re-verified by walking the chain +top-to-bottom. The first line uses ``prev_hash="0" * 64``. + +The writer ``flush + os.fsync`` after every append; for the audit trail, +durability beats throughput. +""" + +from __future__ import annotations + +import hashlib +import json +import os +from collections.abc import Iterator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +__all__ = [ + "GENESIS_HASH", + "AuditChainError", + "AuditEntry", + "AuditLog", + "verify_chain", +] + + +GENESIS_HASH = "0" * 64 +_SEP = "|" + + +class AuditChainError(RuntimeError): + """Raised when the audit chain fails verification.""" + + +@dataclass(frozen=True) +class AuditEntry: + """Parsed audit-log line.""" + + timestamp: datetime + event: str + payload: dict[str, Any] + prev_hash: str + hash: str + + +def _canonical_payload(payload: dict[str, Any]) -> str: + """Serialize the payload deterministically (sorted keys, no whitespace).""" + return json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str) + + +def _compute_hash(timestamp: str, event: str, payload_json: str, prev_hash: str) -> str: + raw = f"{timestamp}{_SEP}{event}{_SEP}{payload_json}{_SEP}{prev_hash}" + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + +def _format_line( + timestamp: str, event: str, payload_json: str, prev_hash: str, line_hash: str +) -> str: + return ( + f"{timestamp}{_SEP}{event}{_SEP}{payload_json}{_SEP}" + f"prev_hash={prev_hash}{_SEP}hash={line_hash}\n" + ) + + +def _parse_line(line: str) -> AuditEntry: + """Parse a stored audit line back into :class:`AuditEntry`. + + The payload may legitimately contain ``|`` characters inside JSON + strings; we therefore split from the right for the two trailing + fields and from the left for the timestamp + event + payload. + """ + if not line.endswith("\n"): + line = line + "\n" + body = line.rstrip("\n") + + # Trailing parts. + try: + rest, hash_part = body.rsplit(_SEP, 1) + except ValueError as exc: + raise AuditChainError("missing hash= field") from exc + if not hash_part.startswith("hash="): + raise AuditChainError("missing hash= field") + line_hash = hash_part[len("hash=") :] + + try: + rest, prev_part = rest.rsplit(_SEP, 1) + except ValueError as exc: + raise AuditChainError("missing prev_hash= field") from exc + if not prev_part.startswith("prev_hash="): + raise AuditChainError("missing prev_hash= field") + prev_hash = prev_part[len("prev_hash=") :] + + # Leading parts. + parts = rest.split(_SEP, 2) + if len(parts) != 3: + raise AuditChainError("malformed leading section") + ts_str, event, payload_json = parts + + try: + payload: dict[str, Any] = json.loads(payload_json) + except json.JSONDecodeError as exc: + raise AuditChainError("payload is not valid JSON") from exc + if not isinstance(payload, dict): + raise AuditChainError("payload must be a JSON object") + + return AuditEntry( + timestamp=datetime.fromisoformat(ts_str), + event=event, + payload=payload, + prev_hash=prev_hash, + hash=line_hash, + ) + + +def verify_chain(path: str | Path) -> int: + """Re-walk *path* and raise :class:`AuditChainError` on tampering. + + Returns the number of lines verified (0 if the file does not exist + or is empty). + """ + p = Path(path) + if not p.exists() or p.stat().st_size == 0: + return 0 + expected_prev = GENESIS_HASH + count = 0 + with p.open("r", encoding="utf-8") as fh: + for lineno, line in enumerate(fh, start=1): + if not line.strip(): + continue + entry = _parse_line(line) + if entry.prev_hash != expected_prev: + raise AuditChainError( + f"line {lineno}: prev_hash mismatch " + f"(expected {expected_prev}, got {entry.prev_hash})" + ) + recomputed = _compute_hash( + entry.timestamp.isoformat(), + entry.event, + _canonical_payload(entry.payload), + entry.prev_hash, + ) + if recomputed != entry.hash: + raise AuditChainError( + f"line {lineno}: hash mismatch (expected {recomputed}, " + f"got {entry.hash})" + ) + expected_prev = entry.hash + count += 1 + return count + + +def iter_entries(path: str | Path) -> Iterator[AuditEntry]: + """Yield each :class:`AuditEntry` from *path* without verifying.""" + p = Path(path) + if not p.exists(): + return + with p.open("r", encoding="utf-8") as fh: + for line in fh: + if line.strip(): + yield _parse_line(line) + + +class AuditLog: + """Writer for the hash-chained audit log. + + A single instance per process is enough; concurrent writers are not + supported by design (the engine is the only writer). ``append`` is + fsync'd before returning. + """ + + def __init__(self, path: str | Path) -> None: + self._path = Path(path) + self._path.parent.mkdir(parents=True, exist_ok=True) + self._last_hash: str = self._tail_hash() or GENESIS_HASH + + @property + def path(self) -> Path: # pragma: no cover — accessor used by callers only + return self._path + + @property + def last_hash(self) -> str: + return self._last_hash + + def _tail_hash(self) -> str | None: + if not self._path.exists() or self._path.stat().st_size == 0: + return None + # Walk from EOF to find the last non-empty line. The chunked + # back-seek covers files larger than 4 KiB; the loop-exhausted + # branch is reached only when a partial / no-newline file is + # encountered (defensive — :func:`append` always writes "\n"). + with self._path.open("rb") as fh: + fh.seek(0, os.SEEK_END) + size = fh.tell() + buf = b"" + offset = size + chunk = 4096 + while offset > 0: # pragma: no branch — terminates via break or offset==0 + read = min(chunk, offset) + offset -= read + fh.seek(offset) + buf = fh.read(read) + buf + if b"\n" in buf: + break + text = buf.decode("utf-8", errors="strict") + for line in reversed(text.splitlines()): # pragma: no branch + if line.strip(): + entry = _parse_line(line) + return entry.hash + return None # pragma: no cover — only hit when file is all blank lines + + def append( + self, + *, + event: str, + payload: dict[str, Any] | None = None, + now: datetime | None = None, + ) -> AuditEntry: + """Append one event line and return the resulting entry.""" + ts = (now or datetime.now(UTC)).astimezone(UTC) + ts_iso = ts.isoformat() + payload_json = _canonical_payload(payload or {}) + prev_hash = self._last_hash + line_hash = _compute_hash(ts_iso, event, payload_json, prev_hash) + + line = _format_line(ts_iso, event, payload_json, prev_hash, line_hash) + with self._path.open("a", encoding="utf-8") as fh: + fh.write(line) + fh.flush() + os.fsync(fh.fileno()) + + self._last_hash = line_hash + return AuditEntry( + timestamp=ts, + event=event, + payload=dict(payload or {}), + prev_hash=prev_hash, + hash=line_hash, + ) diff --git a/src/cerbero_bite/safety/kill_switch.py b/src/cerbero_bite/safety/kill_switch.py new file mode 100644 index 0000000..f131754 --- /dev/null +++ b/src/cerbero_bite/safety/kill_switch.py @@ -0,0 +1,120 @@ +"""Kill switch coordinator (``docs/07-risk-controls.md``). + +Encapsulates the three side effects required when arming/disarming: + +1. Update the ``system_state`` row in SQLite. +2. Append a tamper-evident line to the audit log. +3. Make the new state observable via :meth:`is_armed` for the + orchestrator. + +The orchestrator is the only caller; the GUI signals the same intent +via the :class:`ManualAction` queue, which the orchestrator drains and +forwards here. +""" + +from __future__ import annotations + +import sqlite3 +from collections.abc import Callable +from datetime import UTC, datetime + +from cerbero_bite.safety.audit_log import AuditLog +from cerbero_bite.state import Repository, transaction + +__all__ = ["KillSwitch", "KillSwitchError"] + + +class KillSwitchError(RuntimeError): + """Raised when an arm/disarm transition is invalid.""" + + +class KillSwitch: + """Arm/disarm + status helper backed by ``system_state`` + audit log. + + All transitions must go through this class so the SQLite row and + the audit chain stay in lock-step. + """ + + def __init__( + self, + *, + connection_factory: Callable[[], sqlite3.Connection], + repository: Repository, + audit_log: AuditLog, + clock: Callable[[], datetime] | None = None, + ) -> None: + self._connect = connection_factory + self._repo = repository + self._audit = audit_log + self._clock = clock or (lambda: datetime.now(UTC)) + + # ------------------------------------------------------------------ + # Status + # ------------------------------------------------------------------ + + def is_armed(self) -> bool: + conn = self._connect() + try: + state = self._repo.get_system_state(conn) + finally: + conn.close() + if state is None: + return False + return state.kill_switch == 1 + + # ------------------------------------------------------------------ + # Transitions + # ------------------------------------------------------------------ + + def arm(self, *, reason: str, source: str) -> None: + """Move the engine to the armed state. + + ``source`` is a short label (``"manual"``, ``"mcp_timeout"``, + ``"hash_chain"``) that goes both into the audit payload and the + log so we can attribute the trigger later. + """ + if not reason: + raise KillSwitchError("reason is required to arm the kill switch") + now = self._clock() + conn = self._connect() + try: + with transaction(conn): + state = self._repo.get_system_state(conn) + if state is None: + raise KillSwitchError( + "system_state singleton missing — call init_system_state first" + ) + if state.kill_switch == 1: + return # idempotent + self._repo.set_kill_switch(conn, armed=True, reason=reason, now=now) + finally: + conn.close() + self._audit.append( + event="KILL_SWITCH_ARMED", + payload={"reason": reason, "source": source}, + now=now, + ) + + def disarm(self, *, reason: str, source: str) -> None: + """Move the engine back to the disarmed state.""" + if not reason: + raise KillSwitchError("reason is required to disarm the kill switch") + now = self._clock() + conn = self._connect() + try: + with transaction(conn): + state = self._repo.get_system_state(conn) + if state is None: + raise KillSwitchError( + "system_state singleton missing — call init_system_state first" + ) + if state.kill_switch == 0: + return # idempotent + self._repo.set_kill_switch(conn, armed=False, reason=None, now=now) + finally: + conn.close() + self._audit.append( + event="KILL_SWITCH_DISARMED", + payload={"reason": reason, "source": source}, + now=now, + ) diff --git a/src/cerbero_bite/state/__init__.py b/src/cerbero_bite/state/__init__.py index e69de29..5e089ab 100644 --- a/src/cerbero_bite/state/__init__.py +++ b/src/cerbero_bite/state/__init__.py @@ -0,0 +1,27 @@ +"""Persistent state: SQLite schema, migrations, typed repository.""" + +from cerbero_bite.state.db import connect, run_migrations, transaction +from cerbero_bite.state.models import ( + DecisionRecord, + DvolSnapshot, + InstructionRecord, + ManualAction, + PositionRecord, + PositionStatus, + SystemStateRecord, +) +from cerbero_bite.state.repository import Repository + +__all__ = [ + "DecisionRecord", + "DvolSnapshot", + "InstructionRecord", + "ManualAction", + "PositionRecord", + "PositionStatus", + "Repository", + "SystemStateRecord", + "connect", + "run_migrations", + "transaction", +] diff --git a/src/cerbero_bite/state/db.py b/src/cerbero_bite/state/db.py new file mode 100644 index 0000000..e7640dd --- /dev/null +++ b/src/cerbero_bite/state/db.py @@ -0,0 +1,102 @@ +"""SQLite connection helpers and forward-only migrations. + +Connections use ``sqlite3`` with a few pragmas tuned for our workload: + +* ``foreign_keys=ON`` — referential integrity for ``instructions`` → + ``positions``. +* ``journal_mode=WAL`` — concurrent reader (the GUI) without blocking + the engine writer. +* ``synchronous=NORMAL`` — durable enough for our use-case (we also + append-fsync the audit chain) and noticeably faster than FULL. + +Migrations live in ``state/migrations/NNNN_.sql``. They are +applied in numeric order, each bumps ``PRAGMA user_version`` at the +last statement of the file. ``run_migrations`` is idempotent: it +re-checks ``user_version`` and replays only what is missing. +""" + +from __future__ import annotations + +import re +import sqlite3 +from collections.abc import Iterator +from contextlib import contextmanager +from importlib import resources +from pathlib import Path + +__all__ = ["connect", "current_version", "list_migrations", "run_migrations"] + + +_MIGRATION_PATTERN = re.compile(r"^(\d{4})_[a-z0-9_]+\.sql$") + + +def _apply_pragmas(conn: sqlite3.Connection) -> None: + conn.execute("PRAGMA foreign_keys = ON") + conn.execute("PRAGMA journal_mode = WAL") + conn.execute("PRAGMA synchronous = NORMAL") + + +def connect(path: str | Path) -> sqlite3.Connection: + """Open a connection to ``path`` with the standard pragmas applied.""" + db_path = Path(path) + db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect( + str(db_path), + isolation_level=None, # autocommit; we manage transactions explicitly + detect_types=sqlite3.PARSE_DECLTYPES, + ) + conn.row_factory = sqlite3.Row + _apply_pragmas(conn) + return conn + + +@contextmanager +def transaction(conn: sqlite3.Connection) -> Iterator[sqlite3.Connection]: + """Begin/commit/rollback wrapper that respects autocommit mode.""" + conn.execute("BEGIN") + try: + yield conn + except Exception: + conn.execute("ROLLBACK") + raise + else: + conn.execute("COMMIT") + + +def current_version(conn: sqlite3.Connection) -> int: + row = conn.execute("PRAGMA user_version").fetchone() + return int(row[0]) + + +def list_migrations() -> list[tuple[int, str, str]]: + """Return ``[(version, filename, sql_text), ...]`` sorted by version.""" + pkg_files = resources.files("cerbero_bite.state.migrations") + out: list[tuple[int, str, str]] = [] + for entry in pkg_files.iterdir(): + match = _MIGRATION_PATTERN.match(entry.name) + if not match: + continue + version = int(match.group(1)) + out.append((version, entry.name, entry.read_text(encoding="utf-8"))) + out.sort(key=lambda triple: triple[0]) + return out + + +def run_migrations(conn: sqlite3.Connection) -> int: + """Apply migrations whose version is greater than the current one. + + Returns the new ``user_version`` after the run. + """ + applied_to = current_version(conn) + for version, name, sql in list_migrations(): + if version <= applied_to: + continue + conn.executescript(sql) + new_version = current_version(conn) + if new_version != version: + raise RuntimeError( + f"Migration {name} did not bump user_version to {version} " + f"(observed {new_version})" + ) + applied_to = new_version + return applied_to diff --git a/src/cerbero_bite/state/migrations/0001_init.sql b/src/cerbero_bite/state/migrations/0001_init.sql new file mode 100644 index 0000000..8faa3c9 --- /dev/null +++ b/src/cerbero_bite/state/migrations/0001_init.sql @@ -0,0 +1,101 @@ +-- 0001_init.sql — initial schema for Cerbero Bite (docs/05-data-model.md) +-- +-- Forward-only. Run by state/migrations.py once user_version == 0. +-- Bumps user_version to 1 at the end. + +PRAGMA foreign_keys = ON; + +CREATE TABLE positions ( + proposal_id TEXT PRIMARY KEY, + spread_type TEXT NOT NULL, + asset TEXT NOT NULL DEFAULT 'ETH', + expiry TEXT NOT NULL, + short_strike NUMERIC NOT NULL, + long_strike NUMERIC NOT NULL, + short_instrument TEXT NOT NULL, + long_instrument TEXT NOT NULL, + n_contracts INTEGER NOT NULL, + spread_width_usd NUMERIC NOT NULL, + spread_width_pct NUMERIC NOT NULL, + credit_eth NUMERIC NOT NULL, + credit_usd NUMERIC NOT NULL, + max_loss_usd NUMERIC NOT NULL, + spot_at_entry NUMERIC NOT NULL, + dvol_at_entry NUMERIC NOT NULL, + delta_at_entry NUMERIC NOT NULL, + eth_price_at_entry NUMERIC NOT NULL, + proposed_at TEXT NOT NULL, + opened_at TEXT, + closed_at TEXT, + close_reason TEXT, + debit_paid_eth NUMERIC, + pnl_eth NUMERIC, + pnl_usd NUMERIC, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE INDEX idx_positions_status ON positions(status); +CREATE INDEX idx_positions_closed_at ON positions(closed_at); + +CREATE TABLE instructions ( + instruction_id TEXT PRIMARY KEY, + proposal_id TEXT NOT NULL REFERENCES positions(proposal_id), + kind TEXT NOT NULL, + payload_json TEXT NOT NULL, + sent_at TEXT NOT NULL, + acknowledged_at TEXT, + filled_at TEXT, + cancelled_at TEXT, + actual_fill_eth NUMERIC, + actual_fees_eth NUMERIC +); + +CREATE INDEX idx_instructions_proposal ON instructions(proposal_id); + +CREATE TABLE decisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + decision_type TEXT NOT NULL, + proposal_id TEXT, + timestamp TEXT NOT NULL, + inputs_json TEXT NOT NULL, + outputs_json TEXT NOT NULL, + action_taken TEXT, + notes TEXT +); + +CREATE INDEX idx_decisions_timestamp ON decisions(timestamp); +CREATE INDEX idx_decisions_proposal ON decisions(proposal_id); + +CREATE TABLE dvol_history ( + timestamp TEXT PRIMARY KEY, + dvol NUMERIC NOT NULL, + eth_spot NUMERIC NOT NULL +); + +CREATE TABLE manual_actions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + kind TEXT NOT NULL, + proposal_id TEXT, + payload_json TEXT, + created_at TEXT NOT NULL, + consumed_at TEXT, + consumed_by TEXT, + result TEXT +); + +CREATE INDEX idx_manual_actions_unconsumed ON manual_actions(consumed_at); + +CREATE TABLE system_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + kill_switch INTEGER NOT NULL DEFAULT 0, + kill_reason TEXT, + kill_at TEXT, + last_health_check TEXT NOT NULL, + last_kelly_calib TEXT, + config_version TEXT NOT NULL, + started_at TEXT NOT NULL +); + +PRAGMA user_version = 1; diff --git a/src/cerbero_bite/state/migrations/__init__.py b/src/cerbero_bite/state/migrations/__init__.py new file mode 100644 index 0000000..433c76c --- /dev/null +++ b/src/cerbero_bite/state/migrations/__init__.py @@ -0,0 +1 @@ +"""Forward-only SQL migrations for the Cerbero Bite state database.""" diff --git a/src/cerbero_bite/state/models.py b/src/cerbero_bite/state/models.py new file mode 100644 index 0000000..cb0c410 --- /dev/null +++ b/src/cerbero_bite/state/models.py @@ -0,0 +1,154 @@ +"""Pydantic record types mirroring the SQLite tables. + +Every numeric column documented as ``NUMERIC`` in +``state/migrations/0001_init.sql`` is exposed as :class:`decimal.Decimal` +on the Python side. The repository layer is responsible for serialising +to ``TEXT`` (using ``str``) when writing and parsing back when reading, +so precision is never lost via ``float`` coercion. +""" + +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal +from typing import Literal +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + +__all__ = [ + "DecisionRecord", + "DvolSnapshot", + "InstructionRecord", + "ManualAction", + "PositionRecord", + "PositionStatus", + "SystemStateRecord", +] + + +PositionStatus = Literal[ + "proposed", + "awaiting_fill", + "open", + "closing", + "closed", + "cancelled", +] + + +class PositionRecord(BaseModel): + """Row of the ``positions`` table.""" + + model_config = ConfigDict(extra="forbid") + + proposal_id: UUID + spread_type: str + asset: str = "ETH" + expiry: datetime + short_strike: Decimal + long_strike: Decimal + short_instrument: str + long_instrument: str + n_contracts: int + spread_width_usd: Decimal + spread_width_pct: Decimal + credit_eth: Decimal + credit_usd: Decimal + max_loss_usd: Decimal + spot_at_entry: Decimal + dvol_at_entry: Decimal + delta_at_entry: Decimal + eth_price_at_entry: Decimal + proposed_at: datetime + opened_at: datetime | None = None + closed_at: datetime | None = None + close_reason: str | None = None + debit_paid_eth: Decimal | None = None + pnl_eth: Decimal | None = None + pnl_usd: Decimal | None = None + status: PositionStatus + created_at: datetime + updated_at: datetime + + +class InstructionRecord(BaseModel): + """Row of the ``instructions`` table.""" + + model_config = ConfigDict(extra="forbid") + + instruction_id: UUID + proposal_id: UUID + kind: Literal["open_combo", "close_combo"] + payload_json: str + sent_at: datetime + acknowledged_at: datetime | None = None + filled_at: datetime | None = None + cancelled_at: datetime | None = None + actual_fill_eth: Decimal | None = None + actual_fees_eth: Decimal | None = None + + +class DecisionRecord(BaseModel): + """Row of the ``decisions`` table. + + ``id`` is :class:`int` and may be ``None`` before the row has been + inserted; the repository sets it after the auto-increment fires. + """ + + model_config = ConfigDict(extra="forbid") + + id: int | None = None + decision_type: Literal["entry_check", "exit_check", "kelly_recalib"] + proposal_id: UUID | None = None + timestamp: datetime + inputs_json: str + outputs_json: str + action_taken: str | None = None + notes: str | None = None + + +class DvolSnapshot(BaseModel): + """Row of the ``dvol_history`` table.""" + + model_config = ConfigDict(extra="forbid") + + timestamp: datetime + dvol: Decimal + eth_spot: Decimal + + +class ManualAction(BaseModel): + """Row of the ``manual_actions`` table.""" + + model_config = ConfigDict(extra="forbid") + + id: int | None = None + kind: Literal[ + "approve_proposal", + "reject_proposal", + "force_close", + "arm_kill", + "disarm_kill", + ] + proposal_id: UUID | None = None + payload_json: str | None = None + created_at: datetime + consumed_at: datetime | None = None + consumed_by: str | None = None + result: str | None = None + + +class SystemStateRecord(BaseModel): + """Singleton row of the ``system_state`` table.""" + + model_config = ConfigDict(extra="forbid") + + id: int = Field(default=1) + kill_switch: int = 0 + kill_reason: str | None = None + kill_at: datetime | None = None + last_health_check: datetime + last_kelly_calib: datetime | None = None + config_version: str + started_at: datetime diff --git a/src/cerbero_bite/state/repository.py b/src/cerbero_bite/state/repository.py new file mode 100644 index 0000000..53f476e --- /dev/null +++ b/src/cerbero_bite/state/repository.py @@ -0,0 +1,553 @@ +"""Typed CRUD layer over the SQLite database. + +All methods take a :class:`sqlite3.Connection` so callers can compose a +single transaction across multiple writes (the orchestrator does this +when persisting an entry decision + the resulting proposal). The +repository never opens its own connection: that responsibility is left +to :func:`cerbero_bite.state.db.connect`. + +Decimals are stored as TEXT to preserve precision (see +``state/models.py``). +""" + +from __future__ import annotations + +import sqlite3 +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any +from uuid import UUID + +from cerbero_bite.state.models import ( + DecisionRecord, + DvolSnapshot, + InstructionRecord, + ManualAction, + PositionRecord, + PositionStatus, + SystemStateRecord, +) + +__all__ = ["Repository"] + + +# --------------------------------------------------------------------------- +# Encoding helpers +# --------------------------------------------------------------------------- + + +def _enc_dec(value: Decimal | None) -> str | None: + return None if value is None else str(value) + + +def _enc_dt(value: datetime | None) -> str | None: + if value is None: + return None + if value.tzinfo is None: + value = value.replace(tzinfo=UTC) + return value.astimezone(UTC).isoformat() + + +def _enc_uuid(value: UUID | None) -> str | None: + return None if value is None else str(value) + + +def _dec_dec(value: Any) -> Decimal | None: + if value is None: + return None + return Decimal(str(value)) + + +def _dec_dt(value: Any) -> datetime | None: + if value is None: + return None + return datetime.fromisoformat(str(value)) + + +def _dec_uuid(value: Any) -> UUID | None: + if value is None: + return None + return UUID(str(value)) + + +# --------------------------------------------------------------------------- +# Repository +# --------------------------------------------------------------------------- + + +class Repository: + """Typed CRUD wrapper. One instance per process, all calls take a conn.""" + + # ------------------------------------------------------------------ + # positions + # ------------------------------------------------------------------ + + def create_position(self, conn: sqlite3.Connection, record: PositionRecord) -> None: + conn.execute( + """ + INSERT INTO positions ( + proposal_id, spread_type, asset, expiry, + short_strike, long_strike, short_instrument, long_instrument, + n_contracts, spread_width_usd, spread_width_pct, + credit_eth, credit_usd, max_loss_usd, + spot_at_entry, dvol_at_entry, delta_at_entry, eth_price_at_entry, + proposed_at, opened_at, closed_at, close_reason, + debit_paid_eth, pnl_eth, pnl_usd, + status, created_at, updated_at + ) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) + """, + ( + _enc_uuid(record.proposal_id), + record.spread_type, + record.asset, + _enc_dt(record.expiry), + _enc_dec(record.short_strike), + _enc_dec(record.long_strike), + record.short_instrument, + record.long_instrument, + record.n_contracts, + _enc_dec(record.spread_width_usd), + _enc_dec(record.spread_width_pct), + _enc_dec(record.credit_eth), + _enc_dec(record.credit_usd), + _enc_dec(record.max_loss_usd), + _enc_dec(record.spot_at_entry), + _enc_dec(record.dvol_at_entry), + _enc_dec(record.delta_at_entry), + _enc_dec(record.eth_price_at_entry), + _enc_dt(record.proposed_at), + _enc_dt(record.opened_at), + _enc_dt(record.closed_at), + record.close_reason, + _enc_dec(record.debit_paid_eth), + _enc_dec(record.pnl_eth), + _enc_dec(record.pnl_usd), + record.status, + _enc_dt(record.created_at), + _enc_dt(record.updated_at), + ), + ) + + def get_position( + self, conn: sqlite3.Connection, proposal_id: UUID + ) -> PositionRecord | None: + row = conn.execute( + "SELECT * FROM positions WHERE proposal_id = ?", + (_enc_uuid(proposal_id),), + ).fetchone() + return None if row is None else _row_to_position(row) + + def list_positions( + self, + conn: sqlite3.Connection, + *, + status: PositionStatus | None = None, + ) -> list[PositionRecord]: + if status is None: + cursor = conn.execute( + "SELECT * FROM positions ORDER BY created_at DESC" + ) + else: + cursor = conn.execute( + "SELECT * FROM positions WHERE status = ? ORDER BY created_at DESC", + (status,), + ) + return [_row_to_position(r) for r in cursor.fetchall()] + + def list_open_positions(self, conn: sqlite3.Connection) -> list[PositionRecord]: + rows = conn.execute( + "SELECT * FROM positions WHERE status IN ('open', 'awaiting_fill', " + "'closing') ORDER BY created_at DESC" + ).fetchall() + return [_row_to_position(r) for r in rows] + + def update_position_status( + self, + conn: sqlite3.Connection, + proposal_id: UUID, + *, + status: PositionStatus, + opened_at: datetime | None = None, + closed_at: datetime | None = None, + close_reason: str | None = None, + debit_paid_eth: Decimal | None = None, + pnl_eth: Decimal | None = None, + pnl_usd: Decimal | None = None, + now: datetime, + ) -> None: + sets: list[str] = ["status = ?", "updated_at = ?"] + params: list[Any] = [status, _enc_dt(now)] + if opened_at is not None: + sets.append("opened_at = ?") + params.append(_enc_dt(opened_at)) + if closed_at is not None: + sets.append("closed_at = ?") + params.append(_enc_dt(closed_at)) + if close_reason is not None: + sets.append("close_reason = ?") + params.append(close_reason) + if debit_paid_eth is not None: + sets.append("debit_paid_eth = ?") + params.append(_enc_dec(debit_paid_eth)) + if pnl_eth is not None: + sets.append("pnl_eth = ?") + params.append(_enc_dec(pnl_eth)) + if pnl_usd is not None: + sets.append("pnl_usd = ?") + params.append(_enc_dec(pnl_usd)) + params.append(_enc_uuid(proposal_id)) + conn.execute( + f"UPDATE positions SET {', '.join(sets)} WHERE proposal_id = ?", + params, + ) + + def count_concurrent_positions(self, conn: sqlite3.Connection) -> int: + row = conn.execute( + "SELECT COUNT(*) FROM positions WHERE status IN " + "('awaiting_fill', 'open', 'closing')" + ).fetchone() + return int(row[0]) + + # ------------------------------------------------------------------ + # instructions + # ------------------------------------------------------------------ + + def create_instruction( + self, conn: sqlite3.Connection, record: InstructionRecord + ) -> None: + conn.execute( + """ + INSERT INTO instructions ( + instruction_id, proposal_id, kind, payload_json, + sent_at, acknowledged_at, filled_at, cancelled_at, + actual_fill_eth, actual_fees_eth + ) VALUES (?,?,?,?,?,?,?,?,?,?) + """, + ( + _enc_uuid(record.instruction_id), + _enc_uuid(record.proposal_id), + record.kind, + record.payload_json, + _enc_dt(record.sent_at), + _enc_dt(record.acknowledged_at), + _enc_dt(record.filled_at), + _enc_dt(record.cancelled_at), + _enc_dec(record.actual_fill_eth), + _enc_dec(record.actual_fees_eth), + ), + ) + + def update_instruction( + self, + conn: sqlite3.Connection, + instruction_id: UUID, + *, + acknowledged_at: datetime | None = None, + filled_at: datetime | None = None, + cancelled_at: datetime | None = None, + actual_fill_eth: Decimal | None = None, + actual_fees_eth: Decimal | None = None, + ) -> None: + sets: list[str] = [] + params: list[Any] = [] + if acknowledged_at is not None: + sets.append("acknowledged_at = ?") + params.append(_enc_dt(acknowledged_at)) + if filled_at is not None: + sets.append("filled_at = ?") + params.append(_enc_dt(filled_at)) + if cancelled_at is not None: + sets.append("cancelled_at = ?") + params.append(_enc_dt(cancelled_at)) + if actual_fill_eth is not None: + sets.append("actual_fill_eth = ?") + params.append(_enc_dec(actual_fill_eth)) + if actual_fees_eth is not None: + sets.append("actual_fees_eth = ?") + params.append(_enc_dec(actual_fees_eth)) + if not sets: + return + params.append(_enc_uuid(instruction_id)) + conn.execute( + f"UPDATE instructions SET {', '.join(sets)} " + f"WHERE instruction_id = ?", + params, + ) + + def list_instructions( + self, conn: sqlite3.Connection, proposal_id: UUID + ) -> list[InstructionRecord]: + rows = conn.execute( + "SELECT * FROM instructions WHERE proposal_id = ? ORDER BY sent_at", + (_enc_uuid(proposal_id),), + ).fetchall() + return [_row_to_instruction(r) for r in rows] + + # ------------------------------------------------------------------ + # decisions + # ------------------------------------------------------------------ + + def record_decision( + self, conn: sqlite3.Connection, record: DecisionRecord + ) -> int: + cursor = conn.execute( + """ + INSERT INTO decisions ( + decision_type, proposal_id, timestamp, + inputs_json, outputs_json, action_taken, notes + ) VALUES (?,?,?,?,?,?,?) + """, + ( + record.decision_type, + _enc_uuid(record.proposal_id), + _enc_dt(record.timestamp), + record.inputs_json, + record.outputs_json, + record.action_taken, + record.notes, + ), + ) + return int(cursor.lastrowid or 0) + + def list_decisions( + self, + conn: sqlite3.Connection, + *, + proposal_id: UUID | None = None, + limit: int = 100, + ) -> list[DecisionRecord]: + if proposal_id is None: + rows = conn.execute( + "SELECT * FROM decisions ORDER BY timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM decisions WHERE proposal_id = ? " + "ORDER BY timestamp DESC LIMIT ?", + (_enc_uuid(proposal_id), limit), + ).fetchall() + return [_row_to_decision(r) for r in rows] + + # ------------------------------------------------------------------ + # dvol_history + # ------------------------------------------------------------------ + + def record_dvol_snapshot( + self, conn: sqlite3.Connection, snapshot: DvolSnapshot + ) -> None: + conn.execute( + "INSERT OR REPLACE INTO dvol_history(timestamp, dvol, eth_spot) " + "VALUES (?,?,?)", + ( + _enc_dt(snapshot.timestamp), + _enc_dec(snapshot.dvol), + _enc_dec(snapshot.eth_spot), + ), + ) + + # ------------------------------------------------------------------ + # manual_actions + # ------------------------------------------------------------------ + + def enqueue_manual_action( + self, conn: sqlite3.Connection, action: ManualAction + ) -> int: + cursor = conn.execute( + "INSERT INTO manual_actions(kind, proposal_id, payload_json, created_at) " + "VALUES (?,?,?,?)", + ( + action.kind, + _enc_uuid(action.proposal_id), + action.payload_json, + _enc_dt(action.created_at), + ), + ) + return int(cursor.lastrowid or 0) + + def next_unconsumed_action( + self, conn: sqlite3.Connection + ) -> ManualAction | None: + row = conn.execute( + "SELECT * FROM manual_actions WHERE consumed_at IS NULL " + "ORDER BY created_at ASC LIMIT 1" + ).fetchone() + return None if row is None else _row_to_manual(row) + + def mark_action_consumed( + self, + conn: sqlite3.Connection, + action_id: int, + *, + consumed_by: str, + result: str, + now: datetime, + ) -> None: + conn.execute( + "UPDATE manual_actions SET consumed_at = ?, consumed_by = ?, " + "result = ? WHERE id = ?", + (_enc_dt(now), consumed_by, result, action_id), + ) + + # ------------------------------------------------------------------ + # system_state + # ------------------------------------------------------------------ + + def init_system_state( + self, conn: sqlite3.Connection, *, config_version: str, now: datetime + ) -> None: + """Insert the singleton row if it does not already exist.""" + existing = conn.execute( + "SELECT 1 FROM system_state WHERE id = 1" + ).fetchone() + if existing is not None: + return + conn.execute( + "INSERT INTO system_state(id, kill_switch, last_health_check, " + "config_version, started_at) VALUES (1, 0, ?, ?, ?)", + (_enc_dt(now), config_version, _enc_dt(now)), + ) + + def get_system_state( + self, conn: sqlite3.Connection + ) -> SystemStateRecord | None: + row = conn.execute("SELECT * FROM system_state WHERE id = 1").fetchone() + if row is None: + return None + return SystemStateRecord( + id=int(row["id"]), + kill_switch=int(row["kill_switch"]), + kill_reason=row["kill_reason"], + kill_at=_dec_dt(row["kill_at"]), + last_health_check=_dec_dt_required(row["last_health_check"]), + last_kelly_calib=_dec_dt(row["last_kelly_calib"]), + config_version=row["config_version"], + started_at=_dec_dt_required(row["started_at"]), + ) + + def set_kill_switch( + self, + conn: sqlite3.Connection, + *, + armed: bool, + reason: str | None, + now: datetime, + ) -> None: + conn.execute( + "UPDATE system_state SET kill_switch = ?, kill_reason = ?, " + "kill_at = ?, last_health_check = ? WHERE id = 1", + ( + 1 if armed else 0, + reason, + _enc_dt(now) if armed else None, + _enc_dt(now), + ), + ) + + def touch_health_check( + self, conn: sqlite3.Connection, *, now: datetime + ) -> None: + conn.execute( + "UPDATE system_state SET last_health_check = ? WHERE id = 1", + (_enc_dt(now),), + ) + + +# --------------------------------------------------------------------------- +# Row → model converters +# --------------------------------------------------------------------------- + + +def _dec_dt_required(value: Any) -> datetime: + out = _dec_dt(value) + if out is None: + raise ValueError("expected non-null datetime in row") + return out + + +def _row_to_position(row: sqlite3.Row) -> PositionRecord: + proposal_id = _dec_uuid(row["proposal_id"]) + if proposal_id is None: + raise ValueError("positions.proposal_id was NULL") + return PositionRecord( + proposal_id=proposal_id, + spread_type=row["spread_type"], + asset=row["asset"], + expiry=_dec_dt_required(row["expiry"]), + short_strike=_dec_dec_required(row["short_strike"]), + long_strike=_dec_dec_required(row["long_strike"]), + short_instrument=row["short_instrument"], + long_instrument=row["long_instrument"], + n_contracts=int(row["n_contracts"]), + spread_width_usd=_dec_dec_required(row["spread_width_usd"]), + spread_width_pct=_dec_dec_required(row["spread_width_pct"]), + credit_eth=_dec_dec_required(row["credit_eth"]), + credit_usd=_dec_dec_required(row["credit_usd"]), + max_loss_usd=_dec_dec_required(row["max_loss_usd"]), + spot_at_entry=_dec_dec_required(row["spot_at_entry"]), + dvol_at_entry=_dec_dec_required(row["dvol_at_entry"]), + delta_at_entry=_dec_dec_required(row["delta_at_entry"]), + eth_price_at_entry=_dec_dec_required(row["eth_price_at_entry"]), + proposed_at=_dec_dt_required(row["proposed_at"]), + opened_at=_dec_dt(row["opened_at"]), + closed_at=_dec_dt(row["closed_at"]), + close_reason=row["close_reason"], + debit_paid_eth=_dec_dec(row["debit_paid_eth"]), + pnl_eth=_dec_dec(row["pnl_eth"]), + pnl_usd=_dec_dec(row["pnl_usd"]), + status=row["status"], + created_at=_dec_dt_required(row["created_at"]), + updated_at=_dec_dt_required(row["updated_at"]), + ) + + +def _row_to_instruction(row: sqlite3.Row) -> InstructionRecord: + instruction_id = _dec_uuid(row["instruction_id"]) + proposal_id = _dec_uuid(row["proposal_id"]) + if instruction_id is None or proposal_id is None: + raise ValueError("instructions row missing required UUID") + return InstructionRecord( + instruction_id=instruction_id, + proposal_id=proposal_id, + kind=row["kind"], + payload_json=row["payload_json"], + sent_at=_dec_dt_required(row["sent_at"]), + acknowledged_at=_dec_dt(row["acknowledged_at"]), + filled_at=_dec_dt(row["filled_at"]), + cancelled_at=_dec_dt(row["cancelled_at"]), + actual_fill_eth=_dec_dec(row["actual_fill_eth"]), + actual_fees_eth=_dec_dec(row["actual_fees_eth"]), + ) + + +def _row_to_decision(row: sqlite3.Row) -> DecisionRecord: + return DecisionRecord( + id=int(row["id"]), + decision_type=row["decision_type"], + proposal_id=_dec_uuid(row["proposal_id"]), + timestamp=_dec_dt_required(row["timestamp"]), + inputs_json=row["inputs_json"], + outputs_json=row["outputs_json"], + action_taken=row["action_taken"], + notes=row["notes"], + ) + + +def _row_to_manual(row: sqlite3.Row) -> ManualAction: + return ManualAction( + id=int(row["id"]), + kind=row["kind"], + proposal_id=_dec_uuid(row["proposal_id"]), + payload_json=row["payload_json"], + created_at=_dec_dt_required(row["created_at"]), + consumed_at=_dec_dt(row["consumed_at"]), + consumed_by=row["consumed_by"], + result=row["result"], + ) + + +def _dec_dec_required(value: Any) -> Decimal: + out = _dec_dec(value) + if out is None: + raise ValueError("expected non-null Decimal in row") + return out diff --git a/strategy.local.yaml.example b/strategy.local.yaml.example new file mode 100644 index 0000000..dc71355 --- /dev/null +++ b/strategy.local.yaml.example @@ -0,0 +1,15 @@ +# strategy.local.yaml — local override (gitignored). +# +# Copy to strategy.local.yaml and edit only the keys you need to +# change. Top-level sections are deep-merged onto strategy.yaml at +# load time; the merged result is logged as OVERRIDE_APPLIED. +# +# Typical use cases: +# * Halve cap_per_trade in dry-run. +# * Force max_concurrent_positions to 0 to freeze entries without +# stopping the engine. +# * Lower kelly_fraction temporarily after a drawdown. + +# Example: emergency entry freeze. +# sizing: +# max_concurrent_positions: 0 diff --git a/strategy.yaml b/strategy.yaml new file mode 100644 index 0000000..8204ef8 --- /dev/null +++ b/strategy.yaml @@ -0,0 +1,149 @@ +# strategy.yaml — Cerberus Bite golden config v1.0.0 +# +# Source of truth for every threshold consumed by the rule engine. +# Modifying this file is an explicit decision of Adriano. Each change +# bumps `config_version`, regenerates `config_hash` (cerbero-bite +# config hash), and lands as a separate commit with the motivation in +# the commit message. + +config_version: "1.0.0" +config_hash: "a857dc4b187cbdf5ac3f04c4aad48ab7587659bc9a3139db206566e10e2fa5e5" +last_review: "2026-04-26" +last_reviewer: "Adriano" + +asset: + symbol: "ETH" + exchange: "deribit" + +entry: + cron: "0 14 * * MON" + skip_holidays_country: "IT" + + capital_min_usd: "720" + dvol_min: "35" + dvol_max: "90" + funding_perp_abs_max_annualized: "0.80" + eth_holdings_pct_max: "0.30" + no_position_concurrent: true + exclude_macro_severity: ["high"] + exclude_macro_countries: ["US", "EU"] + + trend_window_days: 30 + trend_bull_threshold_pct: "0.05" + trend_bear_threshold_pct: "-0.05" + funding_bull_threshold_annualized: "0.20" + funding_bear_threshold_annualized: "-0.20" + iron_condor_dvol_min: "55" + iron_condor_adx_max: "20" + iron_condor_trend_neutral_band_pct: "0.05" + +structure: + dte_target: 18 + dte_min: 14 + dte_max: 21 + + short_strike: + delta_target: "0.12" + delta_min: "0.10" + delta_max: "0.15" + distance_otm_pct_min: "0.15" + distance_otm_pct_max: "0.25" + + spread_width: + target_pct_of_spot: "0.04" + min_pct_of_spot: "0.03" + max_pct_of_spot: "0.05" + + credit_to_width_ratio_min: "0.30" + +liquidity: + open_interest_min: 100 + volume_24h_min: 20 + bid_ask_spread_pct_max: "0.15" + book_depth_top3_min: 5 + slippage_pct_of_credit_max: "0.08" + +sizing: + kelly_fraction: "0.13" + + cap_per_trade_eur: "200" + cap_aggregate_open_eur: "1000" + max_concurrent_positions: 1 + + max_contracts_per_trade: 4 + + dvol_adjustment: + - {dvol_under: "45", multiplier: "1.00"} + - {dvol_under: "60", multiplier: "0.85"} + - {dvol_under: "80", multiplier: "0.65"} + dvol_no_entry_threshold: "80" + +exit: + profit_take_pct_of_credit: "0.50" + stop_loss_mark_x_credit: "2.50" + vol_stop_dvol_increase: "10" + time_stop_dte_remaining: 7 + time_stop_skip_if_close_to_profit_pct: "0.70" + delta_breach_threshold: "0.30" + adverse_move_4h_pct: "0.05" + + monitor_cron: "0 2,14 * * *" + user_confirmation_timeout_min: 30 + + escalate_on_timeout: + - "CLOSE_STOP" + - "CLOSE_VOL" + - "CLOSE_DELTA" + +execution: + combo_only: true + initial_limit: "mid" + reprice_step_ticks: 1 + reprice_max_steps: 3 + reprice_max_steps_urgent: 5 + order_tif: "GTC" + order_expiry_min: 30 + ack_timeout_s: 300 + +monitoring: + health_check_interval_s: 300 + health_failures_before_kill: 3 + health_failures_before_restart: 5 + + daily_digest_cron: "0 8 * * *" + monthly_report_cron: "0 12 1 * *" + +storage: + sqlite_path: "data/state.sqlite" + log_path: "data/log/" + log_retention_days: 365 + backup_path: "data/backups/" + backup_retention_days: 30 + +mcp: + config_file: "~/.config/cerbero-suite/mcp.json" + call_timeout_s: 8 + retry_max: 3 + retry_base_delay_s: 1 + + required_versions: + cerbero-deribit: "^2.0.0" + cerbero-hyperliquid: "^1.5.0" + cerbero-memory: "^4.0.0" + cerbero-portfolio: "^1.2.0" + cerbero-macro: "^1.0.0" + cerbero-sentiment: "^1.0.0" + cerbero-telegram: "^1.0.0" + cerbero-brain-bridge: "^1.0.0" + +telegram: + parse_mode: "MarkdownV2" + confirmation_timeout_min: 60 + exit_confirmation_timeout_min: 30 + backup_channel_on_critical: true + +kelly_recalibration: + lookback_days: 365 + min_sample_low_confidence: 30 + min_sample_high_confidence: 100 + weight_when_medium_confidence: "0.50" diff --git a/tests/integration/test_dead_man_sh.py b/tests/integration/test_dead_man_sh.py new file mode 100644 index 0000000..7bbcf5e --- /dev/null +++ b/tests/integration/test_dead_man_sh.py @@ -0,0 +1,111 @@ +"""Smoke tests for the dead_man.sh shell watchdog.""" + +from __future__ import annotations + +import shutil +import sqlite3 +import subprocess +from datetime import UTC, datetime, timedelta +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] +SCRIPT = REPO_ROOT / "scripts" / "dead_man.sh" + +_SQLITE3_BIN = shutil.which("sqlite3") + +pytestmark = [ + pytest.mark.integration, +] + + +def _require_sqlite3() -> None: + if _SQLITE3_BIN is None: + pytest.skip("sqlite3 CLI not installed on this host") + + +def _setup_project(tmp_path: Path) -> Path: + project = tmp_path / "project" + (project / "data" / "log").mkdir(parents=True) + (project / "scripts").mkdir(parents=True) + shutil.copy(SCRIPT, project / "scripts" / "dead_man.sh") + return project + + +def _write_health(project: Path, ts: datetime) -> None: + log_file = project / "data" / "log" / f"cerbero-bite-{ts:%Y-%m-%d}.jsonl" + log_file.write_text( + f'{{"ts": "{ts.astimezone(UTC).isoformat()}", "event": "HEALTH_OK"}}\n', + encoding="utf-8", + ) + + +def _run(project: Path, threshold: int = 900) -> subprocess.CompletedProcess[str]: + return subprocess.run( + ["bash", str(project / "scripts" / "dead_man.sh")], + env={ + "PATH": "/usr/bin:/bin", + "PROJECT_ROOT": str(project), + "DEAD_MAN_THRESHOLD_SECONDS": str(threshold), + }, + capture_output=True, + text=True, + check=False, + ) + + +def test_dead_man_exits_zero_when_recent_health_ok(tmp_path: Path) -> None: + project = _setup_project(tmp_path) + _write_health(project, datetime.now(UTC) - timedelta(seconds=60)) + result = _run(project, threshold=900) + assert result.returncode == 0, result.stderr + + +def test_dead_man_arms_kill_switch_when_silent(tmp_path: Path) -> None: + _require_sqlite3() + project = _setup_project(tmp_path) + _write_health(project, datetime.now(UTC) - timedelta(seconds=2000)) + + # Pre-create the SQLite system_state singleton; otherwise the script + # has nothing to update. + db = project / "data" / "state.sqlite" + conn = sqlite3.connect(str(db)) + conn.execute( + "CREATE TABLE system_state (id INTEGER PRIMARY KEY CHECK(id=1), " + "kill_switch INTEGER NOT NULL DEFAULT 0, kill_reason TEXT, " + "kill_at TEXT, last_health_check TEXT NOT NULL, " + "last_kelly_calib TEXT, config_version TEXT NOT NULL, " + "started_at TEXT NOT NULL)" + ) + conn.execute( + "INSERT INTO system_state(id, last_health_check, config_version, started_at) " + "VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')" + ) + conn.commit() + conn.close() + + result = _run(project, threshold=900) + assert result.returncode == 1 + + # Kill switch armed. + conn = sqlite3.connect(str(db)) + try: + kill = conn.execute("SELECT kill_switch FROM system_state").fetchone()[0] + finally: + conn.close() + assert kill == 1 + + # Alert file exists. + alert = project / "data" / "log" / "dead-man-alert.txt" + assert alert.exists() + assert "dead_man" in alert.read_text(encoding="utf-8") + + +def test_dead_man_handles_missing_log_file(tmp_path: Path) -> None: + project = _setup_project(tmp_path) + # No log file at all. + result = _run(project, threshold=900) + assert result.returncode == 1 + alert = project / "data" / "log" / "dead-man-alert.txt" + assert alert.exists() diff --git a/tests/unit/test_audit_log.py b/tests/unit/test_audit_log.py new file mode 100644 index 0000000..ce306d9 --- /dev/null +++ b/tests/unit/test_audit_log.py @@ -0,0 +1,273 @@ +"""Audit chain writer + verifier tests.""" + +from __future__ import annotations + +import hashlib +from datetime import UTC, datetime, timedelta +from pathlib import Path + +import pytest + +from cerbero_bite.safety.audit_log import ( + GENESIS_HASH, + AuditChainError, + AuditLog, + iter_entries, + verify_chain, +) + + +def test_empty_file_verifies_with_zero_entries(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + assert verify_chain(path) == 0 + + +def test_first_entry_uses_genesis_prev_hash(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + entry = log.append( + event="ENGINE_START", + payload={"version": "1.0.0"}, + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + assert entry.prev_hash == GENESIS_HASH + assert entry.hash != GENESIS_HASH + + +def test_chain_links_subsequent_entries(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + e1 = log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + e2 = log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)) + e3 = log.append(event="C", payload={"i": 3}, now=datetime(2026, 4, 27, 14, 2, tzinfo=UTC)) + assert e2.prev_hash == e1.hash + assert e3.prev_hash == e2.hash + assert verify_chain(path) == 3 + + +def test_iter_entries_yields_in_order(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)) + events = [e.event for e in iter_entries(path)] + assert events == ["A", "B"] + + +def test_log_resumes_chain_after_reopen(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + first = AuditLog(path) + e1 = first.append( + event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + ) + + second = AuditLog(path) + assert second.last_hash == e1.hash + e2 = second.append( + event="B", payload={"k": "v"}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC) + ) + assert e2.prev_hash == e1.hash + assert verify_chain(path) == 2 + + +def test_payload_with_pipe_character_round_trips(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append( + event="NOTE", + payload={"text": "first|second|third"}, + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + entries = list(iter_entries(path)) + assert entries[0].payload == {"text": "first|second|third"} + assert verify_chain(path) == 1 + + +def test_tampered_payload_breaks_chain(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)) + + # Mutate the first line's payload by hand. + text = path.read_text(encoding="utf-8").splitlines() + text[0] = text[0].replace('"i":1', '"i":99') + path.write_text("\n".join(text) + "\n", encoding="utf-8") + + with pytest.raises(AuditChainError, match="hash mismatch"): + verify_chain(path) + + +def test_verify_chain_skips_blank_lines(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + raw = path.read_text(encoding="utf-8") + path.write_text("\n" + raw + "\n \n", encoding="utf-8") + # The chain still verifies despite the surrounding whitespace lines. + assert verify_chain(path) == 1 + + +def test_prev_hash_mismatch_between_entries_is_caught(tmp_path: Path) -> None: + """Second line's prev_hash points to a different chain — verify_chain rejects.""" + path = tmp_path / "audit.log" + log = AuditLog(path) + e1 = log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + + # Build a synthetic second line whose prev_hash != e1.hash but whose + # own hash is correctly computed from that bogus prev_hash. + fake_prev = "0" * 32 + "f" * 32 + ts2 = "2026-04-27T14:01:00+00:00" + payload_json = "{}" + raw = f"{ts2}|B|{payload_json}|{fake_prev}" + fake_hash = hashlib.sha256(raw.encode()).hexdigest() + line = f"{ts2}|B|{payload_json}|prev_hash={fake_prev}|hash={fake_hash}\n" + with path.open("a", encoding="utf-8") as fh: + fh.write(line) + + assert e1.hash != fake_prev # sanity + with pytest.raises(AuditChainError, match="prev_hash mismatch"): + verify_chain(path) + + +def test_tampered_prev_hash_breaks_chain(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + log.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)) + + # Inject an unrelated prev_hash on the second line. + lines = path.read_text(encoding="utf-8").splitlines() + lines[1] = lines[1].replace("prev_hash=", "prev_hash=" + "f" * 64 + "X") + # Truncate to recover length: replace prev_hash field with all-ff. + lines[1] = lines[1].replace("X", "") + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + with pytest.raises(AuditChainError): + verify_chain(path) + + +def test_malformed_line_raises_chain_error(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text("not-a-valid-line\n", encoding="utf-8") + with pytest.raises(AuditChainError): + verify_chain(path) + + +def test_parser_rejects_missing_hash_field(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text( + "2026-04-27T14:00:00+00:00|EVT|{}|prev_hash=" + "0" * 64 + "\n", + encoding="utf-8", + ) + with pytest.raises(AuditChainError, match="hash="): + verify_chain(path) + + +def test_parser_rejects_missing_prev_hash_field(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text( + "2026-04-27T14:00:00+00:00|EVT|{}|hash=" + "f" * 64 + "\n", + encoding="utf-8", + ) + with pytest.raises(AuditChainError, match="prev_hash"): + verify_chain(path) + + +def test_parser_rejects_line_with_no_separators(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text("just-a-blob|hash=" + "f" * 64 + "\n", encoding="utf-8") + with pytest.raises(AuditChainError, match="prev_hash"): + verify_chain(path) + + +def test_parser_rejects_malformed_leading_section(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + # Two `|` only: rsplit succeeds twice, leading parts has 1 element ≠ 3. + path.write_text( + "tooshort|prev_hash=" + "0" * 64 + "|hash=" + "f" * 64 + "\n", + encoding="utf-8", + ) + with pytest.raises(AuditChainError, match="leading section"): + verify_chain(path) + + +def test_parser_rejects_payload_not_a_json_object(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text( + "2026-04-27T14:00:00+00:00|EVT|[1,2]|prev_hash=" + + "0" * 64 + + "|hash=" + + "f" * 64 + + "\n", + encoding="utf-8", + ) + with pytest.raises(AuditChainError, match="JSON object"): + verify_chain(path) + + +def test_parser_rejects_payload_with_invalid_json(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + path.write_text( + "2026-04-27T14:00:00+00:00|EVT|{not-json}|prev_hash=" + + "0" * 64 + + "|hash=" + + "f" * 64 + + "\n", + encoding="utf-8", + ) + with pytest.raises(AuditChainError, match="JSON"): + verify_chain(path) + + +def test_iter_entries_returns_empty_when_file_missing(tmp_path: Path) -> None: + path = tmp_path / "missing.log" + assert list(iter_entries(path)) == [] + + +def test_iter_entries_skips_blank_lines(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + raw = path.read_text(encoding="utf-8") + path.write_text(raw + "\n\n", encoding="utf-8") + entries = list(iter_entries(path)) + assert len(entries) == 1 + + +def test_log_resumes_chain_with_large_file(tmp_path: Path) -> None: + """Tail-seek reads past the 4096-byte chunk boundary.""" + path = tmp_path / "audit.log" + log = AuditLog(path) + base = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + # Each line ~150 chars; 50 lines is comfortably > 4096 bytes. + for i in range(50): + log.append( + event=f"E{i}", + payload={"i": i, "filler": "x" * 80}, + now=base + timedelta(seconds=i), + ) + + last_hash = log.last_hash + reopened = AuditLog(path) + assert reopened.last_hash == last_hash + assert verify_chain(path) == 50 + + +def test_payload_serialisation_is_canonical(tmp_path: Path) -> None: + path = tmp_path / "audit.log" + log = AuditLog(path) + # Different key order must produce identical hashes. + e1 = log.append( + event="A", + payload={"b": 1, "a": 2}, + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + other = tmp_path / "audit_other.log" + log2 = AuditLog(other) + e2 = log2.append( + event="A", + payload={"a": 2, "b": 1}, + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + assert e1.hash == e2.hash diff --git a/tests/unit/test_backup.py b/tests/unit/test_backup.py new file mode 100644 index 0000000..b1a840b --- /dev/null +++ b/tests/unit/test_backup.py @@ -0,0 +1,126 @@ +"""Tests for scripts.backup.""" + +from __future__ import annotations + +import importlib.util +import sys +from datetime import UTC, datetime, timedelta +from pathlib import Path + +import pytest + +from cerbero_bite.state import connect, run_migrations + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def _load_backup_module() -> object: + """Load scripts/backup.py as a module without polluting sys.path.""" + spec = importlib.util.spec_from_file_location( + "_cerbero_bite_backup", REPO_ROOT / "scripts" / "backup.py" + ) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +@pytest.fixture +def backup_mod() -> object: + return _load_backup_module() + + +def test_backup_database_creates_snapshot(tmp_path: Path, backup_mod: object) -> None: + db = tmp_path / "state.sqlite" + conn = connect(db) + try: + run_migrations(conn) + conn.execute( + "INSERT INTO system_state(id, last_health_check, config_version, started_at) " + "VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')" + ) + finally: + conn.close() + + backup_dir = tmp_path / "backups" + snapshot = backup_mod.backup_database( # type: ignore[attr-defined] + db_path=db, + backup_dir=backup_dir, + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + assert snapshot.exists() + assert snapshot.name == "state-20260427-14.sqlite" + + # Snapshot is itself a valid SQLite db with the same row. + snap = connect(snapshot) + try: + rows = snap.execute( + "SELECT config_version FROM system_state WHERE id = 1" + ).fetchone() + finally: + snap.close() + assert rows[0] == "1.0.0" + + +def test_backup_database_replaces_existing_hour_snapshot( + tmp_path: Path, backup_mod: object +) -> None: + db = tmp_path / "state.sqlite" + conn = connect(db) + try: + run_migrations(conn) + finally: + conn.close() + + when = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + first = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined] + second = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined] + assert first == second + files = list((tmp_path / "b").iterdir()) + assert len(files) == 1 + + +def test_prune_backups_removes_old_files(tmp_path: Path, backup_mod: object) -> None: + backup_dir = tmp_path / "b" + backup_dir.mkdir() + fresh = backup_dir / "state-20260420-10.sqlite" + stale = backup_dir / "state-20260101-12.sqlite" + other = backup_dir / "unrelated.txt" + fresh.touch() + stale.touch() + other.touch() + + deleted = backup_mod.prune_backups( # type: ignore[attr-defined] + backup_dir, + retention_days=30, + now=datetime(2026, 4, 27, tzinfo=UTC), + ) + assert deleted == [stale] + assert fresh.exists() + assert other.exists() + + +def test_prune_backups_ignores_unparseable_filenames( + tmp_path: Path, backup_mod: object +) -> None: + backup_dir = tmp_path / "b" + backup_dir.mkdir() + (backup_dir / "state-bogus-XX.sqlite").touch() + deleted = backup_mod.prune_backups( # type: ignore[attr-defined] + backup_dir, + retention_days=0, + now=datetime(2026, 4, 27, tzinfo=UTC) + timedelta(days=10000), + ) + assert deleted == [] + + +def test_list_backups_returns_sorted(tmp_path: Path, backup_mod: object) -> None: + backup_dir = tmp_path / "b" + backup_dir.mkdir() + a = backup_dir / "state-20260103-08.sqlite" + b = backup_dir / "state-20260101-08.sqlite" + a.touch() + b.touch() + listed = list(backup_mod.list_backups(backup_dir)) # type: ignore[attr-defined] + assert listed == [b, a] diff --git a/tests/unit/test_cli_safety.py b/tests/unit/test_cli_safety.py new file mode 100644 index 0000000..847e7d9 --- /dev/null +++ b/tests/unit/test_cli_safety.py @@ -0,0 +1,160 @@ +"""End-to-end CLI tests for the Phase 2 commands. + +The commands hit real on-disk paths (tmp_path) so the assertions run +the production code paths verbatim. +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from pathlib import Path + +from click.testing import CliRunner + +from cerbero_bite.cli import main as cli_main +from cerbero_bite.safety import AuditLog +from cerbero_bite.state import Repository, connect, run_migrations, transaction + + +def _seed_state(db_path: Path) -> None: + conn = connect(db_path) + try: + run_migrations(conn) + with transaction(conn): + Repository().init_system_state( + conn, + config_version="1.0.0", + now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + finally: + conn.close() + + +def test_audit_verify_reports_ok_on_clean_chain(tmp_path: Path) -> None: + audit = AuditLog(tmp_path / "audit.log") + audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + audit.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)) + + result = CliRunner().invoke( + cli_main, ["audit", "verify", "--file", str(tmp_path / "audit.log")] + ) + assert result.exit_code == 0, result.output + assert "ok" in result.output + assert "2" in result.output + + +def test_audit_verify_handles_empty_file(tmp_path: Path) -> None: + target = tmp_path / "audit.log" + target.write_text("", encoding="utf-8") + result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)]) + assert result.exit_code == 0 + assert "empty" in result.output + + +def test_audit_verify_exits_nonzero_on_tampering(tmp_path: Path) -> None: + target = tmp_path / "audit.log" + audit = AuditLog(target) + audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)) + target.write_text( + target.read_text(encoding="utf-8").replace('"event":"A"', '"event":"X"'), + encoding="utf-8", + ) + # NB: we mutated the JSON payload, but the actual line still has event=A. + # Force tampering by editing the literal "A" in the line text. + raw = target.read_text(encoding="utf-8") + target.write_text(raw.replace("|A|", "|X|", 1), encoding="utf-8") + + result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)]) + assert result.exit_code == 2 + assert "TAMPERED" in result.output + + +def test_kill_switch_status_prints_disarmed(tmp_path: Path) -> None: + db = tmp_path / "state.sqlite" + _seed_state(db) + result = CliRunner().invoke(cli_main, ["kill-switch", "status", "--db", str(db)]) + assert result.exit_code == 0 + assert "disarmed" in result.output + + +def test_kill_switch_arm_then_status_shows_armed(tmp_path: Path) -> None: + db = tmp_path / "state.sqlite" + audit = tmp_path / "audit.log" + runner = CliRunner() + arm = runner.invoke( + cli_main, + [ + "kill-switch", + "arm", + "--reason", + "manual smoke", + "--db", + str(db), + "--audit", + str(audit), + ], + ) + assert arm.exit_code == 0, arm.output + status = runner.invoke(cli_main, ["kill-switch", "status", "--db", str(db)]) + assert status.exit_code == 0 + assert "ARMED" in status.output + + +def test_kill_switch_status_handles_missing_db(tmp_path: Path) -> None: + result = CliRunner().invoke( + cli_main, ["kill-switch", "status", "--db", str(tmp_path / "absent.sqlite")] + ) + assert result.exit_code == 0 + assert "not found" in result.output + + +def test_state_inspect_shows_no_open_positions(tmp_path: Path) -> None: + db = tmp_path / "state.sqlite" + _seed_state(db) + result = CliRunner().invoke(cli_main, ["state", "inspect", "--db", str(db)]) + assert result.exit_code == 0 + assert "no open positions" in result.output + + +def test_state_inspect_handles_missing_db(tmp_path: Path) -> None: + result = CliRunner().invoke( + cli_main, ["state", "inspect", "--db", str(tmp_path / "absent.sqlite")] + ) + assert result.exit_code == 0 + assert "not found" in result.output + + +def test_config_hash_matches_loader(tmp_path: Path) -> None: + target = tmp_path / "strategy.yaml" + target.write_text( + 'config_version: "1.0.0"\nconfig_hash: "0000"\nasset:\n symbol: ETH\n', + encoding="utf-8", + ) + result = CliRunner().invoke(cli_main, ["config", "hash", "--file", str(target)]) + assert result.exit_code == 0 + assert len(result.output.strip()) == 64 # sha256 hex + + +def test_config_validate_repo_strategy_yaml() -> None: + repo_root = Path(__file__).resolve().parents[2] + yaml_path = repo_root / "strategy.yaml" + result = CliRunner().invoke( + cli_main, ["config", "validate", "--file", str(yaml_path)] + ) + assert result.exit_code == 0 + assert "ok" in result.output + + +def test_config_validate_with_no_enforce_hash_skips_check(tmp_path: Path) -> None: + target = tmp_path / "strategy.yaml" + target.write_text( + 'config_version: "1.0.0"\nconfig_hash: "wrong"\n' + 'last_review: "2026-04-26"\nlast_reviewer: "test"\n', + encoding="utf-8", + ) + result = CliRunner().invoke( + cli_main, + ["config", "validate", "--file", str(target), "--no-enforce-hash"], + ) + assert result.exit_code == 0 + assert "ok" in result.output diff --git a/tests/unit/test_config_loader.py b/tests/unit/test_config_loader.py new file mode 100644 index 0000000..26510b4 --- /dev/null +++ b/tests/unit/test_config_loader.py @@ -0,0 +1,149 @@ +"""Tests for the YAML loader and hash verification.""" + +from __future__ import annotations + +from decimal import Decimal +from pathlib import Path + +import pytest +import yaml + +from cerbero_bite.config.loader import ( + ConfigHashError, + compute_config_hash, + load_strategy, +) + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _golden_yaml_skeleton(**overrides: object) -> dict[str, object]: + base = { + "config_version": "1.0.0-test", + "config_hash": "0" * 64, + "last_review": "2026-04-26", + "last_reviewer": "test", + } + base.update(overrides) + return base + + +def _write_with_correct_hash(path: Path, doc: dict[str, object]) -> None: + """Write a YAML doc and patch ``config_hash`` to match the file body.""" + text = yaml.safe_dump(doc, sort_keys=False) + path.write_text(text, encoding="utf-8") + new_hash = compute_config_hash(path.read_text(encoding="utf-8")) + doc = {**doc, "config_hash": new_hash} + text = yaml.safe_dump(doc, sort_keys=False) + path.write_text(text, encoding="utf-8") + + +# --------------------------------------------------------------------------- +# compute_config_hash +# --------------------------------------------------------------------------- + + +def test_compute_hash_is_independent_of_recorded_hash_value(tmp_path: Path) -> None: + a = tmp_path / "a.yaml" + b = tmp_path / "b.yaml" + a.write_text( + 'config_version: "1.0.0"\nconfig_hash: "aaa"\nfoo: 1\n', encoding="utf-8" + ) + b.write_text( + 'config_version: "1.0.0"\nconfig_hash: "bbb"\nfoo: 1\n', encoding="utf-8" + ) + assert compute_config_hash(a.read_text()) == compute_config_hash(b.read_text()) + + +# --------------------------------------------------------------------------- +# load_strategy — happy path +# --------------------------------------------------------------------------- + + +def test_load_repo_strategy_yaml(tmp_path: Path) -> None: + """The committed strategy.yaml validates with the recorded hash.""" + result = load_strategy(REPO_ROOT / "strategy.yaml") + assert result.config.config_version == "1.0.0" + assert result.config.sizing.kelly_fraction == Decimal("0.13") + assert result.computed_hash == result.config.config_hash + + +def test_load_with_local_override_merges(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + _write_with_correct_hash(main, _golden_yaml_skeleton()) + override = tmp_path / "strategy.local.yaml" + override.write_text( + yaml.safe_dump( + {"sizing": {"max_concurrent_positions": 0}}, + sort_keys=False, + ), + encoding="utf-8", + ) + loaded = load_strategy(main) + assert loaded.config.sizing.max_concurrent_positions == 0 + assert override in loaded.sources + + +def test_local_override_does_not_invalidate_main_hash(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + _write_with_correct_hash(main, _golden_yaml_skeleton()) + (tmp_path / "strategy.local.yaml").write_text( + "sizing:\n kelly_fraction: '0.10'\n", encoding="utf-8" + ) + loaded = load_strategy(main) + assert loaded.config.sizing.kelly_fraction == Decimal("0.10") + + +# --------------------------------------------------------------------------- +# load_strategy — error paths +# --------------------------------------------------------------------------- + + +def test_load_with_mismatched_hash_raises(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + main.write_text( + yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8" + ) + with pytest.raises(ConfigHashError, match="config_hash mismatch"): + load_strategy(main) + + +def test_load_with_enforce_hash_false_skips_check(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + main.write_text( + yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8" + ) + loaded = load_strategy(main, enforce_hash=False) + assert loaded.config.config_hash == "0" * 64 + + +def test_load_rejects_top_level_non_mapping(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + main.write_text("- a\n- b\n", encoding="utf-8") + with pytest.raises(ValueError, match="top-level mapping"): + load_strategy(main, enforce_hash=False) + + +def test_local_override_path_pointing_to_missing_file_is_ignored( + tmp_path: Path, +) -> None: + main = tmp_path / "strategy.yaml" + _write_with_correct_hash(main, _golden_yaml_skeleton()) + loaded = load_strategy( + main, local_override_path=tmp_path / "nonexistent.yaml" + ) + assert main in loaded.sources + assert len(loaded.sources) == 1 + + +def test_empty_local_override_file_is_no_op(tmp_path: Path) -> None: + main = tmp_path / "strategy.yaml" + _write_with_correct_hash(main, _golden_yaml_skeleton()) + (tmp_path / "strategy.local.yaml").write_text("", encoding="utf-8") + loaded = load_strategy(main) + assert loaded.config.sizing.kelly_fraction == Decimal("0.13") diff --git a/tests/unit/test_kill_switch.py b/tests/unit/test_kill_switch.py new file mode 100644 index 0000000..3216f8a --- /dev/null +++ b/tests/unit/test_kill_switch.py @@ -0,0 +1,170 @@ +"""Kill switch behaviour: SQLite + audit log stay in lock-step.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path + +import pytest + +from cerbero_bite.safety import AuditLog, verify_chain +from cerbero_bite.safety.kill_switch import KillSwitch, KillSwitchError +from cerbero_bite.state import Repository, connect, run_migrations, transaction + + +def _make_kill_switch(tmp_path: Path) -> tuple[KillSwitch, AuditLog, Path, Repository]: + db_path = tmp_path / "state.sqlite" + audit_path = tmp_path / "audit.log" + conn = connect(db_path) + run_migrations(conn) + repo = Repository() + with transaction(conn): + repo.init_system_state( + conn, config_version="1.0.0", now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + ) + conn.close() + + audit = AuditLog(audit_path) + times = iter( + datetime(2026, 4, 27, 14, m, tzinfo=UTC) for m in (10, 20, 30, 40, 50) + ) + ks = KillSwitch( + connection_factory=lambda: connect(db_path), + repository=repo, + audit_log=audit, + clock=lambda: next(times), + ) + return ks, audit, audit_path, repo + + +def test_arm_persists_state_and_appends_audit(tmp_path: Path) -> None: + ks, _audit, audit_path, repo = _make_kill_switch(tmp_path) + assert ks.is_armed() is False + + ks.arm(reason="manual test", source="manual") + + assert ks.is_armed() is True + conn = connect(tmp_path / "state.sqlite") + try: + state = repo.get_system_state(conn) + finally: + conn.close() + assert state is not None + assert state.kill_switch == 1 + assert state.kill_reason == "manual test" + assert state.kill_at is not None + assert verify_chain(audit_path) == 1 + + +def test_arm_is_idempotent_on_second_call(tmp_path: Path) -> None: + ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path) + ks.arm(reason="first", source="manual") + ks.arm(reason="second", source="manual") # no-op + # only one audit line because the second call short-circuits + assert verify_chain(audit_path) == 1 + + +def test_disarm_resets_kill_switch(tmp_path: Path) -> None: + ks, _audit, audit_path, repo = _make_kill_switch(tmp_path) + ks.arm(reason="test", source="manual") + ks.disarm(reason="cleared", source="manual") + assert ks.is_armed() is False + conn = connect(tmp_path / "state.sqlite") + try: + state = repo.get_system_state(conn) + finally: + conn.close() + assert state is not None + assert state.kill_at is None + # arm + disarm = 2 audit lines + assert verify_chain(audit_path) == 2 + + +def test_disarm_when_not_armed_is_noop(tmp_path: Path) -> None: + ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path) + ks.disarm(reason="nothing to do", source="manual") + assert verify_chain(audit_path) == 0 + + +def test_arm_requires_reason(tmp_path: Path) -> None: + ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path) + with pytest.raises(KillSwitchError, match="reason is required"): + ks.arm(reason="", source="manual") + + +def test_arm_without_initialised_state_raises(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + audit_path = tmp_path / "audit.log" + conn = connect(db_path) + run_migrations(conn) + conn.close() + ks = KillSwitch( + connection_factory=lambda: connect(db_path), + repository=Repository(), + audit_log=AuditLog(audit_path), + clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + with pytest.raises(KillSwitchError, match="system_state singleton missing"): + ks.arm(reason="x", source="manual") + + +def test_audit_chain_records_event_kind(tmp_path: Path) -> None: + ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path) + ks.arm(reason="x", source="mcp_timeout") + ks.disarm(reason="y", source="manual") + text = audit_path.read_text(encoding="utf-8") + assert "KILL_SWITCH_ARMED" in text + assert "KILL_SWITCH_DISARMED" in text + + +def test_is_armed_returns_false_when_singleton_missing(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + audit_path = tmp_path / "audit.log" + conn = connect(db_path) + run_migrations(conn) + conn.close() + ks = KillSwitch( + connection_factory=lambda: connect(db_path), + repository=Repository(), + audit_log=AuditLog(audit_path), + clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + assert ks.is_armed() is False + + +def test_disarm_requires_reason(tmp_path: Path) -> None: + ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path) + with pytest.raises(KillSwitchError, match="reason is required"): + ks.disarm(reason="", source="manual") + + +def test_disarm_without_initialised_state_raises(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + audit_path = tmp_path / "audit.log" + conn = connect(db_path) + run_migrations(conn) + conn.close() + ks = KillSwitch( + connection_factory=lambda: connect(db_path), + repository=Repository(), + audit_log=AuditLog(audit_path), + clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + with pytest.raises(KillSwitchError, match="system_state singleton missing"): + ks.disarm(reason="x", source="manual") + + +def test_clock_is_advanced_for_each_call(tmp_path: Path) -> None: + ks, _audit, _audit_path, repo = _make_kill_switch(tmp_path) + ks.arm(reason="x", source="manual") + ks.disarm(reason="y", source="manual") + conn = connect(tmp_path / "state.sqlite") + try: + state = repo.get_system_state(conn) + finally: + conn.close() + assert state is not None + # last_health_check should reflect the disarm time (14:20 from the fake clock). + assert state.last_health_check >= datetime(2026, 4, 27, 14, 15, tzinfo=UTC) - timedelta( + seconds=1 + ) diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index 555f1e4..43fbb9d 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -38,14 +38,29 @@ def test_cli_status_runs(tmp_data_dir: Path) -> None: assert "phase: 0" in result.output -def test_cli_kill_switch_arm_placeholder(tmp_data_dir: Path) -> None: +def test_cli_kill_switch_arm_persists_state(tmp_data_dir: Path) -> None: runner = CliRunner() + db_path = tmp_data_dir / "state.sqlite" + audit_path = tmp_data_dir / "audit.log" result = runner.invoke( cli_main, - ["--log-dir", str(tmp_data_dir / "log"), "kill-switch", "arm", "--reason", "test"], + [ + "--log-dir", + str(tmp_data_dir / "log"), + "kill-switch", + "arm", + "--reason", + "smoke", + "--db", + str(db_path), + "--audit", + str(audit_path), + ], ) - assert result.exit_code == 0 - assert "phase 0 placeholder" in result.output + assert result.exit_code == 0, result.output + assert "ARMED" in result.output + assert db_path.exists() + assert audit_path.exists() def test_cli_version_flag() -> None: diff --git a/tests/unit/test_state_db.py b/tests/unit/test_state_db.py new file mode 100644 index 0000000..88cf70f --- /dev/null +++ b/tests/unit/test_state_db.py @@ -0,0 +1,109 @@ +"""Migration runner + SQLite pragma sanity tests.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import pytest + +from cerbero_bite.state.db import ( + connect, + current_version, + list_migrations, + run_migrations, + transaction, +) + + +def test_list_migrations_is_ordered_and_starts_with_0001() -> None: + migs = list_migrations() + assert migs, "expected at least one migration file" + versions = [m[0] for m in migs] + assert versions == sorted(versions) + assert versions[0] == 1 + + +def test_run_migrations_creates_full_schema(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + conn = connect(db_path) + try: + new_version = run_migrations(conn) + assert new_version == max(m[0] for m in list_migrations()) + # Sanity: every documented table is present. + existing = { + row["name"] + for row in conn.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ).fetchall() + } + assert { + "positions", + "instructions", + "decisions", + "dvol_history", + "manual_actions", + "system_state", + } <= existing + finally: + conn.close() + + +def test_run_migrations_is_idempotent(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + conn = connect(db_path) + try: + run_migrations(conn) + first = current_version(conn) + run_migrations(conn) # second call must be a no-op + assert current_version(conn) == first + finally: + conn.close() + + +def test_pragmas_applied(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + conn = connect(db_path) + try: + run_migrations(conn) + assert conn.execute("PRAGMA foreign_keys").fetchone()[0] == 1 + # WAL is sticky on the file: confirm via the journal_mode pragma. + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode.lower() == "wal" + finally: + conn.close() + + +def test_transaction_rolls_back_on_exception(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + conn = connect(db_path) + try: + run_migrations(conn) + with pytest.raises(RuntimeError, match="boom"), transaction(conn): + conn.execute( + "INSERT INTO system_state(id, last_health_check, " + "config_version, started_at) VALUES (1, '2026-04-27', " + "'1.0.0', '2026-04-27')" + ) + raise RuntimeError("boom") + # Row was rolled back. + rows = conn.execute("SELECT COUNT(*) FROM system_state").fetchone()[0] + assert rows == 0 + finally: + conn.close() + + +def test_foreign_key_enforcement(tmp_path: Path) -> None: + db_path = tmp_path / "state.sqlite" + conn = connect(db_path) + try: + run_migrations(conn) + # Inserting an instruction without a parent position must fail. + with pytest.raises(sqlite3.IntegrityError): + conn.execute( + "INSERT INTO instructions(instruction_id, proposal_id, kind, " + "payload_json, sent_at) VALUES (?, ?, ?, ?, ?)", + ("i1", "missing", "open_combo", "{}", "2026-04-27"), + ) + finally: + conn.close() diff --git a/tests/unit/test_state_repository.py b/tests/unit/test_state_repository.py new file mode 100644 index 0000000..1264574 --- /dev/null +++ b/tests/unit/test_state_repository.py @@ -0,0 +1,450 @@ +"""CRUD tests for :mod:`cerbero_bite.state.repository`.""" + +from __future__ import annotations + +import sqlite3 +from collections.abc import Iterator +from datetime import UTC, datetime, timedelta +from decimal import Decimal +from pathlib import Path +from uuid import uuid4 + +import pytest + +from cerbero_bite.state import ( + DecisionRecord, + DvolSnapshot, + InstructionRecord, + ManualAction, + PositionRecord, + Repository, + connect, + run_migrations, + transaction, +) + + +@pytest.fixture +def conn(tmp_path: Path) -> Iterator[sqlite3.Connection]: + db = tmp_path / "state.sqlite" + c = connect(db) + run_migrations(c) + yield c + c.close() + + +@pytest.fixture +def repo() -> Repository: + return Repository() + + +def _make_position(**overrides: object) -> PositionRecord: + base: dict[str, object] = { + "proposal_id": uuid4(), + "spread_type": "bull_put", + "expiry": datetime(2026, 5, 15, 8, 0, tzinfo=UTC), + "short_strike": Decimal("2475"), + "long_strike": Decimal("2350"), + "short_instrument": "ETH-15MAY26-2475-P", + "long_instrument": "ETH-15MAY26-2350-P", + "n_contracts": 2, + "spread_width_usd": Decimal("125"), + "spread_width_pct": Decimal("0.0417"), + "credit_eth": Decimal("0.030"), + "credit_usd": Decimal("90"), + "max_loss_usd": Decimal("160"), + "spot_at_entry": Decimal("3000"), + "dvol_at_entry": Decimal("50"), + "delta_at_entry": Decimal("-0.12"), + "eth_price_at_entry": Decimal("3000"), + "proposed_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + "status": "proposed", + "created_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + "updated_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + } + base.update(overrides) + return PositionRecord(**base) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# positions +# --------------------------------------------------------------------------- + + +def test_create_and_get_position_roundtrip( + conn: sqlite3.Connection, repo: Repository +) -> None: + record = _make_position() + with transaction(conn): + repo.create_position(conn, record) + fetched = repo.get_position(conn, record.proposal_id) + assert fetched is not None + assert fetched.proposal_id == record.proposal_id + assert fetched.short_strike == Decimal("2475") + # Decimal precision preserved (no float coercion). + assert fetched.spread_width_pct == Decimal("0.0417") + + +def test_get_unknown_position_returns_none( + conn: sqlite3.Connection, repo: Repository +) -> None: + assert repo.get_position(conn, uuid4()) is None + + +def test_list_open_positions_filters_by_status( + conn: sqlite3.Connection, repo: Repository +) -> None: + open_pos = _make_position(status="open") + closed_pos = _make_position(status="closed") + with transaction(conn): + repo.create_position(conn, open_pos) + repo.create_position(conn, closed_pos) + open_only = repo.list_open_positions(conn) + assert {p.proposal_id for p in open_only} == {open_pos.proposal_id} + + +def test_count_concurrent_positions_excludes_closed( + conn: sqlite3.Connection, repo: Repository +) -> None: + with transaction(conn): + repo.create_position(conn, _make_position(status="open")) + repo.create_position(conn, _make_position(status="awaiting_fill")) + repo.create_position(conn, _make_position(status="closed")) + repo.create_position(conn, _make_position(status="cancelled")) + assert repo.count_concurrent_positions(conn) == 2 + + +def test_update_position_status_with_only_status_skips_optional_fields( + conn: sqlite3.Connection, repo: Repository +) -> None: + """Hit the False side of every optional-field branch in ``update_position_status``.""" + record = _make_position(status="proposed") + with transaction(conn): + repo.create_position(conn, record) + repo.update_position_status( + conn, + record.proposal_id, + status="awaiting_fill", + now=datetime(2026, 4, 27, 14, 5, tzinfo=UTC), + ) + fetched = repo.get_position(conn, record.proposal_id) + assert fetched is not None + assert fetched.status == "awaiting_fill" + assert fetched.opened_at is None + assert fetched.closed_at is None + assert fetched.close_reason is None + + +def test_update_position_status_persists_open_then_close( + conn: sqlite3.Connection, repo: Repository +) -> None: + record = _make_position(status="awaiting_fill") + opened_at = datetime(2026, 4, 27, 14, 10, tzinfo=UTC) + with transaction(conn): + repo.create_position(conn, record) + repo.update_position_status( + conn, + record.proposal_id, + status="open", + opened_at=opened_at, + now=datetime(2026, 4, 27, 14, 11, tzinfo=UTC), + ) + after_open = repo.get_position(conn, record.proposal_id) + assert after_open is not None + assert after_open.opened_at == opened_at + + closed_at = datetime(2026, 5, 12, 14, 0, tzinfo=UTC) + with transaction(conn): + repo.update_position_status( + conn, + record.proposal_id, + status="closed", + closed_at=closed_at, + close_reason="CLOSE_PROFIT", + debit_paid_eth=Decimal("0.012"), + pnl_eth=Decimal("0.018"), + pnl_usd=Decimal("54"), + now=closed_at, + ) + + fetched = repo.get_position(conn, record.proposal_id) + assert fetched is not None + assert fetched.status == "closed" + assert fetched.close_reason == "CLOSE_PROFIT" + assert fetched.debit_paid_eth == Decimal("0.012") + assert fetched.pnl_eth == Decimal("0.018") + + +def test_naive_datetime_is_normalised_to_utc( + conn: sqlite3.Connection, repo: Repository +) -> None: + naive_now = datetime(2026, 4, 27, 14, 0) + record = _make_position(proposed_at=naive_now, created_at=naive_now, updated_at=naive_now) + with transaction(conn): + repo.create_position(conn, record) + fetched = repo.get_position(conn, record.proposal_id) + assert fetched is not None + assert fetched.proposed_at.tzinfo is not None + assert fetched.proposed_at.utcoffset() == timedelta(0) + + +# --------------------------------------------------------------------------- +# instructions +# --------------------------------------------------------------------------- + + +def test_instruction_lifecycle_ack_and_fill( + conn: sqlite3.Connection, repo: Repository +) -> None: + pos = _make_position(status="awaiting_fill") + instr = InstructionRecord( + instruction_id=uuid4(), + proposal_id=pos.proposal_id, + kind="open_combo", + payload_json='{"action":"open"}', + sent_at=datetime(2026, 4, 27, 14, 5, tzinfo=UTC), + ) + with transaction(conn): + repo.create_position(conn, pos) + repo.create_instruction(conn, instr) + repo.update_instruction( + conn, + instr.instruction_id, + acknowledged_at=datetime(2026, 4, 27, 14, 6, tzinfo=UTC), + ) + repo.update_instruction( + conn, + instr.instruction_id, + filled_at=datetime(2026, 4, 27, 14, 8, tzinfo=UTC), + actual_fill_eth=Decimal("0.0298"), + actual_fees_eth=Decimal("0.0001"), + ) + + fetched = repo.list_instructions(conn, pos.proposal_id) + assert len(fetched) == 1 + assert fetched[0].acknowledged_at is not None + assert fetched[0].filled_at is not None + assert fetched[0].actual_fill_eth == Decimal("0.0298") + + +def test_update_instruction_no_op_when_no_fields( + conn: sqlite3.Connection, repo: Repository +) -> None: + pos = _make_position() + instr = InstructionRecord( + instruction_id=uuid4(), + proposal_id=pos.proposal_id, + kind="open_combo", + payload_json="{}", + sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + with transaction(conn): + repo.create_position(conn, pos) + repo.create_instruction(conn, instr) + repo.update_instruction(conn, instr.instruction_id) # no fields + + fetched = repo.list_instructions(conn, pos.proposal_id) + assert fetched[0].acknowledged_at is None + + +# --------------------------------------------------------------------------- +# decisions / dvol / manual_actions +# --------------------------------------------------------------------------- + + +def test_record_and_list_decisions( + conn: sqlite3.Connection, repo: Repository +) -> None: + decision = DecisionRecord( + decision_type="entry_check", + timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + inputs_json='{"capital":1500}', + outputs_json='{"accepted":true}', + action_taken="propose_open", + ) + with transaction(conn): + decision_id = repo.record_decision(conn, decision) + assert decision_id > 0 + decisions = repo.list_decisions(conn) + assert len(decisions) == 1 + assert decisions[0].id == decision_id + assert decisions[0].action_taken == "propose_open" + + +def test_record_dvol_snapshot_replaces_on_duplicate_timestamp( + conn: sqlite3.Connection, repo: Repository +) -> None: + ts = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + with transaction(conn): + repo.record_dvol_snapshot( + conn, DvolSnapshot(timestamp=ts, dvol=Decimal("50"), eth_spot=Decimal("3000")) + ) + repo.record_dvol_snapshot( + conn, DvolSnapshot(timestamp=ts, dvol=Decimal("55"), eth_spot=Decimal("3050")) + ) + rows = conn.execute("SELECT COUNT(*) FROM dvol_history").fetchone() + assert rows[0] == 1 + + +def test_manual_action_enqueue_consume_cycle( + conn: sqlite3.Connection, repo: Repository +) -> None: + pos = _make_position() + action = ManualAction( + kind="approve_proposal", + proposal_id=pos.proposal_id, + payload_json='{"reason":"go"}', + created_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + with transaction(conn): + repo.create_position(conn, pos) + action_id = repo.enqueue_manual_action(conn, action) + next_action = repo.next_unconsumed_action(conn) + assert next_action is not None + assert next_action.id == action_id + with transaction(conn): + repo.mark_action_consumed( + conn, + action_id, + consumed_by="orchestrator", + result="ok", + now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC), + ) + assert repo.next_unconsumed_action(conn) is None + + +# --------------------------------------------------------------------------- +# system_state +# --------------------------------------------------------------------------- + + +def test_init_system_state_is_idempotent( + conn: sqlite3.Connection, repo: Repository +) -> None: + now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + with transaction(conn): + repo.init_system_state(conn, config_version="1.0.0", now=now) + repo.init_system_state(conn, config_version="1.0.0", now=now) + state = repo.get_system_state(conn) + assert state is not None + assert state.kill_switch == 0 + assert state.config_version == "1.0.0" + + +def test_kill_switch_arm_and_disarm( + conn: sqlite3.Connection, repo: Repository +) -> None: + now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + with transaction(conn): + repo.init_system_state(conn, config_version="1.0.0", now=now) + repo.set_kill_switch(conn, armed=True, reason="manual", now=now) + state = repo.get_system_state(conn) + assert state is not None + assert state.kill_switch == 1 + assert state.kill_reason == "manual" + assert state.kill_at is not None + + later = now + timedelta(minutes=5) + with transaction(conn): + repo.set_kill_switch(conn, armed=False, reason=None, now=later) + state = repo.get_system_state(conn) + assert state is not None + assert state.kill_switch == 0 + assert state.kill_at is None + assert state.last_health_check == later + + +def test_get_system_state_returns_none_when_uninitialised( + conn: sqlite3.Connection, repo: Repository +) -> None: + assert repo.get_system_state(conn) is None + + +def test_list_positions_filters_by_status( + conn: sqlite3.Connection, repo: Repository +) -> None: + open_pos = _make_position(status="open") + closed_pos = _make_position(status="closed") + with transaction(conn): + repo.create_position(conn, open_pos) + repo.create_position(conn, closed_pos) + closed_only = repo.list_positions(conn, status="closed") + assert len(closed_only) == 1 + assert closed_only[0].proposal_id == closed_pos.proposal_id + + +def test_list_positions_without_filter_returns_all( + conn: sqlite3.Connection, repo: Repository +) -> None: + with transaction(conn): + repo.create_position(conn, _make_position(status="open")) + repo.create_position(conn, _make_position(status="closed")) + assert len(repo.list_positions(conn)) == 2 + + +def test_list_decisions_filters_by_proposal( + conn: sqlite3.Connection, repo: Repository +) -> None: + pos = _make_position() + with transaction(conn): + repo.create_position(conn, pos) + repo.record_decision( + conn, + DecisionRecord( + decision_type="entry_check", + proposal_id=pos.proposal_id, + timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + inputs_json="{}", + outputs_json="{}", + action_taken="propose_open", + ), + ) + repo.record_decision( + conn, + DecisionRecord( + decision_type="exit_check", + timestamp=datetime(2026, 4, 27, 15, 0, tzinfo=UTC), + inputs_json="{}", + outputs_json="{}", + action_taken="HOLD", + ), + ) + only_for_proposal = repo.list_decisions(conn, proposal_id=pos.proposal_id) + assert len(only_for_proposal) == 1 + assert only_for_proposal[0].action_taken == "propose_open" + + +def test_update_instruction_sets_cancelled( + conn: sqlite3.Connection, repo: Repository +) -> None: + pos = _make_position() + instr = InstructionRecord( + instruction_id=uuid4(), + proposal_id=pos.proposal_id, + kind="open_combo", + payload_json="{}", + sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC), + ) + with transaction(conn): + repo.create_position(conn, pos) + repo.create_instruction(conn, instr) + repo.update_instruction( + conn, + instr.instruction_id, + cancelled_at=datetime(2026, 4, 27, 14, 30, tzinfo=UTC), + ) + fetched = repo.list_instructions(conn, pos.proposal_id) + assert fetched[0].cancelled_at is not None + + +def test_touch_health_check_updates_timestamp( + conn: sqlite3.Connection, repo: Repository +) -> None: + now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC) + with transaction(conn): + repo.init_system_state(conn, config_version="1.0.0", now=now) + later = now + timedelta(minutes=10) + repo.touch_health_check(conn, now=later) + state = repo.get_system_state(conn) + assert state is not None + assert state.last_health_check == later