/ middleware / failover_middleware.py
failover_middleware.py
  1  """
  2  Failover middleware for AI models.
  3  Automatically manages failover to alternative models in case of errors,
  4  implements retry strategies, and ensures better service availability.
  5  """
  6  
  7  import json
  8  import time
  9  import logging
 10  import traceback
 11  from typing import Dict, List, Any, Optional, Callable, Tuple, Set
 12  from collections import defaultdict
 13  import random
 14  from fastapi import Request, Response
 15  from fastapi.responses import JSONResponse
 16  from starlette.middleware.base import BaseHTTPMiddleware
 17  from starlette.types import ASGIApp
 18  
 19  # Logging configuration
 20  logger = logging.getLogger("failover_middleware")
 21  
 22  class FailoverConfig:
 23      """Configuration of failover strategies for a model type."""
 24      
 25      def __init__(
 26          self,
 27          model_type: str,
 28          alternatives: Dict[str, List[str]],
 29          max_retries: int = 3,
 30          backoff_factor: float = 1.5,
 31          jitter: float = 0.1,
 32          cooldown_period: int = 300,  # 5 minutes
 33      ):
 34          """
 35          Initializes the failover configuration.
 36          
 37          Args:
 38              model_type: Model type (text, video, transcription, etc.)
 39              alternatives: Dictionary of alternative models by main model
 40              max_retries: Maximum number of retry attempts
 41              backoff_factor: Exponential backoff factor
 42              jitter: Random variation to avoid request storms
 43              cooldown_period: Cooling period before retrying a failed model (seconds)
 44          """
 45          self.model_type = model_type
 46          self.alternatives = alternatives
 47          self.max_retries = max_retries
 48          self.backoff_factor = backoff_factor
 49          self.jitter = jitter
 50          self.cooldown_period = cooldown_period
 51  
 52  class ModelStatus:
 53      """Availability status of a model."""
 54      
 55      def __init__(self, model_id: str):
 56          self.model_id = model_id
 57          self.available = True
 58          self.failure_count = 0
 59          self.last_failure_time = 0
 60          self.recovery_count = 0
 61          self.cumulative_errors = 0
 62      
 63      def mark_failure(self) -> None:
 64          """Marks the model as having failed."""
 65          self.available = False
 66          self.failure_count += 1
 67          self.cumulative_errors += 1
 68          self.last_failure_time = time.time()
 69      
 70      def mark_success(self) -> None:
 71          """Marks the model as functional."""
 72          if not self.available:
 73              self.recovery_count += 1
 74          self.available = True
 75          self.failure_count = 0
 76      
 77      def should_retry(self, cooldown_period: int) -> bool:
 78          """Checks if the model should be retried after a cooling period."""
 79          if self.available:
 80              return True
 81          
 82          # Calculate a progressive cooldown based on the number of failures
 83          adjusted_cooldown = cooldown_period * min(5, self.failure_count)
 84          
 85          # Check if the cooling time has elapsed
 86          return (time.time() - self.last_failure_time) > adjusted_cooldown
 87      
 88      def __str__(self) -> str:
 89          return (f"ModelStatus(model_id={self.model_id}, available={self.available}, "
 90                  f"failures={self.failure_count}, recoveries={self.recovery_count}, "
 91                  f"total_errors={self.cumulative_errors})")
 92  
 93  class FailoverManager:
 94      """Centralized manager for failover strategies."""
 95      
 96      def __init__(self):
 97          # Configurations by model type
 98          self.configs: Dict[str, FailoverConfig] = {}
 99          
