diff --git a/src/cerbero_mcp/client_registry.py b/src/cerbero_mcp/client_registry.py new file mode 100644 index 0000000..3a845a8 --- /dev/null +++ b/src/cerbero_mcp/client_registry.py @@ -0,0 +1,41 @@ +"""Cache lazy di client exchange, una istanza per (exchange, env).""" +from __future__ import annotations + +import asyncio +from collections import defaultdict +from collections.abc import Awaitable, Callable +from typing import Any, Literal + +Environment = Literal["testnet", "mainnet"] +Builder = Callable[[str, Environment], Awaitable[Any]] + + +class ClientRegistry: + def __init__(self, *, builder: Builder) -> None: + self._builder = builder + self._clients: dict[tuple[str, Environment], Any] = {} + self._locks: dict[tuple[str, Environment], asyncio.Lock] = defaultdict( + asyncio.Lock + ) + + async def get(self, exchange: str, env: Environment) -> Any: + key = (exchange, env) + if key in self._clients: + return self._clients[key] + async with self._locks[key]: + if key in self._clients: # double-check + return self._clients[key] + client = await self._builder(exchange, env) + self._clients[key] = client + return client + + async def aclose(self) -> None: + for client in self._clients.values(): + close = getattr(client, "aclose", None) + if close is None: + continue + try: + await close() + except Exception: + pass + self._clients.clear() diff --git a/tests/unit/test_client_registry.py b/tests/unit/test_client_registry.py new file mode 100644 index 0000000..bc944c8 --- /dev/null +++ b/tests/unit/test_client_registry.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_registry_lazy_build(): + from cerbero_mcp.client_registry import ClientRegistry + + builds: list[tuple[str, str]] = [] + + async def fake_build(exchange: str, env: str): + builds.append((exchange, env)) + + class C: + async def aclose(self): + pass + return C() + + reg = ClientRegistry(builder=fake_build) + c = await reg.get("deribit", "testnet") + assert builds == [("deribit", "testnet")] + assert (await reg.get("deribit", "testnet")) is c # cached + assert builds == [("deribit", "testnet")] + + +@pytest.mark.asyncio +async def test_registry_different_keys_different_clients(): + from cerbero_mcp.client_registry import ClientRegistry + + async def fake_build(exchange: str, env: str): + class C: + tag = (exchange, env) + async def aclose(self): ... + return C() + + reg = ClientRegistry(builder=fake_build) + a = await reg.get("deribit", "testnet") + b = await reg.get("deribit", "mainnet") + c = await reg.get("bybit", "testnet") + assert a is not b + assert a.tag == ("deribit", "testnet") + assert b.tag == ("deribit", "mainnet") + assert c.tag == ("bybit", "testnet") + + +@pytest.mark.asyncio +async def test_registry_concurrent_get_one_build(): + from cerbero_mcp.client_registry import ClientRegistry + + counter = {"calls": 0} + + async def fake_build(exchange: str, env: str): + counter["calls"] += 1 + await asyncio.sleep(0.05) + + class C: + async def aclose(self): ... + return C() + + reg = ClientRegistry(builder=fake_build) + results = await asyncio.gather( + *[reg.get("deribit", "testnet") for _ in range(10)] + ) + assert counter["calls"] == 1 + assert all(r is results[0] for r in results) + + +@pytest.mark.asyncio +async def test_registry_aclose_calls_all(): + from cerbero_mcp.client_registry import ClientRegistry + + closed: list[tuple[str, str]] = [] + + async def fake_build(exchange: str, env: str): + class C: + async def aclose(self): + closed.append((exchange, env)) + return C() + + reg = ClientRegistry(builder=fake_build) + await reg.get("deribit", "testnet") + await reg.get("bybit", "mainnet") + await reg.aclose() + assert sorted(closed) == sorted([("deribit", "testnet"), ("bybit", "mainnet")])