hooks.go
1 package hooks 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "log" 10 "os" 11 "os/exec" 12 "path/filepath" 13 "regexp" 14 "strings" 15 "sync" 16 "syscall" 17 "time" 18 ) 19 20 type HookEvent string 21 22 const ( 23 PreToolUse HookEvent = "PreToolUse" 24 PostToolUse HookEvent = "PostToolUse" 25 SessionStart HookEvent = "SessionStart" 26 Stop HookEvent = "Stop" 27 ) 28 29 const ( 30 defaultTimeout = 10 * time.Second 31 maxOutputBytes = 10 * 1024 // 10KB 32 exitCodeDeny = 2 33 ) 34 35 // HookConfig uses PascalCase event names in both YAML and JSON to match the 36 // established schema (e.g. "PreToolUse" in config.yaml). This is intentionally 37 // inconsistent with the snake_case used by other config fields; clients doing a 38 // GET /config → PATCH /config round-trip must preserve these PascalCase keys. 39 type HookConfig struct { 40 PreToolUse []HookEntry `yaml:"PreToolUse" json:"PreToolUse"` 41 PostToolUse []HookEntry `yaml:"PostToolUse" json:"PostToolUse"` 42 SessionStart []HookEntry `yaml:"SessionStart" json:"SessionStart"` 43 Stop []HookEntry `yaml:"Stop" json:"Stop"` 44 } 45 46 type HookEntry struct { 47 Matcher string `yaml:"matcher" json:"matcher"` 48 Command string `yaml:"command" json:"command"` 49 } 50 51 type HookInput struct { 52 Event HookEvent `json:"event"` 53 ToolName string `json:"tool_name,omitempty"` 54 ToolInput json.RawMessage `json:"tool_input"` 55 ToolResponse json.RawMessage `json:"tool_response"` 56 SessionID string `json:"session_id"` 57 } 58 59 type HookRunner struct { 60 config HookConfig 61 timeout time.Duration 62 allowedDirs []string // additional allowed absolute path prefixes (for testing) 63 64 mu sync.Mutex 65 inHook bool // recursion guard 66 } 67 68 func NewHookRunner(config HookConfig) *HookRunner { 69 return &HookRunner{ 70 config: config, 71 timeout: defaultTimeout, 72 } 73 } 74 75 // RunPreToolUse runs matching PreToolUse hooks. 76 // Returns: decision ("allow"/"deny"/""), reason string, error. 77 // If any hook exits with code 2, returns "deny" with stderr as reason. 78 // If any hook exits with non-zero (not 2), logs warning but doesn't block. 79 func (h *HookRunner) RunPreToolUse(ctx context.Context, toolName string, toolInput string, sessionID string) (string, string, error) { 80 if h == nil { 81 return "", "", nil 82 } 83 if h.enterHook() { 84 return "", "", nil // skip recursive invocations 85 } 86 defer h.exitHook() 87 88 entries := h.matchEntries(h.config.PreToolUse, toolName) 89 if len(entries) == 0 { 90 return "", "", nil 91 } 92 93 input := HookInput{ 94 Event: PreToolUse, 95 ToolName: toolName, 96 ToolInput: toRawJSON(toolInput), 97 ToolResponse: nullJSON(), 98 SessionID: sessionID, 99 } 100 101 for _, entry := range entries { 102 exitCode, _, stderr, err := h.runHook(ctx, entry, input) 103 if err != nil { 104 log.Printf("[hooks] warning: PreToolUse hook %q failed: %v", entry.Command, err) 105 continue 106 } 107 if exitCode == exitCodeDeny { 108 reason := strings.TrimSpace(stderr) 109 if reason == "" { 110 reason = "blocked by hook" 111 } 112 return "deny", reason, nil 113 } 114 if exitCode != 0 { 115 log.Printf("[hooks] warning: PreToolUse hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr)) 116 } 117 } 118 119 return "allow", "", nil 120 } 121 122 // RunPostToolUse runs matching PostToolUse hooks (fire-and-forget, errors logged). 123 func (h *HookRunner) RunPostToolUse(ctx context.Context, toolName string, toolInput string, toolResponse string, sessionID string) error { 124 if h == nil { 125 return nil 126 } 127 if h.enterHook() { 128 return nil 129 } 130 defer h.exitHook() 131 132 entries := h.matchEntries(h.config.PostToolUse, toolName) 133 if len(entries) == 0 { 134 return nil 135 } 136 137 input := HookInput{ 138 Event: PostToolUse, 139 ToolName: toolName, 140 ToolInput: toRawJSON(toolInput), 141 ToolResponse: toRawJSON(toolResponse), 142 SessionID: sessionID, 143 } 144 145 for _, entry := range entries { 146 exitCode, _, stderr, err := h.runHook(ctx, entry, input) 147 if err != nil { 148 log.Printf("[hooks] warning: PostToolUse hook %q failed: %v", entry.Command, err) 149 continue 150 } 151 if exitCode != 0 { 152 log.Printf("[hooks] warning: PostToolUse hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr)) 153 } 154 } 155 156 return nil 157 } 158 159 // RunSessionStart runs all SessionStart hooks. 160 func (h *HookRunner) RunSessionStart(ctx context.Context, sessionID string) error { 161 if h == nil { 162 return nil 163 } 164 if h.enterHook() { 165 return nil 166 } 167 defer h.exitHook() 168 169 if len(h.config.SessionStart) == 0 { 170 return nil 171 } 172 173 input := HookInput{ 174 Event: SessionStart, 175 ToolInput: nullJSON(), 176 ToolResponse: nullJSON(), 177 SessionID: sessionID, 178 } 179 180 for _, entry := range h.config.SessionStart { 181 exitCode, _, stderr, err := h.runHook(ctx, entry, input) 182 if err != nil { 183 log.Printf("[hooks] warning: SessionStart hook %q failed: %v", entry.Command, err) 184 continue 185 } 186 if exitCode != 0 { 187 log.Printf("[hooks] warning: SessionStart hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr)) 188 } 189 } 190 191 return nil 192 } 193 194 // RunStop runs all Stop hooks. 195 func (h *HookRunner) RunStop(ctx context.Context, sessionID string) error { 196 if h == nil { 197 return nil 198 } 199 if h.enterHook() { 200 return nil 201 } 202 defer h.exitHook() 203 204 if len(h.config.Stop) == 0 { 205 return nil 206 } 207 208 input := HookInput{ 209 Event: Stop, 210 ToolInput: nullJSON(), 211 ToolResponse: nullJSON(), 212 SessionID: sessionID, 213 } 214 215 for _, entry := range h.config.Stop { 216 exitCode, _, stderr, err := h.runHook(ctx, entry, input) 217 if err != nil { 218 log.Printf("[hooks] warning: Stop hook %q failed: %v", entry.Command, err) 219 continue 220 } 221 if exitCode != 0 { 222 log.Printf("[hooks] warning: Stop hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr)) 223 } 224 } 225 226 return nil 227 } 228 229 func (h *HookRunner) runHook(ctx context.Context, entry HookEntry, input HookInput) (exitCode int, stdout string, stderr string, err error) { 230 cmdPath, err := resolveCommand(entry.Command, h.allowedDirs) 231 if err != nil { 232 return -1, "", "", err 233 } 234 235 inputJSON, err := json.Marshal(input) 236 if err != nil { 237 return -1, "", "", fmt.Errorf("failed to marshal hook input: %w", err) 238 } 239 240 hookCtx, cancel := context.WithTimeout(ctx, h.timeout) 241 defer cancel() 242 243 cmd := exec.CommandContext(hookCtx, cmdPath) 244 cmd.Stdin = bytes.NewReader(inputJSON) 245 246 // Create a new process group so we can kill the entire tree on timeout 247 cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} 248 cmd.Cancel = func() error { 249 // Kill the entire process group (negative PID) 250 return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) 251 } 252 253 var stdoutBuf, stderrBuf bytes.Buffer 254 cmd.Stdout = &limitedWriter{buf: &stdoutBuf, limit: maxOutputBytes} 255 cmd.Stderr = &limitedWriter{buf: &stderrBuf, limit: maxOutputBytes} 256 257 if runErr := cmd.Run(); runErr != nil { 258 if hookCtx.Err() == context.DeadlineExceeded { 259 return -1, stdoutBuf.String(), stderrBuf.String(), fmt.Errorf("hook timed out after %v", h.timeout) 260 } 261 if exitErr, ok := runErr.(*exec.ExitError); ok { 262 return exitErr.ExitCode(), stdoutBuf.String(), stderrBuf.String(), nil 263 } 264 return -1, stdoutBuf.String(), stderrBuf.String(), runErr 265 } 266 267 return 0, stdoutBuf.String(), stderrBuf.String(), nil 268 } 269 270 // matchEntries returns hook entries whose matcher regex matches the tool name. 271 // An empty matcher matches all tools. 272 func (h *HookRunner) matchEntries(entries []HookEntry, toolName string) []HookEntry { 273 var matched []HookEntry 274 for _, e := range entries { 275 if e.Matcher == "" { 276 matched = append(matched, e) 277 continue 278 } 279 re, err := regexp.Compile(e.Matcher) 280 if err != nil { 281 log.Printf("[hooks] warning: invalid matcher regex %q: %v", e.Matcher, err) 282 continue 283 } 284 if re.MatchString(toolName) { 285 matched = append(matched, e) 286 } 287 } 288 return matched 289 } 290 291 // resolveCommand validates and resolves the command path. 292 // Allowed: relative paths (starting with . or no / prefix), or paths under ~/.shannon/. 293 // Additional directories can be allowed via extraAllowed (used for testing). 294 // Rejected: absolute paths outside allowed directories. 295 func resolveCommand(command string, extraAllowed []string) (string, error) { 296 if command == "" { 297 return "", fmt.Errorf("hook command must not be empty") 298 } 299 300 // Expand ~ prefix 301 if strings.HasPrefix(command, "~/") { 302 home, err := os.UserHomeDir() 303 if err != nil { 304 return "", fmt.Errorf("failed to resolve home directory: %w", err) 305 } 306 command = filepath.Join(home, command[2:]) 307 } 308 309 // Reject bare command names that would be resolved via PATH 310 if !filepath.IsAbs(command) && !strings.HasPrefix(command, "./") && !strings.HasPrefix(command, "../") && !strings.Contains(command, string(filepath.Separator)) { 311 return "", fmt.Errorf("bare command %q rejected: use absolute path or ./ prefix", command) 312 } 313 314 // If it's an absolute path, it must be under ~/.shannon/ or an extra allowed dir 315 if filepath.IsAbs(command) { 316 shannonDir := shannonDirPath() 317 allowed := false 318 319 if shannonDir != "" && (strings.HasPrefix(command, shannonDir+string(filepath.Separator)) || command == shannonDir) { 320 allowed = true 321 } 322 for _, dir := range extraAllowed { 323 if strings.HasPrefix(command, dir+string(filepath.Separator)) || command == dir { 324 allowed = true 325 break 326 } 327 } 328 329 if !allowed { 330 target := shannonDir 331 if target == "" { 332 target = "~/.shannon" 333 } 334 return "", fmt.Errorf("absolute path %q rejected: must be under %s", command, target) 335 } 336 } 337 338 return command, nil 339 } 340 341 func shannonDirPath() string { 342 home, err := os.UserHomeDir() 343 if err != nil || home == "" { 344 return "" 345 } 346 return filepath.Join(home, ".shannon") 347 } 348 349 // enterHook attempts to set the recursion guard. Returns true if already inside a hook. 350 func (h *HookRunner) enterHook() bool { 351 h.mu.Lock() 352 defer h.mu.Unlock() 353 if h.inHook { 354 return true 355 } 356 h.inHook = true 357 return false 358 } 359 360 func (h *HookRunner) exitHook() { 361 h.mu.Lock() 362 defer h.mu.Unlock() 363 h.inHook = false 364 } 365 366 // toRawJSON converts a string to json.RawMessage. 367 // If the string is valid JSON, it's used as-is. Otherwise it's marshaled as a JSON string. 368 func toRawJSON(s string) json.RawMessage { 369 if s == "" { 370 return json.RawMessage("null") 371 } 372 if json.Valid([]byte(s)) { 373 return json.RawMessage(s) 374 } 375 data, err := json.Marshal(s) 376 if err != nil { 377 return json.RawMessage("null") 378 } 379 return data 380 } 381 382 func nullJSON() json.RawMessage { 383 return json.RawMessage("null") 384 } 385 386 // limitedWriter wraps a bytes.Buffer and stops writing after limit bytes. 387 type limitedWriter struct { 388 buf *bytes.Buffer 389 limit int 390 } 391 392 func (w *limitedWriter) Write(p []byte) (int, error) { 393 totalLen := len(p) 394 remaining := w.limit - w.buf.Len() 395 if remaining <= 0 { 396 return totalLen, nil // discard but report success 397 } 398 toWrite := p 399 if len(toWrite) > remaining { 400 toWrite = toWrite[:remaining] 401 } 402 if _, err := w.buf.Write(toWrite); err != nil { 403 return 0, err 404 } 405 // Report all bytes as written even if we truncated, so exec doesn't fail 406 return totalLen, nil 407 } 408 409 // Ensure limitedWriter satisfies io.Writer. 410 var _ io.Writer = (*limitedWriter)(nil)