"""Rate limiting middleware for FastAPI. Implements in-memory sliding window rate limiting per client IP. Configurable limits for login and general endpoints. """ import time from collections import defaultdict from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from config import settings class RateLimitMiddleware(BaseHTTPMiddleware): """Middleware that enforces per-IP rate limits using a sliding window. - Login endpoint (/api/auth/login): limited to `rate_limit_login` req/min. - All other endpoints: limited to `rate_limit_general` req/min. Returns HTTP 429 with Retry-After header when limit exceeded. """ LOGIN_PATH = "/api/auth/login" WINDOW_SECONDS = 60 def __init__(self, app) -> None: super().__init__(app) # {ip: [timestamp, ...]} per bucket self._login_requests: dict[str, list[float]] = defaultdict(list) self._general_requests: dict[str, list[float]] = defaultdict(list) self._request_count = 0 # Counter for triggering eviction @staticmethod def _client_ip(request: Request) -> str: """Resolve the originating client IP, honoring proxy headers. Order of precedence: ``X-Forwarded-For`` (first hop), ``X-Real-IP``, ``request.client.host``. Required because Nginx and the Flask client sit between the tablet and the API; without parsing these headers every tablet shares one bucket. """ xff = request.headers.get("x-forwarded-for") if xff: first = xff.split(",")[0].strip() if first: return first real = request.headers.get("x-real-ip") if real: return real.strip() return request.client.host if request.client else "unknown" def _clean_window(self, timestamps: list[float], now: float) -> list[float]: """Remove timestamps outside the current sliding window.""" cutoff = now - self.WINDOW_SECONDS return [t for t in timestamps if t > cutoff] def _evict_stale_ips(self, bucket: dict[str, list[float]], now: float) -> None: """Remove IP entries with no timestamps in the current window (memory leak prevention).""" cutoff = now - self.WINDOW_SECONDS stale_ips = [ip for ip, timestamps in bucket.items() if not timestamps or max(timestamps) <= cutoff] for ip in stale_ips: del bucket[ip] def _check_rate_limit( self, bucket: dict[str, list[float]], client_ip: str, limit: int, now: float, ) -> tuple[bool, int]: """Check if a request is within the rate limit. Returns: Tuple of (allowed, retry_after_seconds). """ bucket[client_ip] = self._clean_window(bucket[client_ip], now) if len(bucket[client_ip]) >= limit: # Calculate seconds until the oldest request falls out of window oldest = bucket[client_ip][0] retry_after = int(oldest + self.WINDOW_SECONDS - now) + 1 return False, max(retry_after, 1) bucket[client_ip].append(now) return True, 0 async def dispatch(self, request: Request, call_next: Callable) -> Response: client_ip = self._client_ip(request) now = time.time() path = request.url.path # Periodic eviction: every 100 requests, remove stale IP buckets self._request_count += 1 if self._request_count % 100 == 0: self._evict_stale_ips(self._login_requests, now) self._evict_stale_ips(self._general_requests, now) # Check login-specific rate limit if path == self.LOGIN_PATH and request.method == "POST": allowed, retry_after = self._check_rate_limit( self._login_requests, client_ip, settings.rate_limit_login, now ) if not allowed: return JSONResponse( status_code=429, content={"detail": "Too many login attempts. Please try again later."}, headers={"Retry-After": str(retry_after)}, ) # Check general rate limit allowed, retry_after = self._check_rate_limit( self._general_requests, client_ip, settings.rate_limit_general, now ) if not allowed: return JSONResponse( status_code=429, content={"detail": "Too many requests. Please try again later."}, headers={"Retry-After": str(retry_after)}, ) return await call_next(request)