multi_query_embedding_retriever.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from concurrent.futures import ThreadPoolExecutor 6 from typing import Any 7 8 from haystack import Document, component, default_from_dict, default_to_dict 9 from haystack.components.embedders.types.protocol import TextEmbedder 10 from haystack.components.retrievers.types import EmbeddingRetriever 11 from haystack.core.serialization import component_to_dict 12 from haystack.utils.misc import _deduplicate_documents 13 14 15 @component 16 class MultiQueryEmbeddingRetriever: 17 """ 18 A component that retrieves documents using multiple queries in parallel with an embedding-based retriever. 19 20 This component takes a list of text queries, converts them to embeddings using a query embedder, 21 and then uses an embedding-based retriever to find relevant documents for each query in parallel. 22 The results are combined and sorted by relevance score. 23 24 ### Usage example 25 26 ```python 27 from haystack import Document 28 from haystack.document_stores.in_memory import InMemoryDocumentStore 29 from haystack.document_stores.types import DuplicatePolicy 30 from haystack.components.embedders import SentenceTransformersTextEmbedder 31 from haystack.components.embedders import SentenceTransformersDocumentEmbedder 32 from haystack.components.retrievers import InMemoryEmbeddingRetriever 33 from haystack.components.writers import DocumentWriter 34 from haystack.components.retrievers import MultiQueryEmbeddingRetriever 35 36 documents = [ 37 Document(content="Renewable energy is energy that is collected from renewable resources."), 38 Document(content="Solar energy is a type of green energy that is harnessed from the sun."), 39 Document(content="Wind energy is another type of green energy that is generated by wind turbines."), 40 Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."), 41 Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."), 42 Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."), 43 ] 44 45 # Populate the document store 46 doc_store = InMemoryDocumentStore() 47 doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") 48 doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP) 49 documents = doc_embedder.run(documents)["documents"] 50 doc_writer.run(documents=documents) 51 52 # Run the multi-query retriever 53 in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1) 54 query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") 55 56 multi_query_retriever = MultiQueryEmbeddingRetriever( 57 retriever=in_memory_retriever, 58 query_embedder=query_embedder, 59 max_workers=3 60 ) 61 62 queries = ["Geothermal energy", "natural gas", "turbines"] 63 result = multi_query_retriever.run(queries=queries) 64 for doc in result["documents"]: 65 print(f"Content: {doc.content}, Score: {doc.score}") 66 # >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574 67 # >> Content: Renewable energy is energy that is collected from renewable resources., Score: 0.42763211298893034 68 # >> Content: Solar energy is a type of green energy that is harnessed from the sun., Score: 0.40077417016494354 69 # >> Content: Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources., Score: 0.3774863680 70 # >> Content: Wind energy is another type of green energy that is generated by wind turbines., Score: 0.30914239725622 71 # >> Content: Biomass energy is produced from organic materials, such as plant and animal waste., Score: 0.25173074243 72 ``` 73 """ # noqa E501 74 75 def __init__(self, *, retriever: EmbeddingRetriever, query_embedder: TextEmbedder, max_workers: int = 3) -> None: 76 """ 77 Initialize MultiQueryEmbeddingRetriever. 78 79 :param retriever: The embedding-based retriever to use for document retrieval. 80 :param query_embedder: The query embedder to convert text queries to embeddings. 81 :param max_workers: Maximum number of worker threads for parallel processing. 82 """ 83 self.retriever = retriever 84 self.query_embedder = query_embedder 85 self.max_workers = max_workers 86 self._is_warmed_up = False 87 88 def warm_up(self) -> None: 89 """ 90 Warm up the query embedder and the retriever if any has a warm_up method. 91 """ 92 if not self._is_warmed_up: 93 if hasattr(self.query_embedder, "warm_up") and callable(self.query_embedder.warm_up): 94 self.query_embedder.warm_up() 95 if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): 96 self.retriever.warm_up() 97 self._is_warmed_up = True 98 99 @component.output_types(documents=list[Document]) 100 def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None) -> dict[str, list[Document]]: 101 """ 102 Retrieve documents using multiple queries in parallel. 103 104 :param queries: List of text queries to process. 105 :param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method. 106 :returns: 107 A dictionary containing: 108 - `documents`: List of retrieved documents sorted by relevance score. 109 """ 110 docs: list[Document] = [] 111 retriever_kwargs = retriever_kwargs or {} 112 113 if not self._is_warmed_up: 114 self.warm_up() 115 116 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: 117 queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries) 118 for result in queries_results: 119 if not result: 120 continue 121 docs.extend(result) 122 123 # de-duplicate and sort 124 docs = _deduplicate_documents(docs) 125 docs.sort(key=lambda x: x.score or 0.0, reverse=True) 126 return {"documents": docs} 127 128 def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None: 129 """ 130 Process a single query on a separate thread. 131 132 :param query: The text query to process. 133 :returns: 134 List of retrieved documents or None if no results. 135 """ 136 embedding_result = self.query_embedder.run(text=query) 137 query_embedding = embedding_result["embedding"] 138 result = self.retriever.run(query_embedding=query_embedding, **(retriever_kwargs or {})) 139 if result and "documents" in result: 140 return result["documents"] 141 return None 142 143 def to_dict(self) -> dict[str, Any]: 144 """ 145 Serializes the component to a dictionary. 146 147 :returns: 148 A dictionary representing the serialized component. 149 """ 150 return default_to_dict( 151 self, 152 retriever=component_to_dict(obj=self.retriever, name="retriever"), 153 query_embedder=component_to_dict(obj=self.query_embedder, name="query_embedder"), 154 max_workers=self.max_workers, 155 ) 156 157 @classmethod 158 def from_dict(cls, data: dict[str, Any]) -> "MultiQueryEmbeddingRetriever": 159 """ 160 Deserializes the component from a dictionary. 161 162 :param data: The dictionary to deserialize from. 163 :returns: 164 The deserialized component. 165 """ 166 return default_from_dict(cls, data)