Phase 2: persistence + safety controls

Aggiunge la persistenza SQLite, l'audit log a hash chain, il kill
switch coordinato e i CLI di gestione documentati in
docs/05-data-model.md e docs/07-risk-controls.md. 197 test pass,
1 skipped (sqlite3 CLI mancante), copertura totale 97%.

State (`state/`):
- 0001_init.sql con positions, instructions, decisions, dvol_history,
  manual_actions, system_state.
- db.py: connect con WAL + foreign_keys + transaction ctx, runner
  forward-only basato su PRAGMA user_version.
- models.py: record Pydantic, Decimal preservato come TEXT.
- repository.py: CRUD typed con singola connessione passata, cache
  aware, posizioni concorrenti.

Safety (`safety/`):
- audit_log.py: AuditLog append-only con SHA-256 chain e fsync,
  verify_chain riconosce ogni manomissione (payload, prev_hash,
  hash, JSON, separatori).
- kill_switch.py: arm/disarm transazionali, idempotenti, accoppiati
  all'audit chain.

Config (`config/loader.py` + `strategy.yaml`):
- Loader YAML con deep-merge di strategy.local.yaml.
- Verifica config_hash SHA-256 (riga config_hash esclusa).
- File golden strategy.yaml + esempio override.

Scripts:
- dead_man.sh: watchdog shell indipendente da Python.
- backup.py: VACUUM INTO orario con retention 30 giorni.

