/ tests / test_model_tools.py
test_model_tools.py
  1  """Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets."""
  2  
  3  import json
  4  from unittest.mock import ANY, call, patch
  5  
  6  import pytest
  7  
  8  from model_tools import (
  9      handle_function_call,
 10      get_all_tool_names,
 11      get_toolset_for_tool,
 12      _AGENT_LOOP_TOOLS,
 13      _LEGACY_TOOLSET_MAP,
 14      TOOL_TO_TOOLSET_MAP,
 15  )
 16  
 17  
 18  # =========================================================================
 19  # handle_function_call
 20  # =========================================================================
 21  
 22  class TestHandleFunctionCall:
 23      def test_agent_loop_tool_returns_error(self):
 24          for tool_name in _AGENT_LOOP_TOOLS:
 25              result = json.loads(handle_function_call(tool_name, {}))
 26              assert "error" in result
 27              assert "agent loop" in result["error"].lower()
 28  
 29      def test_unknown_tool_returns_error(self):
 30          result = json.loads(handle_function_call("totally_fake_tool_xyz", {}))
 31          assert "error" in result
 32          assert "totally_fake_tool_xyz" in result["error"]
 33  
 34      def test_exception_returns_json_error(self):
 35          # Even if something goes wrong, should return valid JSON
 36          result = handle_function_call("web_search", None)  # None args may cause issues
 37          parsed = json.loads(result)
 38          assert isinstance(parsed, dict)
 39          assert "error" in parsed
 40          assert len(parsed["error"]) > 0
 41          assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower()
 42  
 43      def test_tool_hooks_receive_session_and_tool_call_ids(self):
 44          with (
 45              patch("model_tools.registry.dispatch", return_value='{"ok":true}'),
 46              patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook,
 47          ):
 48              result = handle_function_call(
 49                  "web_search",
 50                  {"q": "test"},
 51                  task_id="task-1",
 52                  tool_call_id="call-1",
 53                  session_id="session-1",
 54              )
 55  
 56          assert result == '{"ok":true}'
 57          assert mock_invoke_hook.call_args_list == [
 58              call(
 59                  "pre_tool_call",
 60                  tool_name="web_search",
 61                  args={"q": "test"},
 62                  task_id="task-1",
 63                  session_id="session-1",
 64                  tool_call_id="call-1",
 65              ),
 66              call(
 67                  "post_tool_call",
 68                  tool_name="web_search",
 69                  args={"q": "test"},
 70                  result='{"ok":true}',
 71                  task_id="task-1",
 72                  session_id="session-1",
 73                  tool_call_id="call-1",
 74                  duration_ms=ANY,
 75              ),
 76              call(
 77                  "transform_tool_result",
 78                  tool_name="web_search",
 79                  args={"q": "test"},
 80                  result='{"ok":true}',
 81                  task_id="task-1",
 82                  session_id="session-1",
 83                  tool_call_id="call-1",
 84                  duration_ms=ANY,
 85              ),
 86          ]
 87  
 88      def test_post_tool_call_receives_non_negative_integer_duration_ms(self):
 89          """Regression: post_tool_call and transform_tool_result hooks must
 90          receive a non-negative integer ``duration_ms`` kwarg measuring
 91          dispatch latency.  Inspired by Claude Code 2.1.119, which added
 92          ``duration_ms`` to its PostToolUse hook inputs.
 93          """
 94          with (
 95              patch("model_tools.registry.dispatch", return_value='{"ok":true}'),
 96              patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook,
 97          ):
 98              handle_function_call("web_search", {"q": "test"}, task_id="t1")
 99  
