/ haystack / testing / sample_components / accumulate.py
accumulate.py
 1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
 2  #
 3  # SPDX-License-Identifier: Apache-2.0
 4  
 5  import builtins
 6  import sys
 7  from collections.abc import Callable
 8  from importlib import import_module
 9  from typing import Any
10  
11  from haystack.core.component import component
12  from haystack.core.errors import ComponentDeserializationError
13  from haystack.core.serialization import default_to_dict
14  
15  
16  def _default_function(first: int, second: int) -> int:
17      return first + second
18  
19  
20  @component
21  class Accumulate:
22      """
23      Accumulates the value flowing through the connection into an internal attribute.
24  
25      The sum function can be customized. Example of how to deal with serialization when some of the parameters
26      are not directly serializable.
27      """
28  
29      def __init__(self, function: Callable | None = None) -> None:
30          """
31          Class constructor
32  
33          :param function:
34              the function to use to accumulate the values.
35              The function must take exactly two values.
36              If it's a callable, it's used as it is.
37              If it's a string, the component will look for it in sys.modules and
38              import it at need. This is also a parameter.
39          """
40          self.state = 0
41          self.function: Callable = _default_function if function is None else function
42  
43      def to_dict(self) -> dict[str, Any]:
44          """Converts the component to a dictionary"""
45          module = sys.modules.get(self.function.__module__)
46          if not module:
47              raise ValueError("Could not locate the import module.")
48          if module == builtins:
49              function_name = self.function.__name__
50          else:
51              function_name = f"{module.__name__}.{self.function.__name__}"
52  
53          return default_to_dict(self, function=function_name)
54  
55      @classmethod
56      def from_dict(cls, data: dict[str, Any]) -> "Accumulate":
57          """Loads the component from a dictionary"""
58          if "type" not in data:
59              raise ComponentDeserializationError("Missing 'type' in component serialization data")
60          if data["type"] != f"{cls.__module__}.{cls.__name__}":
61              raise ComponentDeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
62  
63          init_params = data.get("init_parameters", {})
64  
65          accumulator_function = None
66          if "function" in init_params:
67              parts = init_params["function"].split(".")
68              module_name = ".".join(parts[:-1])
69              function_name = parts[-1]
70              module = import_module(module_name)
71              accumulator_function = getattr(module, function_name)
72  
73          return cls(function=accumulator_function)
74  
75      @component.output_types(value=int)
76      def run(self, value: int):
77          """
78          Accumulates the value flowing through the connection into an internal attribute.
79  
80          The sum function can be customized.
81          """
82          self.state = self.function(self.state, value)
83          return {"value": self.state}