test_tool_call_guardrail_runtime.py
1 """Runtime tests for tool-call loop guardrails.""" 2 3 import json 4 import uuid 5 from types import SimpleNamespace 6 from unittest.mock import MagicMock, patch 7 8 from run_agent import AIAgent 9 10 11 def _make_tool_defs(*names: str) -> list[dict]: 12 return [ 13 { 14 "type": "function", 15 "function": { 16 "name": name, 17 "description": f"{name} tool", 18 "parameters": {"type": "object", "properties": {}}, 19 }, 20 } 21 for name in names 22 ] 23 24 25 def _mock_tool_call(name="web_search", arguments="{}", call_id=None): 26 return SimpleNamespace( 27 id=call_id or f"call_{uuid.uuid4().hex[:8]}", 28 type="function", 29 function=SimpleNamespace(name=name, arguments=arguments), 30 ) 31 32 33 def _mock_response(content="Hello", finish_reason="stop", tool_calls=None): 34 msg = SimpleNamespace(content=content, tool_calls=tool_calls) 35 choice = SimpleNamespace(message=msg, finish_reason=finish_reason) 36 return SimpleNamespace(choices=[choice], model="test/model", usage=None) 37 38 39 def _make_agent(*tool_names: str, max_iterations: int = 10, config: dict | None = None) -> AIAgent: 40 with ( 41 patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)), 42 patch("run_agent.check_toolset_requirements", return_value={}), 43 patch("hermes_cli.config.load_config", return_value=config or {}), 44 patch("run_agent.OpenAI"), 45 ): 46 agent = AIAgent( 47 api_key="test-key-1234567890", 48 base_url="https://openrouter.ai/api/v1", 49 max_iterations=max_iterations, 50 quiet_mode=True, 51 skip_context_files=True, 52 skip_memory=True, 53 ) 54 agent.client = MagicMock() 55 agent._cached_system_prompt = "You are helpful." 56 agent._use_prompt_caching = False 57 agent.tool_delay = 0 58 agent.compression_enabled = False 59 agent.save_trajectories = False 60 return agent 61 62 63 def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None: 64 for _ in range(count): 65 agent._tool_guardrails.after_call( 66 tool_name, 67 args, 68 json.dumps({"error": "boom"}), 69 failed=True, 70 ) 71 72 73 def _hard_stop_config(**overrides) -> dict: 74 cfg = { 75 "tool_loop_guardrails": { 76 "warnings_enabled": True, 77 "hard_stop_enabled": True, 78 "hard_stop_after": { 79 "exact_failure": 2, 80 "same_tool_failure": 8, 81 "idempotent_no_progress": 5, 82 }, 83 } 84 } 85 cfg["tool_loop_guardrails"].update(overrides) 86 return cfg 87 88 89 def test_default_sequential_path_warns_repeated_exact_failure_without_blocking_execution(): 90 agent = _make_agent("web_search") 91 args = {"query": "same"} 92 _seed_exact_failures(agent, "web_search", args) 93 starts = [] 94 progress = [] 95 agent.tool_start_callback = lambda *a, **k: starts.append((a, k)) 96 agent.tool_progress_callback = lambda *a, **k: progress.append((a, k)) 97 tc = _mock_tool_call("web_search", json.dumps(args), "c-soft") 98 msg = SimpleNamespace(content="", tool_calls=[tc]) 99 messages = [] 100 101 with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc: 102 agent._execute_tool_calls_sequential(msg, messages, "task-1") 103 104 mock_hfc.assert_called_once() 105 assert len(starts) == 1 106 assert any(event[0][0] == "tool.completed" for event in progress) 107 assert len(messages) == 1 108 assert messages[0]["role"] == "tool" 109 assert messages[0]["tool_call_id"] == "c-soft" 110 assert "repeated_exact_failure_warning" in messages[0]["content"] 111 assert "repeated_exact_failure_block" not in messages[0]["content"] 112 assert agent._tool_guardrail_halt_decision is None 113 114 115 def test_config_enabled_hard_stop_blocks_repeated_exact_failure_before_execution(): 116 agent = _make_agent("web_search", config=_hard_stop_config()) 117 args = {"query": "same"} 118 _seed_exact_failures(agent, "web_search", args) 119 starts = [] 120 progress = [] 121 agent.tool_start_callback = lambda *a, **k: starts.append((a, k)) 122 agent.tool_progress_callback = lambda *a, **k: progress.append((a, k)) 123 tc = _mock_tool_call("web_search", json.dumps(args), "c-block") 124 msg = SimpleNamespace(content="", tool_calls=[tc]) 125 messages = [] 126 127 with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc: 128 agent._execute_tool_calls_sequential(msg, messages, "task-1") 129 130 mock_hfc.assert_not_called() 131 assert starts == [] 132 assert progress == [] 133 assert len(messages) == 1 134 assert messages[0]["role"] == "tool" 135 assert messages[0]["tool_call_id"] == "c-block" 136 assert "repeated_exact_failure_block" in messages[0]["content"] 137 138 139 def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages(): 140 agent = _make_agent("web_search") 141 args = {"query": "same"} 142 _seed_exact_failures(agent, "web_search", args, count=1) 143 tc = _mock_tool_call("web_search", json.dumps(args), "c-warn") 144 msg = SimpleNamespace(content="", tool_calls=[tc]) 145 messages = [] 146 147 with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})): 148 agent._execute_tool_calls_sequential(msg, messages, "task-1") 149 150 assert [m["role"] for m in messages] == ["tool"] 151 assert messages[0]["tool_call_id"] == "c-warn" 152 assert "Tool loop warning" in messages[0]["content"] 153 assert "repeated_exact_failure_warning" in messages[0]["content"] 154 155 156 def test_config_enabled_hard_stop_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order(): 157 agent = _make_agent("web_search", config=_hard_stop_config()) 158 blocked_args = {"query": "blocked"} 159 allowed_args = {"query": "allowed"} 160 _seed_exact_failures(agent, "web_search", blocked_args) 161 starts = [] 162 progress_events = [] 163 agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args)) 164 agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw)) 165 calls = [ 166 _mock_tool_call("web_search", json.dumps(blocked_args), "c-block"), 167 _mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"), 168 ] 169 msg = SimpleNamespace(content="", tool_calls=calls) 170 messages = [] 171 executed = [] 172 173 def fake_handle(name, args, task_id, **kwargs): 174 executed.append((name, args, kwargs["tool_call_id"])) 175 return json.dumps({"ok": args["query"]}) 176 177 with patch("run_agent.handle_function_call", side_effect=fake_handle): 178 agent._execute_tool_calls_concurrent(msg, messages, "task-1") 179 180 assert executed == [("web_search", allowed_args, "c-allow")] 181 assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"] 182 assert "repeated_exact_failure_block" in messages[0]["content"] 183 assert json.loads(messages[1]["content"]) == {"ok": "allowed"} 184 assert starts == [("c-allow", "web_search", allowed_args)] 185 started_events = [event for event in progress_events if event[0] == "tool.started"] 186 completed_events = [event for event in progress_events if event[0] == "tool.completed"] 187 assert started_events == [("tool.started", "web_search", allowed_args, {})] 188 assert len(completed_events) == 1 189 assert completed_events[0][1] == "web_search" 190 191 192 def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block(): 193 agent = _make_agent("web_search") 194 args = {"query": "same"} 195 tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin") 196 msg = SimpleNamespace(content="", tool_calls=[tc]) 197 messages = [] 198 199 with ( 200 patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"), 201 patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc, 202 ): 203 agent._execute_tool_calls_sequential(msg, messages, "task-1") 204 205 mock_hfc.assert_not_called() 206 assert "plugin policy" in messages[0]["content"] 207 assert agent._tool_guardrails.before_call("web_search", args).action == "allow" 208 209 210 def test_default_run_conversation_warns_without_guardrail_halt(): 211 agent = _make_agent("web_search", max_iterations=10) 212 same_args = {"query": "same"} 213 responses = [ 214 _mock_response( 215 content="", 216 finish_reason="tool_calls", 217 tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")], 218 ) 219 for i in range(1, 4) 220 ] 221 responses.append(_mock_response(content="done", finish_reason="stop", tool_calls=None)) 222 agent.client.chat.completions.create.side_effect = responses 223 224 with ( 225 patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc, 226 patch.object(agent, "_persist_session"), 227 patch.object(agent, "_save_trajectory"), 228 patch.object(agent, "_cleanup_task_resources"), 229 ): 230 result = agent.run_conversation("search repeatedly") 231 232 assert mock_hfc.call_count == 3 233 assert result["turn_exit_reason"].startswith("text_response") 234 assert "guardrail" not in result 235 assert result["final_response"] == "done" 236 tool_contents = [m["content"] for m in result["messages"] if m.get("role") == "tool"] 237 assert any("repeated_exact_failure_warning" in content for content in tool_contents) 238 239 240 def test_config_enabled_hard_stop_run_conversation_returns_controlled_guardrail_halt_without_top_level_error(): 241 agent = _make_agent("web_search", max_iterations=10, config=_hard_stop_config()) 242 same_args = {"query": "same"} 243 responses = [ 244 _mock_response( 245 content="", 246 finish_reason="tool_calls", 247 tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")], 248 ) 249 for i in range(1, 10) 250 ] 251 agent.client.chat.completions.create.side_effect = responses 252 253 with ( 254 patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc, 255 patch.object(agent, "_persist_session"), 256 patch.object(agent, "_save_trajectory"), 257 patch.object(agent, "_cleanup_task_resources"), 258 ): 259 result = agent.run_conversation("search repeatedly") 260 261 assert mock_hfc.call_count == 2 262 assert result["api_calls"] == 3 263 assert result["api_calls"] < agent.max_iterations 264 assert result["turn_exit_reason"] == "guardrail_halt" 265 assert "error" not in result 266 assert result["completed"] is True 267 assert "stopped retrying" in result["final_response"] 268 assert result["guardrail"]["code"] == "repeated_exact_failure_block" 269 assert result["guardrail"]["tool_name"] == "web_search" 270 271 assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")] 272 for assistant_msg in assistant_tool_calls: 273 call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]] 274 following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids] 275 assert len(following_results) == len(call_ids)