diff --git a/pyproject.toml b/pyproject.toml index 5f80ab9..2124d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,5 +37,7 @@ dev = [ "pytest>=8.0", "pytest-asyncio>=0.24", "pytest-httpx>=0.30", + "respx>=0.21", + "pytest-cov>=5.0", "ruff>=0.7", ] diff --git a/services/mcp-docugen/src/mcp_docugen/generation_store.py b/services/mcp-docugen/src/mcp_docugen/generation_store.py new file mode 100644 index 0000000..84fd1fc --- /dev/null +++ b/services/mcp-docugen/src/mcp_docugen/generation_store.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + +import aiosqlite + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS generations ( + id TEXT PRIMARY KEY, + timestamp INTEGER NOT NULL, + template_name TEXT NOT NULL, + model TEXT NOT NULL, + tokens_in INTEGER NOT NULL, + tokens_out INTEGER NOT NULL, + cost_usd REAL NOT NULL, + success INTEGER NOT NULL, + error_msg TEXT +); +CREATE INDEX IF NOT EXISTS idx_generations_timestamp ON generations(timestamp); + +CREATE TABLE IF NOT EXISTS ephemeral_assets ( + generation_id TEXT NOT NULL, + var_name TEXT NOT NULL, + file_path TEXT NOT NULL, + mime TEXT NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + PRIMARY KEY (generation_id, var_name) +); +CREATE INDEX IF NOT EXISTS idx_assets_expires ON ephemeral_assets(expires_at); +""" + + +@dataclass +class GenerationRecord: + id: str + template_name: str + model: str + tokens_in: int + tokens_out: int + cost_usd: float + success: bool + error_msg: str | None + + +@dataclass +class EphemeralAssetRecord: + generation_id: str + var_name: str + file_path: str + mime: str + ttl_days: int + + +@dataclass +class EphemeralAssetInfo: + generation_id: str + var_name: str + file_path: str + mime: str + created_at: datetime + expires_at: datetime + is_expired: bool + + +def _now_ms() -> int: + return int(datetime.now(tz=timezone.utc).timestamp() * 1000) + + +class GenerationStore: + def __init__(self, db_path: Path, generated_dir: Path) -> None: + self.db_path = Path(db_path) + self.generated_dir = Path(generated_dir) + self.generated_dir.mkdir(parents=True, exist_ok=True) + + async def init(self) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.executescript(_SCHEMA) + await db.commit() + + async def record_generation(self, record: GenerationRecord) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + INSERT INTO generations + (id, timestamp, template_name, model, tokens_in, tokens_out, + cost_usd, success, error_msg) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + record.id, + _now_ms(), + record.template_name, + record.model, + record.tokens_in, + record.tokens_out, + record.cost_usd, + 1 if record.success else 0, + record.error_msg, + ), + ) + await db.commit() + + async def register_ephemeral_asset( + self, record: EphemeralAssetRecord + ) -> None: + now = _now_ms() + expires = now + int(record.ttl_days * 24 * 3600 * 1000) + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + """ + INSERT OR REPLACE INTO ephemeral_assets + (generation_id, var_name, file_path, mime, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + record.generation_id, + record.var_name, + record.file_path, + record.mime, + now, + expires, + ), + ) + await db.commit() + + async def get_ephemeral_asset( + self, generation_id: str, filename: str + ) -> EphemeralAssetInfo | None: + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + """ + SELECT generation_id, var_name, file_path, mime, created_at, expires_at + FROM ephemeral_assets + WHERE generation_id = ? + """, + (generation_id,), + ) + rows = await cursor.fetchall() + for row in rows: + path = Path(row[2]) + if path.name == filename: + created_at = datetime.fromtimestamp(row[4] / 1000, tz=timezone.utc) + expires_at = datetime.fromtimestamp(row[5] / 1000, tz=timezone.utc) + return EphemeralAssetInfo( + generation_id=row[0], + var_name=row[1], + file_path=row[2], + mime=row[3], + created_at=created_at, + expires_at=expires_at, + is_expired=expires_at < datetime.now(tz=timezone.utc), + ) + return None + + async def cleanup_expired(self) -> int: + now = _now_ms() + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + "SELECT generation_id, file_path FROM ephemeral_assets WHERE expires_at < ?", + (now,), + ) + rows = await cursor.fetchall() + count = 0 + for _, fpath in rows: + try: + Path(fpath).unlink(missing_ok=True) + except OSError: + pass + count += 1 + for gen_id, _ in rows: + gdir = self.generated_dir / gen_id + if gdir.exists() and not any(gdir.iterdir()): + gdir.rmdir() + await db.execute( + "DELETE FROM ephemeral_assets WHERE expires_at < ?", + (now,), + ) + await db.commit() + return count + + async def get_stats(self) -> dict: + async with aiosqlite.connect(self.db_path) as db: + cur = await db.execute( + "SELECT COUNT(*), SUM(success), SUM(cost_usd) FROM generations" + ) + total, success, cost = await cur.fetchone() + return { + "total": total or 0, + "success": success or 0, + "failed": (total or 0) - (success or 0), + "total_cost_usd": float(cost or 0), + } diff --git a/services/mcp-docugen/src/mcp_docugen/llm_client.py b/services/mcp-docugen/src/mcp_docugen/llm_client.py new file mode 100644 index 0000000..01b0ff3 --- /dev/null +++ b/services/mcp-docugen/src/mcp_docugen/llm_client.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from time import perf_counter + +import httpx + + +class LLMError(Exception): + pass + + +class LLMTimeout(LLMError): + pass + + +class LLMUpstreamError(LLMError): + pass + + +class LLMAuthError(LLMError): + pass + + +class LLMRateLimit(LLMError): + pass + + +class LLMInvalidResponse(LLMError): + pass + + +class LLMEmptyResponse(LLMError): + pass + + +@dataclass +class LLMResponse: + text: str + model: str + tokens_in: int + tokens_out: int + cost_usd: float + latency_ms: int + + +class OpenRouterClient: + def __init__( + self, + api_key: str, + base_url: str, + timeout: float = 60, + max_retries: int = 3, + retry_base_delay: float = 1.0, + ) -> None: + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max_retries + self.retry_base_delay = retry_base_delay + + async def chat(self, model: str, system: str, user: str) -> LLMResponse: + payload = { + "model": model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + } + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + start = perf_counter() + last_transient_error: Exception | None = None + + for attempt in range(self.max_retries): + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/chat/completions", + headers=headers, + json=payload, + ) + except (httpx.ReadTimeout, httpx.ConnectTimeout) as exc: + last_transient_error = exc + if attempt == self.max_retries - 1: + raise LLMTimeout(str(exc)) from exc + await self._sleep_backoff(attempt, rate_limit=False) + continue + + status = response.status_code + if status == 200: + return self._parse_success(response, start) + if status in (401, 403): + raise LLMAuthError(f"status {status}: {response.text[:200]}") + if status == 429: + if attempt == self.max_retries - 1: + raise LLMRateLimit(response.text[:200]) + await self._sleep_backoff(attempt, rate_limit=True) + continue + if 500 <= status < 600: + if attempt == self.max_retries - 1: + raise LLMUpstreamError( + f"status {status}: {response.text[:200]}" + ) + await self._sleep_backoff(attempt, rate_limit=False) + continue + raise LLMUpstreamError( + f"unexpected status {status}: {response.text[:200]}" + ) + + raise LLMUpstreamError(f"retries exhausted: {last_transient_error}") + + async def _sleep_backoff(self, attempt: int, rate_limit: bool) -> None: + multiplier = 5 if rate_limit else 1 + delay = self.retry_base_delay * multiplier * (2**attempt) + if delay > 0: + await asyncio.sleep(delay) + + def _parse_success( + self, response: httpx.Response, start: float + ) -> LLMResponse: + try: + data = response.json() + choice = data["choices"][0] + text = choice["message"]["content"] + model = data.get("model", "unknown") + usage = data.get("usage", {}) + except (KeyError, IndexError, ValueError) as exc: + raise LLMInvalidResponse(str(exc)) from exc + + if not text: + raise LLMEmptyResponse("LLM returned empty content") + + latency_ms = int((perf_counter() - start) * 1000) + return LLMResponse( + text=text, + model=model, + tokens_in=int(usage.get("prompt_tokens", 0)), + tokens_out=int(usage.get("completion_tokens", 0)), + cost_usd=float(usage.get("total_cost", 0.0)), + latency_ms=latency_ms, + ) diff --git a/services/mcp-docugen/src/mcp_docugen/template_store.py b/services/mcp-docugen/src/mcp_docugen/template_store.py new file mode 100644 index 0000000..80605cd --- /dev/null +++ b/services/mcp-docugen/src/mcp_docugen/template_store.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import base64 +import re +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + +import aiofiles +import yaml +from pydantic import ValidationError + +from mcp_docugen.models import TemplateAsset, TemplateFrontmatter, TemplateSummary + +_SLUG_RE = re.compile(r"^[a-z0-9][a-z0-9-]*$") +_ASSET_FILENAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") +_FRONTMATTER_DELIM = "---" + + +class TemplateNotFound(Exception): + pass + + +class TemplateAlreadyExists(Exception): + pass + + +class InvalidFrontmatter(Exception): + pass + + +class InvalidTemplateName(ValueError): + pass + + +@dataclass +class LoadedTemplate: + frontmatter: TemplateFrontmatter + body: str + assets: list[TemplateAsset] + updated_at: datetime + + +class TemplateStore: + def __init__(self, base_dir: Path) -> None: + self.base_dir = Path(base_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + async def create( + self, + name: str, + frontmatter: TemplateFrontmatter, + body: str, + assets: list[dict] | None = None, + ) -> None: + self._validate_name(name) + tdir = self.base_dir / name + if tdir.exists(): + raise TemplateAlreadyExists(name) + tdir.mkdir(parents=True) + (tdir / "assets").mkdir() + if assets: + self._validate_assets(assets) + await self._write_template_file(tdir, frontmatter, body) + if assets: + await self._write_assets(tdir, assets) + + async def update( + self, + name: str, + frontmatter: TemplateFrontmatter, + body: str, + assets: list[dict] | None = None, + ) -> None: + self._validate_name(name) + tdir = self.base_dir / name + if not tdir.exists(): + raise TemplateNotFound(name) + await self._write_template_file(tdir, frontmatter, body) + if assets is not None: + assets_dir = tdir / "assets" + for f in assets_dir.iterdir(): + f.unlink() + await self._write_assets(tdir, assets) + + async def delete(self, name: str) -> None: + self._validate_name(name) + tdir = self.base_dir / name + if not tdir.exists(): + raise TemplateNotFound(name) + for f in (tdir / "assets").iterdir(): + f.unlink() + (tdir / "assets").rmdir() + (tdir / "template.md").unlink() + tdir.rmdir() + + async def get(self, name: str) -> LoadedTemplate: + self._validate_name(name) + tdir = self.base_dir / name + tpath = tdir / "template.md" + if not tpath.exists(): + raise TemplateNotFound(name) + frontmatter, body = await self._read_template_file(tpath) + assets = self._list_assets(tdir) + stat = tpath.stat() + updated_at = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc) + return LoadedTemplate( + frontmatter=frontmatter, body=body, assets=assets, updated_at=updated_at + ) + + async def list(self) -> list[TemplateSummary]: + summaries = [] + for tdir in sorted(self.base_dir.iterdir()): + if not tdir.is_dir(): + continue + tpath = tdir / "template.md" + if not tpath.exists(): + continue + try: + frontmatter, _ = await self._read_template_file(tpath) + except InvalidFrontmatter: + continue + stat = tpath.stat() + summaries.append( + TemplateSummary( + name=frontmatter.name, + description=frontmatter.description, + updated_at=datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc), + ) + ) + return summaries + + async def read_asset(self, template_name: str, filename: str) -> bytes: + self._validate_name(template_name) + self._validate_asset_filename(filename) + path = self.base_dir / template_name / "assets" / filename + if not path.exists(): + raise FileNotFoundError(filename) + async with aiofiles.open(path, "rb") as f: + return await f.read() + + def asset_path(self, template_name: str, filename: str) -> Path: + self._validate_name(template_name) + self._validate_asset_filename(filename) + return self.base_dir / template_name / "assets" / filename + + def _validate_name(self, name: str) -> None: + if not _SLUG_RE.match(name): + raise InvalidTemplateName(name) + + def _validate_asset_filename(self, filename: str) -> None: + if not _ASSET_FILENAME_RE.match(filename): + raise ValueError(f"invalid asset filename: {filename!r}") + + def _validate_assets(self, assets: list[dict]) -> None: + for asset in assets: + self._validate_asset_filename(asset["filename"]) + + async def _write_template_file( + self, tdir: Path, frontmatter: TemplateFrontmatter, body: str + ) -> None: + yaml_text = yaml.safe_dump( + frontmatter.model_dump(exclude_none=True), sort_keys=False + ) + content = f"{_FRONTMATTER_DELIM}\n{yaml_text}{_FRONTMATTER_DELIM}\n{body}" + async with aiofiles.open(tdir / "template.md", "w") as f: + await f.write(content) + + async def _read_template_file( + self, path: Path + ) -> tuple[TemplateFrontmatter, str]: + async with aiofiles.open(path) as f: + raw = await f.read() + if not raw.startswith(_FRONTMATTER_DELIM): + raise InvalidFrontmatter("missing opening '---'") + parts = raw.split(_FRONTMATTER_DELIM, 2) + if len(parts) < 3: + raise InvalidFrontmatter("missing closing '---'") + yaml_text = parts[1] + body = parts[2].lstrip("\n") + try: + data = yaml.safe_load(yaml_text) or {} + frontmatter = TemplateFrontmatter(**data) + except (yaml.YAMLError, ValidationError) as exc: + raise InvalidFrontmatter(str(exc)) from exc + return frontmatter, body + + async def _write_assets(self, tdir: Path, assets: list[dict]) -> None: + for asset in assets: + filename = asset["filename"] + self._validate_asset_filename(filename) + data = base64.b64decode(asset["data_b64"]) + async with aiofiles.open(tdir / "assets" / filename, "wb") as f: + await f.write(data) + + def _list_assets(self, tdir: Path) -> list[TemplateAsset]: + assets_dir = tdir / "assets" + if not assets_dir.exists(): + return [] + out = [] + for f in sorted(assets_dir.iterdir()): + if not f.is_file(): + continue + out.append( + TemplateAsset( + filename=f.name, + mime=_guess_mime(f.name), + size_bytes=f.stat().st_size, + ) + ) + return out + + +def _guess_mime(filename: str) -> str: + ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" + return { + "png": "image/png", + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "webp": "image/webp", + }.get(ext, "application/octet-stream") diff --git a/services/mcp-docugen/tests/unit/test_generation_store.py b/services/mcp-docugen/tests/unit/test_generation_store.py new file mode 100644 index 0000000..2c35e17 --- /dev/null +++ b/services/mcp-docugen/tests/unit/test_generation_store.py @@ -0,0 +1,115 @@ +import pytest + +from mcp_docugen.generation_store import ( + EphemeralAssetRecord, + GenerationRecord, + GenerationStore, +) + + +@pytest.fixture +async def store(tmp_path): + s = GenerationStore( + db_path=tmp_path / "gen.db", generated_dir=tmp_path / "generated" + ) + await s.init() + return s + + +async def test_record_success_generation(store): + await store.record_generation( + GenerationRecord( + id="g-1", + template_name="fattura", + model="m", + tokens_in=100, + tokens_out=200, + cost_usd=0.01, + success=True, + error_msg=None, + ) + ) + stats = await store.get_stats() + assert stats["total"] == 1 + assert stats["success"] == 1 + assert stats["failed"] == 0 + + +async def test_record_failed_generation(store): + await store.record_generation( + GenerationRecord( + id="g-2", + template_name="fattura", + model="m", + tokens_in=0, + tokens_out=0, + cost_usd=0.0, + success=False, + error_msg="LLMTimeout", + ) + ) + stats = await store.get_stats() + assert stats["failed"] == 1 + + +async def test_register_ephemeral_asset(store, tmp_path): + asset_file = tmp_path / "generated" / "g-1" / "foto.png" + asset_file.parent.mkdir(parents=True) + asset_file.write_bytes(b"png-bytes") + + await store.register_ephemeral_asset( + EphemeralAssetRecord( + generation_id="g-1", + var_name="foto", + file_path=str(asset_file), + mime="image/png", + ttl_days=30, + ) + ) + asset = await store.get_ephemeral_asset("g-1", "foto.png") + assert asset is not None + assert asset.mime == "image/png" + + +async def test_get_ephemeral_asset_returns_none_if_missing(store): + asset = await store.get_ephemeral_asset("nope", "foo.png") + assert asset is None + + +async def test_cleanup_expired_removes_records_and_files(store, tmp_path): + asset_file = tmp_path / "generated" / "g-old" / "foto.png" + asset_file.parent.mkdir(parents=True) + asset_file.write_bytes(b"bytes") + + await store.register_ephemeral_asset( + EphemeralAssetRecord( + generation_id="g-old", + var_name="foto", + file_path=str(asset_file), + mime="image/png", + ttl_days=-1, + ) + ) + removed = await store.cleanup_expired() + assert removed == 1 + assert not asset_file.exists() + assert await store.get_ephemeral_asset("g-old", "foto.png") is None + + +async def test_ephemeral_asset_expired_flag(store, tmp_path): + f = tmp_path / "generated" / "g-e" / "img.png" + f.parent.mkdir(parents=True) + f.write_bytes(b"x") + + await store.register_ephemeral_asset( + EphemeralAssetRecord( + generation_id="g-e", + var_name="img", + file_path=str(f), + mime="image/png", + ttl_days=-1, + ) + ) + asset = await store.get_ephemeral_asset("g-e", "img.png") + assert asset is not None + assert asset.is_expired is True diff --git a/services/mcp-docugen/tests/unit/test_llm_client.py b/services/mcp-docugen/tests/unit/test_llm_client.py new file mode 100644 index 0000000..620ed77 --- /dev/null +++ b/services/mcp-docugen/tests/unit/test_llm_client.py @@ -0,0 +1,173 @@ +import httpx +import pytest +import respx + +from mcp_docugen.llm_client import ( + LLMAuthError, + LLMEmptyResponse, + LLMInvalidResponse, + LLMRateLimit, + LLMTimeout, + LLMUpstreamError, + OpenRouterClient, +) + + +def _success_body(text: str = "output text") -> dict: + return { + "id": "gen-1", + "choices": [{"message": {"role": "assistant", "content": text}}], + "model": "anthropic/claude-sonnet-4", + "usage": { + "prompt_tokens": 100, + "completion_tokens": 200, + "total_cost": 0.01, + }, + } + + +@respx.mock +async def test_chat_success(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(200, json=_success_body("hello")) + ) + client = OpenRouterClient( + api_key="sk", base_url="https://openrouter.ai/api/v1", timeout=5 + ) + resp = await client.chat( + model="anthropic/claude-sonnet-4", system="sys", user="user" + ) + assert resp.text == "hello" + assert resp.tokens_in == 100 + assert resp.tokens_out == 200 + assert resp.cost_usd == 0.01 + assert resp.model == "anthropic/claude-sonnet-4" + + +@respx.mock +async def test_chat_retries_on_5xx(): + route = respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + side_effect=[ + httpx.Response(503), + httpx.Response(502), + httpx.Response(200, json=_success_body()), + ] + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + resp = await client.chat(model="m", system="s", user="u") + assert resp.text == "output text" + assert route.call_count == 3 + + +@respx.mock +async def test_chat_exhausts_retries_5xx(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(500) + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + with pytest.raises(LLMUpstreamError): + await client.chat(model="m", system="s", user="u") + + +@respx.mock +async def test_chat_retries_on_429(): + route = respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + side_effect=[ + httpx.Response(429), + httpx.Response(200, json=_success_body()), + ] + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + resp = await client.chat(model="m", system="s", user="u") + assert route.call_count == 2 + assert resp.text == "output text" + + +@respx.mock +async def test_chat_exhausts_retries_429(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(429) + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + with pytest.raises(LLMRateLimit): + await client.chat(model="m", system="s", user="u") + + +@respx.mock +async def test_chat_no_retry_on_401(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(401) + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + with pytest.raises(LLMAuthError): + await client.chat(model="m", system="s", user="u") + + +@respx.mock +async def test_chat_timeout(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + side_effect=httpx.ReadTimeout("timeout") + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=1, + retry_base_delay=0, + ) + with pytest.raises(LLMTimeout): + await client.chat(model="m", system="s", user="u") + + +@respx.mock +async def test_chat_invalid_response_shape(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(200, json={"no": "choices"}) + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + with pytest.raises(LLMInvalidResponse): + await client.chat(model="m", system="s", user="u") + + +@respx.mock +async def test_chat_empty_content(): + respx.post("https://openrouter.ai/api/v1/chat/completions").mock( + return_value=httpx.Response(200, json=_success_body(text="")) + ) + client = OpenRouterClient( + api_key="sk", + base_url="https://openrouter.ai/api/v1", + timeout=5, + retry_base_delay=0, + ) + with pytest.raises(LLMEmptyResponse): + await client.chat(model="m", system="s", user="u") diff --git a/services/mcp-docugen/tests/unit/test_template_store.py b/services/mcp-docugen/tests/unit/test_template_store.py new file mode 100644 index 0000000..3ac1fa0 --- /dev/null +++ b/services/mcp-docugen/tests/unit/test_template_store.py @@ -0,0 +1,109 @@ +import base64 + +import pytest + +from mcp_docugen.models import TemplateFrontmatter +from mcp_docugen.template_store import ( + InvalidFrontmatter, + TemplateAlreadyExists, + TemplateNotFound, + TemplateStore, +) + + +@pytest.fixture +def store(tmp_path): + return TemplateStore(base_dir=tmp_path) + + +def _fm(name="fattura"): + return TemplateFrontmatter( + name=name, + description="Fattura commerciale", + required_variables=[{"name": "cliente", "type": "string"}], + ) + + +async def test_create_and_get(store): + fm = _fm() + await store.create(name="fattura", frontmatter=fm, body="# Body {{cliente}}") + got = await store.get("fattura") + assert got.frontmatter.name == "fattura" + assert "# Body {{cliente}}" in got.body + + +async def test_create_duplicate_raises(store): + fm = _fm() + await store.create(name="fattura", frontmatter=fm, body="body") + with pytest.raises(TemplateAlreadyExists): + await store.create(name="fattura", frontmatter=fm, body="body") + + +async def test_get_missing_raises(store): + with pytest.raises(TemplateNotFound): + await store.get("nope") + + +async def test_list_returns_summaries(store): + await store.create(name="a", frontmatter=_fm("a"), body="b") + await store.create(name="b", frontmatter=_fm("b"), body="b") + result = await store.list() + names = sorted(t.name for t in result) + assert names == ["a", "b"] + + +async def test_update_overwrites(store): + await store.create(name="f", frontmatter=_fm("f"), body="old") + new_fm = TemplateFrontmatter(name="f", description="new description") + await store.update(name="f", frontmatter=new_fm, body="new body") + got = await store.get("f") + assert got.body == "new body" + assert got.frontmatter.description == "new description" + + +async def test_update_missing_raises(store): + with pytest.raises(TemplateNotFound): + await store.update(name="nope", frontmatter=_fm("nope"), body="x") + + +async def test_delete_removes(store): + await store.create(name="f", frontmatter=_fm("f"), body="b") + await store.delete("f") + with pytest.raises(TemplateNotFound): + await store.get("f") + + +async def test_delete_missing_raises(store): + with pytest.raises(TemplateNotFound): + await store.delete("nope") + + +async def test_assets_are_saved_and_listed(store): + png_bytes = b"\x89PNG\r\n\x1a\n" + assets = [ + { + "filename": "logo.png", + "data_b64": base64.b64encode(png_bytes).decode(), + "mime": "image/png", + } + ] + await store.create(name="f", frontmatter=_fm("f"), body="b", assets=assets) + got = await store.get("f") + asset_names = [a.filename for a in got.assets] + assert "logo.png" in asset_names + content = await store.read_asset("f", "logo.png") + assert content == png_bytes + + +async def test_asset_filename_rejects_path_traversal(store): + assets = [{"filename": "../evil.png", "data_b64": "aGk=", "mime": "image/png"}] + with pytest.raises(ValueError): + await store.create(name="f", frontmatter=_fm("f"), body="b", assets=assets) + + +async def test_frontmatter_parsing_rejects_malformed_yaml(store, tmp_path): + template_dir = tmp_path / "broken" + template_dir.mkdir() + (template_dir / "template.md").write_text("---\nname: :::broken\n---\nbody") + with pytest.raises(InvalidFrontmatter): + await store.get("broken") diff --git a/uv.lock b/uv.lock index ed679b3..0885ee3 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,9 @@ members = [ dev = [ { name = "pytest", specifier = ">=8.0" }, { name = "pytest-asyncio", specifier = ">=0.24" }, + { name = "pytest-cov", specifier = ">=5.0" }, { name = "pytest-httpx", specifier = ">=0.30" }, + { name = "respx", specifier = ">=0.21" }, { name = "ruff", specifier = ">=0.7" }, ]