""" JWT Authentication System for AI Server Admin Phase 3: Security Enhancement """ import jwt import bcrypt import secrets from datetime import datetime, timedelta from typing import Optional, Dict, Any from fastapi import HTTPException, Depends, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import os # JWT Configuration JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32)) JWT_ALGORITHM = "HS256" JWT_EXPIRATION_HOURS = 24 JWT_REMEMBER_DAYS = 30 # Security security = HTTPBearer() # In-memory user store (in production, use a proper database) USERS_DB = { "admin": { "username": "admin", "password_hash": bcrypt.hashpw("admin123".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), "role": "admin", "created_at": datetime.now().isoformat(), "last_login": None, "login_attempts": 0, "locked_until": None }, "hyungi": { "username": "hyungi", "password_hash": bcrypt.hashpw("hyungi123".encode('utf-8'), bcrypt.gensalt()).decode('utf-8'), "role": "system", "created_at": datetime.now().isoformat(), "last_login": None, "login_attempts": 0, "locked_until": None } } # Login attempt tracking LOGIN_ATTEMPTS = {} MAX_LOGIN_ATTEMPTS = 5 LOCKOUT_DURATION_MINUTES = 15 class AuthManager: @staticmethod def hash_password(password: str) -> str: """Hash password using bcrypt""" return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') @staticmethod def verify_password(password: str, password_hash: str) -> bool: """Verify password against hash""" return bcrypt.checkpw(password.encode('utf-8'), password_hash.encode('utf-8')) @staticmethod def create_jwt_token(user_data: Dict[str, Any], remember_me: bool = False) -> str: """Create JWT token""" expiration = datetime.utcnow() + timedelta( days=JWT_REMEMBER_DAYS if remember_me else 0, hours=JWT_EXPIRATION_HOURS if not remember_me else 0 ) payload = { "username": user_data["username"], "role": user_data["role"], "exp": expiration, "iat": datetime.utcnow(), "type": "remember" if remember_me else "session" } return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) @staticmethod def verify_jwt_token(token: str) -> Dict[str, Any]: """Verify and decode JWT token""" try: payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token has expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") @staticmethod def is_account_locked(username: str) -> bool: """Check if account is locked due to failed attempts""" user = USERS_DB.get(username) if not user: return False if user["locked_until"]: locked_until = datetime.fromisoformat(user["locked_until"]) if datetime.now() < locked_until: return True else: # Unlock account user["locked_until"] = None user["login_attempts"] = 0 return False @staticmethod def record_login_attempt(username: str, success: bool, ip_address: str = None): """Record login attempt""" user = USERS_DB.get(username) if not user: return if success: user["login_attempts"] = 0 user["locked_until"] = None user["last_login"] = datetime.now().isoformat() else: user["login_attempts"] += 1 # Lock account after max attempts if user["login_attempts"] >= MAX_LOGIN_ATTEMPTS: user["locked_until"] = ( datetime.now() + timedelta(minutes=LOCKOUT_DURATION_MINUTES) ).isoformat() @staticmethod def authenticate_user(username: str, password: str) -> Optional[Dict[str, Any]]: """Authenticate user credentials""" user = USERS_DB.get(username) if not user: return None if AuthManager.is_account_locked(username): raise HTTPException( status_code=423, detail=f"Account locked due to too many failed attempts. Try again later." ) if AuthManager.verify_password(password, user["password_hash"]): return user return None # Dependency functions async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): """Get current authenticated user from JWT token""" try: payload = AuthManager.verify_jwt_token(credentials.credentials) username = payload.get("username") user = USERS_DB.get(username) if not user: raise HTTPException(status_code=401, detail="User not found") return { "username": user["username"], "role": user["role"], "token_type": payload.get("type", "session") } except Exception as e: raise HTTPException(status_code=401, detail="Invalid authentication credentials") async def require_admin_role(current_user: dict = Depends(get_current_user)): """Require admin or system role""" if current_user["role"] not in ["admin", "system"]: raise HTTPException(status_code=403, detail="Admin privileges required") return current_user async def require_system_role(current_user: dict = Depends(get_current_user)): """Require system role""" if current_user["role"] != "system": raise HTTPException(status_code=403, detail="System privileges required") return current_user # Legacy API key support (for backward compatibility) async def get_current_user_or_api_key( request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), x_api_key: Optional[str] = None ): """Support both JWT and API key authentication""" # Try JWT first if credentials: try: return await get_current_user(credentials) except HTTPException: pass # Fall back to API key api_key = x_api_key or request.headers.get("X-API-Key") if api_key and api_key == os.getenv("API_KEY", "test-admin-key-123"): return { "username": "api_user", "role": "system", "token_type": "api_key" } raise HTTPException(status_code=401, detail="Authentication required") # Audit logging class AuditLogger: @staticmethod def log_login(username: str, success: bool, ip_address: str = None, user_agent: str = None): """Log login attempt""" log_entry = { "timestamp": datetime.now().isoformat(), "event": "login_attempt", "username": username, "success": success, "ip_address": ip_address, "user_agent": user_agent } print(f"AUDIT: {log_entry}") # In production, use proper logging @staticmethod def log_admin_action(username: str, action: str, details: str = None, ip_address: str = None): """Log admin action""" log_entry = { "timestamp": datetime.now().isoformat(), "event": "admin_action", "username": username, "action": action, "details": details, "ip_address": ip_address } print(f"AUDIT: {log_entry}") # In production, use proper logging