/ api_key_middleware.py
api_key_middleware.py
1 from fastapi import Request, HTTPException, status 2 from fastapi.responses import JSONResponse 3 from typing import Callable, List, Dict, Any, Optional 4 import time 5 import json 6 import logging 7 from starlette.middleware.base import BaseHTTPMiddleware 8 from starlette.types import ASGIApp 9 10 from auth import validate_api_key, check_usage_limits, record_usage, authorize_batch_processing, authorize_advanced_models 11 from auth_models import UsageRecord, ApiKey, ApiKeyLevel 12 from database import record_api_usage 13 14 # Configuration du logging 15 logger = logging.getLogger("auth_middleware") 16 17 class APIKeyMiddleware(BaseHTTPMiddleware): 18 """Middleware pour l'authentification par clé API et le suivi d'utilisation.""" 19 20 def __init__( 21 self, 22 app: ASGIApp, 23 exclude_paths: List[str] = None, 24 exclude_prefixes: List[str] = None, 25 admin_paths: List[str] = None 26 ): 27 super().__init__(app) 28 self.exclude_paths = exclude_paths or ["/docs", "/redoc", "/openapi.json", "/auth/token", "/auth/register"] 29 self.exclude_prefixes = exclude_prefixes or ["/static/", "/assets/"] 30 self.admin_paths = admin_paths or ["/auth/admin/"] 31 32 async def dispatch(self, request: Request, call_next: Callable): 33 # Mesurer le temps de traitement 34 start_time = time.time() 35 36 # Vérifier si le chemin est exclu de l'authentification 37 path = request.url.path 38 if self._is_excluded_path(path): 39 response = await call_next(request) 40 return response 41 42 # Vérifier si le chemin nécessite des droits d'administrateur 43 if self._is_admin_path(path) and not await self._check_admin_rights(request): 44 return JSONResponse( 45 status_code=status.HTTP_403_FORBIDDEN, 46 content={"detail": "Accès réservé aux administrateurs"} 47 ) 48 49 # Extraire et valider la clé API 50 api_key = request.headers.get("X-API-Key") 51 if not api_key: 52 return JSONResponse( 53 status_code=status.HTTP_401_UNAUTHORIZED, 54 content={"detail": "Clé API requise"} 55 ) 56 57 try: 58 # Valider la clé API 59 from auth import get_api_key 60 api_key_info = get_api_key(api_key) 61 if not api_key_info: 62 return JSONResponse( 63 status_code=status.HTTP_401_UNAUTHORIZED, 64 content={"detail": "Clé API invalide"} 65 ) 66 67 # Vérifier si la clé est active 68 if not api_key_info.is_active: 69 return JSONResponse( 70 status_code=status.HTTP_403_FORBIDDEN, 71 content={"detail": "Clé API inactive"} 72 ) 73 74 # Vérifier les autorisations spécifiques selon le chemin 75 if "/api/batch" in path: 76 try: 77 authorize_batch_processing(api_key_info) 78 except HTTPException as e: 79 return JSONResponse( 80 status_code=e.status_code, 81 content={"detail": e.detail} 82 ) 83 84 # Vérifier l'accès aux modèles avancés si nécessaire 85 if request.query_params.get("model") and request.query_params.get("model") != "default": 86 try: 87 authorize_advanced_models(api_key_info) 88 except HTTPException as e: 89 return JSONResponse( 90 status_code=e.status_code, 91 content={"detail": e.detail} 92 ) 93 94 # Vérifier les limites d'utilisation pour les requêtes qui envoient du texte 95 if request.method == "POST" and ("/api/inference" in path or "/api/batch" in path): 96 try: 97 # Extraire le contenu de la requête 98 body = await self._get_request_body(request) 99 100 # Calculer la longueur du texte 101 text_length = 0 102 if "text" in body: 103 text_length = len(body["text"]) 104 elif "texts" in body: 105 text_length = sum(len(t) for t in body["texts"]) 106 107 # Vérifier les limites 108 check_usage_limits(api_key_info, text_length) 109 110 except HTTPException as e: 111 return JSONResponse( 112 status_code=e.status_code, 113 content={"detail": e.detail} 114 ) 115 except Exception as e: 116 logger.error(f"Erreur lors de la vérification des limites: {e}") 117 118 # Stocker la clé API dans la requête pour un accès ultérieur 119 request.state.api_key_info = api_key_info 120 121 # Exécuter le prochain middleware ou le gestionnaire de route 122 response = await call_next(request) 123 124 # Enregistrer l'utilisation 125 processing_time = time.time() - start_time 126 self._record_api_usage(request, response, api_key_info, processing_time) 127 128 # Ajouter des headers d'information sur l'utilisation 129 self._add_usage_headers(response, api_key_info) 130 131 return response 132 133 except Exception as e: 134 logger.error(f"Erreur dans le middleware API Key: {e}") 135 return JSONResponse( 136 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 137 content={"detail": "Erreur interne du serveur"} 138 ) 139 140 def _is_excluded_path(self, path: str) -> bool: 141 """Vérifie si le chemin est exclu de l'authentification.""" 142 if path in self.exclude_paths: 143 return True 144 145 for prefix in self.exclude_prefixes: 146 if path.startswith(prefix): 147 return True 148 149 return False 150 151 def _is_admin_path(self, path: str) -> bool: 152 """Vérifie si le chemin nécessite des droits d'administrateur.""" 153 for admin_path in self.admin_paths: 154 if path.startswith(admin_path): 155 return True 156 return False 157 158 async def _check_admin_rights(self, request: Request) -> bool: 159 """Vérifie si l'utilisateur a des droits d'administrateur.""" 160 # Cette méthode devrait vérifier le token JWT pour les droits admin 161 # Pour simplifier, nous vérifions simplement si la clé API a un niveau ENTERPRISE 162 api_key = request.headers.get("X-API-Key") 163 if not api_key: 164 return False 165 166 from auth import get_api_key 167 api_key_info = get_api_key(api_key) 168 169 if not api_key_info: 170 return False 171 172 # Vérifier si l'utilisateur associé à la clé API a le rôle admin 173 from database import get_user_by_id 174 user = get_user_by_id(api_key_info.user_id) 175 if not user: 176 return False 177 178 return "admin" in user.roles 179 180 async def _get_request_body(self, request: Request) -> Dict[str, Any]: 181 """Extrait le corps de la requête.""" 182 try: 183 body_bytes = await request.body() 184 body_str = body_bytes.decode('utf-8') 185 return json.loads(body_str) 186 except Exception as e: 187 logger.error(f"Erreur lors de l'extraction du corps de la requête: {e}") 188 return {} 189 190 def _record_api_usage(self, request: Request, response, api_key_info: ApiKey, processing_time: float): 191 """Enregistre l'utilisation de l'API.""" 192 try: 193 # Créer un enregistrement d'utilisation 194 usage_record = UsageRecord( 195 user_id=api_key_info.user_id, 196 api_key_id=api_key_info.key, 197 request_path=str(request.url.path), 198 request_method=request.method, 199 tokens_input=getattr(request.state, "tokens_input", 0), 200 tokens_output=getattr(request.state, "tokens_output", 0), 201 processing_time=processing_time, 202 status_code=response.status_code 203 ) 204 205 # Enregistrer l'utilisation 206 record_api_usage(usage_record) 207 208 except Exception as e: 209 logger.error(f"Erreur lors de l'enregistrement de l'utilisation: {e}") 210 211 def _add_usage_headers(self, response, api_key_info: ApiKey): 212 """Ajoute des headers d'information sur l'utilisation.""" 213 from auth import get_usage_limits 214 try: 215 # Récupérer les limites d'utilisation 216 limits = get_usage_limits(api_key_info.level) 217 218 # Calculer l'utilisation quotidienne et mensuelle 219 today = time.strftime("%Y-%m-%d") 220 current_month = time.strftime("%Y-%m") 221 222 daily_usage = api_key_info.usage.get(today, 0) 223 monthly_usage = sum(count for date, count in api_key_info.usage.items() if date.startswith(current_month)) 224 225 # Ajouter les headers 226 response.headers["X-Rate-Limit-Limit-Day"] = str(limits.daily_requests) 227 response.headers["X-Rate-Limit-Remaining-Day"] = str(max(0, limits.daily_requests - daily_usage)) 228 response.headers["X-Rate-Limit-Limit-Month"] = str(limits.monthly_requests) 229 response.headers["X-Rate-Limit-Remaining-Month"] = str(max(0, limits.monthly_requests - monthly_usage)) 230 response.headers["X-Rate-Limit-Type"] = api_key_info.level 231 232 except Exception as e: 233 logger.error(f"Erreur lors de l'ajout des headers d'utilisation: {e}")