/ haystack / core / component / sockets.py
sockets.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  from haystack.core.type_utils import _type_name
  8  
  9  from .types import InputSocket, OutputSocket
 10  
 11  SocketsDict = dict[str, InputSocket | OutputSocket]
 12  SocketsIOType = type[InputSocket] | type[OutputSocket]
 13  
 14  
 15  class Sockets:  # noqa: PLW1641
 16      """
 17      Represents the inputs or outputs of a `Component`.
 18  
 19      Depending on the type passed to the constructor, it will represent either the inputs or the outputs of
 20      the `Component`.
 21  
 22      Usage:
 23      ```python
 24      from typing import Any
 25      from haystack.components.builders.prompt_builder import PromptBuilder
 26      from haystack.core.component.sockets import Sockets
 27      from haystack.core.component.types import InputSocket, OutputSocket
 28  
 29  
 30      prompt_template = \"""
 31      Given these documents, answer the question.\nDocuments:
 32      {% for doc in documents %}
 33          {{ doc.content }}
 34      {% endfor %}
 35  
 36      \nQuestion: {{question}}
 37      \nAnswer:
 38      \"""
 39  
 40      prompt_builder = PromptBuilder(template=prompt_template)
 41      sockets = {"question": InputSocket("question", Any), "documents": InputSocket("documents", Any)}
 42      inputs = Sockets(component=prompt_builder, sockets_dict=sockets, sockets_io_type=InputSocket)
 43      inputs
 44      # >> Inputs:
 45      # >>   - question: Any
 46      # >>   - documents: Any
 47  
 48      inputs.question
 49      # >> InputSocket(name='question', type=typing.Any, default_value=<class 'haystack.core.component.types._empty'>, ...
 50      ```
 51      """
 52  
 53      # We're using a forward declaration here to avoid a circular import.
 54      def __init__(
 55          self,
 56          component: "Component",  # type: ignore[name-defined] # noqa: F821
 57          sockets_dict: SocketsDict,
 58          sockets_io_type: SocketsIOType,
 59      ) -> None:
 60          """
 61          Create a new Sockets object.
 62  
 63          We don't do any enforcement on the types of the sockets here, the `sockets_type` is only used for
 64          the `__repr__` method.
 65          We could do without it and use the type of a random value in the `sockets` dict, but that wouldn't
 66          work for components that have no sockets at all. Either input or output.
 67  
 68          :param component:
 69              The component that these sockets belong to.
 70          :param sockets_dict:
 71              A dictionary of sockets.
 72          :param sockets_io_type:
 73              The type of the sockets.
 74          """
 75          self._sockets_io_type = sockets_io_type
 76          self._component = component
 77          self._sockets_dict = sockets_dict
 78          self.__dict__.update(sockets_dict)
 79  
 80      def __eq__(self, value: object) -> bool:
 81          if not isinstance(value, Sockets):
 82              return False
 83  
 84          return (
 85              self._sockets_io_type == value._sockets_io_type
 86              and self._component == value._component
 87              and self._sockets_dict == value._sockets_dict
 88          )
 89  
 90      def __setitem__(self, key: str, socket: InputSocket | OutputSocket) -> None:
 91          """
 92          Adds a new socket to this Sockets object.
 93  
 94          This eases a bit updating the list of sockets after Sockets has been created.
 95          That should happen only in the `component` decorator.
 96          """
 97          self._sockets_dict[key] = socket
 98          self.__dict__[key] = socket
 99  
100      def __contains__(self, key: str) -> bool:
101          return key in self._sockets_dict
102  
103      def get(self, key: str, default: InputSocket | OutputSocket | None = None) -> InputSocket | OutputSocket | None:
104          """
105          Get a socket from the Sockets object.
106  
107          :param key:
108              The name of the socket to get.
109          :param default:
110              The value to return if the key is not found.
111          :returns:
112              The socket with the given key or `default` if the key is not found.
113          """
114          return self._sockets_dict.get(key, default)
115  
116      def _component_name(self) -> str:
117          if pipeline := self._component.__haystack_added_to_pipeline__:
118              # This Component has been added in a Pipeline, let's get the name from there.
119              return pipeline.get_component_name(self._component)
120  
121          # This Component has not been added to a Pipeline yet, so we can't know its name.
122          # Let's use default __repr__. We don't call repr() directly as Components have a custom
123          # __repr__ method and that would lead to infinite recursion since we call Sockets.__repr__ in it.
124          return object.__repr__(self._component)
125  
126      def __getattribute__(self, name: Any) -> Any:
127          try:
128              sockets = object.__getattribute__(self, "_sockets")
129              if name in sockets:
130                  return sockets[name]
131          except AttributeError:
132              pass
133  
134          return object.__getattribute__(self, name)
135  
136      def __repr__(self) -> str:
137          result = ""
138          if self._sockets_io_type == InputSocket:
139              result = "Inputs:\n"
140          elif self._sockets_io_type == OutputSocket:
141              result = "Outputs:\n"
142  
143          return result + "\n".join([f"  - {n}: {_type_name(s.type)}" for n, s in self._sockets_dict.items()])