/ test / core / component / test_sockets.py
test_sockets.py
 1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
 2  #
 3  # SPDX-License-Identifier: Apache-2.0
 4  
 5  import pytest
 6  
 7  from haystack.core.component.sockets import InputSocket, OutputSocket, Sockets
 8  from haystack.testing.factory import component_class
 9  
10  
11  class TestSockets:
12      def test_init(self):
13          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
14          sockets: dict[str, InputSocket | OutputSocket] = {
15              "input_1": InputSocket("input_1", int),
16              "input_2": InputSocket("input_2", int),
17          }
18          io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket)
19          assert io._component == comp
20          assert "input_1" in io.__dict__
21          assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"]  # type: ignore[attr-defined]
22          assert "input_2" in io.__dict__
23          assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"]  # type: ignore[attr-defined]
24  
25      def test_init_with_empty_sockets(self):
26          comp = component_class("SomeComponent")()
27          io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket)
28  
29          assert io._component == comp
30          assert io._sockets_dict == {}
31  
32      def test_getattribute(self):
33          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
34          io = Sockets(
35              component=comp,
36              sockets_dict=comp.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
37              sockets_io_type=InputSocket,
38          )
39  
40          assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"]  # type: ignore[attr-defined]
41          assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"]  # type: ignore[attr-defined]
42  
43      def test_getattribute_non_existing_socket(self):
44          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
45          io = Sockets(
46              component=comp,
47              sockets_dict=comp.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
48              sockets_io_type=InputSocket,
49          )
50  
51          with pytest.raises(AttributeError):
52              io.input_3
53  
54      def test_repr(self):
55          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
56          io = Sockets(
57              component=comp,
58              sockets_dict=comp.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
59              sockets_io_type=InputSocket,
60          )
61          res = repr(io)
62          assert res == "Inputs:\n  - input_1: int\n  - input_2: int"
63  
64      def test_get(self):
65          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
66          io = Sockets(
67              component=comp,
68              sockets_dict=comp.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
69              sockets_io_type=InputSocket,
70          )
71  
72          assert io.get("input_1") == comp.__haystack_input__._sockets_dict["input_1"]  # type: ignore[attr-defined]
73          assert io.get("input_2") == comp.__haystack_input__._sockets_dict["input_2"]  # type: ignore[attr-defined]
74          assert io.get("invalid") is None
75          assert io.get("invalid", InputSocket("input_2", int)) == InputSocket("input_2", int)
76  
77      def test_contains(self):
78          comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})()
79          io = Sockets(
80              component=comp,
81              sockets_dict=comp.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
82              sockets_io_type=InputSocket,
83          )
84  
85          assert "input_1" in io
86          assert "input_2" in io
87          assert "invalid" not in io