/ main.py
main.py
  1  """
  2  Main entry point for the Cerastes API
  3  -----------------------------------------
  4  This module initializes the FastAPI application and mounts the various routers.
  5  """
  6  
  7  import os
  8  import logging
  9  import time
 10  import traceback
 11  from pathlib import Path
 12  from fastapi import FastAPI, Request, Depends
 13  from fastapi.middleware.cors import CORSMiddleware
 14  from fastapi.responses import JSONResponse
 15  from fastapi.staticfiles import StaticFiles
 16  from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
 17  from fastapi.openapi.utils import get_openapi
 18  
 19  # Importing routers
 20  from api import (
 21      health_router,
 22      error_handlers,
 23      inference_router,
 24      transcription_router, 
 25      video_router,
 26      subscription_router,
 27      task_router,
 28      auth_router
 29  )
 30  
 31  # Importing middlewares
 32  from middleware import SecurityMiddleware, RateLimitMiddleware, CacheMiddleware, TranslationMiddleware, FailoverMiddleware
 33  # Importer APIKeyMiddleware depuis son fichier dédié
 34  from api_key_middleware import APIKeyMiddleware
 35  
 36  # Importing configuration
 37  from config import setup_logging, app_config, model_config, api_config
 38  
 39  # Initial logging configuration
 40  setup_logging()
 41  logger = logging.getLogger("api.main")
 42  
 43  # Creating necessary directories
 44  for directory in ["inference_results", "uploads", "results", "logs", "cache", "translation_models"]:
 45      Path(directory).mkdir(parents=True, exist_ok=True)
 46  
 47  # FastAPI application initialization
 48  app = FastAPI(
 49      title="Cerastes API",
 50      description="API for advanced analysis of multimedia and textual content",
 51      version=app_config.get("version", "1.0.0"),
 52      docs_url=None,  # Disabled by default, redirected to /api/docs
 53      redoc_url=None,  # Disabled by default, redirected to /api/redoc
 54      openapi_url="/api/openapi.json"
 55  )
 56  
 57  # Common excluded paths configuration
 58  common_exclude_paths = [
 59      "/api/health", 
 60      "/auth/token", 
 61      "/auth/register", 
 62      "/api/docs", 
 63      "/api/redoc", 
 64      "/api/openapi.json",
 65      "/static"
 66  ]
 67  common_exclude_prefixes = ["/static/", "/docs/", "/assets/"]
 68  
 69  # CORS configuration
 70  cors_origins = os.getenv("CORS_ORIGINS", "*").split(",")
 71  app.add_middleware(
 72      CORSMiddleware,
 73      allow_origins=cors_origins,
 74      allow_credentials=True,
 75      allow_methods=["*"],
 76      allow_headers=["*"],
 77  )
 78  
 79  # Security middleware - First in the chain to protect all requests
 80  app.add_middleware(
 81      SecurityMiddleware,
 82      enable_xss_protection=True,
 83      enable_hsts=True,
 84      enable_content_type_options=True,
 85      enable_frame_options=True,
 86      enable_referrer_policy=True,
 87      enable_csp=True,
 88      enable_cors_protection=True,
 89      allowed_origins=cors_origins,
 90      allowed_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
 91  )
 92  
 93  # Rate limit middleware - Protection against abuse
 94  app.add_middleware(
 95      RateLimitMiddleware,
 96      global_rate_limit=api_config.get("global_rate_limit", 1000),
 97      ip_rate_limit=api_config.get("ip_rate_limit", 100),
 98      api_key_rate_limit=api_config.get("api_key_rate_limit", 200),
 99      window_size=api_config.get("rate_limit_window", 60),
100      exclude_paths=common_exclude_paths,
101      exclude_prefixes=common_exclude_prefixes
102  )
103  
104  # Cache middleware - Performance improvement
105  app.add_middleware(
106      CacheMiddleware,
107      ttl=api_config.get("cache_ttl", 300),
108      max_size=api_config.get("cache_max_size", 1000),
109      include_prefixes=["/api/inference/", "/api/video/", "/api/transcription/"],
110      exclude_paths=common_exclude_paths + ["/api/tasks"],
111      exclude_prefixes=common_exclude_prefixes,
112      cache_query_params=True,
113      cache_by_api_key=True
114  )
115  
116  # Translation middleware - Internationalization
117  app.add_middleware(
118      TranslationMiddleware,
119      exclude_paths=common_exclude_paths,
120      exclude_prefixes=common_exclude_prefixes,
121      text_field_names=["text", "content", "prompt", "transcription", "question"]
122  )
123  
124  # Failover middleware - Resilience
125  app.add_middleware(
126      FailoverMiddleware,
127      exclude_paths=common_exclude_paths,
128      exclude_prefixes=common_exclude_prefixes,
129      default_model_type="text"
130  )
131  
132  # API key authentication middleware - Last in the middleware chain
133  app.add_middleware(
134      APIKeyMiddleware,
135      exclude_paths=common_exclude_paths,
136      exclude_prefixes=["/static/", "/docs/"],
137      admin_paths=["/admin/"]
138  )
139  
140  # Custom routes for documentation
141  @app.get("/api/docs", include_in_schema=False)
142  async def custom_swagger_ui_html():
143      """Custom route for Swagger UI."""
144      return get_swagger_ui_html(
145          openapi_url="/api/openapi.json",
146          title="Cerastes API - Documentation",
147          swagger_js_url="/static/swagger-ui-bundle.js",
148          swagger_css_url="/static/swagger-ui.css",
149      )
150  
151  @app.get("/api/redoc", include_in_schema=False)
152  async def custom_redoc_html():
153      """Custom route for ReDoc."""
154      return get_redoc_html(
155          openapi_url="/api/openapi.json",
156          title="Cerastes API - Documentation ReDoc",
157          redoc_js_url="/static/redoc.standalone.js",
158      )
159  
160  # OpenAPI schema customization
161  def custom_openapi():
162      if app.openapi_schema:
163          return app.openapi_schema
164      
165      openapi_schema = get_openapi(
166          title=app.title,
167          version=app.version,
168          description=app.description,
169          routes=app.routes,
170      )
171      
172      # Add contact and license information
173      openapi_schema["info"]["contact"] = {
174          "name": "Cerastes API Support",
175          "url": "https://cerastes.ai/support",
176          "email": "support@cerastes.ai"
177      }
178      
179      openapi_schema["info"]["license"] = {
180          "name": "Dual GPL/Commercial License",
181          "url": "https://cerastes.ai/license"
182      }
183      
184      # Customize tags
185      openapi_schema["tags"] = [
186          {
187              "name": "health",
188              "description": "Endpoints to check API status"
189          },
190          {
191              "name": "inference",
192              "description": "Endpoints for text inference"
193          },
194          {
195              "name": "transcription",
196              "description": "Endpoints for audio transcription"
197          },
198          {
199              "name": "video",
200              "description": "Endpoints for video analysis"
201          },
202          {
203              "name": "tasks",
204              "description": "Endpoints for task management"
205          },
206          {
207              "name": "auth",
208              "description": "Endpoints for authentication and authorization"
209          },
210          {
211              "name": "subscription",
212              "description": "Endpoints for subscription management"
213          }
214      ]
215      
216      app.openapi_schema = openapi_schema
217      return app.openapi_schema
218  
219  app.openapi = custom_openapi
220  
221  # Custom error handling
222  @app.middleware("http")
223  async def exception_handling(request: Request, call_next):
224      """Middleware to handle global exceptions."""
225      start_time = time.time()
226      
227      try:
228          response = await call_next(request)
229          
230          # Basic logging for successful requests
231          process_time = time.time() - start_time
232          logger.debug(f"{request.method} {request.url.path} - {response.status_code} - {process_time:.3f}s")
233          
234          return response
235      except Exception as e:
236          # Detailed logging for exceptions
237          logger.error(f"Unhandled exception: {str(e)}")
238          logger.error(f"Path: {request.url.path}")
239          logger.error(f"Method: {request.method}")
240          logger.error(f"Client: {request.client.host if request.client else 'Unknown'}")
241          logger.error(traceback.format_exc())
242          
243          # Return a structured error response
244          return JSONResponse(
245              status_code=500,
246              content={
247                  "detail": "Internal server error",
248                  "message": str(e),
249                  "type": type(e).__name__,
250                  "path": request.url.path
251              }
252          )
253  
254  # Middleware for request logging
255  @app.middleware("http")
256  async def add_process_time_header(request: Request, call_next):
257      """Adds a header with processing time."""
258      start_time = time.time()
259      response = await call_next(request)
260      process_time = time.time() - start_time
261      response.headers["X-Process-Time"] = str(process_time)
262      return response
263  
264  # Mounting routers
265  app.include_router(health_router, prefix="/api")
266  app.include_router(inference_router, prefix="/api")
267  app.include_router(video_router, prefix="/api/video")
268  app.include_router(transcription_router, prefix="/api/transcription")
269  app.include_router(subscription_router, prefix="/api/subscription")
270  app.include_router(task_router, prefix="/api/tasks")
271  app.include_router(auth_router, prefix="/auth")
272  
273  # Mounting error handlers
274  error_handlers.register_exception_handlers(app)
275  
276  # Static files - only if directory exists
277  static_dir = Path("static")
278  if static_dir.exists() and static_dir.is_dir():
279      app.mount("/static", StaticFiles(directory="static"), name="static")
280  
281  # Frontend - mounted at root
282  frontend_dir = Path("frontend/dist")
283  if frontend_dir.exists() and frontend_dir.is_dir():
284      app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="frontend")
285  
286  @app.on_event("startup")
287  async def startup_event():
288      """Executed at application startup."""
289      logger.info("=== Starting Cerastes API ===")
290      
291      # Initialize global resources (database, cache, etc.)
292      try:
293          # Database initialization
294          from db.init_db import init_db
295          init_db()
296          logger.info("Database initialized successfully")
297      except Exception as e:
298          logger.error(f"Error initializing database: {str(e)}")
299          logger.error(traceback.format_exc())
300      
301      # Model manager initialization
302      try:
303          from model_manager import ModelManager
304          ModelManager.initialize()
305          logger.info("Model manager initialized successfully")
306          
307          # Preload models if configured
308          if model_config.get("preload_models", False):
309              logger.info("Model preloading requested...")
310              preload_list = model_config.get("preload_list", [])
311              for model_name in preload_list:
312                  try:
313                      logger.info(f"Preloading model {model_name}...")
314                      ModelManager.get_instance().load_model(model_name)
315                  except Exception as e:
316                      logger.warning(f"Failed to preload model {model_name}: {str(e)}")
317      except Exception as e:
318          logger.error(f"Error initializing model manager: {str(e)}")
319          logger.error(traceback.format_exc())
320      
321      # Post-processors initialization
322      try:
323          from postprocessors.json_simplifier import JSONSimplifier
324          from config import load_config
325          
326          config = load_config()
327          json_simplifier = JSONSimplifier(config.get("postprocessing", {}).get("json_simplifier", {}))
328          app.state.json_simplifier = json_simplifier
329          
330          if json_simplifier.enabled:
331              logger.info(f"JSONSimplifier post-processor enabled for: {', '.join(json_simplifier.apply_to)}")
332      except Exception as e:
333          logger.error(f"Error initializing JSONSimplifier: {str(e)}")
334          logger.error(traceback.format_exc())
335      
336      # Advanced middleware initialization
337      try:
338          # Check model health for failover
339          from middleware.failover_middleware import get_models_health
340          health_report = get_models_health()
341          logger.info(f"Failover middleware initialized with {len(health_report['models'])} configured models")
342          
343          # Cache initialization
344          from middleware.cache_middleware import get_cache_stats
345          logger.info(f"Cache middleware initialized: {get_cache_stats()}")
346      except Exception as e:
347          logger.warning(f"Error initializing advanced middleware: {str(e)}")
348      
349      # Startup information
350      logger.info(f"Version: {app.version}")
351      logger.info(f"Environment: {os.getenv('ENVIRONMENT', 'development')}")
352      logger.info(f"Log level: {os.getenv('LOG_LEVEL', 'INFO')}")
353      
354      # GPU verification
355      try:
356          import torch
357          gpu_available = torch.cuda.is_available()
358          gpu_count = torch.cuda.device_count() if gpu_available else 0
359          logger.info(f"GPU available: {gpu_available}, GPU count: {gpu_count}")
360          
361          if gpu_available:
362              for i in range(gpu_count):
363                  logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}, Total memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
364      except ImportError:
365          logger.warning("PyTorch not available, operating in CPU-only mode")
366      except Exception as e:
367          logger.warning(f"Error checking GPUs: {str(e)}")
368      
369      # Data path verification
370      for path_name, path in [
371          ("Uploads", Path("uploads")),
372          ("Results", Path("results")),
373          ("Logs", Path("logs")),
374          ("Cache", Path("cache")),
375          ("Translation models", Path("translation_models"))
376      ]:
377          logger.info(f"{path_name}: {path.absolute()} ({path.exists() and path.is_dir() and os.access(path, os.W_OK)})") 
378      
379      # System prompts loading
380      try:
381          from config import get_system_prompts
382          prompts, prompt_order = get_system_prompts()
383          logger.info(f"System prompts loaded: {len(prompts)} prompts in order: {', '.join(prompt_order)}")
384      except Exception as e:
385          logger.error(f"Error loading system prompts: {str(e)}")
386          logger.error(traceback.format_exc())
387      
388      # Mounted middleware verification
389      middleware_list = [m.__class__.__name__ for m in app.user_middleware]
390      logger.info(f"Active middlewares: {', '.join(middleware_list)}")
391      
392      logger.info("Cerastes API started successfully and ready to receive requests!")
393  
394  @app.on_event("shutdown")
395  async def shutdown_event():
396      """Executed at application shutdown."""
397      logger.info("=== Stopping Cerastes API ===")
398      
399      # Release model resources
400      try:
401          from model_manager import ModelManager
402          logger.info("Releasing models from memory...")
403          ModelManager.cleanup()
404      except Exception as e:
405          logger.error(f"Error releasing models: {str(e)}")
406      
407      # Release middleware resources
408      try:
409          # Translator resources
410          from middleware.translation_middleware import translation_manager
411          logger.info("Releasing translation models...")
412          translation_manager.close()
413          
414          # Cache resources
415          from middleware.cache_middleware import invalidate_cache
416          logger.info("Cleaning cache...")
417          invalidate_cache()
418          
419          logger.info("Middleware resources released successfully")
420      except Exception as e:
421          logger.error(f"Error releasing middleware resources: {str(e)}")
422      
423      # Temporary files cleanup
424      try:
425          import shutil
426          from datetime import datetime, timedelta
427          
428          # Delete temporary files older than 24h
429          logger.info("Cleaning temporary files...")
430          temp_dirs = ["uploads", "results/transcriptions", "cache"]
431          cutoff_time = datetime.now() - timedelta(hours=24)
432          
433          for temp_dir in temp_dirs:
434              if os.path.exists(temp_dir):
435                  for item in os.listdir(temp_dir):
436                      item_path = os.path.join(temp_dir, item)
437                      
438                      # Check file age
439                      if os.path.isfile(item_path):
440                          mod_time = datetime.fromtimestamp(os.path.getmtime(item_path))
441                          if mod_time < cutoff_time:
442                              try:
443                                  os.unlink(item_path)
444                                  logger.debug(f"Temporary file deleted: {item_path}")
445                              except Exception as e:
446                                  logger.warning(f"Unable to delete {item_path}: {str(e)}")
447      except Exception as e:
448          logger.error(f"Error cleaning temporary files: {str(e)}")
449      
450      # Release CUDA resources
451      try:
452          import torch
453          if torch.cuda.is_available():
454              logger.info("Releasing CUDA memory...")
455              torch.cuda.empty_cache()
456      except ImportError:
457          pass
458      except Exception as e:
459          logger.error(f"Error releasing CUDA memory: {str(e)}")
460      
461      # Close database connections
462      try:
463          from db import engine
464          logger.info("Closing database connections...")
465          engine.dispose()
466      except Exception as e:
467          logger.error(f"Error closing DB connections: {str(e)}")
468      
469      logger.info("Cerastes API shutdown completed")
470  
471  # Entry point for direct execution
472  if __name__ == "__main__":
473      import uvicorn
474      
475      host = os.getenv("HOST", "0.0.0.0")
476      port = int(os.getenv("PORT", "8000"))
477      reload_enabled = os.getenv("RELOAD", "false").lower() == "true"
478      
479      logger.info(f"Starting server on {host}:{port} (reload: {reload_enabled})")
480      
481      uvicorn.run(
482          "main:app", 
483          host=host, 
484          port=port,
485          reload=reload_enabled,
486          log_level=os.getenv("LOG_LEVEL", "info").lower()
487      )