partition_test.go
1 package agent 2 3 import ( 4 "context" 5 "testing" 6 7 "github.com/Kocoro-lab/ShanClaw/internal/client" 8 ) 9 10 // readOnlyStub always classifies as read-only. 11 type readOnlyStub struct{ name string } 12 13 func (t *readOnlyStub) Info() ToolInfo { return ToolInfo{Name: t.name} } 14 func (t *readOnlyStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil } 15 func (t *readOnlyStub) RequiresApproval() bool { return false } 16 func (t *readOnlyStub) IsReadOnlyCall(string) bool { return true } 17 18 // writeStub always classifies as non-read-only. 19 type writeStub struct{ name string } 20 21 func (t *writeStub) Info() ToolInfo { return ToolInfo{Name: t.name} } 22 func (t *writeStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil } 23 func (t *writeStub) RequiresApproval() bool { return false } 24 func (t *writeStub) IsReadOnlyCall(string) bool { return false } 25 26 // plainStub does NOT implement ReadOnlyChecker — should default to non-read-only. 27 type plainStub struct{ name string } 28 29 func (t *plainStub) Info() ToolInfo { return ToolInfo{Name: t.name} } 30 func (t *plainStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil } 31 func (t *plainStub) RequiresApproval() bool { return false } 32 33 func ac(tool Tool, index int) approvedToolCall { 34 return approvedToolCall{ 35 index: index, 36 fc: client.FunctionCall{Name: tool.Info().Name}, 37 tool: tool, 38 argsStr: "{}", 39 } 40 } 41 42 func TestPartition_MixedReadWrite(t *testing.T) { 43 r := &readOnlyStub{name: "r"} 44 w := &writeStub{name: "w"} 45 batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(r, 1), ac(w, 2), ac(r, 3)}) 46 if len(batches) != 3 { 47 t.Fatalf("expected 3 batches, got %d", len(batches)) 48 } 49 if len(batches[0]) != 2 { 50 t.Errorf("batch 0: expected 2 calls, got %d", len(batches[0])) 51 } 52 if len(batches[1]) != 1 { 53 t.Errorf("batch 1: expected 1 call, got %d", len(batches[1])) 54 } 55 if len(batches[2]) != 1 { 56 t.Errorf("batch 2: expected 1 call, got %d", len(batches[2])) 57 } 58 } 59 60 func TestPartition_AllWrites(t *testing.T) { 61 w := &writeStub{name: "w"} 62 batches := partitionToolCalls([]approvedToolCall{ac(w, 0), ac(w, 1)}) 63 if len(batches) != 2 { 64 t.Fatalf("expected 2 batches, got %d", len(batches)) 65 } 66 } 67 68 func TestPartition_AllReads(t *testing.T) { 69 r := &readOnlyStub{name: "r"} 70 batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(r, 1), ac(r, 2)}) 71 if len(batches) != 1 { 72 t.Fatalf("expected 1 batch, got %d", len(batches)) 73 } 74 if len(batches[0]) != 3 { 75 t.Errorf("expected 3 calls, got %d", len(batches[0])) 76 } 77 } 78 79 func TestPartition_SingleWrite(t *testing.T) { 80 w := &writeStub{name: "w"} 81 batches := partitionToolCalls([]approvedToolCall{ac(w, 0)}) 82 if len(batches) != 1 || len(batches[0]) != 1 { 83 t.Fatalf("expected [[w]], got %v", batches) 84 } 85 } 86 87 func TestPartition_NoReadOnlyChecker_TreatedAsWrite(t *testing.T) { 88 p := &plainStub{name: "mcp_tool"} 89 r := &readOnlyStub{name: "r"} 90 batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(p, 1), ac(r, 2)}) 91 if len(batches) != 3 { 92 t.Fatalf("expected 3 batches (plain treated as write), got %d", len(batches)) 93 } 94 } 95 96 func TestPartition_Empty(t *testing.T) { 97 batches := partitionToolCalls(nil) 98 if len(batches) != 0 { 99 t.Fatalf("expected 0 batches, got %d", len(batches)) 100 } 101 }