/ tests / core-tests / test_hooks.py
test_hooks.py
  1  """
  2  Tests for hooks.py.
  3  
  4  Tests the SDK hooks system:
  5  - HooksManager registration and building
  6  - HookResult SDK response conversion
  7  - Permission hook (allow/deny/interrupt)
  8  - Audit hook (logging, callbacks)
  9  - Header reminder hook
 10  - Prompt enhancement hook
 11  - Stop and SubagentStop hooks
 12  - Sandbox bypass detection
 13  """
 14  import json
 15  import sys
 16  from pathlib import Path
 17  from unittest.mock import MagicMock, AsyncMock, patch
 18  
 19  import pytest
 20  
 21  PROJECT_ROOT = Path(__file__).parent.parent.parent
 22  sys.path.insert(0, str(PROJECT_ROOT))
 23  
 24  from src.core.hooks import (
 25      HookResult,
 26      HooksManager,
 27      ToolUsageRecord,
 28      create_permission_hook,
 29      create_audit_hook,
 30      create_header_reminder_hook,
 31      create_prompt_enhancement_hook,
 32      create_stop_hook,
 33      create_subagent_stop_hook,
 34  )
 35  
 36  
 37  class TestHookResult:
 38      """Tests for HookResult dataclass."""
 39  
 40      @pytest.mark.unit
 41      def test_default_values(self):
 42          """Test default HookResult values."""
 43          result = HookResult()
 44          assert result.permission_decision is None
 45          assert result.block is False
 46          assert result.interrupt is False
 47          assert result.system_message is None
 48  
 49      @pytest.mark.unit
 50      def test_allow_response(self):
 51          """Test allow decision SDK response."""
 52          result = HookResult(permission_decision="allow")
 53          response = result.to_sdk_response("PreToolUse")
 54          assert "hookSpecificOutput" in response
 55          assert response["hookSpecificOutput"]["permissionDecision"] == "allow"
 56  
 57      @pytest.mark.unit
 58      def test_deny_response(self):
 59          """Test deny decision SDK response."""
 60          result = HookResult(
 61              permission_decision="deny",
 62              permission_reason="Not allowed",
 63          )
 64          response = result.to_sdk_response("PreToolUse")
 65          hook_output = response["hookSpecificOutput"]
 66          assert hook_output["permissionDecision"] == "deny"
 67          assert hook_output["permissionDecisionReason"] == "Not allowed"
 68  
 69      @pytest.mark.unit
 70      def test_block_response(self):
 71          """Test block decision."""
 72          result = HookResult(block=True)
 73          response = result.to_sdk_response("PreToolUse")
 74          assert response["decision"] == "block"
 75  
 76      @pytest.mark.unit
 77      def test_interrupt_response(self):
 78          """Test interrupt flag in response."""
 79          result = HookResult(
 80              permission_decision="deny",
 81              interrupt=True,
 82          )
 83          response = result.to_sdk_response("PreToolUse")
 84          assert response["hookSpecificOutput"]["interrupt"] is True
 85  
 86      @pytest.mark.unit
 87      def test_system_message(self):
 88          """Test system message in response."""
 89          result = HookResult(system_message="Remember to do X")
 90          response = result.to_sdk_response("PostToolUse")
 91          assert response["systemMessage"] == "Remember to do X"
 92  
 93      @pytest.mark.unit
 94      def test_empty_result(self):
 95          """Test empty result gives empty response."""
 96          result = HookResult()
 97          response = result.to_sdk_response("PreToolUse")
 98          assert response == {}
 99  
