/ test / core / component / test_component.py
test_component.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from functools import partial
  6  from typing import Any
  7  
  8  import pytest
  9  
 10  from haystack.core.component import Component, InputSocket, OutputSocket, component
 11  from haystack.core.component.component import _hook_component_init
 12  from haystack.core.errors import ComponentError
 13  from haystack.core.pipeline import Pipeline
 14  
 15  
 16  def test_correct_declaration():
 17      @component
 18      class MockComponent:
 19          def to_dict(self):
 20              return {}
 21  
 22          @classmethod
 23          def from_dict(cls, data):
 24              return cls()
 25  
 26          @component.output_types(output_value=int)
 27          def run(self, input_value: int) -> dict[str, int]:
 28              return {"output_value": input_value}
 29  
 30      # Verifies also instantiation works with no issues
 31      assert MockComponent()
 32      assert component.registry["test_component.MockComponent"] == MockComponent
 33      assert isinstance(MockComponent(), Component)
 34      assert MockComponent().__haystack_supports_async__ is False  # type: ignore[attr-defined]
 35  
 36  
 37  def test_correct_declaration_with_async():
 38      @component
 39      class MockComponent:
 40          def to_dict(self):
 41              return {}
 42  
 43          @classmethod
 44          def from_dict(cls, data):
 45              return cls()
 46  
 47          @component.output_types(output_value=int)
 48          def run(self, input_value: int) -> dict[str, int]:
 49              return {"output_value": input_value}
 50  
 51          @component.output_types(output_value=int)
 52          async def run_async(self, input_value: int) -> dict[str, int]:
 53              return {"output_value": input_value}
 54  
 55      # Verifies also instantiation works with no issues
 56      assert MockComponent()
 57      assert component.registry["test_component.MockComponent"] == MockComponent
 58      assert isinstance(MockComponent(), Component)
 59      assert MockComponent().__haystack_supports_async__ is True  # type: ignore[attr-defined]
 60  
 61  
 62  def test_correct_declaration_with_additional_readonly_property():
 63      @component
 64      class MockComponent:
 65          @property
 66          def store(self):
 67              return "test_store"
 68  
 69          def to_dict(self):
 70              return {}
 71  
 72          @classmethod
 73          def from_dict(cls, data):
 74              return cls()
 75  
 76          @component.output_types(output_value=int)
 77          def run(self, input_value: int) -> dict[str, int]:
 78              return {"output_value": input_value}
 79  
 80      # Verifies that instantiation works with no issues
 81      assert MockComponent()
 82      assert component.registry["test_component.MockComponent"] == MockComponent
 83      assert MockComponent().store == "test_store"
 84  
 85  
 86  def test_correct_declaration_with_additional_writable_property():
 87      @component
 88      class MockComponent:
 89          @property
 90          def store(self):
 91              return "test_store"
 92  
 93          @store.setter
 94          def store(self, value):
 95              self._store = value
 96  
 97          def to_dict(self):
 98              return {}
 99  
