test_agent.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import logging 6 import os 7 import re 8 from collections.abc import Iterator 9 from datetime import datetime 10 from typing import Any 11 from unittest.mock import AsyncMock, MagicMock, patch 12 13 import pytest 14 from jinja2 import TemplateSyntaxError 15 from openai import Stream 16 from openai.types.chat import ChatCompletionChunk, chat_completion_chunk 17 18 from haystack import Document, Pipeline, component, tracing 19 from haystack.components.agents.agent import Agent 20 from haystack.components.agents.state import merge_lists 21 from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder 22 from haystack.components.builders.prompt_builder import PromptBuilder 23 from haystack.components.generators.chat.openai import OpenAIChatGenerator 24 from haystack.components.joiners.list_joiner import ListJoiner 25 from haystack.components.retrievers.in_memory import InMemoryBM25Retriever 26 from haystack.core.component.types import OutputSocket 27 from haystack.dataclasses import ChatMessage, ToolCall 28 from haystack.dataclasses.chat_message import ChatRole, TextContent 29 from haystack.dataclasses.streaming_chunk import StreamingChunk 30 from haystack.document_stores.in_memory import InMemoryDocumentStore 31 from haystack.tools import ComponentTool, Tool, tool 32 from haystack.tools.toolset import Toolset 33 from haystack.tracing.logging_tracer import LoggingTracer 34 from haystack.utils import Secret, serialize_callable 35 36 37 def _user_msg(text: str) -> str: 38 return f'{{% message role="user" %}}{text}{{% endmessage %}}' 39 40 41 def _sys_msg(text: str) -> str: 42 return f'{{% message role="system" %}}{text}{{% endmessage %}}' 43 44 45 def sync_streaming_callback(chunk: StreamingChunk) -> None: 46 """A synchronous streaming callback.""" 47 pass 48 49 50 async def async_streaming_callback(chunk: StreamingChunk) -> None: 51 """An asynchronous streaming callback.""" 52 pass 53 54 55 def weather_function(location): 56 weather_info = { 57 "berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, 58 "paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, 59 "rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, 60 } 61 for city, result in weather_info.items(): 62 if city in location.lower(): 63 return result 64 return {"weather": "unknown", "temperature": 0, "unit": "celsius"} 65 66 67 @tool 68 def weather_tool_with_decorator(location: str) -> str: 69 """Provides weather information for a given location.""" 70 return f"Weather report for {location}: 20°C, sunny" 71 72 73 @pytest.fixture 74 def weather_tool(): 75 return Tool( 76 name="weather_tool", 77 description="Provides weather information for a given location.", 78 parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, 79 function=weather_function, 80 ) 81 82 83 @pytest.fixture 84 def component_tool(): 85 return ComponentTool(name="parrot", description="This is a parrot.", component=PromptBuilder(template="{{parrot}}")) 86 87 88 @pytest.fixture 89 def make_agent(weather_tool): 90 def _factory(**kwargs): 91 return Agent(chat_generator=MockChatGenerator(), tools=[weather_tool], **kwargs) 92 93 return _factory 94 95 96 class OpenAIMockStream(Stream[ChatCompletionChunk]): 97 def __init__(self, mock_chunk: ChatCompletionChunk, client=None, *args, **kwargs): 98 client = client or MagicMock() 99 super().__init__(client=client, *args, **kwargs) # noqa: B026 100 self.mock_chunk = mock_chunk 101 102 def __stream__(self) -> Iterator[ChatCompletionChunk]: 103 yield self.mock_chunk 104 105 106 @pytest.fixture 107 def openai_mock_chat_completion_chunk(): 108 """ 109 Mock the OpenAI API completion chunk response and reuse it for tests 110 """ 111 112 with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: 113 completion = ChatCompletionChunk( 114 id="foo", 115 model="gpt-4", 116 object="chat.completion.chunk", 117 choices=[ 118 chat_completion_chunk.Choice( 119 finish_reason="stop", 120 logprobs=None, 121 index=0, 122 delta=chat_completion_chunk.ChoiceDelta(content="Hello", role="assistant"), 123 ) 124 ], 125 created=int(datetime.now().timestamp()), 126 usage=None, 127 ) 128 mock_chat_completion_create.return_value = OpenAIMockStream( 129 completion, cast_to=None, response=None, client=None 130 ) 131 yield mock_chat_completion_create 132 133 134 @component 135 class MockChatGeneratorWithoutTools: 136 """A mock chat generator that implements ChatGenerator protocol but doesn't support tools.""" 137 138 def to_dict(self) -> dict[str, Any]: 139 return {"type": "MockChatGeneratorWithoutTools", "data": {}} 140 141 @classmethod 142 def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutTools": 143 return cls() 144 145 @component.output_types(replies=list[ChatMessage]) 146 def run(self, messages: list[ChatMessage]) -> dict[str, Any]: 147 return {"replies": [ChatMessage.from_assistant("Hello")]} 148 149 150 @component 151 class MockChatGeneratorWithoutRunAsync: 152 """A mock chat generator that implements ChatGenerator protocol but doesn't have run_async method.""" 153 154 def to_dict(self) -> dict[str, Any]: 155 return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}} 156 157 @classmethod 158 def from_dict(cls, data: dict[str, Any]) -> "MockChatGeneratorWithoutRunAsync": 159 return cls() 160 161 @component.output_types(replies=list[ChatMessage]) 162 def run(self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs) -> dict[str, Any]: 163 return {"replies": [ChatMessage.from_assistant("Hello")]} 164 165 166 @component 167 class MockChatGenerator: 168 def to_dict(self) -> dict[str, Any]: 169 return {"type": "MockChatGeneratorWithoutRunAsync", "data": {}} 170 171 @classmethod 172 def from_dict(cls, data: dict[str, Any]) -> "MockChatGenerator": 173 return cls() 174 175 @component.output_types(replies=list[ChatMessage]) 176 def run(self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs) -> dict[str, Any]: 177 return {"replies": [ChatMessage.from_assistant("Hello")]} 178 179 @component.output_types(replies=list[ChatMessage]) 180 async def run_async( 181 self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs 182 ) -> dict[str, Any]: 183 return {"replies": [ChatMessage.from_assistant("Hello from run_async")]} 184 185 186 class TestAgent: 187 def test_output_types(self, weather_tool, component_tool, monkeypatch): 188 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 189 chat_generator = OpenAIChatGenerator() 190 agent = Agent(chat_generator=chat_generator, tools=[weather_tool, component_tool]) 191 assert agent.__haystack_output__._sockets_dict == { 192 "messages": OutputSocket(name="messages", type=list[ChatMessage], receivers=[]), 193 "last_message": OutputSocket(name="last_message", type=ChatMessage, receivers=[]), 194 } 195 196 def test_to_dict(self, weather_tool, component_tool, monkeypatch): 197 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 198 generator = OpenAIChatGenerator() 199 agent = Agent( 200 chat_generator=generator, 201 tools=[weather_tool, component_tool], 202 exit_conditions=["text", "weather_tool"], 203 state_schema={"foo": {"type": str}}, 204 tool_invoker_kwargs={"max_workers": 5, "enable_streaming_callback_passthrough": True}, 205 ) 206 serialized_agent = agent.to_dict() 207 # Verify the model is truthy and serialized 208 assert "model" in serialized_agent["init_parameters"]["chat_generator"]["init_parameters"] 209 model_name = serialized_agent["init_parameters"]["chat_generator"]["init_parameters"]["model"] 210 # Check the rest of the structure 211 expected_structure = { 212 "type": "haystack.components.agents.agent.Agent", 213 "init_parameters": { 214 "chat_generator": { 215 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 216 "init_parameters": { 217 "model": model_name, 218 "streaming_callback": None, 219 "api_base_url": None, 220 "organization": None, 221 "generation_kwargs": {}, 222 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 223 "timeout": None, 224 "max_retries": None, 225 "tools": None, 226 "tools_strict": False, 227 "http_client_kwargs": None, 228 }, 229 }, 230 "tools": [ 231 { 232 "type": "haystack.tools.tool.Tool", 233 "data": { 234 "name": "weather_tool", 235 "description": "Provides weather information for a given location.", 236 "parameters": { 237 "type": "object", 238 "properties": {"location": {"type": "string"}}, 239 "required": ["location"], 240 }, 241 "function": "test_agent.weather_function", 242 "outputs_to_string": None, 243 "inputs_from_state": None, 244 "outputs_to_state": None, 245 }, 246 }, 247 { 248 "type": "haystack.tools.component_tool.ComponentTool", 249 "data": { 250 "component": { 251 "type": "haystack.components.builders.prompt_builder.PromptBuilder", 252 "init_parameters": { 253 "template": "{{parrot}}", 254 "variables": None, 255 "required_variables": None, 256 }, 257 }, 258 "name": "parrot", 259 "description": "This is a parrot.", 260 "parameters": None, 261 "outputs_to_string": None, 262 "inputs_from_state": None, 263 "outputs_to_state": None, 264 }, 265 }, 266 ], 267 "system_prompt": None, 268 "user_prompt": None, 269 "required_variables": None, 270 "exit_conditions": ["text", "weather_tool"], 271 "state_schema": {"foo": {"type": "str"}}, 272 "max_agent_steps": 100, 273 "streaming_callback": None, 274 "raise_on_tool_invocation_failure": False, 275 "tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True}, 276 "confirmation_strategies": None, 277 }, 278 } 279 assert serialized_agent == expected_structure 280 281 def test_to_dict_with_toolset(self, monkeypatch, weather_tool): 282 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 283 toolset = Toolset(tools=[weather_tool]) 284 agent = Agent(chat_generator=OpenAIChatGenerator(), tools=toolset) 285 serialized_agent = agent.to_dict() 286 # Verify the model is truthy and serialized 287 assert "model" in serialized_agent["init_parameters"]["chat_generator"]["init_parameters"] 288 model_name = serialized_agent["init_parameters"]["chat_generator"]["init_parameters"]["model"] 289 # Check the rest of the structure 290 expected_structure = { 291 "type": "haystack.components.agents.agent.Agent", 292 "init_parameters": { 293 "chat_generator": { 294 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 295 "init_parameters": { 296 "model": model_name, 297 "streaming_callback": None, 298 "api_base_url": None, 299 "organization": None, 300 "generation_kwargs": {}, 301 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 302 "timeout": None, 303 "max_retries": None, 304 "tools": None, 305 "tools_strict": False, 306 "http_client_kwargs": None, 307 }, 308 }, 309 "tools": { 310 "type": "haystack.tools.toolset.Toolset", 311 "data": { 312 "tools": [ 313 { 314 "type": "haystack.tools.tool.Tool", 315 "data": { 316 "name": "weather_tool", 317 "description": "Provides weather information for a given location.", 318 "parameters": { 319 "type": "object", 320 "properties": {"location": {"type": "string"}}, 321 "required": ["location"], 322 }, 323 "function": "test_agent.weather_function", 324 "outputs_to_string": None, 325 "inputs_from_state": None, 326 "outputs_to_state": None, 327 }, 328 } 329 ] 330 }, 331 }, 332 "system_prompt": None, 333 "user_prompt": None, 334 "required_variables": None, 335 "exit_conditions": ["text"], 336 "state_schema": {}, 337 "max_agent_steps": 100, 338 "raise_on_tool_invocation_failure": False, 339 "streaming_callback": None, 340 "tool_invoker_kwargs": None, 341 "confirmation_strategies": None, 342 }, 343 } 344 assert serialized_agent == expected_structure 345 346 def test_agent_serialization_with_tool_decorator(self, monkeypatch): 347 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 348 agent = Agent(chat_generator=OpenAIChatGenerator(), tools=[weather_tool_with_decorator]) 349 serialized_agent = agent.to_dict() 350 deserialized_agent = Agent.from_dict(serialized_agent) 351 352 assert deserialized_agent.tools == agent.tools 353 assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) 354 # Model name should match whatever the default is - not testing specific model 355 assert deserialized_agent.chat_generator.model == agent.chat_generator.model 356 assert deserialized_agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY") 357 assert deserialized_agent.exit_conditions == ["text"] 358 359 def test_from_dict(self, monkeypatch): 360 model = "gpt-5" 361 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 362 data = { 363 "type": "haystack.components.agents.agent.Agent", 364 "init_parameters": { 365 "chat_generator": { 366 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 367 "init_parameters": { 368 "model": model, 369 "streaming_callback": None, 370 "api_base_url": None, 371 "organization": None, 372 "generation_kwargs": {}, 373 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 374 "timeout": None, 375 "max_retries": None, 376 "tools": None, 377 "tools_strict": False, 378 "http_client_kwargs": None, 379 }, 380 }, 381 "tools": [ 382 { 383 "type": "haystack.tools.tool.Tool", 384 "data": { 385 "name": "weather_tool", 386 "description": "Provides weather information for a given location.", 387 "parameters": { 388 "type": "object", 389 "properties": {"location": {"type": "string"}}, 390 "required": ["location"], 391 }, 392 "function": "test_agent.weather_function", 393 "outputs_to_string": None, 394 "inputs_from_state": None, 395 "outputs_to_state": None, 396 }, 397 }, 398 { 399 "type": "haystack.tools.component_tool.ComponentTool", 400 "data": { 401 "component": { 402 "type": "haystack.components.builders.prompt_builder.PromptBuilder", 403 "init_parameters": { 404 "template": "{{parrot}}", 405 "variables": None, 406 "required_variables": None, 407 }, 408 }, 409 "name": "parrot", 410 "description": "This is a parrot.", 411 "parameters": None, 412 "outputs_to_string": None, 413 "inputs_from_state": None, 414 "outputs_to_state": None, 415 }, 416 }, 417 ], 418 "system_prompt": None, 419 "exit_conditions": ["text", "weather_tool"], 420 "state_schema": {"foo": {"type": "str"}}, 421 "max_agent_steps": 100, 422 "raise_on_tool_invocation_failure": False, 423 "streaming_callback": None, 424 "tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True}, 425 }, 426 } 427 agent = Agent.from_dict(data) 428 assert isinstance(agent, Agent) 429 assert isinstance(agent.chat_generator, OpenAIChatGenerator) 430 # from_dict should restore the model from the dict (testing backward compatibility) 431 assert agent.chat_generator.model == model 432 assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY") 433 assert agent.tools[0].function is weather_function 434 assert isinstance(agent.tools[1]._component, PromptBuilder) 435 assert agent.exit_conditions == ["text", "weather_tool"] 436 assert agent.state_schema == { 437 "foo": {"type": str}, 438 "messages": {"handler": merge_lists, "type": list[ChatMessage]}, 439 } 440 assert agent.tool_invoker_kwargs == {"max_workers": 5, "enable_streaming_callback_passthrough": True} 441 assert agent._tool_invoker.max_workers == 5 442 assert agent._tool_invoker.enable_streaming_callback_passthrough is True 443 444 def test_from_dict_with_toolset(self, monkeypatch): 445 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 446 data = { 447 "type": "haystack.components.agents.agent.Agent", 448 "init_parameters": { 449 "chat_generator": { 450 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 451 "init_parameters": { 452 "model": "gpt-4o-mini", 453 "streaming_callback": None, 454 "api_base_url": None, 455 "organization": None, 456 "generation_kwargs": {}, 457 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 458 "timeout": None, 459 "max_retries": None, 460 "tools": None, 461 "tools_strict": False, 462 "http_client_kwargs": None, 463 }, 464 }, 465 "tools": { 466 "type": "haystack.tools.toolset.Toolset", 467 "data": { 468 "tools": [ 469 { 470 "type": "haystack.tools.tool.Tool", 471 "data": { 472 "name": "weather_tool", 473 "description": "Provides weather information for a given location.", 474 "parameters": { 475 "type": "object", 476 "properties": {"location": {"type": "string"}}, 477 "required": ["location"], 478 }, 479 "function": "test_agent.weather_function", 480 "outputs_to_string": None, 481 "inputs_from_state": None, 482 "outputs_to_state": None, 483 }, 484 } 485 ] 486 }, 487 }, 488 "system_prompt": None, 489 "exit_conditions": ["text"], 490 "state_schema": {}, 491 "max_agent_steps": 100, 492 "raise_on_tool_invocation_failure": False, 493 "streaming_callback": None, 494 "tool_invoker_kwargs": None, 495 }, 496 } 497 agent = Agent.from_dict(data) 498 assert isinstance(agent, Agent) 499 assert isinstance(agent.chat_generator, OpenAIChatGenerator) 500 # from_dict should restore the model from the dict (testing backward compatibility) 501 assert agent.chat_generator.model == "gpt-4o-mini" 502 assert agent.chat_generator.api_key == Secret.from_env_var("OPENAI_API_KEY") 503 assert isinstance(agent.tools, Toolset) 504 assert agent.tools[0].function is weather_function 505 assert agent.exit_conditions == ["text"] 506 507 def test_from_dict_state_schema_none(self, monkeypatch): 508 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 509 data = { 510 "type": "haystack.components.agents.agent.Agent", 511 "init_parameters": { 512 "chat_generator": { 513 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 514 "init_parameters": { 515 "model": "gpt-4o-mini", 516 "streaming_callback": None, 517 "api_base_url": None, 518 "organization": None, 519 "generation_kwargs": {}, 520 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 521 "timeout": None, 522 "max_retries": None, 523 "tools": None, 524 "tools_strict": False, 525 "http_client_kwargs": None, 526 }, 527 }, 528 "tools": [ 529 { 530 "type": "haystack.tools.tool.Tool", 531 "data": { 532 "name": "weather_tool", 533 "description": "Provides weather information for a given location.", 534 "parameters": { 535 "type": "object", 536 "properties": {"location": {"type": "string"}}, 537 "required": ["location"], 538 }, 539 "function": "test_agent.weather_function", 540 "outputs_to_string": None, 541 "inputs_from_state": None, 542 "outputs_to_state": None, 543 }, 544 }, 545 { 546 "type": "haystack.tools.component_tool.ComponentTool", 547 "data": { 548 "component": { 549 "type": "haystack.components.builders.prompt_builder.PromptBuilder", 550 "init_parameters": { 551 "template": "{{parrot}}", 552 "variables": None, 553 "required_variables": None, 554 }, 555 }, 556 "name": "parrot", 557 "description": "This is a parrot.", 558 "parameters": None, 559 "outputs_to_string": None, 560 "inputs_from_state": None, 561 "outputs_to_state": None, 562 }, 563 }, 564 ], 565 "system_prompt": None, 566 "exit_conditions": ["text", "weather_tool"], 567 "state_schema": None, 568 "max_agent_steps": 100, 569 "raise_on_tool_invocation_failure": False, 570 "streaming_callback": None, 571 "tool_invoker_kwargs": {"max_workers": 5, "enable_streaming_callback_passthrough": True}, 572 }, 573 } 574 agent = Agent.from_dict(data) 575 assert agent.state_schema == {"messages": {"type": list[ChatMessage], "handler": merge_lists}} 576 577 def test_serde(self, weather_tool, component_tool, monkeypatch): 578 monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") 579 generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) 580 agent = Agent( 581 chat_generator=generator, 582 tools=[weather_tool, component_tool], 583 exit_conditions=["text", "weather_tool"], 584 state_schema={"foo": {"type": str}}, 585 ) 586 587 serialized_agent = agent.to_dict() 588 589 init_parameters = serialized_agent["init_parameters"] 590 591 assert serialized_agent["type"] == "haystack.components.agents.agent.Agent" 592 assert ( 593 init_parameters["chat_generator"]["type"] 594 == "haystack.components.generators.chat.openai.OpenAIChatGenerator" 595 ) 596 assert init_parameters["streaming_callback"] is None 597 assert init_parameters["tools"][0]["data"]["function"] == serialize_callable(weather_function) 598 assert ( 599 init_parameters["tools"][1]["data"]["component"]["type"] 600 == "haystack.components.builders.prompt_builder.PromptBuilder" 601 ) 602 assert init_parameters["exit_conditions"] == ["text", "weather_tool"] 603 604 deserialized_agent = Agent.from_dict(serialized_agent) 605 606 assert isinstance(deserialized_agent, Agent) 607 assert isinstance(deserialized_agent.chat_generator, OpenAIChatGenerator) 608 assert deserialized_agent.tools[0].function is weather_function 609 assert isinstance(deserialized_agent.tools[1]._component, PromptBuilder) 610 assert deserialized_agent.exit_conditions == ["text", "weather_tool"] 611 assert deserialized_agent.state_schema == { 612 "foo": {"type": str}, 613 "messages": {"handler": merge_lists, "type": list[ChatMessage]}, 614 } 615 616 def test_serde_with_streaming_callback(self, weather_tool, component_tool, monkeypatch): 617 monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") 618 generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) 619 agent = Agent( 620 chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=sync_streaming_callback 621 ) 622 623 serialized_agent = agent.to_dict() 624 625 init_parameters = serialized_agent["init_parameters"] 626 assert init_parameters["streaming_callback"] == "test_agent.sync_streaming_callback" 627 628 deserialized_agent = Agent.from_dict(serialized_agent) 629 assert deserialized_agent.streaming_callback is sync_streaming_callback 630 631 def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch): 632 monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") 633 generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) 634 635 # Test invalid exit condition 636 with pytest.raises(ValueError, match="Invalid exit conditions provided:"): 637 Agent(chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["invalid_tool"]) 638 639 # Test default exit condition 640 agent = Agent(chat_generator=generator, tools=[weather_tool, component_tool]) 641 assert agent.exit_conditions == ["text"] 642 643 # Test multiple valid exit conditions 644 agent = Agent( 645 chat_generator=generator, tools=[weather_tool, component_tool], exit_conditions=["text", "weather_tool"] 646 ) 647 assert agent.exit_conditions == ["text", "weather_tool"] 648 649 def test_run_with_params_streaming(self, openai_mock_chat_completion_chunk, weather_tool): 650 chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) 651 streaming_callback_called = False 652 653 def streaming_callback(chunk: StreamingChunk) -> None: 654 nonlocal streaming_callback_called 655 streaming_callback_called = True 656 657 agent = Agent(chat_generator=chat_generator, streaming_callback=streaming_callback, tools=[weather_tool]) 658 response = agent.run([ChatMessage.from_user("Hello")]) 659 660 # check we called the streaming callback 661 assert streaming_callback_called is True 662 663 # check that the component still returns the correct response 664 assert isinstance(response, dict) 665 assert "messages" in response 666 assert isinstance(response["messages"], list) 667 assert len(response["messages"]) == 2 668 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] 669 assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk 670 assert "last_message" in response 671 assert isinstance(response["last_message"], ChatMessage) 672 673 def test_run_with_run_streaming(self, openai_mock_chat_completion_chunk, weather_tool): 674 chat_generator = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) 675 676 streaming_callback_called = False 677 678 def streaming_callback(chunk: StreamingChunk) -> None: 679 nonlocal streaming_callback_called 680 streaming_callback_called = True 681 682 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 683 response = agent.run([ChatMessage.from_user("Hello")], streaming_callback=streaming_callback) 684 685 # check we called the streaming callback 686 assert streaming_callback_called is True 687 688 # check that the component still returns the correct response 689 assert isinstance(response, dict) 690 assert "messages" in response 691 assert isinstance(response["messages"], list) 692 assert len(response["messages"]) == 2 693 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] 694 assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk 695 assert "last_message" in response 696 assert isinstance(response["last_message"], ChatMessage) 697 698 def test_keep_generator_streaming(self, openai_mock_chat_completion_chunk, weather_tool): 699 streaming_callback_called = False 700 701 def streaming_callback(chunk: StreamingChunk) -> None: 702 nonlocal streaming_callback_called 703 streaming_callback_called = True 704 705 chat_generator = OpenAIChatGenerator( 706 api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback 707 ) 708 709 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 710 response = agent.run([ChatMessage.from_user("Hello")]) 711 712 # check we called the streaming callback 713 assert streaming_callback_called is True 714 715 # check that the component still returns the correct response 716 assert isinstance(response, dict) 717 assert "messages" in response 718 assert isinstance(response["messages"], list) 719 assert len(response["messages"]) == 2 720 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] 721 assert "Hello" in response["messages"][1].text # see openai_mock_chat_completion_chunk 722 assert "last_message" in response 723 assert isinstance(response["last_message"], ChatMessage) 724 725 def test_chat_generator_must_support_tools(self, weather_tool): 726 chat_generator = MockChatGeneratorWithoutTools() 727 728 with pytest.raises(TypeError, match="MockChatGeneratorWithoutTools does not accept tools"): 729 Agent(chat_generator=chat_generator, tools=[weather_tool]) 730 731 def test_no_tools_with_chat_generator_without_tools_support(self): 732 chat_generator = MockChatGeneratorWithoutTools() 733 agent = Agent(chat_generator=chat_generator, max_agent_steps=1) 734 735 response = agent.run(messages=[ChatMessage.from_user("Hello")]) 736 737 assert isinstance(response, dict) 738 assert "messages" in response 739 assert len(response["messages"]) == 2 740 assert response["messages"][0].text == "Hello" 741 assert response["messages"][1].text == "Hello" 742 assert response["last_message"] == response["messages"][-1] 743 744 def test_exceed_max_steps(self, monkeypatch, weather_tool, caplog): 745 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 746 generator = OpenAIChatGenerator() 747 748 mock_messages = [ 749 ChatMessage.from_assistant("First response"), 750 ChatMessage.from_assistant( 751 tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 752 ), 753 ] 754 755 agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=0) 756 757 # Patch agent.chat_generator.run to return mock_messages 758 agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) 759 760 with caplog.at_level(logging.WARNING): 761 agent.run([ChatMessage.from_user("Hello")]) 762 assert "Agent reached maximum agent steps" in caplog.text 763 764 def test_exit_condition_exits(self, monkeypatch, weather_tool): 765 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 766 generator = OpenAIChatGenerator() 767 768 # Mock messages where the exit condition appears in the second message 769 mock_messages = [ 770 ChatMessage.from_assistant( 771 tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 772 ) 773 ] 774 775 agent = Agent(chat_generator=generator, tools=[weather_tool], exit_conditions=["weather_tool"]) 776 777 # Patch agent.chat_generator.run to return mock_messages 778 agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) 779 780 result = agent.run([ChatMessage.from_user("Hello")]) 781 782 assert "messages" in result 783 assert len(result["messages"]) == 3 784 assert result["messages"][-2].tool_call.tool_name == "weather_tool" 785 assert ( 786 result["messages"][-1].tool_call_result.result 787 == "{'weather': 'mostly sunny', 'temperature': 7, 'unit': 'celsius'}" 788 ) 789 assert "last_message" in result 790 assert isinstance(result["last_message"], ChatMessage) 791 assert result["messages"][-1] == result["last_message"] 792 793 def test_agent_with_no_tools(self, monkeypatch, caplog): 794 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 795 generator = OpenAIChatGenerator() 796 797 # Mock messages where the exit condition appears in the second message 798 mock_messages = [ChatMessage.from_assistant("Berlin")] 799 800 with caplog.at_level("WARNING"): 801 agent = Agent(chat_generator=generator, tools=[], max_agent_steps=3) 802 assert "No tools provided to the Agent." in caplog.text 803 804 # Patch agent.chat_generator.run to return mock_messages 805 agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) 806 807 response = agent.run([ChatMessage.from_user("What is the capital of Germany?")]) 808 809 assert isinstance(response, dict) 810 assert "messages" in response 811 assert isinstance(response["messages"], list) 812 assert len(response["messages"]) == 2 813 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] 814 assert response["messages"][0].text == "What is the capital of Germany?" 815 assert response["messages"][1].text == "Berlin" 816 assert "last_message" in response 817 assert isinstance(response["last_message"], ChatMessage) 818 assert response["messages"][-1] == response["last_message"] 819 820 def test_run_with_system_prompt(self, weather_tool): 821 chat_generator = MockChatGeneratorWithoutRunAsync() 822 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") 823 response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) 824 assert response["messages"][0].text == "This is a system prompt." 825 826 def test_run_with_system_prompt_run_param(self, weather_tool): 827 chat_generator = MockChatGeneratorWithoutRunAsync() 828 agent = Agent( 829 chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is the init system prompt." 830 ) 831 response = agent.run( 832 [ChatMessage.from_user("What is the weather in Berlin?")], system_prompt="This is the run system prompt." 833 ) 834 assert response["messages"][0].text == "This is the run system prompt." 835 836 def test_run_with_tools_run_param(self, weather_tool: Tool, component_tool: Tool, monkeypatch): 837 @component 838 class MockChatGenerator: 839 tool_invoked = False 840 841 @component.output_types(replies=list[ChatMessage]) 842 def run( 843 self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs 844 ) -> dict[str, Any]: 845 assert tools == [weather_tool] 846 tool_message = ChatMessage.from_assistant( 847 tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 848 ) 849 message = tool_message if not self.tool_invoked else ChatMessage.from_assistant("Hello") 850 self.tool_invoked = True 851 return {"replies": [message]} 852 853 chat_generator = MockChatGenerator() 854 agent = Agent(chat_generator=chat_generator, tools=[component_tool], system_prompt="This is a system prompt.") 855 tool_invoker_run_mock = MagicMock(wraps=agent._tool_invoker.run) 856 monkeypatch.setattr(agent._tool_invoker, "run", tool_invoker_run_mock) 857 agent.run([ChatMessage.from_user("What is the weather in Berlin?")], tools=[weather_tool]) 858 tool_invoker_run_mock.assert_called_once() 859 assert tool_invoker_run_mock.call_args[1]["tools"] == [weather_tool] 860 861 def test_run_with_tools_run_param_for_tool_selection(self, weather_tool: Tool, component_tool: Tool, monkeypatch): 862 @component 863 class MockChatGenerator: 864 tool_invoked = False 865 866 @component.output_types(replies=list[ChatMessage]) 867 def run( 868 self, messages: list[ChatMessage], tools: list[Tool] | Toolset | None = None, **kwargs 869 ) -> dict[str, Any]: 870 assert tools == [weather_tool] 871 tool_message = ChatMessage.from_assistant( 872 tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 873 ) 874 message = tool_message if not self.tool_invoked else ChatMessage.from_assistant("Hello") 875 self.tool_invoked = True 876 return {"replies": [message]} 877 878 chat_generator = MockChatGenerator() 879 agent = Agent( 880 chat_generator=chat_generator, 881 tools=[weather_tool, component_tool], 882 system_prompt="This is a system prompt.", 883 ) 884 tool_invoker_run_mock = MagicMock(wraps=agent._tool_invoker.run) 885 monkeypatch.setattr(agent._tool_invoker, "run", tool_invoker_run_mock) 886 agent.run([ChatMessage.from_user("What is the weather in Berlin?")], tools=[weather_tool.name]) 887 tool_invoker_run_mock.assert_called_once() 888 assert tool_invoker_run_mock.call_args[1]["tools"] == [weather_tool] 889 890 def test_run_not_warmed_up(self, weather_tool): 891 """Warmup is run automatically on first run""" 892 chat_generator = MockChatGeneratorWithoutRunAsync() 893 chat_generator.warm_up = MagicMock() 894 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") 895 agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) 896 assert agent._is_warmed_up is True 897 assert chat_generator.warm_up.call_count == 1 898 899 def test_run_no_messages(self, monkeypatch): 900 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 901 chat_generator = OpenAIChatGenerator() 902 agent = Agent(chat_generator=chat_generator, tools=[]) 903 result = agent.run([]) 904 assert result["messages"] == [] 905 906 def test_run_only_system_prompt(self, caplog): 907 chat_generator = MockChatGeneratorWithoutRunAsync() 908 agent = Agent(chat_generator=chat_generator, tools=[], system_prompt="This is a system prompt.") 909 _ = agent.run([]) 910 assert "All messages provided to the Agent component are system messages." in caplog.text 911 912 @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") 913 @pytest.mark.integration 914 def test_run(self, weather_tool): 915 chat_generator = OpenAIChatGenerator(model="gpt-4.1-nano") 916 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) 917 response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) 918 919 assert isinstance(response, dict) 920 assert "messages" in response 921 assert isinstance(response["messages"], list) 922 assert len(response["messages"]) == 4 923 assert [isinstance(reply, ChatMessage) for reply in response["messages"]] 924 # Loose check of message texts 925 assert response["messages"][0].text == "What is the weather in Berlin?" 926 assert response["messages"][1].text is None 927 assert response["messages"][2].text is None 928 assert response["messages"][3].text is not None 929 # Loose check of message metadata 930 assert response["messages"][0].meta == {} 931 assert response["messages"][1].meta.get("model") is not None 932 assert response["messages"][2].meta == {} 933 assert response["messages"][3].meta.get("model") is not None 934 # Loose check of tool calls and results 935 assert response["messages"][1].tool_calls[0].tool_name == "weather_tool" 936 assert response["messages"][1].tool_calls[0].arguments is not None 937 assert response["messages"][2].tool_call_results[0].result is not None 938 assert response["messages"][2].tool_call_results[0].origin is not None 939 assert "last_message" in response 940 assert isinstance(response["last_message"], ChatMessage) 941 assert response["messages"][-1] == response["last_message"] 942 943 @pytest.mark.asyncio 944 async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(self, weather_tool): 945 chat_generator = MockChatGeneratorWithoutRunAsync() 946 947 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 948 949 chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]}) 950 951 result = await agent.run_async([ChatMessage.from_user("Hello")]) 952 953 expected_messages = [ 954 ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) 955 ] 956 chat_generator.run.assert_called_once_with(messages=expected_messages, tools=[weather_tool]) 957 958 assert isinstance(result, dict) 959 assert "messages" in result 960 assert isinstance(result["messages"], list) 961 assert len(result["messages"]) == 2 962 assert [isinstance(reply, ChatMessage) for reply in result["messages"]] 963 assert "Hello" in result["messages"][1].text 964 assert "last_message" in result 965 assert isinstance(result["last_message"], ChatMessage) 966 assert result["messages"][-1] == result["last_message"] 967 968 @pytest.mark.asyncio 969 async def test_generation_kwargs(self): 970 chat_generator = MockChatGeneratorWithoutRunAsync() 971 972 agent = Agent(chat_generator=chat_generator) 973 974 chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]}) 975 976 await agent.run_async([ChatMessage.from_user("Hello")], generation_kwargs={"temperature": 0.0}) 977 978 expected_messages = [ 979 ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) 980 ] 981 chat_generator.run.assert_called_once_with( 982 messages=expected_messages, generation_kwargs={"temperature": 0.0}, tools=[] 983 ) 984 985 @pytest.mark.asyncio 986 async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool): 987 chat_generator = MockChatGenerator() 988 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 989 990 chat_generator.run_async = AsyncMock( 991 return_value={"replies": [ChatMessage.from_assistant("Hello from run_async")]} 992 ) 993 994 result = await agent.run_async([ChatMessage.from_user("Hello")]) 995 996 expected_messages = [ 997 ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={}) 998 ] 999 chat_generator.run_async.assert_called_once_with(messages=expected_messages, tools=[weather_tool]) 1000 1001 assert isinstance(result, dict) 1002 assert "messages" in result 1003 assert isinstance(result["messages"], list) 1004 assert len(result["messages"]) == 2 1005 assert [isinstance(reply, ChatMessage) for reply in result["messages"]] 1006 assert "Hello from run_async" in result["messages"][1].text 1007 assert "last_message" in result 1008 assert isinstance(result["last_message"], ChatMessage) 1009 assert result["messages"][-1] == result["last_message"] 1010 1011 @pytest.mark.integration 1012 @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") 1013 def test_agent_streaming_with_tool_call(self, weather_tool): 1014 chat_generator = OpenAIChatGenerator(model="gpt-4.1-nano") 1015 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 1016 streaming_callback_called = False 1017 1018 def streaming_callback(chunk: StreamingChunk) -> None: 1019 nonlocal streaming_callback_called 1020 streaming_callback_called = True 1021 1022 result = agent.run( 1023 [ChatMessage.from_user("What's the weather in Paris?")], streaming_callback=streaming_callback 1024 ) 1025 1026 assert result is not None 1027 assert result["messages"] is not None 1028 assert result["last_message"] is not None 1029 assert streaming_callback_called 1030 1031 @pytest.mark.asyncio 1032 async def test_run_async_with_async_streaming_callback(self, weather_tool): 1033 chat_generator = MockChatGenerator() 1034 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback) 1035 1036 # This should not raise any exception 1037 result = await agent.run_async([ChatMessage.from_user("Hello")]) 1038 1039 assert "messages" in result 1040 assert len(result["messages"]) == 2 1041 assert result["messages"][1].text == "Hello from run_async" 1042 1043 def test_run_with_async_streaming_callback_fails(self, weather_tool): 1044 chat_generator = MockChatGenerator() 1045 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback) 1046 1047 with pytest.raises(ValueError, match="The init callback cannot be a coroutine"): 1048 agent.run([ChatMessage.from_user("Hello")]) 1049 1050 @pytest.mark.asyncio 1051 async def test_run_async_with_sync_streaming_callback_fails(self, weather_tool): 1052 chat_generator = MockChatGenerator() 1053 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=sync_streaming_callback) 1054 1055 with pytest.raises(ValueError, match="The init callback must be async compatible"): 1056 await agent.run_async([ChatMessage.from_user("Hello")]) 1057 1058 1059 class TestAgentTracing: 1060 def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool): 1061 chat_generator = MockChatGeneratorWithoutRunAsync() 1062 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 1063 1064 tracing.tracer.is_content_tracing_enabled = True 1065 tracing.enable_tracing(LoggingTracer()) 1066 caplog.set_level(logging.DEBUG) 1067 1068 _ = agent.run([ChatMessage.from_user("What's the weather in Paris?")]) 1069 1070 # Ensure tracing span was emitted 1071 assert any("Operation: haystack.component.run" in record.message for record in caplog.records) 1072 1073 # Check specific tags 1074 tags_records = [r for r in caplog.records if hasattr(r, "tag_name")] 1075 1076 expected_tag_names = [ 1077 "haystack.component.name", 1078 "haystack.component.type", 1079 "haystack.component.fully_qualified_type", 1080 "haystack.component.input_types", 1081 "haystack.component.input_spec", 1082 "haystack.component.output_spec", 1083 "haystack.component.input", 1084 "haystack.component.visits", 1085 "haystack.component.output", 1086 "haystack.agent.max_steps", 1087 "haystack.agent.tools", 1088 "haystack.agent.exit_conditions", 1089 "haystack.agent.state_schema", 1090 "haystack.agent.input", 1091 "haystack.agent.output", 1092 "haystack.agent.steps_taken", 1093 ] 1094 1095 expected_tag_values = [ 1096 "chat_generator", 1097 "MockChatGeneratorWithoutRunAsync", 1098 "test_agent.MockChatGeneratorWithoutRunAsync", 1099 '{"messages": "list", "tools": "list"}', 1100 '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "senders": []}, "tools": {"type": "list[haystack.tools.tool.Tool] | haystack.tools.toolset.Toolset | None", "senders": []}}', # noqa: E501 1101 '{"replies": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "receivers": []}}', 1102 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501 1103 1, 1104 '{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', 1105 100, 1106 '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501 1107 '["text"]', 1108 '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501 1109 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null, "break_point": null, "snapshot": null}', # noqa: E501 1110 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello"}]}]}', # noqa: E501 1111 1, 1112 ] 1113 for idx, record in enumerate(tags_records): 1114 assert record.tag_name == expected_tag_names[idx] 1115 assert record.tag_value == expected_tag_values[idx] 1116 1117 # Clean up 1118 tracing.tracer.is_content_tracing_enabled = False 1119 tracing.disable_tracing() 1120 1121 @pytest.mark.asyncio 1122 async def test_agent_tracing_span_async_run(self, caplog, monkeypatch, weather_tool): 1123 chat_generator = MockChatGenerator() 1124 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 1125 1126 tracing.tracer.is_content_tracing_enabled = True 1127 tracing.enable_tracing(LoggingTracer()) 1128 caplog.set_level(logging.DEBUG) 1129 1130 _ = await agent.run_async([ChatMessage.from_user("What's the weather in Paris?")]) 1131 1132 # Ensure tracing span was emitted 1133 assert any("Operation: haystack.component.run" in record.message for record in caplog.records) 1134 1135 # Check specific tags 1136 tags_records = [r for r in caplog.records if hasattr(r, "tag_name")] 1137 1138 expected_tag_names = [ 1139 "haystack.component.name", 1140 "haystack.component.type", 1141 "haystack.component.fully_qualified_type", 1142 "haystack.component.input_types", 1143 "haystack.component.input_spec", 1144 "haystack.component.output_spec", 1145 "haystack.component.input", 1146 "haystack.component.visits", 1147 "haystack.component.output", 1148 "haystack.agent.max_steps", 1149 "haystack.agent.tools", 1150 "haystack.agent.exit_conditions", 1151 "haystack.agent.state_schema", 1152 "haystack.agent.input", 1153 "haystack.agent.output", 1154 "haystack.agent.steps_taken", 1155 ] 1156 1157 expected_tag_values = [ 1158 "chat_generator", 1159 "MockChatGenerator", 1160 "test_agent.MockChatGenerator", 1161 '{"messages": "list", "tools": "list"}', 1162 '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "senders": []}, "tools": {"type": "list[haystack.tools.tool.Tool] | haystack.tools.toolset.Toolset | None", "senders": []}}', # noqa: E501 1163 '{"replies": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "receivers": []}}', 1164 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "tools": [{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]}', # noqa: E501 1165 1, 1166 '{"replies": [{"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501 1167 100, 1168 '[{"type": "haystack.tools.tool.Tool", "data": {"name": "weather_tool", "description": "Provides weather information for a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, "function": "test_agent.weather_function", "outputs_to_string": null, "inputs_from_state": null, "outputs_to_state": null}}]', # noqa: E501 1169 '["text"]', 1170 '{"messages": {"type": "list[haystack.dataclasses.chat_message.ChatMessage]", "handler": "haystack.components.agents.state.state_utils.merge_lists"}}', # noqa: E501 1171 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}], "streaming_callback": null, "break_point": null, "snapshot": null}', # noqa: E501 1172 '{"messages": [{"role": "user", "meta": {}, "name": null, "content": [{"text": "What\'s the weather in Paris?"}]}, {"role": "assistant", "meta": {}, "name": null, "content": [{"text": "Hello from run_async"}]}]}', # noqa: E501 1173 1, 1174 ] 1175 for idx, record in enumerate(tags_records): 1176 assert record.tag_name == expected_tag_names[idx] 1177 assert record.tag_value == expected_tag_values[idx] 1178 1179 # Clean up 1180 tracing.tracer.is_content_tracing_enabled = False 1181 tracing.disable_tracing() 1182 1183 def test_agent_tracing_in_pipeline(self, caplog, monkeypatch, weather_tool): 1184 chat_generator = MockChatGeneratorWithoutRunAsync() 1185 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 1186 1187 tracing.tracer.is_content_tracing_enabled = True 1188 tracing.enable_tracing(LoggingTracer()) 1189 caplog.set_level(logging.DEBUG) 1190 1191 pipeline = Pipeline() 1192 pipeline.add_component( 1193 "prompt_builder", ChatPromptBuilder(template=[ChatMessage.from_user("Hello {{location}}")]) 1194 ) 1195 pipeline.add_component("agent", agent) 1196 pipeline.connect("prompt_builder.prompt", "agent.messages") 1197 1198 pipeline.run(data={"prompt_builder": {"location": "Berlin"}}) 1199 1200 assert any("Operation: haystack.pipeline.run" in record.message for record in caplog.records) 1201 tags_records = [r for r in caplog.records if hasattr(r, "tag_name")] 1202 expected_tag_names = [ 1203 "haystack.component.name", 1204 "haystack.component.type", 1205 "haystack.component.fully_qualified_type", 1206 "haystack.component.input_types", 1207 "haystack.component.input_spec", 1208 "haystack.component.output_spec", 1209 "haystack.component.input", 1210 "haystack.component.visits", 1211 "haystack.component.output", 1212 "haystack.component.name", 1213 "haystack.component.type", 1214 "haystack.component.fully_qualified_type", 1215 "haystack.component.input_types", 1216 "haystack.component.input_spec", 1217 "haystack.component.output_spec", 1218 "haystack.component.input", 1219 "haystack.component.visits", 1220 "haystack.component.output", 1221 "haystack.agent.max_steps", 1222 "haystack.agent.tools", 1223 "haystack.agent.exit_conditions", 1224 "haystack.agent.state_schema", 1225 "haystack.agent.input", 1226 "haystack.agent.output", 1227 "haystack.agent.steps_taken", 1228 "haystack.component.name", 1229 "haystack.component.type", 1230 "haystack.component.fully_qualified_type", 1231 "haystack.component.input_types", 1232 "haystack.component.input_spec", 1233 "haystack.component.output_spec", 1234 "haystack.component.input", 1235 "haystack.component.visits", 1236 "haystack.component.output", 1237 "haystack.pipeline.input_data", 1238 "haystack.pipeline.output_data", 1239 "haystack.pipeline.metadata", 1240 "haystack.pipeline.max_runs_per_component", 1241 ] 1242 for idx, record in enumerate(tags_records): 1243 assert record.tag_name == expected_tag_names[idx] 1244 1245 # Clean up 1246 tracing.tracer.is_content_tracing_enabled = False 1247 tracing.disable_tracing() 1248 1249 def test_agent_span_has_parent_when_in_pipeline(self, spying_tracer, weather_tool): 1250 """Test that the agent's span has the component span as its parent when running in a pipeline.""" 1251 chat_generator = MockChatGeneratorWithoutRunAsync() 1252 agent = Agent(chat_generator=chat_generator, tools=[weather_tool]) 1253 1254 pipeline = Pipeline() 1255 pipeline.add_component( 1256 "prompt_builder", ChatPromptBuilder(template=[ChatMessage.from_user("Hello {{location}}")]) 1257 ) 1258 pipeline.add_component("agent", agent) 1259 pipeline.connect("prompt_builder.prompt", "agent.messages") 1260 1261 pipeline.run(data={"prompt_builder": {"location": "Berlin"}}) 1262 1263 # Find the agent span (haystack.agent.run) 1264 agent_spans = [s for s in spying_tracer.spans if s.operation_name == "haystack.agent.run"] 1265 assert len(agent_spans) == 1 1266 agent_span = agent_spans[0] 1267 1268 # Find the agent's component span (the outer span for the Agent component) 1269 agent_component_spans = [ 1270 s 1271 for s in spying_tracer.spans 1272 if s.operation_name == "haystack.component.run" and s.tags.get("haystack.component.name") == "agent" 1273 ] 1274 assert len(agent_component_spans) == 1 1275 agent_component_span = agent_component_spans[0] 1276 1277 # Verify the agent span has the component span as its parent 1278 assert agent_span.parent_span is not None 1279 assert agent_span.parent_span == agent_component_span 1280 1281 1282 class TestAgentToolSelection: 1283 def test_tool_selection_by_name(self, weather_tool: Tool, component_tool: Tool): 1284 chat_generator = MockChatGenerator() 1285 agent = Agent( 1286 chat_generator=chat_generator, 1287 tools=[weather_tool, component_tool], 1288 system_prompt="This is a system prompt.", 1289 ) 1290 result = agent._select_tools([weather_tool.name]) 1291 assert result == [weather_tool] 1292 1293 def test_tool_selection_new_tool(self, weather_tool: Tool, component_tool: Tool): 1294 chat_generator = MockChatGenerator() 1295 agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") 1296 result = agent._select_tools([component_tool]) 1297 assert result == [component_tool] 1298 1299 def test_tool_selection_existing_tools(self, weather_tool: Tool, component_tool: Tool): 1300 chat_generator = MockChatGenerator() 1301 agent = Agent( 1302 chat_generator=chat_generator, 1303 tools=[weather_tool, component_tool], 1304 system_prompt="This is a system prompt.", 1305 ) 1306 result = agent._select_tools(None) 1307 assert result == [weather_tool, component_tool] 1308 1309 def test_tool_selection_invalid_tool_name(self, weather_tool: Tool, component_tool: Tool): 1310 chat_generator = MockChatGenerator() 1311 agent = Agent( 1312 chat_generator=chat_generator, 1313 tools=[weather_tool, component_tool], 1314 system_prompt="This is a system prompt.", 1315 ) 1316 with pytest.raises( 1317 ValueError, match=("The following tool names are not valid: {'invalid_tool_name'}. Valid tool names are: .") 1318 ): 1319 agent._select_tools(["invalid_tool_name"]) 1320 1321 def test_tool_selection_no_tools_configured(self, weather_tool: Tool, component_tool: Tool): 1322 chat_generator = MockChatGenerator() 1323 agent = Agent(chat_generator=chat_generator, tools=[], system_prompt="This is a system prompt.") 1324 with pytest.raises(ValueError, match="No tools were configured for the Agent at initialization."): 1325 agent._select_tools([weather_tool.name]) 1326 1327 def test_tool_selection_invalid_type(self, weather_tool: Tool, component_tool: Tool): 1328 chat_generator = MockChatGenerator() 1329 agent = Agent( 1330 chat_generator=chat_generator, 1331 tools=[weather_tool, component_tool], 1332 system_prompt="This is a system prompt.", 1333 ) 1334 with pytest.raises( 1335 TypeError, 1336 match=( 1337 re.escape( 1338 "tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)." 1339 ) 1340 ), 1341 ): 1342 agent._select_tools("invalid_tool_name") 1343 1344 def test_tool_selection_with_list_of_toolsets(self, weather_tool: Tool, component_tool: Tool): 1345 """Test that list of Toolsets and Tools can be passed to agent.""" 1346 chat_generator = MockChatGenerator() 1347 toolset1 = Toolset([weather_tool]) 1348 standalone_tool = Tool( 1349 name="standalone", 1350 description="A standalone tool", 1351 parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, 1352 function=lambda x: f"Result: {x}", 1353 ) 1354 toolset2 = Toolset([component_tool]) 1355 1356 agent = Agent(chat_generator=chat_generator, tools=[toolset1, standalone_tool, toolset2]) 1357 result = agent._select_tools(None) 1358 1359 assert result == [toolset1, standalone_tool, toolset2] 1360 assert isinstance(result, list) 1361 assert len(result) == 3 1362 1363 def test_agent_serde_with_list_of_toolsets(self, weather_tool: Tool, component_tool: Tool, monkeypatch): 1364 """Test Agent serialization and deserialization with a list of Toolsets.""" 1365 monkeypatch.setenv("OPENAI_API_KEY", "fake-key") 1366 1367 toolset1 = Toolset([weather_tool]) 1368 toolset2 = Toolset([component_tool]) 1369 1370 generator = OpenAIChatGenerator() 1371 agent = Agent(chat_generator=generator, tools=[toolset1, toolset2]) 1372 1373 serialized_agent = agent.to_dict() 1374 1375 # Verify serialization preserves list[Toolset] structure 1376 tools_data = serialized_agent["init_parameters"]["tools"] 1377 assert isinstance(tools_data, list) 1378 assert len(tools_data) == 2 1379 assert all(isinstance(ts, dict) for ts in tools_data) 1380 assert tools_data[0]["type"] == "haystack.tools.toolset.Toolset" 1381 assert tools_data[1]["type"] == "haystack.tools.toolset.Toolset" 1382 1383 # Deserialize and verify 1384 deserialized_agent = Agent.from_dict(serialized_agent) 1385 assert isinstance(deserialized_agent.tools, list) 1386 assert len(deserialized_agent.tools) == 2 1387 assert all(isinstance(ts, Toolset) for ts in deserialized_agent.tools) 1388 1389 1390 class TestRegisterPromptVariables: 1391 def test_register_prompt_variables_warning_when_no_prompt_and_required_variables(self, make_agent, caplog): 1392 make_agent(required_variables=["name"]) 1393 assert "The parameter required_variables is provided but neither" in caplog.text 1394 1395 def test_register_prompt_variables_set_all_variables_as_required(self, make_agent): 1396 agent = make_agent(user_prompt=_user_msg("Question: {{question}}"), required_variables="*") 1397 assert agent._user_chat_prompt_builder.required_variables == "*" 1398 1399 input_names = set(agent.__haystack_input__._sockets_dict.keys()) 1400 assert "question" in input_names 1401 1402 def test_register_prompt_variables_set_required_variables_on_builder(self, make_agent): 1403 agent = make_agent(user_prompt=_user_msg("Question: {{question}}"), required_variables=["question"]) 1404 assert agent._user_chat_prompt_builder.required_variables == ["question"] 1405 1406 input_names = set(agent.__haystack_input__._sockets_dict.keys()) 1407 assert "question" in input_names 1408 1409 def test_register_prompt_variables_raises_on_state_schema_conflict(self, make_agent): 1410 with pytest.raises( 1411 ValueError, match="Variable 'question' from user_prompt is already defined in the state schema." 1412 ): 1413 make_agent(user_prompt=_user_msg("Question: {{question}}"), state_schema={"question": {"type": str}}) 1414 1415 def test_register_prompt_variables_raises_on_run_param_conflict(self, make_agent): 1416 with pytest.raises( 1417 ValueError, match="Variable 'system_prompt' from user_prompt conflicts with input names in the run method." 1418 ): 1419 make_agent(user_prompt=_user_msg("{{system_prompt}} is already a run parameter.")) 1420 1421 1422 class TestInitializeFreshExecution: 1423 def test_initialize_fresh_execution_raises_with_init_run_mismatch(self, make_agent): 1424 agent = make_agent(system_prompt="Plain init prompt.") 1425 with pytest.raises(ValueError, match="no system prompt builder is initialized"): 1426 agent._initialize_fresh_execution( 1427 messages=None, 1428 streaming_callback=None, 1429 requires_async=False, 1430 user_prompt=None, 1431 system_prompt=_sys_msg("Jinja2 syntax."), 1432 ) 1433 1434 agent = make_agent() 1435 with pytest.raises(ValueError, match="user_prompt is provided but the ChatPromptBuilder is not initialized"): 1436 agent._initialize_fresh_execution( 1437 messages=None, 1438 streaming_callback=None, 1439 requires_async=False, 1440 user_prompt=_user_msg("Jinja2 syntax."), 1441 system_prompt=None, 1442 ) 1443 1444 def test_initialize_fresh_execution_raises_with_wrong_role(self, make_agent): 1445 agent = make_agent(system_prompt=_user_msg("This is a user message, not system.")) 1446 with pytest.raises(ValueError, match="system_prompt must render to a system message"): 1447 agent._initialize_fresh_execution( 1448 messages=None, streaming_callback=None, requires_async=False, user_prompt=None, system_prompt=None 1449 ) 1450 1451 agent = make_agent(user_prompt=_sys_msg("This is a user message, not system.")) 1452 with pytest.raises(ValueError, match="user_prompt must render to a user message"): 1453 agent._initialize_fresh_execution( 1454 messages=None, streaming_callback=None, requires_async=False, user_prompt=None, system_prompt=None 1455 ) 1456 1457 def test_initialize_fresh_execution_raises_with_incorrect_prompt_length(self, make_agent): 1458 multi_message_prompt = """{% message role='system' %}You are a helpful assistant.{% endmessage %} 1459 {% message role='user' %}How are you?{% endmessage %}""" 1460 1461 agent = make_agent(system_prompt=multi_message_prompt) 1462 with pytest.raises(ValueError, match="system_prompt must render to exactly one system message"): 1463 agent._initialize_fresh_execution( 1464 messages=None, streaming_callback=None, requires_async=False, user_prompt=None, system_prompt=None 1465 ) 1466 1467 agent = make_agent(user_prompt=multi_message_prompt) 1468 with pytest.raises(ValueError, match="user_prompt must render to exactly one user message"): 1469 agent._initialize_fresh_execution( 1470 messages=None, streaming_callback=None, requires_async=False, user_prompt=None, system_prompt=None 1471 ) 1472 1473 1474 class TestPrompts: 1475 def test_system_prompt_incorrect_jinja2_syntax_raises(self, make_agent): 1476 with pytest.raises(TemplateSyntaxError): 1477 make_agent(system_prompt="{% message role='system' %}Incomplete syntax.") 1478 1479 def test_system_prompt_plain_string(self, make_agent): 1480 agent = make_agent(system_prompt="You are a helpful assistant.") 1481 assert agent._system_chat_prompt_builder is None 1482 result = agent.run(messages=[ChatMessage.from_user("Hi")]) 1483 assert result["messages"][0].is_from(ChatRole.SYSTEM) 1484 assert result["messages"][0].text == "You are a helpful assistant." 1485 1486 def test_system_prompt_with_template_variables(self, make_agent): 1487 agent = make_agent(system_prompt=_sys_msg("You are an assistant for {{company}}. Your role is {{role}}.")) 1488 assert agent._system_chat_prompt_builder is not None 1489 assert set(agent._system_chat_prompt_builder.variables) == {"company", "role"} 1490 1491 result = agent.run(messages=[ChatMessage.from_user("Hi")], company="Acme", role="support agent") 1492 sys_msg = result["messages"][0] 1493 assert sys_msg.is_from(ChatRole.SYSTEM) 1494 assert sys_msg.text == "You are an assistant for Acme. Your role is support agent." 1495 1496 input_names = set(agent.__haystack_input__._sockets_dict.keys()) 1497 assert "company" in input_names 1498 assert "role" in input_names 1499 1500 def test_system_prompt_with_meta(self, make_agent): 1501 agent = make_agent( 1502 system_prompt="{% message role='system' meta={'key': 'value'} %}System message with meta{% endmessage %}" 1503 ) 1504 assert agent._system_chat_prompt_builder is not None 1505 1506 result = agent.run(messages=[ChatMessage.from_user("Hi")]) 1507 messages = result["messages"] 1508 assert messages[0].is_from(ChatRole.SYSTEM) 1509 assert messages[0].text == "System message with meta" 1510 assert messages[0].meta == {"key": "value"} 1511 1512 def test_system_prompt_runtime_override(self, make_agent): 1513 agent = make_agent(system_prompt=_sys_msg("You are a helpful assistant.")) 1514 result = agent.run( 1515 messages=[ChatMessage.from_user("Hi")], system_prompt=_sys_msg("You are an Haystack expert.") 1516 ) 1517 assert result["messages"][0].text == "You are an Haystack expert." 1518 assert result["messages"][1].text == "Hi" 1519 1520 def test_user_prompt_only_variables_forwarded_to_builder(self, make_agent): 1521 agent = make_agent(user_prompt=_user_msg("Question: {{question}}")) 1522 # 'irrelevant_kwarg' is not a template variable — must not raise 1523 result = agent.run(messages=[], question="Will it snow?", irrelevant_kwarg="unused") 1524 assert "messages" in result 1525 1526 def test_user_prompt_with_template_variables(self, make_agent): 1527 agent = make_agent( 1528 user_prompt=_user_msg( 1529 "Hello {{name|upper}}, check weather for: " 1530 + "{% for c in cities %}{{c}}{% if not loop.last %}, {% endif %}{% endfor %}" 1531 + " on {{date}}?" 1532 ) 1533 ) 1534 result = agent.run(messages=[], name="Alice", cities=["Berlin", "Paris", "Rome"], date="2024-01-15") 1535 user_messages = [m for m in result["messages"] if m.is_from(ChatRole.USER)] 1536 assert user_messages[0].text == "Hello ALICE, check weather for: Berlin, Paris, Rome on 2024-01-15?" 1537 1538 input_names = set(agent.__haystack_input__._sockets_dict.keys()) 1539 assert "name" in input_names 1540 assert "cities" in input_names 1541 assert "date" in input_names 1542 1543 def test_runtime_user_prompt_overrides_init_prompt(self, make_agent): 1544 agent = make_agent(user_prompt=_user_msg("Default prompt for {{city}}.")) 1545 result = agent.run(messages=[], user_prompt=_user_msg("Runtime prompt for {{city}}."), city="Berlin") 1546 user_messages = [m for m in result["messages"] if m.is_from(ChatRole.USER)] 1547 assert user_messages[0].text == "Runtime prompt for Berlin." 1548 1549 def test_user_prompt_appended_after_initial_messages(self, make_agent): 1550 agent = make_agent(user_prompt=_user_msg("And now: {{query}}")) 1551 initial_messages = [ChatMessage.from_user("First message")] 1552 result = agent.run(messages=initial_messages, query="What is the weather?") 1553 user_messages = [m for m in result["messages"] if m.is_from(ChatRole.USER)] 1554 assert user_messages[0].text == "First message" 1555 assert user_messages[1].text == "And now: What is the weather?" 1556 1557 def test_runtime_user_prompt_appended_after_initial_messages(self, make_agent): 1558 agent = make_agent(user_prompt=_user_msg("Init prompt: {{question}}")) 1559 initial_messages = [ChatMessage.from_user("Context message")] 1560 result = agent.run( 1561 messages=initial_messages, user_prompt=_user_msg("Follow-up: {{question}}"), question="Is it raining?" 1562 ) 1563 user_messages = [m for m in result["messages"] if m.is_from(ChatRole.USER)] 1564 assert len(user_messages) == 2 1565 assert user_messages[0].text == "Context message" 1566 assert user_messages[1].text == "Follow-up: Is it raining?" 1567 1568 def test_system_prompt_and_user_prompt(self, make_agent): 1569 agent = make_agent( 1570 system_prompt=_sys_msg("You help users of {{project}}."), 1571 user_prompt=_user_msg("Tell me about {{topic}} in the {{project}} context."), 1572 ) 1573 assert agent._system_chat_prompt_builder is not None 1574 assert agent._user_chat_prompt_builder is not None 1575 1576 result = agent.run(messages=[], project="Haystack", topic="pipelines") 1577 messages = result["messages"] 1578 assert messages[0].is_from(ChatRole.SYSTEM) 1579 assert messages[0].text == "You help users of Haystack." 1580 user_messages = [m for m in messages if m.is_from(ChatRole.USER)] 1581 assert user_messages[0].text == "Tell me about pipelines in the Haystack context." 1582 1583 1584 @pytest.mark.integration 1585 class TestAgentUserPromptInPipeline: 1586 @pytest.fixture 1587 def document_store_with_docs(self): 1588 store = InMemoryDocumentStore() 1589 store.write_documents( 1590 [ 1591 Document(content="The Eiffel Tower is located in Paris."), 1592 Document(content="The Brandenburg Gate is in Berlin."), 1593 Document(content="The Colosseum is in Rome."), 1594 ] 1595 ) 1596 return store 1597 1598 @pytest.fixture 1599 def make_rag_pipeline(self, document_store_with_docs: InMemoryDocumentStore, make_agent): 1600 1601 def _factory(user_prompt: str | None = None): 1602 agent = make_agent( 1603 user_prompt=user_prompt 1604 or _user_msg( 1605 "Use the following documents to answer the question.\n" 1606 "Documents:\n{% for doc in documents %}{{doc.content}}\n{% endfor %}" 1607 "Question: {{query}}" 1608 ), 1609 system_prompt="You are a knowledgeable assistant.", 1610 required_variables=["query", "documents"], 1611 ) 1612 1613 pp = Pipeline() 1614 pp.add_component("retriever", InMemoryBM25Retriever(document_store=document_store_with_docs)) 1615 pp.add_component("agent", agent) 1616 pp.connect("retriever.documents", "agent.documents") 1617 1618 return pp 1619 1620 return _factory 1621 1622 def test_rag_pipeline_user_prompt_init_only(self, make_rag_pipeline): 1623 pipeline = make_rag_pipeline() 1624 query = "Where is the Colosseum?" 1625 result = pipeline.run(data={"retriever": {"query": query}, "agent": {"query": query, "messages": []}}) 1626 assert "agent" in result 1627 agent_output = result["agent"] 1628 assert "messages" in agent_output 1629 assert "last_message" in agent_output 1630 1631 messages = agent_output["messages"] 1632 assert messages[0].is_from(ChatRole.SYSTEM) 1633 assert messages[0].text == "You are a knowledgeable assistant." 1634 1635 user_messages = [m for m in messages if m.is_from(ChatRole.USER)] 1636 assert len(user_messages) == 1 1637 rendered = user_messages[0].text 1638 assert "Question: Where is the Colosseum?" in rendered 1639 assert "Documents:" in rendered 1640 1641 def test_rag_pipeline_user_prompt_runtime_override(self, make_rag_pipeline): 1642 user_prompt = _user_msg( 1643 "Documents:\n{% for doc in documents %}{{doc.content}}\n{% endfor %}Question: {{query}}" 1644 ) 1645 pipeline = make_rag_pipeline(user_prompt=user_prompt) 1646 1647 query = "Where is the Eiffel Tower?" 1648 result = pipeline.run( 1649 data={ 1650 "retriever": {"query": query}, 1651 "agent": { 1652 "user_prompt": _user_msg( 1653 "OVERRIDE: Using docs:\n" 1654 "{% for doc in documents %}{{doc.content}}\n{% endfor %}" 1655 "Answer: {{query}}" 1656 ), 1657 "query": query, 1658 "messages": [], 1659 }, 1660 } 1661 ) 1662 messages = result["agent"]["messages"] 1663 user_messages = [m for m in messages if m.is_from(ChatRole.USER)] 1664 rendered = user_messages[0].text 1665 assert "OVERRIDE:" in rendered 1666 assert "Where is the Eiffel Tower?" in rendered 1667 1668 def test_rag_pipeline_messages_plus_user_prompt(self, document_store_with_docs, weather_tool): 1669 from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder 1670 1671 chat_generator = MockChatGenerator() 1672 1673 agent = Agent( 1674 chat_generator=chat_generator, 1675 tools=[weather_tool], 1676 user_prompt=_user_msg("Relevant docs:\n{% for doc in documents %}{{doc.content}}\n{% endfor %}"), 1677 ) 1678 chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Berlin")]}) 1679 1680 pipeline = Pipeline() 1681 pipeline.add_component( 1682 "prompt_builder", ChatPromptBuilder(template=[ChatMessage.from_user("History: {{history_note}}")]) 1683 ) 1684 pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store_with_docs)) 1685 pipeline.add_component("agent", agent) 1686 1687 pipeline.connect("prompt_builder.prompt", "agent.messages") 1688 pipeline.connect("retriever.documents", "agent.documents") 1689 1690 result = pipeline.run( 1691 data={ 1692 "prompt_builder": {"history_note": "User previously asked about European cities."}, 1693 "retriever": {"query": "Brandenburg Gate"}, 1694 } 1695 ) 1696 messages = result["agent"]["messages"] 1697 user_messages = [m for m in messages if m.is_from(ChatRole.USER)] 1698 assert "History:" in user_messages[0].text 1699 rendered = user_messages[1].text 1700 assert "Relevant docs:" in rendered 1701 1702 1703 class TestAgentWaitsForBlockedPredecessor: 1704 """ 1705 Regression test for the scheduling bug introduced by making the 'messages' 1706 run parameter non-required in https://github.com/deepset-ai/haystack/pull/10638. 1707 1708 Pipeline shape 1709 -------------- 1710 Two paths feed into a lazy-variadic joiner that collects messages for the Agent: 1711 1712 Path A (works): query → history_parser → messages_joiner.values 1713 Path B (blocked): files=[] → files_processor (returns {}) → attachments_builder ──╳──→ messages_joiner.values 1714 1715 messages_joiner.values → agent.messages 1716 filters → agent.retrieval_filters (static input from pipeline.run data) 1717 1718 The bug 1719 ------- 1720 1. history_parser runs → sends messages to messages_joiner. 1721 2. files_processor runs with files=[] → returns {} (no output). 1722 3. attachments_builder is BLOCKED — its mandatory processed_files input never arrives. 1723 4. messages_joiner gets DEFER_LAST (priority=4): it has a lazy-variadic socket and attachments_builder hasn't 1724 executed yet, so the joiner doesn't know if more data might still come. It keeps waiting. 1725 5. agent gets DEFER (priority=3): retrieval_filters arrives with sender=None (static pipeline input), which 1726 satisfies has_any_trigger() on the first visit. The Agent has no mandatory sockets, so can_component_run() 1727 returns True. It also has no unresolved lazy-variadic sockets, so it gets DEFER rather than DEFER_LAST. 1728 6. Since DEFER (3) < DEFER_LAST (4), the scheduler picks the Agent before the joiner runs. 1729 The Agent executes without messages and raises: 1730 1731 ValueError("No messages provided to the Agent and neither user_prompt nor system_prompt is set.") 1732 """ 1733 1734 def test_agent_waits_for_messages_when_predecessor_is_blocked(self, weather_tool): 1735 1736 @component 1737 class HistoryParser: 1738 @component.output_types(messages=list[ChatMessage]) 1739 def run(self, query: str) -> dict: 1740 return {"messages": [ChatMessage.from_user(query)]} 1741 1742 @component 1743 class FilesProcessor: 1744 """Produces no output when given an empty file list.""" 1745 1746 @component.output_types(processed_files=list[str]) 1747 def run(self, files: list[str]) -> dict: 1748 if not files: 1749 return {} # _NO_OUTPUT_PRODUCED → blocks AttachmentsBuilder 1750 return {"processed_files": files} 1751 1752 @component 1753 class AttachmentsBuilder: 1754 """Builds attachment messages; mandatory processed_files from FilesProcessor.""" 1755 1756 @component.output_types(prompt=list[ChatMessage]) 1757 def run(self, processed_files: list[str]) -> dict: 1758 return {"prompt": [ChatMessage.from_user(f"Files: {processed_files}")]} 1759 1760 chat_generator = MockChatGenerator() 1761 agent = Agent( 1762 chat_generator=chat_generator, 1763 tools=[weather_tool], 1764 state_schema={"retrieval_filters": {"type": dict[str, Any]}}, 1765 ) 1766 chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("done")]}) 1767 1768 pipeline = Pipeline() 1769 pipeline.add_component("history_parser", HistoryParser()) 1770 pipeline.add_component("files_processor", FilesProcessor()) 1771 pipeline.add_component("attachments_builder", AttachmentsBuilder()) 1772 pipeline.add_component("messages_joiner", ListJoiner(list[ChatMessage])) 1773 pipeline.add_component("agent", agent) 1774 1775 pipeline.connect("history_parser.messages", "messages_joiner.values") 1776 pipeline.connect("files_processor.processed_files", "attachments_builder.processed_files") 1777 pipeline.connect("attachments_builder.prompt", "messages_joiner.values") 1778 pipeline.connect("messages_joiner.values", "agent.messages") 1779 1780 # files=[] → files_processor produces no output → attachments_builder BLOCKED 1781 # → messages_joiner stays DEFER_LAST 1782 # → agent (DEFER) runs first without messages → ValueError 1783 result = pipeline.run( 1784 data={ 1785 "history_parser": {"query": "What case law applies?"}, 1786 "files_processor": {"files": []}, # empty → no output 1787 "agent": {"retrieval_filters": {"field": "date", "value": "2024-01-01"}}, 1788 } 1789 ) 1790 assert "agent" in result