/ test / tools / test_pipeline_tool.py
test_pipeline_tool.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from unittest.mock import ANY
  7  
  8  import pytest
  9  
 10  from haystack import AsyncPipeline, Document, Pipeline
 11  from haystack.components.agents import Agent
 12  from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
 13  from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
 14  from haystack.components.generators.chat import OpenAIChatGenerator
 15  from haystack.components.rankers.sentence_transformers_similarity import SentenceTransformersSimilarityRanker
 16  from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
 17  from haystack.dataclasses import ChatMessage
 18  from haystack.document_stores.in_memory import InMemoryDocumentStore
 19  from haystack.tools import PipelineTool
 20  
 21  
 22  @pytest.fixture
 23  def sample_pipeline():
 24      pipeline = Pipeline()
 25      pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=InMemoryDocumentStore()))
 26      pipeline.add_component("ranker", SentenceTransformersSimilarityRanker(model="fake-model"))
 27      pipeline.connect("bm25_retriever", "ranker")
 28      return pipeline
 29  
 30  
 31  @pytest.fixture
 32  def sample_async_pipeline():
 33      pipeline = AsyncPipeline()
 34      pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=InMemoryDocumentStore()))
 35      pipeline.add_component("ranker", SentenceTransformersSimilarityRanker(model="fake-model"))
 36      pipeline.connect("bm25_retriever", "ranker")
 37      return pipeline
 38  
 39  
 40  @pytest.fixture
 41  def sample_pipeline_dict():
 42      return {
 43          "metadata": {},
 44          "max_runs_per_component": 100,
 45          "components": {
 46              "bm25_retriever": {
 47                  "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
 48                  "init_parameters": {
 49                      "document_store": {
 50                          "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
 51                          "init_parameters": {
 52                              "bm25_tokenization_regex": "(?u)\\b\\w+\\b",
 53                              "bm25_algorithm": "BM25L",
 54                              "bm25_parameters": {},
 55                              "embedding_similarity_function": "dot_product",
 56                              "index": ANY,
 57                              "return_embedding": True,
 58                          },
 59                      },
 60                      "filters": None,
 61                      "top_k": 10,
 62                      "scale_score": False,
 63                      "filter_policy": "replace",
 64                  },
 65              },
 66              "ranker": {
 67                  "type": "haystack.components.rankers.sentence_transformers_similarity."
 68                  "SentenceTransformersSimilarityRanker",
 69                  "init_parameters": {
 70                      "device": {"type": "single", "device": ANY},
 71                      "model": "fake-model",
 72                      "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
 73                      "top_k": 10,
 74                      "query_prefix": "",
 75                      "query_suffix": "",
 76                      "document_prefix": "",
 77                      "document_suffix": "",
 78                      "meta_fields_to_embed": [],
 79                      "embedding_separator": "\n",
 80                      "scale_score": True,
 81                      "score_threshold": None,
 82                      "trust_remote_code": False,
 83                      "model_kwargs": None,
 84                      "tokenizer_kwargs": None,
 85                      "config_kwargs": None,
 86                      "backend": "torch",
 87                      "batch_size": 16,
 88                  },
 89              },
 90          },
 91          "connections": [{"sender": "bm25_retriever.documents", "receiver": "ranker.documents"}],
 92          "connection_type_validation": True,
 93      }
 94  
 95  
 96  class TestPipelineTool:
 97      def test_init_invalid_pipeline(self):
 98          with pytest.raises(
 99              TypeError, match="The 'pipeline' parameter must be an instance of Pipeline or AsyncPipeline."
100          ):
101              PipelineTool(pipeline="invalid_pipeline", name="test_tool", description="A test tool")
102  
103      def test_to_dict(self, sample_pipeline, sample_pipeline_dict):
104          tool = PipelineTool(
105              pipeline=sample_pipeline,
106              input_mapping={"query": ["bm25_retriever.query"]},
107              output_mapping={"ranker.documents": "documents"},
108              name="test_tool",
109              description="A test tool",
110          )
111  
112          tool_dict = tool.to_dict()
113          assert tool_dict == {
114              "type": "haystack.tools.pipeline_tool.PipelineTool",
115              "data": {
116                  "pipeline": sample_pipeline_dict,
117                  "name": "test_tool",
118                  "input_mapping": {"query": ["bm25_retriever.query"]},
119                  "output_mapping": {"ranker.documents": "documents"},
120                  "description": "A test tool",
121                  "parameters": None,
122                  "inputs_from_state": None,
123                  "outputs_to_state": None,
124                  "is_pipeline_async": False,
125                  "outputs_to_string": None,
126              },
127          }
128  
129      def test_to_dict_async_pipeline(self, sample_async_pipeline, sample_pipeline_dict):
130          tool = PipelineTool(
131              pipeline=sample_async_pipeline,
132              input_mapping={"query": ["bm25_retriever.query"]},
133              output_mapping={"ranker.documents": "documents"},
134              name="test_tool",
135              description="A test tool",
136          )
137  
138          tool_dict = tool.to_dict()
139          assert tool_dict == {
140              "type": "haystack.tools.pipeline_tool.PipelineTool",
141              "data": {
142                  "pipeline": sample_pipeline_dict,
143                  "name": "test_tool",
144                  "input_mapping": {"query": ["bm25_retriever.query"]},
145                  "output_mapping": {"ranker.documents": "documents"},
146                  "description": "A test tool",
147                  "parameters": None,
148                  "inputs_from_state": None,
149                  "outputs_to_state": None,
150                  "is_pipeline_async": True,
151                  "outputs_to_string": None,
152              },
153          }
154  
155      def test_from_dict(self, sample_pipeline):
156          tool = PipelineTool(
157              pipeline=sample_pipeline,
158              input_mapping={"query": ["bm25_retriever.query"]},
159              output_mapping={"ranker.documents": "documents"},
160              name="test_tool",
161              description="A test tool",
162          )
163  
164          tool_dict = tool.to_dict()
165          recreated_tool = PipelineTool.from_dict(tool_dict)
166  
167          assert tool.name == recreated_tool.name
168          assert tool.description == recreated_tool.description
169          assert tool._input_mapping == recreated_tool._input_mapping
170          assert tool._output_mapping == recreated_tool._output_mapping
171          assert tool.parameters == recreated_tool.parameters
172          assert isinstance(recreated_tool._pipeline, Pipeline)
173  
174      def test_from_dict_async_pipeline(self, sample_async_pipeline):
175          tool = PipelineTool(
176              pipeline=sample_async_pipeline,
177              input_mapping={"query": ["bm25_retriever.query"]},
178              output_mapping={"ranker.documents": "documents"},
179              name="test_tool",
180              description="A test tool",
181          )
182  
183          tool_dict = tool.to_dict()
184          recreated_tool = PipelineTool.from_dict(tool_dict)
185  
186          assert tool.name == recreated_tool.name
187          assert tool.description == recreated_tool.description
188          assert tool._input_mapping == recreated_tool._input_mapping
189          assert tool._output_mapping == recreated_tool._output_mapping
190          assert tool.parameters == recreated_tool.parameters
191          assert isinstance(recreated_tool._pipeline, AsyncPipeline)
192  
193      def test_auto_generated_tool_params(self, sample_pipeline):
194          tool = PipelineTool(
195              pipeline=sample_pipeline,
196              input_mapping={"query": ["bm25_retriever.query", "ranker.query"]},
197              output_mapping={"ranker.documents": "documents"},
198              name="test_tool",
199              description="A test tool",
200          )
201  
202          assert tool.parameters == {
203              "description": "A component that combines: 'bm25_retriever': Run the InMemoryBM25Retriever on the "
204              "given input data., 'ranker': Returns a list of documents ranked by their similarity "
205              "to the given query.",
206              "properties": {
207                  "query": {
208                      "description": "Provided to the 'bm25_retriever' component as: 'The query string for the Retriever."
209                      "', and Provided to the 'ranker' component as: 'The input query to compare the "
210                      "documents to.'.",
211                      "type": "string",
212                  }
213              },
214              "required": ["query"],
215              "type": "object",
216          }
217  
218      def test_auto_generated_tool_params_no_mappings(self, sample_pipeline):
219          tool = PipelineTool(pipeline=sample_pipeline, name="test_tool", description="A test tool")
220          assert tool.parameters == {
221              "description": "A component that combines: 'bm25_retriever': Run the InMemoryBM25Retriever on the given "
222              "input data., 'ranker': Returns a list of documents ranked by their similarity to the "
223              "given query.",
224              "properties": {
225                  "query": {
226                      "description": "Provided to the 'bm25_retriever' component as: 'The query string for the "
227                      "Retriever.', and Provided to the 'ranker' component as: 'The input query to "
228                      "compare the documents to.'.",
229                      "type": "string",
230                  },
231                  "filters": {
232                      "anyOf": [{"additionalProperties": True, "type": "object"}, {"type": "null"}],
233                      "description": "Provided to the 'bm25_retriever' component as: 'A dictionary with filters to "
234                      "narrow down the search space when retrieving documents.'.",
235                  },
236                  "top_k": {
237                      "anyOf": [{"type": "integer"}, {"type": "null"}],
238                      "description": "Provided to the 'bm25_retriever' component as: 'The maximum number of documents "
239                      "to return.', and Provided to the 'ranker' component as: 'The maximum number "
240                      "of documents to return.'.",
241                  },
242                  "scale_score": {
243                      "description": "Provided to the 'bm25_retriever' component as: 'When `True`, scales the score "
244                      "of retrieved documents to a range of 0 to 1, where 1 means extremely relevant."
245                      "\nWhen `False`, uses raw similarity scores.', and Provided to the 'ranker' "
246                      "component as: 'If `True`, scales the raw logit predictions using a Sigmoid "
247                      "activation function.\nIf `False`, disables scaling of the raw logit predictions."
248                      "\nIf set, overrides the value set at initialization.'.",
249                      "anyOf": [{"type": "boolean"}, {"type": "null"}],
250                  },
251                  "score_threshold": {
252                      "anyOf": [{"type": "number"}, {"type": "null"}],
253                      "description": "Provided to the 'ranker' component as: 'Use it to return documents only with "
254                      "a score above this threshold.\nIf set, overrides the value set at initialization.'"
255                      ".",
256                  },
257              },
258              "required": ["query"],
259              "type": "object",
260          }
261  
262      @pytest.mark.integration
263      @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
264      def test_live_pipeline_tool(self, in_memory_doc_store):
265          # Initialize a document store and add some documents
266          document_embedder = OpenAIDocumentEmbedder()
267          documents = [
268              Document(content="Nikola Tesla was a Serbian-American inventor and electrical engineer."),
269              Document(
270                  content="He is best known for his contributions to the design of the modern alternating current (AC) "
271                  "electricity supply system."
272              ),
273          ]
274          docs_with_embeddings = document_embedder.run(documents=documents)["documents"]
275          in_memory_doc_store.write_documents(docs_with_embeddings)
276  
277          # Build a simple retrieval pipeline
278          retrieval_pipeline = Pipeline()
279          retrieval_pipeline.add_component("embedder", OpenAITextEmbedder())
280          retrieval_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=in_memory_doc_store))
281  
282          retrieval_pipeline.connect("embedder.embedding", "retriever.query_embedding")
283  
284          # Wrap the pipeline as a tool
285          retriever_tool = PipelineTool(
286              pipeline=retrieval_pipeline,
287              input_mapping={"query": ["embedder.text"]},
288              output_mapping={"retriever.documents": "documents"},
289              name="document_retriever",
290              description="This tool retrieves documents relevant to Nikola Tesla from the document store",
291          )
292  
293          # Create an Agent with the tool
294          agent = Agent(
295              chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"),
296              system_prompt="For any questions about Nikola Tesla, always use the document_retriever.",
297              tools=[retriever_tool],
298          )
299  
300          # Let the Agent handle a query
301          result = agent.run([ChatMessage.from_user("Who was Nikola Tesla?")])
302  
303          assert len(result["messages"]) == 5  # System msg, User msg, Agent msg, Tool call result, Agent mgs
304          assert "nikola" in result["messages"][-1].text.lower()
305  
306      @pytest.mark.asyncio
307      @pytest.mark.integration
308      @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
309      async def test_live_async_pipeline_tool(self, in_memory_doc_store):
310          # Initialize a document store and add some documents
311          document_embedder = OpenAIDocumentEmbedder()
312          documents = [
313              Document(content="Nikola Tesla was a Serbian-American inventor and electrical engineer."),
314              Document(
315                  content="He is best known for his contributions to the design of the modern alternating current (AC) "
316                  "electricity supply system."
317              ),
318          ]
319          docs_with_embeddings = document_embedder.run(documents=documents)["documents"]
320          in_memory_doc_store.write_documents(docs_with_embeddings)
321  
322          # Build a simple retrieval pipeline
323          retrieval_pipeline = AsyncPipeline()
324          retrieval_pipeline.add_component("embedder", OpenAITextEmbedder())
325          retrieval_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=in_memory_doc_store))
326  
327          retrieval_pipeline.connect("embedder.embedding", "retriever.query_embedding")
328  
329          # Wrap the pipeline as a tool
330          retriever_tool = PipelineTool(
331              pipeline=retrieval_pipeline,
332              input_mapping={"query": ["embedder.text"]},
333              output_mapping={"retriever.documents": "documents"},
334              name="document_retriever",
335              description="For any questions about Nikola Tesla, always use this tool",
336          )
337  
338          # Create an Agent with the tool
339          agent = Agent(
340              chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"),
341              system_prompt="For any questions about Nikola Tesla, always use the document_retriever.",
342              tools=[retriever_tool],
343          )
344  
345          # Let the Agent handle a query
346          result = await agent.run_async([ChatMessage.from_user("Who was Nikola Tesla?")])
347  
348          assert len(result["messages"]) == 5  # System msg, User msg, Agent msg, Tool call result, Agent mgs
349          assert "nikola" in result["messages"][-1].text.lower()
350  
351      def test_pipeline_tool_with_valid_inputs_from_state(self, sample_pipeline):
352          """Test that PipelineTool accepts valid inputs_from_state mapping"""
353          tool = PipelineTool(
354              pipeline=sample_pipeline,
355              input_mapping={"query": ["bm25_retriever.query"]},
356              output_mapping={"ranker.documents": "documents"},
357              name="test_tool",
358              description="A test tool",
359              inputs_from_state={"user_query": "query"},
360          )
361          assert tool.inputs_from_state == {"user_query": "query"}
362  
363      def test_pipeline_tool_with_invalid_inputs_from_state(self, sample_pipeline):
364          """Test that PipelineTool validates inputs_from_state against pipeline inputs"""
365          with pytest.raises(ValueError, match="unknown parameter 'nonexistent'"):
366              PipelineTool(
367                  pipeline=sample_pipeline,
368                  input_mapping={"query": ["bm25_retriever.query"]},
369                  output_mapping={"ranker.documents": "documents"},
370                  name="test_tool",
371                  description="A test tool",
372                  inputs_from_state={"user_query": "nonexistent"},
373              )
374  
375      def test_pipeline_tool_with_invalid_inputs_from_state_nested_dict(self, sample_pipeline):
376          """Test that PipelineTool rejects nested dict format for inputs_from_state"""
377          with pytest.raises(TypeError, match="must be str, not dict"):
378              PipelineTool(
379                  pipeline=sample_pipeline,
380                  input_mapping={"query": ["bm25_retriever.query"]},
381                  output_mapping={"ranker.documents": "documents"},
382                  name="test_tool",
383                  description="A test tool",
384                  inputs_from_state={"user_query": {"source": "query"}},
385              )
386  
387      def test_pipeline_tool_with_valid_outputs_to_state(self, sample_pipeline):
388          """Test that PipelineTool accepts valid outputs_to_state mapping"""
389          tool = PipelineTool(
390              pipeline=sample_pipeline,
391              input_mapping={"query": ["bm25_retriever.query"]},
392              output_mapping={"ranker.documents": "documents"},
393              name="test_tool",
394              description="A test tool",
395              outputs_to_state={"result_docs": {"source": "documents"}},
396          )
397          assert tool.outputs_to_state == {"result_docs": {"source": "documents"}}
398  
399      def test_pipeline_tool_with_invalid_outputs_to_state(self, sample_pipeline):
400          """Test that PipelineTool validates outputs_to_state against pipeline outputs"""
401          with pytest.raises(ValueError, match="unknown output"):
402              PipelineTool(
403                  pipeline=sample_pipeline,
404                  input_mapping={"query": ["bm25_retriever.query"]},
405                  output_mapping={"ranker.documents": "documents"},
406                  name="test_tool",
407                  description="A test tool",
408                  outputs_to_state={"result": {"source": "nonexistent"}},
409              )