/ internal / agent / partition.go
partition.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"sync"
  7  	"time"
  8  )
  9  
 10  const maxToolConcurrency = 10
 11  
 12  // isReadOnly checks if a tool call is read-only by testing the ReadOnlyChecker
 13  // optional interface. Tools without the interface default to false (fail-closed).
 14  func isReadOnly(ac approvedToolCall) bool {
 15  	checker, ok := ac.tool.(ReadOnlyChecker)
 16  	if !ok {
 17  		return false
 18  	}
 19  	return checker.IsReadOnlyCall(ac.argsStr)
 20  }
 21  
 22  // partitionToolCalls groups approved tool calls into execution batches.
 23  // Consecutive read-only calls are grouped into a single concurrent batch.
 24  // Non-read-only calls each get their own sequential batch of size 1.
 25  func partitionToolCalls(approved []approvedToolCall) [][]approvedToolCall {
 26  	if len(approved) == 0 {
 27  		return nil
 28  	}
 29  	var batches [][]approvedToolCall
 30  	var currentBatch []approvedToolCall
 31  	currentIsReadOnly := false
 32  
 33  	for i, ac := range approved {
 34  		ro := isReadOnly(ac)
 35  		if i == 0 {
 36  			currentBatch = []approvedToolCall{ac}
 37  			currentIsReadOnly = ro
 38  			continue
 39  		}
 40  		if ro && currentIsReadOnly {
 41  			currentBatch = append(currentBatch, ac)
 42  		} else {
 43  			batches = append(batches, currentBatch)
 44  			currentBatch = []approvedToolCall{ac}
 45  			currentIsReadOnly = ro
 46  		}
 47  	}
 48  	if len(currentBatch) > 0 {
 49  		batches = append(batches, currentBatch)
 50  	}
 51  	return batches
 52  }
 53  
 54  // executeBatches runs partitioned tool call batches sequentially.
 55  // Read-only batches (len > 1) run concurrently with a channel semaphore
 56  // capped at maxToolConcurrency. Write batches (len == 1) run directly.
 57  // After each batch, readTracker is updated for successful file_read calls
 58  // so that subsequent write batches can pass read-before-edit checks.
 59  // If handler is non-nil, OnToolCall is fired for each call immediately
 60  // before execution begins (so "running" status reflects actual execution).
 61  func executeBatches(ctx context.Context, batches [][]approvedToolCall, execResults []toolExecResult, readTracker *ReadTracker, handler EventHandler) {
 62  	// Attach the handler's OnUsage as the per-run usage emitter so tools
 63  	// that bill per call (gateway tools reporting xAI/Grok or SerpAPI costs)
 64  	// can fold their usage into the session totals.
 65  	if handler != nil {
 66  		ctx = WithUsageEmit(ctx, handler.OnUsage)
 67  	}
 68  	for _, batch := range batches {
 69  		if len(batch) == 1 {
 70  			// Single call: run directly, no goroutine overhead.
 71  			ac := batch[0]
 72  			func() {
 73  				defer func() {
 74  					if r := recover(); r != nil {
 75  						execResults[ac.index] = toolExecResult{
 76  							result: ToolResult{Content: fmt.Sprintf("tool panicked: %v", r), IsError: true},
 77  						}
 78  					}
 79  				}()
 80  				if handler != nil {
 81  					handler.OnToolCall(ac.fc.Name, ac.argsStr)
 82  				}
 83  				startTime := time.Now()
 84  				result, runErr := ac.tool.Run(ctx, ac.argsStr)
 85  				execResults[ac.index] = toolExecResult{result: result, elapsed: time.Since(startTime), err: runErr}
 86  			}()
 87  		} else {
 88  			// Concurrent batch with semaphore.
 89  			sem := make(chan struct{}, maxToolConcurrency)
 90  			var wg sync.WaitGroup
 91  			wg.Add(len(batch))
 92  			for _, ac := range batch {
 93  				sem <- struct{}{} // acquire — blocks until a slot is free
 94  				// Emit "running" after acquiring the slot so the event reflects
 95  				// actual execution start, not just batch membership. Called from
 96  				// the main goroutine so handler writes stay serialized.
 97  				if handler != nil {
 98  					handler.OnToolCall(ac.fc.Name, ac.argsStr)
 99  				}
100  				go func(ac approvedToolCall) {
101  					defer wg.Done()
102  					defer func() { <-sem }() // release
103  					defer func() {
104  						if r := recover(); r != nil {
105  							execResults[ac.index] = toolExecResult{
106  								result: ToolResult{Content: fmt.Sprintf("tool panicked: %v", r), IsError: true},
107  							}
108  						}
109  					}()
110  					startTime := time.Now()
111  					result, runErr := ac.tool.Run(ctx, ac.argsStr)
112  					execResults[ac.index] = toolExecResult{result: result, elapsed: time.Since(startTime), err: runErr}
113  				}(ac)
114  			}
115  			wg.Wait()
116  		}
117  
118  		// Inter-batch side effect: track file_read results for ReadTracker.
119  		if readTracker != nil {
120  			for _, ac := range batch {
121  				if ac.fc.Name == "file_read" {
122  					er := execResults[ac.index]
123  					if !er.result.IsError && er.err == nil {
124  						if p := extractPathArg(ac.argsStr); p != "" {
125  							readTracker.MarkRead(p)
126  						}
127  					}
128  				}
129  			}
130  		}
131  	}
132  }