/ internal / agent / tools_test.go
tools_test.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"strings"
  6  	"testing"
  7  
  8  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  9  )
 10  
 11  func TestToolRegistry_Get(t *testing.T) {
 12  	reg := NewToolRegistry()
 13  	reg.Register(&mockTool{name: "file_read"})
 14  
 15  	tool, ok := reg.Get("file_read")
 16  	if !ok {
 17  		t.Fatal("expected to find file_read")
 18  	}
 19  	if tool.Info().Name != "file_read" {
 20  		t.Errorf("expected 'file_read', got %q", tool.Info().Name)
 21  	}
 22  
 23  	_, ok = reg.Get("nonexistent")
 24  	if ok {
 25  		t.Error("expected not found")
 26  	}
 27  }
 28  
 29  func TestToolRegistry_Schemas(t *testing.T) {
 30  	reg := NewToolRegistry()
 31  	reg.Register(&mockTool{name: "file_read"})
 32  	reg.Register(&mockTool{name: "bash"})
 33  
 34  	schemas := reg.Schemas()
 35  	if len(schemas) != 2 {
 36  		t.Errorf("expected 2 schemas, got %d", len(schemas))
 37  	}
 38  }
 39  
 40  type mockTool struct {
 41  	name string
 42  }
 43  
 44  func (m *mockTool) Info() ToolInfo {
 45  	return ToolInfo{
 46  		Name:        m.name,
 47  		Description: "mock tool",
 48  		Parameters:  map[string]any{"type": "object", "properties": map[string]any{}},
 49  	}
 50  }
 51  
 52  func (m *mockTool) Run(ctx context.Context, args string) (ToolResult, error) {
 53  	return ToolResult{Content: "mock result"}, nil
 54  }
 55  
 56  func (m *mockTool) RequiresApproval() bool { return false }
 57  
 58  type mockNativeTool struct {
 59  	name string
 60  }
 61  
 62  func (m *mockNativeTool) Info() ToolInfo {
 63  	return ToolInfo{Name: m.name, Description: "native tool"}
 64  }
 65  func (m *mockNativeTool) Run(ctx context.Context, args string) (ToolResult, error) {
 66  	return ToolResult{Content: "ok"}, nil
 67  }
 68  func (m *mockNativeTool) RequiresApproval() bool { return false }
 69  func (m *mockNativeTool) NativeToolDef() *client.NativeToolDef {
 70  	return &client.NativeToolDef{
 71  		Type:            "computer_20251124",
 72  		Name:            "computer",
 73  		DisplayWidthPx:  1280,
 74  		DisplayHeightPx: 800,
 75  	}
 76  }
 77  
 78  func TestToolRegistry_SchemasIncludesNativeTool(t *testing.T) {
 79  	reg := NewToolRegistry()
 80  	reg.Register(&mockNativeTool{name: "computer"})
 81  	reg.Register(&mockTool{name: "bash"})
 82  
 83  	schemas := reg.Schemas()
 84  	if len(schemas) != 2 {
 85  		t.Fatalf("expected 2 schemas, got %d", len(schemas))
 86  	}
 87  	// Native tool should use its own type
 88  	if schemas[0].Type != "computer_20251124" {
 89  		t.Errorf("expected type 'computer_20251124', got %q", schemas[0].Type)
 90  	}
 91  	if schemas[0].Name != "computer" {
 92  		t.Errorf("expected name 'computer', got %q", schemas[0].Name)
 93  	}
 94  	if schemas[0].DisplayWidthPx != 1280 {
 95  		t.Errorf("expected display_width_px 1280, got %d", schemas[0].DisplayWidthPx)
 96  	}
 97  	// Standard tool should use function type
 98  	if schemas[1].Type != "function" {
 99  		t.Errorf("expected type 'function' for bash, got %q", schemas[1].Type)
100  	}
101  }
102  
103  func TestToolRegistry_Remove(t *testing.T) {
104  	r := NewToolRegistry()
105  	r.Register(&mockTool{name: "a"})
106  	r.Register(&mockTool{name: "b"})
107  	r.Register(&mockTool{name: "c"})
108  
109  	r.Remove("b")
110  
111  	if _, ok := r.Get("b"); ok {
112  		t.Error("b should be removed")
113  	}
114  	if r.Len() != 2 {
115  		t.Errorf("Len() = %d, want 2", r.Len())
116  	}
117  	names := r.Names()
118  	if len(names) != 2 || names[0] != "a" || names[1] != "c" {
119  		t.Errorf("names = %v, want [a c]", names)
120  	}
121  }
122  
123  func TestToolRegistry_RemoveNonexistent(t *testing.T) {
124  	r := NewToolRegistry()
125  	r.Register(&mockTool{name: "a"})
126  	r.Remove("nonexistent") // should not panic
127  	if r.Len() != 1 {
128  		t.Errorf("Len() = %d, want 1", r.Len())
129  	}
130  }
131  
132  func TestToolRegistry_FilterByAllow(t *testing.T) {
133  	r := NewToolRegistry()
134  	r.Register(&mockTool{name: "file_read"})
135  	r.Register(&mockTool{name: "bash"})
136  	r.Register(&mockTool{name: "computer"})
137  	r.Register(&mockTool{name: "browser"})
138  
139  	filtered := r.FilterByAllow([]string{"file_read", "bash"})
140  	if filtered.Len() != 2 {
141  		t.Errorf("filtered Len() = %d, want 2", filtered.Len())
142  	}
143  	if _, ok := filtered.Get("computer"); ok {
144  		t.Error("computer should be filtered out")
145  	}
146  	if _, ok := filtered.Get("file_read"); !ok {
147  		t.Error("file_read should be present")
148  	}
149  }
150  
151  func TestToolRegistry_FilterByDeny(t *testing.T) {
152  	r := NewToolRegistry()
153  	r.Register(&mockTool{name: "file_read"})
154  	r.Register(&mockTool{name: "bash"})
155  	r.Register(&mockTool{name: "computer"})
156  	r.Register(&mockTool{name: "browser"})
157  
158  	filtered := r.FilterByDeny([]string{"computer", "browser"})
159  	if filtered.Len() != 2 {
160  		t.Errorf("filtered Len() = %d, want 2", filtered.Len())
161  	}
162  	if _, ok := filtered.Get("computer"); ok {
163  		t.Error("computer should be denied")
164  	}
165  	if _, ok := filtered.Get("file_read"); !ok {
166  		t.Error("file_read should be present")
167  	}
168  }
169  
170  func TestToolRegistry_CloneIndependence(t *testing.T) {
171  	r := NewToolRegistry()
172  	r.Register(&mockTool{name: "a"})
173  	r.Register(&mockTool{name: "b"})
174  
175  	c := r.Clone()
176  	c.Remove("a")
177  
178  	if _, ok := r.Get("a"); !ok {
179  		t.Error("original should still have 'a'")
180  	}
181  	if c.Len() != 1 {
182  		t.Errorf("clone Len() = %d, want 1", c.Len())
183  	}
184  }
185  
186  func TestToolRegistry_RegisterOverwrite(t *testing.T) {
187  	r := NewToolRegistry()
188  	r.Register(&mockTool{name: "a"})
189  	r.Register(&mockTool{name: "b"})
190  	r.Register(&mockTool{name: "a"}) // overwrite
191  
192  	names := r.Names()
193  	if len(names) != 2 {
194  		t.Errorf("expected 2 names after overwrite, got %d: %v", len(names), names)
195  	}
196  	if r.Len() != 2 {
197  		t.Errorf("Len() = %d, want 2", r.Len())
198  	}
199  	schemas := r.Schemas()
200  	if len(schemas) != 2 {
201  		t.Errorf("expected 2 schemas, got %d", len(schemas))
202  	}
203  }
204  
205  func TestToolRegistry_RemoveAndReRegister(t *testing.T) {
206  	r := NewToolRegistry()
207  	r.Register(&mockTool{name: "a"})
208  	r.Register(&mockTool{name: "b"})
209  	r.Remove("a")
210  	r.Register(&mockTool{name: "a"})
211  
212  	names := r.Names()
213  	if len(names) != 2 {
214  		t.Errorf("expected 2 names, got %d: %v", len(names), names)
215  	}
216  	schemas := r.Schemas()
217  	if len(schemas) != 2 {
218  		t.Errorf("expected 2 schemas, got %d", len(schemas))
219  	}
220  }
221  
222  func TestToolResultErrorHelpers(t *testing.T) {
223  	tests := []struct {
224  		name        string
225  		result      ToolResult
226  		wantIsError bool
227  		wantCat     ErrorCategory
228  		wantRetry   bool
229  		wantPrefix  string
230  	}{
231  		{
232  			name:        "TransientError",
233  			result:      TransientError("connection timed out"),
234  			wantIsError: true,
235  			wantCat:     ErrCategoryTransient,
236  			wantRetry:   true,
237  			wantPrefix:  "[transient error]",
238  		},
239  		{
240  			name:        "ValidationError",
241  			result:      ValidationError("invalid URL format"),
242  			wantIsError: true,
243  			wantCat:     ErrCategoryValidation,
244  			wantRetry:   false,
245  			wantPrefix:  "[validation error]",
246  		},
247  		{
248  			name:        "BusinessError",
249  			result:      BusinessError("refund exceeds policy limit"),
250  			wantIsError: true,
251  			wantCat:     ErrCategoryBusiness,
252  			wantRetry:   false,
253  			wantPrefix:  "[business error]",
254  		},
255  		{
256  			name:        "PermissionError",
257  			result:      PermissionError("access denied"),
258  			wantIsError: true,
259  			wantCat:     ErrCategoryPermission,
260  			wantRetry:   false,
261  			wantPrefix:  "[permission error]",
262  		},
263  	}
264  
265  	for _, tt := range tests {
266  		t.Run(tt.name, func(t *testing.T) {
267  			if tt.result.IsError != tt.wantIsError {
268  				t.Errorf("IsError = %v, want %v", tt.result.IsError, tt.wantIsError)
269  			}
270  			if tt.result.ErrorCategory != tt.wantCat {
271  				t.Errorf("ErrorCategory = %q, want %q", tt.result.ErrorCategory, tt.wantCat)
272  			}
273  			if tt.result.IsRetryable != tt.wantRetry {
274  				t.Errorf("IsRetryable = %v, want %v", tt.result.IsRetryable, tt.wantRetry)
275  			}
276  			if !strings.HasPrefix(tt.result.Content, tt.wantPrefix) {
277  				t.Errorf("Content = %q, want prefix %q", tt.result.Content, tt.wantPrefix)
278  			}
279  		})
280  	}
281  }
282  
283  func TestToolResult_ZeroValueNotError(t *testing.T) {
284  	r := ToolResult{Content: "some output"}
285  	if r.IsError {
286  		t.Error("zero-value ToolResult must not be an error")
287  	}
288  	if r.ErrorCategory != "" {
289  		t.Errorf("zero-value ErrorCategory must be empty, got %q", r.ErrorCategory)
290  	}
291  	if r.IsRetryable {
292  		t.Error("zero-value IsRetryable must be false")
293  	}
294  }
295  
296  func TestToolResult_ImagesField(t *testing.T) {
297  	result := ToolResult{
298  		Content: "Screenshot captured",
299  		IsError: false,
300  		Images: []ImageBlock{
301  			{MediaType: "image/png", Data: "iVBORfakedata"},
302  		},
303  	}
304  	if len(result.Images) != 1 {
305  		t.Errorf("expected 1 image, got %d", len(result.Images))
306  	}
307  	if result.Images[0].MediaType != "image/png" {
308  		t.Errorf("expected image/png, got %s", result.Images[0].MediaType)
309  	}
310  }
311  
312  // mockSourcedTool is a mock tool that implements ToolSourcer.
313  type mockSourcedTool struct {
314  	name   string
315  	source ToolSource
316  }
317  
318  func (m *mockSourcedTool) Info() ToolInfo {
319  	return ToolInfo{
320  		Name:        m.name,
321  		Description: "mock sourced tool",
322  		Parameters:  map[string]any{"type": "object", "properties": map[string]any{}},
323  	}
324  }
325  func (m *mockSourcedTool) Run(ctx context.Context, args string) (ToolResult, error) {
326  	return ToolResult{Content: "ok"}, nil
327  }
328  func (m *mockSourcedTool) RequiresApproval() bool  { return false }
329  func (m *mockSourcedTool) ToolSource() ToolSource { return m.source }
330  
331  func TestToolRegistry_SortedSchemas(t *testing.T) {
332  	r := NewToolRegistry()
333  	// Register in non-alphabetical, mixed-source order
334  	r.Register(&mockTool{name: "grep"})                                      // local
335  	r.Register(&mockSourcedTool{name: "browser_click", source: SourceMCP})   // mcp
336  	r.Register(&mockSourcedTool{name: "web_search", source: SourceGateway})  // gateway
337  	r.Register(&mockTool{name: "bash"})                                      // local
338  	r.Register(&mockSourcedTool{name: "browser_type", source: SourceMCP})    // mcp
339  	r.Register(&mockSourcedTool{name: "alpaca_news", source: SourceGateway}) // gateway
340  	r.Register(&mockTool{name: "file_read"})                                 // local
341  
342  	schemas := r.SortedSchemas()
343  	var names []string
344  	for _, s := range schemas {
345  		names = append(names, s.Function.Name)
346  	}
347  
348  	expected := []string{
349  		// local alpha
350  		"bash", "file_read", "grep",
351  		// mcp alpha
352  		"browser_click", "browser_type",
353  		// gateway alpha
354  		"alpaca_news", "web_search",
355  	}
356  	if len(names) != len(expected) {
357  		t.Fatalf("got %d schemas, want %d: %v", len(names), len(expected), names)
358  	}
359  	for i, want := range expected {
360  		if names[i] != want {
361  			t.Errorf("position %d: got %q, want %q (full: %v)", i, names[i], want, names)
362  			break
363  		}
364  	}
365  }
366  
367  func TestToolRegistry_SortedNames(t *testing.T) {
368  	r := NewToolRegistry()
369  	r.Register(&mockTool{name: "grep"})
370  	r.Register(&mockSourcedTool{name: "browser_click", source: SourceMCP})
371  	r.Register(&mockTool{name: "bash"})
372  
373  	names := r.SortedNames()
374  	expected := []string{"bash", "grep", "browser_click"}
375  	if len(names) != len(expected) {
376  		t.Fatalf("got %v, want %v", names, expected)
377  	}
378  	for i, want := range expected {
379  		if names[i] != want {
380  			t.Errorf("position %d: got %q, want %q", i, names[i], want)
381  		}
382  	}
383  }
384  
385  func TestToolRegistry_SortedSchemas_MCPAdditionDoesNotShiftLocal(t *testing.T) {
386  	r := NewToolRegistry()
387  	r.Register(&mockTool{name: "grep"})
388  	r.Register(&mockTool{name: "bash"})
389  
390  	before := r.SortedNames()
391  
392  	r.Register(&mockSourcedTool{name: "browser_navigate", source: SourceMCP})
393  
394  	after := r.SortedNames()
395  	// Local tools should still be in positions 0 and 1 with same order
396  	for i := 0; i < 2; i++ {
397  		if before[i] != after[i] {
398  			t.Errorf("local tool shifted: position %d was %q, now %q", i, before[i], after[i])
399  		}
400  	}
401  }
402  
403  func TestToolRegistry_SummaryList(t *testing.T) {
404  	reg := NewToolRegistry()
405  	reg.Register(&mockTool{name: "bash"})
406  	reg.Register(&mockTool{name: "file_read"})
407  
408  	summaries := reg.SummaryList()
409  	if len(summaries) != 2 {
410  		t.Fatalf("expected 2 summaries, got %d", len(summaries))
411  	}
412  	for _, s := range summaries {
413  		if s.Name == "" {
414  			t.Error("summary name is empty")
415  		}
416  		if s.Description == "" {
417  			t.Error("summary description is empty")
418  		}
419  	}
420  }
421  
422  func TestToolRegistry_FullSchemas(t *testing.T) {
423  	reg := NewToolRegistry()
424  	reg.Register(&mockTool{name: "bash"})
425  	reg.Register(&mockTool{name: "file_read"})
426  	reg.Register(&mockTool{name: "grep"})
427  
428  	schemas := reg.FullSchemas([]string{"bash", "file_read"})
429  	if len(schemas) != 2 {
430  		t.Fatalf("expected 2 schemas, got %d", len(schemas))
431  	}
432  	names := map[string]bool{}
433  	for _, s := range schemas {
434  		names[s.Function.Name] = true
435  	}
436  	if !names["bash"] || !names["file_read"] {
437  		t.Errorf("expected bash and file_read, got %v", names)
438  	}
439  }
440  
441  func TestToolRegistry_FullSchemas_Nonexistent(t *testing.T) {
442  	reg := NewToolRegistry()
443  	reg.Register(&mockTool{name: "bash"})
444  
445  	schemas := reg.FullSchemas([]string{"nonexistent"})
446  	if len(schemas) != 0 {
447  		t.Fatalf("expected 0 schemas for nonexistent tool, got %d", len(schemas))
448  	}
449  }
450  
451  func TestTurnUsage_CacheTelemetry(t *testing.T) {
452  	u := &TurnUsage{}
453  
454  	// Turn 1: cache creation (first turn always creates, no reads)
455  	u.Add(client.Usage{InputTokens: 5000, CacheCreationTokens: 4000, CacheReadTokens: 0})
456  	if !u.cacheCapable {
457  		t.Error("should be cache-capable after seeing CacheCreationTokens > 0")
458  	}
459  	if u.cacheMissStreak != 0 {
460  		t.Errorf("first turn should not count as miss, got streak %d", u.cacheMissStreak)
461  	}
462  
463  	// Turn 2: cache hit
464  	u.Add(client.Usage{InputTokens: 5000, CacheReadTokens: 3500})
465  	if u.cacheMissStreak != 0 {
466  		t.Errorf("cache hit should reset streak, got %d", u.cacheMissStreak)
467  	}
468  
469  	// Turns 3-5: cache misses
470  	for i := 0; i < 3; i++ {
471  		u.Add(client.Usage{InputTokens: 5000, CacheReadTokens: 0})
472  	}
473  	if u.cacheMissStreak != 3 {
474  		t.Errorf("expected miss streak 3, got %d", u.cacheMissStreak)
475  	}
476  }
477  
478  func TestTurnUsage_CacheTelemetry_NonCacheProvider(t *testing.T) {
479  	u := &TurnUsage{}
480  
481  	// Provider never returns cache tokens — should not flag as cache-capable
482  	for i := 0; i < 5; i++ {
483  		u.Add(client.Usage{InputTokens: 5000})
484  	}
485  	if u.cacheCapable {
486  		t.Error("should not be cache-capable when provider never returns cache tokens")
487  	}
488  	if u.cacheMissStreak != 0 {
489  		t.Errorf("non-cache provider should not accumulate miss streak, got %d", u.cacheMissStreak)
490  	}
491  }