/ haystack / components / audio / whisper_remote.py
whisper_remote.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import io
  6  from pathlib import Path
  7  from typing import Any
  8  
  9  from openai import OpenAI
 10  
 11  from haystack import Document, component, default_from_dict, default_to_dict, logging
 12  from haystack.dataclasses import ByteStream
 13  from haystack.utils import Secret
 14  from haystack.utils.http_client import init_http_client
 15  
 16  logger = logging.getLogger(__name__)
 17  
 18  
 19  @component
 20  class RemoteWhisperTranscriber:
 21      """
 22      Transcribes audio files using the OpenAI's Whisper API.
 23  
 24      The component requires an OpenAI API key, see the
 25      [OpenAI documentation](https://platform.openai.com/docs/api-reference/authentication) for more details.
 26      For the supported audio formats, languages, and other parameters, see the
 27      [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text).
 28  
 29      ### Usage example
 30  
 31      ```python
 32      from haystack.components.audio import RemoteWhisperTranscriber
 33  
 34      whisper = RemoteWhisperTranscriber(model="whisper-1")
 35      transcription = whisper.run(sources=["test/test_files/audio/answer.wav"])
 36      ```
 37      """
 38  
 39      def __init__(
 40          self,
 41          api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
 42          model: str = "whisper-1",
 43          api_base_url: str | None = None,
 44          organization: str | None = None,
 45          http_client_kwargs: dict[str, Any] | None = None,
 46          **kwargs: Any,
 47      ) -> None:
 48          """
 49          Creates an instance of the RemoteWhisperTranscriber component.
 50  
 51          :param api_key:
 52              OpenAI API key.
 53              You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
 54              during initialization.
 55          :param model:
 56              Name of the model to use. Currently accepts only `whisper-1`.
 57          :param organization:
 58              Your OpenAI organization ID. See OpenAI's documentation on
 59              [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
 60          :param api_base_url:
 61              An optional URL to use as the API base. For details, see the
 62              OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio).
 63          :param http_client_kwargs:
 64              A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
 65              For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
 66          :param kwargs:
 67              Other optional parameters for the model. These are sent directly to the OpenAI
 68              endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details.
 69              Some of the supported parameters are:
 70              - `language`: The language of the input audio.
 71                Provide the input language in ISO-639-1 format
 72                to improve transcription accuracy and latency.
 73              - `prompt`: An optional text to guide the model's
 74                style or continue a previous audio segment.
 75                The prompt should match the audio language.
 76              - `response_format`: The format of the transcript
 77                output. This component only supports `json`.
 78              - `temperature`: The sampling temperature, between 0
 79              and 1. Higher values like 0.8 make the output more
 80              random, while lower values like 0.2 make it more
 81              focused and deterministic. If set to 0, the model
 82              uses log probability to automatically increase the
 83              temperature until certain thresholds are hit.
 84          """
 85  
 86          self.organization = organization
 87          self.model = model
 88          self.api_base_url = api_base_url
 89          self.api_key = api_key
 90          self.http_client_kwargs = http_client_kwargs
 91  
 92          # Only response_format = "json" is supported
 93          whisper_params = kwargs
 94          response_format = whisper_params.get("response_format", "json")
 95          if response_format != "json":
 96              logger.warning(
 97                  "RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten."
 98              )
 99          whisper_params["response_format"] = "json"
100          self.whisper_params = whisper_params
101          self.client = OpenAI(
102              api_key=api_key.resolve_value(),
103              organization=organization,
104              base_url=api_base_url,
105              http_client=init_http_client(self.http_client_kwargs, async_client=False),
106          )
107  
108      def to_dict(self) -> dict[str, Any]:
109          """
110          Serializes the component to a dictionary.
111  
112          :returns:
113              Dictionary with serialized data.
114          """
115          return default_to_dict(
116              self,
117              api_key=self.api_key,
118              model=self.model,
119              organization=self.organization,
120              api_base_url=self.api_base_url,
121              http_client_kwargs=self.http_client_kwargs,
122              **self.whisper_params,
123          )
124  
125      @classmethod
126      def from_dict(cls, data: dict[str, Any]) -> "RemoteWhisperTranscriber":
127          """
128          Deserializes the component from a dictionary.
129  
130          :param data:
131              The dictionary to deserialize from.
132          :returns:
133              The deserialized component.
134          """
135          return default_from_dict(cls, data)
136  
137      @component.output_types(documents=list[Document])
138      def run(self, sources: list[str | Path | ByteStream]) -> dict[str, Any]:
139          """
140          Transcribes the list of audio files into a list of documents.
141  
142          :param sources:
143              A list of file paths or `ByteStream` objects containing the audio files to transcribe.
144  
145          :returns: A dictionary with the following keys:
146              - `documents`: A list of documents, one document for each file.
147                  The content of each document is the transcribed text.
148          """
149          documents = []
150  
151          for source in sources:
152              if not isinstance(source, ByteStream):
153                  path = source
154                  source = ByteStream.from_file_path(Path(source))
155                  source.meta["file_path"] = path
156  
157              file = io.BytesIO(source.data)
158              file.name = str(source.meta["file_path"]) if "file_path" in source.meta else "__fallback__.wav"
159  
160              content = self.client.audio.transcriptions.create(file=file, model=self.model, **self.whisper_params)
161              doc = Document(content=content.text, meta=source.meta)
162              documents.append(doc)
163  
164          return {"documents": documents}