/ internal / agent / tools.go
tools.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"sort"
  6  
  7  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  8  )
  9  
 10  type ToolInfo struct {
 11  	Name        string
 12  	Description string
 13  	Parameters  map[string]any
 14  	Required    []string
 15  }
 16  
 17  type ImageBlock struct {
 18  	MediaType string // e.g. "image/png"
 19  	Data      string // base64-encoded
 20  }
 21  
 22  // ErrorCategory classifies the nature of a tool failure so the agent
 23  // can make informed retry decisions.
 24  type ErrorCategory string
 25  
 26  const (
 27  	// ErrCategoryTransient indicates a timeout or network error. Retry may help.
 28  	ErrCategoryTransient ErrorCategory = "transient"
 29  	// ErrCategoryValidation indicates the tool arguments were invalid. Fix before retrying.
 30  	ErrCategoryValidation ErrorCategory = "validation"
 31  	// ErrCategoryBusiness indicates a policy or constraint violation. Do not retry.
 32  	ErrCategoryBusiness ErrorCategory = "business"
 33  	// ErrCategoryPermission indicates access was denied. Escalate to user.
 34  	ErrCategoryPermission ErrorCategory = "permission"
 35  )
 36  
 37  // ToolSource classifies the origin of a tool for deterministic ordering.
 38  type ToolSource string
 39  
 40  const (
 41  	SourceLocal   ToolSource = "local"
 42  	SourceMCP     ToolSource = "mcp"
 43  	SourceGateway ToolSource = "gateway"
 44  )
 45  
 46  // ToolSourcer is an optional interface tools implement to declare their origin.
 47  // Tools that don't implement this are classified as SourceLocal.
 48  type ToolSourcer interface {
 49  	ToolSource() ToolSource
 50  }
 51  
 52  type ToolResult struct {
 53  	Content       string
 54  	IsError       bool
 55  	ErrorCategory ErrorCategory // empty when IsError is false
 56  	IsRetryable   bool          // true only for transient errors
 57  	Images        []ImageBlock
 58  	CloudResult   bool // true when result is a cloud deliverable (bypass LLM summarization)
 59  	// Usage optionally reports per-call cost for this tool. Gateway tools
 60  	// whose server returns billing info (x_search → xAI tokens, web_search
 61  	// → SerpAPI query count) populate this so the audit logger can write a
 62  	// cost breakdown per tool call. nil when the tool does not bill per call.
 63  	Usage *ToolUsage
 64  	// ContentBlocks, when non-nil, carries structured output (e.g. tool_reference
 65  	// blocks from tool_search) that loop.go passes through verbatim as
 66  	// tool_result content when the gateway/model supports the protocol.
 67  	// When nil, loop.go falls back to the Content string path.
 68  	ContentBlocks []client.ContentBlock
 69  	// SkillToolFilter, when non-nil, restricts the tool schemas sent to
 70  	// the LLM for the remainder of this Run() call. Only tools whose names
 71  	// appear in this list (plus use_skill itself) will be visible. Set by
 72  	// use_skill when the activated skill declares allowed-tools.
 73  	SkillToolFilter []string
 74  	// SkillToolHint, when non-empty, contains a <system-reminder> text to
 75  	// append to the tool_result content, guiding the LLM to restrict itself
 76  	// to the allowed tools. Works alongside SkillToolFilter (which provides
 77  	// execution-time denial).
 78  	SkillToolHint string
 79  }
 80  
 81  // ToolUsage is ToolResult's per-call cost breakdown. Mirrors client.ToolUsage
 82  // (see internal/client/gateway.go) so tool implementations depending on agent
 83  // don't need the client import.
 84  type ToolUsage struct {
 85  	Provider     string
 86  	Model        string
 87  	InputTokens  int
 88  	OutputTokens int
 89  	TotalTokens  int
 90  	CostUSD      float64
 91  	Units        int
 92  	UnitType     string
 93  }
 94  
 95  // TransientError returns a ToolResult for timeout/network failures where retry may help.
 96  func TransientError(msg string) ToolResult {
 97  	return ToolResult{
 98  		Content:       "[transient error] " + msg,
 99  		IsError:       true,
100  		ErrorCategory: ErrCategoryTransient,
101  		IsRetryable:   true,
102  	}
103  }
104  
105  // ValidationError returns a ToolResult for invalid tool arguments.
106  func ValidationError(msg string) ToolResult {
107  	return ToolResult{
108  		Content:       "[validation error] " + msg,
109  		IsError:       true,
110  		ErrorCategory: ErrCategoryValidation,
111  	}
112  }
113  
114  // BusinessError returns a ToolResult for policy/constraint violations that must not be retried.
115  func BusinessError(msg string) ToolResult {
116  	return ToolResult{
117  		Content:       "[business error] " + msg,
118  		IsError:       true,
119  		ErrorCategory: ErrCategoryBusiness,
120  	}
121  }
122  
123  // PermissionError returns a ToolResult for access denied scenarios requiring escalation.
124  func PermissionError(msg string) ToolResult {
125  	return ToolResult{
126  		Content:       "[permission error] " + msg,
127  		IsError:       true,
128  		ErrorCategory: ErrCategoryPermission,
129  	}
130  }
131  
132  type Tool interface {
133  	Info() ToolInfo
134  	Run(ctx context.Context, args string) (ToolResult, error)
135  	RequiresApproval() bool
136  }
137  
138  // NativeToolProvider is an optional interface for tools that use a provider's
139  // native tool schema (e.g., Anthropic's computer_20251124) instead of the
140  // standard function-calling format.
141  type NativeToolProvider interface {
142  	NativeToolDef() *client.NativeToolDef
143  }
144  
145  // SafeChecker is an optional interface tools can implement to indicate
146  // certain invocations are safe and don't need approval.
147  type SafeChecker interface {
148  	IsSafeArgs(argsJSON string) bool
149  }
150  
151  // SafeCheckerWithContext is like SafeChecker but receives the call context,
152  // allowing tools to use session-scoped CWD for path-based safety checks.
153  type SafeCheckerWithContext interface {
154  	IsSafeArgsWithContext(ctx context.Context, argsJSON string) bool
155  }
156  
157  // ReadOnlyChecker is an optional interface for tools that can classify
158  // individual calls as read-only based on arguments.
159  // If args parsing fails, implementations MUST return false (fail-closed).
160  type ReadOnlyChecker interface {
161  	IsReadOnlyCall(argsJSON string) bool
162  }
163  
164  // ToolSummary is a lightweight name+description pair for deferred tool listings.
165  type ToolSummary struct {
166  	Name        string
167  	Description string
168  }
169  
170  type ToolRegistry struct {
171  	tools map[string]Tool
172  	order []string
173  }
174  
175  func NewToolRegistry() *ToolRegistry {
176  	return &ToolRegistry{
177  		tools: make(map[string]Tool),
178  	}
179  }
180  
181  func (r *ToolRegistry) Register(t Tool) {
182  	name := t.Info().Name
183  	if _, exists := r.tools[name]; !exists {
184  		r.order = append(r.order, name)
185  	}
186  	r.tools[name] = t
187  }
188  
189  func (r *ToolRegistry) Clone() *ToolRegistry {
190  	clone := NewToolRegistry()
191  	for _, name := range r.order {
192  		tool := r.tools[name]
193  		clone.tools[name] = tool
194  		clone.order = append(clone.order, name)
195  	}
196  	return clone
197  }
198  
199  func (r *ToolRegistry) Get(name string) (Tool, bool) {
200  	t, ok := r.tools[name]
201  	return t, ok
202  }
203  
204  func (r *ToolRegistry) All() []Tool {
205  	tools := make([]Tool, 0, len(r.order))
206  	for _, name := range r.order {
207  		tools = append(tools, r.tools[name])
208  	}
209  	return tools
210  }
211  
212  // Remove removes a tool from the registry by name.
213  func (r *ToolRegistry) Remove(name string) {
214  	if _, ok := r.tools[name]; !ok {
215  		return
216  	}
217  	delete(r.tools, name)
218  	for i, n := range r.order {
219  		if n == name {
220  			r.order = append(r.order[:i], r.order[i+1:]...)
221  			return
222  		}
223  	}
224  }
225  
226  // Names returns the ordered list of tool names.
227  func (r *ToolRegistry) Names() []string {
228  	out := make([]string, len(r.order))
229  	copy(out, r.order)
230  	return out
231  }
232  
233  // Len returns the number of registered tools.
234  func (r *ToolRegistry) Len() int {
235  	return len(r.tools)
236  }
237  
238  // FilterByAllow returns a new registry containing only the named tools.
239  // Tools not found are silently skipped.
240  func (r *ToolRegistry) FilterByAllow(allow []string) *ToolRegistry {
241  	filtered := NewToolRegistry()
242  	for _, name := range allow {
243  		if t, ok := r.tools[name]; ok {
244  			filtered.Register(t)
245  		}
246  	}
247  	return filtered
248  }
249  
250  // FilterByDeny returns a new registry with the named tools removed.
251  func (r *ToolRegistry) FilterByDeny(deny []string) *ToolRegistry {
252  	denySet := make(map[string]struct{}, len(deny))
253  	for _, name := range deny {
254  		denySet[name] = struct{}{}
255  	}
256  	filtered := NewToolRegistry()
257  	for _, name := range r.order {
258  		if _, blocked := denySet[name]; !blocked {
259  			filtered.Register(r.tools[name])
260  		}
261  	}
262  	return filtered
263  }
264  
265  func (r *ToolRegistry) Schemas() []client.Tool {
266  	schemas := make([]client.Tool, 0, len(r.order))
267  	for _, name := range r.order {
268  		schemas = append(schemas, buildToolSchema(r.tools[name]))
269  	}
270  	return schemas
271  }
272  
273  // SummaryList returns name+description for all registered tools.
274  func (r *ToolRegistry) SummaryList() []ToolSummary {
275  	summaries := make([]ToolSummary, 0, len(r.order))
276  	for _, name := range r.order {
277  		info := r.tools[name].Info()
278  		summaries = append(summaries, ToolSummary{Name: info.Name, Description: info.Description})
279  	}
280  	return summaries
281  }
282  
283  // FullSchemas returns complete client.Tool schemas for the named tools.
284  // Unknown names are silently skipped.
285  func (r *ToolRegistry) FullSchemas(names []string) []client.Tool {
286  	schemas := make([]client.Tool, 0, len(names))
287  	for _, name := range names {
288  		if t, ok := r.tools[name]; ok {
289  			schemas = append(schemas, buildToolSchema(t))
290  		}
291  	}
292  	return schemas
293  }
294  
295  // SortedSchemas returns tool schemas in deterministic order:
296  // local tools (alpha) → MCP tools (alpha) → gateway tools (alpha).
297  func (r *ToolRegistry) SortedSchemas() []client.Tool {
298  	local, mcp, gw := r.partitionBySource()
299  	sort.Strings(local)
300  	sort.Strings(mcp)
301  	sort.Strings(gw)
302  
303  	schemas := make([]client.Tool, 0, len(r.order))
304  	for _, group := range [][]string{local, mcp, gw} {
305  		for _, name := range group {
306  			schemas = append(schemas, buildToolSchema(r.tools[name]))
307  		}
308  	}
309  	return schemas
310  }
311  
312  // buildToolSchema converts a Tool into a client.Tool schema definition.
313  func buildToolSchema(t Tool) client.Tool {
314  	if native, ok := t.(NativeToolProvider); ok {
315  		def := native.NativeToolDef()
316  		if def != nil {
317  			return client.Tool{
318  				Type:            def.Type,
319  				Name:            def.Name,
320  				DisplayWidthPx:  def.DisplayWidthPx,
321  				DisplayHeightPx: def.DisplayHeightPx,
322  			}
323  		}
324  	}
325  	info := t.Info()
326  	params := info.Parameters
327  	if params == nil {
328  		params = map[string]any{"type": "object", "properties": map[string]any{}}
329  	}
330  	if info.Required != nil {
331  		params["required"] = info.Required
332  	}
333  	return client.Tool{
334  		Type: "function",
335  		Function: client.FunctionDef{
336  			Name:        info.Name,
337  			Description: info.Description,
338  			Parameters:  params,
339  		},
340  	}
341  }
342  
343  // SortedNames returns tool names in the same deterministic order as SortedSchemas.
344  func (r *ToolRegistry) SortedNames() []string {
345  	local, mcp, gw := r.partitionBySource()
346  	sort.Strings(local)
347  	sort.Strings(mcp)
348  	sort.Strings(gw)
349  
350  	names := make([]string, 0, len(r.order))
351  	names = append(names, local...)
352  	names = append(names, mcp...)
353  	names = append(names, gw...)
354  	return names
355  }
356  
357  // MCPNames returns the names of all MCP-origin tools in the registry. Used by
358  // the loop detector to mark MCP tools as batch-tolerant so legitimate
359  // enumerations (e.g. iterating over distinct database UUIDs) do not trip the
360  // NoProgress nudge on count alone.
361  func (r *ToolRegistry) MCPNames() []string {
362  	_, mcp, _ := r.partitionBySource()
363  	return mcp
364  }
365  
366  // partitionBySource groups tool names by their source category.
367  func (r *ToolRegistry) partitionBySource() (local, mcp, gw []string) {
368  	for _, name := range r.order {
369  		t := r.tools[name]
370  		if sourcer, ok := t.(ToolSourcer); ok {
371  			switch sourcer.ToolSource() {
372  			case SourceMCP:
373  				mcp = append(mcp, name)
374  			case SourceGateway:
375  				gw = append(gw, name)
376  			default:
377  				local = append(local, name)
378  			}
379  		} else {
380  			local = append(local, name)
381  		}
382  	}
383  	return
384  }