diff --git a/src/cerbero_mcp/exchanges/__init__.py b/src/cerbero_mcp/exchanges/__init__.py index 38178c2..9e06d84 100644 --- a/src/cerbero_mcp/exchanges/__init__.py +++ b/src/cerbero_mcp/exchanges/__init__.py @@ -19,4 +19,12 @@ async def build_client( client_secret=settings.deribit.client_secret.get_secret_value(), testnet=(env == "testnet"), ) + if exchange == "bybit": + from cerbero_mcp.exchanges.bybit.client import BybitClient + + return BybitClient( + api_key=settings.bybit.api_key, + api_secret=settings.bybit.api_secret.get_secret_value(), + testnet=(env == "testnet"), + ) raise ValueError(f"unsupported exchange: {exchange}") diff --git a/src/cerbero_mcp/exchanges/bybit/__init__.py b/src/cerbero_mcp/exchanges/bybit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/cerbero_mcp/exchanges/bybit/client.py b/src/cerbero_mcp/exchanges/bybit/client.py new file mode 100644 index 0000000..e60f93d --- /dev/null +++ b/src/cerbero_mcp/exchanges/bybit/client.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +from cerbero_mcp.common import indicators as ind +from cerbero_mcp.common import microstructure as micro +from pybit.unified_trading import HTTP + + +def _f(v: Any) -> float | None: + try: + return float(v) + except (TypeError, ValueError): + return None + + +def _i(v: Any) -> int | None: + try: + return int(v) + except (TypeError, ValueError): + return None + + +class BybitClient: + def __init__( + self, + api_key: str, + api_secret: str, + testnet: bool = True, + http: Any | None = None, + ) -> None: + self.api_key = api_key + self.api_secret = api_secret + self.testnet = testnet + self._http = http or HTTP( + api_key=api_key, + api_secret=api_secret, + testnet=testnet, + ) + + async def _run(self, fn, /, **kwargs): + return await asyncio.to_thread(fn, **kwargs) + + @staticmethod + def _parse_ticker(row: dict) -> dict: + return { + "symbol": row.get("symbol"), + "last_price": _f(row.get("lastPrice")), + "mark_price": _f(row.get("markPrice")), + "bid": _f(row.get("bid1Price")), + "ask": _f(row.get("ask1Price")), + "volume_24h": _f(row.get("volume24h")), + "turnover_24h": _f(row.get("turnover24h")), + "funding_rate": _f(row.get("fundingRate")), + "open_interest": _f(row.get("openInterest")), + } + + async def get_ticker(self, symbol: str, category: str = "linear") -> dict: + resp = await self._run( + self._http.get_tickers, category=category, symbol=symbol + ) + rows = (resp.get("result") or {}).get("list") or [] + if not rows: + return {"symbol": symbol, "error": "not_found"} + return self._parse_ticker(rows[0]) + + async def get_ticker_batch( + self, symbols: list[str], category: str = "linear" + ) -> dict[str, dict]: + out: dict[str, dict] = {} + for sym in symbols: + out[sym] = await self.get_ticker(sym, category=category) + return out + + async def get_orderbook( + self, symbol: str, category: str = "linear", limit: int = 50 + ) -> dict: + resp = await self._run( + self._http.get_orderbook, category=category, symbol=symbol, limit=limit + ) + r = resp.get("result") or {} + return { + "symbol": r.get("s"), + "bids": [[float(p), float(q)] for p, q in (r.get("b") or [])], + "asks": [[float(p), float(q)] for p, q in (r.get("a") or [])], + "timestamp": r.get("ts"), + } + + async def get_historical( + self, + symbol: str, + category: str = "linear", + interval: str = "60", + start: int | None = None, + end: int | None = None, + limit: int = 1000, + ) -> dict: + kwargs = dict( + category=category, + symbol=symbol, + interval=interval, + limit=limit, + ) + if start is not None: + kwargs["start"] = start + if end is not None: + kwargs["end"] = end + resp = await self._run(self._http.get_kline, **kwargs) + rows = (resp.get("result") or {}).get("list") or [] + rows_sorted = sorted(rows, key=lambda r: int(r[0])) + candles = [ + { + "timestamp": int(r[0]), + "open": float(r[1]), + "high": float(r[2]), + "low": float(r[3]), + "close": float(r[4]), + "volume": float(r[5]), + } + for r in rows_sorted + ] + return {"symbol": symbol, "candles": candles} + + async def get_indicators( + self, + symbol: str, + category: str = "linear", + indicators: list[str] | None = None, + interval: str = "60", + start: int | None = None, + end: int | None = None, + ) -> dict: + indicators = indicators or ["rsi", "atr", "macd", "adx"] + historical = await self.get_historical( + symbol, category=category, interval=interval, start=start, end=end + ) + candles = historical.get("candles", []) + closes = [c["close"] for c in candles] + highs = [c["high"] for c in candles] + lows = [c["low"] for c in candles] + + out: dict[str, Any] = {"symbol": symbol, "category": category} + for name in indicators: + n = name.lower() + if n == "sma": + out["sma"] = ind.sma(closes, 20) + elif n == "rsi": + out["rsi"] = ind.rsi(closes) + elif n == "atr": + out["atr"] = ind.atr(highs, lows, closes) + elif n == "macd": + out["macd"] = ind.macd(closes) + elif n == "adx": + out["adx"] = ind.adx(highs, lows, closes) + else: + out[n] = None + return out + + async def get_funding_rate(self, symbol: str, category: str = "linear") -> dict: + resp = await self._run( + self._http.get_tickers, category=category, symbol=symbol + ) + rows = (resp.get("result") or {}).get("list") or [] + if not rows: + return {"symbol": symbol, "error": "not_found"} + row = rows[0] + return { + "symbol": row.get("symbol"), + "funding_rate": _f(row.get("fundingRate")), + "next_funding_time": _i(row.get("nextFundingTime")), + } + + async def get_funding_history( + self, symbol: str, category: str = "linear", limit: int = 100 + ) -> dict: + resp = await self._run( + self._http.get_funding_rate_history, + category=category, symbol=symbol, limit=limit, + ) + rows = (resp.get("result") or {}).get("list") or [] + hist = [ + { + "timestamp": int(r.get("fundingRateTimestamp", 0)), + "rate": float(r.get("fundingRate", 0)), + } + for r in rows + ] + return {"symbol": symbol, "history": hist} + + async def get_open_interest( + self, + symbol: str, + category: str = "linear", + interval: str = "5min", + limit: int = 288, + ) -> dict: + resp = await self._run( + self._http.get_open_interest, + category=category, symbol=symbol, intervalTime=interval, limit=limit, + ) + rows = (resp.get("result") or {}).get("list") or [] + points = [ + { + "timestamp": int(r.get("timestamp", 0)), + "oi": float(r.get("openInterest", 0)), + } + for r in rows + ] + current_oi = points[0]["oi"] if points else None + return { + "symbol": symbol, + "category": category, + "interval": interval, + "current_oi": current_oi, + "points": points, + } + + async def get_instruments(self, category: str = "linear", symbol: str | None = None) -> dict: + kwargs: dict[str, Any] = {"category": category} + if symbol: + kwargs["symbol"] = symbol + resp = await self._run(self._http.get_instruments_info, **kwargs) + rows = (resp.get("result") or {}).get("list") or [] + instruments = [] + for r in rows: + pf = r.get("priceFilter") or {} + lf = r.get("lotSizeFilter") or {} + instruments.append({ + "symbol": r.get("symbol"), + "status": r.get("status"), + "base_coin": r.get("baseCoin"), + "quote_coin": r.get("quoteCoin"), + "tick_size": _f(pf.get("tickSize")), + "qty_step": _f(lf.get("qtyStep")), + "min_qty": _f(lf.get("minOrderQty")), + }) + return {"category": category, "instruments": instruments} + + async def get_option_chain(self, base_coin: str, expiry: str | None = None) -> dict: + kwargs: dict[str, Any] = {"category": "option", "baseCoin": base_coin.upper()} + resp = await self._run(self._http.get_instruments_info, **kwargs) + rows = (resp.get("result") or {}).get("list") or [] + options = [] + for r in rows: + delivery = r.get("deliveryTime") + if expiry and expiry not in r.get("symbol", ""): + continue + options.append({ + "symbol": r.get("symbol"), + "base_coin": r.get("baseCoin"), + "settle_coin": r.get("settleCoin"), + "type": r.get("optionsType"), + "launch_time": int(r.get("launchTime", 0)), + "delivery_time": int(delivery) if delivery else None, + }) + return {"base_coin": base_coin.upper(), "options": options} + + async def get_positions( + self, category: str = "linear", settle_coin: str = "USDT" + ) -> list[dict]: + kwargs: dict[str, Any] = {"category": category} + if category in ("linear", "inverse"): + kwargs["settleCoin"] = settle_coin + resp = await self._run(self._http.get_positions, **kwargs) + rows = (resp.get("result") or {}).get("list") or [] + out = [] + for r in rows: + out.append({ + "symbol": r.get("symbol"), + "side": r.get("side"), + "size": _f(r.get("size")), + "entry_price": _f(r.get("avgPrice")), + "unrealized_pnl": _f(r.get("unrealisedPnl")), + "leverage": _f(r.get("leverage")), + "liquidation_price": _f(r.get("liqPrice")), + "position_value": _f(r.get("positionValue")), + }) + return out + + async def get_account_summary(self, account_type: str = "UNIFIED") -> dict: + resp = await self._run( + self._http.get_wallet_balance, accountType=account_type + ) + rows = (resp.get("result") or {}).get("list") or [] + if not rows: + return {"error": "no_account"} + a = rows[0] + coins = [] + for c in a.get("coin") or []: + coins.append({ + "coin": c.get("coin"), + "wallet_balance": _f(c.get("walletBalance")), + "equity": _f(c.get("equity")), + }) + return { + "account_type": a.get("accountType"), + "equity": _f(a.get("totalEquity")), + "wallet_balance": _f(a.get("totalWalletBalance")), + "margin_balance": _f(a.get("totalMarginBalance")), + "available_balance": _f(a.get("totalAvailableBalance")), + "unrealized_pnl": _f(a.get("totalPerpUPL")), + "coins": coins, + } + + async def get_trade_history( + self, category: str = "linear", limit: int = 50 + ) -> list[dict]: + resp = await self._run( + self._http.get_executions, category=category, limit=limit + ) + rows = (resp.get("result") or {}).get("list") or [] + return [ + { + "symbol": r.get("symbol"), + "side": r.get("side"), + "size": _f(r.get("execQty")), + "price": _f(r.get("execPrice")), + "fee": _f(r.get("execFee")), + "timestamp": _i(r.get("execTime")), + "order_id": r.get("orderId"), + } + for r in rows + ] + + async def get_open_orders( + self, + category: str = "linear", + symbol: str | None = None, + settle_coin: str = "USDT", + ) -> list[dict]: + kwargs: dict[str, Any] = {"category": category} + if category in ("linear", "inverse") and not symbol: + kwargs["settleCoin"] = settle_coin + if symbol: + kwargs["symbol"] = symbol + resp = await self._run(self._http.get_open_orders, **kwargs) + rows = (resp.get("result") or {}).get("list") or [] + return [ + { + "order_id": r.get("orderId"), + "symbol": r.get("symbol"), + "side": r.get("side"), + "qty": _f(r.get("qty")), + "price": _f(r.get("price")), + "type": r.get("orderType"), + "status": r.get("orderStatus"), + "reduce_only": bool(r.get("reduceOnly")), + } + for r in rows + ] + + async def get_orderbook_imbalance( + self, + symbol: str, + category: str = "linear", + depth: int = 10, + ) -> dict: + """Microstructure: bid/ask imbalance ratio + microprice + slope.""" + ob = await self.get_orderbook(symbol=symbol, category=category, limit=max(depth, 50)) + result = micro.orderbook_imbalance(ob.get("bids") or [], ob.get("asks") or [], depth=depth) + return { + "symbol": symbol, + "category": category, + "depth": depth, + **result, + "timestamp": ob.get("timestamp"), + } + + async def get_basis_term_structure(self, asset: str) -> dict: + """Basis curve futures (dated) vs perp + spot. Filtra contratti future + BTCUSDT / ETHUSDT con scadenza, calcola annualized basis per ognuno. + """ + import datetime as _dt + + asset = asset.upper() + spot = await self.get_ticker(f"{asset}USDT", category="spot") + perp = await self.get_ticker(f"{asset}USDT", category="linear") + sp = spot.get("last_price") + pp = perp.get("last_price") + + # Lista futures dated (linear/inverse) + instr = await self.get_instruments(category="linear") + items = (instr.get("instruments") or []) + futures = [ + x for x in items + if x.get("symbol", "").startswith(f"{asset}-") or x.get("symbol", "").startswith(f"{asset}USDT-") + ] + + rows: list[dict[str, Any]] = [] + if sp: + now_ms = int(_dt.datetime.now(_dt.UTC).timestamp() * 1000) + for f in futures[:10]: + tk = await self.get_ticker(f["symbol"], category="linear") + fp = tk.get("last_price") + expiry_ms = f.get("delivery_time") + if not fp or not expiry_ms: + continue + days = max((int(expiry_ms) - now_ms) / 86_400_000, 1) + basis_pct = 100.0 * (fp - sp) / sp + annualized = basis_pct * 365.0 / days + rows.append({ + "symbol": f["symbol"], + "expiry_ms": int(expiry_ms), + "days_to_expiry": round(days, 2), + "future_price": fp, + "basis_pct": round(basis_pct, 4), + "annualized_basis_pct": round(annualized, 4), + }) + + rows.sort(key=lambda r: r["days_to_expiry"]) + return { + "asset": asset, + "spot_price": sp, + "perp_price": pp, + "perp_basis_pct": round(100.0 * (pp - sp) / sp, 4) if (sp and pp) else None, + "term_structure": rows, + "data_timestamp": _dt.datetime.now(_dt.UTC).isoformat(), + } + + async def get_basis_spot_perp(self, asset: str) -> dict: + asset = asset.upper() + symbol = f"{asset}USDT" + spot = await self.get_ticker(symbol, category="spot") + perp = await self.get_ticker(symbol, category="linear") + sp = spot.get("last_price") + pp = perp.get("last_price") + basis_abs = basis_pct = None + if sp and pp: + basis_abs = pp - sp + basis_pct = 100.0 * basis_abs / sp + return { + "asset": asset, + "symbol": symbol, + "spot_price": sp, + "perp_price": pp, + "basis_abs": basis_abs, + "basis_pct": basis_pct, + "funding_rate": perp.get("funding_rate"), + } + + def _envelope(self, resp: dict, payload: dict) -> dict: + code = resp.get("retCode", 0) + if code != 0: + return {"error": resp.get("retMsg", "bybit_error"), "code": code} + return payload + + async def place_order( + self, + category: str, + symbol: str, + side: str, + qty: float, + order_type: str = "Limit", + price: float | None = None, + tif: str = "GTC", + reduce_only: bool = False, + position_idx: int | None = None, + ) -> dict: + kwargs: dict[str, Any] = { + "category": category, + "symbol": symbol, + "side": side, + "qty": str(qty), + "orderType": order_type, + "timeInForce": tif, + "reduceOnly": reduce_only, + } + if price is not None: + kwargs["price"] = str(price) + if position_idx is not None: + kwargs["positionIdx"] = position_idx + if category == "option": + import uuid + kwargs["orderLinkId"] = f"cerbero-{uuid.uuid4().hex[:16]}" + resp = await self._run(self._http.place_order, **kwargs) + r = resp.get("result") or {} + return self._envelope(resp, { + "order_id": r.get("orderId"), + "order_link_id": r.get("orderLinkId"), + "status": "submitted", + }) + + async def place_combo_order( + self, + category: str, + legs: list[dict[str, Any]], + ) -> dict: + """Atomic multi-leg via /v5/order/create-batch (Bybit option only). + + Bybit supporta batch_order solo su category='option'. Per perp/linear + usare loop di place_order (non atomic). + + legs: [{symbol, side, qty, order_type, price?, tif?, reduce_only?}]. + """ + if category != "option": + raise ValueError("place_combo_order: Bybit batch_order è disponibile solo su category='option'") + if len(legs) < 2: + raise ValueError("combo requires at least 2 legs") + + import uuid + request: list[dict[str, Any]] = [] + for leg in legs: + entry: dict[str, Any] = { + "symbol": leg["symbol"], + "side": leg["side"], + "qty": str(leg["qty"]), + "orderType": leg.get("order_type", "Limit"), + "timeInForce": leg.get("tif", "GTC"), + "reduceOnly": leg.get("reduce_only", False), + "orderLinkId": f"cerbero-{uuid.uuid4().hex[:16]}", + } + if leg.get("price") is not None: + entry["price"] = str(leg["price"]) + request.append(entry) + + resp = await self._run(self._http.place_batch_order, category=category, request=request) + result_list = (resp.get("result") or {}).get("list") or [] + orders = [ + { + "order_id": r.get("orderId"), + "order_link_id": r.get("orderLinkId"), + "status": "submitted", + } + for r in result_list + ] + return self._envelope(resp, {"orders": orders}) + + async def amend_order( + self, + category: str, + symbol: str, + order_id: str, + new_qty: float | None = None, + new_price: float | None = None, + ) -> dict: + kwargs: dict[str, Any] = { + "category": category, + "symbol": symbol, + "orderId": order_id, + } + if new_qty is not None: + kwargs["qty"] = str(new_qty) + if new_price is not None: + kwargs["price"] = str(new_price) + resp = await self._run(self._http.amend_order, **kwargs) + r = resp.get("result") or {} + return self._envelope(resp, { + "order_id": r.get("orderId", order_id), + "status": "amended", + }) + + async def cancel_order( + self, category: str, symbol: str, order_id: str + ) -> dict: + resp = await self._run( + self._http.cancel_order, + category=category, symbol=symbol, orderId=order_id, + ) + r = resp.get("result") or {} + return self._envelope(resp, { + "order_id": r.get("orderId", order_id), + "status": "cancelled", + }) + + async def cancel_all_orders( + self, category: str, symbol: str | None = None + ) -> dict: + kwargs: dict[str, Any] = {"category": category} + if symbol: + kwargs["symbol"] = symbol + resp = await self._run(self._http.cancel_all_orders, **kwargs) + r = resp.get("result") or {} + ids = [x.get("orderId") for x in (r.get("list") or [])] + return self._envelope(resp, { + "cancelled_ids": ids, + "count": len(ids), + }) + + async def set_stop_loss( + self, category: str, symbol: str, stop_loss: float, + position_idx: int = 0, + ) -> dict: + resp = await self._run( + self._http.set_trading_stop, + category=category, symbol=symbol, + stopLoss=str(stop_loss), positionIdx=position_idx, + ) + return self._envelope(resp, { + "symbol": symbol, "stop_loss": stop_loss, + "status": "stop_loss_set", + }) + + async def set_take_profit( + self, category: str, symbol: str, take_profit: float, + position_idx: int = 0, + ) -> dict: + resp = await self._run( + self._http.set_trading_stop, + category=category, symbol=symbol, + takeProfit=str(take_profit), positionIdx=position_idx, + ) + return self._envelope(resp, { + "symbol": symbol, "take_profit": take_profit, + "status": "take_profit_set", + }) + + async def close_position(self, category: str, symbol: str) -> dict: + positions = await self.get_positions(category=category) + target = next((p for p in positions if p["symbol"] == symbol and (p["size"] or 0) > 0), None) + if not target: + return {"error": "no_open_position", "symbol": symbol} + close_side = "Sell" if target["side"] == "Buy" else "Buy" + return await self.place_order( + category=category, + symbol=symbol, + side=close_side, + qty=target["size"], + order_type="Market", + reduce_only=True, + tif="IOC", + ) + + async def set_leverage( + self, category: str, symbol: str, leverage: int + ) -> dict: + resp = await self._run( + self._http.set_leverage, + category=category, symbol=symbol, + buyLeverage=str(leverage), sellLeverage=str(leverage), + ) + return self._envelope(resp, { + "symbol": symbol, "leverage": leverage, + "status": "leverage_set", + }) + + async def switch_position_mode( + self, category: str, symbol: str, mode: str + ) -> dict: + mode_code = 3 if mode.lower() == "hedge" else 0 + resp = await self._run( + self._http.switch_position_mode, + category=category, symbol=symbol, mode=mode_code, + ) + return self._envelope(resp, { + "symbol": symbol, "mode": mode, + "status": "mode_switched", + }) + + async def transfer_asset( + self, + coin: str, + amount: float, + from_type: str, + to_type: str, + ) -> dict: + import uuid + resp = await self._run( + self._http.create_internal_transfer, + transferId=str(uuid.uuid4()), + coin=coin, + amount=str(amount), + fromAccountType=from_type, + toAccountType=to_type, + ) + r = resp.get("result") or {} + return self._envelope(resp, { + "transfer_id": r.get("transferId"), + "coin": coin, + "amount": amount, + "status": "submitted", + }) diff --git a/src/cerbero_mcp/exchanges/bybit/leverage_cap.py b/src/cerbero_mcp/exchanges/bybit/leverage_cap.py new file mode 100644 index 0000000..d04dd51 --- /dev/null +++ b/src/cerbero_mcp/exchanges/bybit/leverage_cap.py @@ -0,0 +1,56 @@ +"""Leverage cap server-side per place_order. + +Cap letto dal secret JSON via campo `max_leverage`. Default 1 (cash) se assente. +""" +from __future__ import annotations + +from fastapi import HTTPException + + +def get_max_leverage(creds: dict) -> int: + """Legge max_leverage dal secret. Default 1 se mancante.""" + raw = creds.get("max_leverage", 1) + try: + value = int(raw) + except (TypeError, ValueError): + value = 1 + return max(1, value) + + +def enforce_leverage( + requested: int | float | None, + *, + creds: dict, + exchange: str, +) -> int: + """Verifica e applica leverage cap. Ritorna leverage applicabile. + + Solleva HTTPException(403, LEVERAGE_CAP_EXCEEDED) se requested > cap. + Se requested is None, applica il cap come default. + """ + cap = get_max_leverage(creds) + if requested is None: + return cap + lev = int(requested) + if lev < 1: + raise HTTPException( + status_code=403, + detail={ + "error": "LEVERAGE_CAP_EXCEEDED", + "exchange": exchange, + "requested": lev, + "max": cap, + "reason": "leverage must be >= 1", + }, + ) + if lev > cap: + raise HTTPException( + status_code=403, + detail={ + "error": "LEVERAGE_CAP_EXCEEDED", + "exchange": exchange, + "requested": lev, + "max": cap, + }, + ) + return lev diff --git a/src/cerbero_mcp/exchanges/bybit/tools.py b/src/cerbero_mcp/exchanges/bybit/tools.py new file mode 100644 index 0000000..4f39bd7 --- /dev/null +++ b/src/cerbero_mcp/exchanges/bybit/tools.py @@ -0,0 +1,442 @@ +"""Tool bybit V2: pydantic schemas + async functions. + +Ogni funzione prende (client: BybitClient, params: ) e restituisce +un dict (o un model Pydantic). Pure logica, no FastAPI dependency, no ACL. +L'autenticazione bearer è gestita dal middleware in cerbero_mcp.auth; +l'audit verrà cablato dal router via request.state.environment. +""" +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from cerbero_mcp.exchanges.bybit.client import BybitClient +from cerbero_mcp.exchanges.bybit.leverage_cap import ( + enforce_leverage as _enforce_leverage, +) +from cerbero_mcp.exchanges.bybit.leverage_cap import get_max_leverage + +# === Schemas: reads === + + +class TickerReq(BaseModel): + symbol: str + category: str = "linear" + + +class TickerBatchReq(BaseModel): + symbols: list[str] + category: str = "linear" + + +class OrderbookReq(BaseModel): + symbol: str + category: str = "linear" + limit: int = 50 + + +class HistoricalReq(BaseModel): + symbol: str + category: str = "linear" + interval: str = "60" + start: int | None = None + end: int | None = None + limit: int = 1000 + + +class IndicatorsReq(BaseModel): + symbol: str + category: str = "linear" + indicators: list[str] = ["rsi", "atr", "macd", "adx"] + interval: str = "60" + start: int | None = None + end: int | None = None + + +class FundingRateReq(BaseModel): + symbol: str + category: str = "linear" + + +class FundingHistoryReq(BaseModel): + symbol: str + category: str = "linear" + limit: int = 100 + + +class OpenInterestReq(BaseModel): + symbol: str + category: str = "linear" + interval: str = "5min" + limit: int = 288 + + +class InstrumentsReq(BaseModel): + category: str = "linear" + symbol: str | None = None + + +class OptionChainReq(BaseModel): + base_coin: str + expiry: str | None = None + + +class PositionsReq(BaseModel): + category: str = "linear" + + +class AccountSummaryReq(BaseModel): + pass + + +class TradeHistoryReq(BaseModel): + category: str = "linear" + limit: int = 50 + + +class OpenOrdersReq(BaseModel): + category: str = "linear" + symbol: str | None = None + + +class BasisSpotPerpReq(BaseModel): + asset: str + + +class OrderbookImbalanceReq(BaseModel): + symbol: str + category: str = "linear" + depth: int = 10 + + +class BasisTermStructureReq(BaseModel): + asset: str + + +# === Schemas: writes === + + +class PlaceOrderReq(BaseModel): + category: str + symbol: str + side: str + qty: float + order_type: str = "Limit" + price: float | None = None + tif: str = "GTC" + reduce_only: bool = False + position_idx: int | None = None + + model_config = { + "json_schema_extra": { + "examples": [ + { + "summary": "Market buy 0.01 BTCUSDT linear perp", + "value": { + "category": "linear", + "symbol": "BTCUSDT", + "side": "Buy", + "qty": 0.01, + "order_type": "Market", + }, + } + ] + } + } + + +class ComboLegReq(BaseModel): + symbol: str + side: str + qty: float + order_type: str = "Limit" + price: float | None = None + tif: str = "GTC" + reduce_only: bool = False + + +class PlaceComboOrderReq(BaseModel): + category: str = "option" + legs: list[ComboLegReq] = Field(..., min_length=2) + + +class AmendOrderReq(BaseModel): + category: str + symbol: str + order_id: str + new_qty: float | None = None + new_price: float | None = None + + +class CancelOrderReq(BaseModel): + category: str + symbol: str + order_id: str + + +class CancelAllReq(BaseModel): + category: str + symbol: str | None = None + + +class SetStopLossReq(BaseModel): + category: str + symbol: str + stop_loss: float + position_idx: int = 0 + + +class SetTakeProfitReq(BaseModel): + category: str + symbol: str + take_profit: float + position_idx: int = 0 + + +class ClosePositionReq(BaseModel): + category: str + symbol: str + + +class SetLeverageReq(BaseModel): + category: str + symbol: str + leverage: int + + +class SwitchModeReq(BaseModel): + category: str + symbol: str + mode: str + + +class TransferReq(BaseModel): + coin: str + amount: float + from_type: str + to_type: str + + +# === Tools (reads) === + + +async def environment_info( + client: BybitClient, *, creds: dict, env_info: Any | None = None +) -> dict: + if env_info is None: + return { + "exchange": "bybit", + "environment": "testnet" if client.testnet else "mainnet", + "source": "credentials", + "env_value": None, + "base_url": getattr(client, "base_url", None), + "max_leverage": get_max_leverage(creds), + } + return { + "exchange": env_info.exchange, + "environment": env_info.environment, + "source": env_info.source, + "env_value": env_info.env_value, + "base_url": env_info.base_url, + "max_leverage": get_max_leverage(creds), + } + + +async def get_ticker(client: BybitClient, params: TickerReq) -> dict: + return await client.get_ticker(params.symbol, params.category) + + +async def get_ticker_batch(client: BybitClient, params: TickerBatchReq) -> dict: + return await client.get_ticker_batch(params.symbols, params.category) + + +async def get_orderbook(client: BybitClient, params: OrderbookReq) -> dict: + return await client.get_orderbook(params.symbol, params.category, params.limit) + + +async def get_historical(client: BybitClient, params: HistoricalReq) -> dict: + return await client.get_historical( + params.symbol, + params.category, + params.interval, + params.start, + params.end, + params.limit, + ) + + +async def get_indicators(client: BybitClient, params: IndicatorsReq) -> dict: + return await client.get_indicators( + params.symbol, + params.category, + params.indicators, + params.interval, + params.start, + params.end, + ) + + +async def get_funding_rate(client: BybitClient, params: FundingRateReq) -> dict: + return await client.get_funding_rate(params.symbol, params.category) + + +async def get_funding_history(client: BybitClient, params: FundingHistoryReq) -> dict: + return await client.get_funding_history( + params.symbol, params.category, params.limit + ) + + +async def get_open_interest(client: BybitClient, params: OpenInterestReq) -> dict: + return await client.get_open_interest( + params.symbol, params.category, params.interval, params.limit + ) + + +async def get_instruments(client: BybitClient, params: InstrumentsReq) -> dict: + return await client.get_instruments(params.category, params.symbol) + + +async def get_option_chain(client: BybitClient, params: OptionChainReq) -> dict: + return await client.get_option_chain(params.base_coin, params.expiry) + + +async def get_positions(client: BybitClient, params: PositionsReq) -> dict: + return {"positions": await client.get_positions(params.category)} + + +async def get_account_summary( + client: BybitClient, params: AccountSummaryReq +) -> dict: + return await client.get_account_summary() + + +async def get_trade_history(client: BybitClient, params: TradeHistoryReq) -> dict: + return { + "trades": await client.get_trade_history(params.category, params.limit) + } + + +async def get_open_orders(client: BybitClient, params: OpenOrdersReq) -> dict: + return { + "orders": await client.get_open_orders(params.category, params.symbol) + } + + +async def get_basis_spot_perp(client: BybitClient, params: BasisSpotPerpReq) -> dict: + return await client.get_basis_spot_perp(params.asset) + + +async def get_orderbook_imbalance( + client: BybitClient, params: OrderbookImbalanceReq +) -> dict: + return await client.get_orderbook_imbalance( + params.symbol, params.category, params.depth + ) + + +async def get_basis_term_structure( + client: BybitClient, params: BasisTermStructureReq +) -> dict: + return await client.get_basis_term_structure(params.asset) + + +# === Tools (writes) === + + +async def place_order( + client: BybitClient, params: PlaceOrderReq, *, creds: dict +) -> dict: + # Bybit non ha leverage_cap parametro per place_order; cap applicato a set_leverage. + result = await client.place_order( + category=params.category, + symbol=params.symbol, + side=params.side, + qty=params.qty, + order_type=params.order_type, + price=params.price, + tif=params.tif, + reduce_only=params.reduce_only, + position_idx=params.position_idx, + ) + # TODO V2: wire audit via request.state.environment in router + return result + + +async def place_combo_order( + client: BybitClient, params: PlaceComboOrderReq, *, creds: dict +) -> dict: + result = await client.place_combo_order( + category=params.category, + legs=[leg.model_dump() for leg in params.legs], + ) + # TODO V2: wire audit via request.state.environment in router + return result + + +async def amend_order(client: BybitClient, params: AmendOrderReq) -> dict: + result = await client.amend_order( + params.category, + params.symbol, + params.order_id, + params.new_qty, + params.new_price, + ) + return result + + +async def cancel_order(client: BybitClient, params: CancelOrderReq) -> dict: + result = await client.cancel_order( + params.category, params.symbol, params.order_id + ) + return result + + +async def cancel_all_orders(client: BybitClient, params: CancelAllReq) -> dict: + result = await client.cancel_all_orders(params.category, params.symbol) + return result + + +async def set_stop_loss(client: BybitClient, params: SetStopLossReq) -> dict: + result = await client.set_stop_loss( + params.category, params.symbol, params.stop_loss, params.position_idx + ) + return result + + +async def set_take_profit(client: BybitClient, params: SetTakeProfitReq) -> dict: + result = await client.set_take_profit( + params.category, params.symbol, params.take_profit, params.position_idx + ) + return result + + +async def close_position(client: BybitClient, params: ClosePositionReq) -> dict: + result = await client.close_position(params.category, params.symbol) + return result + + +async def set_leverage( + client: BybitClient, params: SetLeverageReq, *, creds: dict +) -> dict: + _enforce_leverage(params.leverage, creds=creds, exchange="bybit") + result = await client.set_leverage( + params.category, params.symbol, params.leverage + ) + return result + + +async def switch_position_mode( + client: BybitClient, params: SwitchModeReq +) -> dict: + result = await client.switch_position_mode( + params.category, params.symbol, params.mode + ) + return result + + +async def transfer_asset(client: BybitClient, params: TransferReq) -> dict: + result = await client.transfer_asset( + params.coin, params.amount, params.from_type, params.to_type + ) + return result diff --git a/src/cerbero_mcp/routers/bybit.py b/src/cerbero_mcp/routers/bybit.py new file mode 100644 index 0000000..6af4465 --- /dev/null +++ b/src/cerbero_mcp/routers/bybit.py @@ -0,0 +1,261 @@ +"""Router /mcp-bybit/* — DI per env, client e (write) creds. + +Mappa 1:1 i tool di `cerbero_mcp.exchanges.bybit.tools` a endpoint +`POST /mcp-bybit/tools/{tool_name}`. L'autenticazione bearer è gestita +dal middleware in `cerbero_mcp.auth`; qui leggiamo solo `request.state.environment`. +""" +from __future__ import annotations + +from typing import Literal + +from fastapi import APIRouter, Depends, Request + +from cerbero_mcp.client_registry import ClientRegistry +from cerbero_mcp.exchanges.bybit import tools as t +from cerbero_mcp.exchanges.bybit.client import BybitClient + +Environment = Literal["testnet", "mainnet"] + + +def get_environment(request: Request) -> Environment: + return request.state.environment + + +async def get_bybit_client( + request: Request, env: Environment = Depends(get_environment) +) -> BybitClient: + registry: ClientRegistry = request.app.state.registry + return await registry.get("bybit", env) + + +def _build_creds(request: Request) -> dict: + """Costruisce dict `creds` minimale per leverage cap / metadata. + + Le credenziali vere sono già iniettate nel client da ClientRegistry; + qui passiamo solo il cap di leverage e l'api_key (metadata audit). + """ + settings = request.app.state.settings + return { + "max_leverage": settings.bybit.max_leverage, + "api_key": settings.bybit.api_key, + } + + +def make_router() -> APIRouter: + r = APIRouter(prefix="/mcp-bybit", tags=["bybit"]) + + # === READ tools === + + @r.post("/tools/environment_info") + async def _environment_info( + request: Request, + client: BybitClient = Depends(get_bybit_client), + ): + creds = _build_creds(request) + return await t.environment_info(client, creds=creds) + + @r.post("/tools/get_ticker") + async def _get_ticker( + params: t.TickerReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_ticker(client, params) + + @r.post("/tools/get_ticker_batch") + async def _get_ticker_batch( + params: t.TickerBatchReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_ticker_batch(client, params) + + @r.post("/tools/get_orderbook") + async def _get_orderbook( + params: t.OrderbookReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_orderbook(client, params) + + @r.post("/tools/get_historical") + async def _get_historical( + params: t.HistoricalReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_historical(client, params) + + @r.post("/tools/get_indicators") + async def _get_indicators( + params: t.IndicatorsReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_indicators(client, params) + + @r.post("/tools/get_funding_rate") + async def _get_funding_rate( + params: t.FundingRateReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_funding_rate(client, params) + + @r.post("/tools/get_funding_history") + async def _get_funding_history( + params: t.FundingHistoryReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_funding_history(client, params) + + @r.post("/tools/get_open_interest") + async def _get_open_interest( + params: t.OpenInterestReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_open_interest(client, params) + + @r.post("/tools/get_instruments") + async def _get_instruments( + params: t.InstrumentsReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_instruments(client, params) + + @r.post("/tools/get_option_chain") + async def _get_option_chain( + params: t.OptionChainReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_option_chain(client, params) + + @r.post("/tools/get_positions") + async def _get_positions( + params: t.PositionsReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_positions(client, params) + + @r.post("/tools/get_account_summary") + async def _get_account_summary( + params: t.AccountSummaryReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_account_summary(client, params) + + @r.post("/tools/get_trade_history") + async def _get_trade_history( + params: t.TradeHistoryReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_trade_history(client, params) + + @r.post("/tools/get_open_orders") + async def _get_open_orders( + params: t.OpenOrdersReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_open_orders(client, params) + + @r.post("/tools/get_basis_spot_perp") + async def _get_basis_spot_perp( + params: t.BasisSpotPerpReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_basis_spot_perp(client, params) + + @r.post("/tools/get_orderbook_imbalance") + async def _get_orderbook_imbalance( + params: t.OrderbookImbalanceReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_orderbook_imbalance(client, params) + + @r.post("/tools/get_basis_term_structure") + async def _get_basis_term_structure( + params: t.BasisTermStructureReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.get_basis_term_structure(client, params) + + # === WRITE tools (richiedono creds per leverage cap / audit) === + + @r.post("/tools/place_order") + async def _place_order( + params: t.PlaceOrderReq, + request: Request, + client: BybitClient = Depends(get_bybit_client), + ): + creds = _build_creds(request) + return await t.place_order(client, params, creds=creds) + + @r.post("/tools/place_combo_order") + async def _place_combo_order( + params: t.PlaceComboOrderReq, + request: Request, + client: BybitClient = Depends(get_bybit_client), + ): + creds = _build_creds(request) + return await t.place_combo_order(client, params, creds=creds) + + @r.post("/tools/amend_order") + async def _amend_order( + params: t.AmendOrderReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.amend_order(client, params) + + @r.post("/tools/cancel_order") + async def _cancel_order( + params: t.CancelOrderReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.cancel_order(client, params) + + @r.post("/tools/cancel_all_orders") + async def _cancel_all_orders( + params: t.CancelAllReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.cancel_all_orders(client, params) + + @r.post("/tools/set_stop_loss") + async def _set_stop_loss( + params: t.SetStopLossReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.set_stop_loss(client, params) + + @r.post("/tools/set_take_profit") + async def _set_take_profit( + params: t.SetTakeProfitReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.set_take_profit(client, params) + + @r.post("/tools/close_position") + async def _close_position( + params: t.ClosePositionReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.close_position(client, params) + + @r.post("/tools/set_leverage") + async def _set_leverage( + params: t.SetLeverageReq, + request: Request, + client: BybitClient = Depends(get_bybit_client), + ): + creds = _build_creds(request) + return await t.set_leverage(client, params, creds=creds) + + @r.post("/tools/switch_position_mode") + async def _switch_position_mode( + params: t.SwitchModeReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.switch_position_mode(client, params) + + @r.post("/tools/transfer_asset") + async def _transfer_asset( + params: t.TransferReq, + client: BybitClient = Depends(get_bybit_client), + ): + return await t.transfer_asset(client, params) + + return r diff --git a/tests/unit/exchanges/bybit/__init__.py b/tests/unit/exchanges/bybit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/exchanges/bybit/conftest.py b/tests/unit/exchanges/bybit/conftest.py new file mode 100644 index 0000000..51508cc --- /dev/null +++ b/tests/unit/exchanges/bybit/conftest.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from cerbero_mcp.exchanges.bybit.client import BybitClient + + +@pytest.fixture +def mock_http(): + return MagicMock(name="pybit_HTTP") + + +@pytest.fixture +def client(mock_http): + return BybitClient( + api_key="test_key", + api_secret="test_secret", + testnet=True, + http=mock_http, + ) diff --git a/tests/unit/exchanges/bybit/test_client.py b/tests/unit/exchanges/bybit/test_client.py new file mode 100644 index 0000000..7d5099a --- /dev/null +++ b/tests/unit/exchanges/bybit/test_client.py @@ -0,0 +1,588 @@ +from __future__ import annotations + +import pytest +from cerbero_mcp.exchanges.bybit.client import BybitClient + + +def test_client_init_stores_attrs(client, mock_http): + assert client.testnet is True + assert client._http is mock_http + + +def test_client_init_default_http(monkeypatch): + created = {} + + class FakeHTTP: + def __init__(self, **kwargs): + created.update(kwargs) + + monkeypatch.setattr("cerbero_mcp.exchanges.bybit.client.HTTP", FakeHTTP) + BybitClient(api_key="k", api_secret="s", testnet=False) + assert created["api_key"] == "k" + assert created["api_secret"] == "s" + assert created["testnet"] is False + + +@pytest.mark.asyncio +async def test_get_ticker(client, mock_http): + mock_http.get_tickers.return_value = { + "retCode": 0, + "result": { + "list": [{ + "symbol": "BTCUSDT", + "lastPrice": "60000", + "markPrice": "60010", + "bid1Price": "59995", + "ask1Price": "60005", + "volume24h": "1500.5", + "turnover24h": "90000000", + "fundingRate": "0.0001", + "openInterest": "50000", + }] + }, + } + t = await client.get_ticker("BTCUSDT", category="linear") + mock_http.get_tickers.assert_called_once_with(category="linear", symbol="BTCUSDT") + assert t["symbol"] == "BTCUSDT" + assert t["last_price"] == 60000.0 + assert t["mark_price"] == 60010.0 + assert t["bid"] == 59995.0 + assert t["ask"] == 60005.0 + assert t["volume_24h"] == 1500.5 + assert t["funding_rate"] == 0.0001 + assert t["open_interest"] == 50000.0 + + +@pytest.mark.asyncio +async def test_get_ticker_batch(client, mock_http): + def side_effect(**kwargs): + symbol = kwargs["symbol"] + return {"retCode": 0, "result": {"list": [{ + "symbol": symbol, "lastPrice": "1", "markPrice": "1", + "bid1Price": "1", "ask1Price": "1", "volume24h": "0", + "turnover24h": "0", "fundingRate": "0", "openInterest": "0", + }]}} + mock_http.get_tickers.side_effect = side_effect + out = await client.get_ticker_batch(["BTCUSDT", "ETHUSDT"], category="linear") + assert set(out.keys()) == {"BTCUSDT", "ETHUSDT"} + assert mock_http.get_tickers.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_ticker_not_found(client, mock_http): + mock_http.get_tickers.return_value = {"retCode": 0, "result": {"list": []}} + t = await client.get_ticker("UNKNOWNUSDT", category="linear") + assert t == {"symbol": "UNKNOWNUSDT", "error": "not_found"} + + +def test_parse_helpers(): + from cerbero_mcp.exchanges.bybit.client import _f, _i + assert _f("1.5") == 1.5 + assert _f("") is None + assert _f(None) is None + assert _i("42") == 42 + assert _i("") is None + assert _i(None) is None + + +@pytest.mark.asyncio +async def test_get_orderbook(client, mock_http): + mock_http.get_orderbook.return_value = { + "retCode": 0, + "result": { + "s": "BTCUSDT", + "b": [["59990", "0.5"], ["59980", "1.0"]], + "a": [["60010", "0.3"], ["60020", "0.7"]], + "ts": 1700000000000, + }, + } + ob = await client.get_orderbook("BTCUSDT", category="linear", limit=25) + mock_http.get_orderbook.assert_called_once_with( + category="linear", symbol="BTCUSDT", limit=25 + ) + assert ob["symbol"] == "BTCUSDT" + assert ob["bids"] == [[59990.0, 0.5], [59980.0, 1.0]] + assert ob["asks"] == [[60010.0, 0.3], [60020.0, 0.7]] + assert ob["timestamp"] == 1700000000000 + + +@pytest.mark.asyncio +async def test_get_historical(client, mock_http): + mock_http.get_kline.return_value = { + "retCode": 0, + "result": { + "list": [ + ["1700000000000", "60000", "60500", "59500", "60200", "100", "6020000"], + ["1700003600000", "60200", "60700", "60000", "60400", "80", "4832000"], + ] + }, + } + out = await client.get_historical( + "BTCUSDT", category="linear", interval="60", + start=1700000000000, end=1700003600000, + ) + mock_http.get_kline.assert_called_once_with( + category="linear", symbol="BTCUSDT", interval="60", + start=1700000000000, end=1700003600000, limit=1000, + ) + assert len(out["candles"]) == 2 + c0 = out["candles"][0] + assert c0["timestamp"] == 1700000000000 + assert c0["open"] == 60000.0 + assert c0["high"] == 60500.0 + assert c0["low"] == 59500.0 + assert c0["close"] == 60200.0 + assert c0["volume"] == 100.0 + + +@pytest.mark.asyncio +async def test_get_indicators(client, mock_http): + rows = [ + [str(1700000000000 + i * 3600_000), + str(60000 + i * 10), str(60000 + i * 10 + 5), + str(60000 + i * 10 - 5), str(60000 + i * 10 + 2), + "100", "6000000"] + for i in range(35) + ] + mock_http.get_kline.return_value = {"retCode": 0, "result": {"list": rows}} + out = await client.get_indicators( + "BTCUSDT", category="linear", + indicators=["rsi", "atr", "macd", "adx"], + interval="60", + ) + assert "rsi" in out and out["rsi"] is not None + assert "atr" in out and out["atr"] is not None + assert "macd" in out and out["macd"]["macd"] is not None + assert "adx" in out and out["adx"]["adx"] is not None + + +@pytest.mark.asyncio +async def test_get_funding_rate(client, mock_http): + mock_http.get_tickers.return_value = { + "retCode": 0, + "result": {"list": [{ + "symbol": "BTCUSDT", "fundingRate": "0.0001", + "nextFundingTime": "1700003600000", + "lastPrice": "60000", "markPrice": "60000", + "bid1Price": "0", "ask1Price": "0", + "volume24h": "0", "turnover24h": "0", "openInterest": "0", + }]}, + } + out = await client.get_funding_rate("BTCUSDT", category="linear") + assert out["symbol"] == "BTCUSDT" + assert out["funding_rate"] == 0.0001 + assert out["next_funding_time"] == 1700003600000 + + +@pytest.mark.asyncio +async def test_get_funding_history(client, mock_http): + mock_http.get_funding_rate_history.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTCUSDT", "fundingRate": "0.0001", "fundingRateTimestamp": "1700000000000"}, + {"symbol": "BTCUSDT", "fundingRate": "0.00008", "fundingRateTimestamp": "1699996400000"}, + ]}, + } + out = await client.get_funding_history("BTCUSDT", category="linear", limit=50) + mock_http.get_funding_rate_history.assert_called_once_with( + category="linear", symbol="BTCUSDT", limit=50 + ) + assert len(out["history"]) == 2 + assert out["history"][0]["rate"] == 0.0001 + + +@pytest.mark.asyncio +async def test_get_open_interest(client, mock_http): + mock_http.get_open_interest.return_value = { + "retCode": 0, + "result": {"list": [ + {"openInterest": "50000", "timestamp": "1700000000000"}, + {"openInterest": "49000", "timestamp": "1699996400000"}, + ]}, + } + out = await client.get_open_interest("BTCUSDT", category="linear", interval="5min", limit=100) + mock_http.get_open_interest.assert_called_once_with( + category="linear", symbol="BTCUSDT", intervalTime="5min", limit=100 + ) + assert len(out["points"]) == 2 + assert out["current_oi"] == 50000.0 + + +@pytest.mark.asyncio +async def test_get_instruments(client, mock_http): + mock_http.get_instruments_info.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTCUSDT", "status": "Trading", "baseCoin": "BTC", + "quoteCoin": "USDT", "priceFilter": {"tickSize": "0.1"}, + "lotSizeFilter": {"qtyStep": "0.001", "minOrderQty": "0.001"}}, + ]}, + } + out = await client.get_instruments(category="linear") + mock_http.get_instruments_info.assert_called_once_with(category="linear") + assert len(out["instruments"]) == 1 + inst = out["instruments"][0] + assert inst["symbol"] == "BTCUSDT" + assert inst["tick_size"] == 0.1 + assert inst["qty_step"] == 0.001 + + +@pytest.mark.asyncio +async def test_get_option_chain(client, mock_http): + mock_http.get_instruments_info.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTC-30JUN25-50000-C", "baseCoin": "BTC", + "settleCoin": "USDC", "optionsType": "Call", + "launchTime": "1700000000000", "deliveryTime": "1719734400000"}, + {"symbol": "BTC-30JUN25-50000-P", "baseCoin": "BTC", + "settleCoin": "USDC", "optionsType": "Put", + "launchTime": "1700000000000", "deliveryTime": "1719734400000"}, + ]}, + } + out = await client.get_option_chain(base_coin="BTC") + mock_http.get_instruments_info.assert_called_once_with(category="option", baseCoin="BTC") + assert len(out["options"]) == 2 + assert out["options"][0]["type"] == "Call" + + +@pytest.mark.asyncio +async def test_get_positions(client, mock_http): + mock_http.get_positions.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTCUSDT", "side": "Buy", "size": "0.1", + "avgPrice": "60000", "unrealisedPnl": "50", + "leverage": "10", "liqPrice": "50000", "positionValue": "6000"}, + ]}, + } + out = await client.get_positions(category="linear") + mock_http.get_positions.assert_called_once_with(category="linear", settleCoin="USDT") + assert len(out) == 1 + p = out[0] + assert p["symbol"] == "BTCUSDT" + assert p["side"] == "Buy" + assert p["size"] == 0.1 + assert p["entry_price"] == 60000.0 + assert p["liquidation_price"] == 50000.0 + + +@pytest.mark.asyncio +async def test_get_account_summary(client, mock_http): + mock_http.get_wallet_balance.return_value = { + "retCode": 0, + "result": {"list": [{ + "accountType": "UNIFIED", + "totalEquity": "10000", + "totalWalletBalance": "9500", + "totalMarginBalance": "9800", + "totalAvailableBalance": "9000", + "totalPerpUPL": "200", + "coin": [ + {"coin": "USDT", "walletBalance": "9500", "equity": "9700"} + ], + }]}, + } + out = await client.get_account_summary() + mock_http.get_wallet_balance.assert_called_once_with(accountType="UNIFIED") + assert out["equity"] == 10000.0 + assert out["available_balance"] == 9000.0 + assert out["unrealized_pnl"] == 200.0 + assert len(out["coins"]) == 1 + assert out["coins"][0]["coin"] == "USDT" + + +@pytest.mark.asyncio +async def test_get_trade_history(client, mock_http): + mock_http.get_executions.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTCUSDT", "side": "Buy", "execQty": "0.01", + "execPrice": "60000", "execFee": "0.1", + "execTime": "1700000000000", "orderId": "abc"}, + ]}, + } + out = await client.get_trade_history(category="linear", limit=50) + mock_http.get_executions.assert_called_once_with(category="linear", limit=50) + assert len(out) == 1 + assert out[0]["symbol"] == "BTCUSDT" + assert out[0]["size"] == 0.01 + assert out[0]["price"] == 60000.0 + + +@pytest.mark.asyncio +async def test_get_open_orders(client, mock_http): + mock_http.get_open_orders.return_value = { + "retCode": 0, + "result": {"list": [ + {"symbol": "BTCUSDT", "orderId": "o1", "side": "Buy", + "qty": "0.1", "price": "59000", "orderType": "Limit", + "orderStatus": "New", "reduceOnly": False}, + ]}, + } + out = await client.get_open_orders(category="linear") + mock_http.get_open_orders.assert_called_once_with(category="linear", settleCoin="USDT") + assert len(out) == 1 + assert out[0]["order_id"] == "o1" + assert out[0]["price"] == 59000.0 + + +@pytest.mark.asyncio +async def test_get_basis_spot_perp(client, mock_http): + def side(**kwargs): + if kwargs["category"] == "spot": + return {"retCode": 0, "result": {"list": [{ + "symbol": "BTCUSDT", "lastPrice": "60000", "markPrice": "60000", + "bid1Price": "59995", "ask1Price": "60005", + "volume24h": "0", "turnover24h": "0", + "fundingRate": "0", "openInterest": "0", + }]}} + else: + return {"retCode": 0, "result": {"list": [{ + "symbol": "BTCUSDT", "lastPrice": "60120", "markPrice": "60120", + "bid1Price": "60115", "ask1Price": "60125", + "volume24h": "0", "turnover24h": "0", + "fundingRate": "0.0001", "openInterest": "0", + }]}} + mock_http.get_tickers.side_effect = side + out = await client.get_basis_spot_perp("BTC") + assert out["asset"] == "BTC" + assert out["spot_price"] == 60000.0 + assert out["perp_price"] == 60120.0 + assert out["basis_abs"] == 120.0 + assert round(out["basis_pct"], 3) == 0.2 + + +@pytest.mark.asyncio +async def test_place_order_limit(client, mock_http): + mock_http.place_order.return_value = { + "retCode": 0, + "result": {"orderId": "ord123", "orderLinkId": ""}, + } + out = await client.place_order( + category="linear", symbol="BTCUSDT", side="Buy", + qty=0.01, order_type="Limit", price=60000.0, tif="GTC", + ) + assert out["order_id"] == "ord123" + kwargs = mock_http.place_order.call_args.kwargs + assert kwargs["category"] == "linear" + assert kwargs["symbol"] == "BTCUSDT" + assert kwargs["side"] == "Buy" + assert kwargs["qty"] == "0.01" + assert kwargs["orderType"] == "Limit" + assert kwargs["price"] == "60000.0" + assert kwargs["timeInForce"] == "GTC" + + +@pytest.mark.asyncio +async def test_place_order_error(client, mock_http): + mock_http.place_order.return_value = {"retCode": 10001, "retMsg": "insufficient balance"} + out = await client.place_order( + category="linear", symbol="BTCUSDT", side="Buy", qty=0.01, order_type="Market" + ) + assert out.get("error") == "insufficient balance" + assert out.get("code") == 10001 + + +@pytest.mark.asyncio +async def test_amend_order(client, mock_http): + mock_http.amend_order.return_value = {"retCode": 0, "result": {"orderId": "ord1"}} + out = await client.amend_order( + category="linear", symbol="BTCUSDT", order_id="ord1", new_qty=0.02 + ) + assert out["order_id"] == "ord1" + kwargs = mock_http.amend_order.call_args.kwargs + assert kwargs["orderId"] == "ord1" + assert kwargs["qty"] == "0.02" + assert "price" not in kwargs + + +@pytest.mark.asyncio +async def test_place_order_option_adds_link_id(client, mock_http): + mock_http.place_order.return_value = { + "retCode": 0, + "result": {"orderId": "opt1", "orderLinkId": "cerbero-abc"}, + } + await client.place_order( + category="option", symbol="BTC-24APR26-96000-C-USDT", + side="Buy", qty=0.01, order_type="Limit", price=5.0, + ) + kwargs = mock_http.place_order.call_args.kwargs + assert "orderLinkId" in kwargs + assert kwargs["orderLinkId"].startswith("cerbero-") + + +@pytest.mark.asyncio +async def test_place_order_linear_no_link_id(client, mock_http): + mock_http.place_order.return_value = {"retCode": 0, "result": {"orderId": "x"}} + await client.place_order( + category="linear", symbol="BTCUSDT", side="Buy", qty=0.01, order_type="Market" + ) + kwargs = mock_http.place_order.call_args.kwargs + assert "orderLinkId" not in kwargs + + +@pytest.mark.asyncio +async def test_place_combo_order_batch_option(client, mock_http): + """Combo order via place_batch_order su category=option (atomic, 1 round-trip).""" + mock_http.place_batch_order.return_value = { + "retCode": 0, + "result": { + "list": [ + {"orderId": "ord-1", "orderLinkId": "cerbero-leg1"}, + {"orderId": "ord-2", "orderLinkId": "cerbero-leg2"}, + ] + }, + } + legs = [ + {"symbol": "BTC-30APR26-75000-C-USDT", "side": "Buy", "qty": 0.01, "order_type": "Limit", "price": 5.0}, + {"symbol": "BTC-30APR26-80000-C-USDT", "side": "Sell", "qty": 0.01, "order_type": "Limit", "price": 3.0}, + ] + out = await client.place_combo_order(category="option", legs=legs) + assert len(out["orders"]) == 2 + assert out["orders"][0]["order_id"] == "ord-1" + kwargs = mock_http.place_batch_order.call_args.kwargs + assert kwargs["category"] == "option" + request = kwargs["request"] + assert len(request) == 2 + assert request[0]["symbol"] == "BTC-30APR26-75000-C-USDT" + assert request[0]["qty"] == "0.01" + assert request[0]["orderType"] == "Limit" + # CER: orderLinkId obbligatorio per option + assert "orderLinkId" in request[0] + + +@pytest.mark.asyncio +async def test_place_combo_order_error(client, mock_http): + mock_http.place_batch_order.return_value = {"retCode": 10001, "retMsg": "invalid leg"} + out = await client.place_combo_order( + category="option", + legs=[ + {"symbol": "X", "side": "Buy", "qty": 1, "order_type": "Limit", "price": 1.0}, + {"symbol": "Y", "side": "Sell", "qty": 1, "order_type": "Limit", "price": 1.0}, + ], + ) + assert out["error"] == "invalid leg" + assert out["code"] == 10001 + + +@pytest.mark.asyncio +async def test_place_combo_order_rejects_non_option(client, mock_http): + """Bybit batch_order è disponibile solo su option category.""" + import pytest as _pytest + with _pytest.raises(ValueError, match="option"): + await client.place_combo_order( + category="linear", + legs=[ + {"symbol": "BTCUSDT", "side": "Buy", "qty": 0.01, "order_type": "Market"}, + {"symbol": "ETHUSDT", "side": "Sell", "qty": 0.01, "order_type": "Market"}, + ], + ) + + +@pytest.mark.asyncio +async def test_cancel_order(client, mock_http): + mock_http.cancel_order.return_value = {"retCode": 0, "result": {"orderId": "ord1"}} + out = await client.cancel_order(category="linear", symbol="BTCUSDT", order_id="ord1") + mock_http.cancel_order.assert_called_once_with( + category="linear", symbol="BTCUSDT", orderId="ord1" + ) + assert out["order_id"] == "ord1" + assert out["status"] == "cancelled" + + +@pytest.mark.asyncio +async def test_cancel_all_orders(client, mock_http): + mock_http.cancel_all_orders.return_value = { + "retCode": 0, + "result": {"list": [{"orderId": "o1"}, {"orderId": "o2"}]}, + } + out = await client.cancel_all_orders(category="linear", symbol="BTCUSDT") + mock_http.cancel_all_orders.assert_called_once_with( + category="linear", symbol="BTCUSDT" + ) + assert out["cancelled_ids"] == ["o1", "o2"] + + +@pytest.mark.asyncio +async def test_set_stop_loss(client, mock_http): + mock_http.set_trading_stop.return_value = {"retCode": 0, "result": {}} + out = await client.set_stop_loss( + category="linear", symbol="BTCUSDT", stop_loss=55000.0 + ) + mock_http.set_trading_stop.assert_called_once() + kwargs = mock_http.set_trading_stop.call_args.kwargs + assert kwargs["category"] == "linear" + assert kwargs["symbol"] == "BTCUSDT" + assert kwargs["stopLoss"] == "55000.0" + assert kwargs.get("positionIdx", 0) == 0 + assert out["status"] == "stop_loss_set" + + +@pytest.mark.asyncio +async def test_set_take_profit(client, mock_http): + mock_http.set_trading_stop.return_value = {"retCode": 0, "result": {}} + out = await client.set_take_profit( + category="linear", symbol="BTCUSDT", take_profit=65000.0 + ) + kwargs = mock_http.set_trading_stop.call_args.kwargs + assert kwargs["takeProfit"] == "65000.0" + assert out["status"] == "take_profit_set" + + +@pytest.mark.asyncio +async def test_close_position(client, mock_http): + mock_http.get_positions.return_value = { + "retCode": 0, "result": {"list": [ + {"symbol": "BTCUSDT", "side": "Buy", "size": "0.1", + "avgPrice": "60000", "unrealisedPnl": "0", + "leverage": "10", "liqPrice": "0", "positionValue": "6000"}, + ]}, + } + mock_http.place_order.return_value = { + "retCode": 0, "result": {"orderId": "closeord", "orderLinkId": ""}, + } + out = await client.close_position(category="linear", symbol="BTCUSDT") + assert out["status"] == "submitted" + kwargs = mock_http.place_order.call_args.kwargs + assert kwargs["side"] == "Sell" + assert kwargs["qty"] == "0.1" + assert kwargs["reduceOnly"] is True + assert kwargs["orderType"] == "Market" + + +@pytest.mark.asyncio +async def test_set_leverage(client, mock_http): + mock_http.set_leverage.return_value = {"retCode": 0, "result": {}} + out = await client.set_leverage(category="linear", symbol="BTCUSDT", leverage=5) + mock_http.set_leverage.assert_called_once_with( + category="linear", symbol="BTCUSDT", buyLeverage="5", sellLeverage="5" + ) + assert out["status"] == "leverage_set" + + +@pytest.mark.asyncio +async def test_switch_position_mode(client, mock_http): + mock_http.switch_position_mode.return_value = {"retCode": 0, "result": {}} + out = await client.switch_position_mode( + category="linear", symbol="BTCUSDT", mode="hedge" + ) + kwargs = mock_http.switch_position_mode.call_args.kwargs + assert kwargs["mode"] == 3 + assert out["status"] == "mode_switched" + + +@pytest.mark.asyncio +async def test_transfer_asset(client, mock_http): + mock_http.create_internal_transfer.return_value = { + "retCode": 0, "result": {"transferId": "tx123"}, + } + out = await client.transfer_asset( + coin="USDT", amount=100.0, from_type="UNIFIED", to_type="FUND" + ) + kwargs = mock_http.create_internal_transfer.call_args.kwargs + assert kwargs["coin"] == "USDT" + assert kwargs["amount"] == "100.0" + assert kwargs["fromAccountType"] == "UNIFIED" + assert kwargs["toAccountType"] == "FUND" + assert out["transfer_id"] == "tx123" diff --git a/tests/unit/exchanges/bybit/test_leverage_cap.py b/tests/unit/exchanges/bybit/test_leverage_cap.py new file mode 100644 index 0000000..2d47e16 --- /dev/null +++ b/tests/unit/exchanges/bybit/test_leverage_cap.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import pytest +from fastapi import HTTPException +from cerbero_mcp.exchanges.bybit.leverage_cap import enforce_leverage, get_max_leverage + + +def test_get_max_leverage_returns_creds_value(): + creds = {"max_leverage": 5} + assert get_max_leverage(creds) == 5 + + +def test_get_max_leverage_default_when_missing(): + """Default 1 (cash) se il secret non ha max_leverage.""" + assert get_max_leverage({}) == 1 + + +def test_enforce_leverage_pass_under_cap(): + creds = {"max_leverage": 3} + enforce_leverage(2, creds=creds, exchange="bybit") # no raise + + +def test_enforce_leverage_pass_at_cap(): + creds = {"max_leverage": 3} + enforce_leverage(3, creds=creds, exchange="bybit") # no raise + + +def test_enforce_leverage_reject_over_cap(): + creds = {"max_leverage": 3} + with pytest.raises(HTTPException) as exc: + enforce_leverage(10, creds=creds, exchange="bybit") + assert exc.value.status_code == 403 + assert exc.value.detail["error"] == "LEVERAGE_CAP_EXCEEDED" + assert exc.value.detail["exchange"] == "bybit" + assert exc.value.detail["requested"] == 10 + assert exc.value.detail["max"] == 3 + + +def test_enforce_leverage_reject_when_below_one(): + creds = {"max_leverage": 3} + with pytest.raises(HTTPException) as exc: + enforce_leverage(0, creds=creds, exchange="bybit") + assert exc.value.status_code == 403 + + +def test_enforce_leverage_default_when_none(): + """Se requested è None, applica il cap come default.""" + creds = {"max_leverage": 3} + result = enforce_leverage(None, creds=creds, exchange="bybit") + assert result == 3 diff --git a/tests/unit/test_exchanges_builder.py b/tests/unit/test_exchanges_builder.py index 0bc1959..cd829e3 100644 --- a/tests/unit/test_exchanges_builder.py +++ b/tests/unit/test_exchanges_builder.py @@ -22,6 +22,34 @@ async def test_build_client_deribit_returns_correct_url(monkeypatch): assert "test" not in c_live.base_url.lower() +@pytest.mark.asyncio +async def test_build_client_bybit_returns_correct_env(monkeypatch): + from tests.unit.test_settings import _minimal_env + + for k, v in _minimal_env().items(): + monkeypatch.setenv(k, v) + + # Stub pybit HTTP per evitare connessione reale durante __init__ + from cerbero_mcp.exchanges.bybit import client as bybit_client + + class _FakeHTTP: + def __init__(self, **kwargs): + self.kwargs = kwargs + + monkeypatch.setattr(bybit_client, "HTTP", _FakeHTTP) + + from cerbero_mcp.settings import Settings + from cerbero_mcp.exchanges import build_client + + s = Settings() + c_test = await build_client(s, "bybit", "testnet") + c_live = await build_client(s, "bybit", "mainnet") + + assert c_test is not c_live + assert c_test.testnet is True + assert c_live.testnet is False + + @pytest.mark.asyncio async def test_build_client_unknown_exchange_raises(monkeypatch): from tests.unit.test_settings import _minimal_env