partition_concurrency_test.go
1 package agent 2 3 import ( 4 "context" 5 "os" 6 "sync/atomic" 7 "testing" 8 "time" 9 10 "github.com/Kocoro-lab/ShanClaw/internal/client" 11 ) 12 13 // slowReadTool sleeps briefly and tracks max in-flight concurrency. 14 type slowReadTool struct { 15 inflight *atomic.Int32 16 maxSeen *atomic.Int32 17 } 18 19 func (t *slowReadTool) Info() ToolInfo { return ToolInfo{Name: "slow_read"} } 20 func (t *slowReadTool) RequiresApproval() bool { return false } 21 func (t *slowReadTool) IsReadOnlyCall(string) bool { return true } 22 func (t *slowReadTool) Run(_ context.Context, _ string) (ToolResult, error) { 23 cur := t.inflight.Add(1) 24 for { 25 old := t.maxSeen.Load() 26 if cur <= old || t.maxSeen.CompareAndSwap(old, cur) { 27 break 28 } 29 } 30 time.Sleep(50 * time.Millisecond) 31 t.inflight.Add(-1) 32 return ToolResult{Content: "ok"}, nil 33 } 34 35 func TestExecuteBatches_ConcurrencyLimit(t *testing.T) { 36 inflight := &atomic.Int32{} 37 maxSeen := &atomic.Int32{} 38 tool := &slowReadTool{inflight: inflight, maxSeen: maxSeen} 39 40 // 15 read-only calls — should never exceed maxToolConcurrency (10). 41 var approved []approvedToolCall 42 for i := 0; i < 15; i++ { 43 approved = append(approved, approvedToolCall{ 44 index: i, 45 fc: client.FunctionCall{Name: "slow_read"}, 46 tool: tool, 47 argsStr: "{}", 48 }) 49 } 50 51 execResults := make([]toolExecResult, 15) 52 batches := partitionToolCalls(approved) 53 executeBatches(context.Background(), batches, execResults, nil, nil) 54 55 if maxSeen.Load() > int32(maxToolConcurrency) { 56 t.Errorf("max concurrent = %d, want <= %d", maxSeen.Load(), maxToolConcurrency) 57 } 58 for i, er := range execResults { 59 if er.result.Content != "ok" { 60 t.Errorf("result[%d]: expected 'ok', got %q", i, er.result.Content) 61 } 62 } 63 } 64 65 // panicReadTool panics during Run. 66 type panicReadTool struct{} 67 68 func (t *panicReadTool) Info() ToolInfo { return ToolInfo{Name: "panic_read"} } 69 func (t *panicReadTool) RequiresApproval() bool { return false } 70 func (t *panicReadTool) IsReadOnlyCall(string) bool { return true } 71 func (t *panicReadTool) Run(context.Context, string) (ToolResult, error) { 72 panic("deliberate panic in tool") 73 } 74 75 func TestExecuteBatches_PanicRecovery(t *testing.T) { 76 normal := &readOnlyStub{name: "normal"} 77 panicker := &panicReadTool{} 78 79 approved := []approvedToolCall{ 80 {index: 0, fc: client.FunctionCall{Name: "normal"}, tool: normal, argsStr: "{}"}, 81 {index: 1, fc: client.FunctionCall{Name: "panic_read"}, tool: panicker, argsStr: "{}"}, 82 {index: 2, fc: client.FunctionCall{Name: "normal"}, tool: normal, argsStr: "{}"}, 83 } 84 85 execResults := make([]toolExecResult, 3) 86 batches := partitionToolCalls(approved) 87 executeBatches(context.Background(), batches, execResults, nil, nil) 88 89 // Normal tools should succeed. 90 if execResults[0].result.IsError { 91 t.Errorf("result[0]: expected success, got error: %s", execResults[0].result.Content) 92 } 93 // Panicking tool should have error result. 94 if !execResults[1].result.IsError { 95 t.Error("result[1]: expected error from panic, got success") 96 } 97 if execResults[2].result.IsError { 98 t.Errorf("result[2]: expected success, got error: %s", execResults[2].result.Content) 99 } 100 } 101 102 func TestExecuteBatches_ResultOrdering(t *testing.T) { 103 r := &readOnlyStub{name: "r"} 104 w := &writeStub{name: "w"} 105 106 approved := []approvedToolCall{ 107 {index: 0, fc: client.FunctionCall{Name: "r"}, tool: r, argsStr: "{}"}, 108 {index: 1, fc: client.FunctionCall{Name: "r"}, tool: r, argsStr: "{}"}, 109 {index: 2, fc: client.FunctionCall{Name: "w"}, tool: w, argsStr: "{}"}, 110 {index: 3, fc: client.FunctionCall{Name: "r"}, tool: r, argsStr: "{}"}, 111 } 112 113 execResults := make([]toolExecResult, 4) 114 batches := partitionToolCalls(approved) 115 executeBatches(context.Background(), batches, execResults, nil, nil) 116 117 // Verify all results are populated (not default zero values). 118 for i, er := range execResults { 119 if er.result.Content == "" && !er.result.IsError && er.err == nil { 120 // readOnlyStub and writeStub return empty ToolResult, which is valid. 121 // Just ensure the execution actually ran by checking err is nil. 122 _ = er 123 } 124 _ = i 125 } 126 // The key invariant: results are at their original indices. 127 // Batch 0 (reads): indices 0, 1 128 // Batch 1 (write): index 2 129 // Batch 2 (read): index 3 130 if len(batches) != 3 { 131 t.Fatalf("expected 3 batches, got %d", len(batches)) 132 } 133 } 134 135 func TestExecuteBatches_ReadTrackerInterBatch(t *testing.T) { 136 rt := NewReadTracker() 137 tmpDir := t.TempDir() 138 filePath := tmpDir + "/test.txt" 139 os.WriteFile(filePath, []byte("hello"), 0644) 140 141 // file_read is read-only -> batch 1; file_edit is write -> batch 2 142 readTool := &readOnlyStub{name: "file_read"} 143 editTool := &writeStub{name: "file_edit"} 144 145 argsJSON := `{"path":"` + filePath + `"}` 146 approved := []approvedToolCall{ 147 {index: 0, fc: client.FunctionCall{Name: "file_read"}, tool: readTool, argsStr: argsJSON}, 148 {index: 1, fc: client.FunctionCall{Name: "file_edit"}, tool: editTool, argsStr: argsJSON}, 149 } 150 151 execResults := make([]toolExecResult, 2) 152 batches := partitionToolCalls(approved) 153 154 if len(batches) != 2 { 155 t.Fatalf("expected 2 batches, got %d", len(batches)) 156 } 157 158 executeBatches(context.Background(), batches, execResults, rt, nil) 159 160 if !rt.HasRead(filePath) { 161 t.Error("ReadTracker should have marked file as read between batches") 162 } 163 }