test_openai_text_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import contextlib 6 import os 7 8 import pytest 9 from openai.types import CreateEmbeddingResponse, Embedding 10 11 from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder 12 from haystack.utils.auth import Secret 13 14 15 class TestOpenAITextEmbedder: 16 def test_init_default(self, monkeypatch): 17 monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") 18 embedder = OpenAITextEmbedder() 19 20 assert embedder.client.api_key == "fake-api-key" 21 assert embedder.model == "text-embedding-ada-002" 22 assert embedder.api_base_url == None 23 assert embedder.organization is None 24 assert embedder.prefix == "" 25 assert embedder.suffix == "" 26 assert embedder.client.timeout == 30 27 assert embedder.client.max_retries == 5 28 29 def test_init_with_parameters(self, monkeypatch): 30 monkeypatch.setenv("OPENAI_TIMEOUT", "100") 31 monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") 32 embedder = OpenAITextEmbedder( 33 api_key=Secret.from_token("fake-api-key"), 34 model="model", 35 api_base_url="https://my-custom-base-url.com", 36 organization="fake-organization", 37 prefix="prefix", 38 suffix="suffix", 39 timeout=40.0, 40 max_retries=1, 41 ) 42 assert embedder.client.api_key == "fake-api-key" 43 assert embedder.model == "model" 44 assert embedder.api_base_url == "https://my-custom-base-url.com" 45 assert embedder.organization == "fake-organization" 46 assert embedder.prefix == "prefix" 47 assert embedder.suffix == "suffix" 48 assert embedder.client.timeout == 40.0 49 assert embedder.client.max_retries == 1 50 51 def test_init_with_parameters_and_env_vars(self, monkeypatch): 52 monkeypatch.setenv("OPENAI_TIMEOUT", "100") 53 monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") 54 embedder = OpenAITextEmbedder( 55 api_key=Secret.from_token("fake-api-key"), 56 model="model", 57 api_base_url="https://my-custom-base-url.com", 58 organization="fake-organization", 59 prefix="prefix", 60 suffix="suffix", 61 ) 62 assert embedder.client.api_key == "fake-api-key" 63 assert embedder.model == "model" 64 assert embedder.api_base_url == "https://my-custom-base-url.com" 65 assert embedder.organization == "fake-organization" 66 assert embedder.prefix == "prefix" 67 assert embedder.suffix == "suffix" 68 assert embedder.client.timeout == 100.0 69 assert embedder.client.max_retries == 10 70 71 def test_init_fail_wo_api_key(self, monkeypatch): 72 monkeypatch.delenv("OPENAI_API_KEY", raising=False) 73 with pytest.raises(ValueError, match="None of the .* environment variables are set"): 74 OpenAITextEmbedder() 75 76 def test_to_dict(self, monkeypatch): 77 monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") 78 component = OpenAITextEmbedder() 79 data = component.to_dict() 80 assert data == { 81 "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder", 82 "init_parameters": { 83 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, 84 "api_base_url": None, 85 "dimensions": None, 86 "model": "text-embedding-ada-002", 87 "organization": None, 88 "http_client_kwargs": None, 89 "prefix": "", 90 "suffix": "", 91 "timeout": None, 92 "max_retries": None, 93 }, 94 } 95 96 def test_to_dict_with_custom_init_parameters(self, monkeypatch): 97 monkeypatch.setenv("ENV_VAR", "fake-api-key") 98 component = OpenAITextEmbedder( 99 api_key=Secret.from_env_var("ENV_VAR", strict=False), 100 model="model", 101 api_base_url="https://my-custom-base-url.com", 102 organization="fake-organization", 103 prefix="prefix", 104 suffix="suffix", 105 timeout=10.0, 106 max_retries=2, 107 http_client_kwargs={"proxy": "http://localhost:8080"}, 108 ) 109 data = component.to_dict() 110 assert data == { 111 "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder", 112 "init_parameters": { 113 "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 114 "api_base_url": "https://my-custom-base-url.com", 115 "model": "model", 116 "dimensions": None, 117 "organization": "fake-organization", 118 "http_client_kwargs": {"proxy": "http://localhost:8080"}, 119 "prefix": "prefix", 120 "suffix": "suffix", 121 "timeout": 10.0, 122 "max_retries": 2, 123 }, 124 } 125 126 def test_from_dict(self, monkeypatch): 127 monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") 128 data = { 129 "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder", 130 "init_parameters": { 131 "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, 132 "model": "text-embedding-ada-002", 133 "api_base_url": "https://my-custom-base-url.com", 134 "organization": "fake-organization", 135 "http_client_kwargs": None, 136 "prefix": "prefix", 137 "suffix": "suffix", 138 }, 139 } 140 component = OpenAITextEmbedder.from_dict(data) 141 assert component.client.api_key == "fake-api-key" 142 assert component.model == "text-embedding-ada-002" 143 assert component.api_base_url == "https://my-custom-base-url.com" 144 assert component.organization == "fake-organization" 145 assert component.http_client_kwargs is None 146 assert component.prefix == "prefix" 147 assert component.suffix == "suffix" 148 149 def test_prepare_input(self, monkeypatch): 150 monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") 151 embedder = OpenAITextEmbedder(dimensions=1536) 152 153 inp = "The food was delicious" 154 prepared_input = embedder._prepare_input(inp) 155 assert prepared_input == { 156 "model": "text-embedding-ada-002", 157 "input": "The food was delicious", 158 "encoding_format": "float", 159 "dimensions": 1536, 160 } 161 162 def test_prepare_output(self, monkeypatch): 163 monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") 164 165 response = CreateEmbeddingResponse( 166 data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding")], 167 model="text-embedding-ada-002", 168 object="list", 169 usage={"prompt_tokens": 6, "total_tokens": 6}, 170 ) 171 172 embedder = OpenAITextEmbedder() 173 result = embedder._prepare_output(result=response) 174 assert result == { 175 "embedding": [0.1, 0.2, 0.3], 176 "meta": {"model": "text-embedding-ada-002", "usage": {"prompt_tokens": 6, "total_tokens": 6}}, 177 } 178 179 def test_run_wrong_input_format(self): 180 embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key")) 181 182 list_integers_input = [1, 2, 3] 183 184 with pytest.raises(TypeError, match="OpenAITextEmbedder expects a string as an input"): 185 embedder.run(text=list_integers_input) 186 187 @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set") 188 @pytest.mark.integration 189 def test_run(self): 190 model = "text-embedding-ada-002" 191 192 embedder = OpenAITextEmbedder(model=model, prefix="prefix ", suffix=" suffix") 193 result = embedder.run(text="The food was delicious") 194 195 assert len(result["embedding"]) == 1536 196 assert all(isinstance(x, float) for x in result["embedding"]) 197 198 assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], ( 199 "The model name does not contain 'text' and 'ada'" 200 ) 201 202 assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6}, "Usage information does not match" 203 204 @pytest.mark.asyncio 205 @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set") 206 @pytest.mark.integration 207 async def test_run_async(self): 208 embedder = OpenAITextEmbedder(model="text-embedding-ada-002", prefix="prefix ", suffix=" suffix") 209 result = await embedder.run_async(text="The food was delicious") 210 211 assert len(result["embedding"]) == 1536 212 assert all(isinstance(x, float) for x in result["embedding"]) 213 214 assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], ( 215 "The model name does not contain 'text' and 'ada'" 216 ) 217 218 assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6}, "Usage information does not match" 219 220 # Close async client; suppress RuntimeError if the event loop is already closed 221 with contextlib.suppress(RuntimeError): 222 await embedder.async_client.close()