feat(V2): ClientRegistry lazy con lock per chiave
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,41 @@
|
|||||||
|
"""Cache lazy di client exchange, una istanza per (exchange, env)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
Environment = Literal["testnet", "mainnet"]
|
||||||
|
Builder = Callable[[str, Environment], Awaitable[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRegistry:
|
||||||
|
def __init__(self, *, builder: Builder) -> None:
|
||||||
|
self._builder = builder
|
||||||
|
self._clients: dict[tuple[str, Environment], Any] = {}
|
||||||
|
self._locks: dict[tuple[str, Environment], asyncio.Lock] = defaultdict(
|
||||||
|
asyncio.Lock
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, exchange: str, env: Environment) -> Any:
|
||||||
|
key = (exchange, env)
|
||||||
|
if key in self._clients:
|
||||||
|
return self._clients[key]
|
||||||
|
async with self._locks[key]:
|
||||||
|
if key in self._clients: # double-check
|
||||||
|
return self._clients[key]
|
||||||
|
client = await self._builder(exchange, env)
|
||||||
|
self._clients[key] = client
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def aclose(self) -> None:
|
||||||
|
for client in self._clients.values():
|
||||||
|
close = getattr(client, "aclose", None)
|
||||||
|
if close is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._clients.clear()
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registry_lazy_build():
|
||||||
|
from cerbero_mcp.client_registry import ClientRegistry
|
||||||
|
|
||||||
|
builds: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async def fake_build(exchange: str, env: str):
|
||||||
|
builds.append((exchange, env))
|
||||||
|
|
||||||
|
class C:
|
||||||
|
async def aclose(self):
|
||||||
|
pass
|
||||||
|
return C()
|
||||||
|
|
||||||
|
reg = ClientRegistry(builder=fake_build)
|
||||||
|
c = await reg.get("deribit", "testnet")
|
||||||
|
assert builds == [("deribit", "testnet")]
|
||||||
|
assert (await reg.get("deribit", "testnet")) is c # cached
|
||||||
|
assert builds == [("deribit", "testnet")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registry_different_keys_different_clients():
|
||||||
|
from cerbero_mcp.client_registry import ClientRegistry
|
||||||
|
|
||||||
|
async def fake_build(exchange: str, env: str):
|
||||||
|
class C:
|
||||||
|
tag = (exchange, env)
|
||||||
|
async def aclose(self): ...
|
||||||
|
return C()
|
||||||
|
|
||||||
|
reg = ClientRegistry(builder=fake_build)
|
||||||
|
a = await reg.get("deribit", "testnet")
|
||||||
|
b = await reg.get("deribit", "mainnet")
|
||||||
|
c = await reg.get("bybit", "testnet")
|
||||||
|
assert a is not b
|
||||||
|
assert a.tag == ("deribit", "testnet")
|
||||||
|
assert b.tag == ("deribit", "mainnet")
|
||||||
|
assert c.tag == ("bybit", "testnet")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registry_concurrent_get_one_build():
|
||||||
|
from cerbero_mcp.client_registry import ClientRegistry
|
||||||
|
|
||||||
|
counter = {"calls": 0}
|
||||||
|
|
||||||
|
async def fake_build(exchange: str, env: str):
|
||||||
|
counter["calls"] += 1
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
class C:
|
||||||
|
async def aclose(self): ...
|
||||||
|
return C()
|
||||||
|
|
||||||
|
reg = ClientRegistry(builder=fake_build)
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[reg.get("deribit", "testnet") for _ in range(10)]
|
||||||
|
)
|
||||||
|
assert counter["calls"] == 1
|
||||||
|
assert all(r is results[0] for r in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registry_aclose_calls_all():
|
||||||
|
from cerbero_mcp.client_registry import ClientRegistry
|
||||||
|
|
||||||
|
closed: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async def fake_build(exchange: str, env: str):
|
||||||
|
class C:
|
||||||
|
async def aclose(self):
|
||||||
|
closed.append((exchange, env))
|
||||||
|
return C()
|
||||||
|
|
||||||
|
reg = ClientRegistry(builder=fake_build)
|
||||||
|
await reg.get("deribit", "testnet")
|
||||||
|
await reg.get("bybit", "mainnet")
|
||||||
|
await reg.aclose()
|
||||||
|
assert sorted(closed) == sorted([("deribit", "testnet"), ("bybit", "mainnet")])
|
||||||
Reference in New Issue
Block a user