/ haystack / components / joiners / list_joiner.py
list_joiner.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from itertools import chain
  6  from typing import Any
  7  
  8  from haystack import component, default_from_dict, default_to_dict
  9  from haystack.core.component.types import Variadic
 10  from haystack.utils import deserialize_type, serialize_type
 11  
 12  
 13  @component
 14  class ListJoiner:
 15      """
 16      A component that joins multiple lists into a single flat list.
 17  
 18      The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list.
 19      The output order respects the pipeline's execution sequence, with earlier inputs being added first.
 20  
 21      Usage example:
 22      ```python
 23      from haystack.components.builders import ChatPromptBuilder
 24      from haystack.components.generators.chat import OpenAIChatGenerator
 25      from haystack.dataclasses import ChatMessage
 26      from haystack import Pipeline
 27      from haystack.components.joiners import ListJoiner
 28  
 29  
 30      user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")]
 31  
 32      feedback_prompt = \"""
 33          You are given a question and an answer.
 34          Your task is to provide a score and a brief feedback on the answer.
 35          Question: {{query}}
 36          Answer: {{response}}
 37          \"""
 38      feedback_message = [ChatMessage.from_system(feedback_prompt)]
 39  
 40      prompt_builder = ChatPromptBuilder(template=user_message)
 41      feedback_prompt_builder = ChatPromptBuilder(template=feedback_message)
 42      llm = OpenAIChatGenerator()
 43      feedback_llm = OpenAIChatGenerator()
 44  
 45      pipe = Pipeline()
 46      pipe.add_component("prompt_builder", prompt_builder)
 47      pipe.add_component("llm", llm)
 48      pipe.add_component("feedback_prompt_builder", feedback_prompt_builder)
 49      pipe.add_component("feedback_llm", feedback_llm)
 50      pipe.add_component("list_joiner", ListJoiner(list[ChatMessage]))
 51  
 52      pipe.connect("prompt_builder.prompt", "llm.messages")
 53      pipe.connect("prompt_builder.prompt", "list_joiner")
 54      pipe.connect("llm.replies", "list_joiner")
 55      pipe.connect("llm.replies", "feedback_prompt_builder.response")
 56      pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages")
 57      pipe.connect("feedback_llm.replies", "list_joiner")
 58  
 59      query = "What is nuclear physics?"
 60      ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}},
 61          "feedback_prompt_builder": {"template_variables":{"query": query}}})
 62  
 63      print(ans["list_joiner"]["values"])
 64      ```
 65      """
 66  
 67      def __init__(self, list_type_: type | None = None) -> None:
 68          """
 69          Creates a ListJoiner component.
 70  
 71          :param list_type_: The expected type of the lists this component will join (e.g., list[ChatMessage]).
 72              If specified, all input lists must conform to this type. If None, the component defaults to handling
 73              lists of any type including mixed types.
 74          """
 75          self.list_type_ = list_type_
 76          if list_type_ is not None:
 77              component.set_output_types(self, values=list_type_)
 78          else:
 79              component.set_output_types(self, values=list[Any])
 80  
 81      def to_dict(self) -> dict[str, Any]:
 82          """
 83          Serializes the component to a dictionary.
 84  
 85          :returns: Dictionary with serialized data.
 86          """
 87          return default_to_dict(
 88              self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None
 89          )
 90  
 91      @classmethod
 92      def from_dict(cls, data: dict[str, Any]) -> "ListJoiner":
 93          """
 94          Deserializes the component from a dictionary.
 95  
 96          :param data: Dictionary to deserialize from.
 97          :returns: Deserialized component.
 98          """
 99          init_parameters = data.get("init_parameters")
100          if init_parameters is not None and init_parameters.get("list_type_") is not None:
101              data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"])
102          return default_from_dict(cls, data)
103  
104      def run(self, values: Variadic[list[Any]]) -> dict[str, list[Any]]:
105          """
106          Joins multiple lists into a single flat list.
107  
108          :param values: The list to be joined.
109          :returns: Dictionary with 'values' key containing the joined list.
110          """
111          result = list(chain(*values))
112          return {"values": result}