/ tools / ToolSearchTool / ToolSearchTool.ts
ToolSearchTool.ts
  1  import type { ToolResultBlockParam } from '@anthropic-ai/sdk/resources/index.mjs'
  2  import memoize from 'lodash-es/memoize.js'
  3  import { z } from 'zod/v4'
  4  import {
  5    type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
  6    logEvent,
  7  } from '../../services/analytics/index.js'
  8  import {
  9    buildTool,
 10    findToolByName,
 11    type Tool,
 12    type ToolDef,
 13    type Tools,
 14  } from '../../Tool.js'
 15  import { logForDebugging } from '../../utils/debug.js'
 16  import { lazySchema } from '../../utils/lazySchema.js'
 17  import { escapeRegExp } from '../../utils/stringUtils.js'
 18  import { isToolSearchEnabledOptimistic } from '../../utils/toolSearch.js'
 19  import { getPrompt, isDeferredTool, TOOL_SEARCH_TOOL_NAME } from './prompt.js'
 20  
 21  export const inputSchema = lazySchema(() =>
 22    z.object({
 23      query: z
 24        .string()
 25        .describe(
 26          'Query to find deferred tools. Use "select:<tool_name>" for direct selection, or keywords to search.',
 27        ),
 28      max_results: z
 29        .number()
 30        .optional()
 31        .default(5)
 32        .describe('Maximum number of results to return (default: 5)'),
 33    }),
 34  )
 35  type InputSchema = ReturnType<typeof inputSchema>
 36  
 37  export const outputSchema = lazySchema(() =>
 38    z.object({
 39      matches: z.array(z.string()),
 40      query: z.string(),
 41      total_deferred_tools: z.number(),
 42      pending_mcp_servers: z.array(z.string()).optional(),
 43    }),
 44  )
 45  type OutputSchema = ReturnType<typeof outputSchema>
 46  
 47  export type Output = z.infer<OutputSchema>
 48  
 49  // Track deferred tool names to detect when cache should be cleared
 50  let cachedDeferredToolNames: string | null = null
 51  
 52  /**
 53   * Get a cache key representing the current set of deferred tools.
 54   */
 55  function getDeferredToolsCacheKey(deferredTools: Tools): string {
 56    return deferredTools
 57      .map(t => t.name)
 58      .sort()
 59      .join(',')
 60  }
 61  
 62  /**
 63   * Get tool description, memoized by tool name.
 64   * Used for keyword search scoring.
 65   */
 66  const getToolDescriptionMemoized = memoize(
 67    async (toolName: string, tools: Tools): Promise<string> => {
 68      const tool = findToolByName(tools, toolName)
 69      if (!tool) {
 70        return ''
 71      }
 72      return tool.prompt({
 73        getToolPermissionContext: async () => ({
 74          mode: 'default' as const,
 75          additionalWorkingDirectories: new Map(),
 76          alwaysAllowRules: {},
 77          alwaysDenyRules: {},
 78          alwaysAskRules: {},
 79          isBypassPermissionsModeAvailable: false,
 80        }),
 81        tools,
 82        agents: [],
 83      })
 84    },
 85    (toolName: string) => toolName,
 86  )
 87  
 88  /**
 89   * Invalidate the description cache if deferred tools have changed.
 90   */
 91  function maybeInvalidateCache(deferredTools: Tools): void {
 92    const currentKey = getDeferredToolsCacheKey(deferredTools)
 93    if (cachedDeferredToolNames !== currentKey) {
 94      logForDebugging(
 95        `ToolSearchTool: cache invalidated - deferred tools changed`,
 96      )
 97      getToolDescriptionMemoized.cache.clear?.()
 98      cachedDeferredToolNames = currentKey
 99    }
100  }
101  
102  export function clearToolSearchDescriptionCache(): void {
103    getToolDescriptionMemoized.cache.clear?.()
104    cachedDeferredToolNames = null
105  }
106  
107  /**
108   * Build the search result output structure.
109   */
110  function buildSearchResult(
111    matches: string[],
112    query: string,
113    totalDeferredTools: number,
114    pendingMcpServers?: string[],
115  ): { data: Output } {
116    return {
117      data: {
118        matches,
119        query,
120        total_deferred_tools: totalDeferredTools,
121        ...(pendingMcpServers && pendingMcpServers.length > 0
122          ? { pending_mcp_servers: pendingMcpServers }
123          : {}),
124      },
125    }
126  }
127  
128  /**
129   * Parse tool name into searchable parts.
130   * Handles both MCP tools (mcp__server__action) and regular tools (CamelCase).
131   */
132  function parseToolName(name: string): {
133    parts: string[]
134    full: string
135    isMcp: boolean
136  } {
137    // Check if it's an MCP tool
138    if (name.startsWith('mcp__')) {
139      const withoutPrefix = name.replace(/^mcp__/, '').toLowerCase()
140      const parts = withoutPrefix.split('__').flatMap(p => p.split('_'))
141      return {
142        parts: parts.filter(Boolean),
143        full: withoutPrefix.replace(/__/g, ' ').replace(/_/g, ' '),
144        isMcp: true,
145      }
146    }
147  
148    // Regular tool - split by CamelCase and underscores
149    const parts = name
150      .replace(/([a-z])([A-Z])/g, '$1 $2') // CamelCase to spaces
151      .replace(/_/g, ' ')
152      .toLowerCase()
153      .split(/\s+/)
154      .filter(Boolean)
155  
156    return {
157      parts,
158      full: parts.join(' '),
159      isMcp: false,
160    }
161  }
162  
163  /**
164   * Pre-compile word-boundary regexes for all search terms.
165   * Called once per search instead of tools×terms×2 times.
166   */
167  function compileTermPatterns(terms: string[]): Map<string, RegExp> {
168    const patterns = new Map<string, RegExp>()
169    for (const term of terms) {
170      if (!patterns.has(term)) {
171        patterns.set(term, new RegExp(`\\b${escapeRegExp(term)}\\b`))
172      }
173    }
174    return patterns
175  }
176  
177  /**
178   * Keyword-based search over tool names and descriptions.
179   * Handles both MCP tools (mcp__server__action) and regular tools (CamelCase).
180   *
181   * The model typically queries with:
182   * - Server names when it knows the integration (e.g., "slack", "github")
183   * - Action words when looking for functionality (e.g., "read", "list", "create")
184   * - Tool-specific terms (e.g., "notebook", "shell", "kill")
185   */
186  async function searchToolsWithKeywords(
187    query: string,
188    deferredTools: Tools,
189    tools: Tools,
190    maxResults: number,
191  ): Promise<string[]> {
192    const queryLower = query.toLowerCase().trim()
193  
194    // Fast path: if query matches a tool name exactly, return it directly.
195    // Handles models using a bare tool name instead of select: prefix (seen
196    // from subagents/post-compaction). Checks deferred first, then falls back
197    // to the full tool set — selecting an already-loaded tool is a harmless
198    // no-op that lets the model proceed without retry churn.
199    const exactMatch =
200      deferredTools.find(t => t.name.toLowerCase() === queryLower) ??
201      tools.find(t => t.name.toLowerCase() === queryLower)
202    if (exactMatch) {
203      return [exactMatch.name]
204    }
205  
206    // If query looks like an MCP tool prefix (mcp__server), find matching tools.
207    // Handles models searching by server name with mcp__ prefix.
208    if (queryLower.startsWith('mcp__') && queryLower.length > 5) {
209      const prefixMatches = deferredTools
210        .filter(t => t.name.toLowerCase().startsWith(queryLower))
211        .slice(0, maxResults)
212        .map(t => t.name)
213      if (prefixMatches.length > 0) {
214        return prefixMatches
215      }
216    }
217  
218    const queryTerms = queryLower.split(/\s+/).filter(term => term.length > 0)
219  
220    // Partition into required (+prefixed) and optional terms
221    const requiredTerms: string[] = []
222    const optionalTerms: string[] = []
223    for (const term of queryTerms) {
224      if (term.startsWith('+') && term.length > 1) {
225        requiredTerms.push(term.slice(1))
226      } else {
227        optionalTerms.push(term)
228      }
229    }
230  
231    const allScoringTerms =
232      requiredTerms.length > 0 ? [...requiredTerms, ...optionalTerms] : queryTerms
233    const termPatterns = compileTermPatterns(allScoringTerms)
234  
235    // Pre-filter to tools matching ALL required terms in name or description
236    let candidateTools = deferredTools
237    if (requiredTerms.length > 0) {
238      const matches = await Promise.all(
239        deferredTools.map(async tool => {
240          const parsed = parseToolName(tool.name)
241          const description = await getToolDescriptionMemoized(tool.name, tools)
242          const descNormalized = description.toLowerCase()
243          const hintNormalized = tool.searchHint?.toLowerCase() ?? ''
244          const matchesAll = requiredTerms.every(term => {
245            const pattern = termPatterns.get(term)!
246            return (
247              parsed.parts.includes(term) ||
248              parsed.parts.some(part => part.includes(term)) ||
249              pattern.test(descNormalized) ||
250              (hintNormalized && pattern.test(hintNormalized))
251            )
252          })
253          return matchesAll ? tool : null
254        }),
255      )
256      candidateTools = matches.filter((t): t is Tool => t !== null)
257    }
258  
259    const scored = await Promise.all(
260      candidateTools.map(async tool => {
261        const parsed = parseToolName(tool.name)
262        const description = await getToolDescriptionMemoized(tool.name, tools)
263        const descNormalized = description.toLowerCase()
264        const hintNormalized = tool.searchHint?.toLowerCase() ?? ''
265  
266        let score = 0
267        for (const term of allScoringTerms) {
268          const pattern = termPatterns.get(term)!
269  
270          // Exact part match (high weight for MCP server names, tool name parts)
271          if (parsed.parts.includes(term)) {
272            score += parsed.isMcp ? 12 : 10
273          } else if (parsed.parts.some(part => part.includes(term))) {
274            score += parsed.isMcp ? 6 : 5
275          }
276  
277          // Full name fallback (for edge cases)
278          if (parsed.full.includes(term) && score === 0) {
279            score += 3
280          }
281  
282          // searchHint match — curated capability phrase, higher signal than prompt
283          if (hintNormalized && pattern.test(hintNormalized)) {
284            score += 4
285          }
286  
287          // Description match - use word boundary to avoid false positives
288          if (pattern.test(descNormalized)) {
289            score += 2
290          }
291        }
292  
293        return { name: tool.name, score }
294      }),
295    )
296  
297    return scored
298      .filter(item => item.score > 0)
299      .sort((a, b) => b.score - a.score)
300      .slice(0, maxResults)
301      .map(item => item.name)
302  }
303  
304  export const ToolSearchTool = buildTool({
305    isEnabled() {
306      return isToolSearchEnabledOptimistic()
307    },
308    isConcurrencySafe() {
309      return true
310    },
311    isReadOnly() {
312      return true
313    },
314    name: TOOL_SEARCH_TOOL_NAME,
315    maxResultSizeChars: 100_000,
316    async description() {
317      return getPrompt()
318    },
319    async prompt() {
320      return getPrompt()
321    },
322    get inputSchema(): InputSchema {
323      return inputSchema()
324    },
325    get outputSchema(): OutputSchema {
326      return outputSchema()
327    },
328    async call(input, { options: { tools }, getAppState }) {
329      const { query, max_results = 5 } = input
330  
331      const deferredTools = tools.filter(isDeferredTool)
332      maybeInvalidateCache(deferredTools)
333  
334      // Check for MCP servers still connecting
335      function getPendingServerNames(): string[] | undefined {
336        const appState = getAppState()
337        const pending = appState.mcp.clients.filter(c => c.type === 'pending')
338        return pending.length > 0 ? pending.map(s => s.name) : undefined
339      }
340  
341      // Helper to log search outcome
342      function logSearchOutcome(
343        matches: string[],
344        queryType: 'select' | 'keyword',
345      ): void {
346        logEvent('tengu_tool_search_outcome', {
347          query:
348            query as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
349          queryType:
350            queryType as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
351          matchCount: matches.length,
352          totalDeferredTools: deferredTools.length,
353          maxResults: max_results,
354          hasMatches: matches.length > 0,
355        })
356      }
357  
358      // Check for select: prefix — direct tool selection.
359      // Supports comma-separated multi-select: `select:A,B,C`.
360      // If a name isn't in the deferred set but IS in the full tool set,
361      // we still return it — the tool is already loaded, so "selecting" it
362      // is a harmless no-op that lets the model proceed without retry churn.
363      const selectMatch = query.match(/^select:(.+)$/i)
364      if (selectMatch) {
365        const requested = selectMatch[1]!
366          .split(',')
367          .map(s => s.trim())
368          .filter(Boolean)
369  
370        const found: string[] = []
371        const missing: string[] = []
372        for (const toolName of requested) {
373          const tool =
374            findToolByName(deferredTools, toolName) ??
375            findToolByName(tools, toolName)
376          if (tool) {
377            if (!found.includes(tool.name)) found.push(tool.name)
378          } else {
379            missing.push(toolName)
380          }
381        }
382  
383        if (found.length === 0) {
384          logForDebugging(
385            `ToolSearchTool: select failed — none found: ${missing.join(', ')}`,
386          )
387          logSearchOutcome([], 'select')
388          const pendingServers = getPendingServerNames()
389          return buildSearchResult(
390            [],
391            query,
392            deferredTools.length,
393            pendingServers,
394          )
395        }
396  
397        if (missing.length > 0) {
398          logForDebugging(
399            `ToolSearchTool: partial select — found: ${found.join(', ')}, missing: ${missing.join(', ')}`,
400          )
401        } else {
402          logForDebugging(`ToolSearchTool: selected ${found.join(', ')}`)
403        }
404        logSearchOutcome(found, 'select')
405        return buildSearchResult(found, query, deferredTools.length)
406      }
407  
408      // Keyword search
409      const matches = await searchToolsWithKeywords(
410        query,
411        deferredTools,
412        tools,
413        max_results,
414      )
415  
416      logForDebugging(
417        `ToolSearchTool: keyword search for "${query}", found ${matches.length} matches`,
418      )
419  
420      logSearchOutcome(matches, 'keyword')
421  
422      // Include pending server info when search finds no matches
423      if (matches.length === 0) {
424        const pendingServers = getPendingServerNames()
425        return buildSearchResult(
426          matches,
427          query,
428          deferredTools.length,
429          pendingServers,
430        )
431      }
432  
433      return buildSearchResult(matches, query, deferredTools.length)
434    },
435    renderToolUseMessage() {
436      return null
437    },
438    userFacingName: () => '',
439    /**
440     * Returns a tool_result with tool_reference blocks.
441     * This format works on 1P/Foundry. Bedrock/Vertex may not support
442     * client-side tool_reference expansion yet.
443     */
444    mapToolResultToToolResultBlockParam(
445      content: Output,
446      toolUseID: string,
447    ): ToolResultBlockParam {
448      if (content.matches.length === 0) {
449        let text = 'No matching deferred tools found'
450        if (
451          content.pending_mcp_servers &&
452          content.pending_mcp_servers.length > 0
453        ) {
454          text += `. Some MCP servers are still connecting: ${content.pending_mcp_servers.join(', ')}. Their tools will become available shortly — try searching again.`
455        }
456        return {
457          type: 'tool_result',
458          tool_use_id: toolUseID,
459          content: text,
460        }
461      }
462      return {
463        type: 'tool_result',
464        tool_use_id: toolUseID,
465        content: content.matches.map(name => ({
466          type: 'tool_reference' as const,
467          tool_name: name,
468        })),
469      } as unknown as ToolResultBlockParam
470    },
471  } satisfies ToolDef<InputSchema, Output>)