/ src / query.ts
query.ts
  1  import {
  2    Message as APIAssistantMessage,
  3    MessageParam,
  4    ToolUseBlock,
  5  } from '@anthropic-ai/sdk/resources/index.mjs'
  6  import { UUID } from 'crypto'
  7  import type { Tool, ToolUseContext } from './Tool.js'
  8  import {
  9    messagePairValidForBinaryFeedback,
 10    shouldUseBinaryFeedback,
 11  } from './components/binary-feedback/utils.js'
 12  import { CanUseToolFn } from './hooks/useCanUseTool.js'
 13  import {
 14    formatSystemPromptWithContext,
 15    querySonnet,
 16  } from './services/claude.js'
 17  import { logEvent } from './services/statsig.js'
 18  import { all } from './utils/generators.js'
 19  import { logError } from './utils/log.js'
 20  import {
 21    createAssistantMessage,
 22    createProgressMessage,
 23    createToolResultStopMessage,
 24    createUserMessage,
 25    FullToolUseResult,
 26    INTERRUPT_MESSAGE,
 27    INTERRUPT_MESSAGE_FOR_TOOL_USE,
 28    NormalizedMessage,
 29    normalizeMessagesForAPI,
 30  } from './utils/messages.js'
 31  import { BashTool } from './tools/BashTool/BashTool.js'
 32  import { getCwd } from './utils/state.js'
 33  
 34  export type Response = { costUSD: number; response: string }
 35  export type UserMessage = {
 36    message: MessageParam
 37    type: 'user'
 38    uuid: UUID
 39    toolUseResult?: FullToolUseResult
 40  }
 41  
 42  export type AssistantMessage = {
 43    costUSD: number
 44    durationMs: number
 45    message: APIAssistantMessage
 46    type: 'assistant'
 47    uuid: UUID
 48    isApiErrorMessage?: boolean
 49  }
 50  
 51  export type BinaryFeedbackResult =
 52    | { message: AssistantMessage | null; shouldSkipPermissionCheck: false }
 53    | { message: AssistantMessage; shouldSkipPermissionCheck: true }
 54  
 55  export type ProgressMessage = {
 56    content: AssistantMessage
 57    normalizedMessages: NormalizedMessage[]
 58    siblingToolUseIDs: Set<string>
 59    tools: Tool[]
 60    toolUseID: string
 61    type: 'progress'
 62    uuid: UUID
 63  }
 64  
 65  // Each array item is either a single message or a message-and-response pair
 66  export type Message = UserMessage | AssistantMessage | ProgressMessage
 67  
 68  const MAX_TOOL_USE_CONCURRENCY = 10
 69  
 70  // Returns a message if we got one, or `null` if the user cancelled
 71  async function queryWithBinaryFeedback(
 72    toolUseContext: ToolUseContext,
 73    getAssistantResponse: () => Promise<AssistantMessage>,
 74    getBinaryFeedbackResponse?: (
 75      m1: AssistantMessage,
 76      m2: AssistantMessage,
 77    ) => Promise<BinaryFeedbackResult>,
 78  ): Promise<BinaryFeedbackResult> {
 79    if (
 80      process.env.USER_TYPE !== 'ant' ||
 81      !getBinaryFeedbackResponse ||
 82      !(await shouldUseBinaryFeedback())
 83    ) {
 84      const assistantMessage = await getAssistantResponse()
 85      if (toolUseContext.abortController.signal.aborted) {
 86        return { message: null, shouldSkipPermissionCheck: false }
 87      }
 88      return { message: assistantMessage, shouldSkipPermissionCheck: false }
 89    }
 90    const [m1, m2] = await Promise.all([
 91      getAssistantResponse(),
 92      getAssistantResponse(),
 93    ])
 94    if (toolUseContext.abortController.signal.aborted) {
 95      return { message: null, shouldSkipPermissionCheck: false }
 96    }
 97    if (m2.isApiErrorMessage) {
 98      // If m2 is an error, we might as well return m1, even if it's also an error --
 99      // the UI will display it as an error as it would in the non-feedback path.
100      return { message: m1, shouldSkipPermissionCheck: false }
101    }
102    if (m1.isApiErrorMessage) {
103      return { message: m2, shouldSkipPermissionCheck: false }
104    }
105    if (!messagePairValidForBinaryFeedback(m1, m2)) {
106      return { message: m1, shouldSkipPermissionCheck: false }
107    }
108    return await getBinaryFeedbackResponse(m1, m2)
109  }
110  
111  /**
112   * The rules of thinking are lengthy and fortuitous. They require plenty of thinking
113   * of most long duration and deep meditation for a wizard to wrap one's noggin around.
114   *
115   * The rules follow:
116   * 1. A message that contains a thinking or redacted_thinking block must be part of a query whose max_thinking_length > 0
117   * 2. A thinking block may not be the last message in a block
118   * 3. Thinking blocks must be preserved for the duration of an assistant trajectory (a single turn, or if that turn includes a tool_use block then also its subsequent tool_result and the following assistant message)
119   *
120   * Heed these rules well, young wizard. For they are the rules of thinking, and
121   * the rules of thinking are the rules of the universe. If ye does not heed these
122   * rules, ye will be punished with an entire day of debugging and hair pulling.
123   */
124  export async function* query(
125    messages: Message[],
126    systemPrompt: string[],
127    context: { [k: string]: string },
128    canUseTool: CanUseToolFn,
129    toolUseContext: ToolUseContext,
130    getBinaryFeedbackResponse?: (
131      m1: AssistantMessage,
132      m2: AssistantMessage,
133    ) => Promise<BinaryFeedbackResult>,
134  ): AsyncGenerator<Message, void> {
135    const fullSystemPrompt = formatSystemPromptWithContext(systemPrompt, context)
136  
137    function getAssistantResponse() {
138      return querySonnet(
139        normalizeMessagesForAPI(messages),
140        fullSystemPrompt,
141        toolUseContext.options.maxThinkingTokens,
142        toolUseContext.options.tools,
143        toolUseContext.abortController.signal,
144        {
145          dangerouslySkipPermissions:
146            toolUseContext.options.dangerouslySkipPermissions ?? false,
147          model: toolUseContext.options.slowAndCapableModel,
148          prependCLISysprompt: true,
149        },
150      )
151    }
152  
153    const result = await queryWithBinaryFeedback(
154      toolUseContext,
155      getAssistantResponse,
156      getBinaryFeedbackResponse,
157    )
158  
159    if (result.message === null) {
160      yield createAssistantMessage(INTERRUPT_MESSAGE)
161      return
162    }
163  
164    const assistantMessage = result.message
165    const shouldSkipPermissionCheck = result.shouldSkipPermissionCheck
166  
167    yield assistantMessage
168  
169    // @see https://docs.anthropic.com/en/docs/build-with-claude/tool-use
170    // Note: stop_reason === 'tool_use' is unreliable -- it's not always set correctly
171    const toolUseMessages = assistantMessage.message.content.filter(
172      _ => _.type === 'tool_use',
173    )
174  
175    // If there's no more tool use, we're done
176    if (!toolUseMessages.length) {
177      return
178    }
179  
180    const toolResults: UserMessage[] = []
181  
182    // Prefer to run tools concurrently, if we can
183    // TODO: tighten up the logic -- we can run concurrently much more often than this
184    if (
185      toolUseMessages.every(msg =>
186        toolUseContext.options.tools.find(t => t.name === msg.name)?.isReadOnly(),
187      )
188    ) {
189      for await (const message of runToolsConcurrently(
190        toolUseMessages,
191        assistantMessage,
192        canUseTool,
193        toolUseContext,
194        shouldSkipPermissionCheck,
195      )) {
196        yield message
197        // progress messages are not sent to the server, so don't need to be accumulated for the next turn
198        if (message.type === 'user') {
199          toolResults.push(message)
200        }
201      }
202    } else {
203      for await (const message of runToolsSerially(
204        toolUseMessages,
205        assistantMessage,
206        canUseTool,
207        toolUseContext,
208        shouldSkipPermissionCheck,
209      )) {
210        yield message
211        // progress messages are not sent to the server, so don't need to be accumulated for the next turn
212        if (message.type === 'user') {
213          toolResults.push(message)
214        }
215      }
216    }
217  
218    if (toolUseContext.abortController.signal.aborted) {
219      yield createAssistantMessage(INTERRUPT_MESSAGE_FOR_TOOL_USE)
220      return
221    }
222  
223    // Sort toolResults to match the order of toolUseMessages
224    const orderedToolResults = toolResults.sort((a, b) => {
225      const aIndex = toolUseMessages.findIndex(
226        tu => tu.id === (a.message.content[0] as ToolUseBlock).id,
227      )
228      const bIndex = toolUseMessages.findIndex(
229        tu => tu.id === (b.message.content[0] as ToolUseBlock).id,
230      )
231      return aIndex - bIndex
232    })
233  
234    yield* await query(
235      [...messages, assistantMessage, ...orderedToolResults],
236      systemPrompt,
237      context,
238      canUseTool,
239      toolUseContext,
240      getBinaryFeedbackResponse,
241    )
242  }
243  
244  async function* runToolsConcurrently(
245    toolUseMessages: ToolUseBlock[],
246    assistantMessage: AssistantMessage,
247    canUseTool: CanUseToolFn,
248    toolUseContext: ToolUseContext,
249    shouldSkipPermissionCheck?: boolean,
250  ): AsyncGenerator<Message, void> {
251    yield* all(
252      toolUseMessages.map(toolUse =>
253        runToolUse(
254          toolUse,
255          new Set(toolUseMessages.map(_ => _.id)),
256          assistantMessage,
257          canUseTool,
258          toolUseContext,
259          shouldSkipPermissionCheck,
260        ),
261      ),
262      MAX_TOOL_USE_CONCURRENCY,
263    )
264  }
265  
266  async function* runToolsSerially(
267    toolUseMessages: ToolUseBlock[],
268    assistantMessage: AssistantMessage,
269    canUseTool: CanUseToolFn,
270    toolUseContext: ToolUseContext,
271    shouldSkipPermissionCheck?: boolean,
272  ): AsyncGenerator<Message, void> {
273    for (const toolUse of toolUseMessages) {
274      yield* runToolUse(
275        toolUse,
276        new Set(toolUseMessages.map(_ => _.id)),
277        assistantMessage,
278        canUseTool,
279        toolUseContext,
280        shouldSkipPermissionCheck,
281      )
282    }
283  }
284  
285  export async function* runToolUse(
286    toolUse: ToolUseBlock,
287    siblingToolUseIDs: Set<string>,
288    assistantMessage: AssistantMessage,
289    canUseTool: CanUseToolFn,
290    toolUseContext: ToolUseContext,
291    shouldSkipPermissionCheck?: boolean,
292  ): AsyncGenerator<Message, void> {
293    const toolName = toolUse.name
294    const tool = toolUseContext.options.tools.find(t => t.name === toolName)
295  
296    // Check if the tool exists
297    if (!tool) {
298      logEvent('tengu_tool_use_error', {
299        error: `No such tool available: ${toolName}`,
300        messageID: assistantMessage.message.id,
301        toolName,
302        toolUseID: toolUse.id,
303      })
304      yield createUserMessage([
305        {
306          type: 'tool_result',
307          content: `Error: No such tool available: ${toolName}`,
308          is_error: true,
309          tool_use_id: toolUse.id,
310        },
311      ])
312      return
313    }
314  
315    const toolInput = toolUse.input as { [key: string]: string }
316  
317    try {
318      if (toolUseContext.abortController.signal.aborted) {
319        logEvent('tengu_tool_use_cancelled', {
320          toolName: tool.name,
321          toolUseID: toolUse.id,
322        })
323        const message = createUserMessage([
324          createToolResultStopMessage(toolUse.id),
325        ])
326        yield message
327        return
328      }
329  
330      for await (const message of checkPermissionsAndCallTool(
331        tool,
332        toolUse.id,
333        siblingToolUseIDs,
334        toolInput,
335        toolUseContext,
336        canUseTool,
337        assistantMessage,
338        shouldSkipPermissionCheck,
339      )) {
340        yield message
341      }
342    } catch (e) {
343      logError(e)
344    }
345  }
346  
347  // TODO: Generalize this to all tools
348  export function normalizeToolInput(
349    tool: Tool,
350    input: { [key: string]: boolean | string | number },
351  ): { [key: string]: boolean | string | number } {
352    switch (tool) {
353      case BashTool: {
354        const { command, timeout } = BashTool.inputSchema.parse(input) // already validated upstream, won't throw
355        return {
356          command: command.replace(`cd ${getCwd()} && `, ''),
357          ...(timeout ? { timeout } : {}),
358        }
359      }
360      default:
361        return input
362    }
363  }
364  
365  async function* checkPermissionsAndCallTool(
366    tool: Tool,
367    toolUseID: string,
368    siblingToolUseIDs: Set<string>,
369    input: { [key: string]: boolean | string | number },
370    context: ToolUseContext,
371    canUseTool: CanUseToolFn,
372    assistantMessage: AssistantMessage,
373    shouldSkipPermissionCheck?: boolean,
374  ): AsyncGenerator<UserMessage | ProgressMessage, void> {
375    // Validate input types with zod
376    // (surprisingly, the model is not great at generating valid input)
377    const isValidInput = tool.inputSchema.safeParse(input)
378    if (!isValidInput.success) {
379      logEvent('tengu_tool_use_error', {
380        error: `InputValidationError: ${isValidInput.error.message}`,
381        messageID: assistantMessage.message.id,
382        toolName: tool.name,
383        toolInput: JSON.stringify(input).slice(0, 200),
384      })
385      yield createUserMessage([
386        {
387          type: 'tool_result',
388          content: `InputValidationError: ${isValidInput.error.message}`,
389          is_error: true,
390          tool_use_id: toolUseID,
391        },
392      ])
393      return
394    }
395  
396    const normalizedInput = normalizeToolInput(tool, input)
397  
398    // Validate input values. Each tool has its own validation logic
399    const isValidCall = await tool.validateInput?.(
400      normalizedInput as never,
401      context,
402    )
403    if (isValidCall?.result === false) {
404      logEvent('tengu_tool_use_error', {
405        error: isValidCall?.message.slice(0, 2000),
406        messageID: assistantMessage.message.id,
407        toolName: tool.name,
408        toolInput: JSON.stringify(input).slice(0, 200),
409        ...(isValidCall?.meta ?? {}),
410      })
411      yield createUserMessage([
412        {
413          type: 'tool_result',
414          content: isValidCall!.message,
415          is_error: true,
416          tool_use_id: toolUseID,
417        },
418      ])
419      return
420    }
421  
422    // Check whether we have permission to use the tool,
423    // and ask the user for permission if we don't
424    const permissionResult = shouldSkipPermissionCheck
425      ? ({ result: true } as const)
426      : await canUseTool(tool, normalizedInput, context, assistantMessage)
427    if (permissionResult.result === false) {
428      yield createUserMessage([
429        {
430          type: 'tool_result',
431          content: permissionResult.message,
432          is_error: true,
433          tool_use_id: toolUseID,
434        },
435      ])
436      return
437    }
438  
439    // Call the tool
440    try {
441      const generator = tool.call(normalizedInput as never, context, canUseTool)
442      for await (const result of generator) {
443        switch (result.type) {
444          case 'result':
445            logEvent('tengu_tool_use_success', {
446              messageID: assistantMessage.message.id,
447              toolName: tool.name,
448            })
449            yield createUserMessage(
450              [
451                {
452                  type: 'tool_result',
453                  content: result.resultForAssistant,
454                  tool_use_id: toolUseID,
455                },
456              ],
457              {
458                data: result.data,
459                resultForAssistant: result.resultForAssistant,
460              },
461            )
462            return
463          case 'progress':
464            logEvent('tengu_tool_use_progress', {
465              messageID: assistantMessage.message.id,
466              toolName: tool.name,
467            })
468            yield createProgressMessage(
469              toolUseID,
470              siblingToolUseIDs,
471              result.content,
472              result.normalizedMessages,
473              result.tools,
474            )
475        }
476      }
477    } catch (error) {
478      const content = formatError(error)
479      logError(error)
480      logEvent('tengu_tool_use_error', {
481        error: content.slice(0, 2000),
482        messageID: assistantMessage.message.id,
483        toolName: tool.name,
484        toolInput: JSON.stringify(input).slice(0, 1000),
485      })
486      yield createUserMessage([
487        {
488          type: 'tool_result',
489          content,
490          is_error: true,
491          tool_use_id: toolUseID,
492        },
493      ])
494    }
495  }
496  
497  function formatError(error: unknown): string {
498    if (!(error instanceof Error)) {
499      return String(error)
500    }
501    const parts = [error.message]
502    if ('stderr' in error && typeof error.stderr === 'string') {
503      parts.push(error.stderr)
504    }
505    if ('stdout' in error && typeof error.stdout === 'string') {
506      parts.push(error.stdout)
507    }
508    const fullMessage = parts.filter(Boolean).join('\n')
509    if (fullMessage.length <= 10000) {
510      return fullMessage
511    }
512    const halfLength = 5000
513    const start = fullMessage.slice(0, halfLength)
514    const end = fullMessage.slice(-halfLength)
515    return `${start}\n\n... [${fullMessage.length - 10000} characters truncated] ...\n\n${end}`
516  }