/ tests / run_agent / test_tool_call_guardrail_runtime.py
test_tool_call_guardrail_runtime.py
  1  """Runtime tests for tool-call loop guardrails."""
  2  
  3  import json
  4  import uuid
  5  from types import SimpleNamespace
  6  from unittest.mock import MagicMock, patch
  7  
  8  from run_agent import AIAgent
  9  
 10  
 11  def _make_tool_defs(*names: str) -> list[dict]:
 12      return [
 13          {
 14              "type": "function",
 15              "function": {
 16                  "name": name,
 17                  "description": f"{name} tool",
 18                  "parameters": {"type": "object", "properties": {}},
 19              },
 20          }
 21          for name in names
 22      ]
 23  
 24  
 25  def _mock_tool_call(name="web_search", arguments="{}", call_id=None):
 26      return SimpleNamespace(
 27          id=call_id or f"call_{uuid.uuid4().hex[:8]}",
 28          type="function",
 29          function=SimpleNamespace(name=name, arguments=arguments),
 30      )
 31  
 32  
 33  def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
 34      msg = SimpleNamespace(content=content, tool_calls=tool_calls)
 35      choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
 36      return SimpleNamespace(choices=[choice], model="test/model", usage=None)
 37  
 38  
 39  def _make_agent(*tool_names: str, max_iterations: int = 10, config: dict | None = None) -> AIAgent:
 40      with (
 41          patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)),
 42          patch("run_agent.check_toolset_requirements", return_value={}),
 43          patch("hermes_cli.config.load_config", return_value=config or {}),
 44          patch("run_agent.OpenAI"),
 45      ):
 46          agent = AIAgent(
 47              api_key="test-key-1234567890",
 48              base_url="https://openrouter.ai/api/v1",
 49              max_iterations=max_iterations,
 50              quiet_mode=True,
 51              skip_context_files=True,
 52              skip_memory=True,
 53          )
 54      agent.client = MagicMock()
 55      agent._cached_system_prompt = "You are helpful."
 56      agent._use_prompt_caching = False
 57      agent.tool_delay = 0
 58      agent.compression_enabled = False
 59      agent.save_trajectories = False
 60      return agent
 61  
 62  
 63  def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None:
 64      for _ in range(count):
 65          agent._tool_guardrails.after_call(
 66              tool_name,
 67              args,
 68              json.dumps({"error": "boom"}),
 69              failed=True,
 70          )
 71  
 72  
 73  def _hard_stop_config(**overrides) -> dict:
 74      cfg = {
 75          "tool_loop_guardrails": {
 76              "warnings_enabled": True,
 77              "hard_stop_enabled": True,
 78              "hard_stop_after": {
 79                  "exact_failure": 2,
 80                  "same_tool_failure": 8,
 81                  "idempotent_no_progress": 5,
 82              },
 83          }
 84      }
 85      cfg["tool_loop_guardrails"].update(overrides)
 86      return cfg
 87  
 88  
 89  def test_default_sequential_path_warns_repeated_exact_failure_without_blocking_execution():
 90      agent = _make_agent("web_search")
 91      args = {"query": "same"}
 92      _seed_exact_failures(agent, "web_search", args)
 93      starts = []
 94      progress = []
 95      agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
 96      agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
 97      tc = _mock_tool_call("web_search", json.dumps(args), "c-soft")
 98      msg = SimpleNamespace(content="", tool_calls=[tc])
 99      messages = []
