test_component.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from functools import partial 6 from typing import Any 7 8 import pytest 9 10 from haystack.core.component import Component, InputSocket, OutputSocket, component 11 from haystack.core.component.component import _hook_component_init 12 from haystack.core.errors import ComponentError 13 from haystack.core.pipeline import Pipeline 14 15 16 def test_correct_declaration(): 17 @component 18 class MockComponent: 19 def to_dict(self): 20 return {} 21 22 @classmethod 23 def from_dict(cls, data): 24 return cls() 25 26 @component.output_types(output_value=int) 27 def run(self, input_value: int) -> dict[str, int]: 28 return {"output_value": input_value} 29 30 # Verifies also instantiation works with no issues 31 assert MockComponent() 32 assert component.registry["test_component.MockComponent"] == MockComponent 33 assert isinstance(MockComponent(), Component) 34 assert MockComponent().__haystack_supports_async__ is False # type: ignore[attr-defined] 35 36 37 def test_correct_declaration_with_async(): 38 @component 39 class MockComponent: 40 def to_dict(self): 41 return {} 42 43 @classmethod 44 def from_dict(cls, data): 45 return cls() 46 47 @component.output_types(output_value=int) 48 def run(self, input_value: int) -> dict[str, int]: 49 return {"output_value": input_value} 50 51 @component.output_types(output_value=int) 52 async def run_async(self, input_value: int) -> dict[str, int]: 53 return {"output_value": input_value} 54 55 # Verifies also instantiation works with no issues 56 assert MockComponent() 57 assert component.registry["test_component.MockComponent"] == MockComponent 58 assert isinstance(MockComponent(), Component) 59 assert MockComponent().__haystack_supports_async__ is True # type: ignore[attr-defined] 60 61 62 def test_correct_declaration_with_additional_readonly_property(): 63 @component 64 class MockComponent: 65 @property 66 def store(self): 67 return "test_store" 68 69 def to_dict(self): 70 return {} 71 72 @classmethod 73 def from_dict(cls, data): 74 return cls() 75 76 @component.output_types(output_value=int) 77 def run(self, input_value: int) -> dict[str, int]: 78 return {"output_value": input_value} 79 80 # Verifies that instantiation works with no issues 81 assert MockComponent() 82 assert component.registry["test_component.MockComponent"] == MockComponent 83 assert MockComponent().store == "test_store" 84 85 86 def test_correct_declaration_with_additional_writable_property(): 87 @component 88 class MockComponent: 89 @property 90 def store(self): 91 return "test_store" 92 93 @store.setter 94 def store(self, value): 95 self._store = value 96 97 def to_dict(self): 98 return {} 99 100 @classmethod 101 def from_dict(cls, data): 102 return cls() 103 104 @component.output_types(output_value=int) 105 def run(self, input_value: int) -> dict[str, int]: 106 return {"output_value": input_value} 107 108 # Verifies that instantiation works with no issues 109 assert component.registry["test_component.MockComponent"] == MockComponent 110 comp = MockComponent() 111 comp.store = "test_store" 112 assert comp.store == "test_store" 113 114 115 def test_missing_run(): 116 with pytest.raises(ComponentError, match=r"must have a 'run\(\)' method"): 117 118 @component 119 class MockComponent: # type: ignore[type-var] 120 def another_method(self, input_value: int) -> dict[str, int]: 121 return {"output_value": input_value} 122 123 124 def test_async_run_not_async(): 125 @component 126 class MockComponent: 127 @component.output_types(value=int) 128 def run(self, value: int) -> dict[str, int]: 129 return {"value": 1} 130 131 @component.output_types(value=int) 132 def run_async(self, value: int) -> dict[str, int]: 133 return {"value": 1} 134 135 with pytest.raises(ComponentError, match=r"must be a coroutine"): 136 _ = MockComponent() 137 138 139 def test_async_run_not_coroutine(): 140 @component 141 class MockComponent: 142 @component.output_types(value=int) 143 def run(self, value: int) -> dict[str, int]: 144 return {"value": 1} 145 146 @component.output_types(value=int) 147 async def run_async(self, value: int) -> dict[str, int]: # type: ignore[misc] 148 yield {"value": 1} 149 150 with pytest.raises(ComponentError, match=r"must be a coroutine"): 151 _ = MockComponent() 152 153 154 def test_parameters_mismatch_run_and_async_run(): 155 err_msg = r"Parameters of 'run' and 'run_async' methods must be the same" 156 157 @component 158 class MockComponentMismatchingInputTypes: 159 @component.output_types(value=int) 160 def run(self, value: int) -> dict[str, int]: 161 return {"value": 1} 162 163 @component.output_types(value=int) 164 async def run_async(self, value: str) -> dict[str, int]: 165 return {"value": 1} 166 167 with pytest.raises(ComponentError, match=err_msg): 168 _ = MockComponentMismatchingInputTypes() 169 170 @component 171 class MockComponentMismatchingInputs: 172 @component.output_types(value=int) 173 def run(self, value: int, **kwargs: Any) -> dict[str, int]: 174 return {"value": 1} 175 176 @component.output_types(value=int) 177 async def run_async(self, value: int) -> dict[str, int]: 178 return {"value": 1} 179 180 with pytest.raises(ComponentError, match=err_msg): 181 _ = MockComponentMismatchingInputs() 182 183 @component 184 class MockComponentMismatchingInputOrder: 185 @component.output_types(value=int) 186 def run(self, value: int, another: str) -> dict[str, int]: 187 return {"value": 1} 188 189 @component.output_types(value=int) 190 async def run_async(self, another: str, value: int) -> dict[str, int]: 191 return {"value": 1} 192 193 with pytest.raises(ComponentError, match=err_msg): 194 _ = MockComponentMismatchingInputOrder() 195 196 197 def test_set_input_types(): 198 @component 199 class MockComponent: 200 def __init__(self, flag: bool): 201 component.set_input_types(self, value=Any) 202 if flag: 203 component.set_input_type(self, name="another", type=str) 204 205 @component.output_types(value=int) 206 def run(self, **kwargs): 207 return {"value": 1} 208 209 comp = MockComponent(False) 210 assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)} # type: ignore[attr-defined] 211 assert comp.run() == {"value": 1} 212 213 comp = MockComponent(True) 214 assert comp.__haystack_input__._sockets_dict == { # type: ignore[attr-defined] 215 "value": InputSocket("value", Any), 216 "another": InputSocket("another", str), 217 } 218 assert comp.run() == {"value": 1} 219 220 221 def test_set_input_types_no_kwarg(): 222 @component 223 class MockComponent: 224 def __init__(self, flag: bool): 225 if flag: 226 component.set_input_type(self, name="another", type=str) 227 else: 228 component.set_input_types(self, value=Any) 229 230 @component.output_types(value=int) 231 def run(self, fini: bool) -> dict[str, int]: 232 return {"value": 1} 233 234 with pytest.raises(ComponentError, match=r"doesn't have a kwargs parameter"): 235 _ = MockComponent(False) 236 237 with pytest.raises(ComponentError, match=r"doesn't have a kwargs parameter"): 238 _ = MockComponent(True) 239 240 241 def test_set_input_types_overrides_run(): 242 @component 243 class MockComponent: 244 def __init__(self, state: bool): 245 if state: 246 component.set_input_type(self, name="fini", type=str) 247 else: 248 component.set_input_types(self, fini=Any) 249 250 @component.output_types(value=int) 251 def run(self, fini: bool, **kwargs: Any) -> dict[str, int]: 252 return {"value": 1} 253 254 err_msg = "cannot override the parameters of the 'run' method" 255 with pytest.raises(ComponentError, match=err_msg): 256 _ = MockComponent(False) 257 258 with pytest.raises(ComponentError, match=err_msg): 259 _ = MockComponent(True) 260 261 262 def test_set_input_types_postponed_annotations(): 263 # The component HelloUsingFutureAnnotations must live in a different module than the one where the test is defined, 264 # so we can properly set up postponed evaluation of annotations using `from __future__ import annotations`. 265 # For this reason, we define it in haystack.testing.sample_components.future_annotations and import it here. 266 from haystack.testing.sample_components import HelloUsingFutureAnnotations 267 268 assert HelloUsingFutureAnnotations().__haystack_input__._sockets_dict == {"word": InputSocket("word", str)} # type: ignore[attr-defined] 269 270 271 def test_set_output_types(): 272 @component 273 class MockComponent: 274 def __init__(self): 275 component.set_output_types(self, value=int) 276 277 def to_dict(self): 278 return {} 279 280 @classmethod 281 def from_dict(cls, data): 282 return cls() 283 284 def run(self, value: int) -> dict[str, int]: 285 return {"value": 1} 286 287 comp = MockComponent() 288 assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} # type: ignore[attr-defined] 289 290 291 def test_output_types_decorator_with_compatible_type(): 292 @component 293 class MockComponent: 294 @component.output_types(value=int) 295 def run(self, value: int) -> dict[str, int]: 296 return {"value": 1} 297 298 def to_dict(self) -> dict: 299 return {} 300 301 @classmethod 302 def from_dict(cls, data: dict) -> "MockComponent": 303 return cls() 304 305 comp = MockComponent() 306 assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)} # type: ignore[attr-defined] 307 308 309 def test_output_types_decorator_wrong_method(): 310 with pytest.raises(ComponentError): 311 312 @component 313 class MockComponent: 314 def run(self, value: int) -> dict[str, int]: 315 return {"value": 1} 316 317 @component.output_types(value=int) 318 def to_dict(self): 319 return {} 320 321 @classmethod 322 def from_dict(cls, data): 323 return cls() 324 325 326 def test_output_types_decorator_and_set_output_types(): 327 @component 328 class MockComponent: 329 def __init__(self) -> None: 330 component.set_output_types(self, value=int) 331 332 @component.output_types(value=int) 333 def run(self, value: int) -> dict[str, int]: 334 return {"value": 1} 335 336 with pytest.raises(ComponentError, match="Cannot call `set_output_types`"): 337 _ = MockComponent() 338 339 340 def test_output_types_decorator_and_set_output_types_async(): 341 @component 342 class MockComponent: 343 def __init__(self) -> None: 344 component.set_output_types(self, value=int) 345 346 def run(self, value: int) -> dict[str, int]: 347 return {"value": 1} 348 349 @component.output_types(value=int) 350 async def run_async(self, value: int) -> dict[str, int]: 351 return {"value": 1} 352 353 with pytest.raises(ComponentError, match="Cannot call `set_output_types`"): 354 _ = MockComponent() 355 356 357 def test_output_types_decorator_mismatch_run_async_run(): 358 @component 359 class MockComponent: 360 @component.output_types(value=int) 361 def run(self, value: int) -> dict[str, str]: 362 return {"value": "1"} 363 364 @component.output_types(value=str) 365 async def run_async(self, value: int) -> dict[str, str]: 366 return {"value": "1"} 367 368 with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"): 369 _ = MockComponent() 370 371 372 def test_output_types_decorator_missing_async_run(): 373 @component 374 class MockComponent: 375 @component.output_types(value=int) 376 def run(self, value: int) -> dict[str, int]: 377 return {"value": 1} 378 379 async def run_async(self, value: int) -> dict[str, int]: 380 return {"value": 1} 381 382 with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"): 383 _ = MockComponent() 384 385 386 def test_component_decorator_set_it_as_component(): 387 @component 388 class MockComponent: 389 @component.output_types(value=int) 390 def run(self, value: int) -> dict[str, int]: 391 return {"value": 1} 392 393 def to_dict(self) -> dict: 394 return {} 395 396 @classmethod 397 def from_dict(cls, data: dict) -> "MockComponent": 398 return cls() 399 400 comp = MockComponent() 401 assert isinstance(comp, Component) 402 403 404 def test_input_has_default_value(): 405 @component 406 class MockComponent: 407 @component.output_types(value=int) 408 def run(self, value: int = 42) -> dict[str, int]: 409 return {"value": value} 410 411 comp = MockComponent() 412 assert comp.__haystack_input__._sockets_dict["value"].default_value == 42 # type: ignore[attr-defined] 413 assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory # type: ignore[attr-defined] 414 415 416 def test_keyword_only_args(): 417 @component 418 class MockComponent: 419 def __init__(self): 420 component.set_output_types(self, value=int) 421 422 def run(self, *, arg: int) -> dict[str, int]: 423 return {"value": arg} 424 425 comp = MockComponent() 426 component_inputs = { 427 name: {"type": socket.type} 428 for name, socket in comp.__haystack_input__._sockets_dict.items() # type: ignore[attr-defined] 429 } 430 assert component_inputs == {"arg": {"type": int}} 431 432 433 def test_repr(): 434 @component 435 class MockComponent: 436 def __init__(self): 437 component.set_output_types(self, value=int) 438 439 def run(self, value: int) -> dict[str, int]: 440 return {"value": value} 441 442 comp = MockComponent() 443 assert repr(comp) == f"{object.__repr__(comp)}\nInputs:\n - value: int\nOutputs:\n - value: int" 444 445 446 def test_repr_added_to_pipeline(): 447 @component 448 class MockComponent: 449 def __init__(self): 450 component.set_output_types(self, value=int) 451 452 def run(self, value: int) -> dict[str, int]: 453 return {"value": value} 454 455 pipe = Pipeline() 456 comp = MockComponent() 457 pipe.add_component("my_component", comp) 458 assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n - value: int\nOutputs:\n - value: int" 459 460 461 def test_pre_init_hooking(): 462 @component 463 class MockComponent: 464 def __init__(self, pos_arg1, pos_arg2, pos_arg3=None, *, kwarg1=1, kwarg2="string"): 465 self.pos_arg1 = pos_arg1 466 self.pos_arg2 = pos_arg2 467 self.pos_arg3 = pos_arg3 468 self.kwarg1 = kwarg1 469 self.kwarg2 = kwarg2 470 471 @component.output_types(output_value=int) 472 def run(self, input_value: int) -> dict[str, int]: 473 return {"output_value": input_value} 474 475 def pre_init_hook(component_class, init_params, expected_params): 476 assert component_class == MockComponent 477 assert init_params == expected_params 478 479 def pre_init_hook_modify(component_class, init_params, expected_params): 480 assert component_class == MockComponent 481 assert init_params == expected_params 482 483 init_params["pos_arg1"] = 2 484 init_params["pos_arg2"] = 0 485 init_params["pos_arg3"] = "modified" 486 init_params["kwarg2"] = "modified string" 487 488 with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "kwarg1": None})): 489 _ = MockComponent(1, 2, kwarg1=None) 490 491 with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "pos_arg3": 0.01})): 492 _ = MockComponent(pos_arg1=1, pos_arg2=2, pos_arg3=0.01) 493 494 with _hook_component_init( 495 partial(pre_init_hook_modify, expected_params={"pos_arg1": 0, "pos_arg2": 1, "pos_arg3": 0.01, "kwarg1": 0}) 496 ): 497 c = MockComponent(0, 1, pos_arg3=0.01, kwarg1=0) 498 499 assert c.pos_arg1 == 2 500 assert c.pos_arg2 == 0 501 assert c.pos_arg3 == "modified" 502 assert c.kwarg1 == 0 503 assert c.kwarg2 == "modified string" 504 505 506 def test_pre_init_hooking_variadic_positional_args(): 507 @component 508 class MockComponent: 509 def __init__(self, *args, kwarg1=1, kwarg2="string"): 510 self.args = args 511 self.kwarg1 = kwarg1 512 self.kwarg2 = kwarg2 513 514 @component.output_types(output_value=int) 515 def run(self, input_value: int) -> dict[str, int]: 516 return {"output_value": input_value} 517 518 def pre_init_hook(component_class, init_params, expected_params): 519 assert component_class == MockComponent 520 assert init_params == expected_params 521 522 c = MockComponent(1, 2, 3, kwarg1=None) 523 assert c.args == (1, 2, 3) 524 assert c.kwarg1 is None 525 assert c.kwarg2 == "string" 526 527 with ( 528 pytest.raises(ComponentError), 529 _hook_component_init(partial(pre_init_hook, expected_params={"args": (1, 2), "kwarg1": None})), 530 ): 531 _ = MockComponent(1, 2, kwarg1=None) 532 533 534 def test_pre_init_hooking_variadic_kwargs(): 535 @component 536 class MockComponent: 537 def __init__(self, pos_arg1, pos_arg2=None, **kwargs): 538 self.pos_arg1 = pos_arg1 539 self.pos_arg2 = pos_arg2 540 self.kwargs = kwargs 541 542 @component.output_types(output_value=int) 543 def run(self, input_value: int) -> dict[str, int]: 544 return {"output_value": input_value} 545 546 def pre_init_hook(component_class, init_params, expected_params): 547 assert component_class == MockComponent 548 assert init_params == expected_params 549 550 with _hook_component_init( 551 partial(pre_init_hook, expected_params={"pos_arg1": 1, "kwarg1": None, "kwarg2": 10, "kwarg3": "string"}) 552 ): 553 c = MockComponent(1, kwarg1=None, kwarg2=10, kwarg3="string") 554 assert c.pos_arg1 == 1 555 assert c.pos_arg2 is None 556 assert c.kwargs == {"kwarg1": None, "kwarg2": 10, "kwarg3": "string"} 557 558 def pre_init_hook_modify(component_class, init_params, expected_params): 559 assert component_class == MockComponent 560 assert init_params == expected_params 561 562 init_params["pos_arg1"] = 2 563 init_params["pos_arg2"] = 0 564 init_params["some_kwarg"] = "modified string" 565 566 with _hook_component_init( 567 partial( 568 pre_init_hook_modify, 569 expected_params={"pos_arg1": 0, "pos_arg2": 1, "kwarg1": 999, "some_kwarg": "some_value"}, 570 ) 571 ): 572 c = MockComponent(0, 1, kwarg1=999, some_kwarg="some_value") 573 574 assert c.pos_arg1 == 2 575 assert c.pos_arg2 == 0 576 assert c.kwargs == {"kwarg1": 999, "some_kwarg": "modified string"}