accumulate.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import builtins 6 import sys 7 from collections.abc import Callable 8 from importlib import import_module 9 from typing import Any 10 11 from haystack.core.component import component 12 from haystack.core.errors import ComponentDeserializationError 13 from haystack.core.serialization import default_to_dict 14 15 16 def _default_function(first: int, second: int) -> int: 17 return first + second 18 19 20 @component 21 class Accumulate: 22 """ 23 Accumulates the value flowing through the connection into an internal attribute. 24 25 The sum function can be customized. Example of how to deal with serialization when some of the parameters 26 are not directly serializable. 27 """ 28 29 def __init__(self, function: Callable | None = None) -> None: 30 """ 31 Class constructor 32 33 :param function: 34 the function to use to accumulate the values. 35 The function must take exactly two values. 36 If it's a callable, it's used as it is. 37 If it's a string, the component will look for it in sys.modules and 38 import it at need. This is also a parameter. 39 """ 40 self.state = 0 41 self.function: Callable = _default_function if function is None else function 42 43 def to_dict(self) -> dict[str, Any]: 44 """Converts the component to a dictionary""" 45 module = sys.modules.get(self.function.__module__) 46 if not module: 47 raise ValueError("Could not locate the import module.") 48 if module == builtins: 49 function_name = self.function.__name__ 50 else: 51 function_name = f"{module.__name__}.{self.function.__name__}" 52 53 return default_to_dict(self, function=function_name) 54 55 @classmethod 56 def from_dict(cls, data: dict[str, Any]) -> "Accumulate": 57 """Loads the component from a dictionary""" 58 if "type" not in data: 59 raise ComponentDeserializationError("Missing 'type' in component serialization data") 60 if data["type"] != f"{cls.__module__}.{cls.__name__}": 61 raise ComponentDeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") 62 63 init_params = data.get("init_parameters", {}) 64 65 accumulator_function = None 66 if "function" in init_params: 67 parts = init_params["function"].split(".") 68 module_name = ".".join(parts[:-1]) 69 function_name = parts[-1] 70 module = import_module(module_name) 71 accumulator_function = getattr(module, function_name) 72 73 return cls(function=accumulator_function) 74 75 @component.output_types(value=int) 76 def run(self, value: int): 77 """ 78 Accumulates the value flowing through the connection into an internal attribute. 79 80 The sum function can be customized. 81 """ 82 self.state = self.function(self.state, value) 83 return {"value": self.state}