/ haystack / components / classifiers / zero_shot_document_classifier.py
zero_shot_document_classifier.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from dataclasses import replace
  6  from typing import Any
  7  
  8  from haystack import Document, component, default_from_dict, default_to_dict
  9  from haystack.lazy_imports import LazyImport
 10  from haystack.utils import ComponentDevice, Secret
 11  from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
 12  
 13  with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
 14      from transformers import Pipeline as HfPipeline
 15      from transformers import pipeline
 16  
 17  
 18  @component
 19  class TransformersZeroShotDocumentClassifier:
 20      """
 21      Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
 22  
 23      The component uses a Hugging Face pipeline for zero-shot classification.
 24      Provide the model and the set of labels to be used for categorization during initialization.
 25      Additionally, you can configure the component to allow multiple labels to be true.
 26  
 27      Classification is run on the document's content field by default. If you want it to run on another field, set the
 28      `classification_field` to one of the document's metadata fields.
 29  
 30      Available models for the task of zero-shot-classification include:
 31          - `valhalla/distilbart-mnli-12-3`
 32          - `cross-encoder/nli-distilroberta-base`
 33          - `cross-encoder/nli-deberta-v3-xsmall`
 34  
 35      ### Usage example
 36  
 37      The following is a pipeline that classifies documents based on predefined classification labels
 38      retrieved from a search pipeline:
 39  
 40      ```python
 41      from haystack import Document
 42      from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
 43      from haystack.document_stores.in_memory import InMemoryDocumentStore
 44      from haystack.core.pipeline import Pipeline
 45      from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
 46  
 47      documents = [Document(id="0", content="Today was a nice day!"),
 48                   Document(id="1", content="Yesterday was a bad day!")]
 49  
 50      document_store = InMemoryDocumentStore()
 51      retriever = InMemoryBM25Retriever(document_store=document_store)
 52      document_classifier = TransformersZeroShotDocumentClassifier(
 53          model="cross-encoder/nli-deberta-v3-xsmall",
 54          labels=["positive", "negative"],
 55      )
 56  
 57      document_store.write_documents(documents)
 58  
 59      pipeline = Pipeline()
 60      pipeline.add_component(instance=retriever, name="retriever")
 61      pipeline.add_component(instance=document_classifier, name="document_classifier")
 62      pipeline.connect("retriever", "document_classifier")
 63  
 64      queries = ["How was your day today?", "How was your day yesterday?"]
 65      expected_predictions = ["positive", "negative"]
 66  
 67      for idx, query in enumerate(queries):
 68          result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
 69          assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
 70          assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
 71                  == expected_predictions[idx])
 72      ```
 73      """
 74  
 75      def __init__(
 76          self,
 77          model: str,
 78          labels: list[str],
 79          multi_label: bool = False,
 80          classification_field: str | None = None,
 81          device: ComponentDevice | None = None,
 82          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 83          huggingface_pipeline_kwargs: dict[str, Any] | None = None,
 84      ) -> None:
 85          """
 86          Initializes the TransformersZeroShotDocumentClassifier.
 87  
 88          See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
 89          for the full list of zero-shot classification models (NLI) models.
 90  
 91          :param model:
 92              The name or path of a Hugging Face model for zero shot document classification.
 93          :param labels:
 94              The set of possible class labels to classify each document into, for example,
 95              ["positive", "negative"]. The labels depend on the selected model.
 96          :param multi_label:
 97              Whether or not multiple candidate labels can be true.
 98              If `False`, the scores are normalized such that
 99              the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
100              independent and probabilities are normalized for each candidate by doing a softmax of the entailment
101              score vs. the contradiction score.
102          :param classification_field:
103              Name of document's meta field to be used for classification.
104              If not set, `Document.content` is used by default.
105          :param device:
106              The device on which the model is loaded. If `None`, the default device is automatically
107              selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
108          :param token:
109              The Hugging Face token to use as HTTP bearer authorization.
110              Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
111          :param huggingface_pipeline_kwargs:
112              Dictionary containing keyword arguments used to initialize the
113              Hugging Face pipeline for text classification.
114          """
115  
116          torch_and_transformers_import.check()
117  
118          self.classification_field = classification_field
119  
120          self.token = token
121          self.labels = labels
122          self.multi_label = multi_label
123  
124          huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
125              huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
126              model=model,
127              task="zero-shot-classification",
128              supported_tasks=["zero-shot-classification"],
129              device=device,
130              token=token,
131          )
132  
133          self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
134          self.pipeline: HfPipeline | None = None
135  
136      def _get_telemetry_data(self) -> dict[str, Any]:
137          """
138          Data that is sent to Posthog for usage analytics.
139          """
140          if isinstance(self.huggingface_pipeline_kwargs["model"], str):
141              return {"model": self.huggingface_pipeline_kwargs["model"]}
142          return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
143  
144      def warm_up(self) -> None:
145          """
146          Initializes the component.
147          """
148          if self.pipeline is None:
149              self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
150  
151      def to_dict(self) -> dict[str, Any]:
152          """
153          Serializes the component to a dictionary.
154  
155          :returns:
156              Dictionary with serialized data.
157          """
158          serialization_dict = default_to_dict(
159              self,
160              labels=self.labels,
161              model=self.huggingface_pipeline_kwargs["model"],
162              huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
163              token=self.token,
164          )
165  
166          huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
167          huggingface_pipeline_kwargs.pop("token", None)
168  
169          serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
170          return serialization_dict
171  
172      @classmethod
173      def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
174          """
175          Deserializes the component from a dictionary.
176  
177          :param data:
178              Dictionary to deserialize from.
179          :returns:
180              Deserialized component.
181          """
182          if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
183              deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
184          return default_from_dict(cls, data)
185  
186      @component.output_types(documents=list[Document])
187      def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
188          """
189          Classifies the documents based on the provided labels and adds them to their metadata.
190  
191          The classification results are stored in the `classification` dict within
192          each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
193          the `details` key within the `classification` dictionary.
194  
195          :param documents:
196              Documents to process.
197          :param batch_size:
198              Batch size used for processing the content in each document.
199          :returns:
200              A dictionary with the following key:
201              - `documents`: A list of documents with an added metadata field called `classification`.
202          """
203  
204          if self.pipeline is None:
205              self.warm_up()
206  
207          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
208              raise TypeError(
209                  "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
210                  "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
211              )
212  
213          invalid_doc_ids = []
214  
215          for doc in documents:
216              if self.classification_field is not None and self.classification_field not in doc.meta:
217                  invalid_doc_ids.append(doc.id)
218  
219          if invalid_doc_ids:
220              raise ValueError(
221                  f"The following documents do not have the classification field '{self.classification_field}': "
222                  f"{', '.join(invalid_doc_ids)}"
223              )
224  
225          texts = [
226              (doc.content if self.classification_field is None else doc.meta[self.classification_field])
227              for doc in documents
228          ]
229  
230          # mypy doesn't know this is set in warm_up
231          predictions = self.pipeline(  # type: ignore[misc]
232              texts, self.labels, multi_label=self.multi_label, batch_size=batch_size
233          )
234  
235          new_documents = []
236          for prediction, document in zip(predictions, documents, strict=True):
237              formatted_prediction = {
238                  "label": prediction["labels"][0],
239                  "score": prediction["scores"][0],
240                  "details": dict(zip(prediction["labels"], prediction["scores"], strict=True)),
241              }
242              new_meta = {**document.meta, "classification": formatted_prediction}
243              new_documents.append(replace(document, meta=new_meta))
244  
245          return {"documents": new_documents}