zero_shot_text_router.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 from haystack import component, default_from_dict, default_to_dict 8 from haystack.lazy_imports import LazyImport 9 from haystack.utils import ComponentDevice, Secret 10 11 with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: 12 from transformers import Pipeline as HfPipeline 13 from transformers import pipeline 14 15 from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs 16 17 18 @component 19 class TransformersZeroShotTextRouter: 20 """ 21 Routes the text strings to different connections based on a category label. 22 23 Specify the set of labels for categorization when initializing the component. 24 25 ### Usage example 26 27 ```python 28 from haystack import Document 29 from haystack.document_stores.in_memory import InMemoryDocumentStore 30 from haystack.core.pipeline import Pipeline 31 from haystack.components.routers import TransformersZeroShotTextRouter 32 from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder 33 from haystack.components.retrievers import InMemoryEmbeddingRetriever 34 35 document_store = InMemoryDocumentStore() 36 doc_embedder = SentenceTransformersDocumentEmbedder(model="intfloat/e5-base-v2") 37 docs = [ 38 Document( 39 content="Germany, officially the Federal Republic of Germany, is a country in the western region of " 40 "Central Europe. The nation's capital and most populous city is Berlin and its main financial centre " 41 "is Frankfurt; the largest urban area is the Ruhr." 42 ), 43 Document( 44 content="France, officially the French Republic, is a country located primarily in Western Europe. " 45 "France is a unitary semi-presidential republic with its capital in Paris, the country's largest city " 46 "and main cultural and commercial centre; other major urban areas include Marseille, Lyon, Toulouse, " 47 "Lille, Bordeaux, Strasbourg, Nantes and Nice." 48 ) 49 ] 50 docs_with_embeddings = doc_embedder.run(docs) 51 document_store.write_documents(docs_with_embeddings["documents"]) 52 53 p = Pipeline() 54 p.add_component(instance=TransformersZeroShotTextRouter(labels=["passage", "query"]), name="text_router") 55 p.add_component( 56 instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="passage: "), 57 name="passage_embedder" 58 ) 59 p.add_component( 60 instance=SentenceTransformersTextEmbedder(model="intfloat/e5-base-v2", prefix="query: "), 61 name="query_embedder" 62 ) 63 p.add_component( 64 instance=InMemoryEmbeddingRetriever(document_store=document_store), 65 name="query_retriever" 66 ) 67 p.add_component( 68 instance=InMemoryEmbeddingRetriever(document_store=document_store), 69 name="passage_retriever" 70 ) 71 72 p.connect("text_router.passage", "passage_embedder.text") 73 p.connect("passage_embedder.embedding", "passage_retriever.query_embedding") 74 p.connect("text_router.query", "query_embedder.text") 75 p.connect("query_embedder.embedding", "query_retriever.query_embedding") 76 77 # Query Example 78 p.run({"text_router": {"text": "What is the capital of Germany?"}}) 79 80 # Passage Example 81 p.run({ 82 "text_router":{ 83 "text": "The United Kingdom of Great Britain and Northern Ireland, commonly known as the "\ 84 "United Kingdom (UK) or Britain, is a country in Northwestern Europe, off the north-western coast of "\ 85 "the continental mainland." 86 } 87 }) 88 ``` 89 """ 90 91 def __init__( 92 self, 93 labels: list[str], 94 multi_label: bool = False, 95 model: str = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33", 96 device: ComponentDevice | None = None, 97 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 98 huggingface_pipeline_kwargs: dict[str, Any] | None = None, 99 ) -> None: 100 """ 101 Initializes the TransformersZeroShotTextRouter component. 102 103 :param labels: The set of labels to use for classification. Can be a single label, 104 a string of comma-separated labels, or a list of labels. 105 :param multi_label: 106 Indicates if multiple labels can be true. 107 If `False`, label scores are normalized so their sum equals 1 for each sequence. 108 If `True`, the labels are considered independent and probabilities are normalized for each candidate by 109 doing a softmax of the entailment score vs. the contradiction score. 110 :param model: The name or path of a Hugging Face model for zero-shot text classification. 111 :param device: The device for loading the model. If `None`, automatically selects the default device. 112 If a device or device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. 113 :param token: The API token used to download private models from Hugging Face. 114 If `True`, uses either `HF_API_TOKEN` or `HF_TOKEN` environment variables. 115 To generate these tokens, run `transformers-cli login`. 116 :param huggingface_pipeline_kwargs: A dictionary of keyword arguments for initializing the Hugging Face 117 zero shot text classification. 118 """ 119 torch_and_transformers_import.check() 120 121 self.token = token 122 self.labels = labels 123 self.multi_label = multi_label 124 component.set_output_types(self, **dict.fromkeys(labels, str)) 125 126 huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( 127 huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, 128 model=model, 129 task="zero-shot-classification", 130 supported_tasks=["zero-shot-classification"], 131 device=device, 132 token=token, 133 ) 134 self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs 135 self.pipeline: HfPipeline | None = None 136 137 def _get_telemetry_data(self) -> dict[str, Any]: 138 """ 139 Data that is sent to Posthog for usage analytics. 140 """ 141 if isinstance(self.huggingface_pipeline_kwargs["model"], str): 142 return {"model": self.huggingface_pipeline_kwargs["model"]} 143 return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} 144 145 def warm_up(self) -> None: 146 """ 147 Initializes the component. 148 """ 149 if self.pipeline is None: 150 self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) 151 152 def to_dict(self) -> dict[str, Any]: 153 """ 154 Serializes the component to a dictionary. 155 156 :returns: 157 Dictionary with serialized data. 158 """ 159 serialization_dict = default_to_dict( 160 self, labels=self.labels, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, token=self.token 161 ) 162 163 huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] 164 huggingface_pipeline_kwargs.pop("token", None) 165 166 serialize_hf_model_kwargs(huggingface_pipeline_kwargs) 167 return serialization_dict 168 169 @classmethod 170 def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotTextRouter": 171 """ 172 Deserializes the component from a dictionary. 173 174 :param data: 175 Dictionary to deserialize from. 176 :returns: 177 Deserialized component. 178 """ 179 if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None: 180 deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) 181 return default_from_dict(cls, data) 182 183 def run(self, text: str) -> dict[str, str]: 184 """ 185 Routes the text strings to different connections based on a category label. 186 187 :param text: A string of text to route. 188 :returns: 189 A dictionary with the label as key and the text as value. 190 191 :raises TypeError: 192 If the input is not a str. 193 """ 194 if self.pipeline is None: 195 self.warm_up() 196 197 if not isinstance(text, str): 198 raise TypeError("TransformersZeroShotTextRouter expects a str as input.") 199 200 # mypy doesn't know this is set in warm_up 201 prediction = self.pipeline( # type: ignore[misc] 202 [text], candidate_labels=self.labels, multi_label=self.multi_label 203 ) 204 predicted_scores = prediction[0]["scores"] 205 max_score_index = max(range(len(predicted_scores)), key=predicted_scores.__getitem__) 206 label = prediction[0]["labels"][max_score_index] 207 return {label: text}