/ haystack / components / builders / answer_builder.py
answer_builder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import re
  6  from dataclasses import replace
  7  from typing import Any
  8  
  9  from haystack import Document, GeneratedAnswer, component, logging
 10  from haystack.dataclasses.chat_message import ChatMessage
 11  
 12  logger = logging.getLogger(__name__)
 13  
 14  
 15  @component
 16  class AnswerBuilder:
 17      """
 18      Converts a query and Generator replies into a `GeneratedAnswer` object.
 19  
 20      AnswerBuilder parses Generator replies using custom regular expressions.
 21      Check out the usage example below to see how it works.
 22      Optionally, it can also take documents and metadata from the Generator to add to the `GeneratedAnswer` object.
 23      AnswerBuilder works with both non-chat and chat Generators.
 24  
 25      ### Usage example
 26  
 27      ```python
 28      from haystack.components.builders import AnswerBuilder
 29  
 30      builder = AnswerBuilder(pattern="Answer: (.*)")
 31      builder.run(query="What's the answer?", replies=["This is an argument. Answer: This is the answer."])
 32      ```
 33  
 34      ### Usage example with documents and reference pattern
 35  
 36      ```python
 37      from haystack import Document
 38      from haystack.components.builders import AnswerBuilder
 39  
 40      replies = ["The capital of France is Paris [2]."]
 41  
 42      docs = [
 43          Document(content="Berlin is the capital of Germany."),
 44          Document(content="Paris is the capital of France."),
 45          Document(content="Rome is the capital of Italy."),
 46      ]
 47  
 48      builder = AnswerBuilder(reference_pattern="\\[(\\d+)\\]", return_only_referenced_documents=False)
 49      result = builder.run(query="What is the capital of France?", replies=replies, documents=docs)["answers"][0]
 50  
 51      print(f"Answer: {result.data}")
 52      print("References:")
 53      for doc in result.documents:
 54          if doc.meta["referenced"]:
 55              print(f"[{doc.meta['source_index']}] {doc.content}")
 56      print("Other sources:")
 57      for doc in result.documents:
 58          if not doc.meta["referenced"]:
 59              print(f"[{doc.meta['source_index']}] {doc.content}")
 60  
 61      # >> Answer: The capital of France is Paris
 62      # >> References:
 63      # >> [2] Paris is the capital of France.
 64      # >> Other sources:
 65      # >> [1] Berlin is the capital of Germany.
 66      # >> [3] Rome is the capital of Italy.
 67      ```
 68      """
 69  
 70      def __init__(
 71          self,
 72          pattern: str | None = None,
 73          reference_pattern: str | None = None,
 74          last_message_only: bool = False,
 75          *,
 76          return_only_referenced_documents: bool = True,
 77      ) -> None:
 78          """
 79          Creates an instance of the AnswerBuilder component.
 80  
 81          :param pattern:
 82              The regular expression pattern to extract the answer text from the Generator.
 83              If not specified, the entire response is used as the answer.
 84              The regular expression can have one capture group at most.
 85              If present, the capture group text
 86              is used as the answer. If no capture group is present, the whole match is used as the answer.
 87              Examples:
 88                  `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer".
 89                  `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer".
 90  
 91          :param reference_pattern:
 92              The regular expression pattern used for parsing the document references.
 93              If not specified, no parsing is done, and all documents are returned.
 94              References need to be specified as indices of the input documents and start at [1].
 95              Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
 96              If this parameter is provided, documents metadata will contain a "referenced" key with a boolean value.
 97  
 98          :param last_message_only:
 99             If False (default value), all messages are used as the answer.
100             If True, only the last message is used as the answer.
101  
102          :param return_only_referenced_documents:
103              To be used in conjunction with `reference_pattern`.
104              If True (default value), only the documents that were actually referenced in `replies` are returned.
105              If False, all documents are returned.
106              If `reference_pattern` is not provided, this parameter has no effect, and all documents are returned.
107          """
108          if pattern:
109              AnswerBuilder._check_num_groups_in_regex(pattern)
110  
111          self.pattern = pattern
112          self.reference_pattern = reference_pattern
113          self.last_message_only = last_message_only
114          self.return_only_referenced_documents = return_only_referenced_documents
115  
116      @component.output_types(answers=list[GeneratedAnswer])
117      def run(
118          self,
119          query: str,
120          replies: list[str] | list[ChatMessage],
121          meta: list[dict[str, Any]] | None = None,
122          documents: list[Document] | None = None,
123          pattern: str | None = None,
124          reference_pattern: str | None = None,
125      ) -> dict[str, Any]:
126          """
127          Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions.
128  
129          :param query:
130              The input query used as the Generator prompt.
131          :param replies:
132              The output of the Generator. Can be a list of strings or a list of `ChatMessage` objects.
133          :param meta:
134              The metadata returned by the Generator. If not specified, the generated answer will contain no metadata.
135          :param documents:
136              The documents used as the Generator inputs. If specified, they are added to
137              the `GeneratedAnswer` objects.
138              Each Document.meta includes a "source_index" key, representing its 1-based position in the input list.
139              When `reference_pattern` is provided:
140              - "referenced" key is added to the Document.meta, indicating if the document was referenced in the output.
141              - `return_only_referenced_documents` init parameter controls if all or only referenced documents are
142              returned.
143          :param pattern:
144              The regular expression pattern to extract the answer text from the Generator.
145              If not specified, the entire response is used as the answer.
146              The regular expression can have one capture group at most.
147              If present, the capture group text
148              is used as the answer. If no capture group is present, the whole match is used as the answer.
149                  Examples:
150                      `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer".
151                      `Answer: (.*)` finds "this is an answer" in a string
152                      "this is an argument. Answer: this is an answer".
153          :param reference_pattern:
154              The regular expression pattern used for parsing the document references.
155              If not specified, no parsing is done, and all documents are returned.
156              References need to be specified as indices of the input documents and start at [1].
157              Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]".
158  
159          :returns: A dictionary with the following keys:
160              - `answers`: The answers received from the output of the Generator.
161          """
162          if not meta:
163              meta = [{}] * len(replies)
164          elif len(replies) != len(meta):
165              raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(meta)}) must match.")
166  
167          if pattern:
168              AnswerBuilder._check_num_groups_in_regex(pattern)
169  
170          pattern = pattern or self.pattern
171          reference_pattern = reference_pattern or self.reference_pattern
172  
173          replies_to_iterate = replies[-1:] if self.last_message_only and replies else replies
174          meta_to_iterate = meta[-1:] if self.last_message_only and meta else meta
175  
176          all_answers = []
177          for reply, given_metadata in zip(replies_to_iterate, meta_to_iterate, strict=True):
178              # Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
179              extracted_reply = reply.text or "" if isinstance(reply, ChatMessage) else str(reply)
180              extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {}
181  
182              extracted_metadata = {**extracted_metadata, **given_metadata}
183              extracted_metadata["all_messages"] = replies
184  
185              referenced_docs = []
186              if documents:
187                  referenced_idxs = (
188                      AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern)
189                      if reference_pattern
190                      else set()
191                  )
192                  doc_idxs = (
193                      referenced_idxs
194                      if reference_pattern and self.return_only_referenced_documents
195                      else set(range(len(documents)))
196                  )
197  
198                  for idx in doc_idxs:
199                      try:
200                          doc = documents[idx]
201                      except IndexError:
202                          logger.warning(
203                              "Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1
204                          )
205                          continue
206  
207                      doc_meta: dict[str, Any] = doc.meta or {}
208                      doc_meta["source_index"] = idx + 1
209                      if reference_pattern:
210                          doc_meta["referenced"] = idx in referenced_idxs
211                      referenced_docs.append(replace(doc, meta=doc_meta))
212  
213              answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern)
214              answer = GeneratedAnswer(
215                  data=answer_string, query=query, documents=referenced_docs, meta=extracted_metadata
216              )
217              all_answers.append(answer)
218  
219          return {"answers": all_answers}
220  
221      @staticmethod
222      def _extract_answer_string(reply: str, pattern: str | None = None) -> str:
223          """
224          Extract the answer string from the generator output using the specified pattern.
225  
226          If no pattern is specified, the whole string is used as the answer.
227  
228          :param reply:
229              The output of the Generator. A string.
230          :param pattern:
231              The regular expression pattern to use to extract the answer text from the generator output.
232          """
233          if pattern is None:
234              return reply
235  
236          if match := re.search(pattern, reply):
237              # No capture group in pattern -> use the whole match as answer
238              if not match.lastindex:
239                  return match.group(0)
240              # One capture group in pattern -> use the capture group as answer
241              return match.group(1)
242          return ""
243  
244      @staticmethod
245      def _extract_reference_idxs(reply: str, reference_pattern: str) -> set[int]:
246          document_idxs = re.findall(reference_pattern, reply)
247          return {int(idx) - 1 for idx in document_idxs}
248  
249      @staticmethod
250      def _check_num_groups_in_regex(pattern: str) -> None:
251          num_groups = re.compile(pattern).groups
252          if num_groups > 1:
253              raise ValueError(
254                  f"Pattern '{pattern}' contains multiple capture groups. "
255                  f"Please specify a pattern with at most one capture group."
256              )