test_multi_query_text_retriever.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from unittest.mock import ANY 7 8 import pytest 9 10 from haystack import Document, Pipeline 11 from haystack.components.generators.chat import OpenAIChatGenerator 12 from haystack.components.query import QueryExpander 13 from haystack.components.retrievers import InMemoryBM25Retriever, MultiQueryTextRetriever 14 from haystack.components.writers import DocumentWriter 15 from haystack.document_stores.in_memory import InMemoryDocumentStore 16 from haystack.document_stores.types import DuplicatePolicy 17 18 19 class TestMultiQueryTextRetriever: 20 @pytest.fixture 21 def sample_documents(self): 22 return [ 23 Document( 24 content="Renewable energy is energy that is collected from renewable resources.", 25 meta={"category": None}, 26 ), 27 Document( 28 content="Solar energy is a type of green energy that is harnessed from the sun.", 29 meta={"category": "solar"}, 30 ), 31 Document( 32 content="Wind energy is another type of green energy that is generated by wind turbines", 33 meta={"category": "wind"}, 34 ), 35 Document( 36 content="Hydropower is a form of renewable energy using the flow of water to generate electricity.", 37 meta={"category": "hydro"}, 38 ), 39 Document( 40 content="Geothermal energy is heat that comes from the sub-surface of the earth.", 41 meta={"category": "geo"}, 42 ), 43 Document( 44 content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.", 45 meta={"category": "fossil"}, 46 ), 47 Document( 48 content="Nuclear energy is produced through nuclear reactions, typically using uranium or plutonium " 49 "as fuel.", 50 meta={"category": "nuclear"}, 51 ), 52 ] 53 54 @pytest.fixture 55 def document_store_with_docs(self, sample_documents): 56 document_store = InMemoryDocumentStore() 57 doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP) 58 doc_writer.run(documents=sample_documents) 59 return document_store 60 61 def test_init_with_default_parameters(self, in_memory_doc_store): 62 in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store) 63 retriever = MultiQueryTextRetriever(retriever=in_memory_retriever) 64 assert retriever.retriever == in_memory_retriever 65 assert retriever.max_workers == 3 66 67 def test_init_with_custom_parameters(self, in_memory_doc_store): 68 in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store) 69 retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2) 70 assert retriever.retriever == in_memory_retriever 71 assert retriever.max_workers == 2 72 73 def test_to_dict(self, in_memory_doc_store): 74 in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store) 75 multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2) 76 result = multi_retriever.to_dict() 77 assert result == { 78 "type": "haystack.components.retrievers.multi_query_text_retriever.MultiQueryTextRetriever", 79 "init_parameters": { 80 "retriever": { 81 "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever", 82 "init_parameters": { 83 "document_store": { 84 "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", 85 "init_parameters": { 86 "bm25_tokenization_regex": "(?u)\\b\\w+\\b", 87 "bm25_algorithm": "BM25L", 88 "bm25_parameters": {}, 89 "embedding_similarity_function": "dot_product", 90 "index": ANY, 91 "return_embedding": True, 92 }, 93 }, 94 "filters": None, 95 "top_k": 10, 96 "scale_score": False, 97 "filter_policy": "replace", 98 }, 99 }, 100 "max_workers": 2, 101 }, 102 } 103 104 def test_from_dict(self): 105 data = { 106 "type": "haystack.components.retrievers.multi_query_text_retriever.MultiQueryTextRetriever", 107 "init_parameters": { 108 "retriever": { 109 "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever", 110 "init_parameters": { 111 "document_store": { 112 "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", 113 "init_parameters": { 114 "bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b", 115 "bm25_algorithm": "BM25L", 116 "bm25_parameters": {}, 117 "embedding_similarity_function": "dot_product", 118 "index": "88144fa9-6e45-4e5d-8647-4c4002d8b6db", 119 "return_embedding": True, 120 }, 121 }, 122 "filters": None, 123 "top_k": 10, 124 "scale_score": False, 125 "filter_policy": "replace", 126 }, 127 }, 128 "max_workers": 3, 129 }, 130 } 131 result = MultiQueryTextRetriever.from_dict(data) 132 assert isinstance(result, MultiQueryTextRetriever) 133 assert result.retriever.__class__.__name__ == "InMemoryBM25Retriever" 134 assert result.max_workers == 3 135 136 def test_run_with_multiple_queries(self, document_store_with_docs): 137 in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs) 138 multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever) 139 queries = ["renewable energy", "solar power", "wind turbines"] 140 result = multi_retriever.run(queries=queries) 141 142 assert "documents" in result 143 assert len(result["documents"]) > 0 144 assert all(isinstance(doc, Document) for doc in result["documents"]) 145 scores = [doc.score for doc in result["documents"] if doc.score is not None] 146 assert scores == sorted(scores, reverse=True) 147 148 @pytest.mark.integration 149 def test_run_with_filters(self, document_store_with_docs): 150 in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs) 151 filters = {"field": "category", "operator": "==", "value": "solar"} 152 multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever) 153 result = multi_retriever.run(queries=["energy"], retriever_kwargs={"filters": filters}) 154 assert "documents" in result 155 assert all(doc.meta.get("category") == "solar" for doc in result["documents"]) 156 157 @pytest.mark.skipif( 158 not os.environ.get("OPENAI_API_KEY", None), 159 reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", 160 ) 161 @pytest.mark.integration 162 def test_pipeline_integration(self, document_store_with_docs): 163 pipeline = Pipeline() 164 expander = QueryExpander( 165 chat_generator=OpenAIChatGenerator(model="gpt-4.1-nano"), n_expansions=3, include_original_query=True 166 ) 167 in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs) 168 multiquery_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=3) 169 pipeline.add_component("query_expander", expander) 170 pipeline.add_component("multiquery_retriever", multiquery_retriever) 171 pipeline.connect("query_expander.queries", "multiquery_retriever.queries") 172 173 data = { 174 "query_expander": {"query": "green energy sources"}, 175 "multiquery_retriever": {"retriever_kwargs": {"top_k": 3}}, 176 } 177 results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"}) 178 179 assert "multiquery_retriever" in results 180 assert "documents" in results["multiquery_retriever"] 181 assert len(results["multiquery_retriever"]["documents"]) > 0 182 assert "query_expander" in results 183 assert "queries" in results["query_expander"] 184 assert len(results["query_expander"]["queries"]) == 4 185 186 # assert that documents are sorted by score (highest first) 187 scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None] 188 assert scores == sorted(scores, reverse=True) 189 190 # assert there are not duplicates 191 contents = [doc.content for doc in results["multiquery_retriever"]["documents"]] 192 assert len(contents) == len(set(contents))