tools.go
1 package agent 2 3 import ( 4 "context" 5 "sort" 6 7 "github.com/Kocoro-lab/ShanClaw/internal/client" 8 ) 9 10 type ToolInfo struct { 11 Name string 12 Description string 13 Parameters map[string]any 14 Required []string 15 } 16 17 type ImageBlock struct { 18 MediaType string // e.g. "image/png" 19 Data string // base64-encoded 20 } 21 22 // ErrorCategory classifies the nature of a tool failure so the agent 23 // can make informed retry decisions. 24 type ErrorCategory string 25 26 const ( 27 // ErrCategoryTransient indicates a timeout or network error. Retry may help. 28 ErrCategoryTransient ErrorCategory = "transient" 29 // ErrCategoryValidation indicates the tool arguments were invalid. Fix before retrying. 30 ErrCategoryValidation ErrorCategory = "validation" 31 // ErrCategoryBusiness indicates a policy or constraint violation. Do not retry. 32 ErrCategoryBusiness ErrorCategory = "business" 33 // ErrCategoryPermission indicates access was denied. Escalate to user. 34 ErrCategoryPermission ErrorCategory = "permission" 35 ) 36 37 // ToolSource classifies the origin of a tool for deterministic ordering. 38 type ToolSource string 39 40 const ( 41 SourceLocal ToolSource = "local" 42 SourceMCP ToolSource = "mcp" 43 SourceGateway ToolSource = "gateway" 44 ) 45 46 // ToolSourcer is an optional interface tools implement to declare their origin. 47 // Tools that don't implement this are classified as SourceLocal. 48 type ToolSourcer interface { 49 ToolSource() ToolSource 50 } 51 52 type ToolResult struct { 53 Content string 54 IsError bool 55 ErrorCategory ErrorCategory // empty when IsError is false 56 IsRetryable bool // true only for transient errors 57 Images []ImageBlock 58 CloudResult bool // true when result is a cloud deliverable (bypass LLM summarization) 59 // Usage optionally reports per-call cost for this tool. Gateway tools 60 // whose server returns billing info (x_search → xAI tokens, web_search 61 // → SerpAPI query count) populate this so the audit logger can write a 62 // cost breakdown per tool call. nil when the tool does not bill per call. 63 Usage *ToolUsage 64 // ContentBlocks, when non-nil, carries structured output (e.g. tool_reference 65 // blocks from tool_search) that loop.go passes through verbatim as 66 // tool_result content when the gateway/model supports the protocol. 67 // When nil, loop.go falls back to the Content string path. 68 ContentBlocks []client.ContentBlock 69 // SkillToolFilter, when non-nil, restricts the tool schemas sent to 70 // the LLM for the remainder of this Run() call. Only tools whose names 71 // appear in this list (plus use_skill itself) will be visible. Set by 72 // use_skill when the activated skill declares allowed-tools. 73 SkillToolFilter []string 74 // SkillToolHint, when non-empty, contains a <system-reminder> text to 75 // append to the tool_result content, guiding the LLM to restrict itself 76 // to the allowed tools. Works alongside SkillToolFilter (which provides 77 // execution-time denial). 78 SkillToolHint string 79 } 80 81 // ToolUsage is ToolResult's per-call cost breakdown. Mirrors client.ToolUsage 82 // (see internal/client/gateway.go) so tool implementations depending on agent 83 // don't need the client import. 84 type ToolUsage struct { 85 Provider string 86 Model string 87 InputTokens int 88 OutputTokens int 89 TotalTokens int 90 CostUSD float64 91 Units int 92 UnitType string 93 } 94 95 // TransientError returns a ToolResult for timeout/network failures where retry may help. 96 func TransientError(msg string) ToolResult { 97 return ToolResult{ 98 Content: "[transient error] " + msg, 99 IsError: true, 100 ErrorCategory: ErrCategoryTransient, 101 IsRetryable: true, 102 } 103 } 104 105 // ValidationError returns a ToolResult for invalid tool arguments. 106 func ValidationError(msg string) ToolResult { 107 return ToolResult{ 108 Content: "[validation error] " + msg, 109 IsError: true, 110 ErrorCategory: ErrCategoryValidation, 111 } 112 } 113 114 // BusinessError returns a ToolResult for policy/constraint violations that must not be retried. 115 func BusinessError(msg string) ToolResult { 116 return ToolResult{ 117 Content: "[business error] " + msg, 118 IsError: true, 119 ErrorCategory: ErrCategoryBusiness, 120 } 121 } 122 123 // PermissionError returns a ToolResult for access denied scenarios requiring escalation. 124 func PermissionError(msg string) ToolResult { 125 return ToolResult{ 126 Content: "[permission error] " + msg, 127 IsError: true, 128 ErrorCategory: ErrCategoryPermission, 129 } 130 } 131 132 type Tool interface { 133 Info() ToolInfo 134 Run(ctx context.Context, args string) (ToolResult, error) 135 RequiresApproval() bool 136 } 137 138 // NativeToolProvider is an optional interface for tools that use a provider's 139 // native tool schema (e.g., Anthropic's computer_20251124) instead of the 140 // standard function-calling format. 141 type NativeToolProvider interface { 142 NativeToolDef() *client.NativeToolDef 143 } 144 145 // SafeChecker is an optional interface tools can implement to indicate 146 // certain invocations are safe and don't need approval. 147 type SafeChecker interface { 148 IsSafeArgs(argsJSON string) bool 149 } 150 151 // SafeCheckerWithContext is like SafeChecker but receives the call context, 152 // allowing tools to use session-scoped CWD for path-based safety checks. 153 type SafeCheckerWithContext interface { 154 IsSafeArgsWithContext(ctx context.Context, argsJSON string) bool 155 } 156 157 // ReadOnlyChecker is an optional interface for tools that can classify 158 // individual calls as read-only based on arguments. 159 // If args parsing fails, implementations MUST return false (fail-closed). 160 type ReadOnlyChecker interface { 161 IsReadOnlyCall(argsJSON string) bool 162 } 163 164 // ToolSummary is a lightweight name+description pair for deferred tool listings. 165 type ToolSummary struct { 166 Name string 167 Description string 168 } 169 170 type ToolRegistry struct { 171 tools map[string]Tool 172 order []string 173 } 174 175 func NewToolRegistry() *ToolRegistry { 176 return &ToolRegistry{ 177 tools: make(map[string]Tool), 178 } 179 } 180 181 func (r *ToolRegistry) Register(t Tool) { 182 name := t.Info().Name 183 if _, exists := r.tools[name]; !exists { 184 r.order = append(r.order, name) 185 } 186 r.tools[name] = t 187 } 188 189 func (r *ToolRegistry) Clone() *ToolRegistry { 190 clone := NewToolRegistry() 191 for _, name := range r.order { 192 tool := r.tools[name] 193 clone.tools[name] = tool 194 clone.order = append(clone.order, name) 195 } 196 return clone 197 } 198 199 func (r *ToolRegistry) Get(name string) (Tool, bool) { 200 t, ok := r.tools[name] 201 return t, ok 202 } 203 204 func (r *ToolRegistry) All() []Tool { 205 tools := make([]Tool, 0, len(r.order)) 206 for _, name := range r.order { 207 tools = append(tools, r.tools[name]) 208 } 209 return tools 210 } 211 212 // Remove removes a tool from the registry by name. 213 func (r *ToolRegistry) Remove(name string) { 214 if _, ok := r.tools[name]; !ok { 215 return 216 } 217 delete(r.tools, name) 218 for i, n := range r.order { 219 if n == name { 220 r.order = append(r.order[:i], r.order[i+1:]...) 221 return 222 } 223 } 224 } 225 226 // Names returns the ordered list of tool names. 227 func (r *ToolRegistry) Names() []string { 228 out := make([]string, len(r.order)) 229 copy(out, r.order) 230 return out 231 } 232 233 // Len returns the number of registered tools. 234 func (r *ToolRegistry) Len() int { 235 return len(r.tools) 236 } 237 238 // FilterByAllow returns a new registry containing only the named tools. 239 // Tools not found are silently skipped. 240 func (r *ToolRegistry) FilterByAllow(allow []string) *ToolRegistry { 241 filtered := NewToolRegistry() 242 for _, name := range allow { 243 if t, ok := r.tools[name]; ok { 244 filtered.Register(t) 245 } 246 } 247 return filtered 248 } 249 250 // FilterByDeny returns a new registry with the named tools removed. 251 func (r *ToolRegistry) FilterByDeny(deny []string) *ToolRegistry { 252 denySet := make(map[string]struct{}, len(deny)) 253 for _, name := range deny { 254 denySet[name] = struct{}{} 255 } 256 filtered := NewToolRegistry() 257 for _, name := range r.order { 258 if _, blocked := denySet[name]; !blocked { 259 filtered.Register(r.tools[name]) 260 } 261 } 262 return filtered 263 } 264 265 func (r *ToolRegistry) Schemas() []client.Tool { 266 schemas := make([]client.Tool, 0, len(r.order)) 267 for _, name := range r.order { 268 schemas = append(schemas, buildToolSchema(r.tools[name])) 269 } 270 return schemas 271 } 272 273 // SummaryList returns name+description for all registered tools. 274 func (r *ToolRegistry) SummaryList() []ToolSummary { 275 summaries := make([]ToolSummary, 0, len(r.order)) 276 for _, name := range r.order { 277 info := r.tools[name].Info() 278 summaries = append(summaries, ToolSummary{Name: info.Name, Description: info.Description}) 279 } 280 return summaries 281 } 282 283 // FullSchemas returns complete client.Tool schemas for the named tools. 284 // Unknown names are silently skipped. 285 func (r *ToolRegistry) FullSchemas(names []string) []client.Tool { 286 schemas := make([]client.Tool, 0, len(names)) 287 for _, name := range names { 288 if t, ok := r.tools[name]; ok { 289 schemas = append(schemas, buildToolSchema(t)) 290 } 291 } 292 return schemas 293 } 294 295 // SortedSchemas returns tool schemas in deterministic order: 296 // local tools (alpha) → MCP tools (alpha) → gateway tools (alpha). 297 func (r *ToolRegistry) SortedSchemas() []client.Tool { 298 local, mcp, gw := r.partitionBySource() 299 sort.Strings(local) 300 sort.Strings(mcp) 301 sort.Strings(gw) 302 303 schemas := make([]client.Tool, 0, len(r.order)) 304 for _, group := range [][]string{local, mcp, gw} { 305 for _, name := range group { 306 schemas = append(schemas, buildToolSchema(r.tools[name])) 307 } 308 } 309 return schemas 310 } 311 312 // buildToolSchema converts a Tool into a client.Tool schema definition. 313 func buildToolSchema(t Tool) client.Tool { 314 if native, ok := t.(NativeToolProvider); ok { 315 def := native.NativeToolDef() 316 if def != nil { 317 return client.Tool{ 318 Type: def.Type, 319 Name: def.Name, 320 DisplayWidthPx: def.DisplayWidthPx, 321 DisplayHeightPx: def.DisplayHeightPx, 322 } 323 } 324 } 325 info := t.Info() 326 params := info.Parameters 327 if params == nil { 328 params = map[string]any{"type": "object", "properties": map[string]any{}} 329 } 330 if info.Required != nil { 331 params["required"] = info.Required 332 } 333 return client.Tool{ 334 Type: "function", 335 Function: client.FunctionDef{ 336 Name: info.Name, 337 Description: info.Description, 338 Parameters: params, 339 }, 340 } 341 } 342 343 // SortedNames returns tool names in the same deterministic order as SortedSchemas. 344 func (r *ToolRegistry) SortedNames() []string { 345 local, mcp, gw := r.partitionBySource() 346 sort.Strings(local) 347 sort.Strings(mcp) 348 sort.Strings(gw) 349 350 names := make([]string, 0, len(r.order)) 351 names = append(names, local...) 352 names = append(names, mcp...) 353 names = append(names, gw...) 354 return names 355 } 356 357 // MCPNames returns the names of all MCP-origin tools in the registry. Used by 358 // the loop detector to mark MCP tools as batch-tolerant so legitimate 359 // enumerations (e.g. iterating over distinct database UUIDs) do not trip the 360 // NoProgress nudge on count alone. 361 func (r *ToolRegistry) MCPNames() []string { 362 _, mcp, _ := r.partitionBySource() 363 return mcp 364 } 365 366 // partitionBySource groups tool names by their source category. 367 func (r *ToolRegistry) partitionBySource() (local, mcp, gw []string) { 368 for _, name := range r.order { 369 t := r.tools[name] 370 if sourcer, ok := t.(ToolSourcer); ok { 371 switch sourcer.ToolSource() { 372 case SourceMCP: 373 mcp = append(mcp, name) 374 case SourceGateway: 375 gw = append(gw, name) 376 default: 377 local = append(local, name) 378 } 379 } else { 380 local = append(local, name) 381 } 382 } 383 return 384 }