test_super_component.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 from unittest.mock import patch 7 8 import pytest 9 10 from haystack import AsyncPipeline, Document, Pipeline, SuperComponent, component, super_component 11 from haystack.components.builders import AnswerBuilder, PromptBuilder 12 from haystack.components.generators import OpenAIGenerator 13 from haystack.components.joiners import DocumentJoiner 14 from haystack.components.retrievers.in_memory import InMemoryBM25Retriever 15 from haystack.core.pipeline.base import component_from_dict, component_to_dict 16 from haystack.core.serialization import default_from_dict, default_to_dict 17 from haystack.core.super_component.super_component import InvalidMappingTypeError, InvalidMappingValueError 18 from haystack.dataclasses import GeneratedAnswer 19 from haystack.document_stores.in_memory import InMemoryDocumentStore 20 from haystack.document_stores.types import DuplicatePolicy 21 from haystack.testing.sample_components import AddFixedValue, Double 22 from haystack.utils.auth import Secret 23 24 25 @pytest.fixture 26 def mock_openai_generator(monkeypatch): 27 """Create a mock OpenAI Generator for testing.""" 28 29 def mock_run(self: Any, prompt: str, **kwargs: Any) -> dict[str, list[str]]: 30 return {"replies": ["This is a test response about capitals."]} 31 32 monkeypatch.setattr(OpenAIGenerator, "run", mock_run) 33 return OpenAIGenerator(api_key=Secret.from_token("test-key")) 34 35 36 @pytest.fixture 37 def documents(): 38 """Create test documents for the document store.""" 39 return [ 40 Document(content="Paris is the capital of France."), 41 Document(content="Berlin is the capital of Germany."), 42 Document(content="Rome is the capital of Italy."), 43 ] 44 45 46 @pytest.fixture 47 def document_store(documents): 48 """Create and populate a test document store.""" 49 store = InMemoryDocumentStore() 50 store.write_documents(documents, policy=DuplicatePolicy.OVERWRITE) 51 yield store 52 store.shutdown() 53 54 55 @pytest.fixture 56 def rag_pipeline(document_store): 57 """Create a simple RAG pipeline.""" 58 59 @component 60 class FakeGenerator: 61 @component.output_types(replies=list[str]) 62 def run(self, prompt: str, **kwargs: Any) -> dict[str, list[str]]: 63 return {"replies": ["This is a test response about capitals."]} 64 65 pipeline = Pipeline() 66 pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store)) 67 pipeline.add_component( 68 "prompt_builder", 69 PromptBuilder( 70 template="Given these documents: {{documents|join(', ',attribute='content')}} Answer: {{query}}", 71 required_variables="*", 72 ), 73 ) 74 pipeline.add_component("llm", FakeGenerator()) 75 pipeline.add_component("answer_builder", AnswerBuilder()) 76 pipeline.add_component("joiner", DocumentJoiner()) 77 78 pipeline.connect("retriever", "prompt_builder.documents") 79 pipeline.connect("prompt_builder", "llm") 80 pipeline.connect("llm.replies", "answer_builder.replies") 81 pipeline.connect("retriever.documents", "joiner.documents") 82 83 return pipeline 84 85 86 @pytest.fixture 87 def async_rag_pipeline(document_store): 88 """Create a simple asyncRAG pipeline.""" 89 90 @component 91 class FakeGenerator: 92 @component.output_types(replies=list[str]) 93 def run(self, prompt: str, **kwargs: Any) -> dict[str, list[str]]: 94 return {"replies": ["This is a test response about capitals."]} 95 96 pipeline = AsyncPipeline() 97 pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store)) 98 pipeline.add_component( 99 "prompt_builder", 100 PromptBuilder( 101 template="Given these documents: {{documents|join(', ',attribute='content')}} Answer: {{query}}", 102 required_variables="*", 103 ), 104 ) 105 pipeline.add_component("llm", FakeGenerator()) 106 pipeline.add_component("answer_builder", AnswerBuilder()) 107 pipeline.add_component("joiner", DocumentJoiner()) 108 109 pipeline.connect("retriever", "prompt_builder.documents") 110 pipeline.connect("prompt_builder", "llm") 111 pipeline.connect("llm.replies", "answer_builder.replies") 112 pipeline.connect("retriever.documents", "joiner.documents") 113 114 return pipeline 115 116 117 @pytest.fixture 118 def sample_super_component(): 119 """Creates a sample SuperComponent for testing visualization methods""" 120 pipe = Pipeline() 121 pipe.add_component("comp1", AddFixedValue(add=3)) 122 pipe.add_component("comp2", Double()) 123 pipe.connect("comp1.result", "comp2.value") 124 125 return SuperComponent(pipeline=pipe) 126 127 128 @super_component 129 class CustomSuperComponent: 130 def __init__(self, var1: int, var2: str = "test"): 131 self.var1 = var1 132 self.var2 = var2 133 pipeline = Pipeline() 134 pipeline.add_component("joiner", DocumentJoiner()) 135 self.pipeline = pipeline 136 137 138 class TestSuperComponent: 139 def test_split_component_path(self): 140 path = "router.chat_query" 141 components = SuperComponent._split_component_path(path) 142 assert components == ("router", "chat_query") 143 144 def test_split_component_path_error(self): 145 path = "router" 146 with pytest.raises(InvalidMappingValueError): 147 SuperComponent._split_component_path(path) 148 149 def test_invalid_input_mapping_type(self, rag_pipeline): 150 input_mapping = {"search_query": "not_a_list"} # Should be a list 151 with pytest.raises(InvalidMappingTypeError): 152 SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping) # type: ignore[arg-type] 153 154 def test_invalid_input_mapping_value(self, rag_pipeline): 155 input_mapping = {"search_query": ["nonexistent_component.query"]} 156 with pytest.raises(InvalidMappingValueError): 157 SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping) 158 159 def test_invalid_output_mapping_type(self, rag_pipeline): 160 output_mapping = {"answer_builder.answers": 123} # Should be a string 161 with pytest.raises(InvalidMappingTypeError): 162 SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping) # type: ignore[arg-type] 163 164 def test_invalid_output_mapping_value(self, rag_pipeline): 165 output_mapping = {"nonexistent_component.answers": "final_answers"} 166 with pytest.raises(InvalidMappingValueError): 167 SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping) 168 169 def test_duplicate_output_names(self, rag_pipeline): 170 output_mapping = { 171 "answer_builder.answers": "final_answers", 172 "llm.replies": "final_answers", # Different path but same output name 173 } 174 with pytest.raises(InvalidMappingValueError): 175 SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping) 176 177 def test_explicit_input_mapping(self, rag_pipeline): 178 input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]} 179 wrapper = SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping) 180 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 181 assert set(input_sockets.keys()) == {"search_query"} 182 assert input_sockets["search_query"].type == str 183 184 def test_explicit_output_mapping(self, rag_pipeline): 185 output_mapping = {"answer_builder.answers": "final_answers"} 186 wrapper = SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping) 187 output_sockets = wrapper.__haystack_output__._sockets_dict # type: ignore[attr-defined] 188 assert set(output_sockets.keys()) == {"final_answers"} 189 assert output_sockets["final_answers"].type == list[GeneratedAnswer] 190 191 def test_auto_input_mapping(self, rag_pipeline): 192 wrapper = SuperComponent(pipeline=rag_pipeline) 193 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 194 assert set(input_sockets.keys()) == { 195 "documents", 196 "filters", 197 "meta", 198 "pattern", 199 "query", 200 "reference_pattern", 201 "scale_score", 202 "template", 203 "template_variables", 204 "top_k", 205 } 206 207 def test_auto_output_mapping(self, rag_pipeline): 208 wrapper = SuperComponent(pipeline=rag_pipeline) 209 output_sockets = wrapper.__haystack_output__._sockets_dict # type: ignore[attr-defined] 210 assert set(output_sockets.keys()) == {"answers", "documents"} 211 212 def test_auto_mapping_sockets(self, rag_pipeline): 213 wrapper = SuperComponent(pipeline=rag_pipeline) 214 215 output_sockets = wrapper.__haystack_output__._sockets_dict # type: ignore[attr-defined] 216 assert set(output_sockets.keys()) == {"answers", "documents"} 217 assert output_sockets["answers"].type == list[GeneratedAnswer] 218 219 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 220 assert set(input_sockets.keys()) == { 221 "documents", 222 "filters", 223 "meta", 224 "pattern", 225 "query", 226 "reference_pattern", 227 "scale_score", 228 "template", 229 "template_variables", 230 "top_k", 231 } 232 assert input_sockets["query"].type == str 233 234 def test_super_component_run(self, rag_pipeline): 235 input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]} 236 output_mapping = {"answer_builder.answers": "final_answers"} 237 wrapper = SuperComponent(pipeline=rag_pipeline, input_mapping=input_mapping, output_mapping=output_mapping) 238 wrapper.warm_up() 239 result = wrapper.run(search_query="What is the capital of France?") 240 assert "final_answers" in result 241 assert isinstance(result["final_answers"][0], GeneratedAnswer) 242 243 @pytest.mark.asyncio 244 async def test_super_component_run_async(self, async_rag_pipeline): 245 input_mapping = {"search_query": ["retriever.query", "prompt_builder.query", "answer_builder.query"]} 246 output_mapping = {"answer_builder.answers": "final_answers"} 247 wrapper = SuperComponent( 248 pipeline=async_rag_pipeline, input_mapping=input_mapping, output_mapping=output_mapping 249 ) 250 wrapper.warm_up() 251 result = await wrapper.run_async(search_query="What is the capital of France?") 252 assert "final_answers" in result 253 assert isinstance(result["final_answers"][0], GeneratedAnswer) 254 255 def test_wrapper_serialization(self, document_store): 256 """Test serialization and deserialization of pipeline wrapper.""" 257 pipeline = Pipeline() 258 pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store)) 259 260 wrapper = SuperComponent( 261 pipeline=pipeline, 262 input_mapping={"query": ["retriever.query"]}, 263 output_mapping={"retriever.documents": "documents"}, 264 ) 265 266 # Test serialization 267 serialized = wrapper.to_dict() 268 assert "type" in serialized 269 assert "init_parameters" in serialized 270 assert "pipeline" in serialized["init_parameters"] 271 assert serialized["init_parameters"]["is_pipeline_async"] is False 272 273 # Test deserialization 274 deserialized = SuperComponent.from_dict(serialized) 275 assert isinstance(deserialized, SuperComponent) 276 assert deserialized.input_mapping == wrapper.input_mapping 277 assert deserialized.output_mapping == wrapper.output_mapping 278 279 deserialized.warm_up() 280 result = deserialized.run(query="What is the capital of France?") 281 assert "documents" in result 282 assert result["documents"][0].content == "Paris is the capital of France." 283 284 def test_subclass_serialization(self, rag_pipeline): 285 super_comp = SuperComponent(rag_pipeline) 286 serialized = super_comp.to_dict() 287 288 @component 289 class CustomSuperComponent(SuperComponent): 290 def __init__(self, pipeline, instance_attribute="test"): 291 self.instance_attribute = instance_attribute 292 super(CustomSuperComponent, self).__init__(pipeline) # noqa: UP008 293 294 def to_dict(self): 295 return default_to_dict( 296 self, instance_attribute=self.instance_attribute, pipeline=self.pipeline.to_dict() 297 ) 298 299 @classmethod 300 def from_dict(cls, data): 301 data["init_parameters"]["pipeline"] = Pipeline.from_dict(data["init_parameters"]["pipeline"]) 302 return default_from_dict(cls, data) 303 304 custom_super_component = CustomSuperComponent(rag_pipeline) 305 custom_serialized = custom_super_component.to_dict() 306 307 assert custom_serialized["type"] == "test_super_component.CustomSuperComponent" 308 assert custom_super_component._to_super_component_dict() == serialized 309 310 def test_super_component_non_leaf_output(self, rag_pipeline): 311 # 'retriever' is not a leaf, but should now be allowed 312 output_mapping = {"retriever.documents": "retrieved_docs", "answer_builder.answers": "final_answers"} 313 wrapper = SuperComponent(pipeline=rag_pipeline, output_mapping=output_mapping) 314 wrapper.warm_up() 315 result = wrapper.run(query="What is the capital of France?") 316 assert "final_answers" in result # leaf output 317 assert "retrieved_docs" in result # non-leaf output 318 assert isinstance(result["retrieved_docs"][0], Document) 319 320 def test_custom_super_component_to_dict(self, rag_pipeline): 321 custom_super_component = CustomSuperComponent(1) 322 data = component_to_dict(custom_super_component, "custom_super_component") 323 assert data == { 324 "type": "test_super_component.CustomSuperComponent", 325 "init_parameters": {"var1": 1, "var2": "test"}, 326 } 327 328 def test_custom_super_component_from_dict(self): 329 data = {"type": "test_super_component.CustomSuperComponent", "init_parameters": {"var1": 1, "var2": "test"}} 330 custom_super_component = component_from_dict(CustomSuperComponent, data, "custom_super_component") 331 assert isinstance(custom_super_component, CustomSuperComponent) 332 assert custom_super_component.var1 == 1 333 assert custom_super_component.var2 == "test" 334 335 @patch("haystack.core.pipeline.Pipeline.show") 336 def test_show_delegates_to_pipeline(self, mock_show, sample_super_component): 337 """Test that SuperComponent.show() correctly delegates to Pipeline.show() with all parameters""" 338 339 server_url = "https://custom.mermaid.server" 340 params = {"theme": "dark", "format": "svg"} 341 timeout = 60 342 343 sample_super_component.show(server_url=server_url, params=params, timeout=timeout) 344 mock_show.assert_called_once_with(server_url=server_url, params=params, timeout=timeout) 345 346 @patch("haystack.core.pipeline.Pipeline.draw") 347 def test_draw_delegates_to_pipeline(self, mock_draw, sample_super_component, tmp_path): 348 """Test that SuperComponent.draw() correctly delegates to Pipeline.draw() with all parameters""" 349 350 path = tmp_path / "test_pipeline.png" 351 server_url = "https://custom.mermaid.server" 352 params = {"theme": "dark", "format": "png"} 353 timeout = 60 354 355 sample_super_component.draw(path=path, server_url=server_url, params=params, timeout=timeout) 356 mock_draw.assert_called_once_with(path=path, server_url=server_url, params=params, timeout=timeout) 357 358 @patch("haystack.core.pipeline.Pipeline.show") 359 def test_show_with_default_parameters(self, mock_show, sample_super_component): 360 """Test that SuperComponent.show() works with default parameters""" 361 362 sample_super_component.show() 363 mock_show.assert_called_once_with(server_url="https://mermaid.ink", params=None, timeout=30) 364 365 @patch("haystack.core.pipeline.Pipeline.draw") 366 def test_draw_with_default_parameters(self, mock_draw, sample_super_component, tmp_path): 367 """Test that SuperComponent.draw() works with default parameters except path""" 368 369 path = tmp_path / "test_pipeline.png" 370 371 sample_super_component.draw(path=path) 372 mock_draw.assert_called_once_with(path=path, server_url="https://mermaid.ink", params=None, timeout=30) 373 374 def test_input_types_reconciliation(self): 375 """Test that input types are properly reconciled when they are compatible but not identical.""" 376 377 @component 378 class TypeTestComponent: 379 @component.output_types(result_int=int, result_any=Any) 380 def run(self, input_int: int, input_any: Any) -> dict[str, Any]: 381 return {"result_int": input_int, "result_any": input_any} 382 383 pipeline = Pipeline() 384 pipeline.add_component("test1", TypeTestComponent()) 385 pipeline.add_component("test2", TypeTestComponent()) 386 387 input_mapping = {"number": ["test1.input_int", "test2.input_any"]} 388 output_mapping = {"test2.result_int": "result_int"} 389 wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping) 390 391 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 392 assert "number" in input_sockets 393 assert input_sockets["number"].type == int 394 395 def test_union_type_reconciliation(self): 396 """Test that Union types are properly reconciled when creating a SuperComponent.""" 397 398 @component 399 class UnionTypeComponent1: 400 @component.output_types(result=int | str) 401 def run(self, inp: int | str) -> dict[str, int | str]: 402 return {"result": inp} 403 404 @component 405 class UnionTypeComponent2: 406 @component.output_types(result=float | str) 407 def run(self, inp: float | str) -> dict[str, float | str]: 408 return {"result": inp} 409 410 pipeline = Pipeline() 411 pipeline.add_component("test1", UnionTypeComponent1()) 412 pipeline.add_component("test2", UnionTypeComponent2()) 413 414 input_mapping = {"data": ["test1.inp", "test2.inp"]} 415 output_mapping = {"test2.result": "result"} 416 wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping, output_mapping=output_mapping) 417 418 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 419 assert "data" in input_sockets 420 assert input_sockets["data"].type == str 421 422 def test_input_types_with_any(self): 423 """Test that Any type is properly handled when reconciling types.""" 424 425 @component 426 class AnyTypeComponent: 427 @component.output_types(result=str) 428 def run(self, specific: str, generic: Any) -> dict[str, str]: 429 return {"result": specific} 430 431 pipeline = Pipeline() 432 pipeline.add_component("test", AnyTypeComponent()) 433 434 input_mapping = {"text": ["test.specific", "test.generic"]} 435 wrapper = SuperComponent(pipeline=pipeline, input_mapping=input_mapping) 436 437 input_sockets = wrapper.__haystack_input__._sockets_dict # type: ignore[attr-defined] 438 assert "text" in input_sockets 439 assert input_sockets["text"].type == str 440 441 @pytest.mark.asyncio 442 async def test_super_component_async_serialization_deserialization(self): 443 """ 444 Test for async SuperComponent serialization and deserialization. 445 446 Test that when using the SuperComponent class, a SuperComponent based on an async pipeline can be serialized and 447 deserialized correctly. 448 """ 449 450 @component 451 class AsyncComponent: 452 @component.output_types(output=str) 453 def run(self): 454 return {"output": "irrelevant"} 455 456 @component.output_types(output=str) 457 async def run_async(self): 458 return {"output": "Hello world"} 459 460 pipeline = AsyncPipeline() 461 pipeline.add_component("hello", AsyncComponent()) 462 463 async_super_component = SuperComponent(pipeline=pipeline) 464 serialized_super_component = async_super_component.to_dict() 465 assert serialized_super_component["init_parameters"]["is_pipeline_async"] is True 466 467 deserialized_super_component = SuperComponent.from_dict(serialized_super_component) 468 assert isinstance(deserialized_super_component.pipeline, AsyncPipeline) 469 470 result = await deserialized_super_component.run_async() 471 assert result == {"output": "Hello world"}