100  
101  class TestToolUsageRecord:
102      """Tests for ToolUsageRecord dataclass."""
103  
104      @pytest.mark.unit
105      def test_basic_record(self):
106          """Test creating a basic usage record."""
107          record = ToolUsageRecord(
108              tool_name="Read",
109              tool_id="use_123",
110              input_data={"file_path": "/workspace/test.txt"},
111          )
112          assert record.tool_name == "Read"
113          assert record.tool_id == "use_123"
114          assert record.is_error is False
115          assert record.timestamp is not None
116  
117  
118  class TestHooksManager:
119      """Tests for HooksManager."""
120  
121      @pytest.mark.unit
122      def test_empty_manager(self):
123          """Test empty manager builds empty config."""
124          manager = HooksManager()
125          config = manager.build_hooks_config()
126          assert config == {}
127  
128      @pytest.mark.unit
129      def test_add_pre_tool_hook(self):
130          """Test adding a PreToolUse hook."""
131          manager = HooksManager()
132          callback = AsyncMock()
133          manager.add_pre_tool_hook(callback, matcher="Bash")
134          config = manager.build_hooks_config()
135  
136          assert "PreToolUse" in config
137          assert len(config["PreToolUse"]) == 1
138  
139      @pytest.mark.unit
140      def test_add_post_tool_hook(self):
141          """Test adding a PostToolUse hook."""
142          manager = HooksManager()
143          callback = AsyncMock()
144          manager.add_post_tool_hook(callback)
145          config = manager.build_hooks_config()
146  
147          assert "PostToolUse" in config
148  
149      @pytest.mark.unit
150      def test_add_user_prompt_hook(self):
151          """Test adding a UserPromptSubmit hook."""
152          manager = HooksManager()
153          callback = AsyncMock()
154          manager.add_user_prompt_hook(callback)
155          config = manager.build_hooks_config()
156  
157          assert "UserPromptSubmit" in config
158  
159      @pytest.mark.unit
160      def test_add_stop_hook(self):
161          """Test adding a Stop hook."""
162          manager = HooksManager()
163          callback = AsyncMock()
164          manager.add_stop_hook(callback)
165          config = manager.build_hooks_config()
166  
167          assert "Stop" in config
168  
169      @pytest.mark.unit
170      def test_add_subagent_stop_hook(self):
171          """Test adding a SubagentStop hook."""
172          manager = HooksManager()
173          callback = AsyncMock()
174          manager.add_subagent_stop_hook(callback)
175          config = manager.build_hooks_config()
176  
177          assert "SubagentStop" in config
178  
179      @pytest.mark.unit
180      def test_multiple_hooks(self):
181          """Test adding multiple hooks of same type."""
182          manager = HooksManager()
183          manager.add_pre_tool_hook(AsyncMock(), matcher="Bash")
184          manager.add_pre_tool_hook(AsyncMock(), matcher="Write")
185          config = manager.build_hooks_config()
186  
187          assert len(config["PreToolUse"]) == 2
188  
189      @pytest.mark.unit
190      def test_tool_usage_records(self):
191          """Test tool usage records tracking."""
192          manager = HooksManager()
193          assert manager.tool_usage_records == []
194  
195          manager._tool_usage_records.append(
196              ToolUsageRecord(tool_name="Read", tool_id="1", input_data={})
197          )
198          assert len(manager.tool_usage_records) == 1
199  
200      @pytest.mark.unit
201      def test_clear_records(self):
202          """Test clearing tool usage records."""
203          manager = HooksManager()
204          manager._tool_usage_records.append(
205              ToolUsageRecord(tool_name="Read", tool_id="1", input_data={})
206          )
207          manager.clear_records()
208          assert len(manager.tool_usage_records) == 0
209  
210      @pytest.mark.unit
211      def test_set_permission_check_callback(self):
212          """Test setting permission check callback."""
213          manager = HooksManager()
214          callback = MagicMock()
215          manager.set_permission_check_callback(callback)
216          assert manager._on_permission_check is callback
217  
218  
219  class TestPermissionHook:
220      """Tests for create_permission_hook."""
221  
222      @pytest.mark.asyncio
223      async def test_allowed_tool(self):
224          """Test that allowed tool returns allow decision."""
225          pm = MagicMock()
226          pm.is_allowed.return_value = True
227  
228          hook = create_permission_hook(pm)
229          result = await hook(
230              {"tool_name": "Read", "tool_input": {"file_path": "/workspace/test.txt"}},
231              "use_123",
232              None,
233          )
234  
235          assert "hookSpecificOutput" in result
236          assert result["hookSpecificOutput"]["permissionDecision"] == "allow"
237  
238      @pytest.mark.asyncio
239      async def test_denied_tool(self):
240          """Test that denied tool returns deny decision."""
241          pm = MagicMock()
242          pm.is_allowed.return_value = False
243          pm.get_allowed_patterns_for_tool.return_value = []
244          pm.get_denied_patterns_for_tool.return_value = []
245  
246          hook = create_permission_hook(pm)
247          result = await hook(
248              {"tool_name": "Bash", "tool_input": {"command": "rm -rf /"}},
249              "use_456",
250              None,
251          )
252  
253          assert result["hookSpecificOutput"]["permissionDecision"] == "deny"
254  
255      @pytest.mark.asyncio
256      async def test_sandbox_bypass_blocked(self):
257          """Test that sandbox bypass attempt is immediately interrupted."""
258          pm = MagicMock()
259          pm.is_allowed.return_value = True  # Even if allowed
260  
261          hook = create_permission_hook(pm)
262          result = await hook(
263              {
264                  "tool_name": "Bash",
265                  "tool_input": {
266                      "command": "echo hello",
267                      "dangerouslyDisableSandbox": True,
268                  },
269              },
270              "use_789",
271              None,
272          )
273  
274          assert result["hookSpecificOutput"]["permissionDecision"] == "deny"
275          assert result["hookSpecificOutput"]["interrupt"] is True
276  
277      @pytest.mark.asyncio
278      async def test_denial_count_interrupt(self):
279          """Test that repeated denials trigger interrupt."""
280          pm = MagicMock()
281          pm.is_allowed.return_value = False
282          pm.get_allowed_patterns_for_tool.return_value = []
283          pm.get_denied_patterns_for_tool.return_value = []
284  
285          hook = create_permission_hook(pm, max_denials_before_interrupt=3)
286  
287          # Deny 3 times for same tool
288          for i in range(3):
289              result = await hook(
290                  {"tool_name": "Bash", "tool_input": {"command": "bad"}},
291                  f"use_{i}",
292                  None,
293              )
294  
295          # Third denial should trigger interrupt
296          assert result["hookSpecificOutput"].get("interrupt") is True
297  
298      @pytest.mark.asyncio
299      async def test_permission_check_callback(self):
300          """Test that permission check callback is invoked."""
301          pm = MagicMock()
302          pm.is_allowed.return_value = True
303          callback = MagicMock()
304  
305          hook = create_permission_hook(pm, on_permission_check=callback)
306          await hook(
307              {"tool_name": "Read", "tool_input": {}},
308              "use_1",
309              None,
310          )
311  
312          callback.assert_called_once_with("Read", "allow")
313  
314      @pytest.mark.asyncio
315      async def test_denial_tracker_called(self):
316          """Test that denial tracker is called on deny."""
317          pm = MagicMock()
318          pm.is_allowed.return_value = False
319          pm.get_allowed_patterns_for_tool.return_value = []
320          pm.get_denied_patterns_for_tool.return_value = []
321          tracker = MagicMock()
322  
323          hook = create_permission_hook(pm, denial_tracker=tracker)
324          await hook(
325              {"tool_name": "Bash", "tool_input": {"command": "dangerous"}},
326              "use_1",
327              None,
328          )
329  
330          tracker.record_denial.assert_called_once()
331  
332  
333  class TestAuditHook:
334      """Tests for create_audit_hook."""
335  
336      @pytest.mark.asyncio
337      async def test_audit_returns_empty(self):
338          """Test that audit hook returns empty dict (no modifications)."""
339          hook = create_audit_hook()
340          result = await hook(
341              {"tool_name": "Read", "tool_result": {"content": "data"}},
342              "use_1",
343              None,
344          )
345          assert result == {}
346  
347      @pytest.mark.asyncio
348      async def test_audit_writes_log_file(self, tmp_path):
349          """Test that audit hook writes to log file."""
350          log_file = tmp_path / "audit.log"
351          hook = create_audit_hook(log_file=log_file)
352  
353          await hook(
354              {"tool_name": "Write", "tool_result": {"is_error": False}},
355              "use_1",
356              None,
357          )
358  
359          assert log_file.exists()
360          log_entry = json.loads(log_file.read_text().strip())
361          assert log_entry["tool_name"] == "Write"
362          assert log_entry["is_error"] is False
363  
364      @pytest.mark.asyncio
365      async def test_audit_calls_completion_callback(self):
366          """Test that audit hook calls on_tool_complete callback."""
367          callback = MagicMock()
368          hook = create_audit_hook(on_tool_complete=callback)
369  
370          await hook(
371              {"tool_name": "Edit", "tool_result": {"is_error": True, "content": "err"}},
372              "use_2",
373              None,
374          )
375  
376          callback.assert_called_once()
377          args = callback.call_args
378          assert args[0][0] == "Edit"
379          assert args[0][3] is True  # is_error
380  
381  
382  class TestHeaderReminderHook:
383      """Tests for create_header_reminder_hook."""
384  
385      @pytest.mark.asyncio
386      async def test_enabled_returns_system_message(self):
387          """Test that enabled hook returns system message."""
388          hook = create_header_reminder_hook(enable=True)
389          result = await hook(
390              {"tool_name": "Read"},
391              "use_1",
392              None,
393          )
394          assert "systemMessage" in result
395          assert "header" in result["systemMessage"].lower()
396  
397      @pytest.mark.asyncio
398      async def test_disabled_returns_empty(self):
399          """Test that disabled hook returns empty dict."""
400          hook = create_header_reminder_hook(enable=False)
401          result = await hook(
402              {"tool_name": "Read"},
403              "use_1",
404              None,
405          )
406          assert result == {}
407  
408      @pytest.mark.asyncio
409      async def test_skip_tools(self):
410          """Test that skipped tools return empty dict."""
411          hook = create_header_reminder_hook(
412              enable=True,
413              skip_tools={"AskUserQuestion"}
414          )
415          result = await hook(
416              {"tool_name": "AskUserQuestion"},
417              "use_1",
418              None,
419          )
420          assert result == {}
421  
422  
423  class TestPromptEnhancementHook:
424      """Tests for create_prompt_enhancement_hook."""
425  
426      @pytest.mark.asyncio
427      async def test_adds_timestamp(self):
428          """Test that timestamp is added to prompt."""
429          hook = create_prompt_enhancement_hook(add_timestamp=True)
430          result = await hook(
431              {"prompt": "Hello"},
432              None,
433              None,
434          )
435          updated = result["hookSpecificOutput"]["updatedPrompt"]
436          assert "Hello" in updated
437          assert "[" in updated  # Timestamp format
438  
439      @pytest.mark.asyncio
440      async def test_adds_context(self):
441          """Test that context is added to prompt."""
442          hook = create_prompt_enhancement_hook(
443              add_timestamp=False,
444              add_context="Context: test"
445          )
446          result = await hook(
447              {"prompt": "Do something"},
448              None,
449              None,
450          )
451          updated = result["hookSpecificOutput"]["updatedPrompt"]
452          assert "Context: test" in updated
453          assert "Do something" in updated
454  
455  
456  class TestStopHook:
457      """Tests for create_stop_hook."""
458  
459      @pytest.mark.asyncio
460      async def test_on_stop_called(self):
461          """Test that on_stop callback is called."""
462          on_stop = AsyncMock()
463          hook = create_stop_hook(on_stop=on_stop)
464          await hook({"reason": "complete"}, None, None)
465          on_stop.assert_called_once()
466  
467      @pytest.mark.asyncio
468      async def test_cleanup_called(self):
469          """Test that cleanup function is called."""
470          cleanup = AsyncMock()
471          hook = create_stop_hook(cleanup_fn=cleanup)
472          await hook({}, None, None)
473          cleanup.assert_called_once()
474  
475      @pytest.mark.asyncio
476      async def test_callback_exception_handled(self):
477          """Test that exceptions in callbacks are handled gracefully."""
478          on_stop = AsyncMock(side_effect=RuntimeError("cleanup failed"))
479          hook = create_stop_hook(on_stop=on_stop)
480          # Should not raise
481          result = await hook({}, None, None)
482          assert result == {}
483  
484  
485  class TestSubagentStopHook:
486      """Tests for create_subagent_stop_hook."""
487  
488      @pytest.mark.asyncio
489      async def test_callback_called(self):
490          """Test that subagent complete callback is called."""
491          callback = AsyncMock()
492          hook = create_subagent_stop_hook(on_subagent_complete=callback)
493          await hook(
494              {"subagent_type": "research", "result": {"output": "data"}},
495              None,
496              None,
497          )
498          callback.assert_called_once_with("research", {"output": "data"})
499  
500      @pytest.mark.asyncio
501      async def test_no_callback(self):
502          """Test hook works without callback."""
503          hook = create_subagent_stop_hook()
504          result = await hook({"subagent_type": "task"}, None, None)
505          assert result == {}
506  
507      @pytest.mark.asyncio
508      async def test_callback_exception_handled(self):
509          """Test that exceptions in callback are handled gracefully."""
510          callback = AsyncMock(side_effect=RuntimeError("failed"))
511          hook = create_subagent_stop_hook(on_subagent_complete=callback)
512          result = await hook({"subagent_type": "task", "result": {}}, None, None)
513          assert result == {}