/ main.py
main.py
1 """ 2 Main entry point for the Cerastes API 3 ----------------------------------------- 4 This module initializes the FastAPI application and mounts the various routers. 5 """ 6 7 import os 8 import logging 9 import time 10 import traceback 11 from pathlib import Path 12 from fastapi import FastAPI, Request, Depends 13 from fastapi.middleware.cors import CORSMiddleware 14 from fastapi.responses import JSONResponse 15 from fastapi.staticfiles import StaticFiles 16 from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html 17 from fastapi.openapi.utils import get_openapi 18 19 # Importing routers 20 from api import ( 21 health_router, 22 error_handlers, 23 inference_router, 24 transcription_router, 25 video_router, 26 subscription_router, 27 task_router, 28 auth_router 29 ) 30 31 # Importing middlewares 32 from middleware import SecurityMiddleware, RateLimitMiddleware, CacheMiddleware, TranslationMiddleware, FailoverMiddleware 33 # Importer APIKeyMiddleware depuis son fichier dédié 34 from api_key_middleware import APIKeyMiddleware 35 36 # Importing configuration 37 from config import setup_logging, app_config, model_config, api_config 38 39 # Initial logging configuration 40 setup_logging() 41 logger = logging.getLogger("api.main") 42 43 # Creating necessary directories 44 for directory in ["inference_results", "uploads", "results", "logs", "cache", "translation_models"]: 45 Path(directory).mkdir(parents=True, exist_ok=True) 46 47 # FastAPI application initialization 48 app = FastAPI( 49 title="Cerastes API", 50 description="API for advanced analysis of multimedia and textual content", 51 version=app_config.get("version", "1.0.0"), 52 docs_url=None, # Disabled by default, redirected to /api/docs 53 redoc_url=None, # Disabled by default, redirected to /api/redoc 54 openapi_url="/api/openapi.json" 55 ) 56 57 # Common excluded paths configuration 58 common_exclude_paths = [ 59 "/api/health", 60 "/auth/token", 61 "/auth/register", 62 "/api/docs", 63 "/api/redoc", 64 "/api/openapi.json", 65 "/static" 66 ] 67 common_exclude_prefixes = ["/static/", "/docs/", "/assets/"] 68 69 # CORS configuration 70 cors_origins = os.getenv("CORS_ORIGINS", "*").split(",") 71 app.add_middleware( 72 CORSMiddleware, 73 allow_origins=cors_origins, 74 allow_credentials=True, 75 allow_methods=["*"], 76 allow_headers=["*"], 77 ) 78 79 # Security middleware - First in the chain to protect all requests 80 app.add_middleware( 81 SecurityMiddleware, 82 enable_xss_protection=True, 83 enable_hsts=True, 84 enable_content_type_options=True, 85 enable_frame_options=True, 86 enable_referrer_policy=True, 87 enable_csp=True, 88 enable_cors_protection=True, 89 allowed_origins=cors_origins, 90 allowed_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"] 91 ) 92 93 # Rate limit middleware - Protection against abuse 94 app.add_middleware( 95 RateLimitMiddleware, 96 global_rate_limit=api_config.get("global_rate_limit", 1000), 97 ip_rate_limit=api_config.get("ip_rate_limit", 100), 98 api_key_rate_limit=api_config.get("api_key_rate_limit", 200), 99 window_size=api_config.get("rate_limit_window", 60), 100 exclude_paths=common_exclude_paths, 101 exclude_prefixes=common_exclude_prefixes 102 ) 103 104 # Cache middleware - Performance improvement 105 app.add_middleware( 106 CacheMiddleware, 107 ttl=api_config.get("cache_ttl", 300), 108 max_size=api_config.get("cache_max_size", 1000), 109 include_prefixes=["/api/inference/", "/api/video/", "/api/transcription/"], 110 exclude_paths=common_exclude_paths + ["/api/tasks"], 111 exclude_prefixes=common_exclude_prefixes, 112 cache_query_params=True, 113 cache_by_api_key=True 114 ) 115 116 # Translation middleware - Internationalization 117 app.add_middleware( 118 TranslationMiddleware, 119 exclude_paths=common_exclude_paths, 120 exclude_prefixes=common_exclude_prefixes, 121 text_field_names=["text", "content", "prompt", "transcription", "question"] 122 ) 123 124 # Failover middleware - Resilience 125 app.add_middleware( 126 FailoverMiddleware, 127 exclude_paths=common_exclude_paths, 128 exclude_prefixes=common_exclude_prefixes, 129 default_model_type="text" 130 ) 131 132 # API key authentication middleware - Last in the middleware chain 133 app.add_middleware( 134 APIKeyMiddleware, 135 exclude_paths=common_exclude_paths, 136 exclude_prefixes=["/static/", "/docs/"], 137 admin_paths=["/admin/"] 138 ) 139 140 # Custom routes for documentation 141 @app.get("/api/docs", include_in_schema=False) 142 async def custom_swagger_ui_html(): 143 """Custom route for Swagger UI.""" 144 return get_swagger_ui_html( 145 openapi_url="/api/openapi.json", 146 title="Cerastes API - Documentation", 147 swagger_js_url="/static/swagger-ui-bundle.js", 148 swagger_css_url="/static/swagger-ui.css", 149 ) 150 151 @app.get("/api/redoc", include_in_schema=False) 152 async def custom_redoc_html(): 153 """Custom route for ReDoc.""" 154 return get_redoc_html( 155 openapi_url="/api/openapi.json", 156 title="Cerastes API - Documentation ReDoc", 157 redoc_js_url="/static/redoc.standalone.js", 158 ) 159 160 # OpenAPI schema customization 161 def custom_openapi(): 162 if app.openapi_schema: 163 return app.openapi_schema 164 165 openapi_schema = get_openapi( 166 title=app.title, 167 version=app.version, 168 description=app.description, 169 routes=app.routes, 170 ) 171 172 # Add contact and license information 173 openapi_schema["info"]["contact"] = { 174 "name": "Cerastes API Support", 175 "url": "https://cerastes.ai/support", 176 "email": "support@cerastes.ai" 177 } 178 179 openapi_schema["info"]["license"] = { 180 "name": "Dual GPL/Commercial License", 181 "url": "https://cerastes.ai/license" 182 } 183 184 # Customize tags 185 openapi_schema["tags"] = [ 186 { 187 "name": "health", 188 "description": "Endpoints to check API status" 189 }, 190 { 191 "name": "inference", 192 "description": "Endpoints for text inference" 193 }, 194 { 195 "name": "transcription", 196 "description": "Endpoints for audio transcription" 197 }, 198 { 199 "name": "video", 200 "description": "Endpoints for video analysis" 201 }, 202 { 203 "name": "tasks", 204 "description": "Endpoints for task management" 205 }, 206 { 207 "name": "auth", 208 "description": "Endpoints for authentication and authorization" 209 }, 210 { 211 "name": "subscription", 212 "description": "Endpoints for subscription management" 213 } 214 ] 215 216 app.openapi_schema = openapi_schema 217 return app.openapi_schema 218 219 app.openapi = custom_openapi 220 221 # Custom error handling 222 @app.middleware("http") 223 async def exception_handling(request: Request, call_next): 224 """Middleware to handle global exceptions.""" 225 start_time = time.time() 226 227 try: 228 response = await call_next(request) 229 230 # Basic logging for successful requests 231 process_time = time.time() - start_time 232 logger.debug(f"{request.method} {request.url.path} - {response.status_code} - {process_time:.3f}s") 233 234 return response 235 except Exception as e: 236 # Detailed logging for exceptions 237 logger.error(f"Unhandled exception: {str(e)}") 238 logger.error(f"Path: {request.url.path}") 239 logger.error(f"Method: {request.method}") 240 logger.error(f"Client: {request.client.host if request.client else 'Unknown'}") 241 logger.error(traceback.format_exc()) 242 243 # Return a structured error response 244 return JSONResponse( 245 status_code=500, 246 content={ 247 "detail": "Internal server error", 248 "message": str(e), 249 "type": type(e).__name__, 250 "path": request.url.path 251 } 252 ) 253 254 # Middleware for request logging 255 @app.middleware("http") 256 async def add_process_time_header(request: Request, call_next): 257 """Adds a header with processing time.""" 258 start_time = time.time() 259 response = await call_next(request) 260 process_time = time.time() - start_time 261 response.headers["X-Process-Time"] = str(process_time) 262 return response 263 264 # Mounting routers 265 app.include_router(health_router, prefix="/api") 266 app.include_router(inference_router, prefix="/api") 267 app.include_router(video_router, prefix="/api/video") 268 app.include_router(transcription_router, prefix="/api/transcription") 269 app.include_router(subscription_router, prefix="/api/subscription") 270 app.include_router(task_router, prefix="/api/tasks") 271 app.include_router(auth_router, prefix="/auth") 272 273 # Mounting error handlers 274 error_handlers.register_exception_handlers(app) 275 276 # Static files - only if directory exists 277 static_dir = Path("static") 278 if static_dir.exists() and static_dir.is_dir(): 279 app.mount("/static", StaticFiles(directory="static"), name="static") 280 281 # Frontend - mounted at root 282 frontend_dir = Path("frontend/dist") 283 if frontend_dir.exists() and frontend_dir.is_dir(): 284 app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="frontend") 285 286 @app.on_event("startup") 287 async def startup_event(): 288 """Executed at application startup.""" 289 logger.info("=== Starting Cerastes API ===") 290 291 # Initialize global resources (database, cache, etc.) 292 try: 293 # Database initialization 294 from db.init_db import init_db 295 init_db() 296 logger.info("Database initialized successfully") 297 except Exception as e: 298 logger.error(f"Error initializing database: {str(e)}") 299 logger.error(traceback.format_exc()) 300 301 # Model manager initialization 302 try: 303 from model_manager import ModelManager 304 ModelManager.initialize() 305 logger.info("Model manager initialized successfully") 306 307 # Preload models if configured 308 if model_config.get("preload_models", False): 309 logger.info("Model preloading requested...") 310 preload_list = model_config.get("preload_list", []) 311 for model_name in preload_list: 312 try: 313 logger.info(f"Preloading model {model_name}...") 314 ModelManager.get_instance().load_model(model_name) 315 except Exception as e: 316 logger.warning(f"Failed to preload model {model_name}: {str(e)}") 317 except Exception as e: 318 logger.error(f"Error initializing model manager: {str(e)}") 319 logger.error(traceback.format_exc()) 320 321 # Post-processors initialization 322 try: 323 from postprocessors.json_simplifier import JSONSimplifier 324 from config import load_config 325 326 config = load_config() 327 json_simplifier = JSONSimplifier(config.get("postprocessing", {}).get("json_simplifier", {})) 328 app.state.json_simplifier = json_simplifier 329 330 if json_simplifier.enabled: 331 logger.info(f"JSONSimplifier post-processor enabled for: {', '.join(json_simplifier.apply_to)}") 332 except Exception as e: 333 logger.error(f"Error initializing JSONSimplifier: {str(e)}") 334 logger.error(traceback.format_exc()) 335 336 # Advanced middleware initialization 337 try: 338 # Check model health for failover 339 from middleware.failover_middleware import get_models_health 340 health_report = get_models_health() 341 logger.info(f"Failover middleware initialized with {len(health_report['models'])} configured models") 342 343 # Cache initialization 344 from middleware.cache_middleware import get_cache_stats 345 logger.info(f"Cache middleware initialized: {get_cache_stats()}") 346 except Exception as e: 347 logger.warning(f"Error initializing advanced middleware: {str(e)}") 348 349 # Startup information 350 logger.info(f"Version: {app.version}") 351 logger.info(f"Environment: {os.getenv('ENVIRONMENT', 'development')}") 352 logger.info(f"Log level: {os.getenv('LOG_LEVEL', 'INFO')}") 353 354 # GPU verification 355 try: 356 import torch 357 gpu_available = torch.cuda.is_available() 358 gpu_count = torch.cuda.device_count() if gpu_available else 0 359 logger.info(f"GPU available: {gpu_available}, GPU count: {gpu_count}") 360 361 if gpu_available: 362 for i in range(gpu_count): 363 logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}, Total memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB") 364 except ImportError: 365 logger.warning("PyTorch not available, operating in CPU-only mode") 366 except Exception as e: 367 logger.warning(f"Error checking GPUs: {str(e)}") 368 369 # Data path verification 370 for path_name, path in [ 371 ("Uploads", Path("uploads")), 372 ("Results", Path("results")), 373 ("Logs", Path("logs")), 374 ("Cache", Path("cache")), 375 ("Translation models", Path("translation_models")) 376 ]: 377 logger.info(f"{path_name}: {path.absolute()} ({path.exists() and path.is_dir() and os.access(path, os.W_OK)})") 378 379 # System prompts loading 380 try: 381 from config import get_system_prompts 382 prompts, prompt_order = get_system_prompts() 383 logger.info(f"System prompts loaded: {len(prompts)} prompts in order: {', '.join(prompt_order)}") 384 except Exception as e: 385 logger.error(f"Error loading system prompts: {str(e)}") 386 logger.error(traceback.format_exc()) 387 388 # Mounted middleware verification 389 middleware_list = [m.__class__.__name__ for m in app.user_middleware] 390 logger.info(f"Active middlewares: {', '.join(middleware_list)}") 391 392 logger.info("Cerastes API started successfully and ready to receive requests!") 393 394 @app.on_event("shutdown") 395 async def shutdown_event(): 396 """Executed at application shutdown.""" 397 logger.info("=== Stopping Cerastes API ===") 398 399 # Release model resources 400 try: 401 from model_manager import ModelManager 402 logger.info("Releasing models from memory...") 403 ModelManager.cleanup() 404 except Exception as e: 405 logger.error(f"Error releasing models: {str(e)}") 406 407 # Release middleware resources 408 try: 409 # Translator resources 410 from middleware.translation_middleware import translation_manager 411 logger.info("Releasing translation models...") 412 translation_manager.close() 413 414 # Cache resources 415 from middleware.cache_middleware import invalidate_cache 416 logger.info("Cleaning cache...") 417 invalidate_cache() 418 419 logger.info("Middleware resources released successfully") 420 except Exception as e: 421 logger.error(f"Error releasing middleware resources: {str(e)}") 422 423 # Temporary files cleanup 424 try: 425 import shutil 426 from datetime import datetime, timedelta 427 428 # Delete temporary files older than 24h 429 logger.info("Cleaning temporary files...") 430 temp_dirs = ["uploads", "results/transcriptions", "cache"] 431 cutoff_time = datetime.now() - timedelta(hours=24) 432 433 for temp_dir in temp_dirs: 434 if os.path.exists(temp_dir): 435 for item in os.listdir(temp_dir): 436 item_path = os.path.join(temp_dir, item) 437 438 # Check file age 439 if os.path.isfile(item_path): 440 mod_time = datetime.fromtimestamp(os.path.getmtime(item_path)) 441 if mod_time < cutoff_time: 442 try: 443 os.unlink(item_path) 444 logger.debug(f"Temporary file deleted: {item_path}") 445 except Exception as e: 446 logger.warning(f"Unable to delete {item_path}: {str(e)}") 447 except Exception as e: 448 logger.error(f"Error cleaning temporary files: {str(e)}") 449 450 # Release CUDA resources 451 try: 452 import torch 453 if torch.cuda.is_available(): 454 logger.info("Releasing CUDA memory...") 455 torch.cuda.empty_cache() 456 except ImportError: 457 pass 458 except Exception as e: 459 logger.error(f"Error releasing CUDA memory: {str(e)}") 460 461 # Close database connections 462 try: 463 from db import engine 464 logger.info("Closing database connections...") 465 engine.dispose() 466 except Exception as e: 467 logger.error(f"Error closing DB connections: {str(e)}") 468 469 logger.info("Cerastes API shutdown completed") 470 471 # Entry point for direct execution 472 if __name__ == "__main__": 473 import uvicorn 474 475 host = os.getenv("HOST", "0.0.0.0") 476 port = int(os.getenv("PORT", "8000")) 477 reload_enabled = os.getenv("RELOAD", "false").lower() == "true" 478 479 logger.info(f"Starting server on {host}:{port} (reload: {reload_enabled})") 480 481 uvicorn.run( 482 "main:app", 483 host=host, 484 port=port, 485 reload=reload_enabled, 486 log_level=os.getenv("LOG_LEVEL", "info").lower() 487 )