component.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 """ 6 Attributes: 7 8 component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline. 9 10 All components must follow the contract below. This docstring is the source of truth for components contract. 11 12 <hr> 13 14 `@component` decorator 15 16 All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them. 17 18 <hr> 19 20 `__init__(self, **kwargs)` 21 22 Optional method. 23 24 Components may have an `__init__` method where they define: 25 26 - `self.init_parameters = {same parameters that the __init__ method received}`: 27 In this dictionary you can store any state the components wish to be persisted when they are saved. 28 These values will be given to the `__init__` method of a new instance when the pipeline is loaded. 29 Note that by default the `@component` decorator saves the arguments automatically. 30 However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead. 31 Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed. 32 33 Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and 34 dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init 35 time. If there's the need for such values, consider serializing them to a string. 36 37 If you need to accept classes or callables, accept either a string import path or the callable itself. Resolve strings 38 to objects in `__init__`, and serialize objects back to importable strings in `to_dict()` so that `from_dict()` can load 39 them (for example, store `"module_path.symbol_name"` and load it via `importlib`). This keeps init parameters JSON 40 serializable for pipeline save/load. See `haystack.testing.sample_components.accumulate.Accumulate` for a reference 41 implementation. 42 43 The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and 44 validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to 45 the `warm_up()` method. 46 47 <hr> 48 49 `warm_up(self)` 50 51 Optional method. 52 53 This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations, 54 because Pipeline will not keep track of which components it called `warm_up()` on. 55 56 <hr> 57 58 `run(self, data)` 59 60 Mandatory method. 61 62 This is the method where the main functionality of the component should be carried out. It's called by 63 `Pipeline.run()`. 64 65 When the component should run, Pipeline will call this method with an instance of the dataclass returned by the 66 method decorated with `@component.input`. This dataclass contains: 67 68 - all the input values coming from other components connected to it, 69 - if any is missing, the corresponding value defined in `self.defaults`, if it exists. 70 71 `run()` must return a single instance of the dataclass declared through the method decorated with 72 `@component.output`. 73 74 """ 75 76 import inspect 77 import typing 78 from collections.abc import Callable, Coroutine, Iterator, Mapping 79 from contextlib import contextmanager 80 from contextvars import ContextVar 81 from copy import deepcopy 82 from dataclasses import dataclass 83 from types import new_class 84 from typing import Any, ParamSpec, Protocol, TypeVar, overload, runtime_checkable 85 86 from haystack import logging 87 from haystack.core.errors import ComponentError 88 89 from .sockets import Sockets 90 from .types import InputSocket, OutputSocket, _empty 91 92 logger = logging.getLogger(__name__) 93 94 RunParamsT = ParamSpec("RunParamsT") 95 RunReturnT = TypeVar("RunReturnT", bound=Mapping[str, Any] | Coroutine[Any, Any, Mapping[str, Any]]) 96 97 98 @dataclass 99 class PreInitHookPayload: 100 """ 101 Payload for the hook called before a component instance is initialized. 102 103 :param callback: 104 Receives the following inputs: component class and init parameter keyword args. 105 :param in_progress: 106 Flag to indicate if the hook is currently being executed. 107 Used to prevent it from being called recursively (if the component's constructor 108 instantiates another component). 109 """ 110 111 callback: Callable 112 in_progress: bool = False 113 114 115 _COMPONENT_PRE_INIT_HOOK: ContextVar[PreInitHookPayload | None] = ContextVar("component_pre_init_hook", default=None) 116 117 118 @contextmanager 119 def _hook_component_init(callback: Callable) -> Iterator[None]: 120 """ 121 Context manager to set a callback that will be invoked before a component's constructor is called. 122 123 The callback receives the component class and the init parameters (as keyword arguments) and can modify the init 124 parameters in place. 125 126 :param callback: 127 Callback function to invoke. 128 """ 129 token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback)) 130 try: 131 yield 132 finally: 133 _COMPONENT_PRE_INIT_HOOK.reset(token) 134 135 136 @runtime_checkable 137 class Component(Protocol): 138 """ 139 Note this is only used by type checking tools. 140 141 In order to implement the `Component` protocol, custom components need to 142 have a `run` method. The signature of the method and its return value 143 won't be checked, i.e. classes with the following methods: 144 145 def run(self, param: str) -> dict[str, Any]: 146 ... 147 148 and 149 150 def run(self, **kwargs): 151 ... 152 153 will be both considered as respecting the protocol. This makes the type 154 checking much weaker, but we have other places where we ensure code is 155 dealing with actual Components. 156 157 The protocol is runtime checkable so it'll be possible to assert: 158 159 isinstance(MyComponent, Component) 160 """ 161 162 # The following expression defines a run method compatible with any input signature. 163 # Its type is equivalent to Callable[..., dict[str, Any]]. 164 # See https://typing.python.org/en/latest/spec/callables.html#meaning-of-in-callable. 165 # 166 # Using `run: Callable[..., dict[str, Any]]` directly leads to type errors: the protocol would expect a settable 167 # attribute `run`, while the actual implementation is a read-only method. 168 # For example: 169 # from haystack import Pipeline, component 170 # @component 171 # class MyComponent: 172 # @component.output_types(out=str) 173 # def run(self): 174 # return {"out": "Hello, world!"} 175 # pipeline = Pipeline() 176 # pipeline.add_component("my_component", MyComponent()) 177 # 178 # mypy raises: 179 # error: Argument 2 to "add_component" of "PipelineBase" has incompatible type "MyComponent"; expected "Component" 180 # [arg-type] 181 # note: Protocol member Component.run expected settable variable, got read-only attribute 182 183 def run(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]: # noqa: D102 184 ... 185 186 187 class ComponentMeta(type): 188 @staticmethod 189 def _positional_to_kwargs(cls_type: type, args: tuple[Any, ...]) -> dict[str, Any]: 190 """ 191 Convert positional arguments to keyword arguments based on the signature of the `__init__` method. 192 """ 193 init_signature = inspect.signature(cls_type.__init__) # type:ignore[misc] 194 init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"} 195 196 out = {} 197 for arg, (name, info) in zip(args, init_params.items(), strict=False): 198 if info.kind == inspect.Parameter.VAR_POSITIONAL: 199 raise ComponentError( 200 "Pre-init hooks do not support components with variadic positional args in their init method" 201 ) 202 203 assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY) 204 out[name] = arg 205 return out 206 207 @staticmethod 208 def _parse_and_set_output_sockets(instance: Any) -> None: 209 has_async_run = hasattr(instance, "run_async") 210 211 # If `component.set_output_types()` was called in the component constructor, 212 # `__haystack_output__` is already populated, no need to do anything. 213 if not hasattr(instance, "__haystack_output__"): 214 # If that's not the case, we need to populate `__haystack_output__` 215 # 216 # If either of the run methods were decorated, they'll have a field assigned that 217 # stores the output specification. If both run methods were decorated, we ensure that 218 # outputs are the same. We deepcopy the content of the cache to transfer ownership from 219 # the class method to the actual instance, so that different instances of the same class 220 # won't share this data. 221 222 run_output_types = getattr(instance.run, "_output_types_cache", {}) 223 async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {} 224 225 if has_async_run and run_output_types != async_run_output_types: 226 raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same") 227 output_types_cache = run_output_types 228 229 instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket) 230 231 @staticmethod 232 def _parse_and_set_input_sockets(component_cls: type, instance: Any) -> None: 233 def inner(method: Callable[..., Any], sockets: Sockets) -> inspect.Signature: 234 from inspect import Parameter 235 236 run_signature = inspect.signature(method) 237 try: 238 # TypeError is raised if the argument is not of a type that can contain annotations 239 run_hints = typing.get_type_hints(method) 240 except TypeError: 241 run_hints = None 242 243 for param_name, param_info in run_signature.parameters.items(): 244 if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD): 245 continue 246 247 # We prefer the type annotation from inspect.signature, but if it's a string we need to resolve it 248 # using the hints. The type annotation can be a string if the component is using postponed evaluation 249 # of annotations. 250 annotation = param_info.annotation 251 if isinstance(annotation, str) and run_hints is not None: 252 annotation = run_hints.get(param_name, annotation) 253 254 socket_kwargs = {"name": param_name, "type": annotation} 255 if param_info.default != Parameter.empty: 256 socket_kwargs["default_value"] = param_info.default 257 258 new_socket = InputSocket(**socket_kwargs) 259 260 # Also ensure that new sockets don't override existing ones. 261 existing_socket = sockets.get(param_name) 262 if existing_socket is not None and existing_socket != new_socket: 263 raise ComponentError( 264 "set_input_types()/set_input_type() cannot override the parameters of the 'run' method" 265 ) 266 267 sockets[param_name] = new_socket 268 269 return run_signature 270 271 # Create the sockets if set_input_types() wasn't called in the constructor. 272 if not hasattr(instance, "__haystack_input__"): 273 instance.__haystack_input__ = Sockets(instance, {}, InputSocket) 274 275 inner(getattr(component_cls, "run"), instance.__haystack_input__) # noqa: B009 276 277 # Ensure that the sockets are the same for the async method, if it exists. 278 async_run = getattr(component_cls, "run_async", None) 279 if async_run is not None: 280 run_sockets = Sockets(instance, {}, InputSocket) 281 async_run_sockets = Sockets(instance, {}, InputSocket) 282 283 # Can't use the sockets from above as they might contain 284 # values set with set_input_types(). 285 run_sig = inner(getattr(component_cls, "run"), run_sockets) # noqa: B009 286 async_run_sig = inner(async_run, async_run_sockets) 287 288 if async_run_sockets != run_sockets or run_sig != async_run_sig: 289 sig_diff = _compare_run_methods_signatures(run_sig, async_run_sig) 290 raise ComponentError( 291 f"Parameters of 'run' and 'run_async' methods must be the same.\nDifferences found:\n{sig_diff}" 292 ) 293 294 def __call__(cls, *args: Any, **kwargs: Any) -> Any: 295 """ 296 This method is called when clients instantiate a Component and runs before __new__ and __init__. 297 """ 298 # This will call __new__ then __init__, giving us back the Component instance 299 pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get() 300 if pre_init_hook is None or pre_init_hook.in_progress: 301 instance = super().__call__(*args, **kwargs) 302 else: 303 try: 304 pre_init_hook.in_progress = True 305 named_positional_args = ComponentMeta._positional_to_kwargs(cls, args) 306 assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), ( 307 "positional and keyword arguments overlap" 308 ) 309 kwargs.update(named_positional_args) 310 pre_init_hook.callback(cls, kwargs) 311 instance = super().__call__(**kwargs) 312 finally: 313 pre_init_hook.in_progress = False 314 315 # Before returning, we have the chance to modify the newly created 316 # Component instance, so we take the chance and set up the I/O sockets 317 has_async_run = hasattr(instance, "run_async") 318 if has_async_run and not inspect.iscoroutinefunction(instance.run_async): 319 raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine") 320 instance.__haystack_supports_async__ = has_async_run 321 322 ComponentMeta._parse_and_set_input_sockets(cls, instance) 323 ComponentMeta._parse_and_set_output_sockets(instance) 324 325 # Since a Component can't be used in multiple Pipelines at the same time 326 # we need to know if it's already owned by a Pipeline when adding it to one. 327 # We use this flag to check that. 328 instance.__haystack_added_to_pipeline__ = None 329 330 return instance 331 332 333 def _component_repr(component: Component) -> str: 334 """ 335 All Components override their __repr__ method with this one. 336 337 It prints the component name and the input/output sockets. 338 """ 339 result = object.__repr__(component) 340 if pipeline := getattr(component, "__haystack_added_to_pipeline__", None): 341 # This Component has been added in a Pipeline, let's get the name from there. 342 result += f"\n{pipeline.get_component_name(component)}" 343 344 # We're explicitly ignoring the type here because we're sure that the component 345 # has the __haystack_input__ and __haystack_output__ attributes at this point 346 return ( 347 f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}" 348 f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}" 349 ) 350 351 352 def _component_run_has_kwargs(component_cls: type) -> bool: 353 run_method = getattr(component_cls, "run", None) 354 if run_method is None: 355 return False 356 return any( 357 param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values() 358 ) 359 360 361 def _compare_run_methods_signatures(run_sig: inspect.Signature, async_run_sig: inspect.Signature) -> str: 362 """ 363 Builds a detailed error message with the differences between the signatures of the run and run_async methods. 364 365 :param run_sig: The signature of the run method 366 :param async_run_sig: The signature of the run_async method 367 368 :returns: 369 A detailed error message if signatures don't match, empty string if they do 370 """ 371 differences = [] 372 run_params = list(run_sig.parameters.items()) 373 async_params = list(async_run_sig.parameters.items()) 374 375 if len(run_params) != len(async_params): 376 differences.append( 377 f"Different number of parameters: run has {len(run_params)}, run_async has {len(async_params)}" 378 ) 379 380 for (run_name, run_param), (async_name, async_param) in zip(run_params, async_params, strict=False): 381 if run_name != async_name: 382 differences.append(f"Parameter name mismatch: {run_name} vs {async_name}") 383 384 if run_param.annotation != async_param.annotation: 385 differences.append( 386 f"Parameter '{run_name}' type mismatch: {run_param.annotation} vs {async_param.annotation}" 387 ) 388 389 if run_param.default != async_param.default: 390 differences.append( 391 f"Parameter '{run_name}' default value mismatch: {run_param.default} vs {async_param.default}" 392 ) 393 394 if run_param.kind != async_param.kind: 395 differences.append( 396 f"Parameter '{run_name}' kind (POSITIONAL, KEYWORD, etc.) mismatch: " 397 f"{run_param.kind} vs {async_param.kind}" 398 ) 399 400 return "\n".join(differences) 401 402 403 T = TypeVar("T", bound=Component) 404 405 406 class _Component: 407 """ 408 See module's docstring. 409 410 Args: 411 cls: the class that should be used as a component. 412 413 Returns: 414 A class that can be recognized as a component. 415 416 Raises: 417 ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract. 418 """ 419 420 def __init__(self) -> None: 421 self.registry: dict[str, type] = {} 422 423 def set_input_type( 424 self, 425 instance: Component, 426 name: str, 427 type: Any, # noqa: A002 428 default: Any = _empty, 429 ) -> None: 430 """ 431 Add a single input socket to the component instance. 432 433 Replaces any existing input socket with the same name. 434 435 :param instance: Component instance where the input type will be added. 436 :param name: name of the input socket. 437 :param type: type of the input socket. 438 :param default: default value of the input socket, defaults to _empty 439 """ 440 if not _component_run_has_kwargs(instance.__class__): 441 raise ComponentError( 442 "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method" 443 ) 444 445 if not hasattr(instance, "__haystack_input__"): 446 instance.__haystack_input__ = Sockets(instance, {}, InputSocket) # type: ignore 447 instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default) # type: ignore 448 449 def set_input_types(self, instance: Any, **types: type[Any]) -> None: 450 """ 451 Method that specifies the input types when 'kwargs' is passed to the run method. 452 453 Use as: 454 455 ```python 456 @component 457 class MyComponent: 458 459 def __init__(self, value: int) -> None: 460 component.set_input_types(self, value_1=str, value_2=str) 461 ... 462 463 @component.output_types(output_1=int, output_2=str) 464 def run(self, **kwargs): 465 return {"output_1": kwargs["value_1"], "output_2": ""} 466 ``` 467 468 Note that if the `run()` method also specifies some parameters, those will take precedence. 469 470 For example: 471 472 ```python 473 @component 474 class MyComponent: 475 476 def __init__(self, value: int) -> None: 477 component.set_input_types(self, value_1=str, value_2=str) 478 ... 479 480 @component.output_types(output_1=int, output_2=str) 481 def run(self, value_0: str, value_1: Optional[str] = None, **kwargs): 482 return {"output_1": kwargs["value_1"], "output_2": ""} 483 ``` 484 485 would add a mandatory `value_0` parameters, make the `value_1` 486 parameter optional with a default None, and keep the `value_2` 487 parameter mandatory as specified in `set_input_types`. 488 489 """ 490 if not _component_run_has_kwargs(instance.__class__): 491 raise ComponentError( 492 "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method" 493 ) 494 495 instance.__haystack_input__ = Sockets( 496 instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket 497 ) 498 499 def set_output_types(self, instance: Any, **types: type[Any]) -> None: 500 """ 501 Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'. 502 503 Use as: 504 505 ```python 506 @component 507 class MyComponent: 508 509 def __init__(self, value: int) -> None: 510 component.set_output_types(self, output_1=int, output_2=str) 511 ... 512 513 # no decorators here 514 def run(self, value: int): 515 return {"output_1": 1, "output_2": "2"} 516 517 # also no decorators here 518 async def run_async(self, value: int): 519 return {"output_1": 1, "output_2": "2"} 520 ``` 521 """ 522 has_run_decorator = hasattr(instance.run, "_output_types_cache") 523 has_run_async_decorator = hasattr(instance, "run_async") and hasattr(instance.run_async, "_output_types_cache") 524 if has_run_decorator or has_run_async_decorator: 525 raise ComponentError( 526 "Cannot call `set_output_types` on a component that already has the 'output_types' decorator on its " 527 "`run` or `run_async` methods." 528 ) 529 530 instance.__haystack_output__ = Sockets( 531 instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket 532 ) 533 534 def output_types( 535 self, **types: Any 536 ) -> Callable[[Callable[RunParamsT, RunReturnT]], Callable[RunParamsT, RunReturnT]]: 537 """ 538 Decorator factory that specifies the output types of a component. 539 540 Use as: 541 ```python 542 @component 543 class MyComponent: 544 @component.output_types(output_1=int, output_2=str) 545 def run(self, value: int): 546 return {"output_1": 1, "output_2": "2"} 547 ``` 548 """ 549 550 def output_types_decorator(run_method: Callable[RunParamsT, RunReturnT]) -> Callable[RunParamsT, RunReturnT]: 551 """ 552 Decorator that sets the output types of the decorated method. 553 554 This happens at class creation time, and since we don't have the decorated 555 class available here, we temporarily store the output types as an attribute of 556 the decorated method. The ComponentMeta metaclass will use this data to create 557 sockets at instance creation time. 558 """ 559 method_name = run_method.__name__ 560 if method_name not in ("run", "run_async"): 561 raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods") 562 563 setattr( # noqa: B010 564 run_method, 565 "_output_types_cache", 566 {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, 567 ) 568 return run_method 569 570 return output_types_decorator 571 572 def _component(self, cls: type[T]) -> type[T]: 573 """ 574 Decorator validating the structure of the component and registering it in the components registry. 575 """ 576 logger.debug("Registering {component} as a component", component=cls) 577 578 # Check for required methods and fail as soon as possible 579 if not hasattr(cls, "run"): 580 raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.") 581 582 def copy_class_namespace(namespace: dict[str, Any]) -> None: 583 """ 584 This is the callback that `typing.new_class` will use to populate the newly created class. 585 586 Simply copy the whole namespace from the decorated class. 587 """ 588 for key, val in dict(cls.__dict__).items(): 589 # __dict__ and __weakref__ are class-bound, we should let Python recreate them. 590 if key in ("__dict__", "__weakref__"): 591 continue 592 namespace[key] = val 593 594 # Recreate the decorated component class so it uses our metaclass. 595 # We must explicitly redefine the type of the class to make sure language servers 596 # and type checkers understand that the class is of the correct type. 597 new_cls: type[T] = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) 598 599 # Save the component in the class registry (for deserialization) 600 class_path = f"{new_cls.__module__}.{new_cls.__name__}" 601 if class_path in self.registry: 602 # Corner case, but it may occur easily in notebooks when re-running cells. 603 logger.debug( 604 "Component {component} is already registered. Previous imported from '{module_name}', \ 605 new imported from '{new_module_name}'", 606 component=class_path, 607 module_name=self.registry[class_path], 608 new_module_name=new_cls, 609 ) 610 self.registry[class_path] = new_cls 611 logger.debug("Registered Component {component}", component=new_cls) 612 613 # Override the __repr__ method with a default one 614 # mypy is not happy that: 615 # 1) we are assigning a method to a class 616 # 2) _component_repr has a different type (Callable[[Component], str]) than the expected 617 # __repr__ method (Callable[[object], str]) 618 new_cls.__repr__ = _component_repr # type: ignore[assignment] 619 620 return new_cls 621 622 # Call signature when the decorator is used without parens (@component). 623 @overload 624 def __call__(self, cls: type[T]) -> type[T]: ... 625 626 # Overload allowing the decorator to be used with parens (@component()). 627 @overload 628 def __call__(self) -> Callable[[type[T]], type[T]]: ... 629 630 def __call__(self, cls: type[T] | None = None) -> type[T] | Callable[[type[T]], type[T]]: 631 # We must wrap the call to the decorator in a function for it to work 632 # correctly with or without parens 633 def wrap(cls: type[T]) -> type[T]: 634 return self._component(cls) 635 636 if cls: 637 # Decorator is called without parens 638 return wrap(cls) 639 640 # Decorator is called with parens 641 return wrap 642 643 644 component = _Component()