100          kwargs_by_hook = {
101              c.args[0]: c.kwargs for c in mock_invoke_hook.call_args_list
102          }
103          assert "duration_ms" in kwargs_by_hook["post_tool_call"]
104          assert "duration_ms" in kwargs_by_hook["transform_tool_result"]
105  
106          post_duration = kwargs_by_hook["post_tool_call"]["duration_ms"]
107          transform_duration = kwargs_by_hook["transform_tool_result"]["duration_ms"]
108          assert isinstance(post_duration, int)
109          assert post_duration >= 0
110          # Both hooks should observe the same measured duration.
111          assert post_duration == transform_duration
112          # pre_tool_call does NOT get duration_ms (nothing has run yet).
113          assert "duration_ms" not in kwargs_by_hook["pre_tool_call"]
114  
115  
116  # =========================================================================
117  # Agent loop tools
118  # =========================================================================
119  
120  class TestAgentLoopTools:
121      def test_expected_tools_in_set(self):
122          assert "todo" in _AGENT_LOOP_TOOLS
123          assert "memory" in _AGENT_LOOP_TOOLS
124          assert "session_search" in _AGENT_LOOP_TOOLS
125          assert "delegate_task" in _AGENT_LOOP_TOOLS
126  
127      def test_no_regular_tools_in_set(self):
128          assert "web_search" not in _AGENT_LOOP_TOOLS
129          assert "terminal" not in _AGENT_LOOP_TOOLS
130  
131  
132  # =========================================================================
133  # Pre-tool-call blocking via plugin hooks
134  # =========================================================================
135  
136  class TestPreToolCallBlocking:
137      """Verify that pre_tool_call hooks can block tool execution."""
138  
139      def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch):
140          def fake_invoke_hook(hook_name, **kwargs):
141              if hook_name == "pre_tool_call":
142                  return [{"action": "block", "message": "Blocked by policy"}]
143              return []
144  
145          dispatch_called = False
146          _orig_dispatch = None
147  
148          def fake_dispatch(*args, **kwargs):
149              nonlocal dispatch_called
150              dispatch_called = True
151              raise AssertionError("dispatch should not run when blocked")
152  
153          monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
154          monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
155  
156          result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1"))
157          assert result == {"error": "Blocked by policy"}
158          assert not dispatch_called
159  
160      def test_blocked_tool_skips_read_loop_notification(self, monkeypatch):
161          notifications = []
162  
163          def fake_invoke_hook(hook_name, **kwargs):
164              if hook_name == "pre_tool_call":
165                  return [{"action": "block", "message": "Blocked"}]
166              return []
167  
168          monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
169          monkeypatch.setattr("model_tools.registry.dispatch",
170                              lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run")))
171          monkeypatch.setattr("tools.file_tools.notify_other_tool_call",
172                              lambda task_id: notifications.append(task_id))
173  
174          result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1"))
175          assert result == {"error": "Blocked"}
176          assert notifications == []
177  
178      def test_invalid_hook_returns_do_not_block(self, monkeypatch):
179          """Malformed hook returns should be ignored — tool executes normally."""
180          def fake_invoke_hook(hook_name, **kwargs):
181              if hook_name == "pre_tool_call":
182                  return [
183                      "block",
184                      {"action": "block"},           # missing message
185                      {"action": "deny", "message": "nope"},
186                  ]
187              return []
188  
189          monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
190          monkeypatch.setattr("model_tools.registry.dispatch",
191                              lambda *a, **kw: json.dumps({"ok": True}))
192  
193          result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1"))
194          assert result == {"ok": True}
195  
196      def test_skip_flag_prevents_double_fire(self, monkeypatch):
197          """When skip_pre_tool_call_hook=True, the hook does not fire again.
198  
199          The caller (e.g. run_agent._invoke_tool) has already called
200          get_pre_tool_call_block_message(), which fires the hook once.
201          handle_function_call must NOT fire it a second time — that was
202          the classic double-fire bug where observer hooks logged every
203          tool call twice.
204          """
205          hook_calls = []
206  
207          def fake_invoke_hook(hook_name, **kwargs):
208              hook_calls.append(hook_name)
209              return []
210  
211          monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
212          monkeypatch.setattr("model_tools.registry.dispatch",
213                              lambda *a, **kw: json.dumps({"ok": True}))
214  
215          handle_function_call("web_search", {"q": "test"}, task_id="t1",
216                               skip_pre_tool_call_hook=True)
217  
218          # Single-fire contract: when skip=True the caller already fired
219          # pre_tool_call, so handle_function_call must not fire it again.
220          assert hook_calls.count("pre_tool_call") == 0, (
221              f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times "
222              f"with skip_pre_tool_call_hook=True; expected 0 "
223              f"(caller already fired it). hook_calls={hook_calls}"
224          )
225          # post_tool_call and transform_tool_result still fire — only the
226          # pre-call block-check path is suppressed by the skip flag.
227          assert "post_tool_call" in hook_calls
228          assert "transform_tool_result" in hook_calls
229  
230      def test_run_agent_pattern_fires_pre_tool_call_exactly_once(self, monkeypatch):
231          """End-to-end regression for the double-fire bug.
232  
233          Mirrors run_agent._invoke_tool: first calls
234          get_pre_tool_call_block_message() (which fires the hook as part of
235          its block-directive poll), then calls
236          handle_function_call(skip_pre_tool_call_hook=True).  The plugin
237          hook MUST fire exactly once across both calls — not twice as it
238          did before the fix (observer plugins were seeing every tool
239          execution logged twice).
240          """
241          from hermes_cli.plugins import get_pre_tool_call_block_message
242  
243          hook_calls = []
244  
245          def fake_invoke_hook(hook_name, **kwargs):
246              hook_calls.append(hook_name)
247              return []
248  
249          monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
250          monkeypatch.setattr("model_tools.registry.dispatch",
251                              lambda *a, **kw: json.dumps({"ok": True}))
252  
253          # Step 1: caller checks for a block directive (this fires pre_tool_call once).
254          block = get_pre_tool_call_block_message(
255              "web_search", {"q": "test"}, task_id="t1",
256          )
257          assert block is None
258  
259          # Step 2: caller dispatches with skip=True so the hook isn't re-fired.
260          handle_function_call(
261              "web_search", {"q": "test"}, task_id="t1",
262              skip_pre_tool_call_hook=True,
263          )
264  
265          assert hook_calls.count("pre_tool_call") == 1, (
266              f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times "
267              f"across the run_agent (block-check + dispatch) path; "
268              f"expected exactly 1. hook_calls={hook_calls}"
269          )
270  
271  
272  # =========================================================================
273  # Legacy toolset map
274  # =========================================================================
275  
276  class TestLegacyToolsetMap:
277      def test_expected_legacy_names(self):
278          expected = [
279              "web_tools", "terminal_tools", "vision_tools", "moa_tools",
280              "image_tools", "skills_tools", "browser_tools", "cronjob_tools",
281              "rl_tools", "file_tools", "tts_tools",
282          ]
283          for name in expected:
284              assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}"
285  
286      def test_values_are_lists_of_strings(self):
287          for name, tools in _LEGACY_TOOLSET_MAP.items():
288              assert isinstance(tools, list), f"{name} is not a list"
289              for tool in tools:
290                  assert isinstance(tool, str), f"{name} contains non-string: {tool}"
291  
292  
293  # =========================================================================
294  # Backward-compat wrappers
295  # =========================================================================
296  
297  class TestBackwardCompat:
298      def test_get_all_tool_names_returns_list(self):
299          names = get_all_tool_names()
300          assert isinstance(names, list)
301          assert len(names) > 0
302          # Should contain well-known tools
303          assert "web_search" in names
304          assert "terminal" in names
305  
306      def test_get_toolset_for_tool(self):
307          result = get_toolset_for_tool("web_search")
308          assert result is not None
309          assert isinstance(result, str)
310  
311      def test_get_toolset_for_unknown_tool(self):
312          result = get_toolset_for_tool("totally_nonexistent_tool")
313          assert result is None
314  
315      def test_tool_to_toolset_map(self):
316          assert isinstance(TOOL_TO_TOOLSET_MAP, dict)
317          assert len(TOOL_TO_TOOLSET_MAP) > 0
318  
319  
320  # =========================================================================
321  # _coerce_number — inf / nan must fall through to the original string
322  # (regression: fix: eliminate duplicate checkpoint entries and JSON-unsafe coercion)
323  # =========================================================================
324  
325  class TestCoerceNumberInfNan:
326      """_coerce_number must honor its documented contract ("Returns original
327      string on failure") for inf/nan inputs, because float('inf') and
328      float('nan') are not JSON-compliant under strict serialization."""
329  
330      def test_inf_returns_original_string(self):
331          from model_tools import _coerce_number
332          assert _coerce_number("inf") == "inf"
333  
334      def test_negative_inf_returns_original_string(self):
335          from model_tools import _coerce_number
336          assert _coerce_number("-inf") == "-inf"
337  
338      def test_nan_returns_original_string(self):
339          from model_tools import _coerce_number
340          assert _coerce_number("nan") == "nan"
341  
342      def test_infinity_spelling_returns_original_string(self):
343          from model_tools import _coerce_number
344          # Python's float() parses "Infinity" too — still not JSON-safe.
345          assert _coerce_number("Infinity") == "Infinity"
346  
347      def test_coerced_result_is_strict_json_safe(self):
348          """Whatever _coerce_number returns for inf/nan must round-trip
349          through strict (allow_nan=False) json.dumps without raising."""
350          from model_tools import _coerce_number
351          for s in ("inf", "-inf", "nan", "Infinity"):
352              result = _coerce_number(s)
353              json.dumps({"x": result}, allow_nan=False)  # must not raise
354  
355      def test_normal_numbers_still_coerce(self):
356          """Guard against over-correction — real numbers still coerce."""
357          from model_tools import _coerce_number
358          assert _coerce_number("42") == 42
359          assert _coerce_number("3.14") == 3.14
360          assert _coerce_number("1e3") == 1000