test_azure_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from unittest.mock import Mock, patch 7 8 import pytest 9 from openai import APIError 10 11 from haystack import Document 12 from haystack.components.embedders import AzureOpenAIDocumentEmbedder 13 from haystack.utils.auth import Secret 14 from haystack.utils.azure import default_azure_ad_token_provider 15 16 17 class TestAzureOpenAIDocumentEmbedder: 18 def test_init_default(self, monkeypatch): 19 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 20 embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") 21 assert embedder.azure_deployment == "text-embedding-ada-002" 22 assert embedder.model == "text-embedding-ada-002" 23 assert embedder.dimensions is None 24 assert embedder.organization is None 25 assert embedder.prefix == "" 26 assert embedder.suffix == "" 27 assert embedder.batch_size == 32 28 assert embedder.progress_bar is True 29 assert embedder.meta_fields_to_embed == [] 30 assert embedder.embedding_separator == "\n" 31 assert embedder.default_headers == {} 32 assert embedder.azure_ad_token_provider is None 33 assert embedder.http_client_kwargs is None 34 35 def test_init_with_0_max_retries(self, monkeypatch): 36 """Tests that the max_retries init param is set correctly if equal 0""" 37 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 38 embedder = AzureOpenAIDocumentEmbedder( 39 azure_endpoint="https://example-resource.azure.openai.com/", max_retries=0 40 ) 41 assert embedder.azure_deployment == "text-embedding-ada-002" 42 assert embedder.model == "text-embedding-ada-002" 43 assert embedder.dimensions is None 44 assert embedder.organization is None 45 assert embedder.prefix == "" 46 assert embedder.suffix == "" 47 assert embedder.batch_size == 32 48 assert embedder.progress_bar is True 49 assert embedder.meta_fields_to_embed == [] 50 assert embedder.embedding_separator == "\n" 51 assert embedder.default_headers == {} 52 assert embedder.azure_ad_token_provider is None 53 assert embedder.max_retries == 0 54 55 def test_to_dict(self, monkeypatch): 56 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 57 component = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") 58 data = component.to_dict() 59 assert data == { 60 "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", 61 "init_parameters": { 62 "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, 63 "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, 64 "api_version": "2023-05-15", 65 "azure_deployment": "text-embedding-ada-002", 66 "dimensions": None, 67 "azure_endpoint": "https://example-resource.azure.openai.com/", 68 "organization": None, 69 "prefix": "", 70 "suffix": "", 71 "batch_size": 32, 72 "progress_bar": True, 73 "meta_fields_to_embed": [], 74 "embedding_separator": "\n", 75 "max_retries": 5, 76 "timeout": 30.0, 77 "default_headers": {}, 78 "azure_ad_token_provider": None, 79 "http_client_kwargs": None, 80 "raise_on_failure": False, 81 }, 82 } 83 84 def test_to_dict_with_parameters(self, monkeypatch): 85 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 86 component = AzureOpenAIDocumentEmbedder( 87 azure_endpoint="https://example-resource.azure.openai.com/", 88 azure_deployment="text-embedding-ada-002", 89 dimensions=768, 90 organization="HaystackCI", 91 timeout=60.0, 92 max_retries=10, 93 prefix="prefix ", 94 suffix=" suffix", 95 default_headers={"x-custom-header": "custom-value"}, 96 azure_ad_token_provider=default_azure_ad_token_provider, 97 http_client_kwargs={"proxy": "http://example.com:3128", "verify": False}, 98 raise_on_failure=True, 99 ) 100 data = component.to_dict() 101 assert data == { 102 "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", 103 "init_parameters": { 104 "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, 105 "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, 106 "api_version": "2023-05-15", 107 "azure_deployment": "text-embedding-ada-002", 108 "dimensions": 768, 109 "azure_endpoint": "https://example-resource.azure.openai.com/", 110 "organization": "HaystackCI", 111 "prefix": "prefix ", 112 "suffix": " suffix", 113 "batch_size": 32, 114 "progress_bar": True, 115 "meta_fields_to_embed": [], 116 "embedding_separator": "\n", 117 "max_retries": 10, 118 "timeout": 60.0, 119 "default_headers": {"x-custom-header": "custom-value"}, 120 "azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider", 121 "http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False}, 122 "raise_on_failure": True, 123 }, 124 } 125 126 def test_from_dict(self, monkeypatch): 127 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 128 data = { 129 "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", 130 "init_parameters": { 131 "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, 132 "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, 133 "api_version": "2023-05-15", 134 "azure_deployment": "text-embedding-ada-002", 135 "dimensions": None, 136 "azure_endpoint": "https://example-resource.azure.openai.com/", 137 "organization": None, 138 "prefix": "", 139 "suffix": "", 140 "batch_size": 32, 141 "progress_bar": True, 142 "meta_fields_to_embed": [], 143 "embedding_separator": "\n", 144 "max_retries": 5, 145 "timeout": 30.0, 146 "default_headers": {}, 147 "azure_ad_token_provider": None, 148 "http_client_kwargs": None, 149 "raise_on_failure": False, 150 }, 151 } 152 component = AzureOpenAIDocumentEmbedder.from_dict(data) 153 assert component.azure_deployment == "text-embedding-ada-002" 154 assert component.azure_endpoint == "https://example-resource.azure.openai.com/" 155 assert component.api_version == "2023-05-15" 156 assert component.max_retries == 5 157 assert component.timeout == 30.0 158 assert component.prefix == "" 159 assert component.suffix == "" 160 assert component.default_headers == {} 161 assert component.azure_ad_token_provider is None 162 assert component.http_client_kwargs is None 163 assert component.raise_on_failure is False 164 165 def test_from_dict_with_parameters(self, monkeypatch): 166 monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") 167 data = { 168 "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", 169 "init_parameters": { 170 "api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"}, 171 "azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"}, 172 "api_version": "2023-05-15", 173 "azure_deployment": "text-embedding-ada-002", 174 "dimensions": 768, 175 "azure_endpoint": "https://example-resource.azure.openai.com/", 176 "organization": "HaystackCI", 177 "prefix": "prefix ", 178 "suffix": " suffix", 179 "batch_size": 32, 180 "progress_bar": True, 181 "meta_fields_to_embed": [], 182 "embedding_separator": "\n", 183 "max_retries": 10, 184 "timeout": 60.0, 185 "default_headers": {"x-custom-header": "custom-value"}, 186 "azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider", 187 "http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False}, 188 "raise_on_failure": True, 189 }, 190 } 191 component = AzureOpenAIDocumentEmbedder.from_dict(data) 192 assert component.azure_deployment == "text-embedding-ada-002" 193 assert component.azure_endpoint == "https://example-resource.azure.openai.com/" 194 assert component.api_version == "2023-05-15" 195 assert component.max_retries == 10 196 assert component.timeout == 60.0 197 assert component.prefix == "prefix " 198 assert component.suffix == " suffix" 199 assert component.default_headers == {"x-custom-header": "custom-value"} 200 assert component.azure_ad_token_provider is not None 201 assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False} 202 assert component.raise_on_failure is True 203 204 def test_embed_batch_handles_exceptions_gracefully(self, caplog): 205 embedder = AzureOpenAIDocumentEmbedder( 206 azure_endpoint="https://test.openai.azure.com", 207 api_key=Secret.from_token("fake-api-key"), 208 azure_deployment="text-embedding-ada-002", 209 embedding_separator=" | ", 210 ) 211 212 fake_texts_to_embed = {"1": "text1", "2": "text2"} 213 214 with patch.object( 215 embedder.client.embeddings, 216 "create", 217 side_effect=APIError(message="Mocked error", request=Mock(), body=None), 218 ): 219 embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=32) 220 221 assert len(caplog.records) == 1 222 assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text 223 224 def test_embed_batch_raises_exception_on_failure(self): 225 embedder = AzureOpenAIDocumentEmbedder( 226 azure_endpoint="https://test.openai.azure.com", 227 api_key=Secret.from_token("fake-api-key"), 228 azure_deployment="text-embedding-ada-002", 229 raise_on_failure=True, 230 ) 231 fake_texts_to_embed = {"1": "text1", "2": "text2"} 232 with patch.object( 233 embedder.client.embeddings, 234 "create", 235 side_effect=APIError(message="Mocked error", request=Mock(), body=None), 236 ): 237 with pytest.raises(APIError, match="Mocked error"): 238 embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2) 239 240 @pytest.mark.integration 241 @pytest.mark.skipif( 242 not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), 243 reason=( 244 "Please export env variables called AZURE_OPENAI_API_KEY containing " 245 "the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing " 246 "the Azure OpenAI endpoint URL to run this test." 247 ), 248 ) 249 def test_run(self): 250 docs = [ 251 Document(content="I love cheese", meta={"topic": "Cuisine"}), 252 Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), 253 ] 254 # the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit 255 embedder = AzureOpenAIDocumentEmbedder( 256 azure_deployment="text-embedding-ada-002", 257 meta_fields_to_embed=["topic"], 258 embedding_separator=" | ", 259 organization="HaystackCI", 260 ) 261 262 result = embedder.run(documents=docs) 263 documents_with_embeddings = result["documents"] 264 metadata = result["meta"] 265 266 assert isinstance(documents_with_embeddings, list) 267 assert len(documents_with_embeddings) == len(docs) 268 for doc, new_doc in zip(docs, documents_with_embeddings, strict=True): 269 assert doc.embedding is None 270 assert new_doc is not doc 271 assert isinstance(new_doc, Document) 272 assert isinstance(new_doc.embedding, list) 273 assert len(new_doc.embedding) == 1536 274 assert all(isinstance(x, float) for x in new_doc.embedding) 275 276 assert metadata["usage"]["prompt_tokens"] == 15 277 assert metadata["usage"]["total_tokens"] == 15 278 assert "ada" in metadata["model"]