/ haystack / components / retrievers / multi_query_embedding_retriever.py
multi_query_embedding_retriever.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from concurrent.futures import ThreadPoolExecutor
  6  from typing import Any
  7  
  8  from haystack import Document, component, default_from_dict, default_to_dict
  9  from haystack.components.embedders.types.protocol import TextEmbedder
 10  from haystack.components.retrievers.types import EmbeddingRetriever
 11  from haystack.core.serialization import component_to_dict
 12  from haystack.utils.misc import _deduplicate_documents
 13  
 14  
 15  @component
 16  class MultiQueryEmbeddingRetriever:
 17      """
 18      A component that retrieves documents using multiple queries in parallel with an embedding-based retriever.
 19  
 20      This component takes a list of text queries, converts them to embeddings using a query embedder,
 21      and then uses an embedding-based retriever to find relevant documents for each query in parallel.
 22      The results are combined and sorted by relevance score.
 23  
 24      ### Usage example
 25  
 26      ```python
 27      from haystack import Document
 28      from haystack.document_stores.in_memory import InMemoryDocumentStore
 29      from haystack.document_stores.types import DuplicatePolicy
 30      from haystack.components.embedders import SentenceTransformersTextEmbedder
 31      from haystack.components.embedders import SentenceTransformersDocumentEmbedder
 32      from haystack.components.retrievers import InMemoryEmbeddingRetriever
 33      from haystack.components.writers import DocumentWriter
 34      from haystack.components.retrievers import MultiQueryEmbeddingRetriever
 35  
 36      documents = [
 37          Document(content="Renewable energy is energy that is collected from renewable resources."),
 38          Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
 39          Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
 40          Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
 41          Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
 42          Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
 43      ]
 44  
 45      # Populate the document store
 46      doc_store = InMemoryDocumentStore()
 47      doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
 48      doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
 49      documents = doc_embedder.run(documents)["documents"]
 50      doc_writer.run(documents=documents)
 51  
 52      # Run the multi-query retriever
 53      in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
 54      query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
 55  
 56      multi_query_retriever = MultiQueryEmbeddingRetriever(
 57          retriever=in_memory_retriever,
 58          query_embedder=query_embedder,
 59          max_workers=3
 60      )
 61  
 62      queries = ["Geothermal energy", "natural gas", "turbines"]
 63      result = multi_query_retriever.run(queries=queries)
 64      for doc in result["documents"]:
 65          print(f"Content: {doc.content}, Score: {doc.score}")
 66      # >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
 67      # >> Content: Renewable energy is energy that is collected from renewable resources., Score: 0.42763211298893034
 68      # >> Content: Solar energy is a type of green energy that is harnessed from the sun., Score: 0.40077417016494354
 69      # >> Content: Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources., Score: 0.3774863680
 70      # >> Content: Wind energy is another type of green energy that is generated by wind turbines., Score: 0.30914239725622
 71      # >> Content: Biomass energy is produced from organic materials, such as plant and animal waste., Score: 0.25173074243
 72      ```
 73      """  # noqa E501
 74  
 75      def __init__(self, *, retriever: EmbeddingRetriever, query_embedder: TextEmbedder, max_workers: int = 3) -> None:
 76          """
 77          Initialize MultiQueryEmbeddingRetriever.
 78  
 79          :param retriever: The embedding-based retriever to use for document retrieval.
 80          :param query_embedder: The query embedder to convert text queries to embeddings.
 81          :param max_workers: Maximum number of worker threads for parallel processing.
 82          """
 83          self.retriever = retriever
 84          self.query_embedder = query_embedder
 85          self.max_workers = max_workers
 86          self._is_warmed_up = False
 87  
 88      def warm_up(self) -> None:
 89          """
 90          Warm up the query embedder and the retriever if any has a warm_up method.
 91          """
 92          if not self._is_warmed_up:
 93              if hasattr(self.query_embedder, "warm_up") and callable(self.query_embedder.warm_up):
 94                  self.query_embedder.warm_up()
 95              if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up):
 96                  self.retriever.warm_up()
 97              self._is_warmed_up = True
 98  
 99      @component.output_types(documents=list[Document])
100      def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None) -> dict[str, list[Document]]:
101          """
102          Retrieve documents using multiple queries in parallel.
103  
104          :param queries: List of text queries to process.
105          :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
106          :returns:
107              A dictionary containing:
108                  - `documents`: List of retrieved documents sorted by relevance score.
109          """
110          docs: list[Document] = []
111          retriever_kwargs = retriever_kwargs or {}
112  
113          if not self._is_warmed_up:
114              self.warm_up()
115  
116          with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
117              queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries)
118              for result in queries_results:
119                  if not result:
120                      continue
121                  docs.extend(result)
122  
123          # de-duplicate and sort
124          docs = _deduplicate_documents(docs)
125          docs.sort(key=lambda x: x.score or 0.0, reverse=True)
126          return {"documents": docs}
127  
128      def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None:
129          """
130          Process a single query on a separate thread.
131  
132          :param query: The text query to process.
133          :returns:
134              List of retrieved documents or None if no results.
135          """
136          embedding_result = self.query_embedder.run(text=query)
137          query_embedding = embedding_result["embedding"]
138          result = self.retriever.run(query_embedding=query_embedding, **(retriever_kwargs or {}))
139          if result and "documents" in result:
140              return result["documents"]
141          return None
142  
143      def to_dict(self) -> dict[str, Any]:
144          """
145          Serializes the component to a dictionary.
146  
147          :returns:
148              A dictionary representing the serialized component.
149          """
150          return default_to_dict(
151              self,
152              retriever=component_to_dict(obj=self.retriever, name="retriever"),
153              query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"),
154              max_workers=self.max_workers,
155          )
156  
157      @classmethod
158      def from_dict(cls, data: dict[str, Any]) -> "MultiQueryEmbeddingRetriever":
159          """
160          Deserializes the component from a dictionary.
161  
162          :param data: The dictionary to deserialize from.
163          :returns:
164              The deserialized component.
165          """
166          return default_from_dict(cls, data)