/ src / loaders / whisper_loader.py
whisper_loader.py
  1  from __future__ import annotations
  2  
  3  """Whisper-based video/audio loader implementation."""
  4  
  5  import logging
  6  from collections.abc import Sequence
  7  from pathlib import Path
  8  from typing import Any
  9  
 10  import whisper
 11  from langchain_core.documents import Document
 12  
 13  from .protocol import DocumentLoader
 14  
 15  PageSpecifier = Sequence[int] | range | None
 16  
 17  
 18  class WhisperLoader(DocumentLoader):
 19      """Load video/audio files using Whisper and convert to text/markdown.
 20  
 21      This loader uses OpenAI's Whisper model to transcribe audio/video files
 22      and converts the transcription to LangChain Document objects and Markdown.
 23      """
 24  
 25      def __init__(self, config: dict[str, Any]) -> None:
 26          """Initialize the loader with a configuration dictionary.
 27  
 28          Parameters
 29          ----------
 30          config
 31              Configuration dictionary. Must contain 'file_path' key.
 32              Optional keys:
 33              - output_dir: Directory for output markdown files
 34              - model_name: Whisper model to use (default: "base")
 35              - device: Device to run on ("cpu" or "cuda", default: "cpu")
 36              - language: Language code (e.g., "en", "es"). If None, auto-detects.
 37              - Any other keys are stored and can be accessed via self.config
 38          """
 39          self.logger = logging.getLogger(__name__)
 40  
 41          file_path = config.get("file_path")
 42          if not file_path:
 43              raise ValueError("'file_path' is required in loader configuration")
 44  
 45          self.input_path = Path(file_path).expanduser().resolve()
 46          if not self.input_path.exists():
 47              raise FileNotFoundError(f"Video file not found: {self.input_path}")
 48  
 49          supported_extensions = {".mp4", ".mp3", ".wav", ".m4a", ".avi", ".mov", ".mkv"}
 50          if self.input_path.suffix.lower() not in supported_extensions:
 51              raise ValueError(
 52                  f"Unsupported file type: {self.input_path.suffix}. "
 53                  f"Supported: {', '.join(supported_extensions)}"
 54              )
 55  
 56          output_dir = config.get("output_dir")
 57          self.output_dir = Path(output_dir).expanduser().resolve() if output_dir else None
 58  
 59          self.model_name = config.get("model_name", "base")
 60          self.device = config.get("device", "cpu")
 61          self.language = config.get("language", "en")
 62          self.include_timestamps = config.get("include_timestamps", True)
 63  
 64          self._model: whisper.Whisper | None = None
 65  
 66      def _load_model(self) -> whisper.Whisper:
 67          """Load the Whisper model (lazy loading).
 68  
 69          Returns
 70          -------
 71          Loaded Whisper model instance.
 72          """
 73          if self._model is None:
 74              self.logger.info(f"Loading Whisper model: {self.model_name}")
 75              self._model = whisper.load_model(self.model_name, device=self.device)
 76          return self._model
 77  
 78      def _transcribe(self) -> dict[str, Any]:
 79          """Transcribe the video/audio file using Whisper and return full result.
 80  
 81          Returns
 82          -------
 83          Full transcription result dictionary with segments and text.
 84          """
 85          try:
 86              self.logger.info(f"Transcribing audio from: {self.input_path}")
 87              model = self._load_model()
 88  
 89              transcribe_kwargs = {
 90                  "language": self.language,
 91              }
 92  
 93              result = model.transcribe(str(self.input_path), **transcribe_kwargs)
 94  
 95              if not result.get("text", "").strip():
 96                  self.logger.warning(f"Empty transcription for {self.input_path}")
 97  
 98              return result
 99          except Exception as e:
100              raise RuntimeError(f"Failed to transcribe {self.input_path}: {e}") from e
101  
102      def _format_timestamp(self, seconds: float) -> str:
103          """Format seconds as HH:MM:SS.mmm.
104  
105          Parameters
106          ----------
107          seconds
108              Time in seconds.
109  
110          Returns
111          -------
112          Formatted timestamp string.
113          """
114          hours = int(seconds // 3600)
115          minutes = int((seconds % 3600) // 60)
116          secs = int(seconds % 60)
117          millis = int((seconds % 1) * 1000)
118          return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millis:03d}"
119  
120      def _to_markdown_text(self) -> str:
121          """Convert transcription to markdown text.
122  
123          Returns
124          -------
125          Markdown text representation of the transcription.
126          """
127          result = self._transcribe()
128  
129          if not result.get("text", "").strip():
130              return ""
131  
132          if not self.include_timestamps:
133              return f"# Transcription\n\n{result['text'].strip()}"
134  
135          segments = result.get("segments", [])
136          if not segments:
137              return f"# Transcription\n\n{result['text'].strip()}"
138  
139          markdown_lines = ["# Transcription\n"]
140          for segment in segments:
141              start_time = segment.get("start", 0)
142              end_time = segment.get("end", 0)
143              text = segment.get("text", "").strip()
144  
145              if not text:
146                  continue
147  
148              timestamp_str = f"[{self._format_timestamp(start_time)} - {self._format_timestamp(end_time)}]"
149              markdown_lines.append(f"{timestamp_str}\n{text}\n")
150  
151          return "\n".join(markdown_lines)
152  
153      def load_documents(self) -> list[Document]:
154          """Load the video/audio file into LangChain Document objects with markdown-formatted content.
155  
156          Returns
157          -------
158          List of Document objects with markdown transcription in page_content.
159          """
160          markdown_content = self._to_markdown_text()
161  
162          return [
163              Document(
164                  page_content=markdown_content,
165                  metadata={
166                      "source": str(self.input_path),
167                      "file_name": self.input_path.name,
168                      "loader": "WhisperLoader",
169                      "file_type": self.input_path.suffix.lower(),
170                      "model": self.model_name,
171                      "language": self.language or "auto-detected",
172                  }
173              )
174          ]
175  
176  
177  def create_whisper_loader(config: dict[str, Any]) -> DocumentLoader:
178      """Create a Whisper loader from configuration.
179  
180      Parameters
181      ----------
182      config
183          Configuration dictionary. Must contain:
184          - file_path (required): Path to the video/audio file
185          Optional keys:
186          - output_dir (optional): Directory for output markdown files
187          - model_name (optional): Whisper model name (default: "base")
188          - device (optional): Device to run on (default: "cpu")
189          - language (optional): Language code for transcription (default: "en")
190          - include_timestamps (optional): Include timestamps in output (default: True)
191  
192      Returns
193      -------
194      DocumentLoader instance.
195      """
196      return WhisperLoader(config=config)
197