/ haystack / components / routers / zero_shot_text_router.py
zero_shot_text_router.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  from haystack import component, default_from_dict, default_to_dict
  8  from haystack.lazy_imports import LazyImport
  9  from haystack.utils import ComponentDevice, Secret
 10  
 11  with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
 12      from transformers import Pipeline as HfPipeline
 13      from transformers import pipeline
 14  
 15      from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
 16  
 17  
 18  @component
 19  class TransformersZeroShotTextRouter:
 20      """
 21      Routes the text strings to different connections based on a category label.
 22  
 23      Specify the set of labels for categorization when initializing the component.
 24  
 25      ### Usage example
 26  
 27      ```python
 28      from haystack import Document
 29      from haystack.document_stores.in_memory import InMemoryDocumentStore
 30      from haystack.core.pipeline import Pipeline
 31      from haystack.components.routers import TransformersZeroShotTextRouter
 32      from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
 33      from haystack.components.retrievers import InMemoryEmbeddingRetriever
 34  
 35      document_store = InMemoryDocumentStore()
 36      doc_embedder = SentenceTransformersDocumentEmbedder(model="intfloat/e5-base-v2")
 37      docs = [
 38          Document(
 39              content="Germany, officially the Federal Republic of Germany, is a country in the western region of "
 40              "Central Europe. The nation's capital and most populous city is Berlin and its main financial centre "
 41              "is Frankfurt; the largest urban area is the Ruhr."
 42          ),
 43          Document(
 44              content="France, officially the French Republic, is a country located primarily in Western Europe. "
 45              "France is a unitary semi-presidential republic with its capital in Paris, the country's largest city "
 46              "and main cultural and commercial centre; other major urban areas include Marseille, Lyon, Toulouse, "
 47              "Lille, Bordeaux, Strasbourg, Nantes and Nice."
 48          )
 49      ]
 50      docs_with_embeddings = doc_embedder.run(docs)
 51      document_store.write_documents(docs_with_embeddings["documents"])
 52  
 53      p = Pipeline()
 54      p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router")
 55      p.add_component(
 56          instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="passage: "),
 57          name="passage_embedder"
 58      )
 59      p.add_component(
 60          instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "),
 61          name="query_embedder"
 62      )
 63      p.add_component(
 64          instance=InMemoryEmbeddingRetriever(document_store=document_store),
 65          name="query_retriever"
 66      )
 67      p.add_component(
 68          instance=InMemoryEmbeddingRetriever(document_store=document_store),
 69          name="passage_retriever"
 70      )
 71  
 72      p.connect("text_router.passage", "passage_embedder.text")
 73      p.connect("passage_embedder.embedding", "passage_retriever.query_embedding")
 74      p.connect("text_router.query", "query_embedder.text")
 75      p.connect("query_embedder.embedding", "query_retriever.query_embedding")
 76  
 77      # Query Example
 78      p.run({"text_router": {"text": "What is the capital of Germany?"}})
 79  
 80      # Passage Example
 81      p.run({
 82          "text_router":{
 83              "text": "The United Kingdom of Great Britain and Northern Ireland, commonly known as the "\
 84              "United Kingdom (UK) or Britain, is a country in Northwestern Europe, off the north-western coast of "\
 85              "the continental mainland."
 86          }
 87      })
 88      ```
 89      """
 90  
 91      def __init__(
 92          self,
 93          labels: list[str],
 94          multi_label: bool = False,
 95          model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33",
 96          device: ComponentDevice | None = None,
 97          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 98          huggingface_pipeline_kwargs: dict[str, Any] | None = None,
 99      ) -> None:
100          """
101          Initializes the TransformersZeroShotTextRouter component.
102  
103          :param labels: The set of labels to use for classification. Can be a single label,
104              a string of comma-separated labels, or a list of labels.
105          :param multi_label:
106              Indicates if multiple labels can be true.
107              If `False`, label scores are normalized so their sum equals 1 for each sequence.
108              If `True`, the labels are considered independent and probabilities are normalized for each candidate by
109              doing a softmax of the entailment score vs. the contradiction score.
110          :param model: The name or path of a Hugging Face model for zero-shot text classification.
111          :param device: The device for loading the model. If `None`, automatically selects the default device.
112              If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
113          :param token: The API token used to download private models from Hugging Face.
114              If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables.
115              To generate these tokens, run `transformers-cli login`.
116          :param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face
117              zero shot text classification.
118          """
119          torch_and_transformers_import.check()
120  
121          self.token = token
122          self.labels = labels
123          self.multi_label = multi_label
124          component.set_output_types(self, **dict.fromkeys(labels, str))
125  
126          huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
127              huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
128              model=model,
129              task="zero-shot-classification",
130              supported_tasks=["zero-shot-classification"],
131              device=device,
132              token=token,
133          )
134          self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
135          self.pipeline: HfPipeline | None = None
136  
137      def _get_telemetry_data(self) -> dict[str, Any]:
138          """
139          Data that is sent to Posthog for usage analytics.
140          """
141          if isinstance(self.huggingface_pipeline_kwargs["model"], str):
142              return {"model": self.huggingface_pipeline_kwargs["model"]}
143          return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
144  
145      def warm_up(self) -> None:
146          """
147          Initializes the component.
148          """
149          if self.pipeline is None:
150              self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
151  
152      def to_dict(self) -> dict[str, Any]:
153          """
154          Serializes the component to a dictionary.
155  
156          :returns:
157              Dictionary with serialized data.
158          """
159          serialization_dict = default_to_dict(
160              self, labels=self.labels, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, token=self.token
161          )
162  
163          huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
164          huggingface_pipeline_kwargs.pop("token", None)
165  
166          serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
167          return serialization_dict
168  
169      @classmethod
170      def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotTextRouter":
171          """
172          Deserializes the component from a dictionary.
173  
174          :param data:
175              Dictionary to deserialize from.
176          :returns:
177              Deserialized component.
178          """
179          if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
180              deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
181          return default_from_dict(cls, data)
182  
183      def run(self, text: str) -> dict[str, str]:
184          """
185          Routes the text strings to different connections based on a category label.
186  
187          :param text: A string of text to route.
188          :returns:
189              A dictionary with the label as key and the text as value.
190  
191          :raises TypeError:
192              If the input is not a str.
193          """
194          if self.pipeline is None:
195              self.warm_up()
196  
197          if not isinstance(text, str):
198              raise TypeError("TransformersZeroShotTextRouter expects a str as input.")
199  
200          # mypy doesn't know this is set in warm_up
201          prediction = self.pipeline(  # type: ignore[misc]
202              [text], candidate_labels=self.labels, multi_label=self.multi_label
203          )
204          predicted_scores = prediction[0]["scores"]
205          max_score_index = max(range(len(predicted_scores)), key=predicted_scores.__getitem__)
206          label = prediction[0]["labels"][max_score_index]
207          return {label: text}