/ test / components / embedders / test_azure_document_embedder.py
test_azure_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from unittest.mock import Mock, patch
  7  
  8  import pytest
  9  from openai import APIError
 10  
 11  from haystack import Document
 12  from haystack.components.embedders import AzureOpenAIDocumentEmbedder
 13  from haystack.utils.auth import Secret
 14  from haystack.utils.azure import default_azure_ad_token_provider
 15  
 16  
 17  class TestAzureOpenAIDocumentEmbedder:
 18      def test_init_default(self, monkeypatch):
 19          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
 20          embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
 21          assert embedder.azure_deployment == "text-embedding-ada-002"
 22          assert embedder.model == "text-embedding-ada-002"
 23          assert embedder.dimensions is None
 24          assert embedder.organization is None
 25          assert embedder.prefix == ""
 26          assert embedder.suffix == ""
 27          assert embedder.batch_size == 32
 28          assert embedder.progress_bar is True
 29          assert embedder.meta_fields_to_embed == []
 30          assert embedder.embedding_separator == "\n"
 31          assert embedder.default_headers == {}
 32          assert embedder.azure_ad_token_provider is None
 33          assert embedder.http_client_kwargs is None
 34  
 35      def test_init_with_0_max_retries(self, monkeypatch):
 36          """Tests that the max_retries init param is set correctly if equal 0"""
 37          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
 38          embedder = AzureOpenAIDocumentEmbedder(
 39              azure_endpoint="https://example-resource.azure.openai.com/", max_retries=0
 40          )
 41          assert embedder.azure_deployment == "text-embedding-ada-002"
 42          assert embedder.model == "text-embedding-ada-002"
 43          assert embedder.dimensions is None
 44          assert embedder.organization is None
 45          assert embedder.prefix == ""
 46          assert embedder.suffix == ""
 47          assert embedder.batch_size == 32
 48          assert embedder.progress_bar is True
 49          assert embedder.meta_fields_to_embed == []
 50          assert embedder.embedding_separator == "\n"
 51          assert embedder.default_headers == {}
 52          assert embedder.azure_ad_token_provider is None
 53          assert embedder.max_retries == 0
 54  
 55      def test_to_dict(self, monkeypatch):
 56          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
 57          component = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
 58          data = component.to_dict()
 59          assert data == {
 60              "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
 61              "init_parameters": {
 62                  "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
 63                  "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
 64                  "api_version": "2023-05-15",
 65                  "azure_deployment": "text-embedding-ada-002",
 66                  "dimensions": None,
 67                  "azure_endpoint": "https://example-resource.azure.openai.com/",
 68                  "organization": None,
 69                  "prefix": "",
 70                  "suffix": "",
 71                  "batch_size": 32,
 72                  "progress_bar": True,
 73                  "meta_fields_to_embed": [],
 74                  "embedding_separator": "\n",
 75                  "max_retries": 5,
 76                  "timeout": 30.0,
 77                  "default_headers": {},
 78                  "azure_ad_token_provider": None,
 79                  "http_client_kwargs": None,
 80                  "raise_on_failure": False,
 81              },
 82          }
 83  
 84      def test_to_dict_with_parameters(self, monkeypatch):
 85          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
 86          component = AzureOpenAIDocumentEmbedder(
 87              azure_endpoint="https://example-resource.azure.openai.com/",
 88              azure_deployment="text-embedding-ada-002",
 89              dimensions=768,
 90              organization="HaystackCI",
 91              timeout=60.0,
 92              max_retries=10,
 93              prefix="prefix ",
 94              suffix=" suffix",
 95              default_headers={"x-custom-header": "custom-value"},
 96              azure_ad_token_provider=default_azure_ad_token_provider,
 97              http_client_kwargs={"proxy": "http://example.com:3128", "verify": False},
 98              raise_on_failure=True,
 99          )
