test_zero_shot_document_classifier.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from unittest.mock import patch 6 7 import pytest 8 9 from haystack import Document, Pipeline 10 from haystack.components.classifiers import TransformersZeroShotDocumentClassifier 11 from haystack.components.retrievers import InMemoryBM25Retriever 12 from haystack.utils import ComponentDevice, Secret 13 14 15 class TestTransformersZeroShotDocumentClassifier: 16 def test_init(self): 17 component = TransformersZeroShotDocumentClassifier( 18 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 19 ) 20 assert component.labels == ["positive", "negative"] 21 assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) 22 assert component.multi_label is False 23 assert component.pipeline is None 24 assert component.classification_field is None 25 26 def test_to_dict(self): 27 component = TransformersZeroShotDocumentClassifier( 28 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 29 ) 30 component_dict = component.to_dict() 31 assert component_dict == { 32 "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 33 "init_parameters": { 34 "model": "cross-encoder/nli-deberta-v3-xsmall", 35 "labels": ["positive", "negative"], 36 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 37 "huggingface_pipeline_kwargs": { 38 "model": "cross-encoder/nli-deberta-v3-xsmall", 39 "device": ComponentDevice.resolve_device(None).to_hf(), 40 "task": "zero-shot-classification", 41 }, 42 }, 43 } 44 45 def test_from_dict(self, del_hf_env_vars): 46 data = { 47 "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 48 "init_parameters": { 49 "model": "cross-encoder/nli-deberta-v3-xsmall", 50 "labels": ["positive", "negative"], 51 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 52 "huggingface_pipeline_kwargs": { 53 "model": "cross-encoder/nli-deberta-v3-xsmall", 54 "device": ComponentDevice.resolve_device(None).to_hf(), 55 "task": "zero-shot-classification", 56 }, 57 }, 58 } 59 component = TransformersZeroShotDocumentClassifier.from_dict(data) 60 assert component.labels == ["positive", "negative"] 61 assert component.pipeline is None 62 assert component.token == Secret.from_dict( 63 {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} 64 ) 65 assert component.huggingface_pipeline_kwargs == { 66 "model": "cross-encoder/nli-deberta-v3-xsmall", 67 "device": ComponentDevice.resolve_device(None).to_hf(), 68 "task": "zero-shot-classification", 69 "token": None, 70 } 71 72 def test_from_dict_no_default_parameters(self, del_hf_env_vars): 73 data = { 74 "type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501 75 "init_parameters": {"model": "cross-encoder/nli-deberta-v3-xsmall", "labels": ["positive", "negative"]}, 76 } 77 component = TransformersZeroShotDocumentClassifier.from_dict(data) 78 assert component.labels == ["positive", "negative"] 79 assert component.pipeline is None 80 assert component.token == Secret.from_dict( 81 {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} 82 ) 83 assert component.huggingface_pipeline_kwargs == { 84 "model": "cross-encoder/nli-deberta-v3-xsmall", 85 "device": ComponentDevice.resolve_device(None).to_hf(), 86 "task": "zero-shot-classification", 87 "token": None, 88 } 89 90 @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") 91 def test_warm_up(self, hf_pipeline_mock): 92 component = TransformersZeroShotDocumentClassifier( 93 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 94 ) 95 component.warm_up() 96 assert component.pipeline is not None 97 98 @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") 99 @patch.object(TransformersZeroShotDocumentClassifier, "warm_up") 100 def test_run_calls_warm_up(self, warm_up_mock, hf_pipeline_mock): 101 hf_pipeline_mock.return_value = [ 102 {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]} 103 ] 104 component = TransformersZeroShotDocumentClassifier( 105 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 106 ) 107 warm_up_mock.side_effect = lambda: setattr(component, "pipeline", hf_pipeline_mock) 108 positive_documents = [Document(content="That's good. I like it.")] 109 component.run(documents=positive_documents) 110 warm_up_mock.assert_called_once() 111 112 @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") 113 def test_run_fails_with_non_document_input(self, hf_pipeline_mock): 114 hf_pipeline_mock.return_value = " " 115 component = TransformersZeroShotDocumentClassifier( 116 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 117 ) 118 text_list = ["That's good. I like it.", "That's bad. I don't like it."] 119 with pytest.raises(TypeError): 120 component.run(documents=text_list) 121 122 @patch("haystack.components.classifiers.zero_shot_document_classifier.pipeline") 123 def test_run_unit(self, hf_pipeline_mock): 124 hf_pipeline_mock.return_value = [ 125 {"sequence": "That's good. I like it.", "labels": ["positive", "negative"], "scores": [0.99, 0.01]}, 126 {"sequence": "That's bad. I don't like it.", "labels": ["negative", "positive"], "scores": [0.99, 0.01]}, 127 ] 128 component = TransformersZeroShotDocumentClassifier( 129 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 130 ) 131 component.pipeline = hf_pipeline_mock 132 positive_document = Document(content="That's good. I like it.") 133 negative_document = Document(content="That's bad. I don't like it.") 134 result = component.run(documents=[positive_document, negative_document]) 135 assert component.pipeline is not None 136 assert result["documents"][0].to_dict()["classification"]["label"] == "positive" 137 assert result["documents"][1].to_dict()["classification"]["label"] == "negative" 138 assert "classification" not in positive_document.to_dict() 139 assert "classification" not in negative_document.to_dict() 140 141 @pytest.mark.integration 142 @pytest.mark.slow 143 def test_run(self, del_hf_env_vars): 144 component = TransformersZeroShotDocumentClassifier( 145 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 146 ) 147 positive_document = Document(content="That's good. I like it. " * 1000) 148 negative_document = Document(content="That's bad. I don't like it.") 149 result = component.run(documents=[positive_document, negative_document]) 150 assert component.pipeline is not None 151 assert result["documents"][0].to_dict()["classification"]["label"] == "positive" 152 assert result["documents"][1].to_dict()["classification"]["label"] == "negative" 153 assert "classification" not in positive_document.to_dict() 154 assert "classification" not in negative_document.to_dict() 155 156 def test_serialization_and_deserialization_pipeline(self, in_memory_doc_store): 157 pipeline = Pipeline() 158 retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store) 159 document_classifier = TransformersZeroShotDocumentClassifier( 160 model="cross-encoder/nli-deberta-v3-xsmall", labels=["positive", "negative"] 161 ) 162 163 pipeline.add_component(instance=retriever, name="retriever") 164 pipeline.add_component(instance=document_classifier, name="document_classifier") 165 pipeline.connect("retriever", "document_classifier") 166 pipeline_dump = pipeline.dumps() 167 168 new_pipeline = Pipeline.loads(pipeline_dump) 169 170 assert new_pipeline == pipeline