query.ts
1 import { 2 Message as APIAssistantMessage, 3 MessageParam, 4 ToolUseBlock, 5 } from '@anthropic-ai/sdk/resources/index.mjs' 6 import { UUID } from 'crypto' 7 import type { Tool, ToolUseContext } from './Tool.js' 8 import { 9 messagePairValidForBinaryFeedback, 10 shouldUseBinaryFeedback, 11 } from './components/binary-feedback/utils.js' 12 import { CanUseToolFn } from './hooks/useCanUseTool.js' 13 import { 14 formatSystemPromptWithContext, 15 querySonnet, 16 } from './services/claude.js' 17 import { logEvent } from './services/statsig.js' 18 import { all } from './utils/generators.js' 19 import { logError } from './utils/log.js' 20 import { 21 createAssistantMessage, 22 createProgressMessage, 23 createToolResultStopMessage, 24 createUserMessage, 25 FullToolUseResult, 26 INTERRUPT_MESSAGE, 27 INTERRUPT_MESSAGE_FOR_TOOL_USE, 28 NormalizedMessage, 29 normalizeMessagesForAPI, 30 } from './utils/messages.js' 31 import { BashTool } from './tools/BashTool/BashTool.js' 32 import { getCwd } from './utils/state.js' 33 34 export type Response = { costUSD: number; response: string } 35 export type UserMessage = { 36 message: MessageParam 37 type: 'user' 38 uuid: UUID 39 toolUseResult?: FullToolUseResult 40 } 41 42 export type AssistantMessage = { 43 costUSD: number 44 durationMs: number 45 message: APIAssistantMessage 46 type: 'assistant' 47 uuid: UUID 48 isApiErrorMessage?: boolean 49 } 50 51 export type BinaryFeedbackResult = 52 | { message: AssistantMessage | null; shouldSkipPermissionCheck: false } 53 | { message: AssistantMessage; shouldSkipPermissionCheck: true } 54 55 export type ProgressMessage = { 56 content: AssistantMessage 57 normalizedMessages: NormalizedMessage[] 58 siblingToolUseIDs: Set<string> 59 tools: Tool[] 60 toolUseID: string 61 type: 'progress' 62 uuid: UUID 63 } 64 65 // Each array item is either a single message or a message-and-response pair 66 export type Message = UserMessage | AssistantMessage | ProgressMessage 67 68 const MAX_TOOL_USE_CONCURRENCY = 10 69 70 // Returns a message if we got one, or `null` if the user cancelled 71 async function queryWithBinaryFeedback( 72 toolUseContext: ToolUseContext, 73 getAssistantResponse: () => Promise<AssistantMessage>, 74 getBinaryFeedbackResponse?: ( 75 m1: AssistantMessage, 76 m2: AssistantMessage, 77 ) => Promise<BinaryFeedbackResult>, 78 ): Promise<BinaryFeedbackResult> { 79 if ( 80 process.env.USER_TYPE !== 'ant' || 81 !getBinaryFeedbackResponse || 82 !(await shouldUseBinaryFeedback()) 83 ) { 84 const assistantMessage = await getAssistantResponse() 85 if (toolUseContext.abortController.signal.aborted) { 86 return { message: null, shouldSkipPermissionCheck: false } 87 } 88 return { message: assistantMessage, shouldSkipPermissionCheck: false } 89 } 90 const [m1, m2] = await Promise.all([ 91 getAssistantResponse(), 92 getAssistantResponse(), 93 ]) 94 if (toolUseContext.abortController.signal.aborted) { 95 return { message: null, shouldSkipPermissionCheck: false } 96 } 97 if (m2.isApiErrorMessage) { 98 // If m2 is an error, we might as well return m1, even if it's also an error -- 99 // the UI will display it as an error as it would in the non-feedback path. 100 return { message: m1, shouldSkipPermissionCheck: false } 101 } 102 if (m1.isApiErrorMessage) { 103 return { message: m2, shouldSkipPermissionCheck: false } 104 } 105 if (!messagePairValidForBinaryFeedback(m1, m2)) { 106 return { message: m1, shouldSkipPermissionCheck: false } 107 } 108 return await getBinaryFeedbackResponse(m1, m2) 109 } 110 111 /** 112 * The rules of thinking are lengthy and fortuitous. They require plenty of thinking 113 * of most long duration and deep meditation for a wizard to wrap one's noggin around. 114 * 115 * The rules follow: 116 * 1. A message that contains a thinking or redacted_thinking block must be part of a query whose max_thinking_length > 0 117 * 2. A thinking block may not be the last message in a block 118 * 3. Thinking blocks must be preserved for the duration of an assistant trajectory (a single turn, or if that turn includes a tool_use block then also its subsequent tool_result and the following assistant message) 119 * 120 * Heed these rules well, young wizard. For they are the rules of thinking, and 121 * the rules of thinking are the rules of the universe. If ye does not heed these 122 * rules, ye will be punished with an entire day of debugging and hair pulling. 123 */ 124 export async function* query( 125 messages: Message[], 126 systemPrompt: string[], 127 context: { [k: string]: string }, 128 canUseTool: CanUseToolFn, 129 toolUseContext: ToolUseContext, 130 getBinaryFeedbackResponse?: ( 131 m1: AssistantMessage, 132 m2: AssistantMessage, 133 ) => Promise<BinaryFeedbackResult>, 134 ): AsyncGenerator<Message, void> { 135 const fullSystemPrompt = formatSystemPromptWithContext(systemPrompt, context) 136 137 function getAssistantResponse() { 138 return querySonnet( 139 normalizeMessagesForAPI(messages), 140 fullSystemPrompt, 141 toolUseContext.options.maxThinkingTokens, 142 toolUseContext.options.tools, 143 toolUseContext.abortController.signal, 144 { 145 dangerouslySkipPermissions: 146 toolUseContext.options.dangerouslySkipPermissions ?? false, 147 model: toolUseContext.options.slowAndCapableModel, 148 prependCLISysprompt: true, 149 }, 150 ) 151 } 152 153 const result = await queryWithBinaryFeedback( 154 toolUseContext, 155 getAssistantResponse, 156 getBinaryFeedbackResponse, 157 ) 158 159 if (result.message === null) { 160 yield createAssistantMessage(INTERRUPT_MESSAGE) 161 return 162 } 163 164 const assistantMessage = result.message 165 const shouldSkipPermissionCheck = result.shouldSkipPermissionCheck 166 167 yield assistantMessage 168 169 // @see https://docs.anthropic.com/en/docs/build-with-claude/tool-use 170 // Note: stop_reason === 'tool_use' is unreliable -- it's not always set correctly 171 const toolUseMessages = assistantMessage.message.content.filter( 172 _ => _.type === 'tool_use', 173 ) 174 175 // If there's no more tool use, we're done 176 if (!toolUseMessages.length) { 177 return 178 } 179 180 const toolResults: UserMessage[] = [] 181 182 // Prefer to run tools concurrently, if we can 183 // TODO: tighten up the logic -- we can run concurrently much more often than this 184 if ( 185 toolUseMessages.every(msg => 186 toolUseContext.options.tools.find(t => t.name === msg.name)?.isReadOnly(), 187 ) 188 ) { 189 for await (const message of runToolsConcurrently( 190 toolUseMessages, 191 assistantMessage, 192 canUseTool, 193 toolUseContext, 194 shouldSkipPermissionCheck, 195 )) { 196 yield message 197 // progress messages are not sent to the server, so don't need to be accumulated for the next turn 198 if (message.type === 'user') { 199 toolResults.push(message) 200 } 201 } 202 } else { 203 for await (const message of runToolsSerially( 204 toolUseMessages, 205 assistantMessage, 206 canUseTool, 207 toolUseContext, 208 shouldSkipPermissionCheck, 209 )) { 210 yield message 211 // progress messages are not sent to the server, so don't need to be accumulated for the next turn 212 if (message.type === 'user') { 213 toolResults.push(message) 214 } 215 } 216 } 217 218 if (toolUseContext.abortController.signal.aborted) { 219 yield createAssistantMessage(INTERRUPT_MESSAGE_FOR_TOOL_USE) 220 return 221 } 222 223 // Sort toolResults to match the order of toolUseMessages 224 const orderedToolResults = toolResults.sort((a, b) => { 225 const aIndex = toolUseMessages.findIndex( 226 tu => tu.id === (a.message.content[0] as ToolUseBlock).id, 227 ) 228 const bIndex = toolUseMessages.findIndex( 229 tu => tu.id === (b.message.content[0] as ToolUseBlock).id, 230 ) 231 return aIndex - bIndex 232 }) 233 234 yield* await query( 235 [...messages, assistantMessage, ...orderedToolResults], 236 systemPrompt, 237 context, 238 canUseTool, 239 toolUseContext, 240 getBinaryFeedbackResponse, 241 ) 242 } 243 244 async function* runToolsConcurrently( 245 toolUseMessages: ToolUseBlock[], 246 assistantMessage: AssistantMessage, 247 canUseTool: CanUseToolFn, 248 toolUseContext: ToolUseContext, 249 shouldSkipPermissionCheck?: boolean, 250 ): AsyncGenerator<Message, void> { 251 yield* all( 252 toolUseMessages.map(toolUse => 253 runToolUse( 254 toolUse, 255 new Set(toolUseMessages.map(_ => _.id)), 256 assistantMessage, 257 canUseTool, 258 toolUseContext, 259 shouldSkipPermissionCheck, 260 ), 261 ), 262 MAX_TOOL_USE_CONCURRENCY, 263 ) 264 } 265 266 async function* runToolsSerially( 267 toolUseMessages: ToolUseBlock[], 268 assistantMessage: AssistantMessage, 269 canUseTool: CanUseToolFn, 270 toolUseContext: ToolUseContext, 271 shouldSkipPermissionCheck?: boolean, 272 ): AsyncGenerator<Message, void> { 273 for (const toolUse of toolUseMessages) { 274 yield* runToolUse( 275 toolUse, 276 new Set(toolUseMessages.map(_ => _.id)), 277 assistantMessage, 278 canUseTool, 279 toolUseContext, 280 shouldSkipPermissionCheck, 281 ) 282 } 283 } 284 285 export async function* runToolUse( 286 toolUse: ToolUseBlock, 287 siblingToolUseIDs: Set<string>, 288 assistantMessage: AssistantMessage, 289 canUseTool: CanUseToolFn, 290 toolUseContext: ToolUseContext, 291 shouldSkipPermissionCheck?: boolean, 292 ): AsyncGenerator<Message, void> { 293 const toolName = toolUse.name 294 const tool = toolUseContext.options.tools.find(t => t.name === toolName) 295 296 // Check if the tool exists 297 if (!tool) { 298 logEvent('tengu_tool_use_error', { 299 error: `No such tool available: ${toolName}`, 300 messageID: assistantMessage.message.id, 301 toolName, 302 toolUseID: toolUse.id, 303 }) 304 yield createUserMessage([ 305 { 306 type: 'tool_result', 307 content: `Error: No such tool available: ${toolName}`, 308 is_error: true, 309 tool_use_id: toolUse.id, 310 }, 311 ]) 312 return 313 } 314 315 const toolInput = toolUse.input as { [key: string]: string } 316 317 try { 318 if (toolUseContext.abortController.signal.aborted) { 319 logEvent('tengu_tool_use_cancelled', { 320 toolName: tool.name, 321 toolUseID: toolUse.id, 322 }) 323 const message = createUserMessage([ 324 createToolResultStopMessage(toolUse.id), 325 ]) 326 yield message 327 return 328 } 329 330 for await (const message of checkPermissionsAndCallTool( 331 tool, 332 toolUse.id, 333 siblingToolUseIDs, 334 toolInput, 335 toolUseContext, 336 canUseTool, 337 assistantMessage, 338 shouldSkipPermissionCheck, 339 )) { 340 yield message 341 } 342 } catch (e) { 343 logError(e) 344 } 345 } 346 347 // TODO: Generalize this to all tools 348 export function normalizeToolInput( 349 tool: Tool, 350 input: { [key: string]: boolean | string | number }, 351 ): { [key: string]: boolean | string | number } { 352 switch (tool) { 353 case BashTool: { 354 const { command, timeout } = BashTool.inputSchema.parse(input) // already validated upstream, won't throw 355 return { 356 command: command.replace(`cd ${getCwd()} && `, ''), 357 ...(timeout ? { timeout } : {}), 358 } 359 } 360 default: 361 return input 362 } 363 } 364 365 async function* checkPermissionsAndCallTool( 366 tool: Tool, 367 toolUseID: string, 368 siblingToolUseIDs: Set<string>, 369 input: { [key: string]: boolean | string | number }, 370 context: ToolUseContext, 371 canUseTool: CanUseToolFn, 372 assistantMessage: AssistantMessage, 373 shouldSkipPermissionCheck?: boolean, 374 ): AsyncGenerator<UserMessage | ProgressMessage, void> { 375 // Validate input types with zod 376 // (surprisingly, the model is not great at generating valid input) 377 const isValidInput = tool.inputSchema.safeParse(input) 378 if (!isValidInput.success) { 379 logEvent('tengu_tool_use_error', { 380 error: `InputValidationError: ${isValidInput.error.message}`, 381 messageID: assistantMessage.message.id, 382 toolName: tool.name, 383 toolInput: JSON.stringify(input).slice(0, 200), 384 }) 385 yield createUserMessage([ 386 { 387 type: 'tool_result', 388 content: `InputValidationError: ${isValidInput.error.message}`, 389 is_error: true, 390 tool_use_id: toolUseID, 391 }, 392 ]) 393 return 394 } 395 396 const normalizedInput = normalizeToolInput(tool, input) 397 398 // Validate input values. Each tool has its own validation logic 399 const isValidCall = await tool.validateInput?.( 400 normalizedInput as never, 401 context, 402 ) 403 if (isValidCall?.result === false) { 404 logEvent('tengu_tool_use_error', { 405 error: isValidCall?.message.slice(0, 2000), 406 messageID: assistantMessage.message.id, 407 toolName: tool.name, 408 toolInput: JSON.stringify(input).slice(0, 200), 409 ...(isValidCall?.meta ?? {}), 410 }) 411 yield createUserMessage([ 412 { 413 type: 'tool_result', 414 content: isValidCall!.message, 415 is_error: true, 416 tool_use_id: toolUseID, 417 }, 418 ]) 419 return 420 } 421 422 // Check whether we have permission to use the tool, 423 // and ask the user for permission if we don't 424 const permissionResult = shouldSkipPermissionCheck 425 ? ({ result: true } as const) 426 : await canUseTool(tool, normalizedInput, context, assistantMessage) 427 if (permissionResult.result === false) { 428 yield createUserMessage([ 429 { 430 type: 'tool_result', 431 content: permissionResult.message, 432 is_error: true, 433 tool_use_id: toolUseID, 434 }, 435 ]) 436 return 437 } 438 439 // Call the tool 440 try { 441 const generator = tool.call(normalizedInput as never, context, canUseTool) 442 for await (const result of generator) { 443 switch (result.type) { 444 case 'result': 445 logEvent('tengu_tool_use_success', { 446 messageID: assistantMessage.message.id, 447 toolName: tool.name, 448 }) 449 yield createUserMessage( 450 [ 451 { 452 type: 'tool_result', 453 content: result.resultForAssistant, 454 tool_use_id: toolUseID, 455 }, 456 ], 457 { 458 data: result.data, 459 resultForAssistant: result.resultForAssistant, 460 }, 461 ) 462 return 463 case 'progress': 464 logEvent('tengu_tool_use_progress', { 465 messageID: assistantMessage.message.id, 466 toolName: tool.name, 467 }) 468 yield createProgressMessage( 469 toolUseID, 470 siblingToolUseIDs, 471 result.content, 472 result.normalizedMessages, 473 result.tools, 474 ) 475 } 476 } 477 } catch (error) { 478 const content = formatError(error) 479 logError(error) 480 logEvent('tengu_tool_use_error', { 481 error: content.slice(0, 2000), 482 messageID: assistantMessage.message.id, 483 toolName: tool.name, 484 toolInput: JSON.stringify(input).slice(0, 1000), 485 }) 486 yield createUserMessage([ 487 { 488 type: 'tool_result', 489 content, 490 is_error: true, 491 tool_use_id: toolUseID, 492 }, 493 ]) 494 } 495 } 496 497 function formatError(error: unknown): string { 498 if (!(error instanceof Error)) { 499 return String(error) 500 } 501 const parts = [error.message] 502 if ('stderr' in error && typeof error.stderr === 'string') { 503 parts.push(error.stderr) 504 } 505 if ('stdout' in error && typeof error.stdout === 'string') { 506 parts.push(error.stdout) 507 } 508 const fullMessage = parts.filter(Boolean).join('\n') 509 if (fullMessage.length <= 10000) { 510 return fullMessage 511 } 512 const halfLength = 5000 513 const start = fullMessage.slice(0, halfLength) 514 const end = fullMessage.slice(-halfLength) 515 return `${start}\n\n... [${fullMessage.length - 10000} characters truncated] ...\n\n${end}` 516 }