100          data = component.to_dict()
101          assert data == {
102              "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
103              "init_parameters": {
104                  "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
105                  "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
106                  "api_version": "2023-05-15",
107                  "azure_deployment": "text-embedding-ada-002",
108                  "dimensions": 768,
109                  "azure_endpoint": "https://example-resource.azure.openai.com/",
110                  "organization": "HaystackCI",
111                  "prefix": "prefix ",
112                  "suffix": " suffix",
113                  "batch_size": 32,
114                  "progress_bar": True,
115                  "meta_fields_to_embed": [],
116                  "embedding_separator": "\n",
117                  "max_retries": 10,
118                  "timeout": 60.0,
119                  "default_headers": {"x-custom-header": "custom-value"},
120                  "azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
121                  "http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
122                  "raise_on_failure": True,
123              },
124          }
125  
126      def test_from_dict(self, monkeypatch):
127          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
128          data = {
129              "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
130              "init_parameters": {
131                  "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
132                  "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
133                  "api_version": "2023-05-15",
134                  "azure_deployment": "text-embedding-ada-002",
135                  "dimensions": None,
136                  "azure_endpoint": "https://example-resource.azure.openai.com/",
137                  "organization": None,
138                  "prefix": "",
139                  "suffix": "",
140                  "batch_size": 32,
141                  "progress_bar": True,
142                  "meta_fields_to_embed": [],
143                  "embedding_separator": "\n",
144                  "max_retries": 5,
145                  "timeout": 30.0,
146                  "default_headers": {},
147                  "azure_ad_token_provider": None,
148                  "http_client_kwargs": None,
149                  "raise_on_failure": False,
150              },
151          }
152          component = AzureOpenAIDocumentEmbedder.from_dict(data)
153          assert component.azure_deployment == "text-embedding-ada-002"
154          assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
155          assert component.api_version == "2023-05-15"
156          assert component.max_retries == 5
157          assert component.timeout == 30.0
158          assert component.prefix == ""
159          assert component.suffix == ""
160          assert component.default_headers == {}
161          assert component.azure_ad_token_provider is None
162          assert component.http_client_kwargs is None
163          assert component.raise_on_failure is False
164  
165      def test_from_dict_with_parameters(self, monkeypatch):
166          monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
167          data = {
168              "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
169              "init_parameters": {
170                  "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
171                  "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
172                  "api_version": "2023-05-15",
173                  "azure_deployment": "text-embedding-ada-002",
174                  "dimensions": 768,
175                  "azure_endpoint": "https://example-resource.azure.openai.com/",
176                  "organization": "HaystackCI",
177                  "prefix": "prefix ",
178                  "suffix": " suffix",
179                  "batch_size": 32,
180                  "progress_bar": True,
181                  "meta_fields_to_embed": [],
182                  "embedding_separator": "\n",
183                  "max_retries": 10,
184                  "timeout": 60.0,
185                  "default_headers": {"x-custom-header": "custom-value"},
186                  "azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
187                  "http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
188                  "raise_on_failure": True,
189              },
190          }
191          component = AzureOpenAIDocumentEmbedder.from_dict(data)
192          assert component.azure_deployment == "text-embedding-ada-002"
193          assert component.azure_endpoint == "https://example-resource.azure.openai.com/"
194          assert component.api_version == "2023-05-15"
195          assert component.max_retries == 10
196          assert component.timeout == 60.0
197          assert component.prefix == "prefix "
198          assert component.suffix == " suffix"
199          assert component.default_headers == {"x-custom-header": "custom-value"}
200          assert component.azure_ad_token_provider is not None
201          assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
202          assert component.raise_on_failure is True
203  
204      def test_embed_batch_handles_exceptions_gracefully(self, caplog):
205          embedder = AzureOpenAIDocumentEmbedder(
206              azure_endpoint="https://test.openai.azure.com",
207              api_key=Secret.from_token("fake-api-key"),
208              azure_deployment="text-embedding-ada-002",
209              embedding_separator=" | ",
210          )
211  
212          fake_texts_to_embed = {"1": "text1", "2": "text2"}
213  
214          with patch.object(
215              embedder.client.embeddings,
216              "create",
217              side_effect=APIError(message="Mocked error", request=Mock(), body=None),
218          ):
219              embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=32)
220  
221          assert len(caplog.records) == 1
222          assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
223  
224      def test_embed_batch_raises_exception_on_failure(self):
225          embedder = AzureOpenAIDocumentEmbedder(
226              azure_endpoint="https://test.openai.azure.com",
227              api_key=Secret.from_token("fake-api-key"),
228              azure_deployment="text-embedding-ada-002",
229              raise_on_failure=True,
230          )
231          fake_texts_to_embed = {"1": "text1", "2": "text2"}
232          with patch.object(
233              embedder.client.embeddings,
234              "create",
235              side_effect=APIError(message="Mocked error", request=Mock(), body=None),
236          ):
237              with pytest.raises(APIError, match="Mocked error"):
238                  embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
239  
240      @pytest.mark.integration
241      @pytest.mark.skipif(
242          not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
243          reason=(
244              "Please export env variables called AZURE_OPENAI_API_KEY containing "
245              "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
246              "the Azure OpenAI endpoint URL to run this test."
247          ),
248      )
249      def test_run(self):
250          docs = [
251              Document(content="I love cheese", meta={"topic": "Cuisine"}),
252              Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
253          ]
254          # the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit
255          embedder = AzureOpenAIDocumentEmbedder(
256              azure_deployment="text-embedding-ada-002",
257              meta_fields_to_embed=["topic"],
258              embedding_separator=" | ",
259              organization="HaystackCI",
260          )
261  
262          result = embedder.run(documents=docs)
263          documents_with_embeddings = result["documents"]
264          metadata = result["meta"]
265  
266          assert isinstance(documents_with_embeddings, list)
267          assert len(documents_with_embeddings) == len(docs)
268          for doc, new_doc in zip(docs, documents_with_embeddings, strict=True):
269              assert doc.embedding is None
270              assert new_doc is not doc
271              assert isinstance(new_doc, Document)
272              assert isinstance(new_doc.embedding, list)
273              assert len(new_doc.embedding) == 1536
274              assert all(isinstance(x, float) for x in new_doc.embedding)
275  
276          assert metadata["usage"]["prompt_tokens"] == 15
277          assert metadata["usage"]["total_tokens"] == 15
278          assert "ada" in metadata["model"]