partition.go
1 package agent 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 ) 9 10 const maxToolConcurrency = 10 11 12 // isReadOnly checks if a tool call is read-only by testing the ReadOnlyChecker 13 // optional interface. Tools without the interface default to false (fail-closed). 14 func isReadOnly(ac approvedToolCall) bool { 15 checker, ok := ac.tool.(ReadOnlyChecker) 16 if !ok { 17 return false 18 } 19 return checker.IsReadOnlyCall(ac.argsStr) 20 } 21 22 // partitionToolCalls groups approved tool calls into execution batches. 23 // Consecutive read-only calls are grouped into a single concurrent batch. 24 // Non-read-only calls each get their own sequential batch of size 1. 25 func partitionToolCalls(approved []approvedToolCall) [][]approvedToolCall { 26 if len(approved) == 0 { 27 return nil 28 } 29 var batches [][]approvedToolCall 30 var currentBatch []approvedToolCall 31 currentIsReadOnly := false 32 33 for i, ac := range approved { 34 ro := isReadOnly(ac) 35 if i == 0 { 36 currentBatch = []approvedToolCall{ac} 37 currentIsReadOnly = ro 38 continue 39 } 40 if ro && currentIsReadOnly { 41 currentBatch = append(currentBatch, ac) 42 } else { 43 batches = append(batches, currentBatch) 44 currentBatch = []approvedToolCall{ac} 45 currentIsReadOnly = ro 46 } 47 } 48 if len(currentBatch) > 0 { 49 batches = append(batches, currentBatch) 50 } 51 return batches 52 } 53 54 // executeBatches runs partitioned tool call batches sequentially. 55 // Read-only batches (len > 1) run concurrently with a channel semaphore 56 // capped at maxToolConcurrency. Write batches (len == 1) run directly. 57 // After each batch, readTracker is updated for successful file_read calls 58 // so that subsequent write batches can pass read-before-edit checks. 59 // If handler is non-nil, OnToolCall is fired for each call immediately 60 // before execution begins (so "running" status reflects actual execution). 61 func executeBatches(ctx context.Context, batches [][]approvedToolCall, execResults []toolExecResult, readTracker *ReadTracker, handler EventHandler) { 62 // Attach the handler's OnUsage as the per-run usage emitter so tools 63 // that bill per call (gateway tools reporting xAI/Grok or SerpAPI costs) 64 // can fold their usage into the session totals. 65 if handler != nil { 66 ctx = WithUsageEmit(ctx, handler.OnUsage) 67 } 68 for _, batch := range batches { 69 if len(batch) == 1 { 70 // Single call: run directly, no goroutine overhead. 71 ac := batch[0] 72 func() { 73 defer func() { 74 if r := recover(); r != nil { 75 execResults[ac.index] = toolExecResult{ 76 result: ToolResult{Content: fmt.Sprintf("tool panicked: %v", r), IsError: true}, 77 } 78 } 79 }() 80 if handler != nil { 81 handler.OnToolCall(ac.fc.Name, ac.argsStr) 82 } 83 startTime := time.Now() 84 result, runErr := ac.tool.Run(ctx, ac.argsStr) 85 execResults[ac.index] = toolExecResult{result: result, elapsed: time.Since(startTime), err: runErr} 86 }() 87 } else { 88 // Concurrent batch with semaphore. 89 sem := make(chan struct{}, maxToolConcurrency) 90 var wg sync.WaitGroup 91 wg.Add(len(batch)) 92 for _, ac := range batch { 93 sem <- struct{}{} // acquire — blocks until a slot is free 94 // Emit "running" after acquiring the slot so the event reflects 95 // actual execution start, not just batch membership. Called from 96 // the main goroutine so handler writes stay serialized. 97 if handler != nil { 98 handler.OnToolCall(ac.fc.Name, ac.argsStr) 99 } 100 go func(ac approvedToolCall) { 101 defer wg.Done() 102 defer func() { <-sem }() // release 103 defer func() { 104 if r := recover(); r != nil { 105 execResults[ac.index] = toolExecResult{ 106 result: ToolResult{Content: fmt.Sprintf("tool panicked: %v", r), IsError: true}, 107 } 108 } 109 }() 110 startTime := time.Now() 111 result, runErr := ac.tool.Run(ctx, ac.argsStr) 112 execResults[ac.index] = toolExecResult{result: result, elapsed: time.Since(startTime), err: runErr} 113 }(ac) 114 } 115 wg.Wait() 116 } 117 118 // Inter-batch side effect: track file_read results for ReadTracker. 119 if readTracker != nil { 120 for _, ac := range batch { 121 if ac.fc.Name == "file_read" { 122 er := execResults[ac.index] 123 if !er.result.IsError && er.err == nil { 124 if p := extractPathArg(ac.argsStr); p != "" { 125 readTracker.MarkRead(p) 126 } 127 } 128 } 129 } 130 } 131 } 132 }