/ internal / agent / partition_test.go
partition_test.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"testing"
  6  
  7  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  8  )
  9  
 10  // readOnlyStub always classifies as read-only.
 11  type readOnlyStub struct{ name string }
 12  
 13  func (t *readOnlyStub) Info() ToolInfo                                  { return ToolInfo{Name: t.name} }
 14  func (t *readOnlyStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil }
 15  func (t *readOnlyStub) RequiresApproval() bool                          { return false }
 16  func (t *readOnlyStub) IsReadOnlyCall(string) bool                      { return true }
 17  
 18  // writeStub always classifies as non-read-only.
 19  type writeStub struct{ name string }
 20  
 21  func (t *writeStub) Info() ToolInfo                                  { return ToolInfo{Name: t.name} }
 22  func (t *writeStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil }
 23  func (t *writeStub) RequiresApproval() bool                          { return false }
 24  func (t *writeStub) IsReadOnlyCall(string) bool                      { return false }
 25  
 26  // plainStub does NOT implement ReadOnlyChecker — should default to non-read-only.
 27  type plainStub struct{ name string }
 28  
 29  func (t *plainStub) Info() ToolInfo                                  { return ToolInfo{Name: t.name} }
 30  func (t *plainStub) Run(context.Context, string) (ToolResult, error) { return ToolResult{}, nil }
 31  func (t *plainStub) RequiresApproval() bool                          { return false }
 32  
 33  func ac(tool Tool, index int) approvedToolCall {
 34  	return approvedToolCall{
 35  		index:   index,
 36  		fc:      client.FunctionCall{Name: tool.Info().Name},
 37  		tool:    tool,
 38  		argsStr: "{}",
 39  	}
 40  }
 41  
 42  func TestPartition_MixedReadWrite(t *testing.T) {
 43  	r := &readOnlyStub{name: "r"}
 44  	w := &writeStub{name: "w"}
 45  	batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(r, 1), ac(w, 2), ac(r, 3)})
 46  	if len(batches) != 3 {
 47  		t.Fatalf("expected 3 batches, got %d", len(batches))
 48  	}
 49  	if len(batches[0]) != 2 {
 50  		t.Errorf("batch 0: expected 2 calls, got %d", len(batches[0]))
 51  	}
 52  	if len(batches[1]) != 1 {
 53  		t.Errorf("batch 1: expected 1 call, got %d", len(batches[1]))
 54  	}
 55  	if len(batches[2]) != 1 {
 56  		t.Errorf("batch 2: expected 1 call, got %d", len(batches[2]))
 57  	}
 58  }
 59  
 60  func TestPartition_AllWrites(t *testing.T) {
 61  	w := &writeStub{name: "w"}
 62  	batches := partitionToolCalls([]approvedToolCall{ac(w, 0), ac(w, 1)})
 63  	if len(batches) != 2 {
 64  		t.Fatalf("expected 2 batches, got %d", len(batches))
 65  	}
 66  }
 67  
 68  func TestPartition_AllReads(t *testing.T) {
 69  	r := &readOnlyStub{name: "r"}
 70  	batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(r, 1), ac(r, 2)})
 71  	if len(batches) != 1 {
 72  		t.Fatalf("expected 1 batch, got %d", len(batches))
 73  	}
 74  	if len(batches[0]) != 3 {
 75  		t.Errorf("expected 3 calls, got %d", len(batches[0]))
 76  	}
 77  }
 78  
 79  func TestPartition_SingleWrite(t *testing.T) {
 80  	w := &writeStub{name: "w"}
 81  	batches := partitionToolCalls([]approvedToolCall{ac(w, 0)})
 82  	if len(batches) != 1 || len(batches[0]) != 1 {
 83  		t.Fatalf("expected [[w]], got %v", batches)
 84  	}
 85  }
 86  
 87  func TestPartition_NoReadOnlyChecker_TreatedAsWrite(t *testing.T) {
 88  	p := &plainStub{name: "mcp_tool"}
 89  	r := &readOnlyStub{name: "r"}
 90  	batches := partitionToolCalls([]approvedToolCall{ac(r, 0), ac(p, 1), ac(r, 2)})
 91  	if len(batches) != 3 {
 92  		t.Fatalf("expected 3 batches (plain treated as write), got %d", len(batches))
 93  	}
 94  }
 95  
 96  func TestPartition_Empty(t *testing.T) {
 97  	batches := partitionToolCalls(nil)
 98  	if len(batches) != 0 {
 99  		t.Fatalf("expected 0 batches, got %d", len(batches))
100  	}
101  }