/ test / components / classifiers / test_zero_shot_document_classifier.py
test_zero_shot_document_classifier.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from unittest.mock import patch
  6  
  7  import pytest
  8  
  9  from haystack import Document, Pipeline
 10  from haystack.components.classifiers import TransformersZeroShotDocumentClassifier
 11  from haystack.components.retrievers import InMemoryBM25Retriever
 12  from haystack.utils import ComponentDevice, Secret
 13  
 14  
 15  class TestTransformersZeroShotDocumentClassifier:
 16      def test_init(self):
 17          component = TransformersZeroShotDocumentClassifier(
 18              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
 19          )
 20          assert component.labels == ["positive", "negative"]
 21          assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
 22          assert component.multi_label is False
 23          assert component.pipeline is None
 24          assert component.classification_field is None
 25  
 26      def test_to_dict(self):
 27          component = TransformersZeroShotDocumentClassifier(
 28              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
 29          )
 30          component_dict = component.to_dict()
 31          assert component_dict == {
 32              "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",  # noqa: E501
 33              "init_parameters": {
 34                  "model": "cross-encoder/nli-deberta-v3-xsmall",
 35                  "labels": ["positive", "negative"],
 36                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
 37                  "huggingface_pipeline_kwargs": {
 38                      "model": "cross-encoder/nli-deberta-v3-xsmall",
 39                      "device": ComponentDevice.resolve_device(None).to_hf(),
 40                      "task": "zero-shot-classification",
 41                  },
 42              },
 43          }
 44  
 45      def test_from_dict(self, del_hf_env_vars):
 46          data = {
 47              "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",  # noqa: E501
 48              "init_parameters": {
 49                  "model": "cross-encoder/nli-deberta-v3-xsmall",
 50                  "labels": ["positive", "negative"],
 51                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
 52                  "huggingface_pipeline_kwargs": {
 53                      "model": "cross-encoder/nli-deberta-v3-xsmall",
 54                      "device": ComponentDevice.resolve_device(None).to_hf(),
 55                      "task": "zero-shot-classification",
 56                  },
 57              },
 58          }
 59          component = TransformersZeroShotDocumentClassifier.from_dict(data)
 60          assert component.labels == ["positive", "negative"]
 61          assert component.pipeline is None
 62          assert component.token == Secret.from_dict(
 63              {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
 64          )
 65          assert component.huggingface_pipeline_kwargs == {
 66              "model": "cross-encoder/nli-deberta-v3-xsmall",
 67              "device": ComponentDevice.resolve_device(None).to_hf(),
 68              "task": "zero-shot-classification",
 69              "token": None,
 70          }
 71  
 72      def test_from_dict_no_default_parameters(self, del_hf_env_vars):
 73          data = {
 74              "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier",  # noqa: E501
 75              "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]},
 76          }
 77          component = TransformersZeroShotDocumentClassifier.from_dict(data)
 78          assert component.labels == ["positive", "negative"]
 79          assert component.pipeline is None
 80          assert component.token == Secret.from_dict(
 81              {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}
 82          )
 83          assert component.huggingface_pipeline_kwargs == {
 84              "model": "cross-encoder/nli-deberta-v3-xsmall",
 85              "device": ComponentDevice.resolve_device(None).to_hf(),
 86              "task": "zero-shot-classification",
 87              "token": None,
 88          }
 89  
 90      @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
 91      def test_warm_up(self, hf_pipeline_mock):
 92          component = TransformersZeroShotDocumentClassifier(
 93              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
 94          )
 95          component.warm_up()
 96          assert component.pipeline is not None
 97  
 98      @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
 99      @patch.object(TransformersZeroShotDocumentClassifier, "warm_up")
100      def test_run_calls_warm_up(self, warm_up_mock, hf_pipeline_mock):
101          hf_pipeline_mock.return_value = [
102              {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]}
103          ]
104          component = TransformersZeroShotDocumentClassifier(
105              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
106          )
107          warm_up_mock.side_effect = lambda: setattr(component, "pipeline", hf_pipeline_mock)
108          positive_documents = [Document(content="That's good. I like it.")]
109          component.run(documents=positive_documents)
110          warm_up_mock.assert_called_once()
111  
112      @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
113      def test_run_fails_with_non_document_input(self, hf_pipeline_mock):
114          hf_pipeline_mock.return_value = " "
115          component = TransformersZeroShotDocumentClassifier(
116              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
117          )
118          text_list = ["That's good. I like it.", "That's bad. I don't like it."]
119          with pytest.raises(TypeError):
120              component.run(documents=text_list)
121  
122      @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline")
123      def test_run_unit(self, hf_pipeline_mock):
124          hf_pipeline_mock.return_value = [
125              {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]},
126              {"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]},
127          ]
128          component = TransformersZeroShotDocumentClassifier(
129              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
130          )
131          component.pipeline = hf_pipeline_mock
132          positive_document = Document(content="That's good. I like it.")
133          negative_document = Document(content="That's bad. I don't like it.")
134          result = component.run(documents=[positive_document, negative_document])
135          assert component.pipeline is not None
136          assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
137          assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
138          assert "classification" not in positive_document.to_dict()
139          assert "classification" not in negative_document.to_dict()
140  
141      @pytest.mark.integration
142      @pytest.mark.slow
143      def test_run(self, del_hf_env_vars):
144          component = TransformersZeroShotDocumentClassifier(
145              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
146          )
147          positive_document = Document(content="That's good. I like it. " * 1000)
148          negative_document = Document(content="That's bad. I don't like it.")
149          result = component.run(documents=[positive_document, negative_document])
150          assert component.pipeline is not None
151          assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
152          assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
153          assert "classification" not in positive_document.to_dict()
154          assert "classification" not in negative_document.to_dict()
155  
156      def test_serialization_and_deserialization_pipeline(self, in_memory_doc_store):
157          pipeline = Pipeline()
158          retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store)
159          document_classifier = TransformersZeroShotDocumentClassifier(
160              model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"]
161          )
162  
163          pipeline.add_component(instance=retriever, name="retriever")
164          pipeline.add_component(instance=document_classifier, name="document_classifier")
165          pipeline.connect("retriever", "document_classifier")
166          pipeline_dump = pipeline.dumps()
167  
168          new_pipeline = Pipeline.loads(pipeline_dump)
169  
170          assert new_pipeline == pipeline