/ middleware / translation_middleware.py
translation_middleware.py
  1  """
  2  Middleware for language detection and automatic translation.
  3  This module detects the input language of text, translates it to English if necessary,
  4  then retranslates the response into the original language.
  5  """
  6  
  7  import json
  8  import logging
  9  from typing import Dict, Any, Optional, List, Union, Callable
 10  import time
 11  from fastapi import Request, Response
 12  from fastapi.responses import JSONResponse
 13  from starlette.middleware.base import BaseHTTPMiddleware
 14  from starlette.types import ASGIApp
 15  
 16  # For language detection
 17  from langdetect import detect, LangDetectException
 18  from langdetect.detector_factory import DetectorFactory
 19  DetectorFactory.seed = 0  # For consistent results
 20  
 21  # For translation
 22  from transformers import MarianMTModel, MarianTokenizer
 23  import torch
 24  
 25  # Logging configuration
 26  logger = logging.getLogger("translation_middleware")
 27  
 28  # Models storage path
 29  MODELS_CACHE_DIR = "translation_models"
 30  
 31  # List of supported languages - ISO 639-1 codes
 32  SUPPORTED_LANGUAGES = {
 33      "fr": "French",
 34      "es": "Spanish",
 35      "de": "German",
 36      "it": "Italian",
 37      "pt": "Portuguese",
 38      "nl": "Dutch",
 39      "ru": "Russian",
 40      "zh": "Chinese",
 41      "ja": "Japanese",
 42      "ar": "Arabic",
 43      # Add other languages as needed
 44  }
 45  
 46  # Minimum confidence threshold for language detection
 47  LANGUAGE_DETECTION_THRESHOLD = 0.85
 48  
 49  class TranslationManager:
 50      """Translation manager using Hugging Face models."""
 51      
 52      def __init__(self):
 53          self.tokenizers = {}
 54          self.models = {}
 55          self.device = "cuda" if torch.cuda.is_available() else "cpu"
 56          logger.info(f"TranslationManager initialized on {self.device}")
 57      
 58      def get_model_name(self, source_lang: str, target_lang: str) -> str:
 59          """Returns the model name for the language pair."""
 60          if source_lang == "en" and target_lang in SUPPORTED_LANGUAGES:
 61              return f"Helsinki-NLP/opus-mt-en-{target_lang}"
 62          elif target_lang == "en" and source_lang in SUPPORTED_LANGUAGES:
 63              return f"Helsinki-NLP/opus-mt-{source_lang}-en"
 64          else:
 65              # Fallback for pairs not directly supported
 66              return f"Helsinki-NLP/opus-mt-mul-en" if target_lang == "en" else f"Helsinki-NLP/opus-mt-en-mul"
 67      
 68      def load_model(self, source_lang: str, target_lang: str) -> None:
 69          """Loads a translation model for a language pair."""
 70          model_key = f"{source_lang}-{target_lang}"
 71          
 72          if model_key in self.models:
 73              return
 74          
 75          model_name = self.get_model_name(source_lang, target_lang)
 76          logger.info(f"Loading translation model {model_name}")
 77          
 78          try:
 79              tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir=MODELS_CACHE_DIR)
 80              model = MarianMTModel.from_pretrained(model_name, cache_dir=MODELS_CACHE_DIR)
 81              
 82              # Move to GPU if available
 83              if self.device == "cuda":
 84                  model = model.to(self.device)
 85              
 86              self.tokenizers[model_key] = tokenizer
 87              self.models[model_key] = model
 88              logger.info(f"Model {model_name} loaded successfully")
 89          except Exception as e:
 90              logger.error(f"Error loading model {model_name}: {e}")
 91              raise
 92      
 93      def translate(self, text: str, source_lang: str, target_lang: str) -> str:
 94          """Translates text from a source language to a target language."""
 95          if source_lang == target_lang:
 96              return text
 97          
 98          model_key = f"{source_lang}-{target_lang}"
 99          
