Files
TieMeasureFlow/server/middleware/rate_limit.py
T
Adriano 86df67f2e5 perf: scale workers + per-tablet rate limiting for 20 concurrent users
The default 2-worker gunicorn could only serve 2 concurrent tablet requests,
queueing the rest, and the rate limiter saw every tablet as the same Nginx
container IP, so 20 users would have collectively burned through the
100 req/min general bucket.

- gunicorn: 5 workers x 4 gthread, --forwarded-allow-ips=*, access log
- uvicorn: 4 workers, --proxy-headers, --forwarded-allow-ips=*
- RateLimitMiddleware: resolve real client IP from
  X-Forwarded-For -> X-Real-IP -> request.client.host
- Bump rate_limit_general 100 -> 300 req/min/IP (per tablet now)
- Flask: ProxyFix(x_for=1, x_proto=1, x_host=1) so request.remote_addr
  is the tablet IP, not the Nginx IP
- APIClient: forward X-Forwarded-For + X-Real-IP to FastAPI for both
  JSON and multipart/files calls; safe no-op outside request context
- 12 new tests (7 server + 5 client) covering header precedence,
  forwarding behavior and ProxyFix install

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-25 12:07:43 +02:00

124 lines
4.6 KiB
Python

"""Rate limiting middleware for FastAPI.
Implements in-memory sliding window rate limiting per client IP.
Configurable limits for login and general endpoints.
"""
import time
from collections import defaultdict
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from config import settings
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware that enforces per-IP rate limits using a sliding window.
- Login endpoint (/api/auth/login): limited to `rate_limit_login` req/min.
- All other endpoints: limited to `rate_limit_general` req/min.
Returns HTTP 429 with Retry-After header when limit exceeded.
"""
LOGIN_PATH = "/api/auth/login"
WINDOW_SECONDS = 60
def __init__(self, app) -> None:
super().__init__(app)
# {ip: [timestamp, ...]} per bucket
self._login_requests: dict[str, list[float]] = defaultdict(list)
self._general_requests: dict[str, list[float]] = defaultdict(list)
self._request_count = 0 # Counter for triggering eviction
@staticmethod
def _client_ip(request: Request) -> str:
"""Resolve the originating client IP, honoring proxy headers.
Order of precedence: ``X-Forwarded-For`` (first hop), ``X-Real-IP``,
``request.client.host``. Required because Nginx and the Flask client
sit between the tablet and the API; without parsing these headers
every tablet shares one bucket.
"""
xff = request.headers.get("x-forwarded-for")
if xff:
first = xff.split(",")[0].strip()
if first:
return first
real = request.headers.get("x-real-ip")
if real:
return real.strip()
return request.client.host if request.client else "unknown"
def _clean_window(self, timestamps: list[float], now: float) -> list[float]:
"""Remove timestamps outside the current sliding window."""
cutoff = now - self.WINDOW_SECONDS
return [t for t in timestamps if t > cutoff]
def _evict_stale_ips(self, bucket: dict[str, list[float]], now: float) -> None:
"""Remove IP entries with no timestamps in the current window (memory leak prevention)."""
cutoff = now - self.WINDOW_SECONDS
stale_ips = [ip for ip, timestamps in bucket.items() if not timestamps or max(timestamps) <= cutoff]
for ip in stale_ips:
del bucket[ip]
def _check_rate_limit(
self,
bucket: dict[str, list[float]],
client_ip: str,
limit: int,
now: float,
) -> tuple[bool, int]:
"""Check if a request is within the rate limit.
Returns:
Tuple of (allowed, retry_after_seconds).
"""
bucket[client_ip] = self._clean_window(bucket[client_ip], now)
if len(bucket[client_ip]) >= limit:
# Calculate seconds until the oldest request falls out of window
oldest = bucket[client_ip][0]
retry_after = int(oldest + self.WINDOW_SECONDS - now) + 1
return False, max(retry_after, 1)
bucket[client_ip].append(now)
return True, 0
async def dispatch(self, request: Request, call_next: Callable) -> Response:
client_ip = self._client_ip(request)
now = time.time()
path = request.url.path
# Periodic eviction: every 100 requests, remove stale IP buckets
self._request_count += 1
if self._request_count % 100 == 0:
self._evict_stale_ips(self._login_requests, now)
self._evict_stale_ips(self._general_requests, now)
# Check login-specific rate limit
if path == self.LOGIN_PATH and request.method == "POST":
allowed, retry_after = self._check_rate_limit(
self._login_requests, client_ip, settings.rate_limit_login, now
)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Too many login attempts. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
# Check general rate limit
allowed, retry_after = self._check_rate_limit(
self._general_requests, client_ip, settings.rate_limit_general, now
)
if not allowed:
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."},
headers={"Retry-After": str(retry_after)},
)
return await call_next(request)