azure_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from typing import Any 7 8 from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI 9 10 from haystack import component, default_from_dict, default_to_dict, logging 11 from haystack.components.embedders import OpenAIDocumentEmbedder 12 from haystack.utils import Secret, deserialize_callable, serialize_callable 13 from haystack.utils.http_client import init_http_client 14 15 logger = logging.getLogger(__name__) 16 17 18 @component 19 class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder): 20 """ 21 Calculates document embeddings using OpenAI models deployed on Azure. 22 23 ### Usage example 24 <!-- test-ignore --> 25 ```python 26 from haystack import Document 27 from haystack.components.embedders import AzureOpenAIDocumentEmbedder 28 29 doc = Document(content="I love pizza!") 30 document_embedder = AzureOpenAIDocumentEmbedder() 31 32 result = document_embedder.run([doc]) 33 print(result['documents'][0].embedding) 34 35 # [0.017020374536514282, -0.023255806416273117, ...] 36 ``` 37 """ 38 39 def __init__( # noqa: PLR0913 (too-many-arguments) 40 self, 41 azure_endpoint: str | None = None, 42 api_version: str | None = "2023-05-15", 43 azure_deployment: str = "text-embedding-ada-002", 44 dimensions: int | None = None, 45 api_key: Secret | None = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False), 46 azure_ad_token: Secret | None = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False), 47 organization: str | None = None, 48 prefix: str = "", 49 suffix: str = "", 50 batch_size: int = 32, 51 progress_bar: bool = True, 52 meta_fields_to_embed: list[str] | None = None, 53 embedding_separator: str = "\n", 54 timeout: float | None = None, 55 max_retries: int | None = None, 56 *, 57 default_headers: dict[str, str] | None = None, 58 azure_ad_token_provider: AzureADTokenProvider | None = None, 59 http_client_kwargs: dict[str, Any] | None = None, 60 raise_on_failure: bool = False, 61 ) -> None: 62 """ 63 Creates an AzureOpenAIDocumentEmbedder component. 64 65 :param azure_endpoint: 66 The endpoint of the model deployed on Azure. 67 :param api_version: 68 The version of the API to use. 69 :param azure_deployment: 70 The name of the model deployed on Azure. The default model is text-embedding-ada-002. 71 :param dimensions: 72 The number of dimensions of the resulting embeddings. Only supported in text-embedding-3 73 and later models. 74 :param api_key: 75 The Azure OpenAI API key. 76 You can set it with an environment variable `AZURE_OPENAI_API_KEY`, or pass with this 77 parameter during initialization. 78 :param azure_ad_token: 79 Microsoft Entra ID token, see Microsoft's 80 [Entra ID](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id) 81 documentation for more information. You can set it with an environment variable 82 `AZURE_OPENAI_AD_TOKEN`, or pass with this parameter during initialization. 83 Previously called Azure Active Directory. 84 :param organization: 85 Your organization ID. See OpenAI's 86 [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization) 87 for more information. 88 :param prefix: 89 A string to add at the beginning of each text. 90 :param suffix: 91 A string to add at the end of each text. 92 :param batch_size: 93 Number of documents to embed at once. 94 :param progress_bar: 95 If `True`, shows a progress bar when running. 96 :param meta_fields_to_embed: 97 List of metadata fields to embed along with the document text. 98 :param embedding_separator: 99 Separator used to concatenate the metadata fields to the document text. 100 :param timeout: The timeout for `AzureOpenAI` client calls, in seconds. 101 If not set, defaults to either the 102 `OPENAI_TIMEOUT` environment variable, or 30 seconds. 103 :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error. 104 If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries. 105 :param default_headers: Default headers to send to the AzureOpenAI client. 106 :param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on 107 every request. 108 :param http_client_kwargs: 109 A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`. 110 For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client). 111 :param raise_on_failure: 112 Whether to raise an exception if the embedding request fails. If `False`, the component will log the error 113 and continue processing the remaining documents. If `True`, it will raise an exception on failure. 114 """ 115 # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact 116 # with the API. 117 118 # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT 119 azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") 120 if not azure_endpoint: 121 raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") 122 123 if api_key is None and azure_ad_token is None: 124 raise ValueError("Please provide an API key or an Azure Active Directory token.") 125 126 self.api_key = api_key # type: ignore[assignment] # mypy does not understand that api_key can be None 127 self.azure_ad_token = azure_ad_token 128 self.api_version = api_version 129 self.azure_endpoint = azure_endpoint 130 self.azure_deployment = azure_deployment 131 self.model = azure_deployment 132 self.dimensions = dimensions 133 self.organization = organization 134 self.prefix = prefix 135 self.suffix = suffix 136 self.batch_size = batch_size 137 self.progress_bar = progress_bar 138 self.meta_fields_to_embed = meta_fields_to_embed or [] 139 self.embedding_separator = embedding_separator 140 self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) 141 self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) 142 self.default_headers = default_headers or {} 143 self.azure_ad_token_provider = azure_ad_token_provider 144 self.http_client_kwargs = http_client_kwargs 145 self.raise_on_failure = raise_on_failure 146 147 client_args: dict[str, Any] = { 148 "api_version": api_version, 149 "azure_endpoint": azure_endpoint, 150 "azure_deployment": azure_deployment, 151 "azure_ad_token_provider": azure_ad_token_provider, 152 "api_key": api_key.resolve_value() if api_key is not None else None, 153 "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None, 154 "organization": organization, 155 "timeout": self.timeout, 156 "max_retries": self.max_retries, 157 "default_headers": self.default_headers, 158 } 159 160 self.client = AzureOpenAI( 161 http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args 162 ) 163 self.async_client = AsyncAzureOpenAI( 164 http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args 165 ) 166 167 def to_dict(self) -> dict[str, Any]: 168 """ 169 Serializes the component to a dictionary. 170 171 :returns: 172 Dictionary with serialized data. 173 """ 174 azure_ad_token_provider_name = None 175 if self.azure_ad_token_provider: 176 azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider) 177 return default_to_dict( 178 self, 179 azure_endpoint=self.azure_endpoint, 180 azure_deployment=self.azure_deployment, 181 dimensions=self.dimensions, 182 organization=self.organization, 183 api_version=self.api_version, 184 prefix=self.prefix, 185 suffix=self.suffix, 186 batch_size=self.batch_size, 187 progress_bar=self.progress_bar, 188 meta_fields_to_embed=self.meta_fields_to_embed, 189 embedding_separator=self.embedding_separator, 190 api_key=self.api_key, 191 azure_ad_token=self.azure_ad_token, 192 timeout=self.timeout, 193 max_retries=self.max_retries, 194 default_headers=self.default_headers, 195 azure_ad_token_provider=azure_ad_token_provider_name, 196 http_client_kwargs=self.http_client_kwargs, 197 raise_on_failure=self.raise_on_failure, 198 ) 199 200 @classmethod 201 def from_dict(cls, data: dict[str, Any]) -> "AzureOpenAIDocumentEmbedder": 202 """ 203 Deserializes the component from a dictionary. 204 205 :param data: 206 Dictionary to deserialize from. 207 :returns: 208 Deserialized component. 209 """ 210 serialized_azure_ad_token_provider = data["init_parameters"].get("azure_ad_token_provider") 211 if serialized_azure_ad_token_provider: 212 data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable( 213 serialized_azure_ad_token_provider 214 ) 215 return default_from_dict(cls, data)