types.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from collections.abc import Iterable 6 from dataclasses import dataclass, field 7 from types import UnionType 8 from typing import Annotated, Any, TypeAlias, TypedDict, TypeVar, get_args 9 10 HAYSTACK_VARIADIC_ANNOTATION = "__haystack__variadic_t" 11 HAYSTACK_GREEDY_VARIADIC_ANNOTATION = "__haystack__greedy_variadic_t" 12 13 # # Generic type variable used in the Variadic container 14 T = TypeVar("T") 15 16 17 # Variadic is a custom annotation type we use to mark input types. 18 # This type doesn't do anything else than "marking" the contained 19 # type so it can be used in the `InputSocket` creation where we 20 # check that its annotation equals to HAYSTACK_VARIADIC_ANNOTATION 21 Variadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_VARIADIC_ANNOTATION] 22 23 # GreedyVariadic type is similar to Variadic. 24 # The only difference is the way it's treated by the Pipeline when input is received 25 # in a socket with this type. 26 # Instead of waiting for other inputs to be received, Components that have a GreedyVariadic 27 # input will be run right after receiving the first input. 28 # Even if there are multiple connections to that socket. 29 GreedyVariadic: TypeAlias = Annotated[Iterable[T], HAYSTACK_GREEDY_VARIADIC_ANNOTATION] 30 31 32 class _empty: 33 """Custom object for marking InputSocket.default_value as not set.""" 34 35 36 @dataclass 37 class InputSocket: 38 """ 39 Represents an input of a `Component`. 40 41 :param name: 42 The name of the input. 43 :param type: 44 The type of the input. 45 :param default_value: 46 The default value of the input. If not set, the input is mandatory. 47 :param is_lazy_variadic: 48 Whether the input is a lazy variadic or not. 49 :param is_greedy: 50 Whether the input is a greedy variadic or not. 51 :param senders: 52 The list of components that send data to this input. 53 :param wrap_input_in_list: 54 Whether to wrap the input in a list before passing it to the component. 55 Only applies to lazy variadic inputs so when is_lazy_variadic is True. 56 """ 57 58 name: str 59 type: type | UnionType 60 default_value: Any = _empty 61 is_lazy_variadic: bool = field(init=False) 62 is_greedy: bool = field(init=False) 63 senders: list[str] = field(default_factory=list) 64 wrap_input_in_list: bool = True 65 66 @property 67 def is_variadic(self) -> bool: 68 """Check if the input is variadic.""" 69 return self.is_greedy or self.is_lazy_variadic 70 71 @property 72 def is_mandatory(self) -> bool: 73 """Check if the input is mandatory.""" 74 return self.default_value == _empty 75 76 def __post_init__(self) -> None: 77 try: 78 # __metadata__ is a tuple 79 self.is_lazy_variadic = ( 80 hasattr(self.type, "__metadata__") and self.type.__metadata__[0] == HAYSTACK_VARIADIC_ANNOTATION 81 ) 82 self.is_greedy = ( 83 hasattr(self.type, "__metadata__") and self.type.__metadata__[0] == HAYSTACK_GREEDY_VARIADIC_ANNOTATION 84 ) 85 except AttributeError: 86 self.is_lazy_variadic = False 87 self.is_greedy = False 88 89 # We need to "unpack" the type inside the Variadic annotation, otherwise the pipeline connection api will try 90 # to match `Annotated[type, HAYSTACK_VARIADIC_ANNOTATION]`. 91 # 92 # Note1: Variadic is expressed as an annotation of one single type, so the return value of get_args will 93 # always be a one-item tuple. 94 # 95 # Note2: a pipeline always passes a list of items when a component input is declared as Variadic, so the 96 # type itself always wraps an iterable of the declared type. For example, Variadic[int] is eventually an 97 # alias for Iterable[int]. Since we're interested in getting the inner type `int`, we call `get_args` 98 # twice: the first time to get `list[int]` out of `Variadic`, the second time to get `int` out of `list[int]`. 99 if self.is_lazy_variadic or self.is_greedy: 100 self.type = get_args(get_args(self.type)[0])[0] 101 102 103 class InputSocketTypeDescriptor(TypedDict): 104 """ 105 Describes the type of `InputSocket`. 106 """ 107 108 type: type | UnionType 109 is_mandatory: bool 110 111 112 @dataclass 113 class OutputSocket: 114 """ 115 Represents an output of a `Component`. 116 117 :param name: 118 The name of the output. 119 :param type: 120 The type of the output. 121 :param receivers: 122 The list of components that receive the output of this component. 123 """ 124 125 name: str 126 type: type 127 receivers: list[str] = field(default_factory=list)