/ internal / agent / readtracker.go
readtracker.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"path/filepath"
  7  	"strings"
  8  
  9  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 10  	"github.com/Kocoro-lab/ShanClaw/internal/cwdctx"
 11  )
 12  
 13  // readTrackerKey is the context key for ReadTracker.
 14  type readTrackerKey struct{}
 15  
 16  // memoryDirKey is the context key for the agent's memory directory path.
 17  type memoryDirKey struct{}
 18  
 19  // conversationSnapshotKey 是获取当前对话快照的 context key。
 20  type conversationSnapshotKey struct{}
 21  
 22  // ConversationSnapshotFunc 返回当前对话消息的快照副本。
 23  type ConversationSnapshotFunc func() []client.Message
 24  
 25  // WithConversationSnapshot 注入对话快照提供函数到 context。
 26  func WithConversationSnapshot(ctx context.Context, fn ConversationSnapshotFunc) context.Context {
 27  	return context.WithValue(ctx, conversationSnapshotKey{}, fn)
 28  }
 29  
 30  // ConversationSnapshotFromContext 从 context 获取对话快照提供函数。
 31  // 调用返回的函数可获取当前对话消息的副本。无 provider 时返回 nil。
 32  func ConversationSnapshotFromContext(ctx context.Context) ConversationSnapshotFunc {
 33  	fn, _ := ctx.Value(conversationSnapshotKey{}).(ConversationSnapshotFunc)
 34  	return fn
 35  }
 36  
 37  // WithMemoryDir returns a new context with the memory directory set.
 38  func WithMemoryDir(ctx context.Context, dir string) context.Context {
 39  	return context.WithValue(ctx, memoryDirKey{}, dir)
 40  }
 41  
 42  // MemoryDirFromContext returns the memory directory from context, or "".
 43  func MemoryDirFromContext(ctx context.Context) string {
 44  	if v, ok := ctx.Value(memoryDirKey{}).(string); ok {
 45  		return v
 46  	}
 47  	return ""
 48  }
 49  
 50  // IsMemoryFile returns true if path resolves to the MEMORY.md inside the
 51  // agent's configured memory directory. Returns false when no memory dir
 52  // is set in context (e.g. tool called outside agent loop).
 53  func IsMemoryFile(ctx context.Context, path string) bool {
 54  	dir, ok := ctx.Value(memoryDirKey{}).(string)
 55  	if !ok || dir == "" {
 56  		return false
 57  	}
 58  	resolvedPath := cwdctx.ResolvePath(ctx, path)
 59  	memPath := filepath.Clean(filepath.Join(dir, "MEMORY.md"))
 60  	return strings.EqualFold(resolvedPath, memPath)
 61  }
 62  
 63  // ReadTrackerKey returns the context key used to store a ReadTracker.
 64  // Exported for use in tests that need to inject a tracker into context.
 65  func ReadTrackerKey() any { return readTrackerKey{} }
 66  
 67  // ReadTracker tracks which files have been read during the current agent turn.
 68  // Used to enforce read-before-edit: file_edit and file_write on existing files
 69  // must be preceded by a file_read of that file.
 70  type ReadTracker struct {
 71  	read map[string]bool
 72  	cwd  string
 73  }
 74  
 75  // NewReadTracker creates a new ReadTracker.
 76  func NewReadTracker() *ReadTracker {
 77  	return &ReadTracker{read: make(map[string]bool)}
 78  }
 79  
 80  // SetCWD sets the session CWD used for relative path resolution.
 81  func (rt *ReadTracker) SetCWD(cwd string) {
 82  	rt.cwd = cwd
 83  }
 84  
 85  // MarkRead records that a file has been read.
 86  func (rt *ReadTracker) MarkRead(path string) {
 87  	norm := normalizePathWithCWD(path, rt.cwd)
 88  	if norm != "" {
 89  		rt.read[norm] = true
 90  	}
 91  }
 92  
 93  // HasRead returns true if the file has been read in this turn.
 94  func (rt *ReadTracker) HasRead(path string) bool {
 95  	norm := normalizePathWithCWD(path, rt.cwd)
 96  	if norm == "" {
 97  		return false
 98  	}
 99  	return rt.read[norm]
100  }
101  
102  // CheckReadBeforeWrite extracts the ReadTracker from context and returns an error
103  // if the given path has not been read. Returns nil if the tracker is absent (e.g.,
104  // tool called outside the agent loop) or the file has been read.
105  func CheckReadBeforeWrite(ctx context.Context, path string) error {
106  	rt, ok := ctx.Value(readTrackerKey{}).(*ReadTracker)
107  	if !ok || rt == nil {
108  		return nil
109  	}
110  	if !rt.HasRead(path) {
111  		return fmt.Errorf("You must read this file with file_read before editing it. Path: %s", path)
112  	}
113  	return nil
114  }
115  
116  // normalizePathWithCWD resolves a path to an absolute, clean, symlink-resolved
117  // form using the given cwd for relative path resolution. When cwd is empty
118  // (scopeless daemon tasks that arrive without a CWD) a relative input is
119  // returned cleaned but unresolved; callers must not fall back to the daemon
120  // process cwd, which is what the wider CWD-hardening work was meant to
121  // eliminate.
122  func normalizePathWithCWD(path, cwd string) string {
123  	if path == "" {
124  		return ""
125  	}
126  	if !filepath.IsAbs(path) {
127  		if cwd == "" {
128  			return filepath.Clean(path)
129  		}
130  		path = filepath.Join(cwd, path)
131  	}
132  	path = filepath.Clean(path)
133  	// Try to resolve symlinks; if it fails (file doesn't exist yet), use the clean path.
134  	if resolved, err := filepath.EvalSymlinks(path); err == nil {
135  		path = resolved
136  	}
137  	return path
138  }