zero_shot_document_classifier.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from dataclasses import replace 6 from typing import Any 7 8 from haystack import Document, component, default_from_dict, default_to_dict 9 from haystack.lazy_imports import LazyImport 10 from haystack.utils import ComponentDevice, Secret 11 from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs 12 13 with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: 14 from transformers import Pipeline as HfPipeline 15 from transformers import pipeline 16 17 18 @component 19 class TransformersZeroShotDocumentClassifier: 20 """ 21 Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata. 22 23 The component uses a Hugging Face pipeline for zero-shot classification. 24 Provide the model and the set of labels to be used for categorization during initialization. 25 Additionally, you can configure the component to allow multiple labels to be true. 26 27 Classification is run on the document's content field by default. If you want it to run on another field, set the 28 `classification_field` to one of the document's metadata fields. 29 30 Available models for the task of zero-shot-classification include: 31 - `valhalla/distilbart-mnli-12-3` 32 - `cross-encoder/nli-distilroberta-base` 33 - `cross-encoder/nli-deberta-v3-xsmall` 34 35 ### Usage example 36 37 The following is a pipeline that classifies documents based on predefined classification labels 38 retrieved from a search pipeline: 39 40 ```python 41 from haystack import Document 42 from haystack.components.retrievers.in_memory import InMemoryBM25Retriever 43 from haystack.document_stores.in_memory import InMemoryDocumentStore 44 from haystack.core.pipeline import Pipeline 45 from haystack.components.classifiers import TransformersZeroShotDocumentClassifier 46 47 documents = [Document(id="0", content="Today was a nice day!"), 48 Document(id="1", content="Yesterday was a bad day!")] 49 50 document_store = InMemoryDocumentStore() 51 retriever = InMemoryBM25Retriever(document_store=document_store) 52 document_classifier = TransformersZeroShotDocumentClassifier( 53 model="cross-encoder/nli-deberta-v3-xsmall", 54 labels=["positive", "negative"], 55 ) 56 57 document_store.write_documents(documents) 58 59 pipeline = Pipeline() 60 pipeline.add_component(instance=retriever, name="retriever") 61 pipeline.add_component(instance=document_classifier, name="document_classifier") 62 pipeline.connect("retriever", "document_classifier") 63 64 queries = ["How was your day today?", "How was your day yesterday?"] 65 expected_predictions = ["positive", "negative"] 66 67 for idx, query in enumerate(queries): 68 result = pipeline.run({"retriever": {"query": query, "top_k": 1}}) 69 assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx) 70 assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"] 71 == expected_predictions[idx]) 72 ``` 73 """ 74 75 def __init__( 76 self, 77 model: str, 78 labels: list[str], 79 multi_label: bool = False, 80 classification_field: str | None = None, 81 device: ComponentDevice | None = None, 82 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 83 huggingface_pipeline_kwargs: dict[str, Any] | None = None, 84 ) -> None: 85 """ 86 Initializes the TransformersZeroShotDocumentClassifier. 87 88 See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli) 89 for the full list of zero-shot classification models (NLI) models. 90 91 :param model: 92 The name or path of a Hugging Face model for zero shot document classification. 93 :param labels: 94 The set of possible class labels to classify each document into, for example, 95 ["positive", "negative"]. The labels depend on the selected model. 96 :param multi_label: 97 Whether or not multiple candidate labels can be true. 98 If `False`, the scores are normalized such that 99 the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered 100 independent and probabilities are normalized for each candidate by doing a softmax of the entailment 101 score vs. the contradiction score. 102 :param classification_field: 103 Name of document's meta field to be used for classification. 104 If not set, `Document.content` is used by default. 105 :param device: 106 The device on which the model is loaded. If `None`, the default device is automatically 107 selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. 108 :param token: 109 The Hugging Face token to use as HTTP bearer authorization. 110 Check your HF token in your [account settings](https://huggingface.co/settings/tokens). 111 :param huggingface_pipeline_kwargs: 112 Dictionary containing keyword arguments used to initialize the 113 Hugging Face pipeline for text classification. 114 """ 115 116 torch_and_transformers_import.check() 117 118 self.classification_field = classification_field 119 120 self.token = token 121 self.labels = labels 122 self.multi_label = multi_label 123 124 huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( 125 huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, 126 model=model, 127 task="zero-shot-classification", 128 supported_tasks=["zero-shot-classification"], 129 device=device, 130 token=token, 131 ) 132 133 self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs 134 self.pipeline: HfPipeline | None = None 135 136 def _get_telemetry_data(self) -> dict[str, Any]: 137 """ 138 Data that is sent to Posthog for usage analytics. 139 """ 140 if isinstance(self.huggingface_pipeline_kwargs["model"], str): 141 return {"model": self.huggingface_pipeline_kwargs["model"]} 142 return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"} 143 144 def warm_up(self) -> None: 145 """ 146 Initializes the component. 147 """ 148 if self.pipeline is None: 149 self.pipeline = pipeline(**self.huggingface_pipeline_kwargs) 150 151 def to_dict(self) -> dict[str, Any]: 152 """ 153 Serializes the component to a dictionary. 154 155 :returns: 156 Dictionary with serialized data. 157 """ 158 serialization_dict = default_to_dict( 159 self, 160 labels=self.labels, 161 model=self.huggingface_pipeline_kwargs["model"], 162 huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, 163 token=self.token, 164 ) 165 166 huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] 167 huggingface_pipeline_kwargs.pop("token", None) 168 169 serialize_hf_model_kwargs(huggingface_pipeline_kwargs) 170 return serialization_dict 171 172 @classmethod 173 def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier": 174 """ 175 Deserializes the component from a dictionary. 176 177 :param data: 178 Dictionary to deserialize from. 179 :returns: 180 Deserialized component. 181 """ 182 if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None: 183 deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) 184 return default_from_dict(cls, data) 185 186 @component.output_types(documents=list[Document]) 187 def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]: 188 """ 189 Classifies the documents based on the provided labels and adds them to their metadata. 190 191 The classification results are stored in the `classification` dict within 192 each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under 193 the `details` key within the `classification` dictionary. 194 195 :param documents: 196 Documents to process. 197 :param batch_size: 198 Batch size used for processing the content in each document. 199 :returns: 200 A dictionary with the following key: 201 - `documents`: A list of documents with an added metadata field called `classification`. 202 """ 203 204 if self.pipeline is None: 205 self.warm_up() 206 207 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 208 raise TypeError( 209 "TransformerZeroShotDocumentClassifier expects a list of documents as input. " 210 "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter." 211 ) 212 213 invalid_doc_ids = [] 214 215 for doc in documents: 216 if self.classification_field is not None and self.classification_field not in doc.meta: 217 invalid_doc_ids.append(doc.id) 218 219 if invalid_doc_ids: 220 raise ValueError( 221 f"The following documents do not have the classification field '{self.classification_field}': " 222 f"{', '.join(invalid_doc_ids)}" 223 ) 224 225 texts = [ 226 (doc.content if self.classification_field is None else doc.meta[self.classification_field]) 227 for doc in documents 228 ] 229 230 # mypy doesn't know this is set in warm_up 231 predictions = self.pipeline( # type: ignore[misc] 232 texts, self.labels, multi_label=self.multi_label, batch_size=batch_size 233 ) 234 235 new_documents = [] 236 for prediction, document in zip(predictions, documents, strict=True): 237 formatted_prediction = { 238 "label": prediction["labels"][0], 239 "score": prediction["scores"][0], 240 "details": dict(zip(prediction["labels"], prediction["scores"], strict=True)), 241 } 242 new_meta = {**document.meta, "classification": formatted_prediction} 243 new_documents.append(replace(document, meta=new_meta)) 244 245 return {"documents": new_documents}