100          # Model status
101          self.model_status: Dict[str, ModelStatus] = {}
102          
103          # Failover history for analysis
104          self.failover_history: List[Dict[str, Any]] = []
105          self.history_max_size = 100
106          
107          # Metrics counters
108          self.metrics = {
109              "total_failovers": 0,
110              "successful_failovers": 0,
111              "failed_failovers": 0,
112              "models_recovered": 0
113          }
114      
115      def register_config(self, config: FailoverConfig) -> None:
116          """Registers a failover configuration."""
117          self.configs[config.model_type] = config
118          
119          # Initialize status for all models in this configuration
120          for primary, alternatives in config.alternatives.items():
121              self._ensure_model_status(primary)
122              for alt in alternatives:
123                  self._ensure_model_status(alt)
124      
125      def _ensure_model_status(self, model_id: str) -> None:
126          """Ensures a status exists for the given model."""
127          if model_id not in self.model_status:
128              self.model_status[model_id] = ModelStatus(model_id)
129      
130      def get_alternative_model(self, model_type: str, original_model: str) -> Optional[str]:
131          """
132          Gets an available alternative model for the given model.
133          
134          Args:
135              model_type: Model type (text, video, transcription, etc.)
136              original_model: Original model identifier
137              
138          Returns:
139              Identifier of an available alternative model or None if none is available
140          """
141          if model_type not in self.configs:
142              logger.warning(f"No failover configuration for model type: {model_type}")
143              return None
144          
145          config = self.configs[model_type]
146          
147          # Check if the original model has alternatives
148          if original_model not in config.alternatives:
149              logger.warning(f"No alternatives configured for model: {original_model}")
150              return None
151          
152          # Get the list of alternatives
153          alternatives = config.alternatives[original_model]
154          if not alternatives:
155              return None
156          
157          # Filter available alternatives
158          available_alternatives = [
159              alt for alt in alternatives 
160              if alt in self.model_status and self.model_status[alt].should_retry(config.cooldown_period)
161          ]
162          
163          if not available_alternatives:
164              logger.warning(f"No available alternatives for {original_model}")
165              return None
166          
167          # Choose an alternative randomly (simple load balancing)
168          return random.choice(available_alternatives)
169      
170      def mark_model_failure(self, model_id: str) -> None:
171          """Marks a model as having failed."""
172          self._ensure_model_status(model_id)
173          self.model_status[model_id].mark_failure()
174          logger.warning(f"Model marked as failed: {model_id}")
175      
176      def mark_model_success(self, model_id: str) -> None:
177          """Marks a model as functional."""
178          self._ensure_model_status(model_id)
179          
180          # If the model was previously failing, increment the recovery counter
181          if not self.model_status[model_id].available:
182              self.metrics["models_recovered"] += 1
183              logger.info(f"Model recovered: {model_id}")
184          
185          self.model_status[model_id].mark_success()
186      
187      def record_failover(self, original_model: str, alternative_model: str, success: bool, error: Optional[str] = None) -> None:
188          """Records a failover event for analysis."""
189          event = {
190              "timestamp": time.time(),
191              "original_model": original_model,
192              "alternative_model": alternative_model,
193              "success": success,
194              "error": error
195          }
196          
197          self.failover_history.append(event)
198          
199          # Limit the history size
200          if len(self.failover_history) > self.history_max_size:
201              self.failover_history.pop(0)
202          
203          # Update metrics
204          self.metrics["total_failovers"] += 1
205          if success:
206              self.metrics["successful_failovers"] += 1
207          else:
208              self.metrics["failed_failovers"] += 1
209      
210      def get_model_health_report(self) -> Dict[str, Any]:
211          """Generates a report on the health status of models."""
212          report = {
213              "metrics": self.metrics.copy(),
214              "models": {}
215          }
216          
217          for model_id, status in self.model_status.items():
218              report["models"][model_id] = {
219                  "available": status.available,
220                  "failure_count": status.failure_count,
221                  "recovery_count": status.recovery_count,
222                  "last_failure": status.last_failure_time,
223                  "total_errors": status.cumulative_errors
224              }
225          
226          return report
227  
228  # Create a single instance of the manager
229  failover_manager = FailoverManager()
230  
231  # Default failover configuration for different model types
232  default_text_failover = FailoverConfig(
233      model_type="text",
234      alternatives={
235          "deepseek-coder-33b-instruct": ["deepseek-coder-6.7b-instruct", "codellama-7b-instruct"],
236          "llama-3-70b-instruct": ["llama-3-8b-instruct", "mistral-7b-instruct"],
237          "mistral-7b-instruct": ["llama-3-8b-instruct", "deepseek-coder-6.7b-instruct"],
238          "claude-3-haiku": ["llama-3-8b-instruct", "mistral-7b-instruct"],
239      }
240  )
241  
242  default_transcription_failover = FailoverConfig(
243      model_type="transcription",
244      alternatives={
245          "whisper-large-v3": ["whisper-medium", "whisper-small"],
246          "whisper-medium": ["whisper-small", "whisper-base"],
247          "whisper-small": ["whisper-base", "whisper-tiny"],
248      }
249  )
250  
251  default_video_failover = FailoverConfig(
252      model_type="video",
253      alternatives={
254          "internvideo-14b": ["internvideo-7b", "videollama-7b"],
255          "videollama-7b": ["internvideo-7b", "videollama-3b"],
256      }
257  )
258  
259  # Register default configurations
260  failover_manager.register_config(default_text_failover)
261  failover_manager.register_config(default_transcription_failover)
262  failover_manager.register_config(default_video_failover)
263  
264  class FailoverMiddleware(BaseHTTPMiddleware):
265      """
266      Middleware that manages automatic failover between models in case of errors.
267      Improves system resilience in the face of specific model problems.
268      """
269      
270      def __init__(
271          self,
272          app: ASGIApp,
273          exclude_paths: List[str] = None,
274          exclude_prefixes: List[str] = None,
275          default_model_type: str = "text",
276      ):
277          """
278          Initializes the failover middleware.
279          
280          Args:
281              app: ASGI Application
282              exclude_paths: Paths excluded from failover processing
283              exclude_prefixes: Path prefixes excluded from failover processing
284              default_model_type: Default model type when not specified
285          """
286          super().__init__(app)
287          self.exclude_paths = exclude_paths or ["/api/health", "/api/docs", "/api/redoc", "/api/openapi.json"]
288          self.exclude_prefixes = exclude_prefixes or ["/static/", "/docs/"]
289          self.default_model_type = default_model_type
290          
291          # Model type indicators in URLs
292          self.model_type_indicators = {
293              "/transcription/": "transcription",
294              "/video/": "video",
295              "/inference/": "text",
296              "/inference/text": "text",
297              "/inference/embedding": "embedding",
298              "/inference/image": "image",
299          }
300      
301      async def dispatch(self, request: Request, call_next: Callable) -> Response:
302          """Processes requests with failover strategy."""
303          # Check if this path should be excluded
304          if self._is_excluded_path(request.url.path):
305              return await call_next(request)
306          
307          # Get the model type and original model identifier
308          model_type, model_id = self._extract_model_info(request)
309          
310          # If no model is specified, no failover is possible
311          if not model_id:
312              return await call_next(request)
313          
314          # Try to execute the original request
315          try:
316              response = await call_next(request)
317              
318              # If the response is successful, mark the model as functional
319              if response.status_code < 400:
320                  failover_manager.mark_model_success(model_id)
321                  return response
322              
323              # If it's an application-specific error, don't failover
324              if response.status_code in [400, 401, 403, 404]:
325                  return response
326              
327          except Exception as e:
328              logger.error(f"Exception during original request: {str(e)}")
329              failover_manager.mark_model_failure(model_id)
330              # Continue to failover
331          
332          # If we get here, the original request failed
333          failover_manager.mark_model_failure(model_id)
334          
335          # Try failover with an alternative model
336          alternative_model = failover_manager.get_alternative_model(model_type, model_id)
337          if not alternative_model:
338              logger.warning(f"No alternative model available for {model_id} of type {model_type}")
339              
340              # Return an error indicating model failure
341              return JSONResponse(
342                  status_code=503,
343                  content={
344                      "detail": "The requested model is temporarily unavailable and no alternative is available",
345                      "model": model_id,
346                      "type": model_type,
347                      "retry_after": 300  # Suggest a delay of 5 minutes
348                  }
349              )
350          
351          logger.info(f"Attempting failover from {model_id} to {alternative_model}")
352          
353          # Create a new request with the alternative model
354          try:
355              # Create a modified copy of the request
356              modified_request = await self._create_modified_request(request, model_id, alternative_model)
357              
358              # Execute the modified request
359              response = await call_next(modified_request)
360              
361              # Check if the alternative request succeeded
362              if response.status_code < 400:
363                  # Record successful failover
364                  failover_manager.record_failover(model_id, alternative_model, True)
365                  failover_manager.mark_model_success(alternative_model)
366                  
367                  # Add a header to inform the client of the failover
368                  response.headers["X-Model-Failover"] = f"Original: {model_id}, Alternative: {alternative_model}"
369                  
370                  return response
371              else:
372                  # Record failover failure
373                  failover_manager.mark_model_failure(alternative_model)
374                  failover_manager.record_failover(model_id, alternative_model, False, 
375                                                f"Status code: {response.status_code}")
376                  
377                  # Return a detailed error
378                  return JSONResponse(
379                      status_code=503,
380                      content={
381                          "detail": "All available models have failed",
382                          "original_model": model_id,
383                          "alternative_model": alternative_model,
384                          "retry_after": 600  # Suggest a longer delay
385                      }
386                  )
387          
388          except Exception as e:
389              logger.error(f"Exception during failover: {str(e)}")
390              failover_manager.mark_model_failure(alternative_model)
391              failover_manager.record_failover(model_id, alternative_model, False, str(e))
392              
393              # Return an error
394              return JSONResponse(
395                  status_code=500,
396                  content={
397                      "detail": "Error during failover processing",
398                      "message": str(e),
399                      "original_model": model_id,
400                      "alternative_model": alternative_model
401                  }
402              )
403      
404      def _is_excluded_path(self, path: str) -> bool:
405          """Checks if the path is excluded from failover processing."""
406          if path in self.exclude_paths:
407              return True
408          
409          for prefix in self.exclude_prefixes:
410              if path.startswith(prefix):
411                  return True
412          
413          return False
414      
415      def _extract_model_info(self, request: Request) -> Tuple[str, Optional[str]]:
416          """
417          Extracts the model type and model identifier from the request.
418          
419          Args:
420              request: FastAPI Request
421              
422          Returns:
423              Tuple containing the model type and model identifier
424          """
425          # Determine the model type from the path
426          model_type = self.default_model_type
427          for indicator, type_value in self.model_type_indicators.items():
428              if indicator in request.url.path:
429                  model_type = type_value
430                  break
431          
432          # Get the model identifier
433          model_id = request.query_params.get("model")
434          
435          # If the model is not in the query params, look in the body
436          if not model_id and request.method in ["POST", "PUT", "PATCH"]:
437              try:
438                  body_bytes = getattr(request, "_body", None)
439                  if body_bytes:
440                      body = json.loads(body_bytes.decode("utf-8"))
441                      model_id = body.get("model") or body.get("model_id") or body.get("engine_id")
442              except Exception:
443                  pass
444          
445          return model_type, model_id
446      
447      async def _create_modified_request(self, request: Request, original_model: str, alternative_model: str) -> Request:
448          """
449          Creates a modified copy of the request with the alternative model.
450          
451          Args:
452              request: Original request
453              original_model: Original model identifier
454              alternative_model: Alternative model identifier
455              
456          Returns:
457              Modified request
458          """
459          # Creating a copy of the request is complex as Request is an immutable object
460          # We will instead store the alternative model in the request state
461          
462          request.state.alternative_model = alternative_model
463          request.state.original_model = original_model
464          
465          # For requests with a body, we need to replace the model in the body
466          if request.method in ["POST", "PUT", "PATCH"]:
467              try:
468                  body_bytes = getattr(request, "_body", None)
469                  if body_bytes:
470                      body = json.loads(body_bytes.decode("utf-8"))
471                      
472                      # Replace the model in the body
473                      if "model" in body:
474                          body["model"] = alternative_model
475                      elif "model_id" in body:
476                          body["model_id"] = alternative_model
477                      elif "engine_id" in body:
478                          body["engine_id"] = alternative_model
479                      
480                      # Replace the request body
481                      new_body = json.dumps(body).encode("utf-8")
482                      request._body = new_body
483              except Exception as e:
484                  logger.error(f"Error when modifying the body: {str(e)}")
485          
486          return request
487  
488  # Function to get model health
489  def get_models_health() -> Dict[str, Any]:
490      """Gets a report on the health status of models."""
491      return failover_manager.get_model_health_report()
492  
493  # Function to reset a model manually
494  def reset_model_status(model_id: str) -> Dict[str, Any]:
495      """Manually resets a model's status."""
496      if model_id in failover_manager.model_status:
497          failover_manager.mark_model_success(model_id)
498          return {"success": True, "message": f"Model status {model_id} reset"}
499      else:
500          return {"success": False, "message": f"Model {model_id} not found"}
501  
502  # Function to configure alternatives for a model
503  def configure_failover(
504      model_type: str,
505      model_id: str,
506      alternatives: List[str],
507      max_retries: int = 3,
508      cooldown_period: int = 300
509  ) -> Dict[str, Any]:
510      """
511      Configures failover alternatives for a model.
512      
513      Args:
514          model_type: Model type
515          model_id: Model identifier
516          alternatives: List of alternative model identifiers
517          max_retries: Maximum number of retry attempts
518          cooldown_period: Cooling period in seconds
519          
520      Returns:
521          Result dictionary
522      """
523      # Check if a configuration already exists for this type
524      if model_type not in failover_manager.configs:
525          config = FailoverConfig(
526              model_type=model_type,
527              alternatives={model_id: alternatives},
528              max_retries=max_retries,
529              cooldown_period=cooldown_period
530          )
531          failover_manager.register_config(config)
532      else:
533          # Update the existing configuration
534          failover_manager.configs[model_type].alternatives[model_id] = alternatives
535          failover_manager.configs[model_type].max_retries = max_retries
536          failover_manager.configs[model_type].cooldown_period = cooldown_period
537      
538      return {
539          "success": True,
540          "message": f"Failover configuration updated for {model_id} of type {model_type}"
541      }