/ middleware / translation_middleware.py
translation_middleware.py
1 """ 2 Middleware for language detection and automatic translation. 3 This module detects the input language of text, translates it to English if necessary, 4 then retranslates the response into the original language. 5 """ 6 7 import json 8 import logging 9 from typing import Dict, Any, Optional, List, Union, Callable 10 import time 11 from fastapi import Request, Response 12 from fastapi.responses import JSONResponse 13 from starlette.middleware.base import BaseHTTPMiddleware 14 from starlette.types import ASGIApp 15 16 # For language detection 17 from langdetect import detect, LangDetectException 18 from langdetect.detector_factory import DetectorFactory 19 DetectorFactory.seed = 0 # For consistent results 20 21 # For translation 22 from transformers import MarianMTModel, MarianTokenizer 23 import torch 24 25 # Logging configuration 26 logger = logging.getLogger("translation_middleware") 27 28 # Models storage path 29 MODELS_CACHE_DIR = "translation_models" 30 31 # List of supported languages - ISO 639-1 codes 32 SUPPORTED_LANGUAGES = { 33 "fr": "French", 34 "es": "Spanish", 35 "de": "German", 36 "it": "Italian", 37 "pt": "Portuguese", 38 "nl": "Dutch", 39 "ru": "Russian", 40 "zh": "Chinese", 41 "ja": "Japanese", 42 "ar": "Arabic", 43 # Add other languages as needed 44 } 45 46 # Minimum confidence threshold for language detection 47 LANGUAGE_DETECTION_THRESHOLD = 0.85 48 49 class TranslationManager: 50 """Translation manager using Hugging Face models.""" 51 52 def __init__(self): 53 self.tokenizers = {} 54 self.models = {} 55 self.device = "cuda" if torch.cuda.is_available() else "cpu" 56 logger.info(f"TranslationManager initialized on {self.device}") 57 58 def get_model_name(self, source_lang: str, target_lang: str) -> str: 59 """Returns the model name for the language pair.""" 60 if source_lang == "en" and target_lang in SUPPORTED_LANGUAGES: 61 return f"Helsinki-NLP/opus-mt-en-{target_lang}" 62 elif target_lang == "en" and source_lang in SUPPORTED_LANGUAGES: 63 return f"Helsinki-NLP/opus-mt-{source_lang}-en" 64 else: 65 # Fallback for pairs not directly supported 66 return f"Helsinki-NLP/opus-mt-mul-en" if target_lang == "en" else f"Helsinki-NLP/opus-mt-en-mul" 67 68 def load_model(self, source_lang: str, target_lang: str) -> None: 69 """Loads a translation model for a language pair.""" 70 model_key = f"{source_lang}-{target_lang}" 71 72 if model_key in self.models: 73 return 74 75 model_name = self.get_model_name(source_lang, target_lang) 76 logger.info(f"Loading translation model {model_name}") 77 78 try: 79 tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir=MODELS_CACHE_DIR) 80 model = MarianMTModel.from_pretrained(model_name, cache_dir=MODELS_CACHE_DIR) 81 82 # Move to GPU if available 83 if self.device == "cuda": 84 model = model.to(self.device) 85 86 self.tokenizers[model_key] = tokenizer 87 self.models[model_key] = model 88 logger.info(f"Model {model_name} loaded successfully") 89 except Exception as e: 90 logger.error(f"Error loading model {model_name}: {e}") 91 raise 92 93 def translate(self, text: str, source_lang: str, target_lang: str) -> str: 94 """Translates text from a source language to a target language.""" 95 if source_lang == target_lang: 96 return text 97 98 model_key = f"{source_lang}-{target_lang}" 99 100 # Load the model if necessary 101 if model_key not in self.models: 102 self.load_model(source_lang, target_lang) 103 104 tokenizer = self.tokenizers[model_key] 105 model = self.models[model_key] 106 107 try: 108 # Tokenization 109 encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) 110 111 # Move to GPU if available 112 if self.device == "cuda": 113 encoded = {k: v.to(self.device) for k, v in encoded.items()} 114 115 # Translation 116 translated = model.generate(**encoded) 117 118 # Decoding 119 result = tokenizer.decode(translated[0], skip_special_tokens=True) 120 return result 121 except Exception as e: 122 logger.error(f"Error during translation: {e}") 123 return text # In case of error, return the original text 124 125 def detect_language(self, text: str) -> Optional[str]: 126 """Detects the language of a text.""" 127 if not text or len(text) < 10: 128 return None 129 130 try: 131 detected_lang = detect(text) 132 return detected_lang if detected_lang in SUPPORTED_LANGUAGES else None 133 except LangDetectException: 134 return None 135 136 def close(self): 137 """Releases resources.""" 138 self.models.clear() 139 self.tokenizers.clear() 140 if torch.cuda.is_available(): 141 torch.cuda.empty_cache() 142 143 # Instantiate the translation manager 144 translation_manager = TranslationManager() 145 146 class TranslationMiddleware(BaseHTTPMiddleware): 147 """Middleware for language detection and automatic translation.""" 148 149 def __init__( 150 self, 151 app: ASGIApp, 152 exclude_paths: List[str] = None, 153 exclude_prefixes: List[str] = None, 154 text_field_names: List[str] = None 155 ): 156 super().__init__(app) 157 self.exclude_paths = exclude_paths or ["/health", "/docs", "/redoc", "/openapi.json"] 158 self.exclude_prefixes = exclude_prefixes or ["/static/", "/assets/"] 159 self.text_field_names = text_field_names or ["text", "content", "prompt", "transcription"] 160 161 async def dispatch(self, request: Request, call_next: Callable): 162 """Processes requests and responses with language detection and translation.""" 163 # Check if this path should be excluded 164 if self._is_excluded_path(request.url.path): 165 return await call_next(request) 166 167 # Check if it's an inference or transcription request 168 if not (request.url.path.startswith("/api/inference") or 169 request.url.path.startswith("/api/transcription")): 170 return await call_next(request) 171 172 # Check HTTP method 173 if request.method != "POST": 174 return await call_next(request) 175 176 # Get language information from the request 177 specified_source_lang = request.query_params.get("language") 178 translate_back = request.query_params.get("translate_back", "true").lower() == "true" 179 180 # Get the request body 181 try: 182 body = await self._get_request_body(request) 183 except Exception as e: 184 logger.error(f"Error retrieving request body: {e}") 185 return await call_next(request) 186 187 source_lang = None 188 needs_translation = False 189 text_to_translate = None 190 191 # Extract text from potential fields 192 for field_name in self.text_field_names: 193 if field_name in body: 194 text_to_translate = body[field_name] 195 break 196 197 # If no text is found, continue normally 198 if not text_to_translate or not isinstance(text_to_translate, str): 199 return await call_next(request) 200 201 # Get the source language (specified or detected) 202 if specified_source_lang: 203 source_lang = specified_source_lang 204 else: 205 source_lang = translation_manager.detect_language(text_to_translate) 206 207 # If the language is not detected or is already English, continue normally 208 if not source_lang or source_lang == "en": 209 return await call_next(request) 210 211 # Non-English language detected, translate 212 needs_translation = True 213 214 # Log the information 215 logger.info(f"Detected language: {source_lang}, translation enabled") 216 217 # Translate the text to English for processing 218 original_text = text_to_translate 219 translated_text = translation_manager.translate(text_to_translate, source_lang, "en") 220 221 # Modify the request body with the translated text 222 modified_body = body.copy() 223 for field_name in self.text_field_names: 224 if field_name in modified_body: 225 modified_body[field_name] = translated_text 226 break 227 228 # Modify the request with the new body 229 request._body = json.dumps(modified_body).encode() 230 231 # Add translation context headers 232 request.state.translation_context = { 233 "needs_translation": needs_translation, 234 "source_lang": source_lang, 235 "original_text": original_text, 236 "translate_back": translate_back 237 } 238 239 # Process the request 240 start_time = time.time() 241 response = await call_next(request) 242 processing_time = time.time() - start_time 243 244 # If no need to retranslate the response, return directly 245 if not needs_translation or not translate_back: 246 return response 247 248 # Check if it's a JSON response 249 if response.headers.get("content-type", "").startswith("application/json"): 250 # Get the response content 251 response_body = await self._get_response_body(response) 252 253 # Find the field containing the result to translate 254 fields_to_translate = self._find_text_fields_to_translate(response_body) 255 256 # Translate the found fields 257 modified_response = response_body.copy() 258 for field_path, value in fields_to_translate: 259 if isinstance(value, str) and len(value) > 5: 260 translated_value = translation_manager.translate(value, "en", source_lang) 261 self._set_field_value(modified_response, field_path, translated_value) 262 263 # Create a new response with the translated content 264 return JSONResponse( 265 content=modified_response, 266 status_code=response.status_code, 267 headers=dict(response.headers) 268 ) 269 270 return response 271 272 def _is_excluded_path(self, path: str) -> bool: 273 """Check if the path is excluded from the middleware.""" 274 if path in self.exclude_paths: 275 return True 276 277 for prefix in self.exclude_prefixes: 278 if path.startswith(prefix): 279 return True 280 281 return False 282 283 async def _get_request_body(self, request: Request) -> Dict[str, Any]: 284 """Extracts the JSON request body.""" 285 if not hasattr(request, "_body"): 286 body = await request.body() 287 request._body = body 288 289 body_str = request._body.decode("utf-8") 290 if not body_str: 291 return {} 292 293 try: 294 return json.loads(body_str) 295 except json.JSONDecodeError: 296 return {} 297 298 async def _get_response_body(self, response: Response) -> Dict[str, Any]: 299 """Extracts the JSON response body.""" 300 if isinstance(response, JSONResponse): 301 return response.body_dict 302 303 # For other response types, try to decode the content 304 try: 305 return json.loads(response.body.decode("utf-8")) 306 except Exception: 307 return {} 308 309 def _find_text_fields_to_translate(self, data: Union[Dict, List], path: List = None) -> List: 310 """Finds all text fields in JSON data that need to be translated.""" 311 if path is None: 312 path = [] 313 314 fields = [] 315 316 if isinstance(data, dict): 317 for key, value in data.items(): 318 current_path = path + [key] 319 320 # If it's a potential text field and not a technical field 321 if isinstance(value, str) and key in self.text_field_names: 322 fields.append((current_path, value)) 323 324 # Recursively explore sub-structures 325 if isinstance(value, (dict, list)): 326 fields.extend(self._find_text_fields_to_translate(value, current_path)) 327 328 # Specifically process transcription segments 329 if key == "segments" and isinstance(value, list): 330 for i, segment in enumerate(value): 331 if isinstance(segment, dict) and "text" in segment: 332 segment_path = current_path + [i, "text"] 333 fields.append((segment_path, segment["text"])) 334 335 elif isinstance(data, list): 336 for i, item in enumerate(data): 337 current_path = path + [i] 338 fields.extend(self._find_text_fields_to_translate(item, current_path)) 339 340 return fields 341 342 def _set_field_value(self, data: Dict, path: List, value: Any): 343 """Modifies the value of a field in JSON data according to the given path.""" 344 if not path: 345 return 346 347 current = data 348 for i, key in enumerate(path): 349 if i == len(path) - 1: 350 current[key] = value 351 else: 352 current = current[key]