/ haystack / components / query / query_expander.py
query_expander.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  from haystack import default_from_dict, default_to_dict, logging
  8  from haystack.components.builders.prompt_builder import PromptBuilder
  9  from haystack.components.generators.chat.openai import OpenAIChatGenerator
 10  from haystack.components.generators.chat.types import ChatGenerator
 11  from haystack.core.component import component
 12  from haystack.core.serialization import component_to_dict
 13  from haystack.dataclasses.chat_message import ChatMessage
 14  from haystack.utils import deserialize_chatgenerator_inplace
 15  from haystack.utils.misc import _parse_dict_from_json
 16  
 17  logger = logging.getLogger(__name__)
 18  
 19  
 20  DEFAULT_PROMPT_TEMPLATE = """
 21  You are part of an information system that processes user queries for retrieval.
 22  You have to expand a given query into {{ n_expansions }} queries that are
 23  semantically similar to improve retrieval recall.
 24  
 25  Structure:
 26  Follow the structure shown below in examples to generate expanded queries.
 27  
 28  Examples:
 29  1.  Query: "climate change effects"
 30      {"queries": ["impact of climate change", "consequences of global warming", "effects of environmental changes"]}
 31  
 32  2.  Query: "machine learning algorithms"
 33      {"queries": ["neural networks", "clustering techniques", "supervised learning methods", "deep learning models"]}
 34  
 35  3.  Query: "open source NLP frameworks"
 36      {"queries": ["natural language processing tools", "free nlp libraries", "open-source NLP platforms"]}
 37  
 38  Guidelines:
 39  - Generate queries that use different words and phrasings
 40  - Include synonyms and related terms
 41  - Maintain the same core meaning and intent
 42  - Make queries that are likely to retrieve relevant information the original might miss
 43  - Focus on variations that would work well with keyword-based search
 44  - Respond in the same language as the input query
 45  
 46  Your Task:
 47  Query: "{{ query }}"
 48  
 49  You *must* respond with a JSON object containing a "queries" array with the expanded queries.
 50  Example: {"queries": ["query1", "query2", "query3"]}"""
 51  
 52  
 53  @component
 54  class QueryExpander:
 55      """
 56      A component that returns a list of semantically similar queries to improve retrieval recall in RAG systems.
 57  
 58      The component uses a chat generator to expand queries. The chat generator is expected to return a JSON response
 59      with the following structure:
 60      ```json
 61      {"queries": ["expanded query 1", "expanded query 2", "expanded query 3"]}
 62      ```
 63  
 64      ### Usage example
 65  
 66      ```python
 67      from haystack.components.generators.chat.openai import OpenAIChatGenerator
 68      from haystack.components.query import QueryExpander
 69  
 70      expander = QueryExpander(
 71          chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"),
 72          n_expansions=3
 73      )
 74  
 75      result = expander.run(query="green energy sources")
 76      print(result["queries"])
 77      # Output: ['alternative query 1', 'alternative query 2', 'alternative query 3', 'green energy sources']
 78      # Note: Up to 3 additional queries + 1 original query (if include_original_query=True)
 79  
 80      # To control total number of queries:
 81      expander = QueryExpander(n_expansions=2, include_original_query=True)  # Up to 3 total
 82      # or
 83      expander = QueryExpander(n_expansions=3, include_original_query=False)  # Exactly 3 total
 84      ```
 85      """
 86  
 87      def __init__(
 88          self,
 89          *,
 90          chat_generator: ChatGenerator | None = None,
 91          prompt_template: str | None = None,
 92          n_expansions: int = 4,
 93          include_original_query: bool = True,
 94      ) -> None:
 95          """
 96          Initialize the QueryExpander component.
 97  
 98          :param chat_generator: The chat generator component to use for query expansion.
 99              If None, a default OpenAIChatGenerator with gpt-4.1-mini model is used.
100          :param prompt_template: Custom [PromptBuilder](https://docs.haystack.deepset.ai/docs/promptbuilder)
101              template for query expansion. The template should instruct the LLM to return a JSON response with the
102              structure: `{"queries": ["query1", "query2", "query3"]}`. The template should include 'query' and
103              'n_expansions' variables.
104          :param n_expansions: Number of alternative queries to generate (default: 4).
105          :param include_original_query: Whether to include the original query in the output.
106          """
107          if n_expansions <= 0:
108              raise ValueError("n_expansions must be positive")
109  
110          self.n_expansions = n_expansions
111          self.include_original_query = include_original_query
112  
113          if chat_generator is None:
114              self.chat_generator: ChatGenerator = OpenAIChatGenerator(
115                  model="gpt-4.1-mini",
116                  generation_kwargs={
117                      "temperature": 0.7,
118                      "response_format": {
119                          "type": "json_schema",
120                          "json_schema": {
121                              "name": "query_expansion",
122                              "schema": {
123                                  "type": "object",
124                                  "properties": {"queries": {"type": "array", "items": {"type": "string"}}},
125                                  "required": ["queries"],
126                                  "additionalProperties": False,
127                              },
128                          },
129                      },
130                      "seed": 42,
131                  },
132              )
133          else:
134              self.chat_generator = chat_generator
135  
136          self._is_warmed_up = False
137          self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE
138  
139          # Check if required variables are present in the template
140          if "query" not in self.prompt_template:
141              logger.warning(
142                  "The prompt template does not contain the 'query' variable. This may cause issues during execution."
143              )
144          if "n_expansions" not in self.prompt_template:
145              logger.warning(
146                  "The prompt template does not contain the 'n_expansions' variable. "
147                  "This may cause issues during execution."
148              )
149  
150          self._prompt_builder = PromptBuilder(
151              template=self.prompt_template, required_variables=["n_expansions", "query"]
152          )
153  
154      def to_dict(self) -> dict[str, Any]:
155          """
156          Serializes the component to a dictionary.
157  
158          :return: Dictionary with serialized data.
159          """
160          return default_to_dict(
161              self,
162              chat_generator=component_to_dict(self.chat_generator, name="chat_generator"),
163              prompt_template=self.prompt_template,
164              n_expansions=self.n_expansions,
165              include_original_query=self.include_original_query,
166          )
167  
168      @classmethod
169      def from_dict(cls, data: dict[str, Any]) -> "QueryExpander":
170          """
171          Deserializes the component from a dictionary.
172  
173          :param data: Dictionary with serialized data.
174          :return: Deserialized component.
175          """
176          init_params = data.get("init_parameters", {})
177  
178          deserialize_chatgenerator_inplace(init_params, key="chat_generator")
179  
180          return default_from_dict(cls, data)
181  
182      @component.output_types(queries=list[str])
183      def run(self, query: str, n_expansions: int | None = None) -> dict[str, list[str]]:
184          """
185          Expand the input query into multiple semantically similar queries.
186  
187          The language of the original query is preserved in the expanded queries.
188  
189          :param query: The original query to expand.
190          :param n_expansions: Number of additional queries to generate (not including the original).
191              If None, uses the value from initialization. Can be 0 to generate no additional queries.
192          :return: Dictionary with "queries" key containing the list of expanded queries.
193              If include_original_query=True, the original query will be included in addition
194              to the n_expansions alternative queries.
195          :raises ValueError: If n_expansions is not positive (less than or equal to 0).
196          """
197  
198          if not self._is_warmed_up:
199              self.warm_up()
200  
201          response = {"queries": [query] if self.include_original_query else []}
202  
203          if not query.strip():
204              logger.warning("Empty query provided to QueryExpander")
205              return response
206  
207          expansion_count = n_expansions if n_expansions is not None else self.n_expansions
208          if expansion_count <= 0:
209              raise ValueError("n_expansions must be positive")
210  
211          try:
212              prompt_result = self._prompt_builder.run(query=query.strip(), n_expansions=expansion_count)
213              generator_result = self.chat_generator.run(messages=[ChatMessage.from_user(prompt_result["prompt"])])
214  
215              if not generator_result.get("replies") or len(generator_result["replies"]) == 0:
216                  logger.warning("ChatGenerator returned no replies for query: {query}", query=query)
217                  return response
218  
219              expanded_text = generator_result["replies"][0].text.strip()
220              expanded_queries = self._parse_expanded_queries(expanded_text)
221  
222              # Limit the number of expanded queries to the requested amount
223              if len(expanded_queries) > expansion_count:
224                  logger.warning(
225                      "Generated {generated_count} queries but only {requested_count} were requested. "
226                      "Truncating to the first {requested_count} queries. ",
227                      generated_count=len(expanded_queries),
228                      requested_count=expansion_count,
229                  )
230                  expanded_queries = expanded_queries[:expansion_count]
231  
232              # Add original query if requested and remove duplicates
233              if self.include_original_query:
234                  expanded_queries_lower = [q.lower() for q in expanded_queries]
235                  if query.lower() not in expanded_queries_lower:
236                      expanded_queries.append(query)
237  
238              response["queries"] = expanded_queries
239              return response
240  
241          except Exception as e:
242              # Fallback: return original query to maintain pipeline functionality
243              logger.exception("Failed to expand query {query}: {error}", query=query, error=str(e))
244              return response
245  
246      def warm_up(self) -> None:
247          """
248          Warm up the LLM provider component.
249          """
250          if not self._is_warmed_up:
251              if hasattr(self.chat_generator, "warm_up"):
252                  self.chat_generator.warm_up()
253              self._is_warmed_up = True
254  
255      @staticmethod
256      def _parse_expanded_queries(generator_response: str) -> list[str]:
257          """
258          Parse the generator response to extract individual expanded queries.
259  
260          :param generator_response: The raw text response from the generator.
261          :return: List of parsed expanded queries.
262          """
263          parsed = _parse_dict_from_json(generator_response, expected_keys=["queries"], raise_on_failure=False)
264  
265          if parsed is None:
266              return []
267  
268          queries = []
269          for item in parsed["queries"]:
270              if isinstance(item, str) and item.strip():
271                  queries.append(item.strip())
272              else:
273                  logger.warning("Skipping non-string or empty query in response: {item}", item=item)
274  
275          return queries