/ test / components / retrievers / test_multi_query_text_retriever.py
test_multi_query_text_retriever.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from unittest.mock import ANY
  7  
  8  import pytest
  9  
 10  from haystack import Document, Pipeline
 11  from haystack.components.generators.chat import OpenAIChatGenerator
 12  from haystack.components.query import QueryExpander
 13  from haystack.components.retrievers import InMemoryBM25Retriever, MultiQueryTextRetriever
 14  from haystack.components.writers import DocumentWriter
 15  from haystack.document_stores.in_memory import InMemoryDocumentStore
 16  from haystack.document_stores.types import DuplicatePolicy
 17  
 18  
 19  class TestMultiQueryTextRetriever:
 20      @pytest.fixture
 21      def sample_documents(self):
 22          return [
 23              Document(
 24                  content="Renewable energy is energy that is collected from renewable resources.",
 25                  meta={"category": None},
 26              ),
 27              Document(
 28                  content="Solar energy is a type of green energy that is harnessed from the sun.",
 29                  meta={"category": "solar"},
 30              ),
 31              Document(
 32                  content="Wind energy is another type of green energy that is generated by wind turbines",
 33                  meta={"category": "wind"},
 34              ),
 35              Document(
 36                  content="Hydropower is a form of renewable energy using the flow of water to generate electricity.",
 37                  meta={"category": "hydro"},
 38              ),
 39              Document(
 40                  content="Geothermal energy is heat that comes from the sub-surface of the earth.",
 41                  meta={"category": "geo"},
 42              ),
 43              Document(
 44                  content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
 45                  meta={"category": "fossil"},
 46              ),
 47              Document(
 48                  content="Nuclear energy is produced through nuclear reactions, typically using uranium or plutonium "
 49                  "as fuel.",
 50                  meta={"category": "nuclear"},
 51              ),
 52          ]
 53  
 54      @pytest.fixture
 55      def document_store_with_docs(self, sample_documents):
 56          document_store = InMemoryDocumentStore()
 57          doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
 58          doc_writer.run(documents=sample_documents)
 59          return document_store
 60  
 61      def test_init_with_default_parameters(self, in_memory_doc_store):
 62          in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store)
 63          retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
 64          assert retriever.retriever == in_memory_retriever
 65          assert retriever.max_workers == 3
 66  
 67      def test_init_with_custom_parameters(self, in_memory_doc_store):
 68          in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store)
 69          retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2)
 70          assert retriever.retriever == in_memory_retriever
 71          assert retriever.max_workers == 2
 72  
 73      def test_to_dict(self, in_memory_doc_store):
 74          in_memory_retriever = InMemoryBM25Retriever(document_store=in_memory_doc_store)
 75          multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2)
 76          result = multi_retriever.to_dict()
 77          assert result == {
 78              "type": "haystack.components.retrievers.multi_query_text_retriever.MultiQueryTextRetriever",
 79              "init_parameters": {
 80                  "retriever": {
 81                      "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
 82                      "init_parameters": {
 83                          "document_store": {
 84                              "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
 85                              "init_parameters": {
 86                                  "bm25_tokenization_regex": "(?u)\\b\\w+\\b",
 87                                  "bm25_algorithm": "BM25L",
 88                                  "bm25_parameters": {},
 89                                  "embedding_similarity_function": "dot_product",
 90                                  "index": ANY,
 91                                  "return_embedding": True,
 92                              },
 93                          },
 94                          "filters": None,
 95                          "top_k": 10,
 96                          "scale_score": False,
 97                          "filter_policy": "replace",
 98                      },
 99                  },
100                  "max_workers": 2,
101              },
102          }
103  
104      def test_from_dict(self):
105          data = {
106              "type": "haystack.components.retrievers.multi_query_text_retriever.MultiQueryTextRetriever",
107              "init_parameters": {
108                  "retriever": {
109                      "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
110                      "init_parameters": {
111                          "document_store": {
112                              "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
113                              "init_parameters": {
114                                  "bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
115                                  "bm25_algorithm": "BM25L",
116                                  "bm25_parameters": {},
117                                  "embedding_similarity_function": "dot_product",
118                                  "index": "88144fa9-6e45-4e5d-8647-4c4002d8b6db",
119                                  "return_embedding": True,
120                              },
121                          },
122                          "filters": None,
123                          "top_k": 10,
124                          "scale_score": False,
125                          "filter_policy": "replace",
126                      },
127                  },
128                  "max_workers": 3,
129              },
130          }
131          result = MultiQueryTextRetriever.from_dict(data)
132          assert isinstance(result, MultiQueryTextRetriever)
133          assert result.retriever.__class__.__name__ == "InMemoryBM25Retriever"
134          assert result.max_workers == 3
135  
136      def test_run_with_multiple_queries(self, document_store_with_docs):
137          in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
138          multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
139          queries = ["renewable energy", "solar power", "wind turbines"]
140          result = multi_retriever.run(queries=queries)
141  
142          assert "documents" in result
143          assert len(result["documents"]) > 0
144          assert all(isinstance(doc, Document) for doc in result["documents"])
145          scores = [doc.score for doc in result["documents"] if doc.score is not None]
146          assert scores == sorted(scores, reverse=True)
147  
148      @pytest.mark.integration
149      def test_run_with_filters(self, document_store_with_docs):
150          in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
151          filters = {"field": "category", "operator": "==", "value": "solar"}
152          multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
153          result = multi_retriever.run(queries=["energy"], retriever_kwargs={"filters": filters})
154          assert "documents" in result
155          assert all(doc.meta.get("category") == "solar" for doc in result["documents"])
156  
157      @pytest.mark.skipif(
158          not os.environ.get("OPENAI_API_KEY", None),
159          reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
160      )
161      @pytest.mark.integration
162      def test_pipeline_integration(self, document_store_with_docs):
163          pipeline = Pipeline()
164          expander = QueryExpander(
165              chat_generator=OpenAIChatGenerator(model="gpt-4.1-nano"), n_expansions=3, include_original_query=True
166          )
167          in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
168          multiquery_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=3)
169          pipeline.add_component("query_expander", expander)
170          pipeline.add_component("multiquery_retriever", multiquery_retriever)
171          pipeline.connect("query_expander.queries", "multiquery_retriever.queries")
172  
173          data = {
174              "query_expander": {"query": "green energy sources"},
175              "multiquery_retriever": {"retriever_kwargs": {"top_k": 3}},
176          }
177          results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"})
178  
179          assert "multiquery_retriever" in results
180          assert "documents" in results["multiquery_retriever"]
181          assert len(results["multiquery_retriever"]["documents"]) > 0
182          assert "query_expander" in results
183          assert "queries" in results["query_expander"]
184          assert len(results["query_expander"]["queries"]) == 4
185  
186          # assert that documents are sorted by score (highest first)
187          scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None]
188          assert scores == sorted(scores, reverse=True)
189  
190          # assert there are not duplicates
191          contents = [doc.content for doc in results["multiquery_retriever"]["documents"]]
192          assert len(contents) == len(set(contents))