/ haystack / components / embedders / azure_document_embedder.py
azure_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from typing import Any
  7  
  8  from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
  9  
 10  from haystack import component, default_from_dict, default_to_dict, logging
 11  from haystack.components.embedders import OpenAIDocumentEmbedder
 12  from haystack.utils import Secret, deserialize_callable, serialize_callable
 13  from haystack.utils.http_client import init_http_client
 14  
 15  logger = logging.getLogger(__name__)
 16  
 17  
 18  @component
 19  class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
 20      """
 21      Calculates document embeddings using OpenAI models deployed on Azure.
 22  
 23      ### Usage example
 24      <!-- test-ignore -->
 25      ```python
 26      from haystack import Document
 27      from haystack.components.embedders import AzureOpenAIDocumentEmbedder
 28  
 29      doc = Document(content="I love pizza!")
 30      document_embedder = AzureOpenAIDocumentEmbedder()
 31  
 32      result = document_embedder.run([doc])
 33      print(result['documents'][0].embedding)
 34  
 35      # [0.017020374536514282, -0.023255806416273117, ...]
 36      ```
 37      """
 38  
 39      def __init__(  # noqa: PLR0913 (too-many-arguments)
 40          self,
 41          azure_endpoint: str | None = None,
 42          api_version: str | None = "2023-05-15",
 43          azure_deployment: str = "text-embedding-ada-002",
 44          dimensions: int | None = None,
 45          api_key: Secret | None = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
 46          azure_ad_token: Secret | None = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
 47          organization: str | None = None,
 48          prefix: str = "",
 49          suffix: str = "",
 50          batch_size: int = 32,
 51          progress_bar: bool = True,
 52          meta_fields_to_embed: list[str] | None = None,
 53          embedding_separator: str = "\n",
 54          timeout: float | None = None,
 55          max_retries: int | None = None,
 56          *,
 57          default_headers: dict[str, str] | None = None,
 58          azure_ad_token_provider: AzureADTokenProvider | None = None,
 59          http_client_kwargs: dict[str, Any] | None = None,
 60          raise_on_failure: bool = False,
 61      ) -> None:
 62          """
 63          Creates an AzureOpenAIDocumentEmbedder component.
 64  
 65          :param azure_endpoint:
 66              The endpoint of the model deployed on Azure.
 67          :param api_version:
 68              The version of the API to use.
 69          :param azure_deployment:
 70              The name of the model deployed on Azure. The default model is text-embedding-ada-002.
 71          :param dimensions:
 72              The number of dimensions of the resulting embeddings. Only supported in text-embedding-3
 73              and later models.
 74          :param api_key:
 75              The Azure OpenAI API key.
 76              You can set it with an environment variable `AZURE_OPENAI_API_KEY`, or pass with this
 77              parameter during initialization.
 78          :param azure_ad_token:
 79              Microsoft Entra ID token, see Microsoft's
 80              [Entra ID](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id)
 81              documentation for more information. You can set it with an environment variable
 82              `AZURE_OPENAI_AD_TOKEN`, or pass with this parameter during initialization.
 83              Previously called Azure Active Directory.
 84          :param organization:
 85              Your organization ID. See OpenAI's
 86              [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
 87              for more information.
 88          :param prefix:
 89              A string to add at the beginning of each text.
 90          :param suffix:
 91              A string to add at the end of each text.
 92          :param batch_size:
 93              Number of documents to embed at once.
 94          :param progress_bar:
 95              If `True`, shows a progress bar when running.
 96          :param meta_fields_to_embed:
 97              List of metadata fields to embed along with the document text.
 98          :param embedding_separator:
 99              Separator used to concatenate the metadata fields to the document text.
100          :param timeout: The timeout for `AzureOpenAI` client calls, in seconds.
101              If not set, defaults to either the
102              `OPENAI_TIMEOUT` environment variable, or 30 seconds.
103          :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
104              If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
105          :param default_headers: Default headers to send to the AzureOpenAI client.
106          :param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
107              every request.
108          :param http_client_kwargs:
109              A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
110              For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
111          :param raise_on_failure:
112              Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
113              and continue processing the remaining documents. If `True`, it will raise an exception on failure.
114          """
115          # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
116          # with the API.
117  
118          # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
119          azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
120          if not azure_endpoint:
121              raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
122  
123          if api_key is None and azure_ad_token is None:
124              raise ValueError("Please provide an API key or an Azure Active Directory token.")
125  
126          self.api_key = api_key  # type: ignore[assignment] # mypy does not understand that api_key can be None
127          self.azure_ad_token = azure_ad_token
128          self.api_version = api_version
129          self.azure_endpoint = azure_endpoint
130          self.azure_deployment = azure_deployment
131          self.model = azure_deployment
132          self.dimensions = dimensions
133          self.organization = organization
134          self.prefix = prefix
135          self.suffix = suffix
136          self.batch_size = batch_size
137          self.progress_bar = progress_bar
138          self.meta_fields_to_embed = meta_fields_to_embed or []
139          self.embedding_separator = embedding_separator
140          self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
141          self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
142          self.default_headers = default_headers or {}
143          self.azure_ad_token_provider = azure_ad_token_provider
144          self.http_client_kwargs = http_client_kwargs
145          self.raise_on_failure = raise_on_failure
146  
147          client_args: dict[str, Any] = {
148              "api_version": api_version,
149              "azure_endpoint": azure_endpoint,
150              "azure_deployment": azure_deployment,
151              "azure_ad_token_provider": azure_ad_token_provider,
152              "api_key": api_key.resolve_value() if api_key is not None else None,
153              "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None,
154              "organization": organization,
155              "timeout": self.timeout,
156              "max_retries": self.max_retries,
157              "default_headers": self.default_headers,
158          }
159  
160          self.client = AzureOpenAI(
161              http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
162          )
163          self.async_client = AsyncAzureOpenAI(
164              http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
165          )
166  
167      def to_dict(self) -> dict[str, Any]:
168          """
169          Serializes the component to a dictionary.
170  
171          :returns:
172              Dictionary with serialized data.
173          """
174          azure_ad_token_provider_name = None
175          if self.azure_ad_token_provider:
176              azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
177          return default_to_dict(
178              self,
179              azure_endpoint=self.azure_endpoint,
180              azure_deployment=self.azure_deployment,
181              dimensions=self.dimensions,
182              organization=self.organization,
183              api_version=self.api_version,
184              prefix=self.prefix,
185              suffix=self.suffix,
186              batch_size=self.batch_size,
187              progress_bar=self.progress_bar,
188              meta_fields_to_embed=self.meta_fields_to_embed,
189              embedding_separator=self.embedding_separator,
190              api_key=self.api_key,
191              azure_ad_token=self.azure_ad_token,
192              timeout=self.timeout,
193              max_retries=self.max_retries,
194              default_headers=self.default_headers,
195              azure_ad_token_provider=azure_ad_token_provider_name,
196              http_client_kwargs=self.http_client_kwargs,
197              raise_on_failure=self.raise_on_failure,
198          )
199  
200      @classmethod
201      def from_dict(cls, data: dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
202          """
203          Deserializes the component from a dictionary.
204  
205          :param data:
206              Dictionary to deserialize from.
207          :returns:
208              Deserialized component.
209          """
210          serialized_azure_ad_token_provider = data["init_parameters"].get("azure_ad_token_provider")
211          if serialized_azure_ad_token_provider:
212              data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
213                  serialized_azure_ad_token_provider
214              )
215          return default_from_dict(cls, data)