/ utils / groupToolUses.ts
groupToolUses.ts
  1  import type { BetaToolUseBlock } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
  2  import type { ToolResultBlockParam } from '@anthropic-ai/sdk/resources/messages/messages.mjs'
  3  import type { Tools } from '../Tool.js'
  4  import type {
  5    GroupedToolUseMessage,
  6    NormalizedAssistantMessage,
  7    NormalizedMessage,
  8    NormalizedUserMessage,
  9    ProgressMessage,
 10    RenderableMessage,
 11  } from '../types/message.js'
 12  
 13  export type MessageWithoutProgress = Exclude<NormalizedMessage, ProgressMessage>
 14  
 15  export type GroupingResult = {
 16    messages: RenderableMessage[]
 17  }
 18  
 19  // Cache the set of tool names that support grouped rendering, keyed by the
 20  // tools array reference. The tools array is stable across renders (only
 21  // replaced on MCP connect/disconnect), so this avoids rebuilding the set on
 22  // every call. WeakMap lets old entries be GC'd when the array is replaced.
 23  const GROUPING_CACHE = new WeakMap<Tools, Set<string>>()
 24  
 25  function getToolsWithGrouping(tools: Tools): Set<string> {
 26    let cached = GROUPING_CACHE.get(tools)
 27    if (!cached) {
 28      cached = new Set(tools.filter(t => t.renderGroupedToolUse).map(t => t.name))
 29      GROUPING_CACHE.set(tools, cached)
 30    }
 31    return cached
 32  }
 33  
 34  function getToolUseInfo(
 35    msg: MessageWithoutProgress,
 36  ): { messageId: string; toolUseId: string; toolName: string } | null {
 37    if (msg.type === 'assistant' && msg.message.content[0]?.type === 'tool_use') {
 38      const content = msg.message.content[0]
 39      return {
 40        messageId: msg.message.id,
 41        toolUseId: content.id,
 42        toolName: content.name,
 43      }
 44    }
 45    return null
 46  }
 47  
 48  /**
 49   * Groups tool uses by message.id (same API response) if the tool supports grouped rendering.
 50   * Only groups 2+ tools of the same type from the same message.
 51   * Also collects corresponding tool_results and attaches them to the grouped message.
 52   * When verbose is true, skips grouping so messages render at original positions.
 53   */
 54  export function applyGrouping(
 55    messages: MessageWithoutProgress[],
 56    tools: Tools,
 57    verbose: boolean = false,
 58  ): GroupingResult {
 59    // In verbose mode, don't group - each message renders at its original position
 60    if (verbose) {
 61      return {
 62        messages: messages,
 63      }
 64    }
 65    const toolsWithGrouping = getToolsWithGrouping(tools)
 66  
 67    // First pass: group tool uses by message.id + tool name
 68    const groups = new Map<
 69      string,
 70      NormalizedAssistantMessage<BetaToolUseBlock>[]
 71    >()
 72  
 73    for (const msg of messages) {
 74      const info = getToolUseInfo(msg)
 75      if (info && toolsWithGrouping.has(info.toolName)) {
 76        const key = `${info.messageId}:${info.toolName}`
 77        const group = groups.get(key) ?? []
 78        group.push(msg as NormalizedAssistantMessage<BetaToolUseBlock>)
 79        groups.set(key, group)
 80      }
 81    }
 82  
 83    // Identify valid groups (2+ items) and collect their tool use IDs
 84    const validGroups = new Map<
 85      string,
 86      NormalizedAssistantMessage<BetaToolUseBlock>[]
 87    >()
 88    const groupedToolUseIds = new Set<string>()
 89  
 90    for (const [key, group] of groups) {
 91      if (group.length >= 2) {
 92        validGroups.set(key, group)
 93        for (const msg of group) {
 94          const info = getToolUseInfo(msg)
 95          if (info) {
 96            groupedToolUseIds.add(info.toolUseId)
 97          }
 98        }
 99      }
100    }
101  
102    // Collect result messages for grouped tool_uses
103    // Map from tool_use_id to the user message containing that result
104    const resultsByToolUseId = new Map<string, NormalizedUserMessage>()
105  
106    for (const msg of messages) {
107      if (msg.type === 'user') {
108        for (const content of msg.message.content) {
109          if (
110            content.type === 'tool_result' &&
111            groupedToolUseIds.has(content.tool_use_id)
112          ) {
113            resultsByToolUseId.set(content.tool_use_id, msg)
114          }
115        }
116      }
117    }
118  
119    // Second pass: build output, emitting each group only once
120    const result: RenderableMessage[] = []
121    const emittedGroups = new Set<string>()
122  
123    for (const msg of messages) {
124      const info = getToolUseInfo(msg)
125  
126      if (info) {
127        const key = `${info.messageId}:${info.toolName}`
128        const group = validGroups.get(key)
129  
130        if (group) {
131          if (!emittedGroups.has(key)) {
132            emittedGroups.add(key)
133            const firstMsg = group[0]!
134  
135            // Collect results for this group
136            const results: NormalizedUserMessage[] = []
137            for (const assistantMsg of group) {
138              const toolUseId = (
139                assistantMsg.message.content[0] as { id: string }
140              ).id
141              const resultMsg = resultsByToolUseId.get(toolUseId)
142              if (resultMsg) {
143                results.push(resultMsg)
144              }
145            }
146  
147            const groupedMessage: GroupedToolUseMessage = {
148              type: 'grouped_tool_use',
149              toolName: info.toolName,
150              messages: group,
151              results,
152              displayMessage: firstMsg,
153              uuid: `grouped-${firstMsg.uuid}`,
154              timestamp: firstMsg.timestamp,
155              messageId: info.messageId,
156            }
157            result.push(groupedMessage)
158          }
159          continue
160        }
161      }
162  
163      // Skip user messages whose tool_results are all grouped
164      if (msg.type === 'user') {
165        const toolResults = msg.message.content.filter(
166          (c): c is ToolResultBlockParam => c.type === 'tool_result',
167        )
168        if (toolResults.length > 0) {
169          const allGrouped = toolResults.every(tr =>
170            groupedToolUseIds.has(tr.tool_use_id),
171          )
172          if (allGrouped) {
173            continue
174          }
175        }
176      }
177  
178      result.push(msg)
179    }
180  
181    return { messages: result }
182  }