/ src / core / tool_utils.py
tool_utils.py
  1  """
  2  Shared utility functions for tool permission handling.
  3  
  4  This module contains common functions used by both hooks.py and permissions.py
  5  for building tool call strings and actionable denial messages.
  6  """
  7  import sys
  8  from pathlib import Path
  9  from typing import Any, Optional, Protocol, runtime_checkable
 10  
 11  
 12  @runtime_checkable
 13  class PermissionManagerProtocol(Protocol):
 14      """Protocol for permission managers that support pattern lookup."""
 15  
 16      def get_allowed_patterns_for_tool(self, tool_name: str) -> list[str]:
 17          """Get allowed patterns for a specific tool."""
 18          ...
 19  
 20  
 21  # Tool parameter mappings for building tool call strings
 22  TOOL_PARAM_MAP: dict[str, tuple[str, ...]] = {
 23      "Read": ("file_path", "path"),
 24      "Write": ("file_path", "path"),
 25      "Edit": ("file_path", "path"),
 26      "MultiEdit": ("file_path", "path"),
 27      "Glob": ("path", "cwd"),
 28      "Grep": ("path", "file"),
 29      "LS": ("path", "dir"),
 30      "Bash": ("command",),
 31      "Ag3ntumBash": ("command",),  # MCP tool: mcp__ag3ntum__Bash
 32      "WebFetch": ("url",),
 33      "WebSearch": ("query",),
 34  }
 35  
 36  
 37  def normalize_tool_name(tool_name: str) -> str:
 38      """
 39      Normalize MCP tool names to their permission-matching format.
 40  
 41      Converts MCP tool prefixes to the format used in permission patterns:
 42      - mcp__ag3ntum__Bash -> Ag3ntumBash
 43      - mcp__ag3ntum__Read -> Ag3ntumRead
 44  
 45      Standard SDK tool names are returned unchanged.
 46  
 47      Args:
 48          tool_name: Raw tool name from SDK.
 49  
 50      Returns:
 51          Normalized tool name for permission matching.
 52      """
 53      prefix = "mcp__ag3ntum__"
 54      if tool_name.startswith(prefix):
 55          suffix = tool_name[len(prefix):]
 56          return f"Ag3ntum{suffix}"
 57      return tool_name
 58  
 59  
 60  def build_tool_call_string(tool_name: str, tool_input: dict[str, Any]) -> str:
 61      """
 62      Build a tool call string for permission matching.
 63  
 64      Constructs a standardized string representation of a tool call
 65      suitable for pattern matching against permission rules.
 66  
 67      MCP tool names are normalized (e.g., mcp__ag3ntum__Bash -> Ag3ntumBash).
 68  
 69      Args:
 70          tool_name: Name of the tool being called.
 71          tool_input: Input parameters for the tool.
 72  
 73      Returns:
 74          Formatted tool call string, e.g., "Read(./input/task.md)" or "Ag3ntumBash(ls -la)".
 75  
 76      Examples:
 77          >>> build_tool_call_string("Read", {"file_path": "./input/task.md"})
 78          'Read(./input/task.md)'
 79          >>> build_tool_call_string("Bash", {"command": "ls -la"})
 80          'Bash(ls -la)'
 81          >>> build_tool_call_string("mcp__ag3ntum__Bash", {"command": "ls -la"})
 82          'Ag3ntumBash(ls -la)'
 83      """
 84      # Normalize MCP tool names for permission matching
 85      normalized_name = normalize_tool_name(tool_name)
 86  
 87      if normalized_name in TOOL_PARAM_MAP:
 88          param_keys = TOOL_PARAM_MAP[normalized_name]
 89          value = ""
 90          for key in param_keys:
 91              if key in tool_input:
 92                  value = tool_input[key]
 93                  break
 94          return f"{normalized_name}({value})"
 95      return normalized_name
 96  
 97  
 98  def build_actionable_denial_message(
 99      tool_name: str,
100      tool_input: dict[str, Any],
101      is_final_denial: bool,
102      permission_manager: Any = None,
103      max_patterns_shown: int = 5,
104      max_value_length: int = 50,
105  ) -> str:
106      """
107      Build an actionable denial message with allowed patterns.
108  
109      Instead of just saying "denied", provides concrete guidance on
110      what patterns the agent CAN use for this tool.
111  
112      Args:
113          tool_name: The tool that was denied.
114          tool_input: The input that was attempted.
115          is_final_denial: If True, this is the last chance before interrupt.
116          permission_manager: Optional permission manager to get allowed patterns.
117          max_patterns_shown: Maximum number of allowed patterns to show.
118          max_value_length: Maximum length of displayed values before truncation.
119  
120      Returns:
121          Actionable message with allowed patterns and optional interrupt warning.
122      """
123      # Normalize MCP tool names for consistent messaging
124      normalized_name = normalize_tool_name(tool_name)
125  
126      # Get allowed patterns for this tool from the permission manager
127      allowed_patterns: list[str] = []
128      if permission_manager is not None:
129          if hasattr(permission_manager, 'get_allowed_patterns_for_tool'):
130              allowed_patterns = permission_manager.get_allowed_patterns_for_tool(normalized_name)
131  
132      # Build the base denial message
133      if normalized_name in ("Bash", "Ag3ntumBash"):
134          command = tool_input.get("command", "")
135          truncated = command[:max_value_length] + "..." if len(command) > max_value_length else command
136          base_msg = f"{normalized_name} command '{truncated}' is not permitted."
137      else:
138          path = tool_input.get("file_path", tool_input.get("path", ""))
139          truncated = path[:max_value_length] if len(path) > max_value_length else path
140          base_msg = f"{normalized_name} for '{truncated}' is not permitted."
141  
142      # Add guidance about what IS allowed
143      if allowed_patterns:
144          patterns_str = ", ".join(f"'{p}'" for p in allowed_patterns[:max_patterns_shown])
145          guidance = f" Allowed patterns for {normalized_name}: {patterns_str}."
146      else:
147          guidance = f" No {normalized_name} operations are allowed in this security context."
148  
149      # Add interrupt warning if this is final denial
150      if is_final_denial:
151          warning = " FINAL WARNING: Agent will be stopped if this tool is denied again."
152      else:
153          warning = ""
154  
155      return base_msg + guidance + warning
156  
157  
158  def extract_patterns_for_tool(
159      tool_name: str,
160      permission_list: list[str],
161  ) -> list[str]:
162      """
163      Extract patterns for a specific tool from a permission list.
164  
165      Parses patterns like "Read(./input/**)" or "Bash(python *)" to extract
166      the inner pattern part, or returns "*" for bare tool names.
167  
168      Args:
169          tool_name: Name of the tool to extract patterns for.
170          permission_list: List of permission patterns (allow or deny list).
171  
172      Returns:
173          List of extracted patterns for the tool.
174  
175      Examples:
176          >>> extract_patterns_for_tool("Read", ["Read(./input/**)", "Write"])
177          ['./input/**']
178          >>> extract_patterns_for_tool("Bash", ["Bash"])
179          ['*']
180      """
181      patterns = []
182      prefix = f"{tool_name}("
183  
184      for pattern in permission_list:
185          if pattern.startswith(prefix):
186              # Extract the pattern inside parentheses
187              inner = pattern[len(prefix):-1] if pattern.endswith(")") else pattern[len(prefix):]
188              patterns.append(inner)
189          elif pattern == tool_name:
190              # Tool name without parentheses means all uses are allowed
191              patterns.append("*")
192  
193      return patterns
194  
195  
196  def build_script_command(
197      script_path: Path,
198      args: Optional[list[str]] = None,
199  ) -> list[str]:
200      """
201      Build a command list for executing a script.
202  
203      Determines the appropriate interpreter based on file extension and
204      constructs the full command with any additional arguments.
205  
206      Args:
207          script_path: Path to the script file.
208          args: Optional additional arguments to pass to the script.
209  
210      Returns:
211          Command list suitable for subprocess.run().
212  
213      Examples:
214          >>> build_script_command(Path("script.py"))
215          ['/usr/bin/python3', 'script.py']
216          >>> build_script_command(Path("script.sh"), ["--verbose"])
217          ['bash', 'script.sh', '--verbose']
218      """
219      if script_path.suffix == ".py":
220          cmd = [sys.executable, str(script_path)]
221      elif script_path.suffix in [".sh", ".bash"]:
222          cmd = ["bash", str(script_path)]
223      else:
224          cmd = [str(script_path)]
225  
226      if args:
227          cmd.extend(args)
228  
229      return cmd
230  
231  
232  def format_token_usage(
233      input_tokens: int,
234      output_tokens: int,
235      cache_creation: int = 0,
236      cache_read: int = 0,
237      model: Optional[str] = None,
238      get_context_size_fn: Optional[Any] = None,
239  ) -> tuple[str, Optional[str]]:
240      """
241      Format token usage information into display strings.
242  
243      Args:
244          input_tokens: Number of input tokens.
245          output_tokens: Number of output tokens.
246          cache_creation: Number of cache creation tokens.
247          cache_read: Number of cache read tokens.
248          model: Optional model name for context size calculation.
249          get_context_size_fn: Optional function to get model context size.
250  
251      Returns:
252          Tuple of (tokens_line, cache_line) where cache_line may be None.
253      """
254      total_input = input_tokens + cache_creation + cache_read
255      total_tokens = total_input + output_tokens
256  
257      token_parts = [f"Tokens: {total_tokens:,}"]
258      token_parts.append(f"(in: {total_input:,}, out: {output_tokens:,})")
259  
260      if model and get_context_size_fn:
261          context_size = get_context_size_fn(model)
262          if context_size:
263              context_percent = (total_input / context_size) * 100
264              token_parts.append(
265                  f"Context: {total_input:,}/{context_size:,} ({context_percent:.1f}%)"
266              )
267  
268      tokens_line = " | ".join(token_parts)
269  
270      cache_line = None
271      if cache_creation > 0 or cache_read > 0:
272          cache_parts = []
273          if cache_creation > 0:
274              cache_parts.append(f"cache_write: {cache_creation:,}")
275          if cache_read > 0:
276              cache_parts.append(f"cache_read: {cache_read:,}")
277          cache_line = " | ".join(cache_parts)
278  
279      return tokens_line, cache_line