/ api_key_middleware.py
api_key_middleware.py
  1  from fastapi import Request, HTTPException, status
  2  from fastapi.responses import JSONResponse
  3  from typing import Callable, List, Dict, Any, Optional
  4  import time
  5  import json
  6  import logging
  7  from starlette.middleware.base import BaseHTTPMiddleware
  8  from starlette.types import ASGIApp
  9  
 10  from auth import validate_api_key, check_usage_limits, record_usage, authorize_batch_processing, authorize_advanced_models
 11  from auth_models import UsageRecord, ApiKey, ApiKeyLevel
 12  from database import record_api_usage
 13  
 14  # Configuration du logging
 15  logger = logging.getLogger("auth_middleware")
 16  
 17  class APIKeyMiddleware(BaseHTTPMiddleware):
 18      """Middleware pour l'authentification par clé API et le suivi d'utilisation."""
 19      
 20      def __init__(
 21          self,
 22          app: ASGIApp,
 23          exclude_paths: List[str] = None,
 24          exclude_prefixes: List[str] = None,
 25          admin_paths: List[str] = None
 26      ):
 27          super().__init__(app)
 28          self.exclude_paths = exclude_paths or ["/docs", "/redoc", "/openapi.json", "/auth/token", "/auth/register"]
 29          self.exclude_prefixes = exclude_prefixes or ["/static/", "/assets/"]
 30          self.admin_paths = admin_paths or ["/auth/admin/"]
 31      
 32      async def dispatch(self, request: Request, call_next: Callable):
 33          # Mesurer le temps de traitement
 34          start_time = time.time()
 35          
 36          # Vérifier si le chemin est exclu de l'authentification
 37          path = request.url.path
 38          if self._is_excluded_path(path):
 39              response = await call_next(request)
 40              return response
 41          
 42          # Vérifier si le chemin nécessite des droits d'administrateur
 43          if self._is_admin_path(path) and not await self._check_admin_rights(request):
 44              return JSONResponse(
 45                  status_code=status.HTTP_403_FORBIDDEN,
 46                  content={"detail": "Accès réservé aux administrateurs"}
 47              )
 48          
 49          # Extraire et valider la clé API
 50          api_key = request.headers.get("X-API-Key")
 51          if not api_key:
 52              return JSONResponse(
 53                  status_code=status.HTTP_401_UNAUTHORIZED,
 54                  content={"detail": "Clé API requise"}
 55              )
 56          
 57          try:
 58              # Valider la clé API
 59              from auth import get_api_key
 60              api_key_info = get_api_key(api_key)
 61              if not api_key_info:
 62                  return JSONResponse(
 63                      status_code=status.HTTP_401_UNAUTHORIZED,
 64                      content={"detail": "Clé API invalide"}
 65                  )
 66              
 67              # Vérifier si la clé est active
 68              if not api_key_info.is_active:
 69                  return JSONResponse(
 70                      status_code=status.HTTP_403_FORBIDDEN,
 71                      content={"detail": "Clé API inactive"}
 72                  )
 73              
 74              # Vérifier les autorisations spécifiques selon le chemin
 75              if "/api/batch" in path:
 76                  try:
 77                      authorize_batch_processing(api_key_info)
 78                  except HTTPException as e:
 79                      return JSONResponse(
 80                          status_code=e.status_code,
 81                          content={"detail": e.detail}
 82                      )
 83              
 84              # Vérifier l'accès aux modèles avancés si nécessaire
 85              if request.query_params.get("model") and request.query_params.get("model") != "default":
 86                  try:
 87                      authorize_advanced_models(api_key_info)
 88                  except HTTPException as e:
 89                      return JSONResponse(
 90                          status_code=e.status_code,
 91                          content={"detail": e.detail}
 92                      )
 93              
 94              # Vérifier les limites d'utilisation pour les requêtes qui envoient du texte
 95              if request.method == "POST" and ("/api/inference" in path or "/api/batch" in path):
 96                  try:
 97                      # Extraire le contenu de la requête
 98                      body = await self._get_request_body(request)
 99                      
