test_model_tools.py
1 """Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets.""" 2 3 import json 4 from unittest.mock import ANY, call, patch 5 6 import pytest 7 8 from model_tools import ( 9 handle_function_call, 10 get_all_tool_names, 11 get_toolset_for_tool, 12 _AGENT_LOOP_TOOLS, 13 _LEGACY_TOOLSET_MAP, 14 TOOL_TO_TOOLSET_MAP, 15 ) 16 17 18 # ========================================================================= 19 # handle_function_call 20 # ========================================================================= 21 22 class TestHandleFunctionCall: 23 def test_agent_loop_tool_returns_error(self): 24 for tool_name in _AGENT_LOOP_TOOLS: 25 result = json.loads(handle_function_call(tool_name, {})) 26 assert "error" in result 27 assert "agent loop" in result["error"].lower() 28 29 def test_unknown_tool_returns_error(self): 30 result = json.loads(handle_function_call("totally_fake_tool_xyz", {})) 31 assert "error" in result 32 assert "totally_fake_tool_xyz" in result["error"] 33 34 def test_exception_returns_json_error(self): 35 # Even if something goes wrong, should return valid JSON 36 result = handle_function_call("web_search", None) # None args may cause issues 37 parsed = json.loads(result) 38 assert isinstance(parsed, dict) 39 assert "error" in parsed 40 assert len(parsed["error"]) > 0 41 assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower() 42 43 def test_tool_hooks_receive_session_and_tool_call_ids(self): 44 with ( 45 patch("model_tools.registry.dispatch", return_value='{"ok":true}'), 46 patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook, 47 ): 48 result = handle_function_call( 49 "web_search", 50 {"q": "test"}, 51 task_id="task-1", 52 tool_call_id="call-1", 53 session_id="session-1", 54 ) 55 56 assert result == '{"ok":true}' 57 assert mock_invoke_hook.call_args_list == [ 58 call( 59 "pre_tool_call", 60 tool_name="web_search", 61 args={"q": "test"}, 62 task_id="task-1", 63 session_id="session-1", 64 tool_call_id="call-1", 65 ), 66 call( 67 "post_tool_call", 68 tool_name="web_search", 69 args={"q": "test"}, 70 result='{"ok":true}', 71 task_id="task-1", 72 session_id="session-1", 73 tool_call_id="call-1", 74 duration_ms=ANY, 75 ), 76 call( 77 "transform_tool_result", 78 tool_name="web_search", 79 args={"q": "test"}, 80 result='{"ok":true}', 81 task_id="task-1", 82 session_id="session-1", 83 tool_call_id="call-1", 84 duration_ms=ANY, 85 ), 86 ] 87 88 def test_post_tool_call_receives_non_negative_integer_duration_ms(self): 89 """Regression: post_tool_call and transform_tool_result hooks must 90 receive a non-negative integer ``duration_ms`` kwarg measuring 91 dispatch latency. Inspired by Claude Code 2.1.119, which added 92 ``duration_ms`` to its PostToolUse hook inputs. 93 """ 94 with ( 95 patch("model_tools.registry.dispatch", return_value='{"ok":true}'), 96 patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook, 97 ): 98 handle_function_call("web_search", {"q": "test"}, task_id="t1") 99 100 kwargs_by_hook = { 101 c.args[0]: c.kwargs for c in mock_invoke_hook.call_args_list 102 } 103 assert "duration_ms" in kwargs_by_hook["post_tool_call"] 104 assert "duration_ms" in kwargs_by_hook["transform_tool_result"] 105 106 post_duration = kwargs_by_hook["post_tool_call"]["duration_ms"] 107 transform_duration = kwargs_by_hook["transform_tool_result"]["duration_ms"] 108 assert isinstance(post_duration, int) 109 assert post_duration >= 0 110 # Both hooks should observe the same measured duration. 111 assert post_duration == transform_duration 112 # pre_tool_call does NOT get duration_ms (nothing has run yet). 113 assert "duration_ms" not in kwargs_by_hook["pre_tool_call"] 114 115 116 # ========================================================================= 117 # Agent loop tools 118 # ========================================================================= 119 120 class TestAgentLoopTools: 121 def test_expected_tools_in_set(self): 122 assert "todo" in _AGENT_LOOP_TOOLS 123 assert "memory" in _AGENT_LOOP_TOOLS 124 assert "session_search" in _AGENT_LOOP_TOOLS 125 assert "delegate_task" in _AGENT_LOOP_TOOLS 126 127 def test_no_regular_tools_in_set(self): 128 assert "web_search" not in _AGENT_LOOP_TOOLS 129 assert "terminal" not in _AGENT_LOOP_TOOLS 130 131 132 # ========================================================================= 133 # Pre-tool-call blocking via plugin hooks 134 # ========================================================================= 135 136 class TestPreToolCallBlocking: 137 """Verify that pre_tool_call hooks can block tool execution.""" 138 139 def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch): 140 def fake_invoke_hook(hook_name, **kwargs): 141 if hook_name == "pre_tool_call": 142 return [{"action": "block", "message": "Blocked by policy"}] 143 return [] 144 145 dispatch_called = False 146 _orig_dispatch = None 147 148 def fake_dispatch(*args, **kwargs): 149 nonlocal dispatch_called 150 dispatch_called = True 151 raise AssertionError("dispatch should not run when blocked") 152 153 monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) 154 monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch) 155 156 result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) 157 assert result == {"error": "Blocked by policy"} 158 assert not dispatch_called 159 160 def test_blocked_tool_skips_read_loop_notification(self, monkeypatch): 161 notifications = [] 162 163 def fake_invoke_hook(hook_name, **kwargs): 164 if hook_name == "pre_tool_call": 165 return [{"action": "block", "message": "Blocked"}] 166 return [] 167 168 monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) 169 monkeypatch.setattr("model_tools.registry.dispatch", 170 lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run"))) 171 monkeypatch.setattr("tools.file_tools.notify_other_tool_call", 172 lambda task_id: notifications.append(task_id)) 173 174 result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1")) 175 assert result == {"error": "Blocked"} 176 assert notifications == [] 177 178 def test_invalid_hook_returns_do_not_block(self, monkeypatch): 179 """Malformed hook returns should be ignored — tool executes normally.""" 180 def fake_invoke_hook(hook_name, **kwargs): 181 if hook_name == "pre_tool_call": 182 return [ 183 "block", 184 {"action": "block"}, # missing message 185 {"action": "deny", "message": "nope"}, 186 ] 187 return [] 188 189 monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) 190 monkeypatch.setattr("model_tools.registry.dispatch", 191 lambda *a, **kw: json.dumps({"ok": True})) 192 193 result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) 194 assert result == {"ok": True} 195 196 def test_skip_flag_prevents_double_fire(self, monkeypatch): 197 """When skip_pre_tool_call_hook=True, the hook does not fire again. 198 199 The caller (e.g. run_agent._invoke_tool) has already called 200 get_pre_tool_call_block_message(), which fires the hook once. 201 handle_function_call must NOT fire it a second time — that was 202 the classic double-fire bug where observer hooks logged every 203 tool call twice. 204 """ 205 hook_calls = [] 206 207 def fake_invoke_hook(hook_name, **kwargs): 208 hook_calls.append(hook_name) 209 return [] 210 211 monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) 212 monkeypatch.setattr("model_tools.registry.dispatch", 213 lambda *a, **kw: json.dumps({"ok": True})) 214 215 handle_function_call("web_search", {"q": "test"}, task_id="t1", 216 skip_pre_tool_call_hook=True) 217 218 # Single-fire contract: when skip=True the caller already fired 219 # pre_tool_call, so handle_function_call must not fire it again. 220 assert hook_calls.count("pre_tool_call") == 0, ( 221 f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times " 222 f"with skip_pre_tool_call_hook=True; expected 0 " 223 f"(caller already fired it). hook_calls={hook_calls}" 224 ) 225 # post_tool_call and transform_tool_result still fire — only the 226 # pre-call block-check path is suppressed by the skip flag. 227 assert "post_tool_call" in hook_calls 228 assert "transform_tool_result" in hook_calls 229 230 def test_run_agent_pattern_fires_pre_tool_call_exactly_once(self, monkeypatch): 231 """End-to-end regression for the double-fire bug. 232 233 Mirrors run_agent._invoke_tool: first calls 234 get_pre_tool_call_block_message() (which fires the hook as part of 235 its block-directive poll), then calls 236 handle_function_call(skip_pre_tool_call_hook=True). The plugin 237 hook MUST fire exactly once across both calls — not twice as it 238 did before the fix (observer plugins were seeing every tool 239 execution logged twice). 240 """ 241 from hermes_cli.plugins import get_pre_tool_call_block_message 242 243 hook_calls = [] 244 245 def fake_invoke_hook(hook_name, **kwargs): 246 hook_calls.append(hook_name) 247 return [] 248 249 monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) 250 monkeypatch.setattr("model_tools.registry.dispatch", 251 lambda *a, **kw: json.dumps({"ok": True})) 252 253 # Step 1: caller checks for a block directive (this fires pre_tool_call once). 254 block = get_pre_tool_call_block_message( 255 "web_search", {"q": "test"}, task_id="t1", 256 ) 257 assert block is None 258 259 # Step 2: caller dispatches with skip=True so the hook isn't re-fired. 260 handle_function_call( 261 "web_search", {"q": "test"}, task_id="t1", 262 skip_pre_tool_call_hook=True, 263 ) 264 265 assert hook_calls.count("pre_tool_call") == 1, ( 266 f"pre_tool_call fired {hook_calls.count('pre_tool_call')} times " 267 f"across the run_agent (block-check + dispatch) path; " 268 f"expected exactly 1. hook_calls={hook_calls}" 269 ) 270 271 272 # ========================================================================= 273 # Legacy toolset map 274 # ========================================================================= 275 276 class TestLegacyToolsetMap: 277 def test_expected_legacy_names(self): 278 expected = [ 279 "web_tools", "terminal_tools", "vision_tools", "moa_tools", 280 "image_tools", "skills_tools", "browser_tools", "cronjob_tools", 281 "rl_tools", "file_tools", "tts_tools", 282 ] 283 for name in expected: 284 assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}" 285 286 def test_values_are_lists_of_strings(self): 287 for name, tools in _LEGACY_TOOLSET_MAP.items(): 288 assert isinstance(tools, list), f"{name} is not a list" 289 for tool in tools: 290 assert isinstance(tool, str), f"{name} contains non-string: {tool}" 291 292 293 # ========================================================================= 294 # Backward-compat wrappers 295 # ========================================================================= 296 297 class TestBackwardCompat: 298 def test_get_all_tool_names_returns_list(self): 299 names = get_all_tool_names() 300 assert isinstance(names, list) 301 assert len(names) > 0 302 # Should contain well-known tools 303 assert "web_search" in names 304 assert "terminal" in names 305 306 def test_get_toolset_for_tool(self): 307 result = get_toolset_for_tool("web_search") 308 assert result is not None 309 assert isinstance(result, str) 310 311 def test_get_toolset_for_unknown_tool(self): 312 result = get_toolset_for_tool("totally_nonexistent_tool") 313 assert result is None 314 315 def test_tool_to_toolset_map(self): 316 assert isinstance(TOOL_TO_TOOLSET_MAP, dict) 317 assert len(TOOL_TO_TOOLSET_MAP) > 0 318 319 320 # ========================================================================= 321 # _coerce_number — inf / nan must fall through to the original string 322 # (regression: fix: eliminate duplicate checkpoint entries and JSON-unsafe coercion) 323 # ========================================================================= 324 325 class TestCoerceNumberInfNan: 326 """_coerce_number must honor its documented contract ("Returns original 327 string on failure") for inf/nan inputs, because float('inf') and 328 float('nan') are not JSON-compliant under strict serialization.""" 329 330 def test_inf_returns_original_string(self): 331 from model_tools import _coerce_number 332 assert _coerce_number("inf") == "inf" 333 334 def test_negative_inf_returns_original_string(self): 335 from model_tools import _coerce_number 336 assert _coerce_number("-inf") == "-inf" 337 338 def test_nan_returns_original_string(self): 339 from model_tools import _coerce_number 340 assert _coerce_number("nan") == "nan" 341 342 def test_infinity_spelling_returns_original_string(self): 343 from model_tools import _coerce_number 344 # Python's float() parses "Infinity" too — still not JSON-safe. 345 assert _coerce_number("Infinity") == "Infinity" 346 347 def test_coerced_result_is_strict_json_safe(self): 348 """Whatever _coerce_number returns for inf/nan must round-trip 349 through strict (allow_nan=False) json.dumps without raising.""" 350 from model_tools import _coerce_number 351 for s in ("inf", "-inf", "nan", "Infinity"): 352 result = _coerce_number(s) 353 json.dumps({"x": result}, allow_nan=False) # must not raise 354 355 def test_normal_numbers_still_coerce(self): 356 """Guard against over-correction — real numbers still coerce.""" 357 from model_tools import _coerce_number 358 assert _coerce_number("42") == 42 359 assert _coerce_number("3.14") == 3.14 360 assert _coerce_number("1e3") == 1000