/ haystack / core / component / component.py
component.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  """
  6  Attributes:
  7  
  8      component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
  9  
 10  All components must follow the contract below. This docstring is the source of truth for components contract.
 11  
 12  <hr>
 13  
 14  `@component` decorator
 15  
 16  All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them.
 17  
 18  <hr>
 19  
 20  `__init__(self, **kwargs)`
 21  
 22  Optional method.
 23  
 24  Components may have an `__init__` method where they define:
 25  
 26  - `self.init_parameters = {same parameters that the __init__ method received}`:
 27      In this dictionary you can store any state the components wish to be persisted when they are saved.
 28      These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
 29      Note that by default the `@component` decorator saves the arguments automatically.
 30      However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
 31      Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
 32  
 33  Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
 34  dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
 35  time. If there's the need for such values, consider serializing them to a string.
 36  
 37  If you need to accept classes or callables, accept either a string import path or the callable itself. Resolve strings
 38  to objects in `__init__`, and serialize objects back to importable strings in `to_dict()` so that `from_dict()` can load
 39  them (for example, store `"module_path.symbol_name"` and load it via `importlib`). This keeps init parameters JSON
 40  serializable for pipeline save/load. See `haystack.testing.sample_components.accumulate.Accumulate` for a reference
 41  implementation.
 42  
 43  The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
 44  validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
 45  the `warm_up()` method.
 46  
 47  <hr>
 48  
 49  `warm_up(self)`
 50  
 51  Optional method.
 52  
 53  This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
 54  because Pipeline will not keep track of which components it called `warm_up()` on.
 55  
 56  <hr>
 57  
 58  `run(self, data)`
 59  
 60  Mandatory method.
 61  
 62  This is the method where the main functionality of the component should be carried out. It's called by
 63  `Pipeline.run()`.
 64  
 65  When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
 66  method decorated with `@component.input`. This dataclass contains:
 67  
 68  - all the input values coming from other components connected to it,
 69  - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
 70  
 71  `run()` must return a single instance of the dataclass declared through the method decorated with
 72  `@component.output`.
 73  
 74  """
 75  
 76  import inspect
 77  import typing
 78  from collections.abc import Callable, Coroutine, Iterator, Mapping
 79  from contextlib import contextmanager
 80  from contextvars import ContextVar
 81  from copy import deepcopy
 82  from dataclasses import dataclass
 83  from types import new_class
 84  from typing import Any, ParamSpec, Protocol, TypeVar, overload, runtime_checkable
 85  
 86  from haystack import logging
 87  from haystack.core.errors import ComponentError
 88  
 89  from .sockets import Sockets
 90  from .types import InputSocket, OutputSocket, _empty
 91  
 92  logger = logging.getLogger(__name__)
 93  
 94  RunParamsT = ParamSpec("RunParamsT")
 95  RunReturnT = TypeVar("RunReturnT", bound=Mapping[str, Any] | Coroutine[Any, Any, Mapping[str, Any]])
 96  
 97  
 98  @dataclass
 99  class PreInitHookPayload:
