Files
TieMeasureFlow/server/tests/conftest.py
T
Adriano 5959c9c92a feat(models): add Station and StationRecipeAssignment models
TDD: test written first, confirmed failing with ModuleNotFoundError,
then model implemented; all 3 new tests pass. conftest updated to
import new models so Base.metadata.create_all picks up the tables.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 21:56:22 +02:00

274 lines
8.0 KiB
Python

"""Shared test fixtures for server tests.
Uses SQLite async (aiosqlite) as in-memory test database.
Overrides FastAPI's get_db dependency to inject the test session.
"""
import sys
from pathlib import Path
from collections.abc import AsyncGenerator
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.pool import StaticPool
# ---------------------------------------------------------------------------
# Mock heavy optional dependencies that require system libraries.
# WeasyPrint needs GTK/Pango which may not be available in test environments.
# We mock it before any server code is imported.
# ---------------------------------------------------------------------------
if "weasyprint" not in sys.modules:
_mock_weasyprint = MagicMock()
sys.modules["weasyprint"] = _mock_weasyprint
# Ensure the server package is importable
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from database import Base, get_db
from main import app
from middleware.rate_limit import RateLimitMiddleware
from models.user import User
from models.recipe import Recipe, RecipeVersion
from models.task import RecipeTask, RecipeSubtask
from models.measurement import Measurement
from models.access_log import AccessLog
from models.setting import SystemSetting, RecipeVersionAudit
from models.station import Station, StationRecipeAssignment
from services.auth_service import hash_password, generate_api_key
# ---------------------------------------------------------------------------
# In-memory SQLite engine for tests
# ---------------------------------------------------------------------------
TEST_DATABASE_URL = "sqlite+aiosqlite://"
test_engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
connect_args={"check_same_thread": False},
# StaticPool keeps a single connection for in-memory SQLite so that
# create_all, fixtures, and the app share the same database.
poolclass=StaticPool,
)
TestSessionFactory = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest_asyncio.fixture(autouse=True)
async def setup_database():
"""Create all tables before each test and drop them after."""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture(autouse=True)
def reset_rate_limits():
"""Clear rate limit buckets between tests to avoid 429 in test suite.
The RateLimitMiddleware is added to the module-level ``app`` object and
persists across tests. Walk the ASGI middleware stack to find the
instance and clear its per-IP sliding-window dictionaries.
"""
middleware = app.middleware_stack
while middleware is not None:
if isinstance(middleware, RateLimitMiddleware):
middleware._login_requests.clear()
middleware._general_requests.clear()
break
middleware = getattr(middleware, "app", None)
yield
@pytest_asyncio.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a fresh async session for direct DB manipulation in tests."""
async with TestSessionFactory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
@pytest_asyncio.fixture
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""Yield an httpx AsyncClient wired to the FastAPI app with test DB."""
async def _override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://testserver") as ac:
yield ac
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# User factory helpers
# ---------------------------------------------------------------------------
async def _create_user(
session: AsyncSession,
username: str,
password: str = "TestPass123",
display_name: str | None = None,
roles: list[str] | None = None,
is_admin: bool = False,
active: bool = True,
) -> User:
"""Insert a user into the test DB and return it with an API key set."""
api_key = generate_api_key()
user = User(
username=username,
password_hash=hash_password(password),
display_name=display_name or username.title(),
roles=roles or [],
is_admin=is_admin,
active=active,
api_key=api_key,
language_pref="en",
theme_pref="light",
)
session.add(user)
await session.flush()
await session.refresh(user)
return user
@pytest_asyncio.fixture
async def admin_user(db_session: AsyncSession) -> User:
"""An active admin user with API key."""
return await _create_user(
db_session,
username="admin",
display_name="Admin User",
roles=["Maker", "MeasurementTec", "Metrologist"],
is_admin=True,
)
@pytest_asyncio.fixture
async def maker_user(db_session: AsyncSession) -> User:
"""An active Maker user with API key."""
return await _create_user(
db_session,
username="maker",
display_name="Maker User",
roles=["Maker"],
)
@pytest_asyncio.fixture
async def measurement_tec_user(db_session: AsyncSession) -> User:
"""An active MeasurementTec user with API key."""
return await _create_user(
db_session,
username="measurement_tec",
display_name="MeasurementTec User",
roles=["MeasurementTec"],
)
@pytest_asyncio.fixture
async def metrologist_user(db_session: AsyncSession) -> User:
"""An active Metrologist user with API key."""
return await _create_user(
db_session,
username="metrologist",
display_name="Metrologist User",
roles=["Metrologist"],
)
# ---------------------------------------------------------------------------
# Recipe helper
# ---------------------------------------------------------------------------
async def create_test_recipe(
session: AsyncSession,
user_id: int,
code: str = "REC-001",
name: str = "Test Recipe",
) -> Recipe:
"""Create a recipe with one version (v1 current), one task, and one subtask.
Returns the Recipe ORM object.
"""
recipe = Recipe(
code=code,
name=name,
description="A recipe for testing",
created_by=user_id,
)
session.add(recipe)
await session.flush()
version = RecipeVersion(
recipe_id=recipe.id,
version_number=1,
is_current=True,
created_by=user_id,
change_notes="Initial version",
)
session.add(version)
await session.flush()
task = RecipeTask(
version_id=version.id,
order_index=0,
title="Test Task",
directive="Measure the part",
description="First measurement task",
)
session.add(task)
await session.flush()
subtask = RecipeSubtask(
task_id=task.id,
marker_number=1,
description="Diameter measurement",
measurement_type="diameter",
nominal=10.0,
utl=10.5,
uwl=10.3,
lwl=9.7,
ltl=9.5,
unit="mm",
)
session.add(subtask)
await session.flush()
await session.refresh(recipe, attribute_names=["versions"])
return recipe
def auth_headers(user: User) -> dict[str, str]:
"""Return headers dict with the user's API key for authenticated requests."""
return {"X-API-Key": user.api_key}