hugging_face_api.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from collections.abc import Iterable 6 from dataclasses import asdict 7 from datetime import datetime 8 from typing import Any, cast 9 10 from haystack import component, default_from_dict, default_to_dict, logging 11 from haystack.dataclasses import ( 12 ComponentInfo, 13 FinishReason, 14 StreamingCallbackT, 15 StreamingChunk, 16 SyncStreamingCallbackT, 17 select_streaming_callback, 18 ) 19 from haystack.lazy_imports import LazyImport 20 from haystack.utils import Secret, deserialize_callable, serialize_callable 21 from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model 22 from haystack.utils.url_validation import is_valid_http_url 23 24 with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: 25 from huggingface_hub import ( 26 InferenceClient, 27 TextGenerationOutput, 28 TextGenerationStreamOutput, 29 TextGenerationStreamOutputToken, 30 ) 31 32 33 logger = logging.getLogger(__name__) 34 35 36 @component 37 class HuggingFaceAPIGenerator: 38 """ 39 Generates text using Hugging Face APIs. 40 41 Use it with the following Hugging Face APIs: 42 - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) 43 - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference) 44 45 **Note:** As of July 2025, the Hugging Face Inference API no longer offers generative models through the 46 `text_generation` endpoint. Generative models are now only available through providers supporting the 47 `chat_completion` endpoint. As a result, this component might no longer work with the Hugging Face Inference API. 48 Use the `HuggingFaceAPIChatGenerator` component, which supports the `chat_completion` endpoint. 49 50 ### Usage examples 51 52 #### With Hugging Face Inference Endpoints 53 <!-- test-ignore --> 54 ```python 55 from haystack.components.generators import HuggingFaceAPIGenerator 56 from haystack.utils import Secret 57 58 generator = HuggingFaceAPIGenerator(api_type="inference_endpoints", 59 api_params={"url": "<your-inference-endpoint-url>"}, 60 token=Secret.from_token("<your-api-key>")) 61 62 result = generator.run(prompt="What's Natural Language Processing?") 63 print(result) 64 ``` 65 66 #### With self-hosted text generation inference 67 <!-- test-ignore --> 68 ```python 69 from haystack.components.generators import HuggingFaceAPIGenerator 70 71 generator = HuggingFaceAPIGenerator(api_type="text_generation_inference", 72 api_params={"url": "http://localhost:8080"}) 73 74 result = generator.run(prompt="What's Natural Language Processing?") 75 print(result) 76 ``` 77 78 #### With the free serverless inference API 79 80 Be aware that this example might not work as the Hugging Face Inference API no longer offer models that support the 81 `text_generation` endpoint. Use the `HuggingFaceAPIChatGenerator` for generative models through the 82 `chat_completion` endpoint. 83 84 <!-- test-ignore --> 85 ```python 86 from haystack.components.generators import HuggingFaceAPIGenerator 87 from haystack.utils import Secret 88 89 generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api", 90 api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, 91 token=Secret.from_token("<your-api-key>")) 92 93 result = generator.run(prompt="What's Natural Language Processing?") 94 print(result) 95 ``` 96 """ 97 98 def __init__( 99 self, 100 api_type: HFGenerationAPIType | str, 101 api_params: dict[str, str], 102 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 103 generation_kwargs: dict[str, Any] | None = None, 104 stop_words: list[str] | None = None, 105 streaming_callback: StreamingCallbackT | None = None, 106 ) -> None: 107 """ 108 Initialize the HuggingFaceAPIGenerator instance. 109 110 :param api_type: 111 The type of Hugging Face API to use. Available types: 112 - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference). 113 - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints). 114 - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api). 115 This might no longer work due to changes in the models offered in the Hugging Face Inference API. 116 Please use the `HuggingFaceAPIChatGenerator` component instead. 117 :param api_params: 118 A dictionary with the following keys: 119 - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. 120 - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or 121 `TEXT_GENERATION_INFERENCE`. 122 - Other parameters specific to the chosen API type, such as `timeout`, `headers`, `provider` etc. 123 :param token: The Hugging Face token to use as HTTP bearer authorization. 124 Check your HF token in your [account settings](https://huggingface.co/settings/tokens). 125 :param generation_kwargs: 126 A dictionary with keyword arguments to customize text generation. Some examples: `max_new_tokens`, 127 `temperature`, `top_k`, `top_p`. 128 For details, see [Hugging Face documentation](https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation) 129 for more information. 130 :param stop_words: An optional list of strings representing the stop words. 131 :param streaming_callback: An optional callable for handling streaming responses. 132 """ 133 134 huggingface_hub_import.check() 135 136 if isinstance(api_type, str): 137 api_type = HFGenerationAPIType.from_str(api_type) 138 139 if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API: 140 logger.warning( 141 "Due to changes in the models offered in Hugging Face Inference API, using this component with the " 142 "Serverless Inference API might no longer work. " 143 "Please use the `HuggingFaceAPIChatGenerator` component instead." 144 ) 145 model = api_params.get("model") 146 if model is None: 147 raise ValueError( 148 "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." 149 ) 150 check_valid_model(model, HFModelType.GENERATION, token) 151 model_or_url = model 152 elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]: 153 url = api_params.get("url") 154 if url is None: 155 msg = ( 156 "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` " 157 "parameter in `api_params`." 158 ) 159 raise ValueError(msg) 160 if not is_valid_http_url(url): 161 raise ValueError(f"Invalid URL: {url}") 162 model_or_url = url 163 else: 164 msg = f"Unknown api_type {api_type}" 165 raise ValueError(msg) 166 167 # handle generation kwargs setup 168 generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} 169 generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) 170 generation_kwargs["stop_sequences"].extend(stop_words or []) 171 generation_kwargs.setdefault("max_new_tokens", 512) 172 173 self.api_type = api_type 174 self.api_params = api_params 175 self.token = token 176 self.generation_kwargs = generation_kwargs 177 self.streaming_callback = streaming_callback 178 179 resolved_api_params: dict[str, Any] = {k: v for k, v in api_params.items() if k != "model" and k != "url"} 180 self._client = InferenceClient( 181 model_or_url, token=token.resolve_value() if token else None, **resolved_api_params 182 ) 183 184 def to_dict(self) -> dict[str, Any]: 185 """ 186 Serialize this component to a dictionary. 187 188 :returns: 189 A dictionary containing the serialized component. 190 """ 191 callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None 192 return default_to_dict( 193 self, 194 api_type=str(self.api_type), 195 api_params=self.api_params, 196 token=self.token, 197 generation_kwargs=self.generation_kwargs, 198 streaming_callback=callback_name, 199 ) 200 201 @classmethod 202 def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIGenerator": 203 """ 204 Deserialize this component from a dictionary. 205 """ 206 init_params = data["init_parameters"] 207 serialized_callback_handler = init_params.get("streaming_callback") 208 if serialized_callback_handler: 209 init_params["streaming_callback"] = deserialize_callable(serialized_callback_handler) 210 return default_from_dict(cls, data) 211 212 @component.output_types(replies=list[str], meta=list[dict[str, Any]]) 213 def run( 214 self, 215 prompt: str, 216 streaming_callback: StreamingCallbackT | None = None, 217 generation_kwargs: dict[str, Any] | None = None, 218 ) -> dict[str, Any]: 219 """ 220 Invoke the text generation inference for the given prompt and generation parameters. 221 222 :param prompt: 223 A string representing the prompt. 224 :param streaming_callback: 225 A callback function that is called when a new token is received from the stream. 226 :param generation_kwargs: 227 Additional keyword arguments for text generation. 228 :returns: 229 A dictionary with the generated replies and metadata. Both are lists of length n. 230 - replies: A list of strings representing the generated replies. 231 """ 232 # update generation kwargs by merging with the default ones 233 generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} 234 235 # check if streaming_callback is passed 236 streaming_callback = select_streaming_callback( 237 init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False 238 ) 239 240 hf_output = self._client.text_generation( 241 prompt, details=True, stream=streaming_callback is not None, **generation_kwargs 242 ) 243 244 if streaming_callback is not None: 245 # mypy doesn't know that hf_output is a Iterable[TextGenerationStreamOutput], so we cast it 246 return self._stream_and_build_response( 247 hf_output=cast(Iterable[TextGenerationStreamOutput], hf_output), streaming_callback=streaming_callback 248 ) 249 250 # mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it 251 return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output)) 252 253 def _stream_and_build_response( 254 self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: SyncStreamingCallbackT 255 ) -> dict[str, Any]: 256 chunks: list[StreamingChunk] = [] 257 first_chunk_time = None 258 259 component_info = ComponentInfo.from_component(self) 260 for chunk in hf_output: 261 token: TextGenerationStreamOutputToken = chunk.token 262 if token.special: 263 continue 264 265 chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} 266 if first_chunk_time is None: 267 first_chunk_time = datetime.now().isoformat() 268 269 mapping: dict[str, FinishReason] = { 270 "length": "length", # Direct match 271 "eos_token": "stop", # EOS token means natural stop 272 "stop_sequence": "stop", # Stop sequence means natural stop 273 } 274 mapped_finish_reason = ( 275 mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None 276 ) 277 stream_chunk = StreamingChunk( 278 content=token.text, 279 meta=chunk_metadata, 280 component_info=component_info, 281 index=0, 282 start=len(chunks) == 0, 283 finish_reason=mapped_finish_reason, 284 ) 285 chunks.append(stream_chunk) 286 streaming_callback(stream_chunk) 287 288 metadata = { 289 "finish_reason": chunks[-1].meta.get("finish_reason", None), 290 "model": self._client.model, 291 "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)}, 292 "completion_start_time": first_chunk_time, 293 } 294 return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} 295 296 def _build_non_streaming_response(self, hf_output: "TextGenerationOutput") -> dict[str, Any]: 297 meta = [ 298 { 299 "model": self._client.model, 300 "finish_reason": hf_output.details.finish_reason if hf_output.details else None, 301 "usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0}, 302 } 303 ] 304 return {"replies": [hf_output.generated_text], "meta": meta}