test_pipeline_tool.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from unittest.mock import ANY 7 8 import pytest 9 10 from haystack import AsyncPipeline, Document, Pipeline 11 from haystack.components.agents import Agent 12 from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder 13 from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder 14 from haystack.components.generators.chat import OpenAIChatGenerator 15 from haystack.components.rankers.sentence_transformers_similarity import SentenceTransformersSimilarityRanker 16 from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever 17 from haystack.dataclasses import ChatMessage 18 from haystack.document_stores.in_memory import InMemoryDocumentStore 19 from haystack.tools import PipelineTool 20 21 22 @pytest.fixture 23 def sample_pipeline(): 24 pipeline = Pipeline() 25 pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=InMemoryDocumentStore())) 26 pipeline.add_component("ranker", SentenceTransformersSimilarityRanker(model="fake-model")) 27 pipeline.connect("bm25_retriever", "ranker") 28 return pipeline 29 30 31 @pytest.fixture 32 def sample_async_pipeline(): 33 pipeline = AsyncPipeline() 34 pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=InMemoryDocumentStore())) 35 pipeline.add_component("ranker", SentenceTransformersSimilarityRanker(model="fake-model")) 36 pipeline.connect("bm25_retriever", "ranker") 37 return pipeline 38 39 40 @pytest.fixture 41 def sample_pipeline_dict(): 42 return { 43 "metadata": {}, 44 "max_runs_per_component": 100, 45 "components": { 46 "bm25_retriever": { 47 "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever", 48 "init_parameters": { 49 "document_store": { 50 "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore", 51 "init_parameters": { 52 "bm25_tokenization_regex": "(?u)\\b\\w+\\b", 53 "bm25_algorithm": "BM25L", 54 "bm25_parameters": {}, 55 "embedding_similarity_function": "dot_product", 56 "index": ANY, 57 "return_embedding": True, 58 }, 59 }, 60 "filters": None, 61 "top_k": 10, 62 "scale_score": False, 63 "filter_policy": "replace", 64 }, 65 }, 66 "ranker": { 67 "type": "haystack.components.rankers.sentence_transformers_similarity." 68 "SentenceTransformersSimilarityRanker", 69 "init_parameters": { 70 "device": {"type": "single", "device": ANY}, 71 "model": "fake-model", 72 "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, 73 "top_k": 10, 74 "query_prefix": "", 75 "query_suffix": "", 76 "document_prefix": "", 77 "document_suffix": "", 78 "meta_fields_to_embed": [], 79 "embedding_separator": "\n", 80 "scale_score": True, 81 "score_threshold": None, 82 "trust_remote_code": False, 83 "model_kwargs": None, 84 "tokenizer_kwargs": None, 85 "config_kwargs": None, 86 "backend": "torch", 87 "batch_size": 16, 88 }, 89 }, 90 }, 91 "connections": [{"sender": "bm25_retriever.documents", "receiver": "ranker.documents"}], 92 "connection_type_validation": True, 93 } 94 95 96 class TestPipelineTool: 97 def test_init_invalid_pipeline(self): 98 with pytest.raises( 99 TypeError, match="The 'pipeline' parameter must be an instance of Pipeline or AsyncPipeline." 100 ): 101 PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool") 102 103 def test_to_dict(self, sample_pipeline, sample_pipeline_dict): 104 tool = PipelineTool( 105 pipeline=sample_pipeline, 106 input_mapping={"query": ["bm25_retriever.query"]}, 107 output_mapping={"ranker.documents": "documents"}, 108 name="test_tool", 109 description="A test tool", 110 ) 111 112 tool_dict = tool.to_dict() 113 assert tool_dict == { 114 "type": "haystack.tools.pipeline_tool.PipelineTool", 115 "data": { 116 "pipeline": sample_pipeline_dict, 117 "name": "test_tool", 118 "input_mapping": {"query": ["bm25_retriever.query"]}, 119 "output_mapping": {"ranker.documents": "documents"}, 120 "description": "A test tool", 121 "parameters": None, 122 "inputs_from_state": None, 123 "outputs_to_state": None, 124 "is_pipeline_async": False, 125 "outputs_to_string": None, 126 }, 127 } 128 129 def test_to_dict_async_pipeline(self, sample_async_pipeline, sample_pipeline_dict): 130 tool = PipelineTool( 131 pipeline=sample_async_pipeline, 132 input_mapping={"query": ["bm25_retriever.query"]}, 133 output_mapping={"ranker.documents": "documents"}, 134 name="test_tool", 135 description="A test tool", 136 ) 137 138 tool_dict = tool.to_dict() 139 assert tool_dict == { 140 "type": "haystack.tools.pipeline_tool.PipelineTool", 141 "data": { 142 "pipeline": sample_pipeline_dict, 143 "name": "test_tool", 144 "input_mapping": {"query": ["bm25_retriever.query"]}, 145 "output_mapping": {"ranker.documents": "documents"}, 146 "description": "A test tool", 147 "parameters": None, 148 "inputs_from_state": None, 149 "outputs_to_state": None, 150 "is_pipeline_async": True, 151 "outputs_to_string": None, 152 }, 153 } 154 155 def test_from_dict(self, sample_pipeline): 156 tool = PipelineTool( 157 pipeline=sample_pipeline, 158 input_mapping={"query": ["bm25_retriever.query"]}, 159 output_mapping={"ranker.documents": "documents"}, 160 name="test_tool", 161 description="A test tool", 162 ) 163 164 tool_dict = tool.to_dict() 165 recreated_tool = PipelineTool.from_dict(tool_dict) 166 167 assert tool.name == recreated_tool.name 168 assert tool.description == recreated_tool.description 169 assert tool._input_mapping == recreated_tool._input_mapping 170 assert tool._output_mapping == recreated_tool._output_mapping 171 assert tool.parameters == recreated_tool.parameters 172 assert isinstance(recreated_tool._pipeline, Pipeline) 173 174 def test_from_dict_async_pipeline(self, sample_async_pipeline): 175 tool = PipelineTool( 176 pipeline=sample_async_pipeline, 177 input_mapping={"query": ["bm25_retriever.query"]}, 178 output_mapping={"ranker.documents": "documents"}, 179 name="test_tool", 180 description="A test tool", 181 ) 182 183 tool_dict = tool.to_dict() 184 recreated_tool = PipelineTool.from_dict(tool_dict) 185 186 assert tool.name == recreated_tool.name 187 assert tool.description == recreated_tool.description 188 assert tool._input_mapping == recreated_tool._input_mapping 189 assert tool._output_mapping == recreated_tool._output_mapping 190 assert tool.parameters == recreated_tool.parameters 191 assert isinstance(recreated_tool._pipeline, AsyncPipeline) 192 193 def test_auto_generated_tool_params(self, sample_pipeline): 194 tool = PipelineTool( 195 pipeline=sample_pipeline, 196 input_mapping={"query": ["bm25_retriever.query", "ranker.query"]}, 197 output_mapping={"ranker.documents": "documents"}, 198 name="test_tool", 199 description="A test tool", 200 ) 201 202 assert tool.parameters == { 203 "description": "A component that combines: 'bm25_retriever': Run the InMemoryBM25Retriever on the " 204 "given input data., 'ranker': Returns a list of documents ranked by their similarity " 205 "to the given query.", 206 "properties": { 207 "query": { 208 "description": "Provided to the 'bm25_retriever' component as: 'The query string for the Retriever." 209 "', and Provided to the 'ranker' component as: 'The input query to compare the " 210 "documents to.'.", 211 "type": "string", 212 } 213 }, 214 "required": ["query"], 215 "type": "object", 216 } 217 218 def test_auto_generated_tool_params_no_mappings(self, sample_pipeline): 219 tool = PipelineTool(pipeline=sample_pipeline, name="test_tool", description="A test tool") 220 assert tool.parameters == { 221 "description": "A component that combines: 'bm25_retriever': Run the InMemoryBM25Retriever on the given " 222 "input data., 'ranker': Returns a list of documents ranked by their similarity to the " 223 "given query.", 224 "properties": { 225 "query": { 226 "description": "Provided to the 'bm25_retriever' component as: 'The query string for the " 227 "Retriever.', and Provided to the 'ranker' component as: 'The input query to " 228 "compare the documents to.'.", 229 "type": "string", 230 }, 231 "filters": { 232 "anyOf": [{"additionalProperties": True, "type": "object"}, {"type": "null"}], 233 "description": "Provided to the 'bm25_retriever' component as: 'A dictionary with filters to " 234 "narrow down the search space when retrieving documents.'.", 235 }, 236 "top_k": { 237 "anyOf": [{"type": "integer"}, {"type": "null"}], 238 "description": "Provided to the 'bm25_retriever' component as: 'The maximum number of documents " 239 "to return.', and Provided to the 'ranker' component as: 'The maximum number " 240 "of documents to return.'.", 241 }, 242 "scale_score": { 243 "description": "Provided to the 'bm25_retriever' component as: 'When `True`, scales the score " 244 "of retrieved documents to a range of 0 to 1, where 1 means extremely relevant." 245 "\nWhen `False`, uses raw similarity scores.', and Provided to the 'ranker' " 246 "component as: 'If `True`, scales the raw logit predictions using a Sigmoid " 247 "activation function.\nIf `False`, disables scaling of the raw logit predictions." 248 "\nIf set, overrides the value set at initialization.'.", 249 "anyOf": [{"type": "boolean"}, {"type": "null"}], 250 }, 251 "score_threshold": { 252 "anyOf": [{"type": "number"}, {"type": "null"}], 253 "description": "Provided to the 'ranker' component as: 'Use it to return documents only with " 254 "a score above this threshold.\nIf set, overrides the value set at initialization.'" 255 ".", 256 }, 257 }, 258 "required": ["query"], 259 "type": "object", 260 } 261 262 @pytest.mark.integration 263 @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") 264 def test_live_pipeline_tool(self, in_memory_doc_store): 265 # Initialize a document store and add some documents 266 document_embedder = OpenAIDocumentEmbedder() 267 documents = [ 268 Document(content="Nikola Tesla was a Serbian-American inventor and electrical engineer."), 269 Document( 270 content="He is best known for his contributions to the design of the modern alternating current (AC) " 271 "electricity supply system." 272 ), 273 ] 274 docs_with_embeddings = document_embedder.run(documents=documents)["documents"] 275 in_memory_doc_store.write_documents(docs_with_embeddings) 276 277 # Build a simple retrieval pipeline 278 retrieval_pipeline = Pipeline() 279 retrieval_pipeline.add_component("embedder", OpenAITextEmbedder()) 280 retrieval_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=in_memory_doc_store)) 281 282 retrieval_pipeline.connect("embedder.embedding", "retriever.query_embedding") 283 284 # Wrap the pipeline as a tool 285 retriever_tool = PipelineTool( 286 pipeline=retrieval_pipeline, 287 input_mapping={"query": ["embedder.text"]}, 288 output_mapping={"retriever.documents": "documents"}, 289 name="document_retriever", 290 description="This tool retrieves documents relevant to Nikola Tesla from the document store", 291 ) 292 293 # Create an Agent with the tool 294 agent = Agent( 295 chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), 296 system_prompt="For any questions about Nikola Tesla, always use the document_retriever.", 297 tools=[retriever_tool], 298 ) 299 300 # Let the Agent handle a query 301 result = agent.run([ChatMessage.from_user("Who was Nikola Tesla?")]) 302 303 assert len(result["messages"]) == 5 # System msg, User msg, Agent msg, Tool call result, Agent mgs 304 assert "nikola" in result["messages"][-1].text.lower() 305 306 @pytest.mark.asyncio 307 @pytest.mark.integration 308 @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") 309 async def test_live_async_pipeline_tool(self, in_memory_doc_store): 310 # Initialize a document store and add some documents 311 document_embedder = OpenAIDocumentEmbedder() 312 documents = [ 313 Document(content="Nikola Tesla was a Serbian-American inventor and electrical engineer."), 314 Document( 315 content="He is best known for his contributions to the design of the modern alternating current (AC) " 316 "electricity supply system." 317 ), 318 ] 319 docs_with_embeddings = document_embedder.run(documents=documents)["documents"] 320 in_memory_doc_store.write_documents(docs_with_embeddings) 321 322 # Build a simple retrieval pipeline 323 retrieval_pipeline = AsyncPipeline() 324 retrieval_pipeline.add_component("embedder", OpenAITextEmbedder()) 325 retrieval_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=in_memory_doc_store)) 326 327 retrieval_pipeline.connect("embedder.embedding", "retriever.query_embedding") 328 329 # Wrap the pipeline as a tool 330 retriever_tool = PipelineTool( 331 pipeline=retrieval_pipeline, 332 input_mapping={"query": ["embedder.text"]}, 333 output_mapping={"retriever.documents": "documents"}, 334 name="document_retriever", 335 description="For any questions about Nikola Tesla, always use this tool", 336 ) 337 338 # Create an Agent with the tool 339 agent = Agent( 340 chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), 341 system_prompt="For any questions about Nikola Tesla, always use the document_retriever.", 342 tools=[retriever_tool], 343 ) 344 345 # Let the Agent handle a query 346 result = await agent.run_async([ChatMessage.from_user("Who was Nikola Tesla?")]) 347 348 assert len(result["messages"]) == 5 # System msg, User msg, Agent msg, Tool call result, Agent mgs 349 assert "nikola" in result["messages"][-1].text.lower() 350 351 def test_pipeline_tool_with_valid_inputs_from_state(self, sample_pipeline): 352 """Test that PipelineTool accepts valid inputs_from_state mapping""" 353 tool = PipelineTool( 354 pipeline=sample_pipeline, 355 input_mapping={"query": ["bm25_retriever.query"]}, 356 output_mapping={"ranker.documents": "documents"}, 357 name="test_tool", 358 description="A test tool", 359 inputs_from_state={"user_query": "query"}, 360 ) 361 assert tool.inputs_from_state == {"user_query": "query"} 362 363 def test_pipeline_tool_with_invalid_inputs_from_state(self, sample_pipeline): 364 """Test that PipelineTool validates inputs_from_state against pipeline inputs""" 365 with pytest.raises(ValueError, match="unknown parameter 'nonexistent'"): 366 PipelineTool( 367 pipeline=sample_pipeline, 368 input_mapping={"query": ["bm25_retriever.query"]}, 369 output_mapping={"ranker.documents": "documents"}, 370 name="test_tool", 371 description="A test tool", 372 inputs_from_state={"user_query": "nonexistent"}, 373 ) 374 375 def test_pipeline_tool_with_invalid_inputs_from_state_nested_dict(self, sample_pipeline): 376 """Test that PipelineTool rejects nested dict format for inputs_from_state""" 377 with pytest.raises(TypeError, match="must be str, not dict"): 378 PipelineTool( 379 pipeline=sample_pipeline, 380 input_mapping={"query": ["bm25_retriever.query"]}, 381 output_mapping={"ranker.documents": "documents"}, 382 name="test_tool", 383 description="A test tool", 384 inputs_from_state={"user_query": {"source": "query"}}, 385 ) 386 387 def test_pipeline_tool_with_valid_outputs_to_state(self, sample_pipeline): 388 """Test that PipelineTool accepts valid outputs_to_state mapping""" 389 tool = PipelineTool( 390 pipeline=sample_pipeline, 391 input_mapping={"query": ["bm25_retriever.query"]}, 392 output_mapping={"ranker.documents": "documents"}, 393 name="test_tool", 394 description="A test tool", 395 outputs_to_state={"result_docs": {"source": "documents"}}, 396 ) 397 assert tool.outputs_to_state == {"result_docs": {"source": "documents"}} 398 399 def test_pipeline_tool_with_invalid_outputs_to_state(self, sample_pipeline): 400 """Test that PipelineTool validates outputs_to_state against pipeline outputs""" 401 with pytest.raises(ValueError, match="unknown output"): 402 PipelineTool( 403 pipeline=sample_pipeline, 404 input_mapping={"query": ["bm25_retriever.query"]}, 405 output_mapping={"ranker.documents": "documents"}, 406 name="test_tool", 407 description="A test tool", 408 outputs_to_state={"result": {"source": "nonexistent"}}, 409 )