/ haystack / components / generators / hugging_face_api.py
hugging_face_api.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from collections.abc import Iterable
  6  from dataclasses import asdict
  7  from datetime import datetime
  8  from typing import Any, cast
  9  
 10  from haystack import component, default_from_dict, default_to_dict, logging
 11  from haystack.dataclasses import (
 12      ComponentInfo,
 13      FinishReason,
 14      StreamingCallbackT,
 15      StreamingChunk,
 16      SyncStreamingCallbackT,
 17      select_streaming_callback,
 18  )
 19  from haystack.lazy_imports import LazyImport
 20  from haystack.utils import Secret, deserialize_callable, serialize_callable
 21  from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
 22  from haystack.utils.url_validation import is_valid_http_url
 23  
 24  with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
 25      from huggingface_hub import (
 26          InferenceClient,
 27          TextGenerationOutput,
 28          TextGenerationStreamOutput,
 29          TextGenerationStreamOutputToken,
 30      )
 31  
 32  
 33  logger = logging.getLogger(__name__)
 34  
 35  
 36  @component
 37  class HuggingFaceAPIGenerator:
 38      """
 39      Generates text using Hugging Face APIs.
 40  
 41      Use it with the following Hugging Face APIs:
 42      - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
 43      - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference)
 44  
 45      **Note:** As of July 2025, the Hugging Face Inference API no longer offers generative models through the
 46      `text_generation` endpoint. Generative models are now only available through providers supporting the
 47      `chat_completion` endpoint. As a result, this component might no longer work with the Hugging Face Inference API.
 48      Use the `HuggingFaceAPIChatGenerator` component, which supports the `chat_completion` endpoint.
 49  
 50      ### Usage examples
 51  
 52      #### With Hugging Face Inference Endpoints
 53      <!-- test-ignore -->
 54      ```python
 55      from haystack.components.generators import HuggingFaceAPIGenerator
 56      from haystack.utils import Secret
 57  
 58      generator = HuggingFaceAPIGenerator(api_type="inference_endpoints",
 59                                          api_params={"url": "<your-inference-endpoint-url>"},
 60                                          token=Secret.from_token("<your-api-key>"))
 61  
 62      result = generator.run(prompt="What's Natural Language Processing?")
 63      print(result)
 64      ```
 65  
 66      #### With self-hosted text generation inference
 67      <!-- test-ignore -->
 68      ```python
 69      from haystack.components.generators import HuggingFaceAPIGenerator
 70  
 71      generator = HuggingFaceAPIGenerator(api_type="text_generation_inference",
 72                                          api_params={"url": "http://localhost:8080"})
 73  
 74      result = generator.run(prompt="What's Natural Language Processing?")
 75      print(result)
 76      ```
 77  
 78      #### With the free serverless inference API
 79  
 80      Be aware that this example might not work as the Hugging Face Inference API no longer offer models that support the
 81      `text_generation` endpoint. Use the `HuggingFaceAPIChatGenerator` for generative models through the
 82      `chat_completion` endpoint.
 83  
 84      <!-- test-ignore -->
 85      ```python
 86      from haystack.components.generators import HuggingFaceAPIGenerator
 87      from haystack.utils import Secret
 88  
 89      generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
 90                                          api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
 91                                          token=Secret.from_token("<your-api-key>"))
 92  
 93      result = generator.run(prompt="What's Natural Language Processing?")
 94      print(result)
 95      ```
 96      """
 97  
 98      def __init__(
 99          self,
100          api_type: HFGenerationAPIType | str,
101          api_params: dict[str, str],
102          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
103          generation_kwargs: dict[str, Any] | None = None,
104          stop_words: list[str] | None = None,
105          streaming_callback: StreamingCallbackT | None = None,
106      ) -> None:
107          """
108          Initialize the HuggingFaceAPIGenerator instance.
109  
110          :param api_type:
111              The type of Hugging Face API to use. Available types:
112              - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference).
113              - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints).
114              - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api).
115                This might no longer work due to changes in the models offered in the Hugging Face Inference API.
116                Please use the `HuggingFaceAPIChatGenerator` component instead.
117          :param api_params:
118              A dictionary with the following keys:
119              - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
120              - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
121              `TEXT_GENERATION_INFERENCE`.
122              - Other parameters specific to the chosen API type, such as `timeout`, `headers`, `provider` etc.
123          :param token: The Hugging Face token to use as HTTP bearer authorization.
124              Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
125          :param generation_kwargs:
126              A dictionary with keyword arguments to customize text generation. Some examples: `max_new_tokens`,
127              `temperature`, `top_k`, `top_p`.
128              For details, see [Hugging Face documentation](https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation)
129              for more information.
130          :param stop_words: An optional list of strings representing the stop words.
131          :param streaming_callback: An optional callable for handling streaming responses.
132          """
133  
134          huggingface_hub_import.check()
135  
136          if isinstance(api_type, str):
137              api_type = HFGenerationAPIType.from_str(api_type)
138  
139          if api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API:
140              logger.warning(
141                  "Due to changes in the models offered in Hugging Face Inference API, using this component with the "
142                  "Serverless Inference API might no longer work. "
143                  "Please use the `HuggingFaceAPIChatGenerator` component instead."
144              )
145              model = api_params.get("model")
146              if model is None:
147                  raise ValueError(
148                      "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
149                  )
150              check_valid_model(model, HFModelType.GENERATION, token)
151              model_or_url = model
152          elif api_type in [HFGenerationAPIType.INFERENCE_ENDPOINTS, HFGenerationAPIType.TEXT_GENERATION_INFERENCE]:
153              url = api_params.get("url")
154              if url is None:
155                  msg = (
156                      "To use Text Generation Inference or Inference Endpoints, you need to specify the `url` "
157                      "parameter in `api_params`."
158                  )
159                  raise ValueError(msg)
160              if not is_valid_http_url(url):
161                  raise ValueError(f"Invalid URL: {url}")
162              model_or_url = url
163          else:
164              msg = f"Unknown api_type {api_type}"
165              raise ValueError(msg)
166  
167          # handle generation kwargs setup
168          generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
169          generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
170          generation_kwargs["stop_sequences"].extend(stop_words or [])
171          generation_kwargs.setdefault("max_new_tokens", 512)
172  
173          self.api_type = api_type
174          self.api_params = api_params
175          self.token = token
176          self.generation_kwargs = generation_kwargs
177          self.streaming_callback = streaming_callback
178  
179          resolved_api_params: dict[str, Any] = {k: v for k, v in api_params.items() if k != "model" and k != "url"}
180          self._client = InferenceClient(
181              model_or_url, token=token.resolve_value() if token else None, **resolved_api_params
182          )
183  
184      def to_dict(self) -> dict[str, Any]:
185          """
186          Serialize this component to a dictionary.
187  
188          :returns:
189              A dictionary containing the serialized component.
190          """
191          callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
192          return default_to_dict(
193              self,
194              api_type=str(self.api_type),
195              api_params=self.api_params,
196              token=self.token,
197              generation_kwargs=self.generation_kwargs,
198              streaming_callback=callback_name,
199          )
200  
201      @classmethod
202      def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIGenerator":
203          """
204          Deserialize this component from a dictionary.
205          """
206          init_params = data["init_parameters"]
207          serialized_callback_handler = init_params.get("streaming_callback")
208          if serialized_callback_handler:
209              init_params["streaming_callback"] = deserialize_callable(serialized_callback_handler)
210          return default_from_dict(cls, data)
211  
212      @component.output_types(replies=list[str], meta=list[dict[str, Any]])
213      def run(
214          self,
215          prompt: str,
216          streaming_callback: StreamingCallbackT | None = None,
217          generation_kwargs: dict[str, Any] | None = None,
218      ) -> dict[str, Any]:
219          """
220          Invoke the text generation inference for the given prompt and generation parameters.
221  
222          :param prompt:
223              A string representing the prompt.
224          :param streaming_callback:
225              A callback function that is called when a new token is received from the stream.
226          :param generation_kwargs:
227              Additional keyword arguments for text generation.
228          :returns:
229              A dictionary with the generated replies and metadata. Both are lists of length n.
230              - replies: A list of strings representing the generated replies.
231          """
232          # update generation kwargs by merging with the default ones
233          generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
234  
235          # check if streaming_callback is passed
236          streaming_callback = select_streaming_callback(
237              init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
238          )
239  
240          hf_output = self._client.text_generation(
241              prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
242          )
243  
244          if streaming_callback is not None:
245              # mypy doesn't know that hf_output is a Iterable[TextGenerationStreamOutput], so we cast it
246              return self._stream_and_build_response(
247                  hf_output=cast(Iterable[TextGenerationStreamOutput], hf_output), streaming_callback=streaming_callback
248              )
249  
250          # mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
251          return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
252  
253      def _stream_and_build_response(
254          self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: SyncStreamingCallbackT
255      ) -> dict[str, Any]:
256          chunks: list[StreamingChunk] = []
257          first_chunk_time = None
258  
259          component_info = ComponentInfo.from_component(self)
260          for chunk in hf_output:
261              token: TextGenerationStreamOutputToken = chunk.token
262              if token.special:
263                  continue
264  
265              chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
266              if first_chunk_time is None:
267                  first_chunk_time = datetime.now().isoformat()
268  
269              mapping: dict[str, FinishReason] = {
270                  "length": "length",  # Direct match
271                  "eos_token": "stop",  # EOS token means natural stop
272                  "stop_sequence": "stop",  # Stop sequence means natural stop
273              }
274              mapped_finish_reason = (
275                  mapping.get(chunk_metadata["finish_reason"], "stop") if chunk_metadata.get("finish_reason") else None
276              )
277              stream_chunk = StreamingChunk(
278                  content=token.text,
279                  meta=chunk_metadata,
280                  component_info=component_info,
281                  index=0,
282                  start=len(chunks) == 0,
283                  finish_reason=mapped_finish_reason,
284              )
285              chunks.append(stream_chunk)
286              streaming_callback(stream_chunk)
287  
288          metadata = {
289              "finish_reason": chunks[-1].meta.get("finish_reason", None),
290              "model": self._client.model,
291              "usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)},
292              "completion_start_time": first_chunk_time,
293          }
294          return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
295  
296      def _build_non_streaming_response(self, hf_output: "TextGenerationOutput") -> dict[str, Any]:
297          meta = [
298              {
299                  "model": self._client.model,
300                  "finish_reason": hf_output.details.finish_reason if hf_output.details else None,
301                  "usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0},
302              }
303          ]
304          return {"replies": [hf_output.generated_text], "meta": meta}