86df67f2e5
The default 2-worker gunicorn could only serve 2 concurrent tablet requests, queueing the rest, and the rate limiter saw every tablet as the same Nginx container IP, so 20 users would have collectively burned through the 100 req/min general bucket. - gunicorn: 5 workers x 4 gthread, --forwarded-allow-ips=*, access log - uvicorn: 4 workers, --proxy-headers, --forwarded-allow-ips=* - RateLimitMiddleware: resolve real client IP from X-Forwarded-For -> X-Real-IP -> request.client.host - Bump rate_limit_general 100 -> 300 req/min/IP (per tablet now) - Flask: ProxyFix(x_for=1, x_proto=1, x_host=1) so request.remote_addr is the tablet IP, not the Nginx IP - APIClient: forward X-Forwarded-For + X-Real-IP to FastAPI for both JSON and multipart/files calls; safe no-op outside request context - 12 new tests (7 server + 5 client) covering header precedence, forwarding behavior and ProxyFix install Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
124 lines
4.6 KiB
Python
124 lines
4.6 KiB
Python
"""Rate limiting middleware for FastAPI.
|
|
|
|
Implements in-memory sliding window rate limiting per client IP.
|
|
Configurable limits for login and general endpoints.
|
|
"""
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Callable
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
from config import settings
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that enforces per-IP rate limits using a sliding window.
|
|
|
|
- Login endpoint (/api/auth/login): limited to `rate_limit_login` req/min.
|
|
- All other endpoints: limited to `rate_limit_general` req/min.
|
|
Returns HTTP 429 with Retry-After header when limit exceeded.
|
|
"""
|
|
|
|
LOGIN_PATH = "/api/auth/login"
|
|
WINDOW_SECONDS = 60
|
|
|
|
def __init__(self, app) -> None:
|
|
super().__init__(app)
|
|
# {ip: [timestamp, ...]} per bucket
|
|
self._login_requests: dict[str, list[float]] = defaultdict(list)
|
|
self._general_requests: dict[str, list[float]] = defaultdict(list)
|
|
self._request_count = 0 # Counter for triggering eviction
|
|
|
|
@staticmethod
|
|
def _client_ip(request: Request) -> str:
|
|
"""Resolve the originating client IP, honoring proxy headers.
|
|
|
|
Order of precedence: ``X-Forwarded-For`` (first hop), ``X-Real-IP``,
|
|
``request.client.host``. Required because Nginx and the Flask client
|
|
sit between the tablet and the API; without parsing these headers
|
|
every tablet shares one bucket.
|
|
"""
|
|
xff = request.headers.get("x-forwarded-for")
|
|
if xff:
|
|
first = xff.split(",")[0].strip()
|
|
if first:
|
|
return first
|
|
real = request.headers.get("x-real-ip")
|
|
if real:
|
|
return real.strip()
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
def _clean_window(self, timestamps: list[float], now: float) -> list[float]:
|
|
"""Remove timestamps outside the current sliding window."""
|
|
cutoff = now - self.WINDOW_SECONDS
|
|
return [t for t in timestamps if t > cutoff]
|
|
|
|
def _evict_stale_ips(self, bucket: dict[str, list[float]], now: float) -> None:
|
|
"""Remove IP entries with no timestamps in the current window (memory leak prevention)."""
|
|
cutoff = now - self.WINDOW_SECONDS
|
|
stale_ips = [ip for ip, timestamps in bucket.items() if not timestamps or max(timestamps) <= cutoff]
|
|
for ip in stale_ips:
|
|
del bucket[ip]
|
|
|
|
def _check_rate_limit(
|
|
self,
|
|
bucket: dict[str, list[float]],
|
|
client_ip: str,
|
|
limit: int,
|
|
now: float,
|
|
) -> tuple[bool, int]:
|
|
"""Check if a request is within the rate limit.
|
|
|
|
Returns:
|
|
Tuple of (allowed, retry_after_seconds).
|
|
"""
|
|
bucket[client_ip] = self._clean_window(bucket[client_ip], now)
|
|
|
|
if len(bucket[client_ip]) >= limit:
|
|
# Calculate seconds until the oldest request falls out of window
|
|
oldest = bucket[client_ip][0]
|
|
retry_after = int(oldest + self.WINDOW_SECONDS - now) + 1
|
|
return False, max(retry_after, 1)
|
|
|
|
bucket[client_ip].append(now)
|
|
return True, 0
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
client_ip = self._client_ip(request)
|
|
now = time.time()
|
|
path = request.url.path
|
|
|
|
# Periodic eviction: every 100 requests, remove stale IP buckets
|
|
self._request_count += 1
|
|
if self._request_count % 100 == 0:
|
|
self._evict_stale_ips(self._login_requests, now)
|
|
self._evict_stale_ips(self._general_requests, now)
|
|
|
|
# Check login-specific rate limit
|
|
if path == self.LOGIN_PATH and request.method == "POST":
|
|
allowed, retry_after = self._check_rate_limit(
|
|
self._login_requests, client_ip, settings.rate_limit_login, now
|
|
)
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Too many login attempts. Please try again later."},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
# Check general rate limit
|
|
allowed, retry_after = self._check_rate_limit(
|
|
self._general_requests, client_ip, settings.rate_limit_general, now
|
|
)
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Too many requests. Please try again later."},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
return await call_next(request)
|