100                      # Calculer la longueur du texte
101                      text_length = 0
102                      if "text" in body:
103                          text_length = len(body["text"])
104                      elif "texts" in body:
105                          text_length = sum(len(t) for t in body["texts"])
106                      
107                      # Vérifier les limites
108                      check_usage_limits(api_key_info, text_length)
109                      
110                  except HTTPException as e:
111                      return JSONResponse(
112                          status_code=e.status_code,
113                          content={"detail": e.detail}
114                      )
115                  except Exception as e:
116                      logger.error(f"Erreur lors de la vérification des limites: {e}")
117              
118              # Stocker la clé API dans la requête pour un accès ultérieur
119              request.state.api_key_info = api_key_info
120              
121              # Exécuter le prochain middleware ou le gestionnaire de route
122              response = await call_next(request)
123              
124              # Enregistrer l'utilisation
125              processing_time = time.time() - start_time
126              self._record_api_usage(request, response, api_key_info, processing_time)
127              
128              # Ajouter des headers d'information sur l'utilisation
129              self._add_usage_headers(response, api_key_info)
130              
131              return response
132              
133          except Exception as e:
134              logger.error(f"Erreur dans le middleware API Key: {e}")
135              return JSONResponse(
136                  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
137                  content={"detail": "Erreur interne du serveur"}
138              )
139      
140      def _is_excluded_path(self, path: str) -> bool:
141          """Vérifie si le chemin est exclu de l'authentification."""
142          if path in self.exclude_paths:
143              return True
144          
145          for prefix in self.exclude_prefixes:
146              if path.startswith(prefix):
147                  return True
148          
149          return False
150      
151      def _is_admin_path(self, path: str) -> bool:
152          """Vérifie si le chemin nécessite des droits d'administrateur."""
153          for admin_path in self.admin_paths:
154              if path.startswith(admin_path):
155                  return True
156          return False
157      
158      async def _check_admin_rights(self, request: Request) -> bool:
159          """Vérifie si l'utilisateur a des droits d'administrateur."""
160          # Cette méthode devrait vérifier le token JWT pour les droits admin
161          # Pour simplifier, nous vérifions simplement si la clé API a un niveau ENTERPRISE
162          api_key = request.headers.get("X-API-Key")
163          if not api_key:
164              return False
165          
166          from auth import get_api_key
167          api_key_info = get_api_key(api_key)
168          
169          if not api_key_info:
170              return False
171          
172          # Vérifier si l'utilisateur associé à la clé API a le rôle admin
173          from database import get_user_by_id
174          user = get_user_by_id(api_key_info.user_id)
175          if not user:
176              return False
177          
178          return "admin" in user.roles
179      
180      async def _get_request_body(self, request: Request) -> Dict[str, Any]:
181          """Extrait le corps de la requête."""
182          try:
183              body_bytes = await request.body()
184              body_str = body_bytes.decode('utf-8')
185              return json.loads(body_str)
186          except Exception as e:
187              logger.error(f"Erreur lors de l'extraction du corps de la requête: {e}")
188              return {}
189      
190      def _record_api_usage(self, request: Request, response, api_key_info: ApiKey, processing_time: float):
191          """Enregistre l'utilisation de l'API."""
192          try:
193              # Créer un enregistrement d'utilisation
194              usage_record = UsageRecord(
195                  user_id=api_key_info.user_id,
196                  api_key_id=api_key_info.key,
197                  request_path=str(request.url.path),
198                  request_method=request.method,
199                  tokens_input=getattr(request.state, "tokens_input", 0),
200                  tokens_output=getattr(request.state, "tokens_output", 0),
201                  processing_time=processing_time,
202                  status_code=response.status_code
203              )
204              
205              # Enregistrer l'utilisation
206              record_api_usage(usage_record)
207              
208          except Exception as e:
209              logger.error(f"Erreur lors de l'enregistrement de l'utilisation: {e}")
210      
211      def _add_usage_headers(self, response, api_key_info: ApiKey):
212          """Ajoute des headers d'information sur l'utilisation."""
213          from auth import get_usage_limits
214          try:
215              # Récupérer les limites d'utilisation
216              limits = get_usage_limits(api_key_info.level)
217              
218              # Calculer l'utilisation quotidienne et mensuelle
219              today = time.strftime("%Y-%m-%d")
220              current_month = time.strftime("%Y-%m")
221              
222              daily_usage = api_key_info.usage.get(today, 0)
223              monthly_usage = sum(count for date, count in api_key_info.usage.items() if date.startswith(current_month))
224              
225              # Ajouter les headers
226              response.headers["X-Rate-Limit-Limit-Day"] = str(limits.daily_requests)
227              response.headers["X-Rate-Limit-Remaining-Day"] = str(max(0, limits.daily_requests - daily_usage))
228              response.headers["X-Rate-Limit-Limit-Month"] = str(limits.monthly_requests)
229              response.headers["X-Rate-Limit-Remaining-Month"] = str(max(0, limits.monthly_requests - monthly_usage))
230              response.headers["X-Rate-Limit-Type"] = api_key_info.level
231              
232          except Exception as e:
233              logger.error(f"Erreur lors de l'ajout des headers d'utilisation: {e}")