99 lines
3.1 KiB
Python
99 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from functools import wraps
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
|
|
|
|
@dataclass
|
|
class Principal:
|
|
name: str
|
|
capabilities: set[str] = field(default_factory=set)
|
|
|
|
|
|
@dataclass
|
|
class TokenStore:
|
|
tokens: dict[str, Principal]
|
|
|
|
def get(self, token: str) -> Principal | None:
|
|
return self.tokens.get(token)
|
|
|
|
|
|
def require_principal(request: Request) -> Principal:
|
|
auth = request.headers.get("Authorization", "")
|
|
if not auth.startswith("Bearer "):
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "missing bearer token")
|
|
token = auth[len("Bearer "):].strip()
|
|
store: TokenStore = request.app.state.token_store
|
|
principal = store.get(token)
|
|
if principal is None:
|
|
raise HTTPException(status.HTTP_403_FORBIDDEN, "invalid token")
|
|
return principal
|
|
|
|
|
|
def acl_requires(*, core: bool = False, observer: bool = False) -> Callable:
|
|
"""Decorator: require at least one matching capability."""
|
|
allowed: set[str] = set()
|
|
if core:
|
|
allowed.add("core")
|
|
if observer:
|
|
allowed.add("observer")
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
principal = kwargs.get("principal")
|
|
if principal is None:
|
|
for a in args:
|
|
if isinstance(a, Principal):
|
|
principal = a
|
|
break
|
|
if principal is None or not (principal.capabilities & allowed):
|
|
raise HTTPException(
|
|
status.HTTP_403_FORBIDDEN,
|
|
f"capability required: {allowed}",
|
|
)
|
|
return await func(*args, **kwargs) if _is_coro(func) else func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
principal = kwargs.get("principal")
|
|
if principal is None:
|
|
for a in args:
|
|
if isinstance(a, Principal):
|
|
principal = a
|
|
break
|
|
if principal is None or not (principal.capabilities & allowed):
|
|
raise HTTPException(
|
|
status.HTTP_403_FORBIDDEN,
|
|
f"capability required: {allowed}",
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
return async_wrapper if _is_coro(func) else sync_wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _is_coro(func: Callable) -> bool:
|
|
import asyncio
|
|
return asyncio.iscoroutinefunction(func)
|
|
|
|
|
|
def load_token_store_from_files(
|
|
core_token_file: str | None,
|
|
observer_token_file: str | None,
|
|
) -> TokenStore:
|
|
tokens: dict[str, Principal] = {}
|
|
if core_token_file:
|
|
with open(core_token_file) as f:
|
|
tokens[f.read().strip()] = Principal(name="core", capabilities={"core"})
|
|
if observer_token_file:
|
|
with open(observer_token_file) as f:
|
|
tokens[f.read().strip()] = Principal(
|
|
name="observer", capabilities={"observer"}
|
|
)
|
|
return TokenStore(tokens=tokens)
|