refactor(common): rename package option_mcp_common → mcp_common
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
@dataclass
|
||||
class Principal:
|
||||
name: str
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenStore:
|
||||
tokens: dict[str, Principal]
|
||||
|
||||
def get(self, token: str) -> Principal | None:
|
||||
return self.tokens.get(token)
|
||||
|
||||
|
||||
def require_principal(request: Request) -> Principal:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer "):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
||||
token = auth[len("Bearer "):].strip()
|
||||
store: TokenStore = request.app.state.token_store
|
||||
principal = store.get(token)
|
||||
if principal is None:
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid token")
|
||||
return principal
|
||||
|
||||
|
||||
def acl_requires(*, core: bool = False, observer: bool = False) -> Callable:
|
||||
"""Decorator: require at least one matching capability."""
|
||||
allowed: set[str] = set()
|
||||
if core:
|
||||
allowed.add("core")
|
||||
if observer:
|
||||
allowed.add("observer")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
principal = kwargs.get("principal")
|
||||
if principal is None:
|
||||
for a in args:
|
||||
if isinstance(a, Principal):
|
||||
principal = a
|
||||
break
|
||||
if principal is None or not (principal.capabilities & allowed):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
f"capability required: {allowed}",
|
||||
)
|
||||
return await func(*args, **kwargs) if _is_coro(func) else func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
principal = kwargs.get("principal")
|
||||
if principal is None:
|
||||
for a in args:
|
||||
if isinstance(a, Principal):
|
||||
principal = a
|
||||
break
|
||||
if principal is None or not (principal.capabilities & allowed):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
f"capability required: {allowed}",
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if _is_coro(func) else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_coro(func: Callable) -> bool:
|
||||
import asyncio
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def load_token_store_from_files(
|
||||
core_token_file: str | None,
|
||||
observer_token_file: str | None,
|
||||
) -> TokenStore:
|
||||
tokens: dict[str, Principal] = {}
|
||||
if core_token_file:
|
||||
with open(core_token_file) as f:
|
||||
tokens[f.read().strip()] = Principal(name="core", capabilities={"core"})
|
||||
if observer_token_file:
|
||||
with open(observer_token_file) as f:
|
||||
tokens[f.read().strip()] = Principal(
|
||||
name="observer", capabilities={"observer"}
|
||||
)
|
||||
return TokenStore(tokens=tokens)
|
||||
Reference in New Issue
Block a user