100      """
101      Payload for the hook called before a component instance is initialized.
102  
103      :param callback:
104          Receives the following inputs: component class and init parameter keyword args.
105      :param in_progress:
106          Flag to indicate if the hook is currently being executed.
107          Used to prevent it from being called recursively (if the component's constructor
108          instantiates another component).
109      """
110  
111      callback: Callable
112      in_progress: bool = False
113  
114  
115  _COMPONENT_PRE_INIT_HOOK: ContextVar[PreInitHookPayload | None] = ContextVar("component_pre_init_hook", default=None)
116  
117  
118  @contextmanager
119  def _hook_component_init(callback: Callable) -> Iterator[None]:
120      """
121      Context manager to set a callback that will be invoked before a component's constructor is called.
122  
123      The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
124      parameters in place.
125  
126      :param callback:
127          Callback function to invoke.
128      """
129      token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
130      try:
131          yield
132      finally:
133          _COMPONENT_PRE_INIT_HOOK.reset(token)
134  
135  
136  @runtime_checkable
137  class Component(Protocol):
138      """
139      Note this is only used by type checking tools.
140  
141      In order to implement the `Component` protocol, custom components need to
142      have a `run` method. The signature of the method and its return value
143      won't be checked, i.e. classes with the following methods:
144  
145          def run(self, param: str) -> dict[str, Any]:
146              ...
147  
148      and
149  
150          def run(self, **kwargs):
151              ...
152  
153      will be both considered as respecting the protocol. This makes the type
154      checking much weaker, but we have other places where we ensure code is
155      dealing with actual Components.
156  
157      The protocol is runtime checkable so it'll be possible to assert:
158  
159          isinstance(MyComponent, Component)
160      """
161  
162      # The following expression defines a run method compatible with any input signature.
163      # Its type is equivalent to Callable[..., dict[str, Any]].
164      # See https://typing.python.org/en/latest/spec/callables.html#meaning-of-in-callable.
165      #
166      # Using `run: Callable[..., dict[str, Any]]` directly leads to type errors: the protocol would expect a settable
167      # attribute `run`, while the actual implementation is a read-only method.
168      # For example:
169      # from haystack import Pipeline, component
170      # @component
171      # class MyComponent:
172      #     @component.output_types(out=str)
173      #     def run(self):
174      #         return {"out": "Hello, world!"}
175      # pipeline = Pipeline()
176      # pipeline.add_component("my_component", MyComponent())
177      #
178      # mypy raises:
179      # error: Argument 2 to "add_component" of "PipelineBase" has incompatible type "MyComponent"; expected "Component"
180      # [arg-type]
181      # note: Protocol member Component.run expected settable variable, got read-only attribute
182  
183      def run(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:  # noqa: D102
184          ...
185  
186  
187  class ComponentMeta(type):
188      @staticmethod
189      def _positional_to_kwargs(cls_type: type, args: tuple[Any, ...]) -> dict[str, Any]:
190          """
191          Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
192          """
193          init_signature = inspect.signature(cls_type.__init__)  # type:ignore[misc]
194          init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
195  
196          out = {}
197          for arg, (name, info) in zip(args, init_params.items(), strict=False):
198              if info.kind == inspect.Parameter.VAR_POSITIONAL:
199                  raise ComponentError(
200                      "Pre-init hooks do not support components with variadic positional args in their init method"
201                  )
202  
203              assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
204              out[name] = arg
205          return out
206  
207      @staticmethod
208      def _parse_and_set_output_sockets(instance: Any) -> None:
209          has_async_run = hasattr(instance, "run_async")
210  
211          # If `component.set_output_types()` was called in the component constructor,
212          # `__haystack_output__` is already populated, no need to do anything.
213          if not hasattr(instance, "__haystack_output__"):
214              # If that's not the case, we need to populate `__haystack_output__`
215              #
216              # If either of the run methods were decorated, they'll have a field assigned that
217              # stores the output specification. If both run methods were decorated, we ensure that
218              # outputs are the same. We deepcopy the content of the cache to transfer ownership from
219              # the class method to the actual instance, so that different instances of the same class
220              # won't share this data.
221  
222              run_output_types = getattr(instance.run, "_output_types_cache", {})
223              async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
224  
225              if has_async_run and run_output_types != async_run_output_types:
226                  raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
227              output_types_cache = run_output_types
228  
229              instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
230  
231      @staticmethod
232      def _parse_and_set_input_sockets(component_cls: type, instance: Any) -> None:
233          def inner(method: Callable[..., Any], sockets: Sockets) -> inspect.Signature:
234              from inspect import Parameter
235  
236              run_signature = inspect.signature(method)
237              try:
238                  # TypeError is raised if the argument is not of a type that can contain annotations
239                  run_hints = typing.get_type_hints(method)
240              except TypeError:
241                  run_hints = None
242  
243              for param_name, param_info in run_signature.parameters.items():
244                  if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
245                      continue
246  
247                  # We prefer the type annotation from inspect.signature, but if it's a string we need to resolve it
248                  # using the hints. The type annotation can be a string if the component is using postponed evaluation
249                  # of annotations.
250                  annotation = param_info.annotation
251                  if isinstance(annotation, str) and run_hints is not None:
252                      annotation = run_hints.get(param_name, annotation)
253  
254                  socket_kwargs = {"name": param_name, "type": annotation}
255                  if param_info.default != Parameter.empty:
256                      socket_kwargs["default_value"] = param_info.default
257  
258                  new_socket = InputSocket(**socket_kwargs)
259  
260                  # Also ensure that new sockets don't override existing ones.
261                  existing_socket = sockets.get(param_name)
262                  if existing_socket is not None and existing_socket != new_socket:
263                      raise ComponentError(
264                          "set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
265                      )
266  
267                  sockets[param_name] = new_socket
268  
269              return run_signature
270  
271          # Create the sockets if set_input_types() wasn't called in the constructor.
272          if not hasattr(instance, "__haystack_input__"):
273              instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
274  
275          inner(getattr(component_cls, "run"), instance.__haystack_input__)  # noqa: B009
276  
277          # Ensure that the sockets are the same for the async method, if it exists.
278          async_run = getattr(component_cls, "run_async", None)
279          if async_run is not None:
280              run_sockets = Sockets(instance, {}, InputSocket)
281              async_run_sockets = Sockets(instance, {}, InputSocket)
282  
283              # Can't use the sockets from above as they might contain
284              # values set with set_input_types().
285              run_sig = inner(getattr(component_cls, "run"), run_sockets)  # noqa: B009
286              async_run_sig = inner(async_run, async_run_sockets)
287  
288              if async_run_sockets != run_sockets or run_sig != async_run_sig:
289                  sig_diff = _compare_run_methods_signatures(run_sig, async_run_sig)
290                  raise ComponentError(
291                      f"Parameters of 'run' and 'run_async' methods must be the same.\nDifferences found:\n{sig_diff}"
292                  )
293  
294      def __call__(cls, *args: Any, **kwargs: Any) -> Any:
295          """
296          This method is called when clients instantiate a Component and runs before __new__ and __init__.
297          """
298          # This will call __new__ then __init__, giving us back the Component instance
299          pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
300          if pre_init_hook is None or pre_init_hook.in_progress:
301              instance = super().__call__(*args, **kwargs)
302          else:
303              try:
304                  pre_init_hook.in_progress = True
305                  named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
306                  assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), (
307                      "positional and keyword arguments overlap"
308                  )
309                  kwargs.update(named_positional_args)
310                  pre_init_hook.callback(cls, kwargs)
311                  instance = super().__call__(**kwargs)
312              finally:
313                  pre_init_hook.in_progress = False
314  
315          # Before returning, we have the chance to modify the newly created
316          # Component instance, so we take the chance and set up the I/O sockets
317          has_async_run = hasattr(instance, "run_async")
318          if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
319              raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
320          instance.__haystack_supports_async__ = has_async_run
321  
322          ComponentMeta._parse_and_set_input_sockets(cls, instance)
323          ComponentMeta._parse_and_set_output_sockets(instance)
324  
325          # Since a Component can't be used in multiple Pipelines at the same time
326          # we need to know if it's already owned by a Pipeline when adding it to one.
327          # We use this flag to check that.
328          instance.__haystack_added_to_pipeline__ = None
329  
330          return instance
331  
332  
333  def _component_repr(component: Component) -> str:
334      """
335      All Components override their __repr__ method with this one.
336  
337      It prints the component name and the input/output sockets.
338      """
339      result = object.__repr__(component)
340      if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
341          # This Component has been added in a Pipeline, let's get the name from there.
342          result += f"\n{pipeline.get_component_name(component)}"
343  
344      # We're explicitly ignoring the type here because we're sure that the component
345      # has the __haystack_input__ and __haystack_output__ attributes at this point
346      return (
347          f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}"
348          f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}"
349      )
350  
351  
352  def _component_run_has_kwargs(component_cls: type) -> bool:
353      run_method = getattr(component_cls, "run", None)
354      if run_method is None:
355          return False
356      return any(
357          param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
358      )
359  
360  
361  def _compare_run_methods_signatures(run_sig: inspect.Signature, async_run_sig: inspect.Signature) -> str:
362      """
363      Builds a detailed error message with the differences between the signatures of the run and run_async methods.
364  
365      :param run_sig: The signature of the run method
366      :param async_run_sig: The signature of the run_async method
367  
368      :returns:
369          A detailed error message if signatures don't match, empty string if they do
370      """
371      differences = []
372      run_params = list(run_sig.parameters.items())
373      async_params = list(async_run_sig.parameters.items())
374  
375      if len(run_params) != len(async_params):
376          differences.append(
377              f"Different number of parameters: run has {len(run_params)}, run_async has {len(async_params)}"
378          )
379  
380      for (run_name, run_param), (async_name, async_param) in zip(run_params, async_params, strict=False):
381          if run_name != async_name:
382              differences.append(f"Parameter name mismatch: {run_name} vs {async_name}")
383  
384          if run_param.annotation != async_param.annotation:
385              differences.append(
386                  f"Parameter '{run_name}' type mismatch: {run_param.annotation} vs {async_param.annotation}"
387              )
388  
389          if run_param.default != async_param.default:
390              differences.append(
391                  f"Parameter '{run_name}' default value mismatch: {run_param.default} vs {async_param.default}"
392              )
393  
394          if run_param.kind != async_param.kind:
395              differences.append(
396                  f"Parameter '{run_name}' kind (POSITIONAL, KEYWORD, etc.) mismatch: "
397                  f"{run_param.kind} vs {async_param.kind}"
398              )
399  
400      return "\n".join(differences)
401  
402  
403  T = TypeVar("T", bound=Component)
404  
405  
406  class _Component:
407      """
408      See module's docstring.
409  
410      Args:
411          cls: the class that should be used as a component.
412  
413      Returns:
414          A class that can be recognized as a component.
415  
416      Raises:
417          ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
418      """
419  
420      def __init__(self) -> None:
421          self.registry: dict[str, type] = {}
422  
423      def set_input_type(
424          self,
425          instance: Component,
426          name: str,
427          type: Any,  # noqa: A002
428          default: Any = _empty,
429      ) -> None:
430          """
431          Add a single input socket to the component instance.
432  
433          Replaces any existing input socket with the same name.
434  
435          :param instance: Component instance where the input type will be added.
436          :param name: name of the input socket.
437          :param type: type of the input socket.
438          :param default: default value of the input socket, defaults to _empty
439          """
440          if not _component_run_has_kwargs(instance.__class__):
441              raise ComponentError(
442                  "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
443              )
444  
445          if not hasattr(instance, "__haystack_input__"):
446              instance.__haystack_input__ = Sockets(instance, {}, InputSocket)  # type: ignore
447          instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)  # type: ignore
448  
449      def set_input_types(self, instance: Any, **types: type[Any]) -> None:
450          """
451          Method that specifies the input types when 'kwargs' is passed to the run method.
452  
453          Use as:
454  
455          ```python
456          @component
457          class MyComponent:
458  
459              def __init__(self, value: int) -> None:
460                  component.set_input_types(self, value_1=str, value_2=str)
461                  ...
462  
463              @component.output_types(output_1=int, output_2=str)
464              def run(self, **kwargs):
465                  return {"output_1": kwargs["value_1"], "output_2": ""}
466          ```
467  
468          Note that if the `run()` method also specifies some parameters, those will take precedence.
469  
470          For example:
471  
472          ```python
473          @component
474          class MyComponent:
475  
476              def __init__(self, value: int) -> None:
477                  component.set_input_types(self, value_1=str, value_2=str)
478                  ...
479  
480              @component.output_types(output_1=int, output_2=str)
481              def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
482                  return {"output_1": kwargs["value_1"], "output_2": ""}
483          ```
484  
485          would add a mandatory `value_0` parameters, make the `value_1`
486          parameter optional with a default None, and keep the `value_2`
487          parameter mandatory as specified in `set_input_types`.
488  
489          """
490          if not _component_run_has_kwargs(instance.__class__):
491              raise ComponentError(
492                  "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
493              )
494  
495          instance.__haystack_input__ = Sockets(
496              instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
497          )
498  
499      def set_output_types(self, instance: Any, **types: type[Any]) -> None:
500          """
501          Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
502  
503          Use as:
504  
505          ```python
506          @component
507          class MyComponent:
508  
509              def __init__(self, value: int) -> None:
510                  component.set_output_types(self, output_1=int, output_2=str)
511                  ...
512  
513              # no decorators here
514              def run(self, value: int):
515                  return {"output_1": 1, "output_2": "2"}
516  
517              # also no decorators here
518              async def run_async(self, value: int):
519                  return {"output_1": 1, "output_2": "2"}
520          ```
521          """
522          has_run_decorator = hasattr(instance.run, "_output_types_cache")
523          has_run_async_decorator = hasattr(instance, "run_async") and hasattr(instance.run_async, "_output_types_cache")
524          if has_run_decorator or has_run_async_decorator:
525              raise ComponentError(
526                  "Cannot call `set_output_types` on a component that already has the 'output_types' decorator on its "
527                  "`run` or `run_async` methods."
528              )
529  
530          instance.__haystack_output__ = Sockets(
531              instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
532          )
533  
534      def output_types(
535          self, **types: Any
536      ) -> Callable[[Callable[RunParamsT, RunReturnT]], Callable[RunParamsT, RunReturnT]]:
537          """
538          Decorator factory that specifies the output types of a component.
539  
540          Use as:
541          ```python
542          @component
543          class MyComponent:
544              @component.output_types(output_1=int, output_2=str)
545              def run(self, value: int):
546                  return {"output_1": 1, "output_2": "2"}
547          ```
548          """
549  
550          def output_types_decorator(run_method: Callable[RunParamsT, RunReturnT]) -> Callable[RunParamsT, RunReturnT]:
551              """
552              Decorator that sets the output types of the decorated method.
553  
554              This happens at class creation time, and since we don't have the decorated
555              class available here, we temporarily store the output types as an attribute of
556              the decorated method. The ComponentMeta metaclass will use this data to create
557              sockets at instance creation time.
558              """
559              method_name = run_method.__name__
560              if method_name not in ("run", "run_async"):
561                  raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
562  
563              setattr(  # noqa: B010
564                  run_method,
565                  "_output_types_cache",
566                  {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
567              )
568              return run_method
569  
570          return output_types_decorator
571  
572      def _component(self, cls: type[T]) -> type[T]:
573          """
574          Decorator validating the structure of the component and registering it in the components registry.
575          """
576          logger.debug("Registering {component} as a component", component=cls)
577  
578          # Check for required methods and fail as soon as possible
579          if not hasattr(cls, "run"):
580              raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
581  
582          def copy_class_namespace(namespace: dict[str, Any]) -> None:
583              """
584              This is the callback that `typing.new_class` will use to populate the newly created class.
585  
586              Simply copy the whole namespace from the decorated class.
587              """
588              for key, val in dict(cls.__dict__).items():
589                  # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
590                  if key in ("__dict__", "__weakref__"):
591                      continue
592                  namespace[key] = val
593  
594          # Recreate the decorated component class so it uses our metaclass.
595          # We must explicitly redefine the type of the class to make sure language servers
596          # and type checkers understand that the class is of the correct type.
597          new_cls: type[T] = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)
598  
599          # Save the component in the class registry (for deserialization)
600          class_path = f"{new_cls.__module__}.{new_cls.__name__}"
601          if class_path in self.registry:
602              # Corner case, but it may occur easily in notebooks when re-running cells.
603              logger.debug(
604                  "Component {component} is already registered. Previous imported from '{module_name}', \
605                  new imported from '{new_module_name}'",
606                  component=class_path,
607                  module_name=self.registry[class_path],
608                  new_module_name=new_cls,
609              )
610          self.registry[class_path] = new_cls
611          logger.debug("Registered Component {component}", component=new_cls)
612  
613          # Override the __repr__ method with a default one
614          # mypy is not happy that:
615          # 1) we are assigning a method to a class
616          # 2) _component_repr has a different type (Callable[[Component], str]) than the expected
617          # __repr__ method (Callable[[object], str])
618          new_cls.__repr__ = _component_repr  # type: ignore[assignment]
619  
620          return new_cls
621  
622      # Call signature when the decorator is used without parens (@component).
623      @overload
624      def __call__(self, cls: type[T]) -> type[T]: ...
625  
626      # Overload allowing the decorator to be used with parens (@component()).
627      @overload
628      def __call__(self) -> Callable[[type[T]], type[T]]: ...
629  
630      def __call__(self, cls: type[T] | None = None) -> type[T] | Callable[[type[T]], type[T]]:
631          # We must wrap the call to the decorator in a function for it to work
632          # correctly with or without parens
633          def wrap(cls: type[T]) -> type[T]:
634              return self._component(cls)
635  
636          if cls:
637              # Decorator is called without parens
638              return wrap(cls)
639  
640          # Decorator is called with parens
641          return wrap
642  
643  
644  component = _Component()