refactor(V2): bybit client da pybit a httpx puro (parità V1)

This commit is contained in:
AdrianoDev
2026-05-01 01:35:26 +02:00
parent 95b8bcfe96
commit 6097dde4e4
4 changed files with 1149 additions and 620 deletions
+426 -208
View File
@@ -1,13 +1,33 @@
"""Bybit V5 REST API client (httpx puro, no SDK).
Implementazione diretta su `httpx.AsyncClient` per i tool Cerbero MCP V2.
Mantiene parità di interfaccia con la versione precedente basata su
`pybit.unified_trading.HTTP` per non rompere `tools.py` né i router.
Auth Bybit V5:
Header X-BAPI-SIGN = HMAC_SHA256(secret,
timestamp + api_key + recv_window + (body_json | querystring))
"""
from __future__ import annotations from __future__ import annotations
import asyncio import hashlib
import hmac
import json
import time
import uuid
from typing import Any from typing import Any
from urllib.parse import urlencode
from pybit.unified_trading import HTTP import httpx
from cerbero_mcp.common import indicators as ind from cerbero_mcp.common import indicators as ind
from cerbero_mcp.common import microstructure as micro from cerbero_mcp.common import microstructure as micro
BASE_MAINNET = "https://api.bybit.com"
BASE_TESTNET = "https://api-testnet.bybit.com"
DEFAULT_RECV_WINDOW = "5000"
DEFAULT_TIMEOUT = 15.0
def _f(v: Any) -> float | None: def _f(v: Any) -> float | None:
try: try:
@@ -23,37 +43,143 @@ def _i(v: Any) -> int | None:
return None return None
class BybitAPIError(RuntimeError):
"""Errore di trasporto Bybit V5 (non gestito a livello envelope)."""
class BybitClient: class BybitClient:
"""Async REST client per Bybit V5 (linear/inverse/spot/option)."""
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
api_secret: str, api_secret: str,
testnet: bool = True, testnet: bool = True,
http: Any | None = None, http: httpx.AsyncClient | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> None: ) -> None:
self.api_key = api_key self.api_key = api_key
self.api_secret = api_secret self.api_secret = api_secret
self.testnet = testnet self.testnet = testnet
# pybit HTTP non accetta `endpoint` come kwarg (vedi _V5HTTPManager.__init__: self.base_url = base_url or (BASE_TESTNET if testnet else BASE_MAINNET)
# solo `domain`/`tld`/`testnet`). Override URL applicato post-init self.recv_window = DEFAULT_RECV_WINDOW
# sovrascrivendo l'attributo `endpoint` dell'istanza HTTP. # `http` injection è usato dai test per montare un AsyncClient con
self.base_url = base_url # `httpx.MockTransport`. In produzione creiamo un client dedicato.
if http is None: self._owns_http = http is None
http = HTTP( self._http: httpx.AsyncClient = http or httpx.AsyncClient(
api_key=api_key, timeout=DEFAULT_TIMEOUT
api_secret=api_secret, )
testnet=testnet,
)
if base_url:
http.endpoint = base_url
self._http = http
async def _run(self, fn, /, **kwargs): async def aclose(self) -> None:
return await asyncio.to_thread(fn, **kwargs) """Chiude l'AsyncClient httpx se di nostra proprietà."""
if self._owns_http:
await self._http.aclose()
# ── auth helpers ───────────────────────────────────────────
def _timestamp_ms(self) -> str:
return str(int(time.time() * 1000))
def _sign(self, timestamp: str, payload: str) -> str:
msg = timestamp + self.api_key + self.recv_window + payload
return hmac.new(
self.api_secret.encode("utf-8"),
msg.encode("utf-8"),
hashlib.sha256,
).hexdigest()
def _signed_headers(self, payload: str) -> dict[str, str]:
ts = self._timestamp_ms()
sig = self._sign(ts, payload)
return {
"X-BAPI-API-KEY": self.api_key,
"X-BAPI-TIMESTAMP": ts,
"X-BAPI-RECV-WINDOW": self.recv_window,
"X-BAPI-SIGN": sig,
"Content-Type": "application/json",
}
@staticmethod @staticmethod
def _parse_ticker(row: dict) -> dict: def _clean_params(params: dict[str, Any] | None) -> dict[str, Any]:
if not params:
return {}
return {k: v for k, v in params.items() if v is not None}
@staticmethod
def _querystring(params: dict[str, Any]) -> str:
# Bybit accetta querystring nell'ordine in cui viene serializzata la
# request. Per la signature usiamo lo stesso urlencode (ordine
# inserzione dict). In Python 3.7+ dict mantiene insertion order:
# mantenere coerenza tra signature payload e URL effettivo.
return urlencode(params)
# ── request primitives ─────────────────────────────────────
async def _request_public(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
) -> dict[str, Any]:
clean = self._clean_params(params)
url = self.base_url + path
resp = await self._http.request(
method, url, params=clean if clean else None
)
return self._parse_response(resp)
async def _request_signed(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
) -> dict[str, Any]:
url = self.base_url + path
method = method.upper()
if method == "GET":
clean = self._clean_params(params)
qs = self._querystring(clean)
headers = self._signed_headers(qs)
resp = await self._http.request(
method, url, params=clean if clean else None, headers=headers
)
else:
payload_body = body or {}
body_json = json.dumps(payload_body, separators=(",", ":"))
headers = self._signed_headers(body_json)
resp = await self._http.request(
method, url, content=body_json, headers=headers
)
return self._parse_response(resp)
@staticmethod
def _parse_response(resp: httpx.Response) -> dict[str, Any]:
try:
data = resp.json()
except Exception as e: # pragma: no cover - difficilmente raggiungibile
raise BybitAPIError(
f"invalid JSON from Bybit (status={resp.status_code}): {resp.text[:200]}"
) from e
if resp.status_code >= 500:
raise BybitAPIError(
f"bybit server error {resp.status_code}: "
f"{data.get('retMsg', resp.text[:200])}"
)
if not isinstance(data, dict):
raise BybitAPIError(f"unexpected payload type: {type(data).__name__}")
return data
def _envelope(self, resp: dict[str, Any], payload: dict[str, Any]) -> dict[str, Any]:
code = resp.get("retCode", 0)
if code != 0:
return {"error": resp.get("retMsg", "bybit_error"), "code": code}
return payload
# ── parsers shared ─────────────────────────────────────────
@staticmethod
def _parse_ticker(row: dict[str, Any]) -> dict[str, Any]:
return { return {
"symbol": row.get("symbol"), "symbol": row.get("symbol"),
"last_price": _f(row.get("lastPrice")), "last_price": _f(row.get("lastPrice")),
@@ -66,9 +192,13 @@ class BybitClient:
"open_interest": _f(row.get("openInterest")), "open_interest": _f(row.get("openInterest")),
} }
# ── market data (public) ───────────────────────────────────
async def get_ticker(self, symbol: str, category: str = "linear") -> dict: async def get_ticker(self, symbol: str, category: str = "linear") -> dict:
resp = await self._run( resp = await self._request_public(
self._http.get_tickers, category=category, symbol=symbol "GET",
"/v5/market/tickers",
params={"category": category, "symbol": symbol},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
if not rows: if not rows:
@@ -86,8 +216,10 @@ class BybitClient:
async def get_orderbook( async def get_orderbook(
self, symbol: str, category: str = "linear", limit: int = 50 self, symbol: str, category: str = "linear", limit: int = 50
) -> dict: ) -> dict:
resp = await self._run( resp = await self._request_public(
self._http.get_orderbook, category=category, symbol=symbol, limit=limit "GET",
"/v5/market/orderbook",
params={"category": category, "symbol": symbol, "limit": limit},
) )
r = resp.get("result") or {} r = resp.get("result") or {}
return { return {
@@ -106,17 +238,17 @@ class BybitClient:
end: int | None = None, end: int | None = None,
limit: int = 1000, limit: int = 1000,
) -> dict: ) -> dict:
kwargs = dict( params: dict[str, Any] = {
category=category, "category": category,
symbol=symbol, "symbol": symbol,
interval=interval, "interval": interval,
limit=limit, "limit": limit,
) }
if start is not None: if start is not None:
kwargs["start"] = start params["start"] = start
if end is not None: if end is not None:
kwargs["end"] = end params["end"] = end
resp = await self._run(self._http.get_kline, **kwargs) resp = await self._request_public("GET", "/v5/market/kline", params=params)
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
rows_sorted = sorted(rows, key=lambda r: int(r[0])) rows_sorted = sorted(rows, key=lambda r: int(r[0]))
candles = [ candles = [
@@ -168,8 +300,10 @@ class BybitClient:
return out return out
async def get_funding_rate(self, symbol: str, category: str = "linear") -> dict: async def get_funding_rate(self, symbol: str, category: str = "linear") -> dict:
resp = await self._run( resp = await self._request_public(
self._http.get_tickers, category=category, symbol=symbol "GET",
"/v5/market/tickers",
params={"category": category, "symbol": symbol},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
if not rows: if not rows:
@@ -184,9 +318,10 @@ class BybitClient:
async def get_funding_history( async def get_funding_history(
self, symbol: str, category: str = "linear", limit: int = 100 self, symbol: str, category: str = "linear", limit: int = 100
) -> dict: ) -> dict:
resp = await self._run( resp = await self._request_public(
self._http.get_funding_rate_history, "GET",
category=category, symbol=symbol, limit=limit, "/v5/market/funding/history",
params={"category": category, "symbol": symbol, "limit": limit},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
hist = [ hist = [
@@ -205,9 +340,15 @@ class BybitClient:
interval: str = "5min", interval: str = "5min",
limit: int = 288, limit: int = 288,
) -> dict: ) -> dict:
resp = await self._run( resp = await self._request_public(
self._http.get_open_interest, "GET",
category=category, symbol=symbol, intervalTime=interval, limit=limit, "/v5/market/open-interest",
params={
"category": category,
"symbol": symbol,
"intervalTime": interval,
"limit": limit,
},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
points = [ points = [
@@ -226,71 +367,88 @@ class BybitClient:
"points": points, "points": points,
} }
async def get_instruments(self, category: str = "linear", symbol: str | None = None) -> dict: async def get_instruments(
kwargs: dict[str, Any] = {"category": category} self, category: str = "linear", symbol: str | None = None
) -> dict:
params: dict[str, Any] = {"category": category}
if symbol: if symbol:
kwargs["symbol"] = symbol params["symbol"] = symbol
resp = await self._run(self._http.get_instruments_info, **kwargs) resp = await self._request_public(
"GET", "/v5/market/instruments-info", params=params
)
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
instruments = [] instruments = []
for r in rows: for r in rows:
pf = r.get("priceFilter") or {} pf = r.get("priceFilter") or {}
lf = r.get("lotSizeFilter") or {} lf = r.get("lotSizeFilter") or {}
instruments.append({ instruments.append(
"symbol": r.get("symbol"), {
"status": r.get("status"), "symbol": r.get("symbol"),
"base_coin": r.get("baseCoin"), "status": r.get("status"),
"quote_coin": r.get("quoteCoin"), "base_coin": r.get("baseCoin"),
"tick_size": _f(pf.get("tickSize")), "quote_coin": r.get("quoteCoin"),
"qty_step": _f(lf.get("qtyStep")), "tick_size": _f(pf.get("tickSize")),
"min_qty": _f(lf.get("minOrderQty")), "qty_step": _f(lf.get("qtyStep")),
}) "min_qty": _f(lf.get("minOrderQty")),
}
)
return {"category": category, "instruments": instruments} return {"category": category, "instruments": instruments}
async def get_option_chain(self, base_coin: str, expiry: str | None = None) -> dict: async def get_option_chain(self, base_coin: str, expiry: str | None = None) -> dict:
kwargs: dict[str, Any] = {"category": "option", "baseCoin": base_coin.upper()} resp = await self._request_public(
resp = await self._run(self._http.get_instruments_info, **kwargs) "GET",
"/v5/market/instruments-info",
params={"category": "option", "baseCoin": base_coin.upper()},
)
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
options = [] options = []
for r in rows: for r in rows:
delivery = r.get("deliveryTime") delivery = r.get("deliveryTime")
if expiry and expiry not in r.get("symbol", ""): if expiry and expiry not in r.get("symbol", ""):
continue continue
options.append({ options.append(
"symbol": r.get("symbol"), {
"base_coin": r.get("baseCoin"), "symbol": r.get("symbol"),
"settle_coin": r.get("settleCoin"), "base_coin": r.get("baseCoin"),
"type": r.get("optionsType"), "settle_coin": r.get("settleCoin"),
"launch_time": int(r.get("launchTime", 0)), "type": r.get("optionsType"),
"delivery_time": int(delivery) if delivery else None, "launch_time": int(r.get("launchTime", 0)),
}) "delivery_time": int(delivery) if delivery else None,
}
)
return {"base_coin": base_coin.upper(), "options": options} return {"base_coin": base_coin.upper(), "options": options}
# ── account / positions / orders (signed) ─────────────────
async def get_positions( async def get_positions(
self, category: str = "linear", settle_coin: str = "USDT" self, category: str = "linear", settle_coin: str = "USDT"
) -> list[dict]: ) -> list[dict]:
kwargs: dict[str, Any] = {"category": category} params: dict[str, Any] = {"category": category}
if category in ("linear", "inverse"): if category in ("linear", "inverse"):
kwargs["settleCoin"] = settle_coin params["settleCoin"] = settle_coin
resp = await self._run(self._http.get_positions, **kwargs) resp = await self._request_signed("GET", "/v5/position/list", params=params)
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
out = [] out = []
for r in rows: for r in rows:
out.append({ out.append(
"symbol": r.get("symbol"), {
"side": r.get("side"), "symbol": r.get("symbol"),
"size": _f(r.get("size")), "side": r.get("side"),
"entry_price": _f(r.get("avgPrice")), "size": _f(r.get("size")),
"unrealized_pnl": _f(r.get("unrealisedPnl")), "entry_price": _f(r.get("avgPrice")),
"leverage": _f(r.get("leverage")), "unrealized_pnl": _f(r.get("unrealisedPnl")),
"liquidation_price": _f(r.get("liqPrice")), "leverage": _f(r.get("leverage")),
"position_value": _f(r.get("positionValue")), "liquidation_price": _f(r.get("liqPrice")),
}) "position_value": _f(r.get("positionValue")),
}
)
return out return out
async def get_account_summary(self, account_type: str = "UNIFIED") -> dict: async def get_account_summary(self, account_type: str = "UNIFIED") -> dict:
resp = await self._run( resp = await self._request_signed(
self._http.get_wallet_balance, accountType=account_type "GET",
"/v5/account/wallet-balance",
params={"accountType": account_type},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
if not rows: if not rows:
@@ -298,11 +456,13 @@ class BybitClient:
a = rows[0] a = rows[0]
coins = [] coins = []
for c in a.get("coin") or []: for c in a.get("coin") or []:
coins.append({ coins.append(
"coin": c.get("coin"), {
"wallet_balance": _f(c.get("walletBalance")), "coin": c.get("coin"),
"equity": _f(c.get("equity")), "wallet_balance": _f(c.get("walletBalance")),
}) "equity": _f(c.get("equity")),
}
)
return { return {
"account_type": a.get("accountType"), "account_type": a.get("accountType"),
"equity": _f(a.get("totalEquity")), "equity": _f(a.get("totalEquity")),
@@ -316,8 +476,10 @@ class BybitClient:
async def get_trade_history( async def get_trade_history(
self, category: str = "linear", limit: int = 50 self, category: str = "linear", limit: int = 50
) -> list[dict]: ) -> list[dict]:
resp = await self._run( resp = await self._request_signed(
self._http.get_executions, category=category, limit=limit "GET",
"/v5/execution/list",
params={"category": category, "limit": limit},
) )
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
return [ return [
@@ -339,12 +501,14 @@ class BybitClient:
symbol: str | None = None, symbol: str | None = None,
settle_coin: str = "USDT", settle_coin: str = "USDT",
) -> list[dict]: ) -> list[dict]:
kwargs: dict[str, Any] = {"category": category} params: dict[str, Any] = {"category": category}
if category in ("linear", "inverse") and not symbol: if category in ("linear", "inverse") and not symbol:
kwargs["settleCoin"] = settle_coin params["settleCoin"] = settle_coin
if symbol: if symbol:
kwargs["symbol"] = symbol params["symbol"] = symbol
resp = await self._run(self._http.get_open_orders, **kwargs) resp = await self._request_signed(
"GET", "/v5/order/realtime", params=params
)
rows = (resp.get("result") or {}).get("list") or [] rows = (resp.get("result") or {}).get("list") or []
return [ return [
{ {
@@ -360,15 +524,20 @@ class BybitClient:
for r in rows for r in rows
] ]
# ── microstructure / basis ─────────────────────────────────
async def get_orderbook_imbalance( async def get_orderbook_imbalance(
self, self,
symbol: str, symbol: str,
category: str = "linear", category: str = "linear",
depth: int = 10, depth: int = 10,
) -> dict: ) -> dict:
"""Microstructure: bid/ask imbalance ratio + microprice + slope.""" ob = await self.get_orderbook(
ob = await self.get_orderbook(symbol=symbol, category=category, limit=max(depth, 50)) symbol=symbol, category=category, limit=max(depth, 50)
result = micro.orderbook_imbalance(ob.get("bids") or [], ob.get("asks") or [], depth=depth) )
result = micro.orderbook_imbalance(
ob.get("bids") or [], ob.get("asks") or [], depth=depth
)
return { return {
"symbol": symbol, "symbol": symbol,
"category": category, "category": category,
@@ -378,9 +547,6 @@ class BybitClient:
} }
async def get_basis_term_structure(self, asset: str) -> dict: async def get_basis_term_structure(self, asset: str) -> dict:
"""Basis curve futures (dated) vs perp + spot. Filtra contratti future
BTCUSDT / ETHUSDT con scadenza, calcola annualized basis per ognuno.
"""
import datetime as _dt import datetime as _dt
asset = asset.upper() asset = asset.upper()
@@ -389,12 +555,13 @@ class BybitClient:
sp = spot.get("last_price") sp = spot.get("last_price")
pp = perp.get("last_price") pp = perp.get("last_price")
# Lista futures dated (linear/inverse)
instr = await self.get_instruments(category="linear") instr = await self.get_instruments(category="linear")
items = (instr.get("instruments") or []) items = instr.get("instruments") or []
futures = [ futures = [
x for x in items x
if x.get("symbol", "").startswith(f"{asset}-") or x.get("symbol", "").startswith(f"{asset}USDT-") for x in items
if x.get("symbol", "").startswith(f"{asset}-")
or x.get("symbol", "").startswith(f"{asset}USDT-")
] ]
rows: list[dict[str, Any]] = [] rows: list[dict[str, Any]] = []
@@ -409,21 +576,25 @@ class BybitClient:
days = max((int(expiry_ms) - now_ms) / 86_400_000, 1) days = max((int(expiry_ms) - now_ms) / 86_400_000, 1)
basis_pct = 100.0 * (fp - sp) / sp basis_pct = 100.0 * (fp - sp) / sp
annualized = basis_pct * 365.0 / days annualized = basis_pct * 365.0 / days
rows.append({ rows.append(
"symbol": f["symbol"], {
"expiry_ms": int(expiry_ms), "symbol": f["symbol"],
"days_to_expiry": round(days, 2), "expiry_ms": int(expiry_ms),
"future_price": fp, "days_to_expiry": round(days, 2),
"basis_pct": round(basis_pct, 4), "future_price": fp,
"annualized_basis_pct": round(annualized, 4), "basis_pct": round(basis_pct, 4),
}) "annualized_basis_pct": round(annualized, 4),
}
)
rows.sort(key=lambda r: r["days_to_expiry"]) rows.sort(key=lambda r: r["days_to_expiry"])
return { return {
"asset": asset, "asset": asset,
"spot_price": sp, "spot_price": sp,
"perp_price": pp, "perp_price": pp,
"perp_basis_pct": round(100.0 * (pp - sp) / sp, 4) if (sp and pp) else None, "perp_basis_pct": round(100.0 * (pp - sp) / sp, 4)
if (sp and pp)
else None,
"term_structure": rows, "term_structure": rows,
"data_timestamp": _dt.datetime.now(_dt.UTC).isoformat(), "data_timestamp": _dt.datetime.now(_dt.UTC).isoformat(),
} }
@@ -449,11 +620,7 @@ class BybitClient:
"funding_rate": perp.get("funding_rate"), "funding_rate": perp.get("funding_rate"),
} }
def _envelope(self, resp: dict, payload: dict) -> dict: # ── trading (signed, write) ────────────────────────────────
code = resp.get("retCode", 0)
if code != 0:
return {"error": resp.get("retMsg", "bybit_error"), "code": code}
return payload
async def place_order( async def place_order(
self, self,
@@ -467,7 +634,7 @@ class BybitClient:
reduce_only: bool = False, reduce_only: bool = False,
position_idx: int | None = None, position_idx: int | None = None,
) -> dict: ) -> dict:
kwargs: dict[str, Any] = { body: dict[str, Any] = {
"category": category, "category": category,
"symbol": symbol, "symbol": symbol,
"side": side, "side": side,
@@ -477,38 +644,34 @@ class BybitClient:
"reduceOnly": reduce_only, "reduceOnly": reduce_only,
} }
if price is not None: if price is not None:
kwargs["price"] = str(price) body["price"] = str(price)
if position_idx is not None: if position_idx is not None:
kwargs["positionIdx"] = position_idx body["positionIdx"] = position_idx
if category == "option": if category == "option":
import uuid body["orderLinkId"] = f"cerbero-{uuid.uuid4().hex[:16]}"
kwargs["orderLinkId"] = f"cerbero-{uuid.uuid4().hex[:16]}" resp = await self._request_signed("POST", "/v5/order/create", body=body)
resp = await self._run(self._http.place_order, **kwargs)
r = resp.get("result") or {} r = resp.get("result") or {}
return self._envelope(resp, { return self._envelope(
"order_id": r.get("orderId"), resp,
"order_link_id": r.get("orderLinkId"), {
"status": "submitted", "order_id": r.get("orderId"),
}) "order_link_id": r.get("orderLinkId"),
"status": "submitted",
},
)
async def place_combo_order( async def place_combo_order(
self, self,
category: str, category: str,
legs: list[dict[str, Any]], legs: list[dict[str, Any]],
) -> dict: ) -> dict:
"""Atomic multi-leg via /v5/order/create-batch (Bybit option only).
Bybit supporta batch_order solo su category='option'. Per perp/linear
usare loop di place_order (non atomic).
legs: [{symbol, side, qty, order_type, price?, tif?, reduce_only?}].
"""
if category != "option": if category != "option":
raise ValueError("place_combo_order: Bybit batch_order è disponibile solo su category='option'") raise ValueError(
"place_combo_order: Bybit batch_order è disponibile solo su category='option'"
)
if len(legs) < 2: if len(legs) < 2:
raise ValueError("combo requires at least 2 legs") raise ValueError("combo requires at least 2 legs")
import uuid
request: list[dict[str, Any]] = [] request: list[dict[str, Any]] = []
for leg in legs: for leg in legs:
entry: dict[str, Any] = { entry: dict[str, Any] = {
@@ -524,7 +687,10 @@ class BybitClient:
entry["price"] = str(leg["price"]) entry["price"] = str(leg["price"])
request.append(entry) request.append(entry)
resp = await self._run(self._http.place_batch_order, category=category, request=request) body = {"category": category, "request": request}
resp = await self._request_signed(
"POST", "/v5/order/create-batch", body=body
)
result_list = (resp.get("result") or {}).get("list") or [] result_list = (resp.get("result") or {}).get("list") or []
orders = [ orders = [
{ {
@@ -544,80 +710,112 @@ class BybitClient:
new_qty: float | None = None, new_qty: float | None = None,
new_price: float | None = None, new_price: float | None = None,
) -> dict: ) -> dict:
kwargs: dict[str, Any] = { body: dict[str, Any] = {
"category": category, "category": category,
"symbol": symbol, "symbol": symbol,
"orderId": order_id, "orderId": order_id,
} }
if new_qty is not None: if new_qty is not None:
kwargs["qty"] = str(new_qty) body["qty"] = str(new_qty)
if new_price is not None: if new_price is not None:
kwargs["price"] = str(new_price) body["price"] = str(new_price)
resp = await self._run(self._http.amend_order, **kwargs) resp = await self._request_signed("POST", "/v5/order/amend", body=body)
r = resp.get("result") or {} r = resp.get("result") or {}
return self._envelope(resp, { return self._envelope(
"order_id": r.get("orderId", order_id), resp,
"status": "amended", {
}) "order_id": r.get("orderId", order_id),
"status": "amended",
async def cancel_order( },
self, category: str, symbol: str, order_id: str
) -> dict:
resp = await self._run(
self._http.cancel_order,
category=category, symbol=symbol, orderId=order_id,
) )
async def cancel_order(self, category: str, symbol: str, order_id: str) -> dict:
body = {"category": category, "symbol": symbol, "orderId": order_id}
resp = await self._request_signed("POST", "/v5/order/cancel", body=body)
r = resp.get("result") or {} r = resp.get("result") or {}
return self._envelope(resp, { return self._envelope(
"order_id": r.get("orderId", order_id), resp,
"status": "cancelled", {
}) "order_id": r.get("orderId", order_id),
"status": "cancelled",
},
)
async def cancel_all_orders( async def cancel_all_orders(
self, category: str, symbol: str | None = None self, category: str, symbol: str | None = None
) -> dict: ) -> dict:
kwargs: dict[str, Any] = {"category": category} body: dict[str, Any] = {"category": category}
if symbol: if symbol:
kwargs["symbol"] = symbol body["symbol"] = symbol
resp = await self._run(self._http.cancel_all_orders, **kwargs) resp = await self._request_signed(
"POST", "/v5/order/cancel-all", body=body
)
r = resp.get("result") or {} r = resp.get("result") or {}
ids = [x.get("orderId") for x in (r.get("list") or [])] ids = [x.get("orderId") for x in (r.get("list") or [])]
return self._envelope(resp, { return self._envelope(
"cancelled_ids": ids, resp,
"count": len(ids), {
}) "cancelled_ids": ids,
"count": len(ids),
},
)
async def set_stop_loss( async def set_stop_loss(
self, category: str, symbol: str, stop_loss: float, self,
category: str,
symbol: str,
stop_loss: float,
position_idx: int = 0, position_idx: int = 0,
) -> dict: ) -> dict:
resp = await self._run( body = {
self._http.set_trading_stop, "category": category,
category=category, symbol=symbol, "symbol": symbol,
stopLoss=str(stop_loss), positionIdx=position_idx, "stopLoss": str(stop_loss),
"positionIdx": position_idx,
}
resp = await self._request_signed(
"POST", "/v5/position/trading-stop", body=body
)
return self._envelope(
resp,
{
"symbol": symbol,
"stop_loss": stop_loss,
"status": "stop_loss_set",
},
) )
return self._envelope(resp, {
"symbol": symbol, "stop_loss": stop_loss,
"status": "stop_loss_set",
})
async def set_take_profit( async def set_take_profit(
self, category: str, symbol: str, take_profit: float, self,
category: str,
symbol: str,
take_profit: float,
position_idx: int = 0, position_idx: int = 0,
) -> dict: ) -> dict:
resp = await self._run( body = {
self._http.set_trading_stop, "category": category,
category=category, symbol=symbol, "symbol": symbol,
takeProfit=str(take_profit), positionIdx=position_idx, "takeProfit": str(take_profit),
"positionIdx": position_idx,
}
resp = await self._request_signed(
"POST", "/v5/position/trading-stop", body=body
)
return self._envelope(
resp,
{
"symbol": symbol,
"take_profit": take_profit,
"status": "take_profit_set",
},
) )
return self._envelope(resp, {
"symbol": symbol, "take_profit": take_profit,
"status": "take_profit_set",
})
async def close_position(self, category: str, symbol: str) -> dict: async def close_position(self, category: str, symbol: str) -> dict:
positions = await self.get_positions(category=category) positions = await self.get_positions(category=category)
target = next((p for p in positions if p["symbol"] == symbol and (p["size"] or 0) > 0), None) target = next(
(p for p in positions if p["symbol"] == symbol and (p["size"] or 0) > 0),
None,
)
if not target: if not target:
return {"error": "no_open_position", "symbol": symbol} return {"error": "no_open_position", "symbol": symbol}
close_side = "Sell" if target["side"] == "Buy" else "Buy" close_side = "Sell" if target["side"] == "Buy" else "Buy"
@@ -634,28 +832,44 @@ class BybitClient:
async def set_leverage( async def set_leverage(
self, category: str, symbol: str, leverage: int self, category: str, symbol: str, leverage: int
) -> dict: ) -> dict:
resp = await self._run( body = {
self._http.set_leverage, "category": category,
category=category, symbol=symbol, "symbol": symbol,
buyLeverage=str(leverage), sellLeverage=str(leverage), "buyLeverage": str(leverage),
"sellLeverage": str(leverage),
}
resp = await self._request_signed(
"POST", "/v5/position/set-leverage", body=body
)
return self._envelope(
resp,
{
"symbol": symbol,
"leverage": leverage,
"status": "leverage_set",
},
) )
return self._envelope(resp, {
"symbol": symbol, "leverage": leverage,
"status": "leverage_set",
})
async def switch_position_mode( async def switch_position_mode(
self, category: str, symbol: str, mode: str self, category: str, symbol: str, mode: str
) -> dict: ) -> dict:
mode_code = 3 if mode.lower() == "hedge" else 0 mode_code = 3 if mode.lower() == "hedge" else 0
resp = await self._run( body = {
self._http.switch_position_mode, "category": category,
category=category, symbol=symbol, mode=mode_code, "symbol": symbol,
"mode": mode_code,
}
resp = await self._request_signed(
"POST", "/v5/position/switch-mode", body=body
)
return self._envelope(
resp,
{
"symbol": symbol,
"mode": mode,
"status": "mode_switched",
},
) )
return self._envelope(resp, {
"symbol": symbol, "mode": mode,
"status": "mode_switched",
})
async def transfer_asset( async def transfer_asset(
self, self,
@@ -664,19 +878,23 @@ class BybitClient:
from_type: str, from_type: str,
to_type: str, to_type: str,
) -> dict: ) -> dict:
import uuid body = {
resp = await self._run( "transferId": str(uuid.uuid4()),
self._http.create_internal_transfer, "coin": coin,
transferId=str(uuid.uuid4()), "amount": str(amount),
coin=coin, "fromAccountType": from_type,
amount=str(amount), "toAccountType": to_type,
fromAccountType=from_type, }
toAccountType=to_type, resp = await self._request_signed(
"POST", "/v5/asset/transfer/inter-transfer", body=body
) )
r = resp.get("result") or {} r = resp.get("result") or {}
return self._envelope(resp, { return self._envelope(
"transfer_id": r.get("transferId"), resp,
"coin": coin, {
"amount": amount, "transfer_id": r.get("transferId"),
"status": "submitted", "coin": coin,
}) "amount": amount,
"status": "submitted",
},
)
+5 -8
View File
@@ -1,21 +1,18 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock
import pytest import pytest
from cerbero_mcp.exchanges.bybit.client import BybitClient from cerbero_mcp.exchanges.bybit.client import BybitClient
@pytest.fixture @pytest.fixture
def mock_http(): def client():
return MagicMock(name="pybit_HTTP") """BybitClient con base_url testnet e AsyncClient interno.
pytest-httpx intercetta le chiamate dell'AsyncClient httpx creato dal
@pytest.fixture costruttore (auto-mock), quindi non serve injection esplicita.
def client(mock_http): """
return BybitClient( return BybitClient(
api_key="test_key", api_key="test_key",
api_secret="test_secret", api_secret="test_secret",
testnet=True, testnet=True,
http=mock_http,
) )
File diff suppressed because it is too large Load Diff
+20 -32
View File
@@ -29,15 +29,8 @@ async def test_build_client_bybit_returns_correct_env(monkeypatch):
for k, v in _minimal_env().items(): for k, v in _minimal_env().items():
monkeypatch.setenv(k, v) monkeypatch.setenv(k, v)
# Stub pybit HTTP per evitare connessione reale durante __init__ # BybitClient costruisce internamente httpx.AsyncClient: nessuna
from cerbero_mcp.exchanges.bybit import client as bybit_client # connessione reale finché non si invoca un metodo di rete.
class _FakeHTTP:
def __init__(self, **kwargs):
self.kwargs = kwargs
monkeypatch.setattr(bybit_client, "HTTP", _FakeHTTP)
from cerbero_mcp.exchanges import build_client from cerbero_mcp.exchanges import build_client
from cerbero_mcp.settings import Settings from cerbero_mcp.settings import Settings
@@ -78,28 +71,22 @@ async def test_build_client_alpaca_returns_correct_env(monkeypatch):
for k, v in _minimal_env().items(): for k, v in _minimal_env().items():
monkeypatch.setenv(k, v) monkeypatch.setenv(k, v)
# Stub alpaca SDK clients per evitare connessioni reali in __init__ # AlpacaClient (V2) usa httpx puro: il costruttore non apre connessioni
from cerbero_mcp.exchanges.alpaca import client as alpaca_client # reali (httpx.AsyncClient è lazy fino alla prima request), quindi nessuno
# stub SDK è necessario.
class _FakeSdk:
def __init__(self, **kwargs):
self.kwargs = kwargs
monkeypatch.setattr(alpaca_client, "TradingClient", _FakeSdk)
monkeypatch.setattr(alpaca_client, "StockHistoricalDataClient", _FakeSdk)
monkeypatch.setattr(alpaca_client, "CryptoHistoricalDataClient", _FakeSdk)
monkeypatch.setattr(alpaca_client, "OptionHistoricalDataClient", _FakeSdk)
from cerbero_mcp.exchanges import build_client from cerbero_mcp.exchanges import build_client
from cerbero_mcp.settings import Settings from cerbero_mcp.settings import Settings
s = Settings() s = Settings()
c_test = await build_client(s, "alpaca", "testnet") c_test = await build_client(s, "alpaca", "testnet")
c_live = await build_client(s, "alpaca", "mainnet") c_live = await build_client(s, "alpaca", "mainnet")
try:
assert c_test is not c_live assert c_test is not c_live
assert c_test.paper is True assert c_test.paper is True
assert c_live.paper is False assert c_live.paper is False
finally:
await c_test.aclose()
await c_live.aclose()
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -202,8 +189,8 @@ async def test_hyperliquid_url_from_env_overrides_default(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bybit_url_from_env_overrides_default(monkeypatch): async def test_bybit_url_from_env_overrides_default(monkeypatch):
"""Bybit: pybit non accetta `endpoint` come kwarg, ma setting di """Bybit (httpx): override BYBIT_URL_TESTNET applica direttamente a
`_http.endpoint` post-init rispecchia l'override.""" `self.base_url`, usato come base di ogni richiesta REST V5."""
from tests.unit.test_settings import _minimal_env from tests.unit.test_settings import _minimal_env
env = _minimal_env(BYBIT_URL_TESTNET="https://bybit-custom.example.com") env = _minimal_env(BYBIT_URL_TESTNET="https://bybit-custom.example.com")
@@ -216,14 +203,12 @@ async def test_bybit_url_from_env_overrides_default(monkeypatch):
s = Settings() s = Settings()
c = await build_client(s, "bybit", "testnet") c = await build_client(s, "bybit", "testnet")
assert c.base_url == "https://bybit-custom.example.com" assert c.base_url == "https://bybit-custom.example.com"
# override applicato all'istanza pybit HTTP via attributo `endpoint`
assert getattr(c._http, "endpoint", None) == "https://bybit-custom.example.com"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_alpaca_url_from_env_overrides_default(monkeypatch): async def test_alpaca_url_from_env_overrides_default(monkeypatch):
"""Alpaca: TradingClient supporta url_override per trading API. """Alpaca V2 (httpx): `base_url` override applica al solo trading
Data clients (Stock/Crypto/Option) non supportano override sul costruttore.""" endpoint; data endpoints (data.alpaca.markets) restano hardcoded."""
from tests.unit.test_settings import _minimal_env from tests.unit.test_settings import _minimal_env
env = _minimal_env(ALPACA_URL_TESTNET="https://alpaca-custom.example.com") env = _minimal_env(ALPACA_URL_TESTNET="https://alpaca-custom.example.com")
@@ -235,7 +220,10 @@ async def test_alpaca_url_from_env_overrides_default(monkeypatch):
s = Settings() s = Settings()
c = await build_client(s, "alpaca", "testnet") c = await build_client(s, "alpaca", "testnet")
assert c.base_url == "https://alpaca-custom.example.com" try:
assert c.base_url == "https://alpaca-custom.example.com"
finally:
await c.aclose()
@pytest.mark.asyncio @pytest.mark.asyncio