tool_utils.py
1 """ 2 Shared utility functions for tool permission handling. 3 4 This module contains common functions used by both hooks.py and permissions.py 5 for building tool call strings and actionable denial messages. 6 """ 7 import sys 8 from pathlib import Path 9 from typing import Any, Optional, Protocol, runtime_checkable 10 11 12 @runtime_checkable 13 class PermissionManagerProtocol(Protocol): 14 """Protocol for permission managers that support pattern lookup.""" 15 16 def get_allowed_patterns_for_tool(self, tool_name: str) -> list[str]: 17 """Get allowed patterns for a specific tool.""" 18 ... 19 20 21 # Tool parameter mappings for building tool call strings 22 TOOL_PARAM_MAP: dict[str, tuple[str, ...]] = { 23 "Read": ("file_path", "path"), 24 "Write": ("file_path", "path"), 25 "Edit": ("file_path", "path"), 26 "MultiEdit": ("file_path", "path"), 27 "Glob": ("path", "cwd"), 28 "Grep": ("path", "file"), 29 "LS": ("path", "dir"), 30 "Bash": ("command",), 31 "Ag3ntumBash": ("command",), # MCP tool: mcp__ag3ntum__Bash 32 "WebFetch": ("url",), 33 "WebSearch": ("query",), 34 } 35 36 37 def normalize_tool_name(tool_name: str) -> str: 38 """ 39 Normalize MCP tool names to their permission-matching format. 40 41 Converts MCP tool prefixes to the format used in permission patterns: 42 - mcp__ag3ntum__Bash -> Ag3ntumBash 43 - mcp__ag3ntum__Read -> Ag3ntumRead 44 45 Standard SDK tool names are returned unchanged. 46 47 Args: 48 tool_name: Raw tool name from SDK. 49 50 Returns: 51 Normalized tool name for permission matching. 52 """ 53 prefix = "mcp__ag3ntum__" 54 if tool_name.startswith(prefix): 55 suffix = tool_name[len(prefix):] 56 return f"Ag3ntum{suffix}" 57 return tool_name 58 59 60 def build_tool_call_string(tool_name: str, tool_input: dict[str, Any]) -> str: 61 """ 62 Build a tool call string for permission matching. 63 64 Constructs a standardized string representation of a tool call 65 suitable for pattern matching against permission rules. 66 67 MCP tool names are normalized (e.g., mcp__ag3ntum__Bash -> Ag3ntumBash). 68 69 Args: 70 tool_name: Name of the tool being called. 71 tool_input: Input parameters for the tool. 72 73 Returns: 74 Formatted tool call string, e.g., "Read(./input/task.md)" or "Ag3ntumBash(ls -la)". 75 76 Examples: 77 >>> build_tool_call_string("Read", {"file_path": "./input/task.md"}) 78 'Read(./input/task.md)' 79 >>> build_tool_call_string("Bash", {"command": "ls -la"}) 80 'Bash(ls -la)' 81 >>> build_tool_call_string("mcp__ag3ntum__Bash", {"command": "ls -la"}) 82 'Ag3ntumBash(ls -la)' 83 """ 84 # Normalize MCP tool names for permission matching 85 normalized_name = normalize_tool_name(tool_name) 86 87 if normalized_name in TOOL_PARAM_MAP: 88 param_keys = TOOL_PARAM_MAP[normalized_name] 89 value = "" 90 for key in param_keys: 91 if key in tool_input: 92 value = tool_input[key] 93 break 94 return f"{normalized_name}({value})" 95 return normalized_name 96 97 98 def build_actionable_denial_message( 99 tool_name: str, 100 tool_input: dict[str, Any], 101 is_final_denial: bool, 102 permission_manager: Any = None, 103 max_patterns_shown: int = 5, 104 max_value_length: int = 50, 105 ) -> str: 106 """ 107 Build an actionable denial message with allowed patterns. 108 109 Instead of just saying "denied", provides concrete guidance on 110 what patterns the agent CAN use for this tool. 111 112 Args: 113 tool_name: The tool that was denied. 114 tool_input: The input that was attempted. 115 is_final_denial: If True, this is the last chance before interrupt. 116 permission_manager: Optional permission manager to get allowed patterns. 117 max_patterns_shown: Maximum number of allowed patterns to show. 118 max_value_length: Maximum length of displayed values before truncation. 119 120 Returns: 121 Actionable message with allowed patterns and optional interrupt warning. 122 """ 123 # Normalize MCP tool names for consistent messaging 124 normalized_name = normalize_tool_name(tool_name) 125 126 # Get allowed patterns for this tool from the permission manager 127 allowed_patterns: list[str] = [] 128 if permission_manager is not None: 129 if hasattr(permission_manager, 'get_allowed_patterns_for_tool'): 130 allowed_patterns = permission_manager.get_allowed_patterns_for_tool(normalized_name) 131 132 # Build the base denial message 133 if normalized_name in ("Bash", "Ag3ntumBash"): 134 command = tool_input.get("command", "") 135 truncated = command[:max_value_length] + "..." if len(command) > max_value_length else command 136 base_msg = f"{normalized_name} command '{truncated}' is not permitted." 137 else: 138 path = tool_input.get("file_path", tool_input.get("path", "")) 139 truncated = path[:max_value_length] if len(path) > max_value_length else path 140 base_msg = f"{normalized_name} for '{truncated}' is not permitted." 141 142 # Add guidance about what IS allowed 143 if allowed_patterns: 144 patterns_str = ", ".join(f"'{p}'" for p in allowed_patterns[:max_patterns_shown]) 145 guidance = f" Allowed patterns for {normalized_name}: {patterns_str}." 146 else: 147 guidance = f" No {normalized_name} operations are allowed in this security context." 148 149 # Add interrupt warning if this is final denial 150 if is_final_denial: 151 warning = " FINAL WARNING: Agent will be stopped if this tool is denied again." 152 else: 153 warning = "" 154 155 return base_msg + guidance + warning 156 157 158 def extract_patterns_for_tool( 159 tool_name: str, 160 permission_list: list[str], 161 ) -> list[str]: 162 """ 163 Extract patterns for a specific tool from a permission list. 164 165 Parses patterns like "Read(./input/**)" or "Bash(python *)" to extract 166 the inner pattern part, or returns "*" for bare tool names. 167 168 Args: 169 tool_name: Name of the tool to extract patterns for. 170 permission_list: List of permission patterns (allow or deny list). 171 172 Returns: 173 List of extracted patterns for the tool. 174 175 Examples: 176 >>> extract_patterns_for_tool("Read", ["Read(./input/**)", "Write"]) 177 ['./input/**'] 178 >>> extract_patterns_for_tool("Bash", ["Bash"]) 179 ['*'] 180 """ 181 patterns = [] 182 prefix = f"{tool_name}(" 183 184 for pattern in permission_list: 185 if pattern.startswith(prefix): 186 # Extract the pattern inside parentheses 187 inner = pattern[len(prefix):-1] if pattern.endswith(")") else pattern[len(prefix):] 188 patterns.append(inner) 189 elif pattern == tool_name: 190 # Tool name without parentheses means all uses are allowed 191 patterns.append("*") 192 193 return patterns 194 195 196 def build_script_command( 197 script_path: Path, 198 args: Optional[list[str]] = None, 199 ) -> list[str]: 200 """ 201 Build a command list for executing a script. 202 203 Determines the appropriate interpreter based on file extension and 204 constructs the full command with any additional arguments. 205 206 Args: 207 script_path: Path to the script file. 208 args: Optional additional arguments to pass to the script. 209 210 Returns: 211 Command list suitable for subprocess.run(). 212 213 Examples: 214 >>> build_script_command(Path("script.py")) 215 ['/usr/bin/python3', 'script.py'] 216 >>> build_script_command(Path("script.sh"), ["--verbose"]) 217 ['bash', 'script.sh', '--verbose'] 218 """ 219 if script_path.suffix == ".py": 220 cmd = [sys.executable, str(script_path)] 221 elif script_path.suffix in [".sh", ".bash"]: 222 cmd = ["bash", str(script_path)] 223 else: 224 cmd = [str(script_path)] 225 226 if args: 227 cmd.extend(args) 228 229 return cmd 230 231 232 def format_token_usage( 233 input_tokens: int, 234 output_tokens: int, 235 cache_creation: int = 0, 236 cache_read: int = 0, 237 model: Optional[str] = None, 238 get_context_size_fn: Optional[Any] = None, 239 ) -> tuple[str, Optional[str]]: 240 """ 241 Format token usage information into display strings. 242 243 Args: 244 input_tokens: Number of input tokens. 245 output_tokens: Number of output tokens. 246 cache_creation: Number of cache creation tokens. 247 cache_read: Number of cache read tokens. 248 model: Optional model name for context size calculation. 249 get_context_size_fn: Optional function to get model context size. 250 251 Returns: 252 Tuple of (tokens_line, cache_line) where cache_line may be None. 253 """ 254 total_input = input_tokens + cache_creation + cache_read 255 total_tokens = total_input + output_tokens 256 257 token_parts = [f"Tokens: {total_tokens:,}"] 258 token_parts.append(f"(in: {total_input:,}, out: {output_tokens:,})") 259 260 if model and get_context_size_fn: 261 context_size = get_context_size_fn(model) 262 if context_size: 263 context_percent = (total_input / context_size) * 100 264 token_parts.append( 265 f"Context: {total_input:,}/{context_size:,} ({context_percent:.1f}%)" 266 ) 267 268 tokens_line = " | ".join(token_parts) 269 270 cache_line = None 271 if cache_creation > 0 or cache_read > 0: 272 cache_parts = [] 273 if cache_creation > 0: 274 cache_parts.append(f"cache_write: {cache_creation:,}") 275 if cache_read > 0: 276 cache_parts.append(f"cache_read: {cache_read:,}") 277 cache_line = " | ".join(cache_parts) 278 279 return tokens_line, cache_line