from __future__ import annotations import time from jose import JWTError, jwt from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from config import settings # Paths that don't require authentication PUBLIC_PATHS = {"/", "/health", "/auth/login", "/docs", "/openapi.json"} PUBLIC_PREFIXES = ("/health/",) class AuthMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): path = request.url.path # Skip auth for public paths if path in PUBLIC_PATHS or any(path.startswith(p) for p in PUBLIC_PREFIXES): request.state.role = "anonymous" return await call_next(request) # Skip auth for OPTIONS (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # Try Bearer token first, then cookie token = _extract_token(request) if not token: request.state.role = "anonymous" return await call_next(request) # Verify JWT payload = _verify_token(token) if payload: request.state.role = payload.get("role", "guest") else: request.state.role = "anonymous" return await call_next(request) def create_token(role: str) -> str: payload = { "role": role, "exp": time.time() + settings.jwt_expire_hours * 3600, "iat": time.time(), } return jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) def _extract_token(request: Request) -> str | None: # 1. Authorization: Bearer header auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): return auth_header[7:] # 2. httpOnly cookie return request.cookies.get("token") def _verify_token(token: str) -> dict | None: try: payload = jwt.decode( token, settings.jwt_secret, algorithms=[settings.jwt_algorithm] ) if payload.get("exp", 0) < time.time(): return None return payload except JWTError: return None # Login rate limiting (IP-based) _login_attempts: dict[str, list[float]] = {} MAX_ATTEMPTS = 5 LOCKOUT_SECONDS = 60 def check_login_rate_limit(ip: str) -> bool: """Returns True if login is allowed for this IP.""" now = time.time() attempts = _login_attempts.get(ip, []) # Clean old attempts attempts = [t for t in attempts if now - t < LOCKOUT_SECONDS] _login_attempts[ip] = attempts return len(attempts) < MAX_ATTEMPTS def record_login_attempt(ip: str): now = time.time() if ip not in _login_attempts: _login_attempts[ip] = [] _login_attempts[ip].append(now)