/ internal / hooks / hooks.go
hooks.go
  1  package hooks
  2  
  3  import (
  4  	"bytes"
  5  	"context"
  6  	"encoding/json"
  7  	"fmt"
  8  	"io"
  9  	"log"
 10  	"os"
 11  	"os/exec"
 12  	"path/filepath"
 13  	"regexp"
 14  	"strings"
 15  	"sync"
 16  	"syscall"
 17  	"time"
 18  )
 19  
 20  type HookEvent string
 21  
 22  const (
 23  	PreToolUse   HookEvent = "PreToolUse"
 24  	PostToolUse  HookEvent = "PostToolUse"
 25  	SessionStart HookEvent = "SessionStart"
 26  	Stop         HookEvent = "Stop"
 27  )
 28  
 29  const (
 30  	defaultTimeout   = 10 * time.Second
 31  	maxOutputBytes   = 10 * 1024 // 10KB
 32  	exitCodeDeny     = 2
 33  )
 34  
 35  // HookConfig uses PascalCase event names in both YAML and JSON to match the
 36  // established schema (e.g. "PreToolUse" in config.yaml). This is intentionally
 37  // inconsistent with the snake_case used by other config fields; clients doing a
 38  // GET /config → PATCH /config round-trip must preserve these PascalCase keys.
 39  type HookConfig struct {
 40  	PreToolUse   []HookEntry `yaml:"PreToolUse"   json:"PreToolUse"`
 41  	PostToolUse  []HookEntry `yaml:"PostToolUse"  json:"PostToolUse"`
 42  	SessionStart []HookEntry `yaml:"SessionStart" json:"SessionStart"`
 43  	Stop         []HookEntry `yaml:"Stop"         json:"Stop"`
 44  }
 45  
 46  type HookEntry struct {
 47  	Matcher string `yaml:"matcher" json:"matcher"`
 48  	Command string `yaml:"command" json:"command"`
 49  }
 50  
 51  type HookInput struct {
 52  	Event        HookEvent       `json:"event"`
 53  	ToolName     string          `json:"tool_name,omitempty"`
 54  	ToolInput    json.RawMessage `json:"tool_input"`
 55  	ToolResponse json.RawMessage `json:"tool_response"`
 56  	SessionID    string          `json:"session_id"`
 57  }
 58  
 59  type HookRunner struct {
 60  	config      HookConfig
 61  	timeout     time.Duration
 62  	allowedDirs []string // additional allowed absolute path prefixes (for testing)
 63  
 64  	mu     sync.Mutex
 65  	inHook bool // recursion guard
 66  }
 67  
 68  func NewHookRunner(config HookConfig) *HookRunner {
 69  	return &HookRunner{
 70  		config:  config,
 71  		timeout: defaultTimeout,
 72  	}
 73  }
 74  
 75  // RunPreToolUse runs matching PreToolUse hooks.
 76  // Returns: decision ("allow"/"deny"/""), reason string, error.
 77  // If any hook exits with code 2, returns "deny" with stderr as reason.
 78  // If any hook exits with non-zero (not 2), logs warning but doesn't block.
 79  func (h *HookRunner) RunPreToolUse(ctx context.Context, toolName string, toolInput string, sessionID string) (string, string, error) {
 80  	if h == nil {
 81  		return "", "", nil
 82  	}
 83  	if h.enterHook() {
 84  		return "", "", nil // skip recursive invocations
 85  	}
 86  	defer h.exitHook()
 87  
 88  	entries := h.matchEntries(h.config.PreToolUse, toolName)
 89  	if len(entries) == 0 {
 90  		return "", "", nil
 91  	}
 92  
 93  	input := HookInput{
 94  		Event:        PreToolUse,
 95  		ToolName:     toolName,
 96  		ToolInput:    toRawJSON(toolInput),
 97  		ToolResponse: nullJSON(),
 98  		SessionID:    sessionID,
 99  	}
100  
101  	for _, entry := range entries {
102  		exitCode, _, stderr, err := h.runHook(ctx, entry, input)
103  		if err != nil {
104  			log.Printf("[hooks] warning: PreToolUse hook %q failed: %v", entry.Command, err)
105  			continue
106  		}
107  		if exitCode == exitCodeDeny {
108  			reason := strings.TrimSpace(stderr)
109  			if reason == "" {
110  				reason = "blocked by hook"
111  			}
112  			return "deny", reason, nil
113  		}
114  		if exitCode != 0 {
115  			log.Printf("[hooks] warning: PreToolUse hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr))
116  		}
117  	}
118  
119  	return "allow", "", nil
120  }
121  
122  // RunPostToolUse runs matching PostToolUse hooks (fire-and-forget, errors logged).
123  func (h *HookRunner) RunPostToolUse(ctx context.Context, toolName string, toolInput string, toolResponse string, sessionID string) error {
124  	if h == nil {
125  		return nil
126  	}
127  	if h.enterHook() {
128  		return nil
129  	}
130  	defer h.exitHook()
131  
132  	entries := h.matchEntries(h.config.PostToolUse, toolName)
133  	if len(entries) == 0 {
134  		return nil
135  	}
136  
137  	input := HookInput{
138  		Event:        PostToolUse,
139  		ToolName:     toolName,
140  		ToolInput:    toRawJSON(toolInput),
141  		ToolResponse: toRawJSON(toolResponse),
142  		SessionID:    sessionID,
143  	}
144  
145  	for _, entry := range entries {
146  		exitCode, _, stderr, err := h.runHook(ctx, entry, input)
147  		if err != nil {
148  			log.Printf("[hooks] warning: PostToolUse hook %q failed: %v", entry.Command, err)
149  			continue
150  		}
151  		if exitCode != 0 {
152  			log.Printf("[hooks] warning: PostToolUse hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr))
153  		}
154  	}
155  
156  	return nil
157  }
158  
159  // RunSessionStart runs all SessionStart hooks.
160  func (h *HookRunner) RunSessionStart(ctx context.Context, sessionID string) error {
161  	if h == nil {
162  		return nil
163  	}
164  	if h.enterHook() {
165  		return nil
166  	}
167  	defer h.exitHook()
168  
169  	if len(h.config.SessionStart) == 0 {
170  		return nil
171  	}
172  
173  	input := HookInput{
174  		Event:        SessionStart,
175  		ToolInput:    nullJSON(),
176  		ToolResponse: nullJSON(),
177  		SessionID:    sessionID,
178  	}
179  
180  	for _, entry := range h.config.SessionStart {
181  		exitCode, _, stderr, err := h.runHook(ctx, entry, input)
182  		if err != nil {
183  			log.Printf("[hooks] warning: SessionStart hook %q failed: %v", entry.Command, err)
184  			continue
185  		}
186  		if exitCode != 0 {
187  			log.Printf("[hooks] warning: SessionStart hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr))
188  		}
189  	}
190  
191  	return nil
192  }
193  
194  // RunStop runs all Stop hooks.
195  func (h *HookRunner) RunStop(ctx context.Context, sessionID string) error {
196  	if h == nil {
197  		return nil
198  	}
199  	if h.enterHook() {
200  		return nil
201  	}
202  	defer h.exitHook()
203  
204  	if len(h.config.Stop) == 0 {
205  		return nil
206  	}
207  
208  	input := HookInput{
209  		Event:        Stop,
210  		ToolInput:    nullJSON(),
211  		ToolResponse: nullJSON(),
212  		SessionID:    sessionID,
213  	}
214  
215  	for _, entry := range h.config.Stop {
216  		exitCode, _, stderr, err := h.runHook(ctx, entry, input)
217  		if err != nil {
218  			log.Printf("[hooks] warning: Stop hook %q failed: %v", entry.Command, err)
219  			continue
220  		}
221  		if exitCode != 0 {
222  			log.Printf("[hooks] warning: Stop hook %q exited with code %d: %s", entry.Command, exitCode, strings.TrimSpace(stderr))
223  		}
224  	}
225  
226  	return nil
227  }
228  
229  func (h *HookRunner) runHook(ctx context.Context, entry HookEntry, input HookInput) (exitCode int, stdout string, stderr string, err error) {
230  	cmdPath, err := resolveCommand(entry.Command, h.allowedDirs)
231  	if err != nil {
232  		return -1, "", "", err
233  	}
234  
235  	inputJSON, err := json.Marshal(input)
236  	if err != nil {
237  		return -1, "", "", fmt.Errorf("failed to marshal hook input: %w", err)
238  	}
239  
240  	hookCtx, cancel := context.WithTimeout(ctx, h.timeout)
241  	defer cancel()
242  
243  	cmd := exec.CommandContext(hookCtx, cmdPath)
244  	cmd.Stdin = bytes.NewReader(inputJSON)
245  
246  	// Create a new process group so we can kill the entire tree on timeout
247  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
248  	cmd.Cancel = func() error {
249  		// Kill the entire process group (negative PID)
250  		return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
251  	}
252  
253  	var stdoutBuf, stderrBuf bytes.Buffer
254  	cmd.Stdout = &limitedWriter{buf: &stdoutBuf, limit: maxOutputBytes}
255  	cmd.Stderr = &limitedWriter{buf: &stderrBuf, limit: maxOutputBytes}
256  
257  	if runErr := cmd.Run(); runErr != nil {
258  		if hookCtx.Err() == context.DeadlineExceeded {
259  			return -1, stdoutBuf.String(), stderrBuf.String(), fmt.Errorf("hook timed out after %v", h.timeout)
260  		}
261  		if exitErr, ok := runErr.(*exec.ExitError); ok {
262  			return exitErr.ExitCode(), stdoutBuf.String(), stderrBuf.String(), nil
263  		}
264  		return -1, stdoutBuf.String(), stderrBuf.String(), runErr
265  	}
266  
267  	return 0, stdoutBuf.String(), stderrBuf.String(), nil
268  }
269  
270  // matchEntries returns hook entries whose matcher regex matches the tool name.
271  // An empty matcher matches all tools.
272  func (h *HookRunner) matchEntries(entries []HookEntry, toolName string) []HookEntry {
273  	var matched []HookEntry
274  	for _, e := range entries {
275  		if e.Matcher == "" {
276  			matched = append(matched, e)
277  			continue
278  		}
279  		re, err := regexp.Compile(e.Matcher)
280  		if err != nil {
281  			log.Printf("[hooks] warning: invalid matcher regex %q: %v", e.Matcher, err)
282  			continue
283  		}
284  		if re.MatchString(toolName) {
285  			matched = append(matched, e)
286  		}
287  	}
288  	return matched
289  }
290  
291  // resolveCommand validates and resolves the command path.
292  // Allowed: relative paths (starting with . or no / prefix), or paths under ~/.shannon/.
293  // Additional directories can be allowed via extraAllowed (used for testing).
294  // Rejected: absolute paths outside allowed directories.
295  func resolveCommand(command string, extraAllowed []string) (string, error) {
296  	if command == "" {
297  		return "", fmt.Errorf("hook command must not be empty")
298  	}
299  
300  	// Expand ~ prefix
301  	if strings.HasPrefix(command, "~/") {
302  		home, err := os.UserHomeDir()
303  		if err != nil {
304  			return "", fmt.Errorf("failed to resolve home directory: %w", err)
305  		}
306  		command = filepath.Join(home, command[2:])
307  	}
308  
309  	// Reject bare command names that would be resolved via PATH
310  	if !filepath.IsAbs(command) && !strings.HasPrefix(command, "./") && !strings.HasPrefix(command, "../") && !strings.Contains(command, string(filepath.Separator)) {
311  		return "", fmt.Errorf("bare command %q rejected: use absolute path or ./ prefix", command)
312  	}
313  
314  	// If it's an absolute path, it must be under ~/.shannon/ or an extra allowed dir
315  	if filepath.IsAbs(command) {
316  		shannonDir := shannonDirPath()
317  		allowed := false
318  
319  		if shannonDir != "" && (strings.HasPrefix(command, shannonDir+string(filepath.Separator)) || command == shannonDir) {
320  			allowed = true
321  		}
322  		for _, dir := range extraAllowed {
323  			if strings.HasPrefix(command, dir+string(filepath.Separator)) || command == dir {
324  				allowed = true
325  				break
326  			}
327  		}
328  
329  		if !allowed {
330  			target := shannonDir
331  			if target == "" {
332  				target = "~/.shannon"
333  			}
334  			return "", fmt.Errorf("absolute path %q rejected: must be under %s", command, target)
335  		}
336  	}
337  
338  	return command, nil
339  }
340  
341  func shannonDirPath() string {
342  	home, err := os.UserHomeDir()
343  	if err != nil || home == "" {
344  		return ""
345  	}
346  	return filepath.Join(home, ".shannon")
347  }
348  
349  // enterHook attempts to set the recursion guard. Returns true if already inside a hook.
350  func (h *HookRunner) enterHook() bool {
351  	h.mu.Lock()
352  	defer h.mu.Unlock()
353  	if h.inHook {
354  		return true
355  	}
356  	h.inHook = true
357  	return false
358  }
359  
360  func (h *HookRunner) exitHook() {
361  	h.mu.Lock()
362  	defer h.mu.Unlock()
363  	h.inHook = false
364  }
365  
366  // toRawJSON converts a string to json.RawMessage.
367  // If the string is valid JSON, it's used as-is. Otherwise it's marshaled as a JSON string.
368  func toRawJSON(s string) json.RawMessage {
369  	if s == "" {
370  		return json.RawMessage("null")
371  	}
372  	if json.Valid([]byte(s)) {
373  		return json.RawMessage(s)
374  	}
375  	data, err := json.Marshal(s)
376  	if err != nil {
377  		return json.RawMessage("null")
378  	}
379  	return data
380  }
381  
382  func nullJSON() json.RawMessage {
383  	return json.RawMessage("null")
384  }
385  
386  // limitedWriter wraps a bytes.Buffer and stops writing after limit bytes.
387  type limitedWriter struct {
388  	buf   *bytes.Buffer
389  	limit int
390  }
391  
392  func (w *limitedWriter) Write(p []byte) (int, error) {
393  	totalLen := len(p)
394  	remaining := w.limit - w.buf.Len()
395  	if remaining <= 0 {
396  		return totalLen, nil // discard but report success
397  	}
398  	toWrite := p
399  	if len(toWrite) > remaining {
400  		toWrite = toWrite[:remaining]
401  	}
402  	if _, err := w.buf.Write(toWrite); err != nil {
403  		return 0, err
404  	}
405  	// Report all bytes as written even if we truncated, so exec doesn't fail
406  	return totalLen, nil
407  }
408  
409  // Ensure limitedWriter satisfies io.Writer.
410  var _ io.Writer = (*limitedWriter)(nil)