From bd6b03ce43a9c96cad82eb4d78b28bf68e1b432f Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Fri, 1 May 2026 08:44:28 +0200 Subject: [PATCH] feat(V2): cabla audit logging nei write endpoint dei 4 router exchange Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cerbero_mcp/common/audit_helpers.py | 94 +++++++++++++++ src/cerbero_mcp/exchanges/bybit/tools.py | 2 - src/cerbero_mcp/exchanges/deribit/tools.py | 6 - .../exchanges/hyperliquid/tools.py | 1 - src/cerbero_mcp/routers/alpaca.py | 58 +++++++++- src/cerbero_mcp/routers/bybit.py | 107 ++++++++++++++++-- src/cerbero_mcp/routers/deribit.py | 58 +++++++++- src/cerbero_mcp/routers/hyperliquid.py | 50 +++++++- tests/unit/common/test_audit_helpers.py | 103 +++++++++++++++++ 9 files changed, 442 insertions(+), 37 deletions(-) create mode 100644 src/cerbero_mcp/common/audit_helpers.py create mode 100644 tests/unit/common/test_audit_helpers.py diff --git a/src/cerbero_mcp/common/audit_helpers.py b/src/cerbero_mcp/common/audit_helpers.py new file mode 100644 index 0000000..ab98ff6 --- /dev/null +++ b/src/cerbero_mcp/common/audit_helpers.py @@ -0,0 +1,94 @@ +"""Helper per cablare audit_write_op nei router. + +Pattern uso nel router:: + + @r.post("/tools/place_order") + async def _place_order( + params: t.PlaceOrderReq, + request: Request, + client: DeribitClient = Depends(get_deribit_client), + ): + return await audit_call( + request=request, + exchange="deribit", + action="place_order", + target_field="instrument_name", + params=params, + tool_fn=lambda: t.place_order(client, params, creds=...), + ) +""" +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from fastapi import Request +from pydantic import BaseModel + +from cerbero_mcp.common.audit import audit_write_op + + +def _extract_target(params: BaseModel | None, target_field: str | None) -> str | None: + if params is None or target_field is None: + return None + val = getattr(params, target_field, None) + if val is None: + return None + return str(val) + + +def _safe_dump(params: BaseModel | None) -> dict[str, Any]: + if params is None: + return {} + try: + return params.model_dump(mode="json", exclude_none=True) + except Exception: + return {} + + +async def audit_call( + *, + request: Request, + exchange: str, + action: str, + tool_fn: Callable[[], Awaitable[Any]], + params: BaseModel | None = None, + target_field: str | None = None, +) -> Any: + """Esegue tool_fn e logga audit (success o error). Riraisola eccezioni.""" + actor = getattr(request.state, "environment", None) + target = _extract_target(params, target_field) + payload = _safe_dump(params) + + try: + result = await tool_fn() + except Exception as e: + audit_write_op( + actor=actor, + action=action, + exchange=exchange, + target=target, + payload=payload, + error=f"{type(e).__name__}: {e}", + ) + raise + + # Se result รจ dict, passa raw; altrimenti tenta serializzazione + audit_result: dict[str, Any] | None = None + if isinstance(result, dict): + audit_result = result + elif hasattr(result, "model_dump"): + try: + audit_result = result.model_dump(mode="json") + except Exception: + audit_result = None + + audit_write_op( + actor=actor, + action=action, + exchange=exchange, + target=target, + payload=payload, + result=audit_result, + ) + return result diff --git a/src/cerbero_mcp/exchanges/bybit/tools.py b/src/cerbero_mcp/exchanges/bybit/tools.py index 4f39bd7..5a6a79d 100644 --- a/src/cerbero_mcp/exchanges/bybit/tools.py +++ b/src/cerbero_mcp/exchanges/bybit/tools.py @@ -359,7 +359,6 @@ async def place_order( reduce_only=params.reduce_only, position_idx=params.position_idx, ) - # TODO V2: wire audit via request.state.environment in router return result @@ -370,7 +369,6 @@ async def place_combo_order( category=params.category, legs=[leg.model_dump() for leg in params.legs], ) - # TODO V2: wire audit via request.state.environment in router return result diff --git a/src/cerbero_mcp/exchanges/deribit/tools.py b/src/cerbero_mcp/exchanges/deribit/tools.py index f9b35eb..b92cba6 100644 --- a/src/cerbero_mcp/exchanges/deribit/tools.py +++ b/src/cerbero_mcp/exchanges/deribit/tools.py @@ -481,7 +481,6 @@ async def place_order( post_only=params.post_only, label=params.label, ) - # TODO V2: wire audit via request.state.environment in router return result @@ -502,29 +501,24 @@ async def place_combo_order( price=params.price, label=params.label, ) - # TODO V2: wire audit via request.state.environment in router return result async def cancel_order(client: DeribitClient, params: CancelOrderReq) -> dict: result = await client.cancel_order(params.order_id) - # TODO V2: wire audit via request.state.environment in router return result async def set_stop_loss(client: DeribitClient, params: SetStopLossReq) -> dict: result = await client.set_stop_loss(params.order_id, params.stop_price) - # TODO V2: wire audit via request.state.environment in router return result async def set_take_profit(client: DeribitClient, params: SetTakeProfitReq) -> dict: result = await client.set_take_profit(params.order_id, params.tp_price) - # TODO V2: wire audit via request.state.environment in router return result async def close_position(client: DeribitClient, params: ClosePositionReq) -> dict: result = await client.close_position(params.instrument_name) - # TODO V2: wire audit via request.state.environment in router return result diff --git a/src/cerbero_mcp/exchanges/hyperliquid/tools.py b/src/cerbero_mcp/exchanges/hyperliquid/tools.py index ea0f332..4e25de4 100644 --- a/src/cerbero_mcp/exchanges/hyperliquid/tools.py +++ b/src/cerbero_mcp/exchanges/hyperliquid/tools.py @@ -303,7 +303,6 @@ async def place_order( price=params.price, reduce_only=params.reduce_only, ) - # TODO V2: wire audit via request.state.environment in router return result diff --git a/src/cerbero_mcp/routers/alpaca.py b/src/cerbero_mcp/routers/alpaca.py index 38fc2a1..8cea2d7 100644 --- a/src/cerbero_mcp/routers/alpaca.py +++ b/src/cerbero_mcp/routers/alpaca.py @@ -11,6 +11,7 @@ from typing import Literal, cast from fastapi import APIRouter, Depends, Request from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.common.audit_helpers import audit_call from cerbero_mcp.exchanges.alpaca import tools as t from cerbero_mcp.exchanges.alpaca.client import AlpacaClient @@ -136,41 +137,86 @@ def make_router() -> APIRouter: client: AlpacaClient = Depends(get_alpaca_client), ): creds = _build_creds(request) - return await t.place_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="alpaca", + action="place_order", + target_field="symbol", + params=params, + tool_fn=lambda: t.place_order(client, params, creds=creds), + ) @r.post("/tools/amend_order") async def _amend_order( params: t.AmendOrderReq, + request: Request, client: AlpacaClient = Depends(get_alpaca_client), ): - return await t.amend_order(client, params) + return await audit_call( + request=request, + exchange="alpaca", + action="amend_order", + target_field="order_id", + params=params, + tool_fn=lambda: t.amend_order(client, params), + ) @r.post("/tools/cancel_order") async def _cancel_order( params: t.CancelOrderReq, + request: Request, client: AlpacaClient = Depends(get_alpaca_client), ): - return await t.cancel_order(client, params) + return await audit_call( + request=request, + exchange="alpaca", + action="cancel_order", + target_field="order_id", + params=params, + tool_fn=lambda: t.cancel_order(client, params), + ) @r.post("/tools/cancel_all_orders") async def _cancel_all_orders( params: t.CancelAllOrdersReq, + request: Request, client: AlpacaClient = Depends(get_alpaca_client), ): - return await t.cancel_all_orders(client, params) + return await audit_call( + request=request, + exchange="alpaca", + action="cancel_all_orders", + params=params, + tool_fn=lambda: t.cancel_all_orders(client, params), + ) @r.post("/tools/close_position") async def _close_position( params: t.ClosePositionReq, + request: Request, client: AlpacaClient = Depends(get_alpaca_client), ): - return await t.close_position(client, params) + return await audit_call( + request=request, + exchange="alpaca", + action="close_position", + target_field="symbol", + params=params, + tool_fn=lambda: t.close_position(client, params), + ) @r.post("/tools/close_all_positions") async def _close_all_positions( params: t.CloseAllPositionsReq, + request: Request, client: AlpacaClient = Depends(get_alpaca_client), ): - return await t.close_all_positions(client, params) + return await audit_call( + request=request, + exchange="alpaca", + action="close_all_positions", + params=params, + tool_fn=lambda: t.close_all_positions(client, params), + ) return r diff --git a/src/cerbero_mcp/routers/bybit.py b/src/cerbero_mcp/routers/bybit.py index 0beb88e..211581a 100644 --- a/src/cerbero_mcp/routers/bybit.py +++ b/src/cerbero_mcp/routers/bybit.py @@ -11,6 +11,7 @@ from typing import Literal, cast from fastapi import APIRouter, Depends, Request from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.common.audit_helpers import audit_call from cerbero_mcp.exchanges.bybit import tools as t from cerbero_mcp.exchanges.bybit.client import BybitClient @@ -182,7 +183,14 @@ def make_router() -> APIRouter: client: BybitClient = Depends(get_bybit_client), ): creds = _build_creds(request) - return await t.place_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="bybit", + action="place_order", + target_field="symbol", + params=params, + tool_fn=lambda: t.place_order(client, params, creds=creds), + ) @r.post("/tools/place_combo_order") async def _place_combo_order( @@ -191,49 +199,103 @@ def make_router() -> APIRouter: client: BybitClient = Depends(get_bybit_client), ): creds = _build_creds(request) - return await t.place_combo_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="bybit", + action="place_combo_order", + params=params, + tool_fn=lambda: t.place_combo_order(client, params, creds=creds), + ) @r.post("/tools/amend_order") async def _amend_order( params: t.AmendOrderReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.amend_order(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="amend_order", + target_field="symbol", + params=params, + tool_fn=lambda: t.amend_order(client, params), + ) @r.post("/tools/cancel_order") async def _cancel_order( params: t.CancelOrderReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.cancel_order(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="cancel_order", + target_field="order_id", + params=params, + tool_fn=lambda: t.cancel_order(client, params), + ) @r.post("/tools/cancel_all_orders") async def _cancel_all_orders( params: t.CancelAllReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.cancel_all_orders(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="cancel_all_orders", + target_field="symbol", + params=params, + tool_fn=lambda: t.cancel_all_orders(client, params), + ) @r.post("/tools/set_stop_loss") async def _set_stop_loss( params: t.SetStopLossReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.set_stop_loss(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="set_stop_loss", + target_field="symbol", + params=params, + tool_fn=lambda: t.set_stop_loss(client, params), + ) @r.post("/tools/set_take_profit") async def _set_take_profit( params: t.SetTakeProfitReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.set_take_profit(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="set_take_profit", + target_field="symbol", + params=params, + tool_fn=lambda: t.set_take_profit(client, params), + ) @r.post("/tools/close_position") async def _close_position( params: t.ClosePositionReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.close_position(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="close_position", + target_field="symbol", + params=params, + tool_fn=lambda: t.close_position(client, params), + ) @r.post("/tools/set_leverage") async def _set_leverage( @@ -242,20 +304,43 @@ def make_router() -> APIRouter: client: BybitClient = Depends(get_bybit_client), ): creds = _build_creds(request) - return await t.set_leverage(client, params, creds=creds) + return await audit_call( + request=request, + exchange="bybit", + action="set_leverage", + target_field="symbol", + params=params, + tool_fn=lambda: t.set_leverage(client, params, creds=creds), + ) @r.post("/tools/switch_position_mode") async def _switch_position_mode( params: t.SwitchModeReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.switch_position_mode(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="switch_position_mode", + target_field="symbol", + params=params, + tool_fn=lambda: t.switch_position_mode(client, params), + ) @r.post("/tools/transfer_asset") async def _transfer_asset( params: t.TransferReq, + request: Request, client: BybitClient = Depends(get_bybit_client), ): - return await t.transfer_asset(client, params) + return await audit_call( + request=request, + exchange="bybit", + action="transfer_asset", + target_field="coin", + params=params, + tool_fn=lambda: t.transfer_asset(client, params), + ) return r diff --git a/src/cerbero_mcp/routers/deribit.py b/src/cerbero_mcp/routers/deribit.py index 8fdaf74..62bdfea 100644 --- a/src/cerbero_mcp/routers/deribit.py +++ b/src/cerbero_mcp/routers/deribit.py @@ -11,6 +11,7 @@ from typing import Literal, cast from fastapi import APIRouter, Depends, Request from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.common.audit_helpers import audit_call from cerbero_mcp.exchanges.deribit import tools as t from cerbero_mcp.exchanges.deribit.client import DeribitClient @@ -249,7 +250,14 @@ def make_router() -> APIRouter: client: DeribitClient = Depends(get_deribit_client), ): creds = _build_creds(request) - return await t.place_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="deribit", + action="place_order", + target_field="instrument_name", + params=params, + tool_fn=lambda: t.place_order(client, params, creds=creds), + ) @r.post("/tools/place_combo_order") async def _place_combo_order( @@ -258,34 +266,72 @@ def make_router() -> APIRouter: client: DeribitClient = Depends(get_deribit_client), ): creds = _build_creds(request) - return await t.place_combo_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="deribit", + action="place_combo_order", + params=params, + tool_fn=lambda: t.place_combo_order(client, params, creds=creds), + ) @r.post("/tools/cancel_order") async def _cancel_order( params: t.CancelOrderReq, + request: Request, client: DeribitClient = Depends(get_deribit_client), ): - return await t.cancel_order(client, params) + return await audit_call( + request=request, + exchange="deribit", + action="cancel_order", + target_field="order_id", + params=params, + tool_fn=lambda: t.cancel_order(client, params), + ) @r.post("/tools/set_stop_loss") async def _set_stop_loss( params: t.SetStopLossReq, + request: Request, client: DeribitClient = Depends(get_deribit_client), ): - return await t.set_stop_loss(client, params) + return await audit_call( + request=request, + exchange="deribit", + action="set_stop_loss", + target_field="order_id", + params=params, + tool_fn=lambda: t.set_stop_loss(client, params), + ) @r.post("/tools/set_take_profit") async def _set_take_profit( params: t.SetTakeProfitReq, + request: Request, client: DeribitClient = Depends(get_deribit_client), ): - return await t.set_take_profit(client, params) + return await audit_call( + request=request, + exchange="deribit", + action="set_take_profit", + target_field="order_id", + params=params, + tool_fn=lambda: t.set_take_profit(client, params), + ) @r.post("/tools/close_position") async def _close_position( params: t.ClosePositionReq, + request: Request, client: DeribitClient = Depends(get_deribit_client), ): - return await t.close_position(client, params) + return await audit_call( + request=request, + exchange="deribit", + action="close_position", + target_field="instrument_name", + params=params, + tool_fn=lambda: t.close_position(client, params), + ) return r diff --git a/src/cerbero_mcp/routers/hyperliquid.py b/src/cerbero_mcp/routers/hyperliquid.py index 39cbd62..268a483 100644 --- a/src/cerbero_mcp/routers/hyperliquid.py +++ b/src/cerbero_mcp/routers/hyperliquid.py @@ -11,6 +11,7 @@ from typing import Literal, cast from fastapi import APIRouter, Depends, Request from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.common.audit_helpers import audit_call from cerbero_mcp.exchanges.hyperliquid import tools as t from cerbero_mcp.exchanges.hyperliquid.client import HyperliquidClient @@ -136,34 +137,73 @@ def make_router() -> APIRouter: client: HyperliquidClient = Depends(get_hyperliquid_client), ): creds = _build_creds(request) - return await t.place_order(client, params, creds=creds) + return await audit_call( + request=request, + exchange="hyperliquid", + action="place_order", + target_field="instrument", + params=params, + tool_fn=lambda: t.place_order(client, params, creds=creds), + ) @r.post("/tools/cancel_order") async def _cancel_order( params: t.CancelOrderReq, + request: Request, client: HyperliquidClient = Depends(get_hyperliquid_client), ): - return await t.cancel_order(client, params) + return await audit_call( + request=request, + exchange="hyperliquid", + action="cancel_order", + target_field="order_id", + params=params, + tool_fn=lambda: t.cancel_order(client, params), + ) @r.post("/tools/set_stop_loss") async def _set_stop_loss( params: t.SetStopLossReq, + request: Request, client: HyperliquidClient = Depends(get_hyperliquid_client), ): - return await t.set_stop_loss(client, params) + return await audit_call( + request=request, + exchange="hyperliquid", + action="set_stop_loss", + target_field="instrument", + params=params, + tool_fn=lambda: t.set_stop_loss(client, params), + ) @r.post("/tools/set_take_profit") async def _set_take_profit( params: t.SetTakeProfitReq, + request: Request, client: HyperliquidClient = Depends(get_hyperliquid_client), ): - return await t.set_take_profit(client, params) + return await audit_call( + request=request, + exchange="hyperliquid", + action="set_take_profit", + target_field="instrument", + params=params, + tool_fn=lambda: t.set_take_profit(client, params), + ) @r.post("/tools/close_position") async def _close_position( params: t.ClosePositionReq, + request: Request, client: HyperliquidClient = Depends(get_hyperliquid_client), ): - return await t.close_position(client, params) + return await audit_call( + request=request, + exchange="hyperliquid", + action="close_position", + target_field="instrument", + params=params, + tool_fn=lambda: t.close_position(client, params), + ) return r diff --git a/tests/unit/common/test_audit_helpers.py b/tests/unit/common/test_audit_helpers.py new file mode 100644 index 0000000..ec60858 --- /dev/null +++ b/tests/unit/common/test_audit_helpers.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import pytest +from pydantic import BaseModel + + +class FakeReq(BaseModel): + instrument_name: str + qty: float + + +@pytest.mark.asyncio +async def test_audit_call_logs_success(monkeypatch): + from cerbero_mcp.common.audit_helpers import audit_call + + logged = [] + + def fake_audit(**kw): + logged.append(kw) + + monkeypatch.setattr("cerbero_mcp.common.audit_helpers.audit_write_op", fake_audit) + + class FakeRequest: + class _State: + environment = "testnet" + state = _State() + + async def tool_fn(): + return {"order_id": "abc123", "state": "filled"} + + result = await audit_call( + request=FakeRequest(), # type: ignore[arg-type] + exchange="deribit", + action="place_order", + target_field="instrument_name", + params=FakeReq(instrument_name="BTC-PERPETUAL", qty=0.1), + tool_fn=tool_fn, + ) + assert result == {"order_id": "abc123", "state": "filled"} + assert len(logged) == 1 + rec = logged[0] + assert rec["actor"] == "testnet" + assert rec["exchange"] == "deribit" + assert rec["action"] == "place_order" + assert rec["target"] == "BTC-PERPETUAL" + assert rec["payload"]["qty"] == 0.1 + assert rec["result"]["order_id"] == "abc123" + assert "error" not in rec or rec.get("error") is None + + +@pytest.mark.asyncio +async def test_audit_call_logs_error_and_reraises(monkeypatch): + from cerbero_mcp.common.audit_helpers import audit_call + + logged = [] + + def fake_audit(**kw): + logged.append(kw) + + monkeypatch.setattr("cerbero_mcp.common.audit_helpers.audit_write_op", fake_audit) + + class FakeRequest: + class _State: + environment = "mainnet" + state = _State() + + async def tool_fn(): + raise RuntimeError("upstream timeout") + + with pytest.raises(RuntimeError, match="upstream timeout"): + await audit_call( + request=FakeRequest(), # type: ignore[arg-type] + exchange="deribit", + action="cancel_order", + target_field="instrument_name", + params=FakeReq(instrument_name="BTC-PERPETUAL", qty=0.0), + tool_fn=tool_fn, + ) + assert len(logged) == 1 + rec = logged[0] + assert rec["actor"] == "mainnet" + assert "RuntimeError: upstream timeout" in rec["error"] + + +@pytest.mark.asyncio +async def test_audit_call_no_params_no_target(): + from cerbero_mcp.common.audit_helpers import audit_call + + class FakeRequest: + class _State: + environment = "testnet" + state = _State() + + async def tool_fn(): + return {"ok": True} + + result = await audit_call( + request=FakeRequest(), # type: ignore[arg-type] + exchange="bybit", + action="cancel_all_orders", + tool_fn=tool_fn, + ) + assert result == {"ok": True}