test_hugging_face_api_text_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 import random 7 import sys 8 from unittest.mock import MagicMock, patch 9 10 import pytest 11 from huggingface_hub.utils import RepositoryNotFoundError 12 from numpy import array 13 14 from haystack.components.embedders import HuggingFaceAPITextEmbedder 15 from haystack.utils.auth import Secret 16 from haystack.utils.hf import HFEmbeddingAPIType 17 18 19 @pytest.fixture 20 def mock_check_valid_model(): 21 with patch( 22 "haystack.components.embedders.hugging_face_api_text_embedder.check_valid_model", MagicMock(return_value=None) 23 ) as mock: 24 yield mock 25 26 27 class TestHuggingFaceAPITextEmbedder: 28 def test_init_invalid_api_type(self): 29 with pytest.raises(ValueError): 30 HuggingFaceAPITextEmbedder(api_type="invalid_api_type", api_params={}) 31 32 def test_init_serverless(self, mock_check_valid_model): 33 model = "BAAI/bge-small-en-v1.5" 34 embedder = HuggingFaceAPITextEmbedder( 35 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model} 36 ) 37 38 assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API 39 assert embedder.api_params == {"model": model} 40 assert embedder.prefix == "" 41 assert embedder.suffix == "" 42 assert embedder.truncate 43 assert not embedder.normalize 44 45 def test_init_serverless_invalid_model(self, mock_check_valid_model): 46 mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id", response=MagicMock()) 47 with pytest.raises(RepositoryNotFoundError): 48 HuggingFaceAPITextEmbedder( 49 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} 50 ) 51 52 def test_init_serverless_no_model(self): 53 with pytest.raises(ValueError): 54 HuggingFaceAPITextEmbedder( 55 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} 56 ) 57 58 def test_init_tei(self): 59 url = "https://some_model.com" 60 61 embedder = HuggingFaceAPITextEmbedder( 62 api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url} 63 ) 64 65 assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE 66 assert embedder.api_params == {"url": url} 67 assert embedder.prefix == "" 68 assert embedder.suffix == "" 69 assert embedder.truncate 70 assert not embedder.normalize 71 72 def test_init_tei_invalid_url(self): 73 with pytest.raises(ValueError): 74 HuggingFaceAPITextEmbedder( 75 api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"} 76 ) 77 78 def test_init_tei_no_url(self): 79 with pytest.raises(ValueError): 80 HuggingFaceAPITextEmbedder( 81 api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"} 82 ) 83 84 def test_to_dict(self, mock_check_valid_model): 85 embedder = HuggingFaceAPITextEmbedder( 86 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, 87 api_params={"model": "BAAI/bge-small-en-v1.5"}, 88 prefix="prefix", 89 suffix="suffix", 90 truncate=False, 91 normalize=True, 92 ) 93 94 data = embedder.to_dict() 95 96 assert data == { 97 "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", 98 "init_parameters": { 99 "api_type": "serverless_inference_api", 100 "api_params": {"model": "BAAI/bge-small-en-v1.5"}, 101 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 102 "prefix": "prefix", 103 "suffix": "suffix", 104 "truncate": False, 105 "normalize": True, 106 }, 107 } 108 109 def test_from_dict(self, mock_check_valid_model): 110 data = { 111 "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", 112 "init_parameters": { 113 "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, 114 "api_params": {"model": "BAAI/bge-small-en-v1.5"}, 115 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 116 "prefix": "prefix", 117 "suffix": "suffix", 118 "truncate": False, 119 "normalize": True, 120 }, 121 } 122 123 embedder = HuggingFaceAPITextEmbedder.from_dict(data) 124 125 assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API 126 assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"} 127 assert embedder.prefix == "prefix" 128 assert embedder.suffix == "suffix" 129 assert not embedder.truncate 130 assert embedder.normalize 131 132 def test_run_wrong_input_format(self, mock_check_valid_model): 133 embedder = HuggingFaceAPITextEmbedder( 134 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} 135 ) 136 137 list_integers_input = [1, 2, 3] 138 139 with pytest.raises(TypeError): 140 embedder.run(text=list_integers_input) 141 142 def test_run(self, mock_check_valid_model, caplog): 143 with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: 144 mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) 145 146 embedder = HuggingFaceAPITextEmbedder( 147 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, 148 api_params={"model": "BAAI/bge-small-en-v1.5"}, 149 token=Secret.from_token("fake-api-token"), 150 prefix="prefix ", 151 suffix=" suffix", 152 ) 153 154 result = embedder.run(text="The food was delicious") 155 156 mock_embedding_patch.assert_called_once_with( 157 text="prefix The food was delicious suffix", truncate=None, normalize=None 158 ) 159 160 assert len(result["embedding"]) == 384 161 assert all(isinstance(x, float) for x in result["embedding"]) 162 163 # Check that warnings about ignoring truncate and normalize are raised 164 assert len(caplog.records) == 2 165 assert "truncate" in caplog.records[0].message 166 assert "normalize" in caplog.records[1].message 167 168 @pytest.mark.asyncio 169 async def test_run_async(self, mock_check_valid_model, caplog): 170 with patch("huggingface_hub.AsyncInferenceClient.feature_extraction") as mock_embedding_patch: 171 mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]]) 172 173 embedder = HuggingFaceAPITextEmbedder( 174 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, 175 api_params={"model": "BAAI/bge-small-en-v1.5"}, 176 token=Secret.from_token("fake-api-token"), 177 prefix="prefix ", 178 suffix=" suffix", 179 ) 180 181 result = await embedder.run_async(text="The food was delicious") 182 183 mock_embedding_patch.assert_called_once_with( 184 text="prefix The food was delicious suffix", truncate=None, normalize=None 185 ) 186 187 assert len(result["embedding"]) == 384 188 assert all(isinstance(x, float) for x in result["embedding"]) 189 190 # Check that warnings about ignoring truncate and normalize are raised 191 assert len(caplog.records) == 2 192 assert "truncate" in caplog.records[0].message 193 assert "normalize" in caplog.records[1].message 194 195 def test_run_wrong_embedding_shape(self, mock_check_valid_model): 196 # embedding ndim > 2 197 with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: 198 mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]) 199 200 embedder = HuggingFaceAPITextEmbedder( 201 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} 202 ) 203 204 with pytest.raises(ValueError): 205 embedder.run(text="The food was delicious") 206 207 # embedding ndim == 2 but shape[0] != 1 208 with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch: 209 mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) 210 211 embedder = HuggingFaceAPITextEmbedder( 212 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} 213 ) 214 215 with pytest.raises(ValueError): 216 embedder.run(text="The food was delicious") 217 218 @pytest.mark.integration 219 @pytest.mark.slow 220 @pytest.mark.flaky(reruns=3, reruns_delay=10) 221 @pytest.mark.skipif( 222 not os.environ.get("HF_API_TOKEN", None), 223 reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", 224 ) 225 @pytest.mark.skipif(sys.platform != "linux", reason="We only test on Linux to avoid overloading the HF server") 226 def test_live_run_serverless(self): 227 embedder = HuggingFaceAPITextEmbedder( 228 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, 229 api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"}, 230 ) 231 embedder._client.timeout = 10 # we want to fail fast if the server is not responding 232 result = embedder.run(text="The food was delicious") 233 234 assert len(result["embedding"]) == 384 235 assert all(isinstance(x, float) for x in result["embedding"]) 236 237 @pytest.mark.integration 238 @pytest.mark.asyncio 239 @pytest.mark.slow 240 @pytest.mark.flaky(reruns=3, reruns_delay=10) 241 @pytest.mark.skipif(os.environ.get("HF_API_TOKEN", "") == "", reason="HF_API_TOKEN is not set") 242 @pytest.mark.skipif(sys.platform != "linux", reason="We only test on Linux to avoid overloading the HF server") 243 async def test_live_run_async_serverless(self): 244 model_name = "sentence-transformers/all-MiniLM-L6-v2" 245 246 embedder = HuggingFaceAPITextEmbedder( 247 api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model_name} 248 ) 249 embedder._client.timeout = 10 # we want to fail fast if the server is not responding 250 251 text = "This is a test sentence for embedding." 252 result = await embedder.run_async(text=text) 253 254 assert "embedding" in result 255 assert isinstance(result["embedding"], list) 256 assert all(isinstance(x, float) for x in result["embedding"]) 257 assert len(result["embedding"]) == 384 # MiniLM-L6-v2 has 384 dimensions