/ middleware / failover_middleware.py
failover_middleware.py
1 """ 2 Failover middleware for AI models. 3 Automatically manages failover to alternative models in case of errors, 4 implements retry strategies, and ensures better service availability. 5 """ 6 7 import json 8 import time 9 import logging 10 import traceback 11 from typing import Dict, List, Any, Optional, Callable, Tuple, Set 12 from collections import defaultdict 13 import random 14 from fastapi import Request, Response 15 from fastapi.responses import JSONResponse 16 from starlette.middleware.base import BaseHTTPMiddleware 17 from starlette.types import ASGIApp 18 19 # Logging configuration 20 logger = logging.getLogger("failover_middleware") 21 22 class FailoverConfig: 23 """Configuration of failover strategies for a model type.""" 24 25 def __init__( 26 self, 27 model_type: str, 28 alternatives: Dict[str, List[str]], 29 max_retries: int = 3, 30 backoff_factor: float = 1.5, 31 jitter: float = 0.1, 32 cooldown_period: int = 300, # 5 minutes 33 ): 34 """ 35 Initializes the failover configuration. 36 37 Args: 38 model_type: Model type (text, video, transcription, etc.) 39 alternatives: Dictionary of alternative models by main model 40 max_retries: Maximum number of retry attempts 41 backoff_factor: Exponential backoff factor 42 jitter: Random variation to avoid request storms 43 cooldown_period: Cooling period before retrying a failed model (seconds) 44 """ 45 self.model_type = model_type 46 self.alternatives = alternatives 47 self.max_retries = max_retries 48 self.backoff_factor = backoff_factor 49 self.jitter = jitter 50 self.cooldown_period = cooldown_period 51 52 class ModelStatus: 53 """Availability status of a model.""" 54 55 def __init__(self, model_id: str): 56 self.model_id = model_id 57 self.available = True 58 self.failure_count = 0 59 self.last_failure_time = 0 60 self.recovery_count = 0 61 self.cumulative_errors = 0 62 63 def mark_failure(self) -> None: 64 """Marks the model as having failed.""" 65 self.available = False 66 self.failure_count += 1 67 self.cumulative_errors += 1 68 self.last_failure_time = time.time() 69 70 def mark_success(self) -> None: 71 """Marks the model as functional.""" 72 if not self.available: 73 self.recovery_count += 1 74 self.available = True 75 self.failure_count = 0 76 77 def should_retry(self, cooldown_period: int) -> bool: 78 """Checks if the model should be retried after a cooling period.""" 79 if self.available: 80 return True 81 82 # Calculate a progressive cooldown based on the number of failures 83 adjusted_cooldown = cooldown_period * min(5, self.failure_count) 84 85 # Check if the cooling time has elapsed 86 return (time.time() - self.last_failure_time) > adjusted_cooldown 87 88 def __str__(self) -> str: 89 return (f"ModelStatus(model_id={self.model_id}, available={self.available}, " 90 f"failures={self.failure_count}, recoveries={self.recovery_count}, " 91 f"total_errors={self.cumulative_errors})") 92 93 class FailoverManager: 94 """Centralized manager for failover strategies.""" 95 96 def __init__(self): 97 # Configurations by model type 98 self.configs: Dict[str, FailoverConfig] = {} 99 100 # Model status 101 self.model_status: Dict[str, ModelStatus] = {} 102 103 # Failover history for analysis 104 self.failover_history: List[Dict[str, Any]] = [] 105 self.history_max_size = 100 106 107 # Metrics counters 108 self.metrics = { 109 "total_failovers": 0, 110 "successful_failovers": 0, 111 "failed_failovers": 0, 112 "models_recovered": 0 113 } 114 115 def register_config(self, config: FailoverConfig) -> None: 116 """Registers a failover configuration.""" 117 self.configs[config.model_type] = config 118 119 # Initialize status for all models in this configuration 120 for primary, alternatives in config.alternatives.items(): 121 self._ensure_model_status(primary) 122 for alt in alternatives: 123 self._ensure_model_status(alt) 124 125 def _ensure_model_status(self, model_id: str) -> None: 126 """Ensures a status exists for the given model.""" 127 if model_id not in self.model_status: 128 self.model_status[model_id] = ModelStatus(model_id) 129 130 def get_alternative_model(self, model_type: str, original_model: str) -> Optional[str]: 131 """ 132 Gets an available alternative model for the given model. 133 134 Args: 135 model_type: Model type (text, video, transcription, etc.) 136 original_model: Original model identifier 137 138 Returns: 139 Identifier of an available alternative model or None if none is available 140 """ 141 if model_type not in self.configs: 142 logger.warning(f"No failover configuration for model type: {model_type}") 143 return None 144 145 config = self.configs[model_type] 146 147 # Check if the original model has alternatives 148 if original_model not in config.alternatives: 149 logger.warning(f"No alternatives configured for model: {original_model}") 150 return None 151 152 # Get the list of alternatives 153 alternatives = config.alternatives[original_model] 154 if not alternatives: 155 return None 156 157 # Filter available alternatives 158 available_alternatives = [ 159 alt for alt in alternatives 160 if alt in self.model_status and self.model_status[alt].should_retry(config.cooldown_period) 161 ] 162 163 if not available_alternatives: 164 logger.warning(f"No available alternatives for {original_model}") 165 return None 166 167 # Choose an alternative randomly (simple load balancing) 168 return random.choice(available_alternatives) 169 170 def mark_model_failure(self, model_id: str) -> None: 171 """Marks a model as having failed.""" 172 self._ensure_model_status(model_id) 173 self.model_status[model_id].mark_failure() 174 logger.warning(f"Model marked as failed: {model_id}") 175 176 def mark_model_success(self, model_id: str) -> None: 177 """Marks a model as functional.""" 178 self._ensure_model_status(model_id) 179 180 # If the model was previously failing, increment the recovery counter 181 if not self.model_status[model_id].available: 182 self.metrics["models_recovered"] += 1 183 logger.info(f"Model recovered: {model_id}") 184 185 self.model_status[model_id].mark_success() 186 187 def record_failover(self, original_model: str, alternative_model: str, success: bool, error: Optional[str] = None) -> None: 188 """Records a failover event for analysis.""" 189 event = { 190 "timestamp": time.time(), 191 "original_model": original_model, 192 "alternative_model": alternative_model, 193 "success": success, 194 "error": error 195 } 196 197 self.failover_history.append(event) 198 199 # Limit the history size 200 if len(self.failover_history) > self.history_max_size: 201 self.failover_history.pop(0) 202 203 # Update metrics 204 self.metrics["total_failovers"] += 1 205 if success: 206 self.metrics["successful_failovers"] += 1 207 else: 208 self.metrics["failed_failovers"] += 1 209 210 def get_model_health_report(self) -> Dict[str, Any]: 211 """Generates a report on the health status of models.""" 212 report = { 213 "metrics": self.metrics.copy(), 214 "models": {} 215 } 216 217 for model_id, status in self.model_status.items(): 218 report["models"][model_id] = { 219 "available": status.available, 220 "failure_count": status.failure_count, 221 "recovery_count": status.recovery_count, 222 "last_failure": status.last_failure_time, 223 "total_errors": status.cumulative_errors 224 } 225 226 return report 227 228 # Create a single instance of the manager 229 failover_manager = FailoverManager() 230 231 # Default failover configuration for different model types 232 default_text_failover = FailoverConfig( 233 model_type="text", 234 alternatives={ 235 "deepseek-coder-33b-instruct": ["deepseek-coder-6.7b-instruct", "codellama-7b-instruct"], 236 "llama-3-70b-instruct": ["llama-3-8b-instruct", "mistral-7b-instruct"], 237 "mistral-7b-instruct": ["llama-3-8b-instruct", "deepseek-coder-6.7b-instruct"], 238 "claude-3-haiku": ["llama-3-8b-instruct", "mistral-7b-instruct"], 239 } 240 ) 241 242 default_transcription_failover = FailoverConfig( 243 model_type="transcription", 244 alternatives={ 245 "whisper-large-v3": ["whisper-medium", "whisper-small"], 246 "whisper-medium": ["whisper-small", "whisper-base"], 247 "whisper-small": ["whisper-base", "whisper-tiny"], 248 } 249 ) 250 251 default_video_failover = FailoverConfig( 252 model_type="video", 253 alternatives={ 254 "internvideo-14b": ["internvideo-7b", "videollama-7b"], 255 "videollama-7b": ["internvideo-7b", "videollama-3b"], 256 } 257 ) 258 259 # Register default configurations 260 failover_manager.register_config(default_text_failover) 261 failover_manager.register_config(default_transcription_failover) 262 failover_manager.register_config(default_video_failover) 263 264 class FailoverMiddleware(BaseHTTPMiddleware): 265 """ 266 Middleware that manages automatic failover between models in case of errors. 267 Improves system resilience in the face of specific model problems. 268 """ 269 270 def __init__( 271 self, 272 app: ASGIApp, 273 exclude_paths: List[str] = None, 274 exclude_prefixes: List[str] = None, 275 default_model_type: str = "text", 276 ): 277 """ 278 Initializes the failover middleware. 279 280 Args: 281 app: ASGI Application 282 exclude_paths: Paths excluded from failover processing 283 exclude_prefixes: Path prefixes excluded from failover processing 284 default_model_type: Default model type when not specified 285 """ 286 super().__init__(app) 287 self.exclude_paths = exclude_paths or ["/api/health", "/api/docs", "/api/redoc", "/api/openapi.json"] 288 self.exclude_prefixes = exclude_prefixes or ["/static/", "/docs/"] 289 self.default_model_type = default_model_type 290 291 # Model type indicators in URLs 292 self.model_type_indicators = { 293 "/transcription/": "transcription", 294 "/video/": "video", 295 "/inference/": "text", 296 "/inference/text": "text", 297 "/inference/embedding": "embedding", 298 "/inference/image": "image", 299 } 300 301 async def dispatch(self, request: Request, call_next: Callable) -> Response: 302 """Processes requests with failover strategy.""" 303 # Check if this path should be excluded 304 if self._is_excluded_path(request.url.path): 305 return await call_next(request) 306 307 # Get the model type and original model identifier 308 model_type, model_id = self._extract_model_info(request) 309 310 # If no model is specified, no failover is possible 311 if not model_id: 312 return await call_next(request) 313 314 # Try to execute the original request 315 try: 316 response = await call_next(request) 317 318 # If the response is successful, mark the model as functional 319 if response.status_code < 400: 320 failover_manager.mark_model_success(model_id) 321 return response 322 323 # If it's an application-specific error, don't failover 324 if response.status_code in [400, 401, 403, 404]: 325 return response 326 327 except Exception as e: 328 logger.error(f"Exception during original request: {str(e)}") 329 failover_manager.mark_model_failure(model_id) 330 # Continue to failover 331 332 # If we get here, the original request failed 333 failover_manager.mark_model_failure(model_id) 334 335 # Try failover with an alternative model 336 alternative_model = failover_manager.get_alternative_model(model_type, model_id) 337 if not alternative_model: 338 logger.warning(f"No alternative model available for {model_id} of type {model_type}") 339 340 # Return an error indicating model failure 341 return JSONResponse( 342 status_code=503, 343 content={ 344 "detail": "The requested model is temporarily unavailable and no alternative is available", 345 "model": model_id, 346 "type": model_type, 347 "retry_after": 300 # Suggest a delay of 5 minutes 348 } 349 ) 350 351 logger.info(f"Attempting failover from {model_id} to {alternative_model}") 352 353 # Create a new request with the alternative model 354 try: 355 # Create a modified copy of the request 356 modified_request = await self._create_modified_request(request, model_id, alternative_model) 357 358 # Execute the modified request 359 response = await call_next(modified_request) 360 361 # Check if the alternative request succeeded 362 if response.status_code < 400: 363 # Record successful failover 364 failover_manager.record_failover(model_id, alternative_model, True) 365 failover_manager.mark_model_success(alternative_model) 366 367 # Add a header to inform the client of the failover 368 response.headers["X-Model-Failover"] = f"Original: {model_id}, Alternative: {alternative_model}" 369 370 return response 371 else: 372 # Record failover failure 373 failover_manager.mark_model_failure(alternative_model) 374 failover_manager.record_failover(model_id, alternative_model, False, 375 f"Status code: {response.status_code}") 376 377 # Return a detailed error 378 return JSONResponse( 379 status_code=503, 380 content={ 381 "detail": "All available models have failed", 382 "original_model": model_id, 383 "alternative_model": alternative_model, 384 "retry_after": 600 # Suggest a longer delay 385 } 386 ) 387 388 except Exception as e: 389 logger.error(f"Exception during failover: {str(e)}") 390 failover_manager.mark_model_failure(alternative_model) 391 failover_manager.record_failover(model_id, alternative_model, False, str(e)) 392 393 # Return an error 394 return JSONResponse( 395 status_code=500, 396 content={ 397 "detail": "Error during failover processing", 398 "message": str(e), 399 "original_model": model_id, 400 "alternative_model": alternative_model 401 } 402 ) 403 404 def _is_excluded_path(self, path: str) -> bool: 405 """Checks if the path is excluded from failover processing.""" 406 if path in self.exclude_paths: 407 return True 408 409 for prefix in self.exclude_prefixes: 410 if path.startswith(prefix): 411 return True 412 413 return False 414 415 def _extract_model_info(self, request: Request) -> Tuple[str, Optional[str]]: 416 """ 417 Extracts the model type and model identifier from the request. 418 419 Args: 420 request: FastAPI Request 421 422 Returns: 423 Tuple containing the model type and model identifier 424 """ 425 # Determine the model type from the path 426 model_type = self.default_model_type 427 for indicator, type_value in self.model_type_indicators.items(): 428 if indicator in request.url.path: 429 model_type = type_value 430 break 431 432 # Get the model identifier 433 model_id = request.query_params.get("model") 434 435 # If the model is not in the query params, look in the body 436 if not model_id and request.method in ["POST", "PUT", "PATCH"]: 437 try: 438 body_bytes = getattr(request, "_body", None) 439 if body_bytes: 440 body = json.loads(body_bytes.decode("utf-8")) 441 model_id = body.get("model") or body.get("model_id") or body.get("engine_id") 442 except Exception: 443 pass 444 445 return model_type, model_id 446 447 async def _create_modified_request(self, request: Request, original_model: str, alternative_model: str) -> Request: 448 """ 449 Creates a modified copy of the request with the alternative model. 450 451 Args: 452 request: Original request 453 original_model: Original model identifier 454 alternative_model: Alternative model identifier 455 456 Returns: 457 Modified request 458 """ 459 # Creating a copy of the request is complex as Request is an immutable object 460 # We will instead store the alternative model in the request state 461 462 request.state.alternative_model = alternative_model 463 request.state.original_model = original_model 464 465 # For requests with a body, we need to replace the model in the body 466 if request.method in ["POST", "PUT", "PATCH"]: 467 try: 468 body_bytes = getattr(request, "_body", None) 469 if body_bytes: 470 body = json.loads(body_bytes.decode("utf-8")) 471 472 # Replace the model in the body 473 if "model" in body: 474 body["model"] = alternative_model 475 elif "model_id" in body: 476 body["model_id"] = alternative_model 477 elif "engine_id" in body: 478 body["engine_id"] = alternative_model 479 480 # Replace the request body 481 new_body = json.dumps(body).encode("utf-8") 482 request._body = new_body 483 except Exception as e: 484 logger.error(f"Error when modifying the body: {str(e)}") 485 486 return request 487 488 # Function to get model health 489 def get_models_health() -> Dict[str, Any]: 490 """Gets a report on the health status of models.""" 491 return failover_manager.get_model_health_report() 492 493 # Function to reset a model manually 494 def reset_model_status(model_id: str) -> Dict[str, Any]: 495 """Manually resets a model's status.""" 496 if model_id in failover_manager.model_status: 497 failover_manager.mark_model_success(model_id) 498 return {"success": True, "message": f"Model status {model_id} reset"} 499 else: 500 return {"success": False, "message": f"Model {model_id} not found"} 501 502 # Function to configure alternatives for a model 503 def configure_failover( 504 model_type: str, 505 model_id: str, 506 alternatives: List[str], 507 max_retries: int = 3, 508 cooldown_period: int = 300 509 ) -> Dict[str, Any]: 510 """ 511 Configures failover alternatives for a model. 512 513 Args: 514 model_type: Model type 515 model_id: Model identifier 516 alternatives: List of alternative model identifiers 517 max_retries: Maximum number of retry attempts 518 cooldown_period: Cooling period in seconds 519 520 Returns: 521 Result dictionary 522 """ 523 # Check if a configuration already exists for this type 524 if model_type not in failover_manager.configs: 525 config = FailoverConfig( 526 model_type=model_type, 527 alternatives={model_id: alternatives}, 528 max_retries=max_retries, 529 cooldown_period=cooldown_period 530 ) 531 failover_manager.register_config(config) 532 else: 533 # Update the existing configuration 534 failover_manager.configs[model_type].alternatives[model_id] = alternatives 535 failover_manager.configs[model_type].max_retries = max_retries 536 failover_manager.configs[model_type].cooldown_period = cooldown_period 537 538 return { 539 "success": True, 540 "message": f"Failover configuration updated for {model_id} of type {model_type}" 541 }