"""Shared test fixtures for server tests. Uses SQLite async (aiosqlite) as in-memory test database. Overrides FastAPI's get_db dependency to inject the test session. """ import sys from pathlib import Path from collections.abc import AsyncGenerator from unittest.mock import MagicMock import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.pool import StaticPool # --------------------------------------------------------------------------- # Mock heavy optional dependencies that require system libraries. # WeasyPrint needs GTK/Pango which may not be available in test environments. # We mock it before any server code is imported. # --------------------------------------------------------------------------- if "weasyprint" not in sys.modules: _mock_weasyprint = MagicMock() sys.modules["weasyprint"] = _mock_weasyprint # Ensure the server package is importable sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from database import Base, get_db from main import app from middleware.rate_limit import RateLimitMiddleware from models.user import User from models.recipe import Recipe, RecipeVersion from models.task import RecipeTask, RecipeSubtask from models.measurement import Measurement from models.access_log import AccessLog from models.setting import SystemSetting, RecipeVersionAudit from models.station import Station, StationRecipeAssignment from services.auth_service import hash_password, generate_api_key # --------------------------------------------------------------------------- # In-memory SQLite engine for tests # --------------------------------------------------------------------------- TEST_DATABASE_URL = "sqlite+aiosqlite://" test_engine = create_async_engine( TEST_DATABASE_URL, echo=False, connect_args={"check_same_thread": False}, # StaticPool keeps a single connection for in-memory SQLite so that # create_all, fixtures, and the app share the same database. poolclass=StaticPool, ) TestSessionFactory = async_sessionmaker( test_engine, class_=AsyncSession, expire_on_commit=False, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest_asyncio.fixture(autouse=True) async def setup_database(): """Create all tables before each test and drop them after.""" async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture(autouse=True) def reset_rate_limits(): """Clear rate limit buckets between tests to avoid 429 in test suite. The RateLimitMiddleware is added to the module-level ``app`` object and persists across tests. Walk the ASGI middleware stack to find the instance and clear its per-IP sliding-window dictionaries. """ middleware = app.middleware_stack while middleware is not None: if isinstance(middleware, RateLimitMiddleware): middleware._login_requests.clear() middleware._general_requests.clear() break middleware = getattr(middleware, "app", None) yield @pytest_asyncio.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """Yield a fresh async session for direct DB manipulation in tests.""" async with TestSessionFactory() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() @pytest_asyncio.fixture async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: """Yield an httpx AsyncClient wired to the FastAPI app with test DB.""" async def _override_get_db() -> AsyncGenerator[AsyncSession, None]: yield db_session app.dependency_overrides[get_db] = _override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://testserver") as ac: yield ac app.dependency_overrides.clear() # --------------------------------------------------------------------------- # User factory helpers # --------------------------------------------------------------------------- async def _create_user( session: AsyncSession, username: str, password: str = "TestPass123", display_name: str | None = None, roles: list[str] | None = None, is_admin: bool = False, active: bool = True, ) -> User: """Insert a user into the test DB and return it with an API key set.""" api_key = generate_api_key() user = User( username=username, password_hash=hash_password(password), display_name=display_name or username.title(), roles=roles or [], is_admin=is_admin, active=active, api_key=api_key, language_pref="en", theme_pref="light", ) session.add(user) await session.flush() await session.refresh(user) return user @pytest_asyncio.fixture async def admin_user(db_session: AsyncSession) -> User: """An active admin user with API key.""" return await _create_user( db_session, username="admin", display_name="Admin User", roles=["Maker", "MeasurementTec", "Metrologist"], is_admin=True, ) @pytest_asyncio.fixture async def maker_user(db_session: AsyncSession) -> User: """An active Maker user with API key.""" return await _create_user( db_session, username="maker", display_name="Maker User", roles=["Maker"], ) @pytest_asyncio.fixture async def measurement_tec_user(db_session: AsyncSession) -> User: """An active MeasurementTec user with API key.""" return await _create_user( db_session, username="measurement_tec", display_name="MeasurementTec User", roles=["MeasurementTec"], ) @pytest_asyncio.fixture async def metrologist_user(db_session: AsyncSession) -> User: """An active Metrologist user with API key.""" return await _create_user( db_session, username="metrologist", display_name="Metrologist User", roles=["Metrologist"], ) # --------------------------------------------------------------------------- # Recipe helper # --------------------------------------------------------------------------- async def create_test_recipe( session: AsyncSession, user_id: int, code: str = "REC-001", name: str = "Test Recipe", ) -> Recipe: """Create a recipe with one version (v1 current), one task, and one subtask. Returns the Recipe ORM object. """ recipe = Recipe( code=code, name=name, description="A recipe for testing", created_by=user_id, ) session.add(recipe) await session.flush() version = RecipeVersion( recipe_id=recipe.id, version_number=1, is_current=True, created_by=user_id, change_notes="Initial version", ) session.add(version) await session.flush() task = RecipeTask( version_id=version.id, order_index=0, title="Test Task", directive="Measure the part", description="First measurement task", ) session.add(task) await session.flush() subtask = RecipeSubtask( task_id=task.id, marker_number=1, description="Diameter measurement", measurement_type="diameter", nominal=10.0, utl=10.5, uwl=10.3, lwl=9.7, ltl=9.5, unit="mm", ) session.add(subtask) await session.flush() await session.refresh(recipe, attribute_names=["versions"]) return recipe def auth_headers(user: User) -> dict[str, str]: """Return headers dict with the user's API key for authenticated requests.""" return {"X-API-Key": user.api_key}