/ test / components / embedders / test_hugging_face_api_text_embedder.py
test_hugging_face_api_text_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  import random
  7  import sys
  8  from unittest.mock import MagicMock, patch
  9  
 10  import pytest
 11  from huggingface_hub.utils import RepositoryNotFoundError
 12  from numpy import array
 13  
 14  from haystack.components.embedders import HuggingFaceAPITextEmbedder
 15  from haystack.utils.auth import Secret
 16  from haystack.utils.hf import HFEmbeddingAPIType
 17  
 18  
 19  @pytest.fixture
 20  def mock_check_valid_model():
 21      with patch(
 22          "haystack.components.embedders.hugging_face_api_text_embedder.check_valid_model", MagicMock(return_value=None)
 23      ) as mock:
 24          yield mock
 25  
 26  
 27  class TestHuggingFaceAPITextEmbedder:
 28      def test_init_invalid_api_type(self):
 29          with pytest.raises(ValueError):
 30              HuggingFaceAPITextEmbedder(api_type="invalid_api_type", api_params={})
 31  
 32      def test_init_serverless(self, mock_check_valid_model):
 33          model = "BAAI/bge-small-en-v1.5"
 34          embedder = HuggingFaceAPITextEmbedder(
 35              api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}
 36          )
 37  
 38          assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
 39          assert embedder.api_params == {"model": model}
 40          assert embedder.prefix == ""
 41          assert embedder.suffix == ""
 42          assert embedder.truncate
 43          assert not embedder.normalize
 44  
 45      def test_init_serverless_invalid_model(self, mock_check_valid_model):
 46          mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id", response=MagicMock())
 47          with pytest.raises(RepositoryNotFoundError):
 48              HuggingFaceAPITextEmbedder(
 49                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"}
 50              )
 51  
 52      def test_init_serverless_no_model(self):
 53          with pytest.raises(ValueError):
 54              HuggingFaceAPITextEmbedder(
 55                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"}
 56              )
 57  
 58      def test_init_tei(self):
 59          url = "https://some_model.com"
 60  
 61          embedder = HuggingFaceAPITextEmbedder(
 62              api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url}
 63          )
 64  
 65          assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE
 66          assert embedder.api_params == {"url": url}
 67          assert embedder.prefix == ""
 68          assert embedder.suffix == ""
 69          assert embedder.truncate
 70          assert not embedder.normalize
 71  
 72      def test_init_tei_invalid_url(self):
 73          with pytest.raises(ValueError):
 74              HuggingFaceAPITextEmbedder(
 75                  api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"}
 76              )
 77  
 78      def test_init_tei_no_url(self):
 79          with pytest.raises(ValueError):
 80              HuggingFaceAPITextEmbedder(
 81                  api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"}
 82              )
 83  
 84      def test_to_dict(self, mock_check_valid_model):
 85          embedder = HuggingFaceAPITextEmbedder(
 86              api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
 87              api_params={"model": "BAAI/bge-small-en-v1.5"},
 88              prefix="prefix",
 89              suffix="suffix",
 90              truncate=False,
 91              normalize=True,
 92          )
 93  
 94          data = embedder.to_dict()
 95  
 96          assert data == {
 97              "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
 98              "init_parameters": {
 99                  "api_type": "serverless_inference_api",
100                  "api_params": {"model": "BAAI/bge-small-en-v1.5"},
101                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
102                  "prefix": "prefix",
103                  "suffix": "suffix",
104                  "truncate": False,
105                  "normalize": True,
106              },
107          }
108  
109      def test_from_dict(self, mock_check_valid_model):
110          data = {
111              "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
112              "init_parameters": {
113                  "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
114                  "api_params": {"model": "BAAI/bge-small-en-v1.5"},
115                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
116                  "prefix": "prefix",
117                  "suffix": "suffix",
118                  "truncate": False,
119                  "normalize": True,
120              },
121          }
122  
123          embedder = HuggingFaceAPITextEmbedder.from_dict(data)
124  
125          assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API
126          assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"}
127          assert embedder.prefix == "prefix"
128          assert embedder.suffix == "suffix"
129          assert not embedder.truncate
130          assert embedder.normalize
131  
132      def test_run_wrong_input_format(self, mock_check_valid_model):
133          embedder = HuggingFaceAPITextEmbedder(
134              api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
135          )
136  
137          list_integers_input = [1, 2, 3]
138  
139          with pytest.raises(TypeError):
140              embedder.run(text=list_integers_input)
141  
142      def test_run(self, mock_check_valid_model, caplog):
143          with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
144              mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])
145  
146              embedder = HuggingFaceAPITextEmbedder(
147                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
148                  api_params={"model": "BAAI/bge-small-en-v1.5"},
149                  token=Secret.from_token("fake-api-token"),
150                  prefix="prefix ",
151                  suffix=" suffix",
152              )
153  
154              result = embedder.run(text="The food was delicious")
155  
156              mock_embedding_patch.assert_called_once_with(
157                  text="prefix The food was delicious suffix", truncate=None, normalize=None
158              )
159  
160          assert len(result["embedding"]) == 384
161          assert all(isinstance(x, float) for x in result["embedding"])
162  
163          # Check that warnings about ignoring truncate and normalize are raised
164          assert len(caplog.records) == 2
165          assert "truncate" in caplog.records[0].message
166          assert "normalize" in caplog.records[1].message
167  
168      @pytest.mark.asyncio
169      async def test_run_async(self, mock_check_valid_model, caplog):
170          with patch("huggingface_hub.AsyncInferenceClient.feature_extraction") as mock_embedding_patch:
171              mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])
172  
173              embedder = HuggingFaceAPITextEmbedder(
174                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
175                  api_params={"model": "BAAI/bge-small-en-v1.5"},
176                  token=Secret.from_token("fake-api-token"),
177                  prefix="prefix ",
178                  suffix=" suffix",
179              )
180  
181              result = await embedder.run_async(text="The food was delicious")
182  
183              mock_embedding_patch.assert_called_once_with(
184                  text="prefix The food was delicious suffix", truncate=None, normalize=None
185              )
186  
187          assert len(result["embedding"]) == 384
188          assert all(isinstance(x, float) for x in result["embedding"])
189  
190          # Check that warnings about ignoring truncate and normalize are raised
191          assert len(caplog.records) == 2
192          assert "truncate" in caplog.records[0].message
193          assert "normalize" in caplog.records[1].message
194  
195      def test_run_wrong_embedding_shape(self, mock_check_valid_model):
196          # embedding ndim > 2
197          with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
198              mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
199  
200              embedder = HuggingFaceAPITextEmbedder(
201                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
202              )
203  
204              with pytest.raises(ValueError):
205                  embedder.run(text="The food was delicious")
206  
207          # embedding ndim == 2 but shape[0] != 1
208          with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
209              mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
210  
211              embedder = HuggingFaceAPITextEmbedder(
212                  api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
213              )
214  
215              with pytest.raises(ValueError):
216                  embedder.run(text="The food was delicious")
217  
218      @pytest.mark.integration
219      @pytest.mark.slow
220      @pytest.mark.flaky(reruns=3, reruns_delay=10)
221      @pytest.mark.skipif(
222          not os.environ.get("HF_API_TOKEN", None),
223          reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
224      )
225      @pytest.mark.skipif(sys.platform != "linux", reason="We only test on Linux to avoid overloading the HF server")
226      def test_live_run_serverless(self):
227          embedder = HuggingFaceAPITextEmbedder(
228              api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
229              api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"},
230          )
231          embedder._client.timeout = 10  # we want to fail fast if the server is not responding
232          result = embedder.run(text="The food was delicious")
233  
234          assert len(result["embedding"]) == 384
235          assert all(isinstance(x, float) for x in result["embedding"])
236  
237      @pytest.mark.integration
238      @pytest.mark.asyncio
239      @pytest.mark.slow
240      @pytest.mark.flaky(reruns=3, reruns_delay=10)
241      @pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set")
242      @pytest.mark.skipif(sys.platform != "linux", reason="We only test on Linux to avoid overloading the HF server")
243      async def test_live_run_async_serverless(self):
244          model_name = "sentence-transformers/all-MiniLM-L6-v2"
245  
246          embedder = HuggingFaceAPITextEmbedder(
247              api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name}
248          )
249          embedder._client.timeout = 10  # we want to fail fast if the server is not responding
250  
251          text = "This is a test sentence for embedding."
252          result = await embedder.run_async(text=text)
253  
254          assert "embedding" in result
255          assert isinstance(result["embedding"], list)
256          assert all(isinstance(x, float) for x in result["embedding"])
257          assert len(result["embedding"]) == 384  # MiniLM-L6-v2 has 384 dimensions