/ config.py
config.py
1 """ 2 Centralized configuration for Cerastes API 3 ------------------------------------------- 4 This module centralizes all application configurations 5 and loads values from environment variables. 6 """ 7 8 import os 9 import logging 10 import re 11 from pathlib import Path 12 from typing import Dict, Any, Optional, List, Tuple, Set 13 from functools import lru_cache 14 15 # Logging configuration 16 logger = logging.getLogger("config") 17 18 # Definition of important paths 19 BASE_DIR = Path(__file__).parent.absolute() 20 PROMPTS_DIR = BASE_DIR / "prompts" 21 UPLOADS_DIR = BASE_DIR / "uploads" 22 RESULTS_DIR = BASE_DIR / "results" 23 INFERENCE_RESULTS_DIR = BASE_DIR / "inference_results" 24 LOGS_DIR = BASE_DIR / "logs" 25 26 # Create directories if they don't exist 27 for directory in [UPLOADS_DIR, RESULTS_DIR, INFERENCE_RESULTS_DIR, LOGS_DIR, 28 UPLOADS_DIR / "video", UPLOADS_DIR / "audio", UPLOADS_DIR / "text"]: 29 directory.mkdir(parents=True, exist_ok=True) 30 31 # Default configuration 32 DEFAULT_CONFIG = { 33 # General configuration 34 "app": { 35 "name": "Cerastes API", 36 "version": "1.0.0", 37 "environment": "development", 38 "secret_key": "changeme_in_production", 39 "allowed_origins": ["*"], 40 "timezone": "UTC" 41 }, 42 43 # Database configuration 44 "database": { 45 "sqlalchemy_url": "postgresql://postgres:password@localhost:5432/cerastes", 46 "pool_size": 10, 47 "max_overflow": 20, 48 "echo": False, 49 "track_modifications": False, 50 }, 51 52 # AI models configuration 53 "models": { 54 # General LLM configuration 55 "llm": { 56 "default_model": "huihui-ai/DeepSeek-R1-Distill-Qwen-14B-abliterated-v2", 57 "tensor_parallel_size": 1, 58 "gpu_memory_utilization": 0.9, 59 "quantization": None, 60 "max_model_len": 24272, 61 "fallback_models": [ 62 "meta-llama/Llama-2-7b-chat-hf", 63 "facebook/opt-6.7b", 64 "bigscience/bloom-7b1" 65 ] 66 }, 67 68 # Whisper configuration 69 "whisper": { 70 "default_size": "medium", # tiny, base, small, medium, large 71 "device": "cuda", # cuda or cpu 72 "language": None, # specific language or None for auto-detection 73 "batch_size": 16 74 }, 75 76 # InternVideo configuration 77 "internvideo": { 78 "model_path": "OpenGVLab/InternVideo2_5_Chat_8B", 79 "input_size": 448, 80 "num_frames": 128, 81 "trust_remote_code": True 82 }, 83 84 # Diarization configuration 85 "diarization": { 86 "model_path": "pyannote/speaker-diarization-3.1", 87 "huggingface_token": "", 88 "min_speakers": 1, 89 "max_speakers": 10 90 } 91 }, 92 93 # Video processing configuration 94 "video": { 95 "max_upload_size_mb": 500, 96 "allowed_extensions": [".mp4", ".mov", ".avi", ".mkv", ".webm"], 97 "extract_frames": 128, 98 "max_resolution": 1080, # resize videos if larger 99 "dynamic_segmentation": True # adapts the number of segments to video duration 100 }, 101 102 # Audio processing configuration 103 "audio": { 104 "max_upload_size_mb": 100, 105 "allowed_extensions": [".mp3", ".wav", ".flac", ".ogg", ".m4a"], 106 "sample_rate": 16000, 107 "channels": 1 108 }, 109 110 # Segmentation configuration 111 "segmentation": { 112 "enabled": True, 113 "language": "fr", 114 "model_name": "sentence-transformers/all-MiniLM-L6-v2", 115 "fallback_models": ["paraphrase-multilingual-MiniLM-L12-v2"], 116 "max_segments_per_text": 10, 117 "target_segments": 6, 118 "similarity_threshold": 0.15 119 }, 120 121 # Inference configuration 122 "inference": { 123 "max_new_tokens": 8000, 124 "temperature": 0.53, 125 "top_p": 0.93, 126 "top_k": 30, 127 "timeout_seconds": 300, 128 "batch_parallel": True, 129 "max_retries": 3, 130 "max_cache_size": 50 131 }, 132 133 # API configuration 134 "api": { 135 "host": "0.0.0.0", 136 "port": 8000, 137 "workers": 1, 138 "debug": False, 139 "max_request_size_mb": 10, 140 "max_concurrent_tasks": 5, 141 "result_storage_dir": "inference_results", 142 "log_level": "info", 143 "token_expiration_minutes": 30, 144 "refresh_token_expiration_days": 7 145 }, 146 147 # Authentication configuration 148 "auth": { 149 "jwt_algorithm": "HS256", 150 "password_hash_rounds": 12, 151 "require_email_verification": False, 152 "allow_registration": True, 153 "admin_emails": [] 154 }, 155 # Post-processors configuration 156 "postprocessing": { 157 "json_simplifier": { 158 "enabled": False, # Disabled by default 159 "model": "huihui-ai/DeepSeek-R1-Distill-Qwen-14B-abliterated-v2", 160 "system_prompt": "Translate this json {text} in plain english", 161 "max_tokens": 1000, 162 "temperature": 0.3, 163 "apply_to": ["inference", "video", "transcription"] # Task types concerned 164 } 165 }, 166 167 # External services configuration 168 "services": { 169 # Stripe 170 "stripe": { 171 "enabled": False, 172 "api_key": "", 173 "webhook_secret": "", 174 "currency": "usd", 175 "success_url": "http://localhost:8000/payment/success", 176 "cancel_url": "http://localhost:8000/payment/cancel" 177 }, 178 179 # Email 180 "email": { 181 "enabled": False, 182 "smtp_server": "smtp.example.com", 183 "smtp_port": 587, 184 "smtp_username": "", 185 "smtp_password": "", 186 "sender_email": "noreply@example.com", 187 "use_tls": True 188 } 189 } 190 } 191 192 @lru_cache() 193 def load_config() -> Dict[str, Any]: 194 """ 195 Loads configuration from environment variables or uses default values. 196 The function is cached to avoid reloading the configuration on each call. 197 198 Returns: 199 Dict[str, Any]: Complete configuration 200 """ 201 config = DEFAULT_CONFIG.copy() 202 203 # ====== General configuration ====== 204 if os.environ.get("APP_NAME"): 205 config["app"]["name"] = os.environ.get("APP_NAME") 206 207 if os.environ.get("APP_VERSION"): 208 config["app"]["version"] = os.environ.get("APP_VERSION") 209 210 if os.environ.get("ENVIRONMENT"): 211 config["app"]["environment"] = os.environ.get("ENVIRONMENT") 212 213 if os.environ.get("SECRET_KEY"): 214 config["app"]["secret_key"] = os.environ.get("SECRET_KEY") 215 216 if os.environ.get("CORS_ORIGINS"): 217 config["app"]["allowed_origins"] = os.environ.get("CORS_ORIGINS").split(",") 218 219 # ====== Database configuration ====== 220 # Building SQLAlchemy URL 221 if os.environ.get("DATABASE_URL"): 222 config["database"]["sqlalchemy_url"] = os.environ.get("DATABASE_URL") 223 else: 224 # Building URL from components 225 db_user = os.environ.get("DB_USER", "postgres") 226 db_password = os.environ.get("DB_PASSWORD", "password") 227 db_host = os.environ.get("DB_HOST", "localhost") 228 db_port = os.environ.get("DB_PORT", "5432") 229 db_name = os.environ.get("DB_NAME", "cerastes") 230 231 config["database"]["sqlalchemy_url"] = f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" 232 233 if os.environ.get("DB_POOL_SIZE"): 234 config["database"]["pool_size"] = int(os.environ.get("DB_POOL_SIZE")) 235 236 if os.environ.get("DB_MAX_OVERFLOW"): 237 config["database"]["max_overflow"] = int(os.environ.get("DB_MAX_OVERFLOW")) 238 239 if os.environ.get("DB_ECHO"): 240 config["database"]["echo"] = os.environ.get("DB_ECHO").lower() in ["true", "1", "yes"] 241 242 # ====== AI models configuration ====== 243 # LLM 244 if os.environ.get("MODEL_NAME"): 245 config["models"]["llm"]["default_model"] = os.environ.get("MODEL_NAME") 246 247 if os.environ.get("TENSOR_PARALLEL_SIZE"): 248 config["models"]["llm"]["tensor_parallel_size"] = int(os.environ.get("TENSOR_PARALLEL_SIZE")) 249 250 if os.environ.get("GPU_MEMORY_UTILIZATION"): 251 config["models"]["llm"]["gpu_memory_utilization"] = float(os.environ.get("GPU_MEMORY_UTILIZATION")) 252 253 if os.environ.get("QUANTIZATION"): 254 config["models"]["llm"]["quantization"] = os.environ.get("QUANTIZATION") 255 256 if os.environ.get("MAX_MODEL_LEN"): 257 config["models"]["llm"]["max_model_len"] = int(os.environ.get("MAX_MODEL_LEN")) 258 259 # Whisper 260 if os.environ.get("WHISPER_MODEL_SIZE"): 261 config["models"]["whisper"]["default_size"] = os.environ.get("WHISPER_MODEL_SIZE") 262 263 if os.environ.get("WHISPER_DEVICE"): 264 config["models"]["whisper"]["device"] = os.environ.get("WHISPER_DEVICE") 265 266 if os.environ.get("WHISPER_LANGUAGE"): 267 config["models"]["whisper"]["language"] = os.environ.get("WHISPER_LANGUAGE") 268 269 # Diarization 270 if os.environ.get("HUGGINGFACE_TOKEN"): 271 config["models"]["diarization"]["huggingface_token"] = os.environ.get("HUGGINGFACE_TOKEN") 272 273 if os.environ.get("DIARIZATION_MODEL"): 274 config["models"]["diarization"]["model_path"] = os.environ.get("DIARIZATION_MODEL") 275 276 # ====== Segmentation configuration ====== 277 if os.environ.get("USE_SEGMENTATION") is not None: 278 config["segmentation"]["enabled"] = os.environ.get("USE_SEGMENTATION").lower() in ["true", "1", "yes"] 279 280 if os.environ.get("SEGMENTATION_LANGUAGE"): 281 config["segmentation"]["language"] = os.environ.get("SEGMENTATION_LANGUAGE") 282 283 if os.environ.get("SEGMENTATION_MODEL"): 284 config["segmentation"]["model_name"] = os.environ.get("SEGMENTATION_MODEL") 285 286 # ====== Inference configuration ====== 287 if os.environ.get("MAX_NEW_TOKENS"): 288 config["inference"]["max_new_tokens"] = int(os.environ.get("MAX_NEW_TOKENS")) 289 290 if os.environ.get("TEMPERATURE"): 291 config["inference"]["temperature"] = float(os.environ.get("TEMPERATURE")) 292 293 if os.environ.get("TOP_P"): 294 config["inference"]["top_p"] = float(os.environ.get("TOP_P")) 295 296 if os.environ.get("TOP_K"): 297 config["inference"]["top_k"] = int(os.environ.get("TOP_K")) 298 299 if os.environ.get("TIMEOUT_SECONDS"): 300 config["inference"]["timeout_seconds"] = int(os.environ.get("TIMEOUT_SECONDS")) 301 302 if os.environ.get("BATCH_PARALLEL") is not None: 303 config["inference"]["batch_parallel"] = os.environ.get("BATCH_PARALLEL").lower() in ["true", "1", "yes"] 304 305 if os.environ.get("MAX_RETRIES"): 306 config["inference"]["max_retries"] = int(os.environ.get("MAX_RETRIES")) 307 308 if os.environ.get("MAX_CACHE_SIZE"): 309 config["inference"]["max_cache_size"] = int(os.environ.get("MAX_CACHE_SIZE")) 310 311 # ====== API configuration ====== 312 if os.environ.get("HOST"): 313 config["api"]["host"] = os.environ.get("HOST") 314 315 if os.environ.get("PORT"): 316 config["api"]["port"] = int(os.environ.get("PORT")) 317 318 # ====== Post-processors configuration ====== 319 # JSONSimplifier configuration 320 if os.environ.get("JSON_SIMPLIFIER_ENABLED") is not None: 321 config["postprocessing"]["json_simplifier"]["enabled"] = \ 322 os.environ.get("JSON_SIMPLIFIER_ENABLED").lower() in ["true", "1", "yes"] 323 324 if os.environ.get("JSON_SIMPLIFIER_MODEL"): 325 config["postprocessing"]["json_simplifier"]["model"] = os.environ.get("JSON_SIMPLIFIER_MODEL") 326 327 if os.environ.get("JSON_SIMPLIFIER_SYSTEM_PROMPT"): 328 config["postprocessing"]["json_simplifier"]["system_prompt"] = os.environ.get("JSON_SIMPLIFIER_SYSTEM_PROMPT") 329 330 if os.environ.get("JSON_SIMPLIFIER_APPLY_TO"): 331 apply_to = os.environ.get("JSON_SIMPLIFIER_APPLY_TO").split(",") 332 config["postprocessing"]["json_simplifier"]["apply_to"] = [t.strip() for t in apply_to] 333 334 if os.environ.get("API_WORKERS"): 335 config["api"]["workers"] = int(os.environ.get("API_WORKERS")) 336 337 if os.environ.get("API_DEBUG") is not None: 338 config["api"]["debug"] = os.environ.get("API_DEBUG").lower() in ["true", "1", "yes"] 339 340 if os.environ.get("MAX_REQUEST_SIZE_MB"): 341 config["api"]["max_request_size_mb"] = int(os.environ.get("MAX_REQUEST_SIZE_MB")) 342 343 if os.environ.get("MAX_CONCURRENT_TASKS"): 344 config["api"]["max_concurrent_tasks"] = int(os.environ.get("MAX_CONCURRENT_TASKS")) 345 346 if os.environ.get("RESULT_STORAGE_DIR"): 347 config["api"]["result_storage_dir"] = os.environ.get("RESULT_STORAGE_DIR") 348 349 if os.environ.get("LOG_LEVEL"): 350 config["api"]["log_level"] = os.environ.get("LOG_LEVEL").lower() 351 352 if os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"): 353 config["api"]["token_expiration_minutes"] = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES")) 354 355 if os.environ.get("REFRESH_TOKEN_EXPIRE_DAYS"): 356 config["api"]["refresh_token_expiration_days"] = int(os.environ.get("REFRESH_TOKEN_EXPIRE_DAYS")) 357 358 # ====== Authentication configuration ====== 359 if os.environ.get("JWT_ALGORITHM"): 360 config["auth"]["jwt_algorithm"] = os.environ.get("JWT_ALGORITHM") 361 362 if os.environ.get("PASSWORD_HASH_ROUNDS"): 363 config["auth"]["password_hash_rounds"] = int(os.environ.get("PASSWORD_HASH_ROUNDS")) 364 365 if os.environ.get("ADMIN_EMAILS"): 366 config["auth"]["admin_emails"] = os.environ.get("ADMIN_EMAILS").split(",") 367 368 # ====== External services configuration ====== 369 # Stripe 370 if os.environ.get("STRIPE_API_KEY"): 371 config["services"]["stripe"]["enabled"] = True 372 config["services"]["stripe"]["api_key"] = os.environ.get("STRIPE_API_KEY") 373 374 if os.environ.get("STRIPE_WEBHOOK_SECRET"): 375 config["services"]["stripe"]["webhook_secret"] = os.environ.get("STRIPE_WEBHOOK_SECRET") 376 377 # Email 378 if os.environ.get("SMTP_SERVER"): 379 config["services"]["email"]["enabled"] = True 380 config["services"]["email"]["smtp_server"] = os.environ.get("SMTP_SERVER") 381 382 if os.environ.get("SMTP_PORT"): 383 config["services"]["email"]["smtp_port"] = int(os.environ.get("SMTP_PORT")) 384 385 if os.environ.get("SMTP_USERNAME"): 386 config["services"]["email"]["smtp_username"] = os.environ.get("SMTP_USERNAME") 387 388 if os.environ.get("SMTP_PASSWORD"): 389 config["services"]["email"]["smtp_password"] = os.environ.get("SMTP_PASSWORD") 390 391 if os.environ.get("SENDER_EMAIL"): 392 config["services"]["email"]["sender_email"] = os.environ.get("SENDER_EMAIL") 393 394 return config 395 396 def get_system_prompts() -> Tuple[Dict[str, str], List[str]]: 397 """ 398 Loads system prompts from the prompts directory. 399 Standardizes placeholders and checks their presence. 400 401 Returns: 402 tuple: A tuple containing (prompts_dict, prompt_order) 403 - prompts_dict: Dictionary of prompts 404 - prompt_order: Ordered list of prompt names 405 """ 406 system_prompts = {} 407 placeholder_pattern = re.compile(r'\{([a-zA-Z0-9_]+)\}') 408 409 # List of recognized standard placeholders 410 standard_placeholders = { 411 "text", "input", "query", "content", "language", "context", 412 "question", "data", "json", "transcript", "audio", "video", 413 "instructions", "parameters" 414 } 415 416 # Default prompts using the {text} placeholder in a standardized way 417 default_prompts = { 418 "system_1": "Analyze the following text: {text}", 419 "system_1_2": "Analyze the coherence of the text: {text}", 420 "system_1_2_1": "Evaluate the Markov dynamics of the text: {text}", 421 "system_2": "Perform a Jungian analysis of the text: {text}", 422 "system_3": "Perform a logical analysis of the text: {text}", 423 "system_final": "Synthesize all previous analyses of the text: {text}", 424 # Prompts specific to different features 425 "nonverbal_analysis": "Analyze the nonverbal behaviors in the following video: {text}", 426 "manipulation_analysis": "Identify manipulation strategies in the following content: {text}", 427 "transcription_general_analysis": "Analyze the following transcription and identify key points: {text}", 428 "image_generation": "Generate an image representing: {text}" 429 } 430 431 # List of expected prompts in standard order for the processing chain 432 expected_prompts = ["system_1", "system_1_2", "system_1_2_1", "system_2", "system_3", "system_final"] 433 434 try: 435 # Check/create the prompts directory 436 if not PROMPTS_DIR.exists(): 437 logger.warning(f"Prompts directory {PROMPTS_DIR} not found. Creating directory.") 438 PROMPTS_DIR.mkdir(parents=True, exist_ok=True) 439 440 # Create default prompt files 441 for prompt_name, prompt_content in default_prompts.items(): 442 with open(PROMPTS_DIR / f"{prompt_name}.txt", "w", encoding="utf-8") as f: 443 f.write(prompt_content) 444 445 logger.info(f"Default prompts created in {PROMPTS_DIR}") 446 447 # Load all available prompts and sort them explicitly by name 448 prompt_files = sorted( 449 list(PROMPTS_DIR.glob("*.txt")) + list(PROMPTS_DIR.glob("*.j2")), 450 key=lambda x: x.stem # Sort by name without extension 451 ) 452 453 if not prompt_files: 454 logger.warning(f"No prompt files found in {PROMPTS_DIR}. Using default values.") 455 system_prompts = default_prompts.copy() 456 else: 457 # Load prompts from files (now sorted) 458 for file_path in prompt_files: 459 prompt_name = file_path.stem 460 try: 461 with open(file_path, "r", encoding="utf-8") as f: 462 prompt_content = f.read().strip() 463 464 # Extract placeholders used in the prompt 465 placeholders_found = set(placeholder_pattern.findall(prompt_content)) 466 467 # Check placeholders 468 if not placeholders_found: 469 logger.warning(f"Prompt '{prompt_name}' does not contain any placeholder. " 470 f"Automatically adding {{text}} placeholder at the end.") 471 prompt_content += " {text}" 472 placeholders_found = {"text"} 473 474 # Check if used placeholders are standard 475 non_standard_placeholders = placeholders_found - standard_placeholders 476 if non_standard_placeholders: 477 logger.warning(f"Prompt '{prompt_name}' uses non-standard placeholders: " 478 f"{', '.join(non_standard_placeholders)}. " 479 f"Recommended standard placeholders are: {', '.join(standard_placeholders)}") 480 481 # Suggestion if {text} is absent but other placeholders are present 482 if "text" not in placeholders_found and placeholders_found: 483 logger.info(f"Prompt '{prompt_name}' does not use the standard {{text}} placeholder " 484 f"but uses: {', '.join(placeholders_found)}") 485 486 system_prompts[prompt_name] = prompt_content 487 logger.debug(f"Prompt '{prompt_name}' loaded from {file_path} " 488 f"with placeholders: {', '.join(placeholders_found)}") 489 490 except Exception as e: 491 logger.error(f"Error loading prompt {file_path}: {e}") 492 # If the file cannot be read, use the default prompt if it exists 493 if prompt_name in default_prompts: 494 system_prompts[prompt_name] = default_prompts[prompt_name] 495 logger.warning(f"Using default prompt for '{prompt_name}'") 496 497 # Check that all expected prompts are present 498 missing_prompts = [p for p in expected_prompts if p not in system_prompts] 499 if missing_prompts: 500 logger.warning(f"Missing prompts in the sequence: {missing_prompts}") 501 502 # Add missing prompts from default values 503 for prompt_name in missing_prompts: 504 if prompt_name in default_prompts: 505 system_prompts[prompt_name] = default_prompts[prompt_name] 506 logger.warning(f"Adding missing default prompt: '{prompt_name}'") 507 508 # Check that the order is correct and reorganize if necessary 509 ordered_prompts = {} 510 for prompt_name in expected_prompts: 511 if prompt_name in system_prompts: 512 ordered_prompts[prompt_name] = system_prompts[prompt_name] 513 514 # Add all other non-standard prompts at the end 515 for prompt_name, prompt_content in system_prompts.items(): 516 if prompt_name not in ordered_prompts: 517 ordered_prompts[prompt_name] = prompt_content 518 logger.info(f"Non-standard prompt detected: '{prompt_name}'") 519 520 # Check if feature-specific prompts are present 521 for func_prompt in ["nonverbal_analysis", "manipulation_analysis", "transcription_general_analysis"]: 522 if func_prompt not in ordered_prompts and func_prompt in default_prompts: 523 ordered_prompts[func_prompt] = default_prompts[func_prompt] 524 logger.info(f"Adding missing functional prompt: '{func_prompt}'") 525 526 # Build final prompt order 527 prompt_order = list(ordered_prompts.keys()) 528 529 return ordered_prompts, prompt_order 530 531 except Exception as e: 532 logger.error(f"Critical error loading prompts: {e}") 533 # In case of critical error, return default prompts 534 return default_prompts, expected_prompts 535 536 537 # Initialize logging before loading configuration 538 def setup_logging(): 539 """Configure the application's logging system""" 540 log_level_name = os.environ.get("LOG_LEVEL", "INFO").upper() 541 log_level = getattr(logging, log_level_name, logging.INFO) 542 543 # Basic format for all handlers 544 log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 545 546 # Root configuration 547 logging.basicConfig( 548 level=log_level, 549 format=log_format, 550 handlers=[ 551 logging.StreamHandler() 552 ] 553 ) 554 555 # Create logs directory if needed 556 if not LOGS_DIR.exists(): 557 LOGS_DIR.mkdir(parents=True, exist_ok=True) 558 559 # File handler 560 file_handler = logging.FileHandler(LOGS_DIR / "cerastes.log") 561 file_handler.setFormatter(logging.Formatter(log_format)) 562 563 # Add file handler to root logger 564 logging.getLogger().addHandler(file_handler) 565 566 # Reduce verbosity level for certain libraries 567 for logger_name in ["urllib3", "PIL", "matplotlib"]: 568 logging.getLogger(logger_name).setLevel(logging.WARNING) 569 570 # Configure logging before anything else 571 setup_logging() 572 573 # Facilitate access to configuration 574 config = load_config() 575 system_prompts, prompt_order = get_system_prompts() 576 577 # Expose configuration sections for easy import 578 app_config = config["app"] 579 db_config = config["database"] 580 model_config = config["models"] 581 video_config = config["video"] 582 audio_config = config["audio"] 583 segmentation_config = config["segmentation"] 584 inference_config = config["inference"] 585 api_config = config["api"] 586 auth_config = config["auth"] 587 services_config = config["services"] 588 postprocessing_config = config["postprocessing"]