/ test / core / super_component / test_super_component.py
test_super_component.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  from unittest.mock import patch
  7  
  8  import pytest
  9  
 10  from haystack import AsyncPipeline, Document, Pipeline, SuperComponent, component, super_component
 11  from haystack.components.builders import AnswerBuilder, PromptBuilder
 12  from haystack.components.generators import OpenAIGenerator
 13  from haystack.components.joiners import DocumentJoiner
 14  from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
 15  from haystack.core.pipeline.base import component_from_dict, component_to_dict
 16  from haystack.core.serialization import default_from_dict, default_to_dict
 17  from haystack.core.super_component.super_component import InvalidMappingTypeError, InvalidMappingValueError
 18  from haystack.dataclasses import GeneratedAnswer
 19  from haystack.document_stores.in_memory import InMemoryDocumentStore
 20  from haystack.document_stores.types import DuplicatePolicy
 21  from haystack.testing.sample_components import AddFixedValue, Double
 22  from haystack.utils.auth import Secret
 23  
 24  
 25  @pytest.fixture
 26  def mock_openai_generator(monkeypatch):
 27      """Create a mock OpenAI Generator for testing."""
 28  
 29      def mock_run(self: Any, prompt: str, **kwargs: Any) -> dict[str, list[str]]:
 30          return {"replies": ["This is a test response about capitals."]}
 31  
 32      monkeypatch.setattr(OpenAIGenerator, "run", mock_run)
 33      return OpenAIGenerator(api_key=Secret.from_token("test-key"))
 34  
 35  
 36  @pytest.fixture
 37  def documents():
 38      """Create test documents for the document store."""
 39      return [
 40          Document(content="Paris is the capital of France."),
 41          Document(content="Berlin is the capital of Germany."),
 42          Document(content="Rome is the capital of Italy."),
 43      ]
 44  
 45  
 46  @pytest.fixture
 47  def document_store(documents):
 48      """Create and populate a test document store."""
 49      store = InMemoryDocumentStore()
 50      store.write_documents(documents, policy=DuplicatePolicy.OVERWRITE)
 51      yield store
 52      store.shutdown()
 53  
 54  
 55  @pytest.fixture
 56  def rag_pipeline(document_store):
 57      """Create a simple RAG pipeline."""
 58  
 59      @component
 60      class FakeGenerator:
 61          @component.output_types(replies=list[str])
 62          def run(self, prompt: str, **kwargs: Any) -> dict[str, list[str]]:
 63              return {"replies": ["This is a test response about capitals."]}
 64  
 65      pipeline = Pipeline()
 66      pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store))
 67      pipeline.add_component(
 68          "prompt_builder",
 69          PromptBuilder(
 70              template="Given these documents: {{documents|join(', ',attribute='content')}} Answer: {{query}}",
 71              required_variables="*",
 72          ),
 73      )
 74      pipeline.add_component("llm", FakeGenerator())
 75      pipeline.add_component("answer_builder", AnswerBuilder())
 76      pipeline.add_component("joiner", DocumentJoiner())
 77  
 78      pipeline.connect("retriever", "prompt_builder.documents")
 79      pipeline.connect("prompt_builder", "llm")
 80      pipeline.connect("llm.replies", "answer_builder.replies")
 81      pipeline.connect("retriever.documents", "joiner.documents")
 82  
 83      return pipeline
 84  
 85  
 86  @pytest.fixture
 87  def async_rag_pipeline(document_store):
 88      """Create a simple asyncRAG pipeline."""
 89  
 90      @component
 91      class FakeGenerator:
 92          @component.output_types(replies=list[str])
 93          def run(self, prompt: str, **kwargs: Any) -> dict[str, list[str]]:
 94              return {"replies": ["This is a test response about capitals."]}
 95  
 96      pipeline = AsyncPipeline()
 97      pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store))
 98      pipeline.add_component(
 99          "prompt_builder",
100          PromptBuilder(
101              template="Given these documents: {{documents|join(', ',attribute='content')}} Answer: {{query}}",
102              required_variables="*",
103          ),
104      )
105      pipeline.add_component("llm", FakeGenerator())
106      pipeline.add_component("answer_builder", AnswerBuilder())
107      pipeline.add_component("joiner", DocumentJoiner())
108  
109      pipeline.connect("retriever", "prompt_builder.documents")
110      pipeline.connect("prompt_builder", "llm")
111      pipeline.connect("llm.replies", "answer_builder.replies")
112      pipeline.connect("retriever.documents", "joiner.documents")
113  
114      return pipeline
115  
116  
117  @pytest.fixture
118  def sample_super_component():
119      """Creates a sample SuperComponent for testing visualization methods"""
120      pipe = Pipeline()
121      pipe.add_component("comp1", AddFixedValue(add=3))
122      pipe.add_component("comp2", Double())
123      pipe.connect("comp1.result", "comp2.value")
124  
125      return SuperComponent(pipeline=pipe)
126  
127  
128  @super_component
129  class CustomSuperComponent:
130      def __init__(self, var1: int, var2: str = "test"):
131          self.var1 = var1
132          self.var2 = var2
133          pipeline = Pipeline()
134          pipeline.add_component("joiner", DocumentJoiner())
135          self.pipeline = pipeline
136  
137  
138  class TestSuperComponent:
139      def test_split_component_path(self):
140          path = "router.chat_query"
141          components = SuperComponent._split_component_path(path)
142          assert components == ("router", "chat_query")
143  
144      def test_split_component_path_error(self):
145          path = "router"
146          with pytest.raises(InvalidMappingValueError):
147              SuperComponent._split_component_path(path)
148  
149      def test_invalid_input_mapping_type(self, rag_pipeline):
150          input_mapping = {"search_query": "not_a_list"}  # Should be a list
151          with pytest.raises(InvalidMappingTypeError):
152              SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping)  # type: ignore[arg-type]
153  
154      def test_invalid_input_mapping_value(self, rag_pipeline):
155          input_mapping = {"search_query": ["nonexistent_component.query"]}
156          with pytest.raises(InvalidMappingValueError):
157              SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping)
158  
159      def test_invalid_output_mapping_type(self, rag_pipeline):
160          output_mapping = {"answer_builder.answers": 123}  # Should be a string
161          with pytest.raises(InvalidMappingTypeError):
162              SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)  # type: ignore[arg-type]
163  
164      def test_invalid_output_mapping_value(self, rag_pipeline):
165          output_mapping = {"nonexistent_component.answers": "final_answers"}
166          with pytest.raises(InvalidMappingValueError):
167              SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)
168  
169      def test_duplicate_output_names(self, rag_pipeline):
170          output_mapping = {
171              "answer_builder.answers": "final_answers",
172              "llm.replies": "final_answers",  # Different path but same output name
173          }
174          with pytest.raises(InvalidMappingValueError):
175              SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)
176  
177      def test_explicit_input_mapping(self, rag_pipeline):
178          input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]}
179          wrapper = SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping)
180          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
181          assert set(input_sockets.keys()) == {"search_query"}
182          assert input_sockets["search_query"].type == str
183  
184      def test_explicit_output_mapping(self, rag_pipeline):
185          output_mapping = {"answer_builder.answers": "final_answers"}
186          wrapper = SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)
187          output_sockets = wrapper.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
188          assert set(output_sockets.keys()) == {"final_answers"}
189          assert output_sockets["final_answers"].type == list[GeneratedAnswer]
190  
191      def test_auto_input_mapping(self, rag_pipeline):
192          wrapper = SuperComponent(pipeline=rag_pipeline)
193          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
194          assert set(input_sockets.keys()) == {
195              "documents",
196              "filters",
197              "meta",
198              "pattern",
199              "query",
200              "reference_pattern",
201              "scale_score",
202              "template",
203              "template_variables",
204              "top_k",
205          }
206  
207      def test_auto_output_mapping(self, rag_pipeline):
208          wrapper = SuperComponent(pipeline=rag_pipeline)
209          output_sockets = wrapper.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
210          assert set(output_sockets.keys()) == {"answers", "documents"}
211  
212      def test_auto_mapping_sockets(self, rag_pipeline):
213          wrapper = SuperComponent(pipeline=rag_pipeline)
214  
215          output_sockets = wrapper.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
216          assert set(output_sockets.keys()) == {"answers", "documents"}
217          assert output_sockets["answers"].type == list[GeneratedAnswer]
218  
219          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
220          assert set(input_sockets.keys()) == {
221              "documents",
222              "filters",
223              "meta",
224              "pattern",
225              "query",
226              "reference_pattern",
227              "scale_score",
228              "template",
229              "template_variables",
230              "top_k",
231          }
232          assert input_sockets["query"].type == str
233  
234      def test_super_component_run(self, rag_pipeline):
235          input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]}
236          output_mapping = {"answer_builder.answers": "final_answers"}
237          wrapper = SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping, output_mapping=output_mapping)
238          wrapper.warm_up()
239          result = wrapper.run(search_query="What is the capital of France?")
240          assert "final_answers" in result
241          assert isinstance(result["final_answers"][0], GeneratedAnswer)
242  
243      @pytest.mark.asyncio
244      async def test_super_component_run_async(self, async_rag_pipeline):
245          input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]}
246          output_mapping = {"answer_builder.answers": "final_answers"}
247          wrapper = SuperComponent(
248              pipeline=async_rag_pipeline, input_mapping=input_mapping, output_mapping=output_mapping
249          )
250          wrapper.warm_up()
251          result = await wrapper.run_async(search_query="What is the capital of France?")
252          assert "final_answers" in result
253          assert isinstance(result["final_answers"][0], GeneratedAnswer)
254  
255      def test_wrapper_serialization(self, document_store):
256          """Test serialization and deserialization of pipeline wrapper."""
257          pipeline = Pipeline()
258          pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store))
259  
260          wrapper = SuperComponent(
261              pipeline=pipeline,
262              input_mapping={"query": ["retriever.query"]},
263              output_mapping={"retriever.documents": "documents"},
264          )
265  
266          # Test serialization
267          serialized = wrapper.to_dict()
268          assert "type" in serialized
269          assert "init_parameters" in serialized
270          assert "pipeline" in serialized["init_parameters"]
271          assert serialized["init_parameters"]["is_pipeline_async"] is False
272  
273          # Test deserialization
274          deserialized = SuperComponent.from_dict(serialized)
275          assert isinstance(deserialized, SuperComponent)
276          assert deserialized.input_mapping == wrapper.input_mapping
277          assert deserialized.output_mapping == wrapper.output_mapping
278  
279          deserialized.warm_up()
280          result = deserialized.run(query="What is the capital of France?")
281          assert "documents" in result
282          assert result["documents"][0].content == "Paris is the capital of France."
283  
284      def test_subclass_serialization(self, rag_pipeline):
285          super_comp = SuperComponent(rag_pipeline)
286          serialized = super_comp.to_dict()
287  
288          @component
289          class CustomSuperComponent(SuperComponent):
290              def __init__(self, pipeline, instance_attribute="test"):
291                  self.instance_attribute = instance_attribute
292                  super(CustomSuperComponent, self).__init__(pipeline)  # noqa: UP008
293  
294              def to_dict(self):
295                  return default_to_dict(
296                      self, instance_attribute=self.instance_attribute, pipeline=self.pipeline.to_dict()
297                  )
298  
299              @classmethod
300              def from_dict(cls, data):
301                  data["init_parameters"]["pipeline"] = Pipeline.from_dict(data["init_parameters"]["pipeline"])
302                  return default_from_dict(cls, data)
303  
304          custom_super_component = CustomSuperComponent(rag_pipeline)
305          custom_serialized = custom_super_component.to_dict()
306  
307          assert custom_serialized["type"] == "test_super_component.CustomSuperComponent"
308          assert custom_super_component._to_super_component_dict() == serialized
309  
310      def test_super_component_non_leaf_output(self, rag_pipeline):
311          # 'retriever' is not a leaf, but should now be allowed
312          output_mapping = {"retriever.documents": "retrieved_docs", "answer_builder.answers": "final_answers"}
313          wrapper = SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping)
314          wrapper.warm_up()
315          result = wrapper.run(query="What is the capital of France?")
316          assert "final_answers" in result  # leaf output
317          assert "retrieved_docs" in result  # non-leaf output
318          assert isinstance(result["retrieved_docs"][0], Document)
319  
320      def test_custom_super_component_to_dict(self, rag_pipeline):
321          custom_super_component = CustomSuperComponent(1)
322          data = component_to_dict(custom_super_component, "custom_super_component")
323          assert data == {
324              "type": "test_super_component.CustomSuperComponent",
325              "init_parameters": {"var1": 1, "var2": "test"},
326          }
327  
328      def test_custom_super_component_from_dict(self):
329          data = {"type": "test_super_component.CustomSuperComponent", "init_parameters": {"var1": 1, "var2": "test"}}
330          custom_super_component = component_from_dict(CustomSuperComponent, data, "custom_super_component")
331          assert isinstance(custom_super_component, CustomSuperComponent)
332          assert custom_super_component.var1 == 1
333          assert custom_super_component.var2 == "test"
334  
335      @patch("haystack.core.pipeline.Pipeline.show")
336      def test_show_delegates_to_pipeline(self, mock_show, sample_super_component):
337          """Test that SuperComponent.show() correctly delegates to Pipeline.show() with all parameters"""
338  
339          server_url = "https://custom.mermaid.server"
340          params = {"theme": "dark", "format": "svg"}
341          timeout = 60
342  
343          sample_super_component.show(server_url=server_url, params=params, timeout=timeout)
344          mock_show.assert_called_once_with(server_url=server_url, params=params, timeout=timeout)
345  
346      @patch("haystack.core.pipeline.Pipeline.draw")
347      def test_draw_delegates_to_pipeline(self, mock_draw, sample_super_component, tmp_path):
348          """Test that SuperComponent.draw() correctly delegates to Pipeline.draw() with all parameters"""
349  
350          path = tmp_path / "test_pipeline.png"
351          server_url = "https://custom.mermaid.server"
352          params = {"theme": "dark", "format": "png"}
353          timeout = 60
354  
355          sample_super_component.draw(path=path, server_url=server_url, params=params, timeout=timeout)
356          mock_draw.assert_called_once_with(path=path, server_url=server_url, params=params, timeout=timeout)
357  
358      @patch("haystack.core.pipeline.Pipeline.show")
359      def test_show_with_default_parameters(self, mock_show, sample_super_component):
360          """Test that SuperComponent.show() works with default parameters"""
361  
362          sample_super_component.show()
363          mock_show.assert_called_once_with(server_url="https://mermaid.ink", params=None, timeout=30)
364  
365      @patch("haystack.core.pipeline.Pipeline.draw")
366      def test_draw_with_default_parameters(self, mock_draw, sample_super_component, tmp_path):
367          """Test that SuperComponent.draw() works with default parameters except path"""
368  
369          path = tmp_path / "test_pipeline.png"
370  
371          sample_super_component.draw(path=path)
372          mock_draw.assert_called_once_with(path=path, server_url="https://mermaid.ink", params=None, timeout=30)
373  
374      def test_input_types_reconciliation(self):
375          """Test that input types are properly reconciled when they are compatible but not identical."""
376  
377          @component
378          class TypeTestComponent:
379              @component.output_types(result_int=int, result_any=Any)
380              def run(self, input_int: int, input_any: Any) -> dict[str, Any]:
381                  return {"result_int": input_int, "result_any": input_any}
382  
383          pipeline = Pipeline()
384          pipeline.add_component("test1", TypeTestComponent())
385          pipeline.add_component("test2", TypeTestComponent())
386  
387          input_mapping = {"number": ["test1.input_int", "test2.input_any"]}
388          output_mapping = {"test2.result_int": "result_int"}
389          wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)
390  
391          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
392          assert "number" in input_sockets
393          assert input_sockets["number"].type == int
394  
395      def test_union_type_reconciliation(self):
396          """Test that Union types are properly reconciled when creating a SuperComponent."""
397  
398          @component
399          class UnionTypeComponent1:
400              @component.output_types(result=int | str)
401              def run(self, inp: int | str) -> dict[str, int | str]:
402                  return {"result": inp}
403  
404          @component
405          class UnionTypeComponent2:
406              @component.output_types(result=float | str)
407              def run(self, inp: float | str) -> dict[str, float | str]:
408                  return {"result": inp}
409  
410          pipeline = Pipeline()
411          pipeline.add_component("test1", UnionTypeComponent1())
412          pipeline.add_component("test2", UnionTypeComponent2())
413  
414          input_mapping = {"data": ["test1.inp", "test2.inp"]}
415          output_mapping = {"test2.result": "result"}
416          wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping)
417  
418          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
419          assert "data" in input_sockets
420          assert input_sockets["data"].type == str
421  
422      def test_input_types_with_any(self):
423          """Test that Any type is properly handled when reconciling types."""
424  
425          @component
426          class AnyTypeComponent:
427              @component.output_types(result=str)
428              def run(self, specific: str, generic: Any) -> dict[str, str]:
429                  return {"result": specific}
430  
431          pipeline = Pipeline()
432          pipeline.add_component("test", AnyTypeComponent())
433  
434          input_mapping = {"text": ["test.specific", "test.generic"]}
435          wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping)
436  
437          input_sockets = wrapper.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
438          assert "text" in input_sockets
439          assert input_sockets["text"].type == str
440  
441      @pytest.mark.asyncio
442      async def test_super_component_async_serialization_deserialization(self):
443          """
444          Test for async SuperComponent serialization and deserialization.
445  
446          Test that when using the SuperComponent class, a SuperComponent based on an async pipeline can be serialized and
447          deserialized correctly.
448          """
449  
450          @component
451          class AsyncComponent:
452              @component.output_types(output=str)
453              def run(self):
454                  return {"output": "irrelevant"}
455  
456              @component.output_types(output=str)
457              async def run_async(self):
458                  return {"output": "Hello world"}
459  
460          pipeline = AsyncPipeline()
461          pipeline.add_component("hello", AsyncComponent())
462  
463          async_super_component = SuperComponent(pipeline=pipeline)
464          serialized_super_component = async_super_component.to_dict()
465          assert serialized_super_component["init_parameters"]["is_pipeline_async"] is True
466  
467          deserialized_super_component = SuperComponent.from_dict(serialized_super_component)
468          assert isinstance(deserialized_super_component.pipeline, AsyncPipeline)
469  
470          result = await deserialized_super_component.run_async()
471          assert result == {"output": "Hello world"}