/ test / components / embedders / test_openai_text_embedder.py
test_openai_text_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import contextlib
  6  import os
  7  
  8  import pytest
  9  from openai.types import CreateEmbeddingResponse, Embedding
 10  
 11  from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
 12  from haystack.utils.auth import Secret
 13  
 14  
 15  class TestOpenAITextEmbedder:
 16      def test_init_default(self, monkeypatch):
 17          monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
 18          embedder = OpenAITextEmbedder()
 19  
 20          assert embedder.client.api_key == "fake-api-key"
 21          assert embedder.model == "text-embedding-ada-002"
 22          assert embedder.api_base_url == None
 23          assert embedder.organization is None
 24          assert embedder.prefix == ""
 25          assert embedder.suffix == ""
 26          assert embedder.client.timeout == 30
 27          assert embedder.client.max_retries == 5
 28  
 29      def test_init_with_parameters(self, monkeypatch):
 30          monkeypatch.setenv("OPENAI_TIMEOUT", "100")
 31          monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
 32          embedder = OpenAITextEmbedder(
 33              api_key=Secret.from_token("fake-api-key"),
 34              model="model",
 35              api_base_url="https://my-custom-base-url.com",
 36              organization="fake-organization",
 37              prefix="prefix",
 38              suffix="suffix",
 39              timeout=40.0,
 40              max_retries=1,
 41          )
 42          assert embedder.client.api_key == "fake-api-key"
 43          assert embedder.model == "model"
 44          assert embedder.api_base_url == "https://my-custom-base-url.com"
 45          assert embedder.organization == "fake-organization"
 46          assert embedder.prefix == "prefix"
 47          assert embedder.suffix == "suffix"
 48          assert embedder.client.timeout == 40.0
 49          assert embedder.client.max_retries == 1
 50  
 51      def test_init_with_parameters_and_env_vars(self, monkeypatch):
 52          monkeypatch.setenv("OPENAI_TIMEOUT", "100")
 53          monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
 54          embedder = OpenAITextEmbedder(
 55              api_key=Secret.from_token("fake-api-key"),
 56              model="model",
 57              api_base_url="https://my-custom-base-url.com",
 58              organization="fake-organization",
 59              prefix="prefix",
 60              suffix="suffix",
 61          )
 62          assert embedder.client.api_key == "fake-api-key"
 63          assert embedder.model == "model"
 64          assert embedder.api_base_url == "https://my-custom-base-url.com"
 65          assert embedder.organization == "fake-organization"
 66          assert embedder.prefix == "prefix"
 67          assert embedder.suffix == "suffix"
 68          assert embedder.client.timeout == 100.0
 69          assert embedder.client.max_retries == 10
 70  
 71      def test_init_fail_wo_api_key(self, monkeypatch):
 72          monkeypatch.delenv("OPENAI_API_KEY", raising=False)
 73          with pytest.raises(ValueError, match="None of the .* environment variables are set"):
 74              OpenAITextEmbedder()
 75  
 76      def test_to_dict(self, monkeypatch):
 77          monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
 78          component = OpenAITextEmbedder()
 79          data = component.to_dict()
 80          assert data == {
 81              "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
 82              "init_parameters": {
 83                  "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
 84                  "api_base_url": None,
 85                  "dimensions": None,
 86                  "model": "text-embedding-ada-002",
 87                  "organization": None,
 88                  "http_client_kwargs": None,
 89                  "prefix": "",
 90                  "suffix": "",
 91                  "timeout": None,
 92                  "max_retries": None,
 93              },
 94          }
 95  
 96      def test_to_dict_with_custom_init_parameters(self, monkeypatch):
 97          monkeypatch.setenv("ENV_VAR", "fake-api-key")
 98          component = OpenAITextEmbedder(
 99              api_key=Secret.from_env_var("ENV_VAR", strict=False),
100              model="model",
101              api_base_url="https://my-custom-base-url.com",
102              organization="fake-organization",
103              prefix="prefix",
104              suffix="suffix",
105              timeout=10.0,
106              max_retries=2,
107              http_client_kwargs={"proxy": "http://localhost:8080"},
108          )
109          data = component.to_dict()
110          assert data == {
111              "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
112              "init_parameters": {
113                  "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
114                  "api_base_url": "https://my-custom-base-url.com",
115                  "model": "model",
116                  "dimensions": None,
117                  "organization": "fake-organization",
118                  "http_client_kwargs": {"proxy": "http://localhost:8080"},
119                  "prefix": "prefix",
120                  "suffix": "suffix",
121                  "timeout": 10.0,
122                  "max_retries": 2,
123              },
124          }
125  
126      def test_from_dict(self, monkeypatch):
127          monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
128          data = {
129              "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
130              "init_parameters": {
131                  "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
132                  "model": "text-embedding-ada-002",
133                  "api_base_url": "https://my-custom-base-url.com",
134                  "organization": "fake-organization",
135                  "http_client_kwargs": None,
136                  "prefix": "prefix",
137                  "suffix": "suffix",
138              },
139          }
140          component = OpenAITextEmbedder.from_dict(data)
141          assert component.client.api_key == "fake-api-key"
142          assert component.model == "text-embedding-ada-002"
143          assert component.api_base_url == "https://my-custom-base-url.com"
144          assert component.organization == "fake-organization"
145          assert component.http_client_kwargs is None
146          assert component.prefix == "prefix"
147          assert component.suffix == "suffix"
148  
149      def test_prepare_input(self, monkeypatch):
150          monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
151          embedder = OpenAITextEmbedder(dimensions=1536)
152  
153          inp = "The food was delicious"
154          prepared_input = embedder._prepare_input(inp)
155          assert prepared_input == {
156              "model": "text-embedding-ada-002",
157              "input": "The food was delicious",
158              "encoding_format": "float",
159              "dimensions": 1536,
160          }
161  
162      def test_prepare_output(self, monkeypatch):
163          monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
164  
165          response = CreateEmbeddingResponse(
166              data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding")],
167              model="text-embedding-ada-002",
168              object="list",
169              usage={"prompt_tokens": 6, "total_tokens": 6},
170          )
171  
172          embedder = OpenAITextEmbedder()
173          result = embedder._prepare_output(result=response)
174          assert result == {
175              "embedding": [0.1, 0.2, 0.3],
176              "meta": {"model": "text-embedding-ada-002", "usage": {"prompt_tokens": 6, "total_tokens": 6}},
177          }
178  
179      def test_run_wrong_input_format(self):
180          embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key"))
181  
182          list_integers_input = [1, 2, 3]
183  
184          with pytest.raises(TypeError, match="OpenAITextEmbedder expects a string as an input"):
185              embedder.run(text=list_integers_input)
186  
187      @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
188      @pytest.mark.integration
189      def test_run(self):
190          model = "text-embedding-ada-002"
191  
192          embedder = OpenAITextEmbedder(model=model, prefix="prefix ", suffix=" suffix")
193          result = embedder.run(text="The food was delicious")
194  
195          assert len(result["embedding"]) == 1536
196          assert all(isinstance(x, float) for x in result["embedding"])
197  
198          assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
199              "The model name does not contain 'text' and 'ada'"
200          )
201  
202          assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6}, "Usage information does not match"
203  
204      @pytest.mark.asyncio
205      @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
206      @pytest.mark.integration
207      async def test_run_async(self):
208          embedder = OpenAITextEmbedder(model="text-embedding-ada-002", prefix="prefix ", suffix=" suffix")
209          result = await embedder.run_async(text="The food was delicious")
210  
211          assert len(result["embedding"]) == 1536
212          assert all(isinstance(x, float) for x in result["embedding"])
213  
214          assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
215              "The model name does not contain 'text' and 'ada'"
216          )
217  
218          assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6}, "Usage information does not match"
219  
220          # Close async client; suppress RuntimeError if the event loop is already closed
221          with contextlib.suppress(RuntimeError):
222              await embedder.async_client.close()