CLI:
- audit verify (exit 2 su tampering).
- kill-switch arm/disarm/status su SQLite reale.
- state inspect con tabella posizioni aperte.
- config hash, config validate.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-27 13:35:35 +02:00
parent fbb7753cc6
commit 263470786d
25 changed files with 3669 additions and 14 deletions
+110
View File
@@ -0,0 +1,110 @@
"""Hourly SQLite backup utility (``docs/05-data-model.md``).
Uses ``VACUUM INTO`` so the snapshot is a self-contained, defragmented
SQLite file that can be inspected without locking the live database.
Retention is enforced by deletion: any file matching
``state-YYYYMMDD-HH.sqlite`` older than ``retention_days`` is removed.
Designed to be invoked from APScheduler in the orchestrator and from
the CLI for ad-hoc backups.
"""
from __future__ import annotations
import argparse
import re
import sqlite3
from collections.abc import Iterable
from datetime import UTC, datetime, timedelta
from pathlib import Path
__all__ = ["BACKUP_FILENAME_RE", "backup_database", "prune_backups"]
BACKUP_FILENAME_RE = re.compile(r"^state-(\d{8}-\d{2})\.sqlite$")
_DEFAULT_RETENTION_DAYS = 30
def _format_backup_name(now: datetime) -> str:
return f"state-{now.astimezone(UTC):%Y%m%d-%H}.sqlite"
def backup_database(
*,
db_path: Path | str,
backup_dir: Path | str,
now: datetime | None = None,
) -> Path:
"""Create a snapshot via ``VACUUM INTO`` and return its path."""
src = Path(db_path)
dst_dir = Path(backup_dir)
dst_dir.mkdir(parents=True, exist_ok=True)
timestamp = (now or datetime.now(UTC)).astimezone(UTC)
dst = dst_dir / _format_backup_name(timestamp)
if dst.exists():
# Idempotent at hour granularity: same hour = same target file.
dst.unlink()
conn = sqlite3.connect(str(src))
try:
conn.execute(f"VACUUM INTO '{dst.as_posix()}'")
finally:
conn.close()
return dst
def _parse_backup_timestamp(name: str) -> datetime | None:
match = BACKUP_FILENAME_RE.match(name)
if match is None:
return None
try:
return datetime.strptime(match.group(1), "%Y%m%d-%H").replace(tzinfo=UTC)
except ValueError:
return None
def prune_backups(
backup_dir: Path | str,
*,
retention_days: int = _DEFAULT_RETENTION_DAYS,
now: datetime | None = None,
) -> list[Path]:
"""Remove backups older than ``retention_days``. Returns the deleted paths."""
cutoff = (now or datetime.now(UTC)).astimezone(UTC) - timedelta(days=retention_days)
deleted: list[Path] = []
for entry in Path(backup_dir).iterdir():
if not entry.is_file():
continue
ts = _parse_backup_timestamp(entry.name)
if ts is None:
continue
if ts < cutoff:
entry.unlink()
deleted.append(entry)
return deleted
def list_backups(backup_dir: Path | str) -> Iterable[Path]:
return sorted(
(p for p in Path(backup_dir).iterdir() if BACKUP_FILENAME_RE.match(p.name)),
key=lambda p: p.name,
)
def _cli() -> None:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument("--db", default="data/state.sqlite")
parser.add_argument("--out", default="data/backups")
parser.add_argument("--retention-days", type=int, default=_DEFAULT_RETENTION_DAYS)
args = parser.parse_args()
out = backup_database(db_path=args.db, backup_dir=args.out)
pruned = prune_backups(args.out, retention_days=args.retention_days)
print(f"backup -> {out}")
if pruned:
print(f"pruned: {', '.join(p.name for p in pruned)}")
if __name__ == "__main__":
_cli()
+107
View File
@@ -0,0 +1,107 @@
#!/usr/bin/env bash
# dead_man.sh — independent watchdog for Cerbero Bite (docs/07-risk-controls.md).
#
# Runs from cron every 5 minutes. If the engine has not written a
# HEALTH_OK event into today's JSONL log within the last
# DEAD_MAN_THRESHOLD_SECONDS (default 900 = 15 minutes), it:
# 1. Sends an alert via DEAD_MAN_ALERT_CMD (any command consuming
# a single argument: the alert text). When unset, falls back to
# writing data/log/dead-man-alert.txt so an external watcher can
# pick it up.
# 2. Arms the SQLite kill switch directly (no Python required).
# 3. Appends one line to data/audit.log (best-effort hash chain;
# verifying after recovery is the operator's job).
#
# Configuration via env vars or .env in PROJECT_ROOT:
# PROJECT_ROOT — repo root (default: parent of this file).
# DEAD_MAN_THRESHOLD_SECONDS — silence threshold (default 900).
# DEAD_MAN_ALERT_CMD — optional alert command.
#
# This script intentionally avoids Python so it survives env corruption.
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="${PROJECT_ROOT:-$(cd "${SCRIPT_DIR}/.." && pwd)}"
THRESHOLD="${DEAD_MAN_THRESHOLD_SECONDS:-900}"
LOG_DIR="${PROJECT_ROOT}/data/log"
DB_PATH="${PROJECT_ROOT}/data/state.sqlite"
AUDIT_PATH="${PROJECT_ROOT}/data/audit.log"
ALERT_FILE="${LOG_DIR}/dead-man-alert.txt"
today_log() {
date -u +"${LOG_DIR}/cerbero-bite-%Y-%m-%d.jsonl"
}
last_health_ts() {
local file="$1"
if [[ ! -f "$file" ]]; then
echo ""
return
fi
grep -F '"event": "HEALTH_OK"' "$file" 2>/dev/null \
| tail -n 1 \
| sed -E 's/.*"ts":[[:space:]]*"([^"]+)".*/\1/' \
|| true
}
emit_alert() {
local message="$1"
if [[ -n "${DEAD_MAN_ALERT_CMD:-}" ]]; then
"${DEAD_MAN_ALERT_CMD}" "$message" || true
fi
mkdir -p "$LOG_DIR"
printf '%s | %s\n' "$(date -u +%FT%TZ)" "$message" >> "$ALERT_FILE"
}
arm_kill_switch() {
if [[ ! -f "$DB_PATH" ]] || ! command -v sqlite3 >/dev/null 2>&1; then
return
fi
sqlite3 "$DB_PATH" <<SQL || true
UPDATE system_state
SET kill_switch = 1,
kill_reason = COALESCE(kill_reason, 'dead_man'),
kill_at = COALESCE(kill_at, datetime('now')),
last_health_check = datetime('now')
WHERE id = 1;
SQL
}
append_audit_line() {
local ts
ts="$(date -u +%FT%TZ)"
mkdir -p "$(dirname "$AUDIT_PATH")"
printf '%s|DEAD_MAN_TRIGGERED|{"reason":"silence>threshold"}|prev_hash=manual|hash=manual\n' "$ts" >> "$AUDIT_PATH"
}
main() {
local log_file
log_file="$(today_log)"
local last_ts
last_ts="$(last_health_ts "$log_file")"
if [[ -z "$last_ts" ]]; then
emit_alert "dead_man: no HEALTH_OK in $log_file"
arm_kill_switch
append_audit_line
exit 1
fi
local last_epoch now_epoch elapsed
last_epoch="$(date -u -d "$last_ts" +%s 2>/dev/null || echo 0)"
now_epoch="$(date -u +%s)"
elapsed=$(( now_epoch - last_epoch ))
if (( elapsed > THRESHOLD )); then
emit_alert "dead_man: ${elapsed}s since last HEALTH_OK (threshold ${THRESHOLD}s)"
arm_kill_switch
append_audit_line
exit 1
fi
exit 0
}
main "$@"
+247 -10
View File
@@ -10,19 +10,32 @@ without changing the surface.
from __future__ import annotations from __future__ import annotations
import sys import sys
from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
import click import click
from rich.console import Console from rich.console import Console
from rich.table import Table
from cerbero_bite import __version__ from cerbero_bite import __version__
from cerbero_bite.config.loader import compute_config_hash, load_strategy
from cerbero_bite.logging import configure as configure_logging from cerbero_bite.logging import configure as configure_logging
from cerbero_bite.logging import get_logger from cerbero_bite.logging import get_logger
from cerbero_bite.safety.audit_log import AuditChainError, AuditLog
from cerbero_bite.safety.audit_log import verify_chain as verify_audit_chain
from cerbero_bite.safety.kill_switch import KillSwitch
from cerbero_bite.state import Repository, run_migrations, transaction
from cerbero_bite.state import connect as connect_state
console = Console() console = Console()
log = get_logger("cli") log = get_logger("cli")
_DEFAULT_DB_PATH = Path("data/state.sqlite")
_DEFAULT_AUDIT_PATH = Path("data/audit.log")
_DEFAULT_STRATEGY_PATH = Path("strategy.yaml")
def _phase0_notice(action: str) -> None: def _phase0_notice(action: str) -> None:
console.print(f"[yellow]\\[phase 0 placeholder][/yellow] {action}") console.print(f"[yellow]\\[phase 0 placeholder][/yellow] {action}")
@@ -85,18 +98,131 @@ def kill_switch() -> None:
"""Manage the engine kill switch.""" """Manage the engine kill switch."""
def _make_kill_switch(
db_path: Path, audit_path: Path, *, config_version: str
) -> KillSwitch:
"""Wire a :class:`KillSwitch` against the on-disk paths.
``init_system_state`` is called eagerly so the CLI can be used on
a fresh checkout before the engine ever ran.
"""
db_path.parent.mkdir(parents=True, exist_ok=True)
audit_path.parent.mkdir(parents=True, exist_ok=True)
conn = connect_state(db_path)
try:
run_migrations(conn)
repo = Repository()
with transaction(conn):
repo.init_system_state(
conn, config_version=config_version, now=datetime.now(UTC)
)
finally:
conn.close()
return KillSwitch(
connection_factory=lambda: connect_state(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
)
@kill_switch.command(name="arm") @kill_switch.command(name="arm")
@click.option("--reason", required=True, help="Why you are arming the kill switch.") @click.option("--reason", required=True, help="Why you are arming the kill switch.")
def kill_switch_arm(reason: str) -> None: @click.option(
"--source",
default="manual",
show_default=True,
help="Trigger label (manual, mcp_timeout, hash_chain, ...).",
)
@click.option(
"--db",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_DB_PATH,
show_default=True,
)
@click.option(
"--audit",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_AUDIT_PATH,
show_default=True,
)
@click.option(
"--config-version",
default="unknown",
show_default=True,
help="Recorded next to the kill event when the singleton is initialised.",
)
def kill_switch_arm(
reason: str, source: str, db: Path, audit: Path, config_version: str
) -> None:
"""Arm the kill switch (engine refuses new entries).""" """Arm the kill switch (engine refuses new entries)."""
_phase0_notice(f"kill-switch arm placeholder (reason: {reason!r}).") ks = _make_kill_switch(db, audit, config_version=config_version)
ks.arm(reason=reason, source=source)
console.print(f"[red]kill switch ARMED[/red] reason={reason!r} source={source}")
@kill_switch.command(name="disarm") @kill_switch.command(name="disarm")
@click.option("--reason", required=True, help="Why you are disarming.") @click.option("--reason", required=True, help="Why you are disarming.")
def kill_switch_disarm(reason: str) -> None: @click.option(
"--source",
default="manual",
show_default=True,
)
@click.option(
"--db",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_DB_PATH,
show_default=True,
)
@click.option(
"--audit",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_AUDIT_PATH,
show_default=True,
)
@click.option(
"--config-version",
default="unknown",
show_default=True,
)
def kill_switch_disarm(
reason: str, source: str, db: Path, audit: Path, config_version: str
) -> None:
"""Disarm the kill switch.""" """Disarm the kill switch."""
_phase0_notice(f"kill-switch disarm placeholder (reason: {reason!r}).") ks = _make_kill_switch(db, audit, config_version=config_version)
ks.disarm(reason=reason, source=source)
console.print(f"[green]kill switch DISARMED[/green] reason={reason!r}")
@kill_switch.command(name="status")
@click.option(
"--db",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_DB_PATH,
show_default=True,
)
def kill_switch_status(db: Path) -> None:
"""Print the current kill switch state."""
if not db.exists():
console.print("[yellow]state.sqlite not found — engine never ran[/yellow]")
return
conn = connect_state(db)
try:
run_migrations(conn)
state = Repository().get_system_state(conn)
finally:
conn.close()
if state is None:
console.print("[yellow]system_state singleton missing[/yellow]")
return
armed = state.kill_switch == 1
flag = "[red]ARMED[/red]" if armed else "[green]disarmed[/green]"
console.print(
f"kill_switch: {flag}\n"
f"reason: {state.kill_reason or '-'}\n"
f"kill_at: {state.kill_at.isoformat() if state.kill_at else '-'}\n"
f"last_health_check: {state.last_health_check.isoformat()}"
)
@main.command() @main.command()
@@ -123,9 +249,42 @@ def config() -> None:
@config.command(name="hash") @config.command(name="hash")
def config_hash() -> None: @click.option(
"""Compute and print SHA-256 of strategy.yaml.""" "--file",
_phase0_notice("config hash placeholder; will read strategy.yaml and compute SHA-256.") "yaml_path",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
default=_DEFAULT_STRATEGY_PATH,
show_default=True,
)
def config_hash(yaml_path: Path) -> None:
"""Compute and print the SHA-256 of *yaml_path* (config_hash field excluded)."""
text = yaml_path.read_text(encoding="utf-8")
digest = compute_config_hash(text)
console.print(digest)
@config.command(name="validate")
@click.option(
"--file",
"yaml_path",
type=click.Path(exists=True, dir_okay=False, path_type=Path),
default=_DEFAULT_STRATEGY_PATH,
show_default=True,
)
@click.option(
"--enforce-hash/--no-enforce-hash",
default=True,
show_default=True,
help="When enabled, the recorded config_hash must match the file body.",
)
def config_validate(yaml_path: Path, enforce_hash: bool) -> None:
"""Load and validate ``strategy.yaml`` (and any local override)."""
loaded = load_strategy(yaml_path, enforce_hash=enforce_hash)
console.print(
f"[green]ok[/green] version={loaded.config.config_version} "
f"hash={loaded.computed_hash[:16]}"
f"sources={', '.join(p.name for p in loaded.sources)}"
)
@main.group() @main.group()
@@ -134,9 +293,87 @@ def audit() -> None:
@audit.command(name="verify") @audit.command(name="verify")
def audit_verify() -> None: @click.option(
"""Verify audit chain integrity.""" "--file",
_phase0_notice("audit verify placeholder; will walk audit.log hash chain.") "audit_path",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_AUDIT_PATH,
show_default=True,
)
def audit_verify(audit_path: Path) -> None:
"""Walk the hash chain in *audit_path* and report tampering."""
try:
count = verify_audit_chain(audit_path)
except AuditChainError as exc:
console.print(f"[red]TAMPERED[/red]: {exc}")
sys.exit(2)
if count == 0:
console.print("[yellow]audit log empty[/yellow]")
else:
console.print(f"[green]ok[/green] {count} entries verified")
@main.group()
def state() -> None:
"""State inspection utilities."""
@state.command(name="inspect")
@click.option(
"--db",
type=click.Path(dir_okay=False, path_type=Path),
default=_DEFAULT_DB_PATH,
show_default=True,
)
def state_inspect(db: Path) -> None:
"""Print a short snapshot of the SQLite state file."""
if not db.exists():
console.print("[yellow]state.sqlite not found[/yellow]")
return
conn = connect_state(db)
try:
run_migrations(conn)
repo = Repository()
sys_state = repo.get_system_state(conn)
positions = repo.list_open_positions(conn)
concurrent = repo.count_concurrent_positions(conn)
finally:
conn.close()
if sys_state is None:
console.print("[yellow]system_state singleton missing[/yellow]")
return
armed = "[red]ARMED[/red]" if sys_state.kill_switch == 1 else "[green]disarmed[/green]"
console.print(
f"engine state: kill_switch={armed}, "
f"open positions: {concurrent}, "
f"config_version: {sys_state.config_version}"
)
if not positions:
console.print("no open positions")
return
table = Table(title="open positions")
table.add_column("proposal_id")
table.add_column("status")
table.add_column("spread")
table.add_column("short")
table.add_column("long")
table.add_column("n")
table.add_column("expiry")
for pos in positions:
table.add_row(
str(pos.proposal_id)[:8],
pos.status,
pos.spread_type,
str(pos.short_strike),
str(pos.long_strike),
str(pos.n_contracts),
pos.expiry.isoformat(),
)
console.print(table)
def _entrypoint() -> None: def _entrypoint() -> None:
+10
View File
@@ -1,5 +1,11 @@
"""Strategy configuration: schema, loader, validation.""" """Strategy configuration: schema, loader, validation."""
from cerbero_bite.config.loader import (
ConfigHashError,
LoadedConfig,
compute_config_hash,
load_strategy,
)
from cerbero_bite.config.schema import ( from cerbero_bite.config.schema import (
AssetConfig, AssetConfig,
DvolAdjustmentBand, DvolAdjustmentBand,
@@ -23,12 +29,14 @@ from cerbero_bite.config.schema import (
__all__ = [ __all__ = [
"AssetConfig", "AssetConfig",
"ConfigHashError",
"DvolAdjustmentBand", "DvolAdjustmentBand",
"EntryConfig", "EntryConfig",
"ExecutionConfig", "ExecutionConfig",
"ExitConfig", "ExitConfig",
"KellyConfig", "KellyConfig",
"LiquidityConfig", "LiquidityConfig",
"LoadedConfig",
"McpConfig", "McpConfig",
"MonitoringConfig", "MonitoringConfig",
"ShortStrikeSpec", "ShortStrikeSpec",
@@ -39,5 +47,7 @@ __all__ = [
"StrategyConfig", "StrategyConfig",
"StructureConfig", "StructureConfig",
"TelegramConfig", "TelegramConfig",
"compute_config_hash",
"golden_config", "golden_config",
"load_strategy",
] ]
+141
View File
@@ -0,0 +1,141 @@
"""YAML loader for ``strategy.yaml`` with optional local override.
* Reads ``strategy.yaml`` (golden config).
* If ``strategy.local.yaml`` exists alongside, deep-merges its keys on
top — that file is ``.gitignore``'d and used by Adriano for emergency
overrides.
* Verifies ``config_hash`` matches the SHA-256 of the YAML *minus* the
``config_hash`` line itself. A mismatch is reported via
:class:`ConfigHashError` and the orchestrator must arm the kill switch
per ``docs/07-risk-controls.md``.
The loader does *not* depend on the runtime: it returns a validated
:class:`StrategyConfig` plus the computed hash; nothing else.
"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
from cerbero_bite.config.schema import StrategyConfig
__all__ = [
"ConfigHashError",
"LoadedConfig",
"compute_config_hash",
"load_strategy",
]
_HASH_KEY = "config_hash"
class ConfigHashError(RuntimeError):
"""Raised when the recorded ``config_hash`` does not match the file."""
@dataclass(frozen=True)
class LoadedConfig:
"""Result of :func:`load_strategy`."""
config: StrategyConfig
computed_hash: str
sources: tuple[Path, ...]
def _strip_hash_line(text: str) -> str:
"""Return *text* with the ``config_hash:`` line replaced by an empty string.
We deliberately keep the surrounding whitespace so that any other
line numbers stay stable; only the value of the hash is removed.
"""
out: list[str] = []
for line in text.splitlines(keepends=True):
stripped = line.lstrip()
if stripped.startswith(f"{_HASH_KEY}:"):
# keep the key but strip the value, so identical files with
# different hashes still hash the same way
indent = line[: len(line) - len(stripped)]
out.append(f"{indent}{_HASH_KEY}:\n")
else:
out.append(line)
return "".join(out)
def compute_config_hash(text: str) -> str:
"""SHA-256 of the YAML text after stripping the ``config_hash`` value."""
canonical = _strip_hash_line(text).encode("utf-8")
return hashlib.sha256(canonical).hexdigest()
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Return ``base`` with values from ``override`` recursively merged in."""
out = dict(base)
for key, value in override.items():
existing = out.get(key)
if isinstance(existing, dict) and isinstance(value, dict):
out[key] = _deep_merge(existing, value)
else:
out[key] = value
return out
def _load_yaml(path: Path) -> dict[str, Any]:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if data is None:
return {}
if not isinstance(data, dict):
raise ValueError(f"{path}: expected a top-level mapping")
return data
def load_strategy(
yaml_path: Path | str,
*,
local_override_path: Path | str | None = None,
enforce_hash: bool = True,
) -> LoadedConfig:
"""Load and validate a strategy YAML, optionally merging a local file.
Args:
yaml_path: path to ``strategy.yaml``.
local_override_path: when ``None`` (default), use
``<yaml_path>.local.yaml`` if present. Pass ``False`` /
non-existent path to disable.
enforce_hash: when ``True``, raise :class:`ConfigHashError` if
the recorded hash does not match the file. Set to ``False``
in test fixtures or right after a manual edit.
"""
main_path = Path(yaml_path)
text = main_path.read_text(encoding="utf-8")
raw = _load_yaml(main_path)
sources: list[Path] = [main_path]
computed_hash = compute_config_hash(text)
declared_hash = raw.get(_HASH_KEY)
if enforce_hash and declared_hash != computed_hash:
raise ConfigHashError(
f"config_hash mismatch in {main_path}: "
f"declared={declared_hash}, computed={computed_hash}"
)
if local_override_path is None:
local_override_path = main_path.with_name(
main_path.stem + ".local" + main_path.suffix
)
override_path = Path(local_override_path)
if override_path.is_file():
override = _load_yaml(override_path)
raw = _deep_merge(raw, override)
sources.append(override_path)
return LoadedConfig(
config=StrategyConfig(**raw),
computed_hash=computed_hash,
sources=tuple(sources),
)
+19
View File
@@ -0,0 +1,19 @@
"""Cross-cutting safety controls: kill switch, dead-man, audit chain."""
from cerbero_bite.safety.audit_log import (
GENESIS_HASH,
AuditChainError,
AuditEntry,
AuditLog,
iter_entries,
verify_chain,
)
__all__ = [
"GENESIS_HASH",
"AuditChainError",
"AuditEntry",
"AuditLog",
"iter_entries",
"verify_chain",
]
+246
View File
@@ -0,0 +1,246 @@
"""Append-only hash-chained audit log (``docs/07-risk-controls.md``).
Every line is::
<iso-ts>|<event>|<json-payload>|prev_hash=<hex>|hash=<hex>
``hash`` is ``sha256("<iso-ts>|<event>|<json-payload>|<prev_hash>")`` so
that the integrity of the file can be re-verified by walking the chain
top-to-bottom. The first line uses ``prev_hash="0" * 64``.
The writer ``flush + os.fsync`` after every append; for the audit trail,
durability beats throughput.
"""
from __future__ import annotations
import hashlib
import json
import os
from collections.abc import Iterator
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
__all__ = [
"GENESIS_HASH",
"AuditChainError",
"AuditEntry",
"AuditLog",
"verify_chain",
]
GENESIS_HASH = "0" * 64
_SEP = "|"
class AuditChainError(RuntimeError):
"""Raised when the audit chain fails verification."""
@dataclass(frozen=True)
class AuditEntry:
"""Parsed audit-log line."""
timestamp: datetime
event: str
payload: dict[str, Any]
prev_hash: str
hash: str
def _canonical_payload(payload: dict[str, Any]) -> str:
"""Serialize the payload deterministically (sorted keys, no whitespace)."""
return json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str)
def _compute_hash(timestamp: str, event: str, payload_json: str, prev_hash: str) -> str:
raw = f"{timestamp}{_SEP}{event}{_SEP}{payload_json}{_SEP}{prev_hash}"
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _format_line(
timestamp: str, event: str, payload_json: str, prev_hash: str, line_hash: str
) -> str:
return (
f"{timestamp}{_SEP}{event}{_SEP}{payload_json}{_SEP}"
f"prev_hash={prev_hash}{_SEP}hash={line_hash}\n"
)
def _parse_line(line: str) -> AuditEntry:
"""Parse a stored audit line back into :class:`AuditEntry`.
The payload may legitimately contain ``|`` characters inside JSON
strings; we therefore split from the right for the two trailing
fields and from the left for the timestamp + event + payload.
"""
if not line.endswith("\n"):
line = line + "\n"
body = line.rstrip("\n")
# Trailing parts.
try:
rest, hash_part = body.rsplit(_SEP, 1)
except ValueError as exc:
raise AuditChainError("missing hash= field") from exc
if not hash_part.startswith("hash="):
raise AuditChainError("missing hash= field")
line_hash = hash_part[len("hash=") :]
try:
rest, prev_part = rest.rsplit(_SEP, 1)
except ValueError as exc:
raise AuditChainError("missing prev_hash= field") from exc
if not prev_part.startswith("prev_hash="):
raise AuditChainError("missing prev_hash= field")
prev_hash = prev_part[len("prev_hash=") :]
# Leading parts.
parts = rest.split(_SEP, 2)
if len(parts) != 3:
raise AuditChainError("malformed leading section")
ts_str, event, payload_json = parts
try:
payload: dict[str, Any] = json.loads(payload_json)
except json.JSONDecodeError as exc:
raise AuditChainError("payload is not valid JSON") from exc
if not isinstance(payload, dict):
raise AuditChainError("payload must be a JSON object")
return AuditEntry(
timestamp=datetime.fromisoformat(ts_str),
event=event,
payload=payload,
prev_hash=prev_hash,
hash=line_hash,
)
def verify_chain(path: str | Path) -> int:
"""Re-walk *path* and raise :class:`AuditChainError` on tampering.
Returns the number of lines verified (0 if the file does not exist
or is empty).
"""
p = Path(path)
if not p.exists() or p.stat().st_size == 0:
return 0
expected_prev = GENESIS_HASH
count = 0
with p.open("r", encoding="utf-8") as fh:
for lineno, line in enumerate(fh, start=1):
if not line.strip():
continue
entry = _parse_line(line)
if entry.prev_hash != expected_prev:
raise AuditChainError(
f"line {lineno}: prev_hash mismatch "
f"(expected {expected_prev}, got {entry.prev_hash})"
)
recomputed = _compute_hash(
entry.timestamp.isoformat(),
entry.event,
_canonical_payload(entry.payload),
entry.prev_hash,
)
if recomputed != entry.hash:
raise AuditChainError(
f"line {lineno}: hash mismatch (expected {recomputed}, "
f"got {entry.hash})"
)
expected_prev = entry.hash
count += 1
return count
def iter_entries(path: str | Path) -> Iterator[AuditEntry]:
"""Yield each :class:`AuditEntry` from *path* without verifying."""
p = Path(path)
if not p.exists():
return
with p.open("r", encoding="utf-8") as fh:
for line in fh:
if line.strip():
yield _parse_line(line)
class AuditLog:
"""Writer for the hash-chained audit log.
A single instance per process is enough; concurrent writers are not
supported by design (the engine is the only writer). ``append`` is
fsync'd before returning.
"""
def __init__(self, path: str | Path) -> None:
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._last_hash: str = self._tail_hash() or GENESIS_HASH
@property
def path(self) -> Path: # pragma: no cover — accessor used by callers only
return self._path
@property
def last_hash(self) -> str:
return self._last_hash
def _tail_hash(self) -> str | None:
if not self._path.exists() or self._path.stat().st_size == 0:
return None
# Walk from EOF to find the last non-empty line. The chunked
# back-seek covers files larger than 4 KiB; the loop-exhausted
# branch is reached only when a partial / no-newline file is
# encountered (defensive — :func:`append` always writes "\n").
with self._path.open("rb") as fh:
fh.seek(0, os.SEEK_END)
size = fh.tell()
buf = b""
offset = size
chunk = 4096
while offset > 0: # pragma: no branch — terminates via break or offset==0
read = min(chunk, offset)
offset -= read
fh.seek(offset)
buf = fh.read(read) + buf
if b"\n" in buf:
break
text = buf.decode("utf-8", errors="strict")
for line in reversed(text.splitlines()): # pragma: no branch
if line.strip():
entry = _parse_line(line)
return entry.hash
return None # pragma: no cover — only hit when file is all blank lines
def append(
self,
*,
event: str,
payload: dict[str, Any] | None = None,
now: datetime | None = None,
) -> AuditEntry:
"""Append one event line and return the resulting entry."""
ts = (now or datetime.now(UTC)).astimezone(UTC)
ts_iso = ts.isoformat()
payload_json = _canonical_payload(payload or {})
prev_hash = self._last_hash
line_hash = _compute_hash(ts_iso, event, payload_json, prev_hash)
line = _format_line(ts_iso, event, payload_json, prev_hash, line_hash)
with self._path.open("a", encoding="utf-8") as fh:
fh.write(line)
fh.flush()
os.fsync(fh.fileno())
self._last_hash = line_hash
return AuditEntry(
timestamp=ts,
event=event,
payload=dict(payload or {}),
prev_hash=prev_hash,
hash=line_hash,
)
+120
View File
@@ -0,0 +1,120 @@
"""Kill switch coordinator (``docs/07-risk-controls.md``).
Encapsulates the three side effects required when arming/disarming:
1. Update the ``system_state`` row in SQLite.
2. Append a tamper-evident line to the audit log.
3. Make the new state observable via :meth:`is_armed` for the
orchestrator.
The orchestrator is the only caller; the GUI signals the same intent
via the :class:`ManualAction` queue, which the orchestrator drains and
forwards here.
"""
from __future__ import annotations
import sqlite3
from collections.abc import Callable
from datetime import UTC, datetime
from cerbero_bite.safety.audit_log import AuditLog
from cerbero_bite.state import Repository, transaction
__all__ = ["KillSwitch", "KillSwitchError"]
class KillSwitchError(RuntimeError):
"""Raised when an arm/disarm transition is invalid."""
class KillSwitch:
"""Arm/disarm + status helper backed by ``system_state`` + audit log.
All transitions must go through this class so the SQLite row and
the audit chain stay in lock-step.
"""
def __init__(
self,
*,
connection_factory: Callable[[], sqlite3.Connection],
repository: Repository,
audit_log: AuditLog,
clock: Callable[[], datetime] | None = None,
) -> None:
self._connect = connection_factory
self._repo = repository
self._audit = audit_log
self._clock = clock or (lambda: datetime.now(UTC))
# ------------------------------------------------------------------
# Status
# ------------------------------------------------------------------
def is_armed(self) -> bool:
conn = self._connect()
try:
state = self._repo.get_system_state(conn)
finally:
conn.close()
if state is None:
return False
return state.kill_switch == 1
# ------------------------------------------------------------------
# Transitions
# ------------------------------------------------------------------
def arm(self, *, reason: str, source: str) -> None:
"""Move the engine to the armed state.
``source`` is a short label (``"manual"``, ``"mcp_timeout"``,
``"hash_chain"``) that goes both into the audit payload and the
log so we can attribute the trigger later.
"""
if not reason:
raise KillSwitchError("reason is required to arm the kill switch")
now = self._clock()
conn = self._connect()
try:
with transaction(conn):
state = self._repo.get_system_state(conn)
if state is None:
raise KillSwitchError(
"system_state singleton missing — call init_system_state first"
)
if state.kill_switch == 1:
return # idempotent
self._repo.set_kill_switch(conn, armed=True, reason=reason, now=now)
finally:
conn.close()
self._audit.append(
event="KILL_SWITCH_ARMED",
payload={"reason": reason, "source": source},
now=now,
)
def disarm(self, *, reason: str, source: str) -> None:
"""Move the engine back to the disarmed state."""
if not reason:
raise KillSwitchError("reason is required to disarm the kill switch")
now = self._clock()
conn = self._connect()
try:
with transaction(conn):
state = self._repo.get_system_state(conn)
if state is None:
raise KillSwitchError(
"system_state singleton missing — call init_system_state first"
)
if state.kill_switch == 0:
return # idempotent
self._repo.set_kill_switch(conn, armed=False, reason=None, now=now)
finally:
conn.close()
self._audit.append(
event="KILL_SWITCH_DISARMED",
payload={"reason": reason, "source": source},
now=now,
)
+27
View File
@@ -0,0 +1,27 @@
"""Persistent state: SQLite schema, migrations, typed repository."""
from cerbero_bite.state.db import connect, run_migrations, transaction
from cerbero_bite.state.models import (
DecisionRecord,
DvolSnapshot,
InstructionRecord,
ManualAction,
PositionRecord,
PositionStatus,
SystemStateRecord,
)
from cerbero_bite.state.repository import Repository
__all__ = [
"DecisionRecord",
"DvolSnapshot",
"InstructionRecord",
"ManualAction",
"PositionRecord",
"PositionStatus",
"Repository",
"SystemStateRecord",
"connect",
"run_migrations",
"transaction",
]
+102
View File
@@ -0,0 +1,102 @@
"""SQLite connection helpers and forward-only migrations.
Connections use ``sqlite3`` with a few pragmas tuned for our workload:
* ``foreign_keys=ON`` — referential integrity for ``instructions`` →
``positions``.
* ``journal_mode=WAL`` — concurrent reader (the GUI) without blocking
the engine writer.
* ``synchronous=NORMAL`` — durable enough for our use-case (we also
append-fsync the audit chain) and noticeably faster than FULL.
Migrations live in ``state/migrations/NNNN_<name>.sql``. They are
applied in numeric order, each bumps ``PRAGMA user_version`` at the
last statement of the file. ``run_migrations`` is idempotent: it
re-checks ``user_version`` and replays only what is missing.
"""
from __future__ import annotations
import re
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from importlib import resources
from pathlib import Path
__all__ = ["connect", "current_version", "list_migrations", "run_migrations"]
_MIGRATION_PATTERN = re.compile(r"^(\d{4})_[a-z0-9_]+\.sql$")
def _apply_pragmas(conn: sqlite3.Connection) -> None:
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
conn.execute("PRAGMA synchronous = NORMAL")
def connect(path: str | Path) -> sqlite3.Connection:
"""Open a connection to ``path`` with the standard pragmas applied."""
db_path = Path(path)
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(
str(db_path),
isolation_level=None, # autocommit; we manage transactions explicitly
detect_types=sqlite3.PARSE_DECLTYPES,
)
conn.row_factory = sqlite3.Row
_apply_pragmas(conn)
return conn
@contextmanager
def transaction(conn: sqlite3.Connection) -> Iterator[sqlite3.Connection]:
"""Begin/commit/rollback wrapper that respects autocommit mode."""
conn.execute("BEGIN")
try:
yield conn
except Exception:
conn.execute("ROLLBACK")
raise
else:
conn.execute("COMMIT")
def current_version(conn: sqlite3.Connection) -> int:
row = conn.execute("PRAGMA user_version").fetchone()
return int(row[0])
def list_migrations() -> list[tuple[int, str, str]]:
"""Return ``[(version, filename, sql_text), ...]`` sorted by version."""
pkg_files = resources.files("cerbero_bite.state.migrations")
out: list[tuple[int, str, str]] = []
for entry in pkg_files.iterdir():
match = _MIGRATION_PATTERN.match(entry.name)
if not match:
continue
version = int(match.group(1))
out.append((version, entry.name, entry.read_text(encoding="utf-8")))
out.sort(key=lambda triple: triple[0])
return out
def run_migrations(conn: sqlite3.Connection) -> int:
"""Apply migrations whose version is greater than the current one.
Returns the new ``user_version`` after the run.
"""
applied_to = current_version(conn)
for version, name, sql in list_migrations():
if version <= applied_to:
continue
conn.executescript(sql)
new_version = current_version(conn)
if new_version != version:
raise RuntimeError(
f"Migration {name} did not bump user_version to {version} "
f"(observed {new_version})"
)
applied_to = new_version
return applied_to
@@ -0,0 +1,101 @@
-- 0001_init.sql — initial schema for Cerbero Bite (docs/05-data-model.md)
--
-- Forward-only. Run by state/migrations.py once user_version == 0.
-- Bumps user_version to 1 at the end.
PRAGMA foreign_keys = ON;
CREATE TABLE positions (
proposal_id TEXT PRIMARY KEY,
spread_type TEXT NOT NULL,
asset TEXT NOT NULL DEFAULT 'ETH',
expiry TEXT NOT NULL,
short_strike NUMERIC NOT NULL,
long_strike NUMERIC NOT NULL,
short_instrument TEXT NOT NULL,
long_instrument TEXT NOT NULL,
n_contracts INTEGER NOT NULL,
spread_width_usd NUMERIC NOT NULL,
spread_width_pct NUMERIC NOT NULL,
credit_eth NUMERIC NOT NULL,
credit_usd NUMERIC NOT NULL,
max_loss_usd NUMERIC NOT NULL,
spot_at_entry NUMERIC NOT NULL,
dvol_at_entry NUMERIC NOT NULL,
delta_at_entry NUMERIC NOT NULL,
eth_price_at_entry NUMERIC NOT NULL,
proposed_at TEXT NOT NULL,
opened_at TEXT,
closed_at TEXT,
close_reason TEXT,
debit_paid_eth NUMERIC,
pnl_eth NUMERIC,
pnl_usd NUMERIC,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX idx_positions_status ON positions(status);
CREATE INDEX idx_positions_closed_at ON positions(closed_at);
CREATE TABLE instructions (
instruction_id TEXT PRIMARY KEY,
proposal_id TEXT NOT NULL REFERENCES positions(proposal_id),
kind TEXT NOT NULL,
payload_json TEXT NOT NULL,
sent_at TEXT NOT NULL,
acknowledged_at TEXT,
filled_at TEXT,
cancelled_at TEXT,
actual_fill_eth NUMERIC,
actual_fees_eth NUMERIC
);
CREATE INDEX idx_instructions_proposal ON instructions(proposal_id);
CREATE TABLE decisions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
decision_type TEXT NOT NULL,
proposal_id TEXT,
timestamp TEXT NOT NULL,
inputs_json TEXT NOT NULL,
outputs_json TEXT NOT NULL,
action_taken TEXT,
notes TEXT
);
CREATE INDEX idx_decisions_timestamp ON decisions(timestamp);
CREATE INDEX idx_decisions_proposal ON decisions(proposal_id);
CREATE TABLE dvol_history (
timestamp TEXT PRIMARY KEY,
dvol NUMERIC NOT NULL,
eth_spot NUMERIC NOT NULL
);
CREATE TABLE manual_actions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
kind TEXT NOT NULL,
proposal_id TEXT,
payload_json TEXT,
created_at TEXT NOT NULL,
consumed_at TEXT,
consumed_by TEXT,
result TEXT
);
CREATE INDEX idx_manual_actions_unconsumed ON manual_actions(consumed_at);
CREATE TABLE system_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
kill_switch INTEGER NOT NULL DEFAULT 0,
kill_reason TEXT,
kill_at TEXT,
last_health_check TEXT NOT NULL,
last_kelly_calib TEXT,
config_version TEXT NOT NULL,
started_at TEXT NOT NULL
);
PRAGMA user_version = 1;
@@ -0,0 +1 @@
"""Forward-only SQL migrations for the Cerbero Bite state database."""
+154
View File
@@ -0,0 +1,154 @@
"""Pydantic record types mirroring the SQLite tables.
Every numeric column documented as ``NUMERIC`` in
``state/migrations/0001_init.sql`` is exposed as :class:`decimal.Decimal`
on the Python side. The repository layer is responsible for serialising
to ``TEXT`` (using ``str``) when writing and parsing back when reading,
so precision is never lost via ``float`` coercion.
"""
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from typing import Literal
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
__all__ = [
"DecisionRecord",
"DvolSnapshot",
"InstructionRecord",
"ManualAction",
"PositionRecord",
"PositionStatus",
"SystemStateRecord",
]
PositionStatus = Literal[
"proposed",
"awaiting_fill",
"open",
"closing",
"closed",
"cancelled",
]
class PositionRecord(BaseModel):
"""Row of the ``positions`` table."""
model_config = ConfigDict(extra="forbid")
proposal_id: UUID
spread_type: str
asset: str = "ETH"
expiry: datetime
short_strike: Decimal
long_strike: Decimal
short_instrument: str
long_instrument: str
n_contracts: int
spread_width_usd: Decimal
spread_width_pct: Decimal
credit_eth: Decimal
credit_usd: Decimal
max_loss_usd: Decimal
spot_at_entry: Decimal
dvol_at_entry: Decimal
delta_at_entry: Decimal
eth_price_at_entry: Decimal
proposed_at: datetime
opened_at: datetime | None = None
closed_at: datetime | None = None
close_reason: str | None = None
debit_paid_eth: Decimal | None = None
pnl_eth: Decimal | None = None
pnl_usd: Decimal | None = None
status: PositionStatus
created_at: datetime
updated_at: datetime
class InstructionRecord(BaseModel):
"""Row of the ``instructions`` table."""
model_config = ConfigDict(extra="forbid")
instruction_id: UUID
proposal_id: UUID
kind: Literal["open_combo", "close_combo"]
payload_json: str
sent_at: datetime
acknowledged_at: datetime | None = None
filled_at: datetime | None = None
cancelled_at: datetime | None = None
actual_fill_eth: Decimal | None = None
actual_fees_eth: Decimal | None = None
class DecisionRecord(BaseModel):
"""Row of the ``decisions`` table.
``id`` is :class:`int` and may be ``None`` before the row has been
inserted; the repository sets it after the auto-increment fires.
"""
model_config = ConfigDict(extra="forbid")
id: int | None = None
decision_type: Literal["entry_check", "exit_check", "kelly_recalib"]
proposal_id: UUID | None = None
timestamp: datetime
inputs_json: str
outputs_json: str
action_taken: str | None = None
notes: str | None = None
class DvolSnapshot(BaseModel):
"""Row of the ``dvol_history`` table."""
model_config = ConfigDict(extra="forbid")
timestamp: datetime
dvol: Decimal
eth_spot: Decimal
class ManualAction(BaseModel):
"""Row of the ``manual_actions`` table."""
model_config = ConfigDict(extra="forbid")
id: int | None = None
kind: Literal[
"approve_proposal",
"reject_proposal",
"force_close",
"arm_kill",
"disarm_kill",
]
proposal_id: UUID | None = None
payload_json: str | None = None
created_at: datetime
consumed_at: datetime | None = None
consumed_by: str | None = None
result: str | None = None
class SystemStateRecord(BaseModel):
"""Singleton row of the ``system_state`` table."""
model_config = ConfigDict(extra="forbid")
id: int = Field(default=1)
kill_switch: int = 0
kill_reason: str | None = None
kill_at: datetime | None = None
last_health_check: datetime
last_kelly_calib: datetime | None = None
config_version: str
started_at: datetime
+553
View File
@@ -0,0 +1,553 @@
"""Typed CRUD layer over the SQLite database.
All methods take a :class:`sqlite3.Connection` so callers can compose a
single transaction across multiple writes (the orchestrator does this
when persisting an entry decision + the resulting proposal). The
repository never opens its own connection: that responsibility is left
to :func:`cerbero_bite.state.db.connect`.
Decimals are stored as TEXT to preserve precision (see
``state/models.py``).
"""
from __future__ import annotations
import sqlite3
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
from cerbero_bite.state.models import (
DecisionRecord,
DvolSnapshot,
InstructionRecord,
ManualAction,
PositionRecord,
PositionStatus,
SystemStateRecord,
)
__all__ = ["Repository"]
# ---------------------------------------------------------------------------
# Encoding helpers
# ---------------------------------------------------------------------------
def _enc_dec(value: Decimal | None) -> str | None:
return None if value is None else str(value)
def _enc_dt(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
return value.astimezone(UTC).isoformat()
def _enc_uuid(value: UUID | None) -> str | None:
return None if value is None else str(value)
def _dec_dec(value: Any) -> Decimal | None:
if value is None:
return None
return Decimal(str(value))
def _dec_dt(value: Any) -> datetime | None:
if value is None:
return None
return datetime.fromisoformat(str(value))
def _dec_uuid(value: Any) -> UUID | None:
if value is None:
return None
return UUID(str(value))
# ---------------------------------------------------------------------------
# Repository
# ---------------------------------------------------------------------------
class Repository:
"""Typed CRUD wrapper. One instance per process, all calls take a conn."""
# ------------------------------------------------------------------
# positions
# ------------------------------------------------------------------
def create_position(self, conn: sqlite3.Connection, record: PositionRecord) -> None:
conn.execute(
"""
INSERT INTO positions (
proposal_id, spread_type, asset, expiry,
short_strike, long_strike, short_instrument, long_instrument,
n_contracts, spread_width_usd, spread_width_pct,
credit_eth, credit_usd, max_loss_usd,
spot_at_entry, dvol_at_entry, delta_at_entry, eth_price_at_entry,
proposed_at, opened_at, closed_at, close_reason,
debit_paid_eth, pnl_eth, pnl_usd,
status, created_at, updated_at
) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)
""",
(
_enc_uuid(record.proposal_id),
record.spread_type,
record.asset,
_enc_dt(record.expiry),
_enc_dec(record.short_strike),
_enc_dec(record.long_strike),
record.short_instrument,
record.long_instrument,
record.n_contracts,
_enc_dec(record.spread_width_usd),
_enc_dec(record.spread_width_pct),
_enc_dec(record.credit_eth),
_enc_dec(record.credit_usd),
_enc_dec(record.max_loss_usd),
_enc_dec(record.spot_at_entry),
_enc_dec(record.dvol_at_entry),
_enc_dec(record.delta_at_entry),
_enc_dec(record.eth_price_at_entry),
_enc_dt(record.proposed_at),
_enc_dt(record.opened_at),
_enc_dt(record.closed_at),
record.close_reason,
_enc_dec(record.debit_paid_eth),
_enc_dec(record.pnl_eth),
_enc_dec(record.pnl_usd),
record.status,
_enc_dt(record.created_at),
_enc_dt(record.updated_at),
),
)
def get_position(
self, conn: sqlite3.Connection, proposal_id: UUID
) -> PositionRecord | None:
row = conn.execute(
"SELECT * FROM positions WHERE proposal_id = ?",
(_enc_uuid(proposal_id),),
).fetchone()
return None if row is None else _row_to_position(row)
def list_positions(
self,
conn: sqlite3.Connection,
*,
status: PositionStatus | None = None,
) -> list[PositionRecord]:
if status is None:
cursor = conn.execute(
"SELECT * FROM positions ORDER BY created_at DESC"
)
else:
cursor = conn.execute(
"SELECT * FROM positions WHERE status = ? ORDER BY created_at DESC",
(status,),
)
return [_row_to_position(r) for r in cursor.fetchall()]
def list_open_positions(self, conn: sqlite3.Connection) -> list[PositionRecord]:
rows = conn.execute(
"SELECT * FROM positions WHERE status IN ('open', 'awaiting_fill', "
"'closing') ORDER BY created_at DESC"
).fetchall()
return [_row_to_position(r) for r in rows]
def update_position_status(
self,
conn: sqlite3.Connection,
proposal_id: UUID,
*,
status: PositionStatus,
opened_at: datetime | None = None,
closed_at: datetime | None = None,
close_reason: str | None = None,
debit_paid_eth: Decimal | None = None,
pnl_eth: Decimal | None = None,
pnl_usd: Decimal | None = None,
now: datetime,
) -> None:
sets: list[str] = ["status = ?", "updated_at = ?"]
params: list[Any] = [status, _enc_dt(now)]
if opened_at is not None:
sets.append("opened_at = ?")
params.append(_enc_dt(opened_at))
if closed_at is not None:
sets.append("closed_at = ?")
params.append(_enc_dt(closed_at))
if close_reason is not None:
sets.append("close_reason = ?")
params.append(close_reason)
if debit_paid_eth is not None:
sets.append("debit_paid_eth = ?")
params.append(_enc_dec(debit_paid_eth))
if pnl_eth is not None:
sets.append("pnl_eth = ?")
params.append(_enc_dec(pnl_eth))
if pnl_usd is not None:
sets.append("pnl_usd = ?")
params.append(_enc_dec(pnl_usd))
params.append(_enc_uuid(proposal_id))
conn.execute(
f"UPDATE positions SET {', '.join(sets)} WHERE proposal_id = ?",
params,
)
def count_concurrent_positions(self, conn: sqlite3.Connection) -> int:
row = conn.execute(
"SELECT COUNT(*) FROM positions WHERE status IN "
"('awaiting_fill', 'open', 'closing')"
).fetchone()
return int(row[0])
# ------------------------------------------------------------------
# instructions
# ------------------------------------------------------------------
def create_instruction(
self, conn: sqlite3.Connection, record: InstructionRecord
) -> None:
conn.execute(
"""
INSERT INTO instructions (
instruction_id, proposal_id, kind, payload_json,
sent_at, acknowledged_at, filled_at, cancelled_at,
actual_fill_eth, actual_fees_eth
) VALUES (?,?,?,?,?,?,?,?,?,?)
""",
(
_enc_uuid(record.instruction_id),
_enc_uuid(record.proposal_id),
record.kind,
record.payload_json,
_enc_dt(record.sent_at),
_enc_dt(record.acknowledged_at),
_enc_dt(record.filled_at),
_enc_dt(record.cancelled_at),
_enc_dec(record.actual_fill_eth),
_enc_dec(record.actual_fees_eth),
),
)
def update_instruction(
self,
conn: sqlite3.Connection,
instruction_id: UUID,
*,
acknowledged_at: datetime | None = None,
filled_at: datetime | None = None,
cancelled_at: datetime | None = None,
actual_fill_eth: Decimal | None = None,
actual_fees_eth: Decimal | None = None,
) -> None:
sets: list[str] = []
params: list[Any] = []
if acknowledged_at is not None:
sets.append("acknowledged_at = ?")
params.append(_enc_dt(acknowledged_at))
if filled_at is not None:
sets.append("filled_at = ?")
params.append(_enc_dt(filled_at))
if cancelled_at is not None:
sets.append("cancelled_at = ?")
params.append(_enc_dt(cancelled_at))
if actual_fill_eth is not None:
sets.append("actual_fill_eth = ?")
params.append(_enc_dec(actual_fill_eth))
if actual_fees_eth is not None:
sets.append("actual_fees_eth = ?")
params.append(_enc_dec(actual_fees_eth))
if not sets:
return
params.append(_enc_uuid(instruction_id))
conn.execute(
f"UPDATE instructions SET {', '.join(sets)} "
f"WHERE instruction_id = ?",
params,
)
def list_instructions(
self, conn: sqlite3.Connection, proposal_id: UUID
) -> list[InstructionRecord]:
rows = conn.execute(
"SELECT * FROM instructions WHERE proposal_id = ? ORDER BY sent_at",
(_enc_uuid(proposal_id),),
).fetchall()
return [_row_to_instruction(r) for r in rows]
# ------------------------------------------------------------------
# decisions
# ------------------------------------------------------------------
def record_decision(
self, conn: sqlite3.Connection, record: DecisionRecord
) -> int:
cursor = conn.execute(
"""
INSERT INTO decisions (
decision_type, proposal_id, timestamp,
inputs_json, outputs_json, action_taken, notes
) VALUES (?,?,?,?,?,?,?)
""",
(
record.decision_type,
_enc_uuid(record.proposal_id),
_enc_dt(record.timestamp),
record.inputs_json,
record.outputs_json,
record.action_taken,
record.notes,
),
)
return int(cursor.lastrowid or 0)
def list_decisions(
self,
conn: sqlite3.Connection,
*,
proposal_id: UUID | None = None,
limit: int = 100,
) -> list[DecisionRecord]:
if proposal_id is None:
rows = conn.execute(
"SELECT * FROM decisions ORDER BY timestamp DESC LIMIT ?",
(limit,),
).fetchall()
else:
rows = conn.execute(
"SELECT * FROM decisions WHERE proposal_id = ? "
"ORDER BY timestamp DESC LIMIT ?",
(_enc_uuid(proposal_id), limit),
).fetchall()
return [_row_to_decision(r) for r in rows]
# ------------------------------------------------------------------
# dvol_history
# ------------------------------------------------------------------
def record_dvol_snapshot(
self, conn: sqlite3.Connection, snapshot: DvolSnapshot
) -> None:
conn.execute(
"INSERT OR REPLACE INTO dvol_history(timestamp, dvol, eth_spot) "
"VALUES (?,?,?)",
(
_enc_dt(snapshot.timestamp),
_enc_dec(snapshot.dvol),
_enc_dec(snapshot.eth_spot),
),
)
# ------------------------------------------------------------------
# manual_actions
# ------------------------------------------------------------------
def enqueue_manual_action(
self, conn: sqlite3.Connection, action: ManualAction
) -> int:
cursor = conn.execute(
"INSERT INTO manual_actions(kind, proposal_id, payload_json, created_at) "
"VALUES (?,?,?,?)",
(
action.kind,
_enc_uuid(action.proposal_id),
action.payload_json,
_enc_dt(action.created_at),
),
)
return int(cursor.lastrowid or 0)
def next_unconsumed_action(
self, conn: sqlite3.Connection
) -> ManualAction | None:
row = conn.execute(
"SELECT * FROM manual_actions WHERE consumed_at IS NULL "
"ORDER BY created_at ASC LIMIT 1"
).fetchone()
return None if row is None else _row_to_manual(row)
def mark_action_consumed(
self,
conn: sqlite3.Connection,
action_id: int,
*,
consumed_by: str,
result: str,
now: datetime,
) -> None:
conn.execute(
"UPDATE manual_actions SET consumed_at = ?, consumed_by = ?, "
"result = ? WHERE id = ?",
(_enc_dt(now), consumed_by, result, action_id),
)
# ------------------------------------------------------------------
# system_state
# ------------------------------------------------------------------
def init_system_state(
self, conn: sqlite3.Connection, *, config_version: str, now: datetime
) -> None:
"""Insert the singleton row if it does not already exist."""
existing = conn.execute(
"SELECT 1 FROM system_state WHERE id = 1"
).fetchone()
if existing is not None:
return
conn.execute(
"INSERT INTO system_state(id, kill_switch, last_health_check, "
"config_version, started_at) VALUES (1, 0, ?, ?, ?)",
(_enc_dt(now), config_version, _enc_dt(now)),
)
def get_system_state(
self, conn: sqlite3.Connection
) -> SystemStateRecord | None:
row = conn.execute("SELECT * FROM system_state WHERE id = 1").fetchone()
if row is None:
return None
return SystemStateRecord(
id=int(row["id"]),
kill_switch=int(row["kill_switch"]),
kill_reason=row["kill_reason"],
kill_at=_dec_dt(row["kill_at"]),
last_health_check=_dec_dt_required(row["last_health_check"]),
last_kelly_calib=_dec_dt(row["last_kelly_calib"]),
config_version=row["config_version"],
started_at=_dec_dt_required(row["started_at"]),
)
def set_kill_switch(
self,
conn: sqlite3.Connection,
*,
armed: bool,
reason: str | None,
now: datetime,
) -> None:
conn.execute(
"UPDATE system_state SET kill_switch = ?, kill_reason = ?, "
"kill_at = ?, last_health_check = ? WHERE id = 1",
(
1 if armed else 0,
reason,
_enc_dt(now) if armed else None,
_enc_dt(now),
),
)
def touch_health_check(
self, conn: sqlite3.Connection, *, now: datetime
) -> None:
conn.execute(
"UPDATE system_state SET last_health_check = ? WHERE id = 1",
(_enc_dt(now),),
)
# ---------------------------------------------------------------------------
# Row → model converters
# ---------------------------------------------------------------------------
def _dec_dt_required(value: Any) -> datetime:
out = _dec_dt(value)
if out is None:
raise ValueError("expected non-null datetime in row")
return out
def _row_to_position(row: sqlite3.Row) -> PositionRecord:
proposal_id = _dec_uuid(row["proposal_id"])
if proposal_id is None:
raise ValueError("positions.proposal_id was NULL")
return PositionRecord(
proposal_id=proposal_id,
spread_type=row["spread_type"],
asset=row["asset"],
expiry=_dec_dt_required(row["expiry"]),
short_strike=_dec_dec_required(row["short_strike"]),
long_strike=_dec_dec_required(row["long_strike"]),
short_instrument=row["short_instrument"],
long_instrument=row["long_instrument"],
n_contracts=int(row["n_contracts"]),
spread_width_usd=_dec_dec_required(row["spread_width_usd"]),
spread_width_pct=_dec_dec_required(row["spread_width_pct"]),
credit_eth=_dec_dec_required(row["credit_eth"]),
credit_usd=_dec_dec_required(row["credit_usd"]),
max_loss_usd=_dec_dec_required(row["max_loss_usd"]),
spot_at_entry=_dec_dec_required(row["spot_at_entry"]),
dvol_at_entry=_dec_dec_required(row["dvol_at_entry"]),
delta_at_entry=_dec_dec_required(row["delta_at_entry"]),
eth_price_at_entry=_dec_dec_required(row["eth_price_at_entry"]),
proposed_at=_dec_dt_required(row["proposed_at"]),
opened_at=_dec_dt(row["opened_at"]),
closed_at=_dec_dt(row["closed_at"]),
close_reason=row["close_reason"],
debit_paid_eth=_dec_dec(row["debit_paid_eth"]),
pnl_eth=_dec_dec(row["pnl_eth"]),
pnl_usd=_dec_dec(row["pnl_usd"]),
status=row["status"],
created_at=_dec_dt_required(row["created_at"]),
updated_at=_dec_dt_required(row["updated_at"]),
)
def _row_to_instruction(row: sqlite3.Row) -> InstructionRecord:
instruction_id = _dec_uuid(row["instruction_id"])
proposal_id = _dec_uuid(row["proposal_id"])
if instruction_id is None or proposal_id is None:
raise ValueError("instructions row missing required UUID")
return InstructionRecord(
instruction_id=instruction_id,
proposal_id=proposal_id,
kind=row["kind"],
payload_json=row["payload_json"],
sent_at=_dec_dt_required(row["sent_at"]),
acknowledged_at=_dec_dt(row["acknowledged_at"]),
filled_at=_dec_dt(row["filled_at"]),
cancelled_at=_dec_dt(row["cancelled_at"]),
actual_fill_eth=_dec_dec(row["actual_fill_eth"]),
actual_fees_eth=_dec_dec(row["actual_fees_eth"]),
)
def _row_to_decision(row: sqlite3.Row) -> DecisionRecord:
return DecisionRecord(
id=int(row["id"]),
decision_type=row["decision_type"],
proposal_id=_dec_uuid(row["proposal_id"]),
timestamp=_dec_dt_required(row["timestamp"]),
inputs_json=row["inputs_json"],
outputs_json=row["outputs_json"],
action_taken=row["action_taken"],
notes=row["notes"],
)
def _row_to_manual(row: sqlite3.Row) -> ManualAction:
return ManualAction(
id=int(row["id"]),
kind=row["kind"],
proposal_id=_dec_uuid(row["proposal_id"]),
payload_json=row["payload_json"],
created_at=_dec_dt_required(row["created_at"]),
consumed_at=_dec_dt(row["consumed_at"]),
consumed_by=row["consumed_by"],
result=row["result"],
)
def _dec_dec_required(value: Any) -> Decimal:
out = _dec_dec(value)
if out is None:
raise ValueError("expected non-null Decimal in row")
return out
+15
View File
@@ -0,0 +1,15 @@
# strategy.local.yaml — local override (gitignored).
#
# Copy to strategy.local.yaml and edit only the keys you need to
# change. Top-level sections are deep-merged onto strategy.yaml at
# load time; the merged result is logged as OVERRIDE_APPLIED.
#
# Typical use cases:
# * Halve cap_per_trade in dry-run.
# * Force max_concurrent_positions to 0 to freeze entries without
# stopping the engine.
# * Lower kelly_fraction temporarily after a drawdown.
# Example: emergency entry freeze.
# sizing:
# max_concurrent_positions: 0
+149
View File
@@ -0,0 +1,149 @@
# strategy.yaml — Cerberus Bite golden config v1.0.0
#
# Source of truth for every threshold consumed by the rule engine.
# Modifying this file is an explicit decision of Adriano. Each change
# bumps `config_version`, regenerates `config_hash` (cerbero-bite
# config hash), and lands as a separate commit with the motivation in
# the commit message.
config_version: "1.0.0"
config_hash: "a857dc4b187cbdf5ac3f04c4aad48ab7587659bc9a3139db206566e10e2fa5e5"
last_review: "2026-04-26"
last_reviewer: "Adriano"
asset:
symbol: "ETH"
exchange: "deribit"
entry:
cron: "0 14 * * MON"
skip_holidays_country: "IT"
capital_min_usd: "720"
dvol_min: "35"
dvol_max: "90"
funding_perp_abs_max_annualized: "0.80"
eth_holdings_pct_max: "0.30"
no_position_concurrent: true
exclude_macro_severity: ["high"]
exclude_macro_countries: ["US", "EU"]
trend_window_days: 30
trend_bull_threshold_pct: "0.05"
trend_bear_threshold_pct: "-0.05"
funding_bull_threshold_annualized: "0.20"
funding_bear_threshold_annualized: "-0.20"
iron_condor_dvol_min: "55"
iron_condor_adx_max: "20"
iron_condor_trend_neutral_band_pct: "0.05"
structure:
dte_target: 18
dte_min: 14
dte_max: 21
short_strike:
delta_target: "0.12"
delta_min: "0.10"
delta_max: "0.15"
distance_otm_pct_min: "0.15"
distance_otm_pct_max: "0.25"
spread_width:
target_pct_of_spot: "0.04"
min_pct_of_spot: "0.03"
max_pct_of_spot: "0.05"
credit_to_width_ratio_min: "0.30"
liquidity:
open_interest_min: 100
volume_24h_min: 20
bid_ask_spread_pct_max: "0.15"
book_depth_top3_min: 5
slippage_pct_of_credit_max: "0.08"
sizing:
kelly_fraction: "0.13"
cap_per_trade_eur: "200"
cap_aggregate_open_eur: "1000"
max_concurrent_positions: 1
max_contracts_per_trade: 4
dvol_adjustment:
- {dvol_under: "45", multiplier: "1.00"}
- {dvol_under: "60", multiplier: "0.85"}
- {dvol_under: "80", multiplier: "0.65"}
dvol_no_entry_threshold: "80"
exit:
profit_take_pct_of_credit: "0.50"
stop_loss_mark_x_credit: "2.50"
vol_stop_dvol_increase: "10"
time_stop_dte_remaining: 7
time_stop_skip_if_close_to_profit_pct: "0.70"
delta_breach_threshold: "0.30"
adverse_move_4h_pct: "0.05"
monitor_cron: "0 2,14 * * *"
user_confirmation_timeout_min: 30
escalate_on_timeout:
- "CLOSE_STOP"
- "CLOSE_VOL"
- "CLOSE_DELTA"
execution:
combo_only: true
initial_limit: "mid"
reprice_step_ticks: 1
reprice_max_steps: 3
reprice_max_steps_urgent: 5
order_tif: "GTC"
order_expiry_min: 30
ack_timeout_s: 300
monitoring:
health_check_interval_s: 300
health_failures_before_kill: 3
health_failures_before_restart: 5
daily_digest_cron: "0 8 * * *"
monthly_report_cron: "0 12 1 * *"
storage:
sqlite_path: "data/state.sqlite"
log_path: "data/log/"
log_retention_days: 365
backup_path: "data/backups/"
backup_retention_days: 30
mcp:
config_file: "~/.config/cerbero-suite/mcp.json"
call_timeout_s: 8
retry_max: 3
retry_base_delay_s: 1
required_versions:
cerbero-deribit: "^2.0.0"
cerbero-hyperliquid: "^1.5.0"
cerbero-memory: "^4.0.0"
cerbero-portfolio: "^1.2.0"
cerbero-macro: "^1.0.0"
cerbero-sentiment: "^1.0.0"
cerbero-telegram: "^1.0.0"
cerbero-brain-bridge: "^1.0.0"
telegram:
parse_mode: "MarkdownV2"
confirmation_timeout_min: 60
exit_confirmation_timeout_min: 30
backup_channel_on_critical: true
kelly_recalibration:
lookback_days: 365
min_sample_low_confidence: 30
min_sample_high_confidence: 100
weight_when_medium_confidence: "0.50"
+111
View File
@@ -0,0 +1,111 @@
"""Smoke tests for the dead_man.sh shell watchdog."""
from __future__ import annotations
import shutil
import sqlite3
import subprocess
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
SCRIPT = REPO_ROOT / "scripts" / "dead_man.sh"
_SQLITE3_BIN = shutil.which("sqlite3")
pytestmark = [
pytest.mark.integration,
]
def _require_sqlite3() -> None:
if _SQLITE3_BIN is None:
pytest.skip("sqlite3 CLI not installed on this host")
def _setup_project(tmp_path: Path) -> Path:
project = tmp_path / "project"
(project / "data" / "log").mkdir(parents=True)
(project / "scripts").mkdir(parents=True)
shutil.copy(SCRIPT, project / "scripts" / "dead_man.sh")
return project
def _write_health(project: Path, ts: datetime) -> None:
log_file = project / "data" / "log" / f"cerbero-bite-{ts:%Y-%m-%d}.jsonl"
log_file.write_text(
f'{{"ts": "{ts.astimezone(UTC).isoformat()}", "event": "HEALTH_OK"}}\n',
encoding="utf-8",
)
def _run(project: Path, threshold: int = 900) -> subprocess.CompletedProcess[str]:
return subprocess.run(
["bash", str(project / "scripts" / "dead_man.sh")],
env={
"PATH": "/usr/bin:/bin",
"PROJECT_ROOT": str(project),
"DEAD_MAN_THRESHOLD_SECONDS": str(threshold),
},
capture_output=True,
text=True,
check=False,
)
def test_dead_man_exits_zero_when_recent_health_ok(tmp_path: Path) -> None:
project = _setup_project(tmp_path)
_write_health(project, datetime.now(UTC) - timedelta(seconds=60))
result = _run(project, threshold=900)
assert result.returncode == 0, result.stderr
def test_dead_man_arms_kill_switch_when_silent(tmp_path: Path) -> None:
_require_sqlite3()
project = _setup_project(tmp_path)
_write_health(project, datetime.now(UTC) - timedelta(seconds=2000))
# Pre-create the SQLite system_state singleton; otherwise the script
# has nothing to update.
db = project / "data" / "state.sqlite"
conn = sqlite3.connect(str(db))
conn.execute(
"CREATE TABLE system_state (id INTEGER PRIMARY KEY CHECK(id=1), "
"kill_switch INTEGER NOT NULL DEFAULT 0, kill_reason TEXT, "
"kill_at TEXT, last_health_check TEXT NOT NULL, "
"last_kelly_calib TEXT, config_version TEXT NOT NULL, "
"started_at TEXT NOT NULL)"
)
conn.execute(
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
)
conn.commit()
conn.close()
result = _run(project, threshold=900)
assert result.returncode == 1
# Kill switch armed.
conn = sqlite3.connect(str(db))
try:
kill = conn.execute("SELECT kill_switch FROM system_state").fetchone()[0]
finally:
conn.close()
assert kill == 1
# Alert file exists.
alert = project / "data" / "log" / "dead-man-alert.txt"
assert alert.exists()
assert "dead_man" in alert.read_text(encoding="utf-8")
def test_dead_man_handles_missing_log_file(tmp_path: Path) -> None:
project = _setup_project(tmp_path)
# No log file at all.
result = _run(project, threshold=900)
assert result.returncode == 1
alert = project / "data" / "log" / "dead-man-alert.txt"
assert alert.exists()
+273
View File
@@ -0,0 +1,273 @@
"""Audit chain writer + verifier tests."""
from __future__ import annotations
import hashlib
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.safety.audit_log import (
GENESIS_HASH,
AuditChainError,
AuditLog,
iter_entries,
verify_chain,
)
def test_empty_file_verifies_with_zero_entries(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
assert verify_chain(path) == 0
def test_first_entry_uses_genesis_prev_hash(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
entry = log.append(
event="ENGINE_START",
payload={"version": "1.0.0"},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert entry.prev_hash == GENESIS_HASH
assert entry.hash != GENESIS_HASH
def test_chain_links_subsequent_entries(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
e1 = log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
e2 = log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
e3 = log.append(event="C", payload={"i": 3}, now=datetime(2026, 4, 27, 14, 2, tzinfo=UTC))
assert e2.prev_hash == e1.hash
assert e3.prev_hash == e2.hash
assert verify_chain(path) == 3
def test_iter_entries_yields_in_order(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
events = [e.event for e in iter_entries(path)]
assert events == ["A", "B"]
def test_log_resumes_chain_after_reopen(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
first = AuditLog(path)
e1 = first.append(
event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
)
second = AuditLog(path)
assert second.last_hash == e1.hash
e2 = second.append(
event="B", payload={"k": "v"}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC)
)
assert e2.prev_hash == e1.hash
assert verify_chain(path) == 2
def test_payload_with_pipe_character_round_trips(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(
event="NOTE",
payload={"text": "first|second|third"},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
entries = list(iter_entries(path))
assert entries[0].payload == {"text": "first|second|third"}
assert verify_chain(path) == 1
def test_tampered_payload_breaks_chain(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={"i": 1}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={"i": 2}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
# Mutate the first line's payload by hand.
text = path.read_text(encoding="utf-8").splitlines()
text[0] = text[0].replace('"i":1', '"i":99')
path.write_text("\n".join(text) + "\n", encoding="utf-8")
with pytest.raises(AuditChainError, match="hash mismatch"):
verify_chain(path)
def test_verify_chain_skips_blank_lines(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
raw = path.read_text(encoding="utf-8")
path.write_text("\n" + raw + "\n \n", encoding="utf-8")
# The chain still verifies despite the surrounding whitespace lines.
assert verify_chain(path) == 1
def test_prev_hash_mismatch_between_entries_is_caught(tmp_path: Path) -> None:
"""Second line's prev_hash points to a different chain — verify_chain rejects."""
path = tmp_path / "audit.log"
log = AuditLog(path)
e1 = log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
# Build a synthetic second line whose prev_hash != e1.hash but whose
# own hash is correctly computed from that bogus prev_hash.
fake_prev = "0" * 32 + "f" * 32
ts2 = "2026-04-27T14:01:00+00:00"
payload_json = "{}"
raw = f"{ts2}|B|{payload_json}|{fake_prev}"
fake_hash = hashlib.sha256(raw.encode()).hexdigest()
line = f"{ts2}|B|{payload_json}|prev_hash={fake_prev}|hash={fake_hash}\n"
with path.open("a", encoding="utf-8") as fh:
fh.write(line)
assert e1.hash != fake_prev # sanity
with pytest.raises(AuditChainError, match="prev_hash mismatch"):
verify_chain(path)
def test_tampered_prev_hash_breaks_chain(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
log.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
# Inject an unrelated prev_hash on the second line.
lines = path.read_text(encoding="utf-8").splitlines()
lines[1] = lines[1].replace("prev_hash=", "prev_hash=" + "f" * 64 + "X")
# Truncate to recover length: replace prev_hash field with all-ff.
lines[1] = lines[1].replace("X", "")
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
with pytest.raises(AuditChainError):
verify_chain(path)
def test_malformed_line_raises_chain_error(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text("not-a-valid-line\n", encoding="utf-8")
with pytest.raises(AuditChainError):
verify_chain(path)
def test_parser_rejects_missing_hash_field(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{}|prev_hash=" + "0" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="hash="):
verify_chain(path)
def test_parser_rejects_missing_prev_hash_field(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{}|hash=" + "f" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="prev_hash"):
verify_chain(path)
def test_parser_rejects_line_with_no_separators(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text("just-a-blob|hash=" + "f" * 64 + "\n", encoding="utf-8")
with pytest.raises(AuditChainError, match="prev_hash"):
verify_chain(path)
def test_parser_rejects_malformed_leading_section(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
# Two `|` only: rsplit succeeds twice, leading parts has 1 element ≠ 3.
path.write_text(
"tooshort|prev_hash=" + "0" * 64 + "|hash=" + "f" * 64 + "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="leading section"):
verify_chain(path)
def test_parser_rejects_payload_not_a_json_object(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|[1,2]|prev_hash="
+ "0" * 64
+ "|hash="
+ "f" * 64
+ "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="JSON object"):
verify_chain(path)
def test_parser_rejects_payload_with_invalid_json(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
path.write_text(
"2026-04-27T14:00:00+00:00|EVT|{not-json}|prev_hash="
+ "0" * 64
+ "|hash="
+ "f" * 64
+ "\n",
encoding="utf-8",
)
with pytest.raises(AuditChainError, match="JSON"):
verify_chain(path)
def test_iter_entries_returns_empty_when_file_missing(tmp_path: Path) -> None:
path = tmp_path / "missing.log"
assert list(iter_entries(path)) == []
def test_iter_entries_skips_blank_lines(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
log.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
raw = path.read_text(encoding="utf-8")
path.write_text(raw + "\n\n", encoding="utf-8")
entries = list(iter_entries(path))
assert len(entries) == 1
def test_log_resumes_chain_with_large_file(tmp_path: Path) -> None:
"""Tail-seek reads past the 4096-byte chunk boundary."""
path = tmp_path / "audit.log"
log = AuditLog(path)
base = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
# Each line ~150 chars; 50 lines is comfortably > 4096 bytes.
for i in range(50):
log.append(
event=f"E{i}",
payload={"i": i, "filler": "x" * 80},
now=base + timedelta(seconds=i),
)
last_hash = log.last_hash
reopened = AuditLog(path)
assert reopened.last_hash == last_hash
assert verify_chain(path) == 50
def test_payload_serialisation_is_canonical(tmp_path: Path) -> None:
path = tmp_path / "audit.log"
log = AuditLog(path)
# Different key order must produce identical hashes.
e1 = log.append(
event="A",
payload={"b": 1, "a": 2},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
other = tmp_path / "audit_other.log"
log2 = AuditLog(other)
e2 = log2.append(
event="A",
payload={"a": 2, "b": 1},
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert e1.hash == e2.hash
+126
View File
@@ -0,0 +1,126 @@
"""Tests for scripts.backup."""
from __future__ import annotations
import importlib.util
import sys
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.state import connect, run_migrations
REPO_ROOT = Path(__file__).resolve().parents[2]
def _load_backup_module() -> object:
"""Load scripts/backup.py as a module without polluting sys.path."""
spec = importlib.util.spec_from_file_location(
"_cerbero_bite_backup", REPO_ROOT / "scripts" / "backup.py"
)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
@pytest.fixture
def backup_mod() -> object:
return _load_backup_module()
def test_backup_database_creates_snapshot(tmp_path: Path, backup_mod: object) -> None:
db = tmp_path / "state.sqlite"
conn = connect(db)
try:
run_migrations(conn)
conn.execute(
"INSERT INTO system_state(id, last_health_check, config_version, started_at) "
"VALUES (1, '2026-04-27', '1.0.0', '2026-04-27')"
)
finally:
conn.close()
backup_dir = tmp_path / "backups"
snapshot = backup_mod.backup_database( # type: ignore[attr-defined]
db_path=db,
backup_dir=backup_dir,
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert snapshot.exists()
assert snapshot.name == "state-20260427-14.sqlite"
# Snapshot is itself a valid SQLite db with the same row.
snap = connect(snapshot)
try:
rows = snap.execute(
"SELECT config_version FROM system_state WHERE id = 1"
).fetchone()
finally:
snap.close()
assert rows[0] == "1.0.0"
def test_backup_database_replaces_existing_hour_snapshot(
tmp_path: Path, backup_mod: object
) -> None:
db = tmp_path / "state.sqlite"
conn = connect(db)
try:
run_migrations(conn)
finally:
conn.close()
when = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
first = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
second = backup_mod.backup_database(db_path=db, backup_dir=tmp_path / "b", now=when) # type: ignore[attr-defined]
assert first == second
files = list((tmp_path / "b").iterdir())
assert len(files) == 1
def test_prune_backups_removes_old_files(tmp_path: Path, backup_mod: object) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
fresh = backup_dir / "state-20260420-10.sqlite"
stale = backup_dir / "state-20260101-12.sqlite"
other = backup_dir / "unrelated.txt"
fresh.touch()
stale.touch()
other.touch()
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
backup_dir,
retention_days=30,
now=datetime(2026, 4, 27, tzinfo=UTC),
)
assert deleted == [stale]
assert fresh.exists()
assert other.exists()
def test_prune_backups_ignores_unparseable_filenames(
tmp_path: Path, backup_mod: object
) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
(backup_dir / "state-bogus-XX.sqlite").touch()
deleted = backup_mod.prune_backups( # type: ignore[attr-defined]
backup_dir,
retention_days=0,
now=datetime(2026, 4, 27, tzinfo=UTC) + timedelta(days=10000),
)
assert deleted == []
def test_list_backups_returns_sorted(tmp_path: Path, backup_mod: object) -> None:
backup_dir = tmp_path / "b"
backup_dir.mkdir()
a = backup_dir / "state-20260103-08.sqlite"
b = backup_dir / "state-20260101-08.sqlite"
a.touch()
b.touch()
listed = list(backup_mod.list_backups(backup_dir)) # type: ignore[attr-defined]
assert listed == [b, a]
+160
View File
@@ -0,0 +1,160 @@
"""End-to-end CLI tests for the Phase 2 commands.
The commands hit real on-disk paths (tmp_path) so the assertions run
the production code paths verbatim.
"""
from __future__ import annotations
from datetime import UTC, datetime
from pathlib import Path
from click.testing import CliRunner
from cerbero_bite.cli import main as cli_main
from cerbero_bite.safety import AuditLog
from cerbero_bite.state import Repository, connect, run_migrations, transaction
def _seed_state(db_path: Path) -> None:
conn = connect(db_path)
try:
run_migrations(conn)
with transaction(conn):
Repository().init_system_state(
conn,
config_version="1.0.0",
now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
finally:
conn.close()
def test_audit_verify_reports_ok_on_clean_chain(tmp_path: Path) -> None:
audit = AuditLog(tmp_path / "audit.log")
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
audit.append(event="B", payload={}, now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC))
result = CliRunner().invoke(
cli_main, ["audit", "verify", "--file", str(tmp_path / "audit.log")]
)
assert result.exit_code == 0, result.output
assert "ok" in result.output
assert "2" in result.output
def test_audit_verify_handles_empty_file(tmp_path: Path) -> None:
target = tmp_path / "audit.log"
target.write_text("", encoding="utf-8")
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
assert result.exit_code == 0
assert "empty" in result.output
def test_audit_verify_exits_nonzero_on_tampering(tmp_path: Path) -> None:
target = tmp_path / "audit.log"
audit = AuditLog(target)
audit.append(event="A", payload={}, now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC))
target.write_text(
target.read_text(encoding="utf-8").replace('"event":"A"', '"event":"X"'),
encoding="utf-8",
)
# NB: we mutated the JSON payload, but the actual line still has event=A.
# Force tampering by editing the literal "A" in the line text.
raw = target.read_text(encoding="utf-8")
target.write_text(raw.replace("|A|", "|X|", 1), encoding="utf-8")
result = CliRunner().invoke(cli_main, ["audit", "verify", "--file", str(target)])
assert result.exit_code == 2
assert "TAMPERED" in result.output
def test_kill_switch_status_prints_disarmed(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
_seed_state(db)
result = CliRunner().invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
assert result.exit_code == 0
assert "disarmed" in result.output
def test_kill_switch_arm_then_status_shows_armed(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
audit = tmp_path / "audit.log"
runner = CliRunner()
arm = runner.invoke(
cli_main,
[
"kill-switch",
"arm",
"--reason",
"manual smoke",
"--db",
str(db),
"--audit",
str(audit),
],
)
assert arm.exit_code == 0, arm.output
status = runner.invoke(cli_main, ["kill-switch", "status", "--db", str(db)])
assert status.exit_code == 0
assert "ARMED" in status.output
def test_kill_switch_status_handles_missing_db(tmp_path: Path) -> None:
result = CliRunner().invoke(
cli_main, ["kill-switch", "status", "--db", str(tmp_path / "absent.sqlite")]
)
assert result.exit_code == 0
assert "not found" in result.output
def test_state_inspect_shows_no_open_positions(tmp_path: Path) -> None:
db = tmp_path / "state.sqlite"
_seed_state(db)
result = CliRunner().invoke(cli_main, ["state", "inspect", "--db", str(db)])
assert result.exit_code == 0
assert "no open positions" in result.output
def test_state_inspect_handles_missing_db(tmp_path: Path) -> None:
result = CliRunner().invoke(
cli_main, ["state", "inspect", "--db", str(tmp_path / "absent.sqlite")]
)
assert result.exit_code == 0
assert "not found" in result.output
def test_config_hash_matches_loader(tmp_path: Path) -> None:
target = tmp_path / "strategy.yaml"
target.write_text(
'config_version: "1.0.0"\nconfig_hash: "0000"\nasset:\n symbol: ETH\n',
encoding="utf-8",
)
result = CliRunner().invoke(cli_main, ["config", "hash", "--file", str(target)])
assert result.exit_code == 0
assert len(result.output.strip()) == 64 # sha256 hex
def test_config_validate_repo_strategy_yaml() -> None:
repo_root = Path(__file__).resolve().parents[2]
yaml_path = repo_root / "strategy.yaml"
result = CliRunner().invoke(
cli_main, ["config", "validate", "--file", str(yaml_path)]
)
assert result.exit_code == 0
assert "ok" in result.output
def test_config_validate_with_no_enforce_hash_skips_check(tmp_path: Path) -> None:
target = tmp_path / "strategy.yaml"
target.write_text(
'config_version: "1.0.0"\nconfig_hash: "wrong"\n'
'last_review: "2026-04-26"\nlast_reviewer: "test"\n',
encoding="utf-8",
)
result = CliRunner().invoke(
cli_main,
["config", "validate", "--file", str(target), "--no-enforce-hash"],
)
assert result.exit_code == 0
assert "ok" in result.output
+149
View File
@@ -0,0 +1,149 @@
"""Tests for the YAML loader and hash verification."""
from __future__ import annotations
from decimal import Decimal
from pathlib import Path
import pytest
import yaml
from cerbero_bite.config.loader import (
ConfigHashError,
compute_config_hash,
load_strategy,
)
REPO_ROOT = Path(__file__).resolve().parents[2]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _golden_yaml_skeleton(**overrides: object) -> dict[str, object]:
base = {
"config_version": "1.0.0-test",
"config_hash": "0" * 64,
"last_review": "2026-04-26",
"last_reviewer": "test",
}
base.update(overrides)
return base
def _write_with_correct_hash(path: Path, doc: dict[str, object]) -> None:
"""Write a YAML doc and patch ``config_hash`` to match the file body."""
text = yaml.safe_dump(doc, sort_keys=False)
path.write_text(text, encoding="utf-8")
new_hash = compute_config_hash(path.read_text(encoding="utf-8"))
doc = {**doc, "config_hash": new_hash}
text = yaml.safe_dump(doc, sort_keys=False)
path.write_text(text, encoding="utf-8")
# ---------------------------------------------------------------------------
# compute_config_hash
# ---------------------------------------------------------------------------
def test_compute_hash_is_independent_of_recorded_hash_value(tmp_path: Path) -> None:
a = tmp_path / "a.yaml"
b = tmp_path / "b.yaml"
a.write_text(
'config_version: "1.0.0"\nconfig_hash: "aaa"\nfoo: 1\n', encoding="utf-8"
)
b.write_text(
'config_version: "1.0.0"\nconfig_hash: "bbb"\nfoo: 1\n', encoding="utf-8"
)
assert compute_config_hash(a.read_text()) == compute_config_hash(b.read_text())
# ---------------------------------------------------------------------------
# load_strategy — happy path
# ---------------------------------------------------------------------------
def test_load_repo_strategy_yaml(tmp_path: Path) -> None:
"""The committed strategy.yaml validates with the recorded hash."""
result = load_strategy(REPO_ROOT / "strategy.yaml")
assert result.config.config_version == "1.0.0"
assert result.config.sizing.kelly_fraction == Decimal("0.13")
assert result.computed_hash == result.config.config_hash
def test_load_with_local_override_merges(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
override = tmp_path / "strategy.local.yaml"
override.write_text(
yaml.safe_dump(
{"sizing": {"max_concurrent_positions": 0}},
sort_keys=False,
),
encoding="utf-8",
)
loaded = load_strategy(main)
assert loaded.config.sizing.max_concurrent_positions == 0
assert override in loaded.sources
def test_local_override_does_not_invalidate_main_hash(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
(tmp_path / "strategy.local.yaml").write_text(
"sizing:\n kelly_fraction: '0.10'\n", encoding="utf-8"
)
loaded = load_strategy(main)
assert loaded.config.sizing.kelly_fraction == Decimal("0.10")
# ---------------------------------------------------------------------------
# load_strategy — error paths
# ---------------------------------------------------------------------------
def test_load_with_mismatched_hash_raises(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text(
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
)
with pytest.raises(ConfigHashError, match="config_hash mismatch"):
load_strategy(main)
def test_load_with_enforce_hash_false_skips_check(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text(
yaml.safe_dump(_golden_yaml_skeleton(), sort_keys=False), encoding="utf-8"
)
loaded = load_strategy(main, enforce_hash=False)
assert loaded.config.config_hash == "0" * 64
def test_load_rejects_top_level_non_mapping(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
main.write_text("- a\n- b\n", encoding="utf-8")
with pytest.raises(ValueError, match="top-level mapping"):
load_strategy(main, enforce_hash=False)
def test_local_override_path_pointing_to_missing_file_is_ignored(
tmp_path: Path,
) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
loaded = load_strategy(
main, local_override_path=tmp_path / "nonexistent.yaml"
)
assert main in loaded.sources
assert len(loaded.sources) == 1
def test_empty_local_override_file_is_no_op(tmp_path: Path) -> None:
main = tmp_path / "strategy.yaml"
_write_with_correct_hash(main, _golden_yaml_skeleton())
(tmp_path / "strategy.local.yaml").write_text("", encoding="utf-8")
loaded = load_strategy(main)
assert loaded.config.sizing.kelly_fraction == Decimal("0.13")
+170
View File
@@ -0,0 +1,170 @@
"""Kill switch behaviour: SQLite + audit log stay in lock-step."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from cerbero_bite.safety import AuditLog, verify_chain
from cerbero_bite.safety.kill_switch import KillSwitch, KillSwitchError
from cerbero_bite.state import Repository, connect, run_migrations, transaction
def _make_kill_switch(tmp_path: Path) -> tuple[KillSwitch, AuditLog, Path, Repository]:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
repo = Repository()
with transaction(conn):
repo.init_system_state(
conn, config_version="1.0.0", now=datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
)
conn.close()
audit = AuditLog(audit_path)
times = iter(
datetime(2026, 4, 27, 14, m, tzinfo=UTC) for m in (10, 20, 30, 40, 50)
)
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=repo,
audit_log=audit,
clock=lambda: next(times),
)
return ks, audit, audit_path, repo
def test_arm_persists_state_and_appends_audit(tmp_path: Path) -> None:
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
assert ks.is_armed() is False
ks.arm(reason="manual test", source="manual")
assert ks.is_armed() is True
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
assert state.kill_switch == 1
assert state.kill_reason == "manual test"
assert state.kill_at is not None
assert verify_chain(audit_path) == 1
def test_arm_is_idempotent_on_second_call(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.arm(reason="first", source="manual")
ks.arm(reason="second", source="manual") # no-op
# only one audit line because the second call short-circuits
assert verify_chain(audit_path) == 1
def test_disarm_resets_kill_switch(tmp_path: Path) -> None:
ks, _audit, audit_path, repo = _make_kill_switch(tmp_path)
ks.arm(reason="test", source="manual")
ks.disarm(reason="cleared", source="manual")
assert ks.is_armed() is False
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
assert state.kill_at is None
# arm + disarm = 2 audit lines
assert verify_chain(audit_path) == 2
def test_disarm_when_not_armed_is_noop(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.disarm(reason="nothing to do", source="manual")
assert verify_chain(audit_path) == 0
def test_arm_requires_reason(tmp_path: Path) -> None:
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
with pytest.raises(KillSwitchError, match="reason is required"):
ks.arm(reason="", source="manual")
def test_arm_without_initialised_state_raises(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
ks.arm(reason="x", source="manual")
def test_audit_chain_records_event_kind(tmp_path: Path) -> None:
ks, _audit, audit_path, _repo = _make_kill_switch(tmp_path)
ks.arm(reason="x", source="mcp_timeout")
ks.disarm(reason="y", source="manual")
text = audit_path.read_text(encoding="utf-8")
assert "KILL_SWITCH_ARMED" in text
assert "KILL_SWITCH_DISARMED" in text
def test_is_armed_returns_false_when_singleton_missing(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
assert ks.is_armed() is False
def test_disarm_requires_reason(tmp_path: Path) -> None:
ks, _audit, _audit_path, _repo = _make_kill_switch(tmp_path)
with pytest.raises(KillSwitchError, match="reason is required"):
ks.disarm(reason="", source="manual")
def test_disarm_without_initialised_state_raises(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
audit_path = tmp_path / "audit.log"
conn = connect(db_path)
run_migrations(conn)
conn.close()
ks = KillSwitch(
connection_factory=lambda: connect(db_path),
repository=Repository(),
audit_log=AuditLog(audit_path),
clock=lambda: datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with pytest.raises(KillSwitchError, match="system_state singleton missing"):
ks.disarm(reason="x", source="manual")
def test_clock_is_advanced_for_each_call(tmp_path: Path) -> None:
ks, _audit, _audit_path, repo = _make_kill_switch(tmp_path)
ks.arm(reason="x", source="manual")
ks.disarm(reason="y", source="manual")
conn = connect(tmp_path / "state.sqlite")
try:
state = repo.get_system_state(conn)
finally:
conn.close()
assert state is not None
# last_health_check should reflect the disarm time (14:20 from the fake clock).
assert state.last_health_check >= datetime(2026, 4, 27, 14, 15, tzinfo=UTC) - timedelta(
seconds=1
)
+19 -4
View File
@@ -38,14 +38,29 @@ def test_cli_status_runs(tmp_data_dir: Path) -> None:
assert "phase: 0" in result.output assert "phase: 0" in result.output
def test_cli_kill_switch_arm_placeholder(tmp_data_dir: Path) -> None: def test_cli_kill_switch_arm_persists_state(tmp_data_dir: Path) -> None:
runner = CliRunner() runner = CliRunner()
db_path = tmp_data_dir / "state.sqlite"
audit_path = tmp_data_dir / "audit.log"
result = runner.invoke( result = runner.invoke(
cli_main, cli_main,
["--log-dir", str(tmp_data_dir / "log"), "kill-switch", "arm", "--reason", "test"], [
"--log-dir",
str(tmp_data_dir / "log"),
"kill-switch",
"arm",
"--reason",
"smoke",
"--db",
str(db_path),
"--audit",
str(audit_path),
],
) )
assert result.exit_code == 0 assert result.exit_code == 0, result.output
assert "phase 0 placeholder" in result.output assert "ARMED" in result.output
assert db_path.exists()
assert audit_path.exists()
def test_cli_version_flag() -> None: def test_cli_version_flag() -> None:
+109
View File
@@ -0,0 +1,109 @@
"""Migration runner + SQLite pragma sanity tests."""
from __future__ import annotations
import sqlite3
from pathlib import Path
import pytest
from cerbero_bite.state.db import (
connect,
current_version,
list_migrations,
run_migrations,
transaction,
)
def test_list_migrations_is_ordered_and_starts_with_0001() -> None:
migs = list_migrations()
assert migs, "expected at least one migration file"
versions = [m[0] for m in migs]
assert versions == sorted(versions)
assert versions[0] == 1
def test_run_migrations_creates_full_schema(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
new_version = run_migrations(conn)
assert new_version == max(m[0] for m in list_migrations())
# Sanity: every documented table is present.
existing = {
row["name"]
for row in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
}
assert {
"positions",
"instructions",
"decisions",
"dvol_history",
"manual_actions",
"system_state",
} <= existing
finally:
conn.close()
def test_run_migrations_is_idempotent(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
first = current_version(conn)
run_migrations(conn) # second call must be a no-op
assert current_version(conn) == first
finally:
conn.close()
def test_pragmas_applied(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
assert conn.execute("PRAGMA foreign_keys").fetchone()[0] == 1
# WAL is sticky on the file: confirm via the journal_mode pragma.
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode.lower() == "wal"
finally:
conn.close()
def test_transaction_rolls_back_on_exception(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
with pytest.raises(RuntimeError, match="boom"), transaction(conn):
conn.execute(
"INSERT INTO system_state(id, last_health_check, "
"config_version, started_at) VALUES (1, '2026-04-27', "
"'1.0.0', '2026-04-27')"
)
raise RuntimeError("boom")
# Row was rolled back.
rows = conn.execute("SELECT COUNT(*) FROM system_state").fetchone()[0]
assert rows == 0
finally:
conn.close()
def test_foreign_key_enforcement(tmp_path: Path) -> None:
db_path = tmp_path / "state.sqlite"
conn = connect(db_path)
try:
run_migrations(conn)
# Inserting an instruction without a parent position must fail.
with pytest.raises(sqlite3.IntegrityError):
conn.execute(
"INSERT INTO instructions(instruction_id, proposal_id, kind, "
"payload_json, sent_at) VALUES (?, ?, ?, ?, ?)",
("i1", "missing", "open_combo", "{}", "2026-04-27"),
)
finally:
conn.close()
+450
View File
@@ -0,0 +1,450 @@
"""CRUD tests for :mod:`cerbero_bite.state.repository`."""
from __future__ import annotations
import sqlite3
from collections.abc import Iterator
from datetime import UTC, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from uuid import uuid4
import pytest
from cerbero_bite.state import (
DecisionRecord,
DvolSnapshot,
InstructionRecord,
ManualAction,
PositionRecord,
Repository,
connect,
run_migrations,
transaction,
)
@pytest.fixture
def conn(tmp_path: Path) -> Iterator[sqlite3.Connection]:
db = tmp_path / "state.sqlite"
c = connect(db)
run_migrations(c)
yield c
c.close()
@pytest.fixture
def repo() -> Repository:
return Repository()
def _make_position(**overrides: object) -> PositionRecord:
base: dict[str, object] = {
"proposal_id": uuid4(),
"spread_type": "bull_put",
"expiry": datetime(2026, 5, 15, 8, 0, tzinfo=UTC),
"short_strike": Decimal("2475"),
"long_strike": Decimal("2350"),
"short_instrument": "ETH-15MAY26-2475-P",
"long_instrument": "ETH-15MAY26-2350-P",
"n_contracts": 2,
"spread_width_usd": Decimal("125"),
"spread_width_pct": Decimal("0.0417"),
"credit_eth": Decimal("0.030"),
"credit_usd": Decimal("90"),
"max_loss_usd": Decimal("160"),
"spot_at_entry": Decimal("3000"),
"dvol_at_entry": Decimal("50"),
"delta_at_entry": Decimal("-0.12"),
"eth_price_at_entry": Decimal("3000"),
"proposed_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
"status": "proposed",
"created_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
"updated_at": datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
}
base.update(overrides)
return PositionRecord(**base) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# positions
# ---------------------------------------------------------------------------
def test_create_and_get_position_roundtrip(
conn: sqlite3.Connection, repo: Repository
) -> None:
record = _make_position()
with transaction(conn):
repo.create_position(conn, record)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.proposal_id == record.proposal_id
assert fetched.short_strike == Decimal("2475")
# Decimal precision preserved (no float coercion).
assert fetched.spread_width_pct == Decimal("0.0417")
def test_get_unknown_position_returns_none(
conn: sqlite3.Connection, repo: Repository
) -> None:
assert repo.get_position(conn, uuid4()) is None
def test_list_open_positions_filters_by_status(
conn: sqlite3.Connection, repo: Repository
) -> None:
open_pos = _make_position(status="open")
closed_pos = _make_position(status="closed")
with transaction(conn):
repo.create_position(conn, open_pos)
repo.create_position(conn, closed_pos)
open_only = repo.list_open_positions(conn)
assert {p.proposal_id for p in open_only} == {open_pos.proposal_id}
def test_count_concurrent_positions_excludes_closed(
conn: sqlite3.Connection, repo: Repository
) -> None:
with transaction(conn):
repo.create_position(conn, _make_position(status="open"))
repo.create_position(conn, _make_position(status="awaiting_fill"))
repo.create_position(conn, _make_position(status="closed"))
repo.create_position(conn, _make_position(status="cancelled"))
assert repo.count_concurrent_positions(conn) == 2
def test_update_position_status_with_only_status_skips_optional_fields(
conn: sqlite3.Connection, repo: Repository
) -> None:
"""Hit the False side of every optional-field branch in ``update_position_status``."""
record = _make_position(status="proposed")
with transaction(conn):
repo.create_position(conn, record)
repo.update_position_status(
conn,
record.proposal_id,
status="awaiting_fill",
now=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.status == "awaiting_fill"
assert fetched.opened_at is None
assert fetched.closed_at is None
assert fetched.close_reason is None
def test_update_position_status_persists_open_then_close(
conn: sqlite3.Connection, repo: Repository
) -> None:
record = _make_position(status="awaiting_fill")
opened_at = datetime(2026, 4, 27, 14, 10, tzinfo=UTC)
with transaction(conn):
repo.create_position(conn, record)
repo.update_position_status(
conn,
record.proposal_id,
status="open",
opened_at=opened_at,
now=datetime(2026, 4, 27, 14, 11, tzinfo=UTC),
)
after_open = repo.get_position(conn, record.proposal_id)
assert after_open is not None
assert after_open.opened_at == opened_at
closed_at = datetime(2026, 5, 12, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.update_position_status(
conn,
record.proposal_id,
status="closed",
closed_at=closed_at,
close_reason="CLOSE_PROFIT",
debit_paid_eth=Decimal("0.012"),
pnl_eth=Decimal("0.018"),
pnl_usd=Decimal("54"),
now=closed_at,
)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.status == "closed"
assert fetched.close_reason == "CLOSE_PROFIT"
assert fetched.debit_paid_eth == Decimal("0.012")
assert fetched.pnl_eth == Decimal("0.018")
def test_naive_datetime_is_normalised_to_utc(
conn: sqlite3.Connection, repo: Repository
) -> None:
naive_now = datetime(2026, 4, 27, 14, 0)
record = _make_position(proposed_at=naive_now, created_at=naive_now, updated_at=naive_now)
with transaction(conn):
repo.create_position(conn, record)
fetched = repo.get_position(conn, record.proposal_id)
assert fetched is not None
assert fetched.proposed_at.tzinfo is not None
assert fetched.proposed_at.utcoffset() == timedelta(0)
# ---------------------------------------------------------------------------
# instructions
# ---------------------------------------------------------------------------
def test_instruction_lifecycle_ack_and_fill(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position(status="awaiting_fill")
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json='{"action":"open"}',
sent_at=datetime(2026, 4, 27, 14, 5, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(
conn,
instr.instruction_id,
acknowledged_at=datetime(2026, 4, 27, 14, 6, tzinfo=UTC),
)
repo.update_instruction(
conn,
instr.instruction_id,
filled_at=datetime(2026, 4, 27, 14, 8, tzinfo=UTC),
actual_fill_eth=Decimal("0.0298"),
actual_fees_eth=Decimal("0.0001"),
)
fetched = repo.list_instructions(conn, pos.proposal_id)
assert len(fetched) == 1
assert fetched[0].acknowledged_at is not None
assert fetched[0].filled_at is not None
assert fetched[0].actual_fill_eth == Decimal("0.0298")
def test_update_instruction_no_op_when_no_fields(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json="{}",
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(conn, instr.instruction_id) # no fields
fetched = repo.list_instructions(conn, pos.proposal_id)
assert fetched[0].acknowledged_at is None
# ---------------------------------------------------------------------------
# decisions / dvol / manual_actions
# ---------------------------------------------------------------------------
def test_record_and_list_decisions(
conn: sqlite3.Connection, repo: Repository
) -> None:
decision = DecisionRecord(
decision_type="entry_check",
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
inputs_json='{"capital":1500}',
outputs_json='{"accepted":true}',
action_taken="propose_open",
)
with transaction(conn):
decision_id = repo.record_decision(conn, decision)
assert decision_id > 0
decisions = repo.list_decisions(conn)
assert len(decisions) == 1
assert decisions[0].id == decision_id
assert decisions[0].action_taken == "propose_open"
def test_record_dvol_snapshot_replaces_on_duplicate_timestamp(
conn: sqlite3.Connection, repo: Repository
) -> None:
ts = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.record_dvol_snapshot(
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("50"), eth_spot=Decimal("3000"))
)
repo.record_dvol_snapshot(
conn, DvolSnapshot(timestamp=ts, dvol=Decimal("55"), eth_spot=Decimal("3050"))
)
rows = conn.execute("SELECT COUNT(*) FROM dvol_history").fetchone()
assert rows[0] == 1
def test_manual_action_enqueue_consume_cycle(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
action = ManualAction(
kind="approve_proposal",
proposal_id=pos.proposal_id,
payload_json='{"reason":"go"}',
created_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
action_id = repo.enqueue_manual_action(conn, action)
next_action = repo.next_unconsumed_action(conn)
assert next_action is not None
assert next_action.id == action_id
with transaction(conn):
repo.mark_action_consumed(
conn,
action_id,
consumed_by="orchestrator",
result="ok",
now=datetime(2026, 4, 27, 14, 1, tzinfo=UTC),
)
assert repo.next_unconsumed_action(conn) is None
# ---------------------------------------------------------------------------
# system_state
# ---------------------------------------------------------------------------
def test_init_system_state_is_idempotent(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
repo.init_system_state(conn, config_version="1.0.0", now=now)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 0
assert state.config_version == "1.0.0"
def test_kill_switch_arm_and_disarm(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
repo.set_kill_switch(conn, armed=True, reason="manual", now=now)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 1
assert state.kill_reason == "manual"
assert state.kill_at is not None
later = now + timedelta(minutes=5)
with transaction(conn):
repo.set_kill_switch(conn, armed=False, reason=None, now=later)
state = repo.get_system_state(conn)
assert state is not None
assert state.kill_switch == 0
assert state.kill_at is None
assert state.last_health_check == later
def test_get_system_state_returns_none_when_uninitialised(
conn: sqlite3.Connection, repo: Repository
) -> None:
assert repo.get_system_state(conn) is None
def test_list_positions_filters_by_status(
conn: sqlite3.Connection, repo: Repository
) -> None:
open_pos = _make_position(status="open")
closed_pos = _make_position(status="closed")
with transaction(conn):
repo.create_position(conn, open_pos)
repo.create_position(conn, closed_pos)
closed_only = repo.list_positions(conn, status="closed")
assert len(closed_only) == 1
assert closed_only[0].proposal_id == closed_pos.proposal_id
def test_list_positions_without_filter_returns_all(
conn: sqlite3.Connection, repo: Repository
) -> None:
with transaction(conn):
repo.create_position(conn, _make_position(status="open"))
repo.create_position(conn, _make_position(status="closed"))
assert len(repo.list_positions(conn)) == 2
def test_list_decisions_filters_by_proposal(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
with transaction(conn):
repo.create_position(conn, pos)
repo.record_decision(
conn,
DecisionRecord(
decision_type="entry_check",
proposal_id=pos.proposal_id,
timestamp=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
inputs_json="{}",
outputs_json="{}",
action_taken="propose_open",
),
)
repo.record_decision(
conn,
DecisionRecord(
decision_type="exit_check",
timestamp=datetime(2026, 4, 27, 15, 0, tzinfo=UTC),
inputs_json="{}",
outputs_json="{}",
action_taken="HOLD",
),
)
only_for_proposal = repo.list_decisions(conn, proposal_id=pos.proposal_id)
assert len(only_for_proposal) == 1
assert only_for_proposal[0].action_taken == "propose_open"
def test_update_instruction_sets_cancelled(
conn: sqlite3.Connection, repo: Repository
) -> None:
pos = _make_position()
instr = InstructionRecord(
instruction_id=uuid4(),
proposal_id=pos.proposal_id,
kind="open_combo",
payload_json="{}",
sent_at=datetime(2026, 4, 27, 14, 0, tzinfo=UTC),
)
with transaction(conn):
repo.create_position(conn, pos)
repo.create_instruction(conn, instr)
repo.update_instruction(
conn,
instr.instruction_id,
cancelled_at=datetime(2026, 4, 27, 14, 30, tzinfo=UTC),
)
fetched = repo.list_instructions(conn, pos.proposal_id)
assert fetched[0].cancelled_at is not None
def test_touch_health_check_updates_timestamp(
conn: sqlite3.Connection, repo: Repository
) -> None:
now = datetime(2026, 4, 27, 14, 0, tzinfo=UTC)
with transaction(conn):
repo.init_system_state(conn, config_version="1.0.0", now=now)
later = now + timedelta(minutes=10)
repo.touch_health_check(conn, now=later)
state = repo.get_system_state(conn)
assert state is not None
assert state.last_health_check == later