tools_test.go
1 package agent 2 3 import ( 4 "context" 5 "strings" 6 "testing" 7 8 "github.com/Kocoro-lab/ShanClaw/internal/client" 9 ) 10 11 func TestToolRegistry_Get(t *testing.T) { 12 reg := NewToolRegistry() 13 reg.Register(&mockTool{name: "file_read"}) 14 15 tool, ok := reg.Get("file_read") 16 if !ok { 17 t.Fatal("expected to find file_read") 18 } 19 if tool.Info().Name != "file_read" { 20 t.Errorf("expected 'file_read', got %q", tool.Info().Name) 21 } 22 23 _, ok = reg.Get("nonexistent") 24 if ok { 25 t.Error("expected not found") 26 } 27 } 28 29 func TestToolRegistry_Schemas(t *testing.T) { 30 reg := NewToolRegistry() 31 reg.Register(&mockTool{name: "file_read"}) 32 reg.Register(&mockTool{name: "bash"}) 33 34 schemas := reg.Schemas() 35 if len(schemas) != 2 { 36 t.Errorf("expected 2 schemas, got %d", len(schemas)) 37 } 38 } 39 40 type mockTool struct { 41 name string 42 } 43 44 func (m *mockTool) Info() ToolInfo { 45 return ToolInfo{ 46 Name: m.name, 47 Description: "mock tool", 48 Parameters: map[string]any{"type": "object", "properties": map[string]any{}}, 49 } 50 } 51 52 func (m *mockTool) Run(ctx context.Context, args string) (ToolResult, error) { 53 return ToolResult{Content: "mock result"}, nil 54 } 55 56 func (m *mockTool) RequiresApproval() bool { return false } 57 58 type mockNativeTool struct { 59 name string 60 } 61 62 func (m *mockNativeTool) Info() ToolInfo { 63 return ToolInfo{Name: m.name, Description: "native tool"} 64 } 65 func (m *mockNativeTool) Run(ctx context.Context, args string) (ToolResult, error) { 66 return ToolResult{Content: "ok"}, nil 67 } 68 func (m *mockNativeTool) RequiresApproval() bool { return false } 69 func (m *mockNativeTool) NativeToolDef() *client.NativeToolDef { 70 return &client.NativeToolDef{ 71 Type: "computer_20251124", 72 Name: "computer", 73 DisplayWidthPx: 1280, 74 DisplayHeightPx: 800, 75 } 76 } 77 78 func TestToolRegistry_SchemasIncludesNativeTool(t *testing.T) { 79 reg := NewToolRegistry() 80 reg.Register(&mockNativeTool{name: "computer"}) 81 reg.Register(&mockTool{name: "bash"}) 82 83 schemas := reg.Schemas() 84 if len(schemas) != 2 { 85 t.Fatalf("expected 2 schemas, got %d", len(schemas)) 86 } 87 // Native tool should use its own type 88 if schemas[0].Type != "computer_20251124" { 89 t.Errorf("expected type 'computer_20251124', got %q", schemas[0].Type) 90 } 91 if schemas[0].Name != "computer" { 92 t.Errorf("expected name 'computer', got %q", schemas[0].Name) 93 } 94 if schemas[0].DisplayWidthPx != 1280 { 95 t.Errorf("expected display_width_px 1280, got %d", schemas[0].DisplayWidthPx) 96 } 97 // Standard tool should use function type 98 if schemas[1].Type != "function" { 99 t.Errorf("expected type 'function' for bash, got %q", schemas[1].Type) 100 } 101 } 102 103 func TestToolRegistry_Remove(t *testing.T) { 104 r := NewToolRegistry() 105 r.Register(&mockTool{name: "a"}) 106 r.Register(&mockTool{name: "b"}) 107 r.Register(&mockTool{name: "c"}) 108 109 r.Remove("b") 110 111 if _, ok := r.Get("b"); ok { 112 t.Error("b should be removed") 113 } 114 if r.Len() != 2 { 115 t.Errorf("Len() = %d, want 2", r.Len()) 116 } 117 names := r.Names() 118 if len(names) != 2 || names[0] != "a" || names[1] != "c" { 119 t.Errorf("names = %v, want [a c]", names) 120 } 121 } 122 123 func TestToolRegistry_RemoveNonexistent(t *testing.T) { 124 r := NewToolRegistry() 125 r.Register(&mockTool{name: "a"}) 126 r.Remove("nonexistent") // should not panic 127 if r.Len() != 1 { 128 t.Errorf("Len() = %d, want 1", r.Len()) 129 } 130 } 131 132 func TestToolRegistry_FilterByAllow(t *testing.T) { 133 r := NewToolRegistry() 134 r.Register(&mockTool{name: "file_read"}) 135 r.Register(&mockTool{name: "bash"}) 136 r.Register(&mockTool{name: "computer"}) 137 r.Register(&mockTool{name: "browser"}) 138 139 filtered := r.FilterByAllow([]string{"file_read", "bash"}) 140 if filtered.Len() != 2 { 141 t.Errorf("filtered Len() = %d, want 2", filtered.Len()) 142 } 143 if _, ok := filtered.Get("computer"); ok { 144 t.Error("computer should be filtered out") 145 } 146 if _, ok := filtered.Get("file_read"); !ok { 147 t.Error("file_read should be present") 148 } 149 } 150 151 func TestToolRegistry_FilterByDeny(t *testing.T) { 152 r := NewToolRegistry() 153 r.Register(&mockTool{name: "file_read"}) 154 r.Register(&mockTool{name: "bash"}) 155 r.Register(&mockTool{name: "computer"}) 156 r.Register(&mockTool{name: "browser"}) 157 158 filtered := r.FilterByDeny([]string{"computer", "browser"}) 159 if filtered.Len() != 2 { 160 t.Errorf("filtered Len() = %d, want 2", filtered.Len()) 161 } 162 if _, ok := filtered.Get("computer"); ok { 163 t.Error("computer should be denied") 164 } 165 if _, ok := filtered.Get("file_read"); !ok { 166 t.Error("file_read should be present") 167 } 168 } 169 170 func TestToolRegistry_CloneIndependence(t *testing.T) { 171 r := NewToolRegistry() 172 r.Register(&mockTool{name: "a"}) 173 r.Register(&mockTool{name: "b"}) 174 175 c := r.Clone() 176 c.Remove("a") 177 178 if _, ok := r.Get("a"); !ok { 179 t.Error("original should still have 'a'") 180 } 181 if c.Len() != 1 { 182 t.Errorf("clone Len() = %d, want 1", c.Len()) 183 } 184 } 185 186 func TestToolRegistry_RegisterOverwrite(t *testing.T) { 187 r := NewToolRegistry() 188 r.Register(&mockTool{name: "a"}) 189 r.Register(&mockTool{name: "b"}) 190 r.Register(&mockTool{name: "a"}) // overwrite 191 192 names := r.Names() 193 if len(names) != 2 { 194 t.Errorf("expected 2 names after overwrite, got %d: %v", len(names), names) 195 } 196 if r.Len() != 2 { 197 t.Errorf("Len() = %d, want 2", r.Len()) 198 } 199 schemas := r.Schemas() 200 if len(schemas) != 2 { 201 t.Errorf("expected 2 schemas, got %d", len(schemas)) 202 } 203 } 204 205 func TestToolRegistry_RemoveAndReRegister(t *testing.T) { 206 r := NewToolRegistry() 207 r.Register(&mockTool{name: "a"}) 208 r.Register(&mockTool{name: "b"}) 209 r.Remove("a") 210 r.Register(&mockTool{name: "a"}) 211 212 names := r.Names() 213 if len(names) != 2 { 214 t.Errorf("expected 2 names, got %d: %v", len(names), names) 215 } 216 schemas := r.Schemas() 217 if len(schemas) != 2 { 218 t.Errorf("expected 2 schemas, got %d", len(schemas)) 219 } 220 } 221 222 func TestToolResultErrorHelpers(t *testing.T) { 223 tests := []struct { 224 name string 225 result ToolResult 226 wantIsError bool 227 wantCat ErrorCategory 228 wantRetry bool 229 wantPrefix string 230 }{ 231 { 232 name: "TransientError", 233 result: TransientError("connection timed out"), 234 wantIsError: true, 235 wantCat: ErrCategoryTransient, 236 wantRetry: true, 237 wantPrefix: "[transient error]", 238 }, 239 { 240 name: "ValidationError", 241 result: ValidationError("invalid URL format"), 242 wantIsError: true, 243 wantCat: ErrCategoryValidation, 244 wantRetry: false, 245 wantPrefix: "[validation error]", 246 }, 247 { 248 name: "BusinessError", 249 result: BusinessError("refund exceeds policy limit"), 250 wantIsError: true, 251 wantCat: ErrCategoryBusiness, 252 wantRetry: false, 253 wantPrefix: "[business error]", 254 }, 255 { 256 name: "PermissionError", 257 result: PermissionError("access denied"), 258 wantIsError: true, 259 wantCat: ErrCategoryPermission, 260 wantRetry: false, 261 wantPrefix: "[permission error]", 262 }, 263 } 264 265 for _, tt := range tests { 266 t.Run(tt.name, func(t *testing.T) { 267 if tt.result.IsError != tt.wantIsError { 268 t.Errorf("IsError = %v, want %v", tt.result.IsError, tt.wantIsError) 269 } 270 if tt.result.ErrorCategory != tt.wantCat { 271 t.Errorf("ErrorCategory = %q, want %q", tt.result.ErrorCategory, tt.wantCat) 272 } 273 if tt.result.IsRetryable != tt.wantRetry { 274 t.Errorf("IsRetryable = %v, want %v", tt.result.IsRetryable, tt.wantRetry) 275 } 276 if !strings.HasPrefix(tt.result.Content, tt.wantPrefix) { 277 t.Errorf("Content = %q, want prefix %q", tt.result.Content, tt.wantPrefix) 278 } 279 }) 280 } 281 } 282 283 func TestToolResult_ZeroValueNotError(t *testing.T) { 284 r := ToolResult{Content: "some output"} 285 if r.IsError { 286 t.Error("zero-value ToolResult must not be an error") 287 } 288 if r.ErrorCategory != "" { 289 t.Errorf("zero-value ErrorCategory must be empty, got %q", r.ErrorCategory) 290 } 291 if r.IsRetryable { 292 t.Error("zero-value IsRetryable must be false") 293 } 294 } 295 296 func TestToolResult_ImagesField(t *testing.T) { 297 result := ToolResult{ 298 Content: "Screenshot captured", 299 IsError: false, 300 Images: []ImageBlock{ 301 {MediaType: "image/png", Data: "iVBORfakedata"}, 302 }, 303 } 304 if len(result.Images) != 1 { 305 t.Errorf("expected 1 image, got %d", len(result.Images)) 306 } 307 if result.Images[0].MediaType != "image/png" { 308 t.Errorf("expected image/png, got %s", result.Images[0].MediaType) 309 } 310 } 311 312 // mockSourcedTool is a mock tool that implements ToolSourcer. 313 type mockSourcedTool struct { 314 name string 315 source ToolSource 316 } 317 318 func (m *mockSourcedTool) Info() ToolInfo { 319 return ToolInfo{ 320 Name: m.name, 321 Description: "mock sourced tool", 322 Parameters: map[string]any{"type": "object", "properties": map[string]any{}}, 323 } 324 } 325 func (m *mockSourcedTool) Run(ctx context.Context, args string) (ToolResult, error) { 326 return ToolResult{Content: "ok"}, nil 327 } 328 func (m *mockSourcedTool) RequiresApproval() bool { return false } 329 func (m *mockSourcedTool) ToolSource() ToolSource { return m.source } 330 331 func TestToolRegistry_SortedSchemas(t *testing.T) { 332 r := NewToolRegistry() 333 // Register in non-alphabetical, mixed-source order 334 r.Register(&mockTool{name: "grep"}) // local 335 r.Register(&mockSourcedTool{name: "browser_click", source: SourceMCP}) // mcp 336 r.Register(&mockSourcedTool{name: "web_search", source: SourceGateway}) // gateway 337 r.Register(&mockTool{name: "bash"}) // local 338 r.Register(&mockSourcedTool{name: "browser_type", source: SourceMCP}) // mcp 339 r.Register(&mockSourcedTool{name: "alpaca_news", source: SourceGateway}) // gateway 340 r.Register(&mockTool{name: "file_read"}) // local 341 342 schemas := r.SortedSchemas() 343 var names []string 344 for _, s := range schemas { 345 names = append(names, s.Function.Name) 346 } 347 348 expected := []string{ 349 // local alpha 350 "bash", "file_read", "grep", 351 // mcp alpha 352 "browser_click", "browser_type", 353 // gateway alpha 354 "alpaca_news", "web_search", 355 } 356 if len(names) != len(expected) { 357 t.Fatalf("got %d schemas, want %d: %v", len(names), len(expected), names) 358 } 359 for i, want := range expected { 360 if names[i] != want { 361 t.Errorf("position %d: got %q, want %q (full: %v)", i, names[i], want, names) 362 break 363 } 364 } 365 } 366 367 func TestToolRegistry_SortedNames(t *testing.T) { 368 r := NewToolRegistry() 369 r.Register(&mockTool{name: "grep"}) 370 r.Register(&mockSourcedTool{name: "browser_click", source: SourceMCP}) 371 r.Register(&mockTool{name: "bash"}) 372 373 names := r.SortedNames() 374 expected := []string{"bash", "grep", "browser_click"} 375 if len(names) != len(expected) { 376 t.Fatalf("got %v, want %v", names, expected) 377 } 378 for i, want := range expected { 379 if names[i] != want { 380 t.Errorf("position %d: got %q, want %q", i, names[i], want) 381 } 382 } 383 } 384 385 func TestToolRegistry_SortedSchemas_MCPAdditionDoesNotShiftLocal(t *testing.T) { 386 r := NewToolRegistry() 387 r.Register(&mockTool{name: "grep"}) 388 r.Register(&mockTool{name: "bash"}) 389 390 before := r.SortedNames() 391 392 r.Register(&mockSourcedTool{name: "browser_navigate", source: SourceMCP}) 393 394 after := r.SortedNames() 395 // Local tools should still be in positions 0 and 1 with same order 396 for i := 0; i < 2; i++ { 397 if before[i] != after[i] { 398 t.Errorf("local tool shifted: position %d was %q, now %q", i, before[i], after[i]) 399 } 400 } 401 } 402 403 func TestToolRegistry_SummaryList(t *testing.T) { 404 reg := NewToolRegistry() 405 reg.Register(&mockTool{name: "bash"}) 406 reg.Register(&mockTool{name: "file_read"}) 407 408 summaries := reg.SummaryList() 409 if len(summaries) != 2 { 410 t.Fatalf("expected 2 summaries, got %d", len(summaries)) 411 } 412 for _, s := range summaries { 413 if s.Name == "" { 414 t.Error("summary name is empty") 415 } 416 if s.Description == "" { 417 t.Error("summary description is empty") 418 } 419 } 420 } 421 422 func TestToolRegistry_FullSchemas(t *testing.T) { 423 reg := NewToolRegistry() 424 reg.Register(&mockTool{name: "bash"}) 425 reg.Register(&mockTool{name: "file_read"}) 426 reg.Register(&mockTool{name: "grep"}) 427 428 schemas := reg.FullSchemas([]string{"bash", "file_read"}) 429 if len(schemas) != 2 { 430 t.Fatalf("expected 2 schemas, got %d", len(schemas)) 431 } 432 names := map[string]bool{} 433 for _, s := range schemas { 434 names[s.Function.Name] = true 435 } 436 if !names["bash"] || !names["file_read"] { 437 t.Errorf("expected bash and file_read, got %v", names) 438 } 439 } 440 441 func TestToolRegistry_FullSchemas_Nonexistent(t *testing.T) { 442 reg := NewToolRegistry() 443 reg.Register(&mockTool{name: "bash"}) 444 445 schemas := reg.FullSchemas([]string{"nonexistent"}) 446 if len(schemas) != 0 { 447 t.Fatalf("expected 0 schemas for nonexistent tool, got %d", len(schemas)) 448 } 449 } 450 451 func TestTurnUsage_CacheTelemetry(t *testing.T) { 452 u := &TurnUsage{} 453 454 // Turn 1: cache creation (first turn always creates, no reads) 455 u.Add(client.Usage{InputTokens: 5000, CacheCreationTokens: 4000, CacheReadTokens: 0}) 456 if !u.cacheCapable { 457 t.Error("should be cache-capable after seeing CacheCreationTokens > 0") 458 } 459 if u.cacheMissStreak != 0 { 460 t.Errorf("first turn should not count as miss, got streak %d", u.cacheMissStreak) 461 } 462 463 // Turn 2: cache hit 464 u.Add(client.Usage{InputTokens: 5000, CacheReadTokens: 3500}) 465 if u.cacheMissStreak != 0 { 466 t.Errorf("cache hit should reset streak, got %d", u.cacheMissStreak) 467 } 468 469 // Turns 3-5: cache misses 470 for i := 0; i < 3; i++ { 471 u.Add(client.Usage{InputTokens: 5000, CacheReadTokens: 0}) 472 } 473 if u.cacheMissStreak != 3 { 474 t.Errorf("expected miss streak 3, got %d", u.cacheMissStreak) 475 } 476 } 477 478 func TestTurnUsage_CacheTelemetry_NonCacheProvider(t *testing.T) { 479 u := &TurnUsage{} 480 481 // Provider never returns cache tokens — should not flag as cache-capable 482 for i := 0; i < 5; i++ { 483 u.Add(client.Usage{InputTokens: 5000}) 484 } 485 if u.cacheCapable { 486 t.Error("should not be cache-capable when provider never returns cache tokens") 487 } 488 if u.cacheMissStreak != 0 { 489 t.Errorf("non-cache provider should not accumulate miss streak, got %d", u.cacheMissStreak) 490 } 491 }