100          # Load the model if necessary
101          if model_key not in self.models:
102              self.load_model(source_lang, target_lang)
103          
104          tokenizer = self.tokenizers[model_key]
105          model = self.models[model_key]
106          
107          try:
108              # Tokenization
109              encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
110              
111              # Move to GPU if available
112              if self.device == "cuda":
113                  encoded = {k: v.to(self.device) for k, v in encoded.items()}
114              
115              # Translation
116              translated = model.generate(**encoded)
117              
118              # Decoding
119              result = tokenizer.decode(translated[0], skip_special_tokens=True)
120              return result
121          except Exception as e:
122              logger.error(f"Error during translation: {e}")
123              return text  # In case of error, return the original text
124      
125      def detect_language(self, text: str) -> Optional[str]:
126          """Detects the language of a text."""
127          if not text or len(text) < 10:
128              return None
129          
130          try:
131              detected_lang = detect(text)
132              return detected_lang if detected_lang in SUPPORTED_LANGUAGES else None
133          except LangDetectException:
134              return None
135      
136      def close(self):
137          """Releases resources."""
138          self.models.clear()
139          self.tokenizers.clear()
140          if torch.cuda.is_available():
141              torch.cuda.empty_cache()
142  
143  # Instantiate the translation manager
144  translation_manager = TranslationManager()
145  
146  class TranslationMiddleware(BaseHTTPMiddleware):
147      """Middleware for language detection and automatic translation."""
148      
149      def __init__(
150          self,
151          app: ASGIApp,
152          exclude_paths: List[str] = None,
153          exclude_prefixes: List[str] = None,
154          text_field_names: List[str] = None
155      ):
156          super().__init__(app)
157          self.exclude_paths = exclude_paths or ["/health", "/docs", "/redoc", "/openapi.json"]
158          self.exclude_prefixes = exclude_prefixes or ["/static/", "/assets/"]
159          self.text_field_names = text_field_names or ["text", "content", "prompt", "transcription"]
160      
161      async def dispatch(self, request: Request, call_next: Callable):
162          """Processes requests and responses with language detection and translation."""
163          # Check if this path should be excluded
164          if self._is_excluded_path(request.url.path):
165              return await call_next(request)
166          
167          # Check if it's an inference or transcription request
168          if not (request.url.path.startswith("/api/inference") or 
169                  request.url.path.startswith("/api/transcription")):
170              return await call_next(request)
171          
172          # Check HTTP method
173          if request.method != "POST":
174              return await call_next(request)
175          
176          # Get language information from the request
177          specified_source_lang = request.query_params.get("language")
178          translate_back = request.query_params.get("translate_back", "true").lower() == "true"
179          
180          # Get the request body
181          try:
182              body = await self._get_request_body(request)
183          except Exception as e:
184              logger.error(f"Error retrieving request body: {e}")
185              return await call_next(request)
186          
187          source_lang = None
188          needs_translation = False
189          text_to_translate = None
190          
191          # Extract text from potential fields
192          for field_name in self.text_field_names:
193              if field_name in body:
194                  text_to_translate = body[field_name]
195                  break
196          
197          # If no text is found, continue normally
198          if not text_to_translate or not isinstance(text_to_translate, str):
199              return await call_next(request)
200          
201          # Get the source language (specified or detected)
202          if specified_source_lang:
203              source_lang = specified_source_lang
204          else:
205              source_lang = translation_manager.detect_language(text_to_translate)
206          
207          # If the language is not detected or is already English, continue normally
208          if not source_lang or source_lang == "en":
209              return await call_next(request)
210          
211          # Non-English language detected, translate
212          needs_translation = True
213          
214          # Log the information
215          logger.info(f"Detected language: {source_lang}, translation enabled")
216          
217          # Translate the text to English for processing
218          original_text = text_to_translate
219          translated_text = translation_manager.translate(text_to_translate, source_lang, "en")
220          
221          # Modify the request body with the translated text
222          modified_body = body.copy()
223          for field_name in self.text_field_names:
224              if field_name in modified_body:
225                  modified_body[field_name] = translated_text
226                  break
227          
228          # Modify the request with the new body
229          request._body = json.dumps(modified_body).encode()
230          
231          # Add translation context headers
232          request.state.translation_context = {
233              "needs_translation": needs_translation,
234              "source_lang": source_lang,
235              "original_text": original_text,
236              "translate_back": translate_back
237          }
238          
239          # Process the request
240          start_time = time.time()
241          response = await call_next(request)
242          processing_time = time.time() - start_time
243          
244          # If no need to retranslate the response, return directly
245          if not needs_translation or not translate_back:
246              return response
247          
248          # Check if it's a JSON response
249          if response.headers.get("content-type", "").startswith("application/json"):
250              # Get the response content
251              response_body = await self._get_response_body(response)
252              
253              # Find the field containing the result to translate
254              fields_to_translate = self._find_text_fields_to_translate(response_body)
255              
256              # Translate the found fields
257              modified_response = response_body.copy()
258              for field_path, value in fields_to_translate:
259                  if isinstance(value, str) and len(value) > 5:
260                      translated_value = translation_manager.translate(value, "en", source_lang)
261                      self._set_field_value(modified_response, field_path, translated_value)
262              
263              # Create a new response with the translated content
264              return JSONResponse(
265                  content=modified_response,
266                  status_code=response.status_code,
267                  headers=dict(response.headers)
268              )
269          
270          return response
271      
272      def _is_excluded_path(self, path: str) -> bool:
273          """Check if the path is excluded from the middleware."""
274          if path in self.exclude_paths:
275              return True
276          
277          for prefix in self.exclude_prefixes:
278              if path.startswith(prefix):
279                  return True
280          
281          return False
282      
283      async def _get_request_body(self, request: Request) -> Dict[str, Any]:
284          """Extracts the JSON request body."""
285          if not hasattr(request, "_body"):
286              body = await request.body()
287              request._body = body
288          
289          body_str = request._body.decode("utf-8")
290          if not body_str:
291              return {}
292          
293          try:
294              return json.loads(body_str)
295          except json.JSONDecodeError:
296              return {}
297      
298      async def _get_response_body(self, response: Response) -> Dict[str, Any]:
299          """Extracts the JSON response body."""
300          if isinstance(response, JSONResponse):
301              return response.body_dict
302          
303          # For other response types, try to decode the content
304          try:
305              return json.loads(response.body.decode("utf-8"))
306          except Exception:
307              return {}
308      
309      def _find_text_fields_to_translate(self, data: Union[Dict, List], path: List = None) -> List:
310          """Finds all text fields in JSON data that need to be translated."""
311          if path is None:
312              path = []
313          
314          fields = []
315          
316          if isinstance(data, dict):
317              for key, value in data.items():
318                  current_path = path + [key]
319                  
320                  # If it's a potential text field and not a technical field
321                  if isinstance(value, str) and key in self.text_field_names:
322                      fields.append((current_path, value))
323                  
324                  # Recursively explore sub-structures
325                  if isinstance(value, (dict, list)):
326                      fields.extend(self._find_text_fields_to_translate(value, current_path))
327                      
328                  # Specifically process transcription segments
329                  if key == "segments" and isinstance(value, list):
330                      for i, segment in enumerate(value):
331                          if isinstance(segment, dict) and "text" in segment:
332                              segment_path = current_path + [i, "text"]
333                              fields.append((segment_path, segment["text"]))
334          
335          elif isinstance(data, list):
336              for i, item in enumerate(data):
337                  current_path = path + [i]
338                  fields.extend(self._find_text_fields_to_translate(item, current_path))
339          
340          return fields
341      
342      def _set_field_value(self, data: Dict, path: List, value: Any):
343          """Modifies the value of a field in JSON data according to the given path."""
344          if not path:
345              return
346          
347          current = data
348          for i, key in enumerate(path):
349              if i == len(path) - 1:
350                  current[key] = value
351              else:
352                  current = current[key]