100  
101      with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc:
102          agent._execute_tool_calls_sequential(msg, messages, "task-1")
103  
104      mock_hfc.assert_called_once()
105      assert len(starts) == 1
106      assert any(event[0][0] == "tool.completed" for event in progress)
107      assert len(messages) == 1
108      assert messages[0]["role"] == "tool"
109      assert messages[0]["tool_call_id"] == "c-soft"
110      assert "repeated_exact_failure_warning" in messages[0]["content"]
111      assert "repeated_exact_failure_block" not in messages[0]["content"]
112      assert agent._tool_guardrail_halt_decision is None
113  
114  
115  def test_config_enabled_hard_stop_blocks_repeated_exact_failure_before_execution():
116      agent = _make_agent("web_search", config=_hard_stop_config())
117      args = {"query": "same"}
118      _seed_exact_failures(agent, "web_search", args)
119      starts = []
120      progress = []
121      agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
122      agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
123      tc = _mock_tool_call("web_search", json.dumps(args), "c-block")
124      msg = SimpleNamespace(content="", tool_calls=[tc])
125      messages = []
126  
127      with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc:
128          agent._execute_tool_calls_sequential(msg, messages, "task-1")
129  
130      mock_hfc.assert_not_called()
131      assert starts == []
132      assert progress == []
133      assert len(messages) == 1
134      assert messages[0]["role"] == "tool"
135      assert messages[0]["tool_call_id"] == "c-block"
136      assert "repeated_exact_failure_block" in messages[0]["content"]
137  
138  
139  def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages():
140      agent = _make_agent("web_search")
141      args = {"query": "same"}
142      _seed_exact_failures(agent, "web_search", args, count=1)
143      tc = _mock_tool_call("web_search", json.dumps(args), "c-warn")
144      msg = SimpleNamespace(content="", tool_calls=[tc])
145      messages = []
146  
147      with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})):
148          agent._execute_tool_calls_sequential(msg, messages, "task-1")
149  
150      assert [m["role"] for m in messages] == ["tool"]
151      assert messages[0]["tool_call_id"] == "c-warn"
152      assert "Tool loop warning" in messages[0]["content"]
153      assert "repeated_exact_failure_warning" in messages[0]["content"]
154  
155  
156  def test_config_enabled_hard_stop_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order():
157      agent = _make_agent("web_search", config=_hard_stop_config())
158      blocked_args = {"query": "blocked"}
159      allowed_args = {"query": "allowed"}
160      _seed_exact_failures(agent, "web_search", blocked_args)
161      starts = []
162      progress_events = []
163      agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args))
164      agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw))
165      calls = [
166          _mock_tool_call("web_search", json.dumps(blocked_args), "c-block"),
167          _mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"),
168      ]
169      msg = SimpleNamespace(content="", tool_calls=calls)
170      messages = []
171      executed = []
172  
173      def fake_handle(name, args, task_id, **kwargs):
174          executed.append((name, args, kwargs["tool_call_id"]))
175          return json.dumps({"ok": args["query"]})
176  
177      with patch("run_agent.handle_function_call", side_effect=fake_handle):
178          agent._execute_tool_calls_concurrent(msg, messages, "task-1")
179  
180      assert executed == [("web_search", allowed_args, "c-allow")]
181      assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"]
182      assert "repeated_exact_failure_block" in messages[0]["content"]
183      assert json.loads(messages[1]["content"]) == {"ok": "allowed"}
184      assert starts == [("c-allow", "web_search", allowed_args)]
185      started_events = [event for event in progress_events if event[0] == "tool.started"]
186      completed_events = [event for event in progress_events if event[0] == "tool.completed"]
187      assert started_events == [("tool.started", "web_search", allowed_args, {})]
188      assert len(completed_events) == 1
189      assert completed_events[0][1] == "web_search"
190  
191  
192  def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block():
193      agent = _make_agent("web_search")
194      args = {"query": "same"}
195      tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin")
196      msg = SimpleNamespace(content="", tool_calls=[tc])
197      messages = []
198  
199      with (
200          patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"),
201          patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc,
202      ):
203          agent._execute_tool_calls_sequential(msg, messages, "task-1")
204  
205      mock_hfc.assert_not_called()
206      assert "plugin policy" in messages[0]["content"]
207      assert agent._tool_guardrails.before_call("web_search", args).action == "allow"
208  
209  
210  def test_default_run_conversation_warns_without_guardrail_halt():
211      agent = _make_agent("web_search", max_iterations=10)
212      same_args = {"query": "same"}
213      responses = [
214          _mock_response(
215              content="",
216              finish_reason="tool_calls",
217              tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
218          )
219          for i in range(1, 4)
220      ]
221      responses.append(_mock_response(content="done", finish_reason="stop", tool_calls=None))
222      agent.client.chat.completions.create.side_effect = responses
223  
224      with (
225          patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
226          patch.object(agent, "_persist_session"),
227          patch.object(agent, "_save_trajectory"),
228          patch.object(agent, "_cleanup_task_resources"),
229      ):
230          result = agent.run_conversation("search repeatedly")
231  
232      assert mock_hfc.call_count == 3
233      assert result["turn_exit_reason"].startswith("text_response")
234      assert "guardrail" not in result
235      assert result["final_response"] == "done"
236      tool_contents = [m["content"] for m in result["messages"] if m.get("role") == "tool"]
237      assert any("repeated_exact_failure_warning" in content for content in tool_contents)
238  
239  
240  def test_config_enabled_hard_stop_run_conversation_returns_controlled_guardrail_halt_without_top_level_error():
241      agent = _make_agent("web_search", max_iterations=10, config=_hard_stop_config())
242      same_args = {"query": "same"}
243      responses = [
244          _mock_response(
245              content="",
246              finish_reason="tool_calls",
247              tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
248          )
249          for i in range(1, 10)
250      ]
251      agent.client.chat.completions.create.side_effect = responses
252  
253      with (
254          patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
255          patch.object(agent, "_persist_session"),
256          patch.object(agent, "_save_trajectory"),
257          patch.object(agent, "_cleanup_task_resources"),
258      ):
259          result = agent.run_conversation("search repeatedly")
260  
261      assert mock_hfc.call_count == 2
262      assert result["api_calls"] == 3
263      assert result["api_calls"] < agent.max_iterations
264      assert result["turn_exit_reason"] == "guardrail_halt"
265      assert "error" not in result
266      assert result["completed"] is True
267      assert "stopped retrying" in result["final_response"]
268      assert result["guardrail"]["code"] == "repeated_exact_failure_block"
269      assert result["guardrail"]["tool_name"] == "web_search"
270  
271      assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")]
272      for assistant_msg in assistant_tool_calls:
273          call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]]
274          following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids]
275          assert len(following_results) == len(call_ids)