list_joiner.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from itertools import chain 6 from typing import Any 7 8 from haystack import component, default_from_dict, default_to_dict 9 from haystack.core.component.types import Variadic 10 from haystack.utils import deserialize_type, serialize_type 11 12 13 @component 14 class ListJoiner: 15 """ 16 A component that joins multiple lists into a single flat list. 17 18 The ListJoiner receives multiple lists of the same type and concatenates them into a single flat list. 19 The output order respects the pipeline's execution sequence, with earlier inputs being added first. 20 21 Usage example: 22 ```python 23 from haystack.components.builders import ChatPromptBuilder 24 from haystack.components.generators.chat import OpenAIChatGenerator 25 from haystack.dataclasses import ChatMessage 26 from haystack import Pipeline 27 from haystack.components.joiners import ListJoiner 28 29 30 user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")] 31 32 feedback_prompt = \""" 33 You are given a question and an answer. 34 Your task is to provide a score and a brief feedback on the answer. 35 Question: {{query}} 36 Answer: {{response}} 37 \""" 38 feedback_message = [ChatMessage.from_system(feedback_prompt)] 39 40 prompt_builder = ChatPromptBuilder(template=user_message) 41 feedback_prompt_builder = ChatPromptBuilder(template=feedback_message) 42 llm = OpenAIChatGenerator() 43 feedback_llm = OpenAIChatGenerator() 44 45 pipe = Pipeline() 46 pipe.add_component("prompt_builder", prompt_builder) 47 pipe.add_component("llm", llm) 48 pipe.add_component("feedback_prompt_builder", feedback_prompt_builder) 49 pipe.add_component("feedback_llm", feedback_llm) 50 pipe.add_component("list_joiner", ListJoiner(list[ChatMessage])) 51 52 pipe.connect("prompt_builder.prompt", "llm.messages") 53 pipe.connect("prompt_builder.prompt", "list_joiner") 54 pipe.connect("llm.replies", "list_joiner") 55 pipe.connect("llm.replies", "feedback_prompt_builder.response") 56 pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages") 57 pipe.connect("feedback_llm.replies", "list_joiner") 58 59 query = "What is nuclear physics?" 60 ans = pipe.run(data={"prompt_builder": {"template_variables":{"query": query}}, 61 "feedback_prompt_builder": {"template_variables":{"query": query}}}) 62 63 print(ans["list_joiner"]["values"]) 64 ``` 65 """ 66 67 def __init__(self, list_type_: type | None = None) -> None: 68 """ 69 Creates a ListJoiner component. 70 71 :param list_type_: The expected type of the lists this component will join (e.g., list[ChatMessage]). 72 If specified, all input lists must conform to this type. If None, the component defaults to handling 73 lists of any type including mixed types. 74 """ 75 self.list_type_ = list_type_ 76 if list_type_ is not None: 77 component.set_output_types(self, values=list_type_) 78 else: 79 component.set_output_types(self, values=list[Any]) 80 81 def to_dict(self) -> dict[str, Any]: 82 """ 83 Serializes the component to a dictionary. 84 85 :returns: Dictionary with serialized data. 86 """ 87 return default_to_dict( 88 self, list_type_=serialize_type(self.list_type_) if self.list_type_ is not None else None 89 ) 90 91 @classmethod 92 def from_dict(cls, data: dict[str, Any]) -> "ListJoiner": 93 """ 94 Deserializes the component from a dictionary. 95 96 :param data: Dictionary to deserialize from. 97 :returns: Deserialized component. 98 """ 99 init_parameters = data.get("init_parameters") 100 if init_parameters is not None and init_parameters.get("list_type_") is not None: 101 data["init_parameters"]["list_type_"] = deserialize_type(data["init_parameters"]["list_type_"]) 102 return default_from_dict(cls, data) 103 104 def run(self, values: Variadic[list[Any]]) -> dict[str, list[Any]]: 105 """ 106 Joins multiple lists into a single flat list. 107 108 :param values: The list to be joined. 109 :returns: Dictionary with 'values' key containing the joined list. 110 """ 111 result = list(chain(*values)) 112 return {"values": result}