sockets.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 from haystack.core.type_utils import _type_name 8 9 from .types import InputSocket, OutputSocket 10 11 SocketsDict = dict[str, InputSocket | OutputSocket] 12 SocketsIOType = type[InputSocket] | type[OutputSocket] 13 14 15 class Sockets: # noqa: PLW1641 16 """ 17 Represents the inputs or outputs of a `Component`. 18 19 Depending on the type passed to the constructor, it will represent either the inputs or the outputs of 20 the `Component`. 21 22 Usage: 23 ```python 24 from typing import Any 25 from haystack.components.builders.prompt_builder import PromptBuilder 26 from haystack.core.component.sockets import Sockets 27 from haystack.core.component.types import InputSocket, OutputSocket 28 29 30 prompt_template = \""" 31 Given these documents, answer the question.\nDocuments: 32 {% for doc in documents %} 33 {{ doc.content }} 34 {% endfor %} 35 36 \nQuestion: {{question}} 37 \nAnswer: 38 \""" 39 40 prompt_builder = PromptBuilder(template=prompt_template) 41 sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)} 42 inputs = Sockets(component=prompt_builder, sockets_dict=sockets, sockets_io_type=InputSocket) 43 inputs 44 # >> Inputs: 45 # >> - question: Any 46 # >> - documents: Any 47 48 inputs.question 49 # >> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, ... 50 ``` 51 """ 52 53 # We're using a forward declaration here to avoid a circular import. 54 def __init__( 55 self, 56 component: "Component", # type: ignore[name-defined] # noqa: F821 57 sockets_dict: SocketsDict, 58 sockets_io_type: SocketsIOType, 59 ) -> None: 60 """ 61 Create a new Sockets object. 62 63 We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for 64 the `__repr__` method. 65 We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't 66 work for components that have no sockets at all. Either input or output. 67 68 :param component: 69 The component that these sockets belong to. 70 :param sockets_dict: 71 A dictionary of sockets. 72 :param sockets_io_type: 73 The type of the sockets. 74 """ 75 self._sockets_io_type = sockets_io_type 76 self._component = component 77 self._sockets_dict = sockets_dict 78 self.__dict__.update(sockets_dict) 79 80 def __eq__(self, value: object) -> bool: 81 if not isinstance(value, Sockets): 82 return False 83 84 return ( 85 self._sockets_io_type == value._sockets_io_type 86 and self._component == value._component 87 and self._sockets_dict == value._sockets_dict 88 ) 89 90 def __setitem__(self, key: str, socket: InputSocket | OutputSocket) -> None: 91 """ 92 Adds a new socket to this Sockets object. 93 94 This eases a bit updating the list of sockets after Sockets has been created. 95 That should happen only in the `component` decorator. 96 """ 97 self._sockets_dict[key] = socket 98 self.__dict__[key] = socket 99 100 def __contains__(self, key: str) -> bool: 101 return key in self._sockets_dict 102 103 def get(self, key: str, default: InputSocket | OutputSocket | None = None) -> InputSocket | OutputSocket | None: 104 """ 105 Get a socket from the Sockets object. 106 107 :param key: 108 The name of the socket to get. 109 :param default: 110 The value to return if the key is not found. 111 :returns: 112 The socket with the given key or `default` if the key is not found. 113 """ 114 return self._sockets_dict.get(key, default) 115 116 def _component_name(self) -> str: 117 if pipeline := self._component.__haystack_added_to_pipeline__: 118 # This Component has been added in a Pipeline, let's get the name from there. 119 return pipeline.get_component_name(self._component) 120 121 # This Component has not been added to a Pipeline yet, so we can't know its name. 122 # Let's use default __repr__. We don't call repr() directly as Components have a custom 123 # __repr__ method and that would lead to infinite recursion since we call Sockets.__repr__ in it. 124 return object.__repr__(self._component) 125 126 def __getattribute__(self, name: Any) -> Any: 127 try: 128 sockets = object.__getattribute__(self, "_sockets") 129 if name in sockets: 130 return sockets[name] 131 except AttributeError: 132 pass 133 134 return object.__getattribute__(self, name) 135 136 def __repr__(self) -> str: 137 result = "" 138 if self._sockets_io_type == InputSocket: 139 result = "Inputs:\n" 140 elif self._sockets_io_type == OutputSocket: 141 result = "Outputs:\n" 142 143 return result + "\n".join([f" - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])