/ services / tools / StreamingToolExecutor.ts
StreamingToolExecutor.ts
  1  import type { ToolUseBlock } from '@anthropic-ai/sdk/resources/index.mjs'
  2  import {
  3    createUserMessage,
  4    REJECT_MESSAGE,
  5    withMemoryCorrectionHint,
  6  } from 'src/utils/messages.js'
  7  import type { CanUseToolFn } from '../../hooks/useCanUseTool.js'
  8  import { findToolByName, type Tools, type ToolUseContext } from '../../Tool.js'
  9  import { BASH_TOOL_NAME } from '../../tools/BashTool/toolName.js'
 10  import type { AssistantMessage, Message } from '../../types/message.js'
 11  import { createChildAbortController } from '../../utils/abortController.js'
 12  import { runToolUse } from './toolExecution.js'
 13  
 14  type MessageUpdate = {
 15    message?: Message
 16    newContext?: ToolUseContext
 17  }
 18  
 19  type ToolStatus = 'queued' | 'executing' | 'completed' | 'yielded'
 20  
 21  type TrackedTool = {
 22    id: string
 23    block: ToolUseBlock
 24    assistantMessage: AssistantMessage
 25    status: ToolStatus
 26    isConcurrencySafe: boolean
 27    promise?: Promise<void>
 28    results?: Message[]
 29    // Progress messages are stored separately and yielded immediately
 30    pendingProgress: Message[]
 31    contextModifiers?: Array<(context: ToolUseContext) => ToolUseContext>
 32  }
 33  
 34  /**
 35   * Executes tools as they stream in with concurrency control.
 36   * - Concurrent-safe tools can execute in parallel with other concurrent-safe tools
 37   * - Non-concurrent tools must execute alone (exclusive access)
 38   * - Results are buffered and emitted in the order tools were received
 39   */
 40  export class StreamingToolExecutor {
 41    private tools: TrackedTool[] = []
 42    private toolUseContext: ToolUseContext
 43    private hasErrored = false
 44    private erroredToolDescription = ''
 45    // Child of toolUseContext.abortController. Fires when a Bash tool errors
 46    // so sibling subprocesses die immediately instead of running to completion.
 47    // Aborting this does NOT abort the parent — query.ts won't end the turn.
 48    private siblingAbortController: AbortController
 49    private discarded = false
 50    // Signal to wake up getRemainingResults when progress is available
 51    private progressAvailableResolve?: () => void
 52  
 53    constructor(
 54      private readonly toolDefinitions: Tools,
 55      private readonly canUseTool: CanUseToolFn,
 56      toolUseContext: ToolUseContext,
 57    ) {
 58      this.toolUseContext = toolUseContext
 59      this.siblingAbortController = createChildAbortController(
 60        toolUseContext.abortController,
 61      )
 62    }
 63  
 64    /**
 65     * Discards all pending and in-progress tools. Called when streaming fallback
 66     * occurs and results from the failed attempt should be abandoned.
 67     * Queued tools won't start, and in-progress tools will receive synthetic errors.
 68     */
 69    discard(): void {
 70      this.discarded = true
 71    }
 72  
 73    /**
 74     * Add a tool to the execution queue. Will start executing immediately if conditions allow.
 75     */
 76    addTool(block: ToolUseBlock, assistantMessage: AssistantMessage): void {
 77      const toolDefinition = findToolByName(this.toolDefinitions, block.name)
 78      if (!toolDefinition) {
 79        this.tools.push({
 80          id: block.id,
 81          block,
 82          assistantMessage,
 83          status: 'completed',
 84          isConcurrencySafe: true,
 85          pendingProgress: [],
 86          results: [
 87            createUserMessage({
 88              content: [
 89                {
 90                  type: 'tool_result',
 91                  content: `<tool_use_error>Error: No such tool available: ${block.name}</tool_use_error>`,
 92                  is_error: true,
 93                  tool_use_id: block.id,
 94                },
 95              ],
 96              toolUseResult: `Error: No such tool available: ${block.name}`,
 97              sourceToolAssistantUUID: assistantMessage.uuid,
 98            }),
 99          ],
100        })
101        return
102      }
103  
104      const parsedInput = toolDefinition.inputSchema.safeParse(block.input)
105      const isConcurrencySafe = parsedInput?.success
106        ? (() => {
107            try {
108              return Boolean(toolDefinition.isConcurrencySafe(parsedInput.data))
109            } catch {
110              return false
111            }
112          })()
113        : false
114      this.tools.push({
115        id: block.id,
116        block,
117        assistantMessage,
118        status: 'queued',
119        isConcurrencySafe,
120        pendingProgress: [],
121      })
122  
123      void this.processQueue()
124    }
125  
126    /**
127     * Check if a tool can execute based on current concurrency state
128     */
129    private canExecuteTool(isConcurrencySafe: boolean): boolean {
130      const executingTools = this.tools.filter(t => t.status === 'executing')
131      return (
132        executingTools.length === 0 ||
133        (isConcurrencySafe && executingTools.every(t => t.isConcurrencySafe))
134      )
135    }
136  
137    /**
138     * Process the queue, starting tools when concurrency conditions allow
139     */
140    private async processQueue(): Promise<void> {
141      for (const tool of this.tools) {
142        if (tool.status !== 'queued') continue
143  
144        if (this.canExecuteTool(tool.isConcurrencySafe)) {
145          await this.executeTool(tool)
146        } else {
147          // Can't execute this tool yet, and since we need to maintain order for non-concurrent tools, stop here
148          if (!tool.isConcurrencySafe) break
149        }
150      }
151    }
152  
153    private createSyntheticErrorMessage(
154      toolUseId: string,
155      reason: 'sibling_error' | 'user_interrupted' | 'streaming_fallback',
156      assistantMessage: AssistantMessage,
157    ): Message {
158      // For user interruptions (ESC to reject), use REJECT_MESSAGE so the UI shows
159      // "User rejected edit" instead of "Error editing file"
160      if (reason === 'user_interrupted') {
161        return createUserMessage({
162          content: [
163            {
164              type: 'tool_result',
165              content: withMemoryCorrectionHint(REJECT_MESSAGE),
166              is_error: true,
167              tool_use_id: toolUseId,
168            },
169          ],
170          toolUseResult: 'User rejected tool use',
171          sourceToolAssistantUUID: assistantMessage.uuid,
172        })
173      }
174      if (reason === 'streaming_fallback') {
175        return createUserMessage({
176          content: [
177            {
178              type: 'tool_result',
179              content:
180                '<tool_use_error>Error: Streaming fallback - tool execution discarded</tool_use_error>',
181              is_error: true,
182              tool_use_id: toolUseId,
183            },
184          ],
185          toolUseResult: 'Streaming fallback - tool execution discarded',
186          sourceToolAssistantUUID: assistantMessage.uuid,
187        })
188      }
189      const desc = this.erroredToolDescription
190      const msg = desc
191        ? `Cancelled: parallel tool call ${desc} errored`
192        : 'Cancelled: parallel tool call errored'
193      return createUserMessage({
194        content: [
195          {
196            type: 'tool_result',
197            content: `<tool_use_error>${msg}</tool_use_error>`,
198            is_error: true,
199            tool_use_id: toolUseId,
200          },
201        ],
202        toolUseResult: msg,
203        sourceToolAssistantUUID: assistantMessage.uuid,
204      })
205    }
206  
207    /**
208     * Determine why a tool should be cancelled.
209     */
210    private getAbortReason(
211      tool: TrackedTool,
212    ): 'sibling_error' | 'user_interrupted' | 'streaming_fallback' | null {
213      if (this.discarded) {
214        return 'streaming_fallback'
215      }
216      if (this.hasErrored) {
217        return 'sibling_error'
218      }
219      if (this.toolUseContext.abortController.signal.aborted) {
220        // 'interrupt' means the user typed a new message while tools were
221        // running. Only cancel tools whose interruptBehavior is 'cancel';
222        // 'block' tools shouldn't reach here (abort isn't fired).
223        if (this.toolUseContext.abortController.signal.reason === 'interrupt') {
224          return this.getToolInterruptBehavior(tool) === 'cancel'
225            ? 'user_interrupted'
226            : null
227        }
228        return 'user_interrupted'
229      }
230      return null
231    }
232  
233    private getToolInterruptBehavior(tool: TrackedTool): 'cancel' | 'block' {
234      const definition = findToolByName(this.toolDefinitions, tool.block.name)
235      if (!definition?.interruptBehavior) return 'block'
236      try {
237        return definition.interruptBehavior()
238      } catch {
239        return 'block'
240      }
241    }
242  
243    private getToolDescription(tool: TrackedTool): string {
244      const input = tool.block.input as Record<string, unknown> | undefined
245      const summary = input?.command ?? input?.file_path ?? input?.pattern ?? ''
246      if (typeof summary === 'string' && summary.length > 0) {
247        const truncated =
248          summary.length > 40 ? summary.slice(0, 40) + '\u2026' : summary
249        return `${tool.block.name}(${truncated})`
250      }
251      return tool.block.name
252    }
253  
254    private updateInterruptibleState(): void {
255      const executing = this.tools.filter(t => t.status === 'executing')
256      this.toolUseContext.setHasInterruptibleToolInProgress?.(
257        executing.length > 0 &&
258          executing.every(t => this.getToolInterruptBehavior(t) === 'cancel'),
259      )
260    }
261  
262    /**
263     * Execute a tool and collect its results
264     */
265    private async executeTool(tool: TrackedTool): Promise<void> {
266      tool.status = 'executing'
267      this.toolUseContext.setInProgressToolUseIDs(prev =>
268        new Set(prev).add(tool.id),
269      )
270      this.updateInterruptibleState()
271  
272      const messages: Message[] = []
273      const contextModifiers: Array<(context: ToolUseContext) => ToolUseContext> =
274        []
275  
276      const collectResults = async () => {
277        // If already aborted (by error or user), generate synthetic error block instead of running the tool
278        const initialAbortReason = this.getAbortReason(tool)
279        if (initialAbortReason) {
280          messages.push(
281            this.createSyntheticErrorMessage(
282              tool.id,
283              initialAbortReason,
284              tool.assistantMessage,
285            ),
286          )
287          tool.results = messages
288          tool.contextModifiers = contextModifiers
289          tool.status = 'completed'
290          this.updateInterruptibleState()
291          return
292        }
293  
294        // Per-tool child controller. Lets siblingAbortController kill running
295        // subprocesses (Bash spawns listen to this signal) when a Bash error
296        // cascades. Permission-dialog rejection also aborts this controller
297        // (PermissionContext.ts cancelAndAbort) — that abort must bubble up to
298        // the query controller so the query loop's post-tool abort check ends
299        // the turn. Without bubble-up, ExitPlanMode "clear context + auto"
300        // sends REJECT_MESSAGE to the model instead of aborting (#21056 regression).
301        const toolAbortController = createChildAbortController(
302          this.siblingAbortController,
303        )
304        toolAbortController.signal.addEventListener(
305          'abort',
306          () => {
307            if (
308              toolAbortController.signal.reason !== 'sibling_error' &&
309              !this.toolUseContext.abortController.signal.aborted &&
310              !this.discarded
311            ) {
312              this.toolUseContext.abortController.abort(
313                toolAbortController.signal.reason,
314              )
315            }
316          },
317          { once: true },
318        )
319  
320        const generator = runToolUse(
321          tool.block,
322          tool.assistantMessage,
323          this.canUseTool,
324          { ...this.toolUseContext, abortController: toolAbortController },
325        )
326  
327        // Track if this specific tool has produced an error result.
328        // This prevents the tool from receiving a duplicate "sibling error"
329        // message when it is the one that caused the error.
330        let thisToolErrored = false
331  
332        for await (const update of generator) {
333          // Check if we were aborted by a sibling tool error or user interruption.
334          // Only add the synthetic error if THIS tool didn't produce the error.
335          const abortReason = this.getAbortReason(tool)
336          if (abortReason && !thisToolErrored) {
337            messages.push(
338              this.createSyntheticErrorMessage(
339                tool.id,
340                abortReason,
341                tool.assistantMessage,
342              ),
343            )
344            break
345          }
346  
347          const isErrorResult =
348            update.message.type === 'user' &&
349            Array.isArray(update.message.message.content) &&
350            update.message.message.content.some(
351              _ => _.type === 'tool_result' && _.is_error === true,
352            )
353  
354          if (isErrorResult) {
355            thisToolErrored = true
356            // Only Bash errors cancel siblings. Bash commands often have implicit
357            // dependency chains (e.g. mkdir fails → subsequent commands pointless).
358            // Read/WebFetch/etc are independent — one failure shouldn't nuke the rest.
359            if (tool.block.name === BASH_TOOL_NAME) {
360              this.hasErrored = true
361              this.erroredToolDescription = this.getToolDescription(tool)
362              this.siblingAbortController.abort('sibling_error')
363            }
364          }
365  
366          if (update.message) {
367            // Progress messages go to pendingProgress for immediate yielding
368            if (update.message.type === 'progress') {
369              tool.pendingProgress.push(update.message)
370              // Signal that progress is available
371              if (this.progressAvailableResolve) {
372                this.progressAvailableResolve()
373                this.progressAvailableResolve = undefined
374              }
375            } else {
376              messages.push(update.message)
377            }
378          }
379          if (update.contextModifier) {
380            contextModifiers.push(update.contextModifier.modifyContext)
381          }
382        }
383        tool.results = messages
384        tool.contextModifiers = contextModifiers
385        tool.status = 'completed'
386        this.updateInterruptibleState()
387  
388        // NOTE: we currently don't support context modifiers for concurrent
389        //       tools. None are actively being used, but if we want to use
390        //       them in concurrent tools, we need to support that here.
391        if (!tool.isConcurrencySafe && contextModifiers.length > 0) {
392          for (const modifier of contextModifiers) {
393            this.toolUseContext = modifier(this.toolUseContext)
394          }
395        }
396      }
397  
398      const promise = collectResults()
399      tool.promise = promise
400  
401      // Process more queue when done
402      void promise.finally(() => {
403        void this.processQueue()
404      })
405    }
406  
407    /**
408     * Get any completed results that haven't been yielded yet (non-blocking)
409     * Maintains order where necessary
410     * Also yields any pending progress messages immediately
411     */
412    *getCompletedResults(): Generator<MessageUpdate, void> {
413      if (this.discarded) {
414        return
415      }
416  
417      for (const tool of this.tools) {
418        // Always yield pending progress messages immediately, regardless of tool status
419        while (tool.pendingProgress.length > 0) {
420          const progressMessage = tool.pendingProgress.shift()!
421          yield { message: progressMessage, newContext: this.toolUseContext }
422        }
423  
424        if (tool.status === 'yielded') {
425          continue
426        }
427  
428        if (tool.status === 'completed' && tool.results) {
429          tool.status = 'yielded'
430  
431          for (const message of tool.results) {
432            yield { message, newContext: this.toolUseContext }
433          }
434  
435          markToolUseAsComplete(this.toolUseContext, tool.id)
436        } else if (tool.status === 'executing' && !tool.isConcurrencySafe) {
437          break
438        }
439      }
440    }
441  
442    /**
443     * Check if any tool has pending progress messages
444     */
445    private hasPendingProgress(): boolean {
446      return this.tools.some(t => t.pendingProgress.length > 0)
447    }
448  
449    /**
450     * Wait for remaining tools and yield their results as they complete
451     * Also yields progress messages as they become available
452     */
453    async *getRemainingResults(): AsyncGenerator<MessageUpdate, void> {
454      if (this.discarded) {
455        return
456      }
457  
458      while (this.hasUnfinishedTools()) {
459        await this.processQueue()
460  
461        for (const result of this.getCompletedResults()) {
462          yield result
463        }
464  
465        // If we still have executing tools but nothing completed, wait for any to complete
466        // OR for progress to become available
467        if (
468          this.hasExecutingTools() &&
469          !this.hasCompletedResults() &&
470          !this.hasPendingProgress()
471        ) {
472          const executingPromises = this.tools
473            .filter(t => t.status === 'executing' && t.promise)
474            .map(t => t.promise!)
475  
476          // Also wait for progress to become available
477          const progressPromise = new Promise<void>(resolve => {
478            this.progressAvailableResolve = resolve
479          })
480  
481          if (executingPromises.length > 0) {
482            await Promise.race([...executingPromises, progressPromise])
483          }
484        }
485      }
486  
487      for (const result of this.getCompletedResults()) {
488        yield result
489      }
490    }
491  
492    /**
493     * Check if there are any completed results ready to yield
494     */
495    private hasCompletedResults(): boolean {
496      return this.tools.some(t => t.status === 'completed')
497    }
498  
499    /**
500     * Check if there are any tools still executing
501     */
502    private hasExecutingTools(): boolean {
503      return this.tools.some(t => t.status === 'executing')
504    }
505  
506    /**
507     * Check if there are any unfinished tools
508     */
509    private hasUnfinishedTools(): boolean {
510      return this.tools.some(t => t.status !== 'yielded')
511    }
512  
513    /**
514     * Get the current tool use context (may have been modified by context modifiers)
515     */
516    getUpdatedContext(): ToolUseContext {
517      return this.toolUseContext
518    }
519  }
520  
521  function markToolUseAsComplete(
522    toolUseContext: ToolUseContext,
523    toolUseID: string,
524  ) {
525    toolUseContext.setInProgressToolUseIDs(prev => {
526      const next = new Set(prev)
527      next.delete(toolUseID)
528      return next
529    })
530  }