diff --git a/src/cerbero_mcp/exchanges/ibkr/key_rotation.py b/src/cerbero_mcp/exchanges/ibkr/key_rotation.py new file mode 100644 index 0000000..8104874 --- /dev/null +++ b/src/cerbero_mcp/exchanges/ibkr/key_rotation.py @@ -0,0 +1,106 @@ +"""IBKR RSA key rotation: stage/confirm/abort with auto-rollback.""" +from __future__ import annotations + +import datetime as _dt +import hashlib +import os +import shutil +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from pathlib import Path + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa + + +def _sha256_fingerprint(pem_path: Path) -> str: + digest = hashlib.sha256(pem_path.read_bytes()).hexdigest() + return f"SHA256:{digest}" + + +@dataclass +class KeyRotationManager: + signature_key_path: str + encryption_key_path: str + + _started: bool = field(default=False, init=False) + + def _sig(self) -> Path: + return Path(self.signature_key_path) + + def _enc(self) -> Path: + return Path(self.encryption_key_path) + + async def start(self) -> dict: + sig_new = self._sig().with_suffix(self._sig().suffix + ".new") + enc_new = self._enc().with_suffix(self._enc().suffix + ".new") + + for p in (sig_new, enc_new): + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + p.write_bytes(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + )) + os.chmod(p, 0o600) + + self._started = True + return { + "fingerprints": { + "sig": _sha256_fingerprint(sig_new), + "enc": _sha256_fingerprint(enc_new), + }, + "expires_at": ( + _dt.datetime.now(_dt.UTC) + _dt.timedelta(hours=24) + ).isoformat(), + } + + async def confirm( + self, *, validate: Callable[[], Awaitable[bool]], + ) -> dict: + sig = self._sig() + enc = self._enc() + sig_new = sig.with_suffix(sig.suffix + ".new") + enc_new = enc.with_suffix(enc.suffix + ".new") + if not (sig_new.exists() and enc_new.exists()): + raise RuntimeError("IBKR_ROTATION_NOT_STARTED") + + archive = sig.parent / ".archive" / _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%S") + archive.mkdir(parents=True, exist_ok=True) + + shutil.move(str(sig), str(archive / sig.name)) + shutil.move(str(enc), str(archive / enc.name)) + shutil.move(str(sig_new), str(sig)) + shutil.move(str(enc_new), str(enc)) + + err: BaseException | None = None + try: + ok = await validate() + except Exception as e: + ok = False + err = e + + if not ok: + shutil.move(str(sig), str(sig.with_suffix(sig.suffix + ".new"))) + shutil.move(str(enc), str(enc.with_suffix(enc.suffix + ".new"))) + shutil.move(str(archive / sig.name), str(sig)) + shutil.move(str(archive / enc.name), str(enc)) + raise RuntimeError( + f"IBKR_ROTATION_VALIDATION_FAILED: {err}" if err + else "IBKR_ROTATION_VALIDATION_FAILED" + ) + + self._started = False + return { + "rotated_at": _dt.datetime.now(_dt.UTC).isoformat(), + "old_archived_at": str(archive), + } + + async def abort(self) -> dict: + sig_new = self._sig().with_suffix(self._sig().suffix + ".new") + enc_new = self._enc().with_suffix(self._enc().suffix + ".new") + for p in (sig_new, enc_new): + if p.exists(): + p.unlink() + self._started = False + return {"aborted": True} diff --git a/tests/unit/exchanges/ibkr/test_key_rotation.py b/tests/unit/exchanges/ibkr/test_key_rotation.py new file mode 100644 index 0000000..ac82edf --- /dev/null +++ b/tests/unit/exchanges/ibkr/test_key_rotation.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import pytest +from cerbero_mcp.exchanges.ibkr.key_rotation import KeyRotationManager + + +@pytest.mark.asyncio +async def test_start_generates_new_keypair_files(tmp_path): + sig_path = tmp_path / "sig.pem" + enc_path = tmp_path / "enc.pem" + sig_path.write_bytes(b"old-sig") + enc_path.write_bytes(b"old-enc") + + mgr = KeyRotationManager( + signature_key_path=str(sig_path), + encryption_key_path=str(enc_path), + ) + out = await mgr.start() + assert "sig" in out["fingerprints"] + assert "enc" in out["fingerprints"] + assert (tmp_path / "sig.pem.new").exists() + assert (tmp_path / "enc.pem.new").exists() + + +@pytest.mark.asyncio +async def test_confirm_swap_and_validate_ok(tmp_path): + sig_path = tmp_path / "sig.pem" + enc_path = tmp_path / "enc.pem" + sig_path.write_bytes(b"old-sig") + enc_path.write_bytes(b"old-enc") + + mgr = KeyRotationManager( + signature_key_path=str(sig_path), + encryption_key_path=str(enc_path), + ) + await mgr.start() + + async def fake_validate() -> bool: + return True + out = await mgr.confirm(validate=fake_validate) + assert "rotated_at" in out + assert (tmp_path / ".archive").exists() + + +@pytest.mark.asyncio +async def test_confirm_validate_fail_rollbacks(tmp_path): + sig_path = tmp_path / "sig.pem" + enc_path = tmp_path / "enc.pem" + sig_path.write_bytes(b"old-sig") + enc_path.write_bytes(b"old-enc") + + mgr = KeyRotationManager( + signature_key_path=str(sig_path), + encryption_key_path=str(enc_path), + ) + await mgr.start() + + async def fake_validate() -> bool: + return False + with pytest.raises(RuntimeError, match="IBKR_ROTATION_VALIDATION_FAILED"): + await mgr.confirm(validate=fake_validate) + assert sig_path.read_bytes() == b"old-sig" + assert enc_path.read_bytes() == b"old-enc" + + +@pytest.mark.asyncio +async def test_abort_cleans_new_files(tmp_path): + sig_path = tmp_path / "sig.pem" + enc_path = tmp_path / "enc.pem" + sig_path.write_bytes(b"old-sig") + enc_path.write_bytes(b"old-enc") + + mgr = KeyRotationManager( + signature_key_path=str(sig_path), + encryption_key_path=str(enc_path), + ) + await mgr.start() + await mgr.abort() + assert not (tmp_path / "sig.pem.new").exists() + assert not (tmp_path / "enc.pem.new").exists()