test_sockets.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import pytest 6 7 from haystack.core.component.sockets import InputSocket, OutputSocket, Sockets 8 from haystack.testing.factory import component_class 9 10 11 class TestSockets: 12 def test_init(self): 13 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 14 sockets: dict[str, InputSocket | OutputSocket] = { 15 "input_1": InputSocket("input_1", int), 16 "input_2": InputSocket("input_2", int), 17 } 18 io = Sockets(component=comp, sockets_dict=sockets, sockets_io_type=InputSocket) 19 assert io._component == comp 20 assert "input_1" in io.__dict__ 21 assert io.__dict__["input_1"] == comp.__haystack_input__._sockets_dict["input_1"] # type: ignore[attr-defined] 22 assert "input_2" in io.__dict__ 23 assert io.__dict__["input_2"] == comp.__haystack_input__._sockets_dict["input_2"] # type: ignore[attr-defined] 24 25 def test_init_with_empty_sockets(self): 26 comp = component_class("SomeComponent")() 27 io = Sockets(component=comp, sockets_dict={}, sockets_io_type=InputSocket) 28 29 assert io._component == comp 30 assert io._sockets_dict == {} 31 32 def test_getattribute(self): 33 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 34 io = Sockets( 35 component=comp, 36 sockets_dict=comp.__haystack_input__._sockets_dict, # type: ignore[attr-defined] 37 sockets_io_type=InputSocket, 38 ) 39 40 assert io.input_1 == comp.__haystack_input__._sockets_dict["input_1"] # type: ignore[attr-defined] 41 assert io.input_2 == comp.__haystack_input__._sockets_dict["input_2"] # type: ignore[attr-defined] 42 43 def test_getattribute_non_existing_socket(self): 44 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 45 io = Sockets( 46 component=comp, 47 sockets_dict=comp.__haystack_input__._sockets_dict, # type: ignore[attr-defined] 48 sockets_io_type=InputSocket, 49 ) 50 51 with pytest.raises(AttributeError): 52 io.input_3 53 54 def test_repr(self): 55 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 56 io = Sockets( 57 component=comp, 58 sockets_dict=comp.__haystack_input__._sockets_dict, # type: ignore[attr-defined] 59 sockets_io_type=InputSocket, 60 ) 61 res = repr(io) 62 assert res == "Inputs:\n - input_1: int\n - input_2: int" 63 64 def test_get(self): 65 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 66 io = Sockets( 67 component=comp, 68 sockets_dict=comp.__haystack_input__._sockets_dict, # type: ignore[attr-defined] 69 sockets_io_type=InputSocket, 70 ) 71 72 assert io.get("input_1") == comp.__haystack_input__._sockets_dict["input_1"] # type: ignore[attr-defined] 73 assert io.get("input_2") == comp.__haystack_input__._sockets_dict["input_2"] # type: ignore[attr-defined] 74 assert io.get("invalid") is None 75 assert io.get("invalid", InputSocket("input_2", int)) == InputSocket("input_2", int) 76 77 def test_contains(self): 78 comp = component_class("SomeComponent", input_types={"input_1": int, "input_2": int})() 79 io = Sockets( 80 component=comp, 81 sockets_dict=comp.__haystack_input__._sockets_dict, # type: ignore[attr-defined] 82 sockets_io_type=InputSocket, 83 ) 84 85 assert "input_1" in io 86 assert "input_2" in io 87 assert "invalid" not in io