100          @classmethod
101          def from_dict(cls, data):
102              return cls()
103  
104          @component.output_types(output_value=int)
105          def run(self, input_value: int) -> dict[str, int]:
106              return {"output_value": input_value}
107  
108      # Verifies that instantiation works with no issues
109      assert component.registry["test_component.MockComponent"] == MockComponent
110      comp = MockComponent()
111      comp.store = "test_store"
112      assert comp.store == "test_store"
113  
114  
115  def test_missing_run():
116      with pytest.raises(ComponentError, match=r"must have a 'run\(\)' method"):
117  
118          @component
119          class MockComponent:  # type: ignore[type-var]
120              def another_method(self, input_value: int) -> dict[str, int]:
121                  return {"output_value": input_value}
122  
123  
124  def test_async_run_not_async():
125      @component
126      class MockComponent:
127          @component.output_types(value=int)
128          def run(self, value: int) -> dict[str, int]:
129              return {"value": 1}
130  
131          @component.output_types(value=int)
132          def run_async(self, value: int) -> dict[str, int]:
133              return {"value": 1}
134  
135      with pytest.raises(ComponentError, match=r"must be a coroutine"):
136          _ = MockComponent()
137  
138  
139  def test_async_run_not_coroutine():
140      @component
141      class MockComponent:
142          @component.output_types(value=int)
143          def run(self, value: int) -> dict[str, int]:
144              return {"value": 1}
145  
146          @component.output_types(value=int)
147          async def run_async(self, value: int) -> dict[str, int]:  # type: ignore[misc]
148              yield {"value": 1}
149  
150      with pytest.raises(ComponentError, match=r"must be a coroutine"):
151          _ = MockComponent()
152  
153  
154  def test_parameters_mismatch_run_and_async_run():
155      err_msg = r"Parameters of 'run' and 'run_async' methods must be the same"
156  
157      @component
158      class MockComponentMismatchingInputTypes:
159          @component.output_types(value=int)
160          def run(self, value: int) -> dict[str, int]:
161              return {"value": 1}
162  
163          @component.output_types(value=int)
164          async def run_async(self, value: str) -> dict[str, int]:
165              return {"value": 1}
166  
167      with pytest.raises(ComponentError, match=err_msg):
168          _ = MockComponentMismatchingInputTypes()
169  
170      @component
171      class MockComponentMismatchingInputs:
172          @component.output_types(value=int)
173          def run(self, value: int, **kwargs: Any) -> dict[str, int]:
174              return {"value": 1}
175  
176          @component.output_types(value=int)
177          async def run_async(self, value: int) -> dict[str, int]:
178              return {"value": 1}
179  
180      with pytest.raises(ComponentError, match=err_msg):
181          _ = MockComponentMismatchingInputs()
182  
183      @component
184      class MockComponentMismatchingInputOrder:
185          @component.output_types(value=int)
186          def run(self, value: int, another: str) -> dict[str, int]:
187              return {"value": 1}
188  
189          @component.output_types(value=int)
190          async def run_async(self, another: str, value: int) -> dict[str, int]:
191              return {"value": 1}
192  
193      with pytest.raises(ComponentError, match=err_msg):
194          _ = MockComponentMismatchingInputOrder()
195  
196  
197  def test_set_input_types():
198      @component
199      class MockComponent:
200          def __init__(self, flag: bool):
201              component.set_input_types(self, value=Any)
202              if flag:
203                  component.set_input_type(self, name="another", type=str)
204  
205          @component.output_types(value=int)
206          def run(self, **kwargs):
207              return {"value": 1}
208  
209      comp = MockComponent(False)
210      assert comp.__haystack_input__._sockets_dict == {"value": InputSocket("value", Any)}  # type: ignore[attr-defined]
211      assert comp.run() == {"value": 1}
212  
213      comp = MockComponent(True)
214      assert comp.__haystack_input__._sockets_dict == {  # type: ignore[attr-defined]
215          "value": InputSocket("value", Any),
216          "another": InputSocket("another", str),
217      }
218      assert comp.run() == {"value": 1}
219  
220  
221  def test_set_input_types_no_kwarg():
222      @component
223      class MockComponent:
224          def __init__(self, flag: bool):
225              if flag:
226                  component.set_input_type(self, name="another", type=str)
227              else:
228                  component.set_input_types(self, value=Any)
229  
230          @component.output_types(value=int)
231          def run(self, fini: bool) -> dict[str, int]:
232              return {"value": 1}
233  
234      with pytest.raises(ComponentError, match=r"doesn't have a kwargs parameter"):
235          _ = MockComponent(False)
236  
237      with pytest.raises(ComponentError, match=r"doesn't have a kwargs parameter"):
238          _ = MockComponent(True)
239  
240  
241  def test_set_input_types_overrides_run():
242      @component
243      class MockComponent:
244          def __init__(self, state: bool):
245              if state:
246                  component.set_input_type(self, name="fini", type=str)
247              else:
248                  component.set_input_types(self, fini=Any)
249  
250          @component.output_types(value=int)
251          def run(self, fini: bool, **kwargs: Any) -> dict[str, int]:
252              return {"value": 1}
253  
254      err_msg = "cannot override the parameters of the 'run' method"
255      with pytest.raises(ComponentError, match=err_msg):
256          _ = MockComponent(False)
257  
258      with pytest.raises(ComponentError, match=err_msg):
259          _ = MockComponent(True)
260  
261  
262  def test_set_input_types_postponed_annotations():
263      # The component HelloUsingFutureAnnotations must live in a different module than the one where the test is defined,
264      # so we can properly set up postponed evaluation of annotations using `from __future__ import annotations`.
265      # For this reason, we define it in haystack.testing.sample_components.future_annotations and import it here.
266      from haystack.testing.sample_components import HelloUsingFutureAnnotations
267  
268      assert HelloUsingFutureAnnotations().__haystack_input__._sockets_dict == {"word": InputSocket("word", str)}  # type: ignore[attr-defined]
269  
270  
271  def test_set_output_types():
272      @component
273      class MockComponent:
274          def __init__(self):
275              component.set_output_types(self, value=int)
276  
277          def to_dict(self):
278              return {}
279  
280          @classmethod
281          def from_dict(cls, data):
282              return cls()
283  
284          def run(self, value: int) -> dict[str, int]:
285              return {"value": 1}
286  
287      comp = MockComponent()
288      assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}  # type: ignore[attr-defined]
289  
290  
291  def test_output_types_decorator_with_compatible_type():
292      @component
293      class MockComponent:
294          @component.output_types(value=int)
295          def run(self, value: int) -> dict[str, int]:
296              return {"value": 1}
297  
298          def to_dict(self) -> dict:
299              return {}
300  
301          @classmethod
302          def from_dict(cls, data: dict) -> "MockComponent":
303              return cls()
304  
305      comp = MockComponent()
306      assert comp.__haystack_output__._sockets_dict == {"value": OutputSocket("value", int)}  # type: ignore[attr-defined]
307  
308  
309  def test_output_types_decorator_wrong_method():
310      with pytest.raises(ComponentError):
311  
312          @component
313          class MockComponent:
314              def run(self, value: int) -> dict[str, int]:
315                  return {"value": 1}
316  
317              @component.output_types(value=int)
318              def to_dict(self):
319                  return {}
320  
321              @classmethod
322              def from_dict(cls, data):
323                  return cls()
324  
325  
326  def test_output_types_decorator_and_set_output_types():
327      @component
328      class MockComponent:
329          def __init__(self) -> None:
330              component.set_output_types(self, value=int)
331  
332          @component.output_types(value=int)
333          def run(self, value: int) -> dict[str, int]:
334              return {"value": 1}
335  
336      with pytest.raises(ComponentError, match="Cannot call `set_output_types`"):
337          _ = MockComponent()
338  
339  
340  def test_output_types_decorator_and_set_output_types_async():
341      @component
342      class MockComponent:
343          def __init__(self) -> None:
344              component.set_output_types(self, value=int)
345  
346          def run(self, value: int) -> dict[str, int]:
347              return {"value": 1}
348  
349          @component.output_types(value=int)
350          async def run_async(self, value: int) -> dict[str, int]:
351              return {"value": 1}
352  
353      with pytest.raises(ComponentError, match="Cannot call `set_output_types`"):
354          _ = MockComponent()
355  
356  
357  def test_output_types_decorator_mismatch_run_async_run():
358      @component
359      class MockComponent:
360          @component.output_types(value=int)
361          def run(self, value: int) -> dict[str, str]:
362              return {"value": "1"}
363  
364          @component.output_types(value=str)
365          async def run_async(self, value: int) -> dict[str, str]:
366              return {"value": "1"}
367  
368      with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"):
369          _ = MockComponent()
370  
371  
372  def test_output_types_decorator_missing_async_run():
373      @component
374      class MockComponent:
375          @component.output_types(value=int)
376          def run(self, value: int) -> dict[str, int]:
377              return {"value": 1}
378  
379          async def run_async(self, value: int) -> dict[str, int]:
380              return {"value": 1}
381  
382      with pytest.raises(ComponentError, match=r"Output type specifications .* must be the same"):
383          _ = MockComponent()
384  
385  
386  def test_component_decorator_set_it_as_component():
387      @component
388      class MockComponent:
389          @component.output_types(value=int)
390          def run(self, value: int) -> dict[str, int]:
391              return {"value": 1}
392  
393          def to_dict(self) -> dict:
394              return {}
395  
396          @classmethod
397          def from_dict(cls, data: dict) -> "MockComponent":
398              return cls()
399  
400      comp = MockComponent()
401      assert isinstance(comp, Component)
402  
403  
404  def test_input_has_default_value():
405      @component
406      class MockComponent:
407          @component.output_types(value=int)
408          def run(self, value: int = 42) -> dict[str, int]:
409              return {"value": value}
410  
411      comp = MockComponent()
412      assert comp.__haystack_input__._sockets_dict["value"].default_value == 42  # type: ignore[attr-defined]
413      assert not comp.__haystack_input__._sockets_dict["value"].is_mandatory  # type: ignore[attr-defined]
414  
415  
416  def test_keyword_only_args():
417      @component
418      class MockComponent:
419          def __init__(self):
420              component.set_output_types(self, value=int)
421  
422          def run(self, *, arg: int) -> dict[str, int]:
423              return {"value": arg}
424  
425      comp = MockComponent()
426      component_inputs = {
427          name: {"type": socket.type}
428          for name, socket in comp.__haystack_input__._sockets_dict.items()  # type: ignore[attr-defined]
429      }
430      assert component_inputs == {"arg": {"type": int}}
431  
432  
433  def test_repr():
434      @component
435      class MockComponent:
436          def __init__(self):
437              component.set_output_types(self, value=int)
438  
439          def run(self, value: int) -> dict[str, int]:
440              return {"value": value}
441  
442      comp = MockComponent()
443      assert repr(comp) == f"{object.__repr__(comp)}\nInputs:\n  - value: int\nOutputs:\n  - value: int"
444  
445  
446  def test_repr_added_to_pipeline():
447      @component
448      class MockComponent:
449          def __init__(self):
450              component.set_output_types(self, value=int)
451  
452          def run(self, value: int) -> dict[str, int]:
453              return {"value": value}
454  
455      pipe = Pipeline()
456      comp = MockComponent()
457      pipe.add_component("my_component", comp)
458      assert repr(comp) == f"{object.__repr__(comp)}\nmy_component\nInputs:\n  - value: int\nOutputs:\n  - value: int"
459  
460  
461  def test_pre_init_hooking():
462      @component
463      class MockComponent:
464          def __init__(self, pos_arg1, pos_arg2, pos_arg3=None, *, kwarg1=1, kwarg2="string"):
465              self.pos_arg1 = pos_arg1
466              self.pos_arg2 = pos_arg2
467              self.pos_arg3 = pos_arg3
468              self.kwarg1 = kwarg1
469              self.kwarg2 = kwarg2
470  
471          @component.output_types(output_value=int)
472          def run(self, input_value: int) -> dict[str, int]:
473              return {"output_value": input_value}
474  
475      def pre_init_hook(component_class, init_params, expected_params):
476          assert component_class == MockComponent
477          assert init_params == expected_params
478  
479      def pre_init_hook_modify(component_class, init_params, expected_params):
480          assert component_class == MockComponent
481          assert init_params == expected_params
482  
483          init_params["pos_arg1"] = 2
484          init_params["pos_arg2"] = 0
485          init_params["pos_arg3"] = "modified"
486          init_params["kwarg2"] = "modified string"
487  
488      with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "kwarg1": None})):
489          _ = MockComponent(1, 2, kwarg1=None)
490  
491      with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "pos_arg3": 0.01})):
492          _ = MockComponent(pos_arg1=1, pos_arg2=2, pos_arg3=0.01)
493  
494      with _hook_component_init(
495          partial(pre_init_hook_modify, expected_params={"pos_arg1": 0, "pos_arg2": 1, "pos_arg3": 0.01, "kwarg1": 0})
496      ):
497          c = MockComponent(0, 1, pos_arg3=0.01, kwarg1=0)
498  
499          assert c.pos_arg1 == 2
500          assert c.pos_arg2 == 0
501          assert c.pos_arg3 == "modified"
502          assert c.kwarg1 == 0
503          assert c.kwarg2 == "modified string"
504  
505  
506  def test_pre_init_hooking_variadic_positional_args():
507      @component
508      class MockComponent:
509          def __init__(self, *args, kwarg1=1, kwarg2="string"):
510              self.args = args
511              self.kwarg1 = kwarg1
512              self.kwarg2 = kwarg2
513  
514          @component.output_types(output_value=int)
515          def run(self, input_value: int) -> dict[str, int]:
516              return {"output_value": input_value}
517  
518      def pre_init_hook(component_class, init_params, expected_params):
519          assert component_class == MockComponent
520          assert init_params == expected_params
521  
522      c = MockComponent(1, 2, 3, kwarg1=None)
523      assert c.args == (1, 2, 3)
524      assert c.kwarg1 is None
525      assert c.kwarg2 == "string"
526  
527      with (
528          pytest.raises(ComponentError),
529          _hook_component_init(partial(pre_init_hook, expected_params={"args": (1, 2), "kwarg1": None})),
530      ):
531          _ = MockComponent(1, 2, kwarg1=None)
532  
533  
534  def test_pre_init_hooking_variadic_kwargs():
535      @component
536      class MockComponent:
537          def __init__(self, pos_arg1, pos_arg2=None, **kwargs):
538              self.pos_arg1 = pos_arg1
539              self.pos_arg2 = pos_arg2
540              self.kwargs = kwargs
541  
542          @component.output_types(output_value=int)
543          def run(self, input_value: int) -> dict[str, int]:
544              return {"output_value": input_value}
545  
546      def pre_init_hook(component_class, init_params, expected_params):
547          assert component_class == MockComponent
548          assert init_params == expected_params
549  
550      with _hook_component_init(
551          partial(pre_init_hook, expected_params={"pos_arg1": 1, "kwarg1": None, "kwarg2": 10, "kwarg3": "string"})
552      ):
553          c = MockComponent(1, kwarg1=None, kwarg2=10, kwarg3="string")
554          assert c.pos_arg1 == 1
555          assert c.pos_arg2 is None
556          assert c.kwargs == {"kwarg1": None, "kwarg2": 10, "kwarg3": "string"}
557  
558      def pre_init_hook_modify(component_class, init_params, expected_params):
559          assert component_class == MockComponent
560          assert init_params == expected_params
561  
562          init_params["pos_arg1"] = 2
563          init_params["pos_arg2"] = 0
564          init_params["some_kwarg"] = "modified string"
565  
566      with _hook_component_init(
567          partial(
568              pre_init_hook_modify,
569              expected_params={"pos_arg1": 0, "pos_arg2": 1, "kwarg1": 999, "some_kwarg": "some_value"},
570          )
571      ):
572          c = MockComponent(0, 1, kwarg1=999, some_kwarg="some_value")
573  
574          assert c.pos_arg1 == 2
575          assert c.pos_arg2 == 0
576          assert c.kwargs == {"kwarg1": 999, "some_kwarg": "modified string"}