query_expander.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 from haystack import default_from_dict, default_to_dict, logging 8 from haystack.components.builders.prompt_builder import PromptBuilder 9 from haystack.components.generators.chat.openai import OpenAIChatGenerator 10 from haystack.components.generators.chat.types import ChatGenerator 11 from haystack.core.component import component 12 from haystack.core.serialization import component_to_dict 13 from haystack.dataclasses.chat_message import ChatMessage 14 from haystack.utils import deserialize_chatgenerator_inplace 15 from haystack.utils.misc import _parse_dict_from_json 16 17 logger = logging.getLogger(__name__) 18 19 20 DEFAULT_PROMPT_TEMPLATE = """ 21 You are part of an information system that processes user queries for retrieval. 22 You have to expand a given query into {{ n_expansions }} queries that are 23 semantically similar to improve retrieval recall. 24 25 Structure: 26 Follow the structure shown below in examples to generate expanded queries. 27 28 Examples: 29 1. Query: "climate change effects" 30 {"queries": ["impact of climate change", "consequences of global warming", "effects of environmental changes"]} 31 32 2. Query: "machine learning algorithms" 33 {"queries": ["neural networks", "clustering techniques", "supervised learning methods", "deep learning models"]} 34 35 3. Query: "open source NLP frameworks" 36 {"queries": ["natural language processing tools", "free nlp libraries", "open-source NLP platforms"]} 37 38 Guidelines: 39 - Generate queries that use different words and phrasings 40 - Include synonyms and related terms 41 - Maintain the same core meaning and intent 42 - Make queries that are likely to retrieve relevant information the original might miss 43 - Focus on variations that would work well with keyword-based search 44 - Respond in the same language as the input query 45 46 Your Task: 47 Query: "{{ query }}" 48 49 You *must* respond with a JSON object containing a "queries" array with the expanded queries. 50 Example: {"queries": ["query1", "query2", "query3"]}""" 51 52 53 @component 54 class QueryExpander: 55 """ 56 A component that returns a list of semantically similar queries to improve retrieval recall in RAG systems. 57 58 The component uses a chat generator to expand queries. The chat generator is expected to return a JSON response 59 with the following structure: 60 ```json 61 {"queries": ["expanded query 1", "expanded query 2", "expanded query 3"]} 62 ``` 63 64 ### Usage example 65 66 ```python 67 from haystack.components.generators.chat.openai import OpenAIChatGenerator 68 from haystack.components.query import QueryExpander 69 70 expander = QueryExpander( 71 chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), 72 n_expansions=3 73 ) 74 75 result = expander.run(query="green energy sources") 76 print(result["queries"]) 77 # Output: ['alternative query 1', 'alternative query 2', 'alternative query 3', 'green energy sources'] 78 # Note: Up to 3 additional queries + 1 original query (if include_original_query=True) 79 80 # To control total number of queries: 81 expander = QueryExpander(n_expansions=2, include_original_query=True) # Up to 3 total 82 # or 83 expander = QueryExpander(n_expansions=3, include_original_query=False) # Exactly 3 total 84 ``` 85 """ 86 87 def __init__( 88 self, 89 *, 90 chat_generator: ChatGenerator | None = None, 91 prompt_template: str | None = None, 92 n_expansions: int = 4, 93 include_original_query: bool = True, 94 ) -> None: 95 """ 96 Initialize the QueryExpander component. 97 98 :param chat_generator: The chat generator component to use for query expansion. 99 If None, a default OpenAIChatGenerator with gpt-4.1-mini model is used. 100 :param prompt_template: Custom [PromptBuilder](https://docs.haystack.deepset.ai/docs/promptbuilder) 101 template for query expansion. The template should instruct the LLM to return a JSON response with the 102 structure: `{"queries": ["query1", "query2", "query3"]}`. The template should include 'query' and 103 'n_expansions' variables. 104 :param n_expansions: Number of alternative queries to generate (default: 4). 105 :param include_original_query: Whether to include the original query in the output. 106 """ 107 if n_expansions <= 0: 108 raise ValueError("n_expansions must be positive") 109 110 self.n_expansions = n_expansions 111 self.include_original_query = include_original_query 112 113 if chat_generator is None: 114 self.chat_generator: ChatGenerator = OpenAIChatGenerator( 115 model="gpt-4.1-mini", 116 generation_kwargs={ 117 "temperature": 0.7, 118 "response_format": { 119 "type": "json_schema", 120 "json_schema": { 121 "name": "query_expansion", 122 "schema": { 123 "type": "object", 124 "properties": {"queries": {"type": "array", "items": {"type": "string"}}}, 125 "required": ["queries"], 126 "additionalProperties": False, 127 }, 128 }, 129 }, 130 "seed": 42, 131 }, 132 ) 133 else: 134 self.chat_generator = chat_generator 135 136 self._is_warmed_up = False 137 self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE 138 139 # Check if required variables are present in the template 140 if "query" not in self.prompt_template: 141 logger.warning( 142 "The prompt template does not contain the 'query' variable. This may cause issues during execution." 143 ) 144 if "n_expansions" not in self.prompt_template: 145 logger.warning( 146 "The prompt template does not contain the 'n_expansions' variable. " 147 "This may cause issues during execution." 148 ) 149 150 self._prompt_builder = PromptBuilder( 151 template=self.prompt_template, required_variables=["n_expansions", "query"] 152 ) 153 154 def to_dict(self) -> dict[str, Any]: 155 """ 156 Serializes the component to a dictionary. 157 158 :return: Dictionary with serialized data. 159 """ 160 return default_to_dict( 161 self, 162 chat_generator=component_to_dict(self.chat_generator, name="chat_generator"), 163 prompt_template=self.prompt_template, 164 n_expansions=self.n_expansions, 165 include_original_query=self.include_original_query, 166 ) 167 168 @classmethod 169 def from_dict(cls, data: dict[str, Any]) -> "QueryExpander": 170 """ 171 Deserializes the component from a dictionary. 172 173 :param data: Dictionary with serialized data. 174 :return: Deserialized component. 175 """ 176 init_params = data.get("init_parameters", {}) 177 178 deserialize_chatgenerator_inplace(init_params, key="chat_generator") 179 180 return default_from_dict(cls, data) 181 182 @component.output_types(queries=list[str]) 183 def run(self, query: str, n_expansions: int | None = None) -> dict[str, list[str]]: 184 """ 185 Expand the input query into multiple semantically similar queries. 186 187 The language of the original query is preserved in the expanded queries. 188 189 :param query: The original query to expand. 190 :param n_expansions: Number of additional queries to generate (not including the original). 191 If None, uses the value from initialization. Can be 0 to generate no additional queries. 192 :return: Dictionary with "queries" key containing the list of expanded queries. 193 If include_original_query=True, the original query will be included in addition 194 to the n_expansions alternative queries. 195 :raises ValueError: If n_expansions is not positive (less than or equal to 0). 196 """ 197 198 if not self._is_warmed_up: 199 self.warm_up() 200 201 response = {"queries": [query] if self.include_original_query else []} 202 203 if not query.strip(): 204 logger.warning("Empty query provided to QueryExpander") 205 return response 206 207 expansion_count = n_expansions if n_expansions is not None else self.n_expansions 208 if expansion_count <= 0: 209 raise ValueError("n_expansions must be positive") 210 211 try: 212 prompt_result = self._prompt_builder.run(query=query.strip(), n_expansions=expansion_count) 213 generator_result = self.chat_generator.run(messages=[ChatMessage.from_user(prompt_result["prompt"])]) 214 215 if not generator_result.get("replies") or len(generator_result["replies"]) == 0: 216 logger.warning("ChatGenerator returned no replies for query: {query}", query=query) 217 return response 218 219 expanded_text = generator_result["replies"][0].text.strip() 220 expanded_queries = self._parse_expanded_queries(expanded_text) 221 222 # Limit the number of expanded queries to the requested amount 223 if len(expanded_queries) > expansion_count: 224 logger.warning( 225 "Generated {generated_count} queries but only {requested_count} were requested. " 226 "Truncating to the first {requested_count} queries. ", 227 generated_count=len(expanded_queries), 228 requested_count=expansion_count, 229 ) 230 expanded_queries = expanded_queries[:expansion_count] 231 232 # Add original query if requested and remove duplicates 233 if self.include_original_query: 234 expanded_queries_lower = [q.lower() for q in expanded_queries] 235 if query.lower() not in expanded_queries_lower: 236 expanded_queries.append(query) 237 238 response["queries"] = expanded_queries 239 return response 240 241 except Exception as e: 242 # Fallback: return original query to maintain pipeline functionality 243 logger.exception("Failed to expand query {query}: {error}", query=query, error=str(e)) 244 return response 245 246 def warm_up(self) -> None: 247 """ 248 Warm up the LLM provider component. 249 """ 250 if not self._is_warmed_up: 251 if hasattr(self.chat_generator, "warm_up"): 252 self.chat_generator.warm_up() 253 self._is_warmed_up = True 254 255 @staticmethod 256 def _parse_expanded_queries(generator_response: str) -> list[str]: 257 """ 258 Parse the generator response to extract individual expanded queries. 259 260 :param generator_response: The raw text response from the generator. 261 :return: List of parsed expanded queries. 262 """ 263 parsed = _parse_dict_from_json(generator_response, expected_keys=["queries"], raise_on_failure=False) 264 265 if parsed is None: 266 return [] 267 268 queries = [] 269 for item in parsed["queries"]: 270 if isinstance(item, str) and item.strip(): 271 queries.append(item.strip()) 272 else: 273 logger.warning("Skipping non-string or empty query in response: {item}", item=item) 274 275 return queries