test_hooks.py
1 """ 2 Tests for hooks.py. 3 4 Tests the SDK hooks system: 5 - HooksManager registration and building 6 - HookResult SDK response conversion 7 - Permission hook (allow/deny/interrupt) 8 - Audit hook (logging, callbacks) 9 - Header reminder hook 10 - Prompt enhancement hook 11 - Stop and SubagentStop hooks 12 - Sandbox bypass detection 13 """ 14 import json 15 import sys 16 from pathlib import Path 17 from unittest.mock import MagicMock, AsyncMock, patch 18 19 import pytest 20 21 PROJECT_ROOT = Path(__file__).parent.parent.parent 22 sys.path.insert(0, str(PROJECT_ROOT)) 23 24 from src.core.hooks import ( 25 HookResult, 26 HooksManager, 27 ToolUsageRecord, 28 create_permission_hook, 29 create_audit_hook, 30 create_header_reminder_hook, 31 create_prompt_enhancement_hook, 32 create_stop_hook, 33 create_subagent_stop_hook, 34 ) 35 36 37 class TestHookResult: 38 """Tests for HookResult dataclass.""" 39 40 @pytest.mark.unit 41 def test_default_values(self): 42 """Test default HookResult values.""" 43 result = HookResult() 44 assert result.permission_decision is None 45 assert result.block is False 46 assert result.interrupt is False 47 assert result.system_message is None 48 49 @pytest.mark.unit 50 def test_allow_response(self): 51 """Test allow decision SDK response.""" 52 result = HookResult(permission_decision="allow") 53 response = result.to_sdk_response("PreToolUse") 54 assert "hookSpecificOutput" in response 55 assert response["hookSpecificOutput"]["permissionDecision"] == "allow" 56 57 @pytest.mark.unit 58 def test_deny_response(self): 59 """Test deny decision SDK response.""" 60 result = HookResult( 61 permission_decision="deny", 62 permission_reason="Not allowed", 63 ) 64 response = result.to_sdk_response("PreToolUse") 65 hook_output = response["hookSpecificOutput"] 66 assert hook_output["permissionDecision"] == "deny" 67 assert hook_output["permissionDecisionReason"] == "Not allowed" 68 69 @pytest.mark.unit 70 def test_block_response(self): 71 """Test block decision.""" 72 result = HookResult(block=True) 73 response = result.to_sdk_response("PreToolUse") 74 assert response["decision"] == "block" 75 76 @pytest.mark.unit 77 def test_interrupt_response(self): 78 """Test interrupt flag in response.""" 79 result = HookResult( 80 permission_decision="deny", 81 interrupt=True, 82 ) 83 response = result.to_sdk_response("PreToolUse") 84 assert response["hookSpecificOutput"]["interrupt"] is True 85 86 @pytest.mark.unit 87 def test_system_message(self): 88 """Test system message in response.""" 89 result = HookResult(system_message="Remember to do X") 90 response = result.to_sdk_response("PostToolUse") 91 assert response["systemMessage"] == "Remember to do X" 92 93 @pytest.mark.unit 94 def test_empty_result(self): 95 """Test empty result gives empty response.""" 96 result = HookResult() 97 response = result.to_sdk_response("PreToolUse") 98 assert response == {} 99 100 101 class TestToolUsageRecord: 102 """Tests for ToolUsageRecord dataclass.""" 103 104 @pytest.mark.unit 105 def test_basic_record(self): 106 """Test creating a basic usage record.""" 107 record = ToolUsageRecord( 108 tool_name="Read", 109 tool_id="use_123", 110 input_data={"file_path": "/workspace/test.txt"}, 111 ) 112 assert record.tool_name == "Read" 113 assert record.tool_id == "use_123" 114 assert record.is_error is False 115 assert record.timestamp is not None 116 117 118 class TestHooksManager: 119 """Tests for HooksManager.""" 120 121 @pytest.mark.unit 122 def test_empty_manager(self): 123 """Test empty manager builds empty config.""" 124 manager = HooksManager() 125 config = manager.build_hooks_config() 126 assert config == {} 127 128 @pytest.mark.unit 129 def test_add_pre_tool_hook(self): 130 """Test adding a PreToolUse hook.""" 131 manager = HooksManager() 132 callback = AsyncMock() 133 manager.add_pre_tool_hook(callback, matcher="Bash") 134 config = manager.build_hooks_config() 135 136 assert "PreToolUse" in config 137 assert len(config["PreToolUse"]) == 1 138 139 @pytest.mark.unit 140 def test_add_post_tool_hook(self): 141 """Test adding a PostToolUse hook.""" 142 manager = HooksManager() 143 callback = AsyncMock() 144 manager.add_post_tool_hook(callback) 145 config = manager.build_hooks_config() 146 147 assert "PostToolUse" in config 148 149 @pytest.mark.unit 150 def test_add_user_prompt_hook(self): 151 """Test adding a UserPromptSubmit hook.""" 152 manager = HooksManager() 153 callback = AsyncMock() 154 manager.add_user_prompt_hook(callback) 155 config = manager.build_hooks_config() 156 157 assert "UserPromptSubmit" in config 158 159 @pytest.mark.unit 160 def test_add_stop_hook(self): 161 """Test adding a Stop hook.""" 162 manager = HooksManager() 163 callback = AsyncMock() 164 manager.add_stop_hook(callback) 165 config = manager.build_hooks_config() 166 167 assert "Stop" in config 168 169 @pytest.mark.unit 170 def test_add_subagent_stop_hook(self): 171 """Test adding a SubagentStop hook.""" 172 manager = HooksManager() 173 callback = AsyncMock() 174 manager.add_subagent_stop_hook(callback) 175 config = manager.build_hooks_config() 176 177 assert "SubagentStop" in config 178 179 @pytest.mark.unit 180 def test_multiple_hooks(self): 181 """Test adding multiple hooks of same type.""" 182 manager = HooksManager() 183 manager.add_pre_tool_hook(AsyncMock(), matcher="Bash") 184 manager.add_pre_tool_hook(AsyncMock(), matcher="Write") 185 config = manager.build_hooks_config() 186 187 assert len(config["PreToolUse"]) == 2 188 189 @pytest.mark.unit 190 def test_tool_usage_records(self): 191 """Test tool usage records tracking.""" 192 manager = HooksManager() 193 assert manager.tool_usage_records == [] 194 195 manager._tool_usage_records.append( 196 ToolUsageRecord(tool_name="Read", tool_id="1", input_data={}) 197 ) 198 assert len(manager.tool_usage_records) == 1 199 200 @pytest.mark.unit 201 def test_clear_records(self): 202 """Test clearing tool usage records.""" 203 manager = HooksManager() 204 manager._tool_usage_records.append( 205 ToolUsageRecord(tool_name="Read", tool_id="1", input_data={}) 206 ) 207 manager.clear_records() 208 assert len(manager.tool_usage_records) == 0 209 210 @pytest.mark.unit 211 def test_set_permission_check_callback(self): 212 """Test setting permission check callback.""" 213 manager = HooksManager() 214 callback = MagicMock() 215 manager.set_permission_check_callback(callback) 216 assert manager._on_permission_check is callback 217 218 219 class TestPermissionHook: 220 """Tests for create_permission_hook.""" 221 222 @pytest.mark.asyncio 223 async def test_allowed_tool(self): 224 """Test that allowed tool returns allow decision.""" 225 pm = MagicMock() 226 pm.is_allowed.return_value = True 227 228 hook = create_permission_hook(pm) 229 result = await hook( 230 {"tool_name": "Read", "tool_input": {"file_path": "/workspace/test.txt"}}, 231 "use_123", 232 None, 233 ) 234 235 assert "hookSpecificOutput" in result 236 assert result["hookSpecificOutput"]["permissionDecision"] == "allow" 237 238 @pytest.mark.asyncio 239 async def test_denied_tool(self): 240 """Test that denied tool returns deny decision.""" 241 pm = MagicMock() 242 pm.is_allowed.return_value = False 243 pm.get_allowed_patterns_for_tool.return_value = [] 244 pm.get_denied_patterns_for_tool.return_value = [] 245 246 hook = create_permission_hook(pm) 247 result = await hook( 248 {"tool_name": "Bash", "tool_input": {"command": "rm -rf /"}}, 249 "use_456", 250 None, 251 ) 252 253 assert result["hookSpecificOutput"]["permissionDecision"] == "deny" 254 255 @pytest.mark.asyncio 256 async def test_sandbox_bypass_blocked(self): 257 """Test that sandbox bypass attempt is immediately interrupted.""" 258 pm = MagicMock() 259 pm.is_allowed.return_value = True # Even if allowed 260 261 hook = create_permission_hook(pm) 262 result = await hook( 263 { 264 "tool_name": "Bash", 265 "tool_input": { 266 "command": "echo hello", 267 "dangerouslyDisableSandbox": True, 268 }, 269 }, 270 "use_789", 271 None, 272 ) 273 274 assert result["hookSpecificOutput"]["permissionDecision"] == "deny" 275 assert result["hookSpecificOutput"]["interrupt"] is True 276 277 @pytest.mark.asyncio 278 async def test_denial_count_interrupt(self): 279 """Test that repeated denials trigger interrupt.""" 280 pm = MagicMock() 281 pm.is_allowed.return_value = False 282 pm.get_allowed_patterns_for_tool.return_value = [] 283 pm.get_denied_patterns_for_tool.return_value = [] 284 285 hook = create_permission_hook(pm, max_denials_before_interrupt=3) 286 287 # Deny 3 times for same tool 288 for i in range(3): 289 result = await hook( 290 {"tool_name": "Bash", "tool_input": {"command": "bad"}}, 291 f"use_{i}", 292 None, 293 ) 294 295 # Third denial should trigger interrupt 296 assert result["hookSpecificOutput"].get("interrupt") is True 297 298 @pytest.mark.asyncio 299 async def test_permission_check_callback(self): 300 """Test that permission check callback is invoked.""" 301 pm = MagicMock() 302 pm.is_allowed.return_value = True 303 callback = MagicMock() 304 305 hook = create_permission_hook(pm, on_permission_check=callback) 306 await hook( 307 {"tool_name": "Read", "tool_input": {}}, 308 "use_1", 309 None, 310 ) 311 312 callback.assert_called_once_with("Read", "allow") 313 314 @pytest.mark.asyncio 315 async def test_denial_tracker_called(self): 316 """Test that denial tracker is called on deny.""" 317 pm = MagicMock() 318 pm.is_allowed.return_value = False 319 pm.get_allowed_patterns_for_tool.return_value = [] 320 pm.get_denied_patterns_for_tool.return_value = [] 321 tracker = MagicMock() 322 323 hook = create_permission_hook(pm, denial_tracker=tracker) 324 await hook( 325 {"tool_name": "Bash", "tool_input": {"command": "dangerous"}}, 326 "use_1", 327 None, 328 ) 329 330 tracker.record_denial.assert_called_once() 331 332 333 class TestAuditHook: 334 """Tests for create_audit_hook.""" 335 336 @pytest.mark.asyncio 337 async def test_audit_returns_empty(self): 338 """Test that audit hook returns empty dict (no modifications).""" 339 hook = create_audit_hook() 340 result = await hook( 341 {"tool_name": "Read", "tool_result": {"content": "data"}}, 342 "use_1", 343 None, 344 ) 345 assert result == {} 346 347 @pytest.mark.asyncio 348 async def test_audit_writes_log_file(self, tmp_path): 349 """Test that audit hook writes to log file.""" 350 log_file = tmp_path / "audit.log" 351 hook = create_audit_hook(log_file=log_file) 352 353 await hook( 354 {"tool_name": "Write", "tool_result": {"is_error": False}}, 355 "use_1", 356 None, 357 ) 358 359 assert log_file.exists() 360 log_entry = json.loads(log_file.read_text().strip()) 361 assert log_entry["tool_name"] == "Write" 362 assert log_entry["is_error"] is False 363 364 @pytest.mark.asyncio 365 async def test_audit_calls_completion_callback(self): 366 """Test that audit hook calls on_tool_complete callback.""" 367 callback = MagicMock() 368 hook = create_audit_hook(on_tool_complete=callback) 369 370 await hook( 371 {"tool_name": "Edit", "tool_result": {"is_error": True, "content": "err"}}, 372 "use_2", 373 None, 374 ) 375 376 callback.assert_called_once() 377 args = callback.call_args 378 assert args[0][0] == "Edit" 379 assert args[0][3] is True # is_error 380 381 382 class TestHeaderReminderHook: 383 """Tests for create_header_reminder_hook.""" 384 385 @pytest.mark.asyncio 386 async def test_enabled_returns_system_message(self): 387 """Test that enabled hook returns system message.""" 388 hook = create_header_reminder_hook(enable=True) 389 result = await hook( 390 {"tool_name": "Read"}, 391 "use_1", 392 None, 393 ) 394 assert "systemMessage" in result 395 assert "header" in result["systemMessage"].lower() 396 397 @pytest.mark.asyncio 398 async def test_disabled_returns_empty(self): 399 """Test that disabled hook returns empty dict.""" 400 hook = create_header_reminder_hook(enable=False) 401 result = await hook( 402 {"tool_name": "Read"}, 403 "use_1", 404 None, 405 ) 406 assert result == {} 407 408 @pytest.mark.asyncio 409 async def test_skip_tools(self): 410 """Test that skipped tools return empty dict.""" 411 hook = create_header_reminder_hook( 412 enable=True, 413 skip_tools={"AskUserQuestion"} 414 ) 415 result = await hook( 416 {"tool_name": "AskUserQuestion"}, 417 "use_1", 418 None, 419 ) 420 assert result == {} 421 422 423 class TestPromptEnhancementHook: 424 """Tests for create_prompt_enhancement_hook.""" 425 426 @pytest.mark.asyncio 427 async def test_adds_timestamp(self): 428 """Test that timestamp is added to prompt.""" 429 hook = create_prompt_enhancement_hook(add_timestamp=True) 430 result = await hook( 431 {"prompt": "Hello"}, 432 None, 433 None, 434 ) 435 updated = result["hookSpecificOutput"]["updatedPrompt"] 436 assert "Hello" in updated 437 assert "[" in updated # Timestamp format 438 439 @pytest.mark.asyncio 440 async def test_adds_context(self): 441 """Test that context is added to prompt.""" 442 hook = create_prompt_enhancement_hook( 443 add_timestamp=False, 444 add_context="Context: test" 445 ) 446 result = await hook( 447 {"prompt": "Do something"}, 448 None, 449 None, 450 ) 451 updated = result["hookSpecificOutput"]["updatedPrompt"] 452 assert "Context: test" in updated 453 assert "Do something" in updated 454 455 456 class TestStopHook: 457 """Tests for create_stop_hook.""" 458 459 @pytest.mark.asyncio 460 async def test_on_stop_called(self): 461 """Test that on_stop callback is called.""" 462 on_stop = AsyncMock() 463 hook = create_stop_hook(on_stop=on_stop) 464 await hook({"reason": "complete"}, None, None) 465 on_stop.assert_called_once() 466 467 @pytest.mark.asyncio 468 async def test_cleanup_called(self): 469 """Test that cleanup function is called.""" 470 cleanup = AsyncMock() 471 hook = create_stop_hook(cleanup_fn=cleanup) 472 await hook({}, None, None) 473 cleanup.assert_called_once() 474 475 @pytest.mark.asyncio 476 async def test_callback_exception_handled(self): 477 """Test that exceptions in callbacks are handled gracefully.""" 478 on_stop = AsyncMock(side_effect=RuntimeError("cleanup failed")) 479 hook = create_stop_hook(on_stop=on_stop) 480 # Should not raise 481 result = await hook({}, None, None) 482 assert result == {} 483 484 485 class TestSubagentStopHook: 486 """Tests for create_subagent_stop_hook.""" 487 488 @pytest.mark.asyncio 489 async def test_callback_called(self): 490 """Test that subagent complete callback is called.""" 491 callback = AsyncMock() 492 hook = create_subagent_stop_hook(on_subagent_complete=callback) 493 await hook( 494 {"subagent_type": "research", "result": {"output": "data"}}, 495 None, 496 None, 497 ) 498 callback.assert_called_once_with("research", {"output": "data"}) 499 500 @pytest.mark.asyncio 501 async def test_no_callback(self): 502 """Test hook works without callback.""" 503 hook = create_subagent_stop_hook() 504 result = await hook({"subagent_type": "task"}, None, None) 505 assert result == {} 506 507 @pytest.mark.asyncio 508 async def test_callback_exception_handled(self): 509 """Test that exceptions in callback are handled gracefully.""" 510 callback = AsyncMock(side_effect=RuntimeError("failed")) 511 hook = create_subagent_stop_hook(on_subagent_complete=callback) 512 result = await hook({"subagent_type": "task", "result": {}}, None, None) 513 assert result == {}