REPL.tsx
1 import { ToolUseBlockParam } from '@anthropic-ai/sdk/resources/index.mjs' 2 import { Box, Newline, Static } from 'ink' 3 import ProjectOnboarding, { 4 markProjectOnboardingComplete, 5 } from '../ProjectOnboarding.js' 6 import { CostThresholdDialog } from '../components/CostThresholdDialog.js' 7 import * as React from 'react' 8 import { useEffect, useMemo, useRef, useState, useCallback } from 'react' 9 import { Command } from '../commands.js' 10 import { Logo } from '../components/Logo.js' 11 import { Message } from '../components/Message.js' 12 import { MessageResponse } from '../components/MessageResponse.js' 13 import { MessageSelector } from '../components/MessageSelector.js' 14 import { 15 PermissionRequest, 16 ToolUseConfirm, 17 } from '../components/permissions/PermissionRequest.js' 18 import PromptInput from '../components/PromptInput.js' 19 import { Spinner } from '../components/Spinner.js' 20 import { getSystemPrompt } from '../constants/prompts.js' 21 import { getContext } from '../context.js' 22 import { getTotalCost, useCostSummary } from '../cost-tracker.js' 23 import { useLogStartupTime } from '../hooks/useLogStartupTime.js' 24 import { addToHistory } from '../history.js' 25 import { useApiKeyVerification } from '../hooks/useApiKeyVerification.js' 26 import { useCancelRequest } from '../hooks/useCancelRequest.js' 27 import useCanUseTool from '../hooks/useCanUseTool.js' 28 import { useLogMessages } from '../hooks/useLogMessages.js' 29 import { setMessagesGetter, setMessagesSetter } from '../messages.js' 30 import { 31 AssistantMessage, 32 BinaryFeedbackResult, 33 Message as MessageType, 34 ProgressMessage, 35 query, 36 } from '../query.js' 37 import type { WrappedClient } from '../services/mcpClient.js' 38 import type { Tool } from '../Tool.js' 39 import { AutoUpdaterResult } from '../utils/autoUpdater.js' 40 import { getGlobalConfig, saveGlobalConfig } from '../utils/config.js' 41 import { logEvent } from '../services/statsig.js' 42 import { getNextAvailableLogForkNumber } from '../utils/log.js' 43 import { 44 getErroredToolUseMessages, 45 getInProgressToolUseIDs, 46 getLastAssistantMessageId, 47 getToolUseID, 48 getUnresolvedToolUseIDs, 49 INTERRUPT_MESSAGE, 50 isNotEmptyMessage, 51 type NormalizedMessage, 52 normalizeMessages, 53 normalizeMessagesForAPI, 54 processUserInput, 55 reorderMessages, 56 } from '../utils/messages.js' 57 import { getSlowAndCapableModel } from '../utils/model.js' 58 import { clearTerminal, updateTerminalTitle } from '../utils/terminal.js' 59 import { BinaryFeedback } from '../components/binary-feedback/BinaryFeedback.js' 60 import { getMaxThinkingTokens } from '../utils/thinking.js' 61 import { getOriginalCwd } from '../utils/state.js' 62 63 type Props = { 64 commands: Command[] 65 dangerouslySkipPermissions?: boolean 66 debug?: boolean 67 initialForkNumber?: number | undefined 68 initialPrompt: string | undefined 69 // A unique name for the message log file, used to identify the fork 70 messageLogName: string 71 shouldShowPromptInput: boolean 72 tools: Tool[] 73 verbose: boolean | undefined 74 // Initial messages to populate the REPL with 75 initialMessages?: MessageType[] 76 // MCP clients 77 mcpClients?: WrappedClient[] 78 // Flag to indicate if current model is default 79 isDefaultModel?: boolean 80 } 81 82 export type BinaryFeedbackContext = { 83 m1: AssistantMessage 84 m2: AssistantMessage 85 resolve: (result: BinaryFeedbackResult) => void 86 } 87 88 export function REPL({ 89 commands, 90 dangerouslySkipPermissions, 91 debug = false, 92 initialForkNumber = 0, 93 initialPrompt, 94 messageLogName, 95 shouldShowPromptInput, 96 tools, 97 verbose: verboseFromCLI, 98 initialMessages, 99 mcpClients = [], 100 isDefaultModel = true, 101 }: Props): React.ReactNode { 102 // TODO: probably shouldn't re-read config from file synchronously on every keystroke 103 const verbose = verboseFromCLI ?? getGlobalConfig().verbose 104 105 // Used to force the logo to re-render and conversation log to use a new file 106 const [forkNumber, setForkNumber] = useState( 107 getNextAvailableLogForkNumber(messageLogName, initialForkNumber, 0), 108 ) 109 110 const [ 111 forkConvoWithMessagesOnTheNextRender, 112 setForkConvoWithMessagesOnTheNextRender, 113 ] = useState<MessageType[] | null>(null) 114 115 const [abortController, setAbortController] = 116 useState<AbortController | null>(null) 117 const [isLoading, setIsLoading] = useState(false) 118 const [autoUpdaterResult, setAutoUpdaterResult] = 119 useState<AutoUpdaterResult | null>(null) 120 const [toolJSX, setToolJSX] = useState<{ 121 jsx: React.ReactNode | null 122 shouldHidePromptInput: boolean 123 } | null>(null) 124 const [toolUseConfirm, setToolUseConfirm] = useState<ToolUseConfirm | null>( 125 null, 126 ) 127 const [messages, setMessages] = useState<MessageType[]>(initialMessages ?? []) 128 const [inputValue, setInputValue] = useState('') 129 const [inputMode, setInputMode] = useState<'bash' | 'prompt'>('prompt') 130 const [submitCount, setSubmitCount] = useState(0) 131 const [isMessageSelectorVisible, setIsMessageSelectorVisible] = 132 useState(false) 133 const [showCostDialog, setShowCostDialog] = useState(false) 134 const [haveShownCostDialog, setHaveShownCostDialog] = useState( 135 getGlobalConfig().hasAcknowledgedCostThreshold, 136 ) 137 138 const [binaryFeedbackContext, setBinaryFeedbackContext] = 139 useState<BinaryFeedbackContext | null>(null) 140 141 const getBinaryFeedbackResponse = useCallback( 142 ( 143 m1: AssistantMessage, 144 m2: AssistantMessage, 145 ): Promise<BinaryFeedbackResult> => { 146 return new Promise<BinaryFeedbackResult>(resolvePromise => { 147 setBinaryFeedbackContext({ 148 m1, 149 m2, 150 resolve: resolvePromise, 151 }) 152 }) 153 }, 154 [], 155 ) 156 157 const readFileTimestamps = useRef<{ 158 [filename: string]: number 159 }>({}) 160 161 const { status: apiKeyStatus, reverify } = useApiKeyVerification() 162 function onCancel() { 163 if (!isLoading) { 164 return 165 } 166 setIsLoading(false) 167 if (toolUseConfirm) { 168 // Tool use confirm handles the abort signal itself 169 toolUseConfirm.onAbort() 170 } else { 171 abortController?.abort() 172 } 173 } 174 175 useCancelRequest( 176 setToolJSX, 177 setToolUseConfirm, 178 setBinaryFeedbackContext, 179 onCancel, 180 isLoading, 181 isMessageSelectorVisible, 182 abortController?.signal, 183 ) 184 185 useEffect(() => { 186 if (forkConvoWithMessagesOnTheNextRender) { 187 setForkNumber(_ => _ + 1) 188 setForkConvoWithMessagesOnTheNextRender(null) 189 setMessages(forkConvoWithMessagesOnTheNextRender) 190 } 191 }, [forkConvoWithMessagesOnTheNextRender]) 192 193 useEffect(() => { 194 const totalCost = getTotalCost() 195 if (totalCost >= 5 /* $5 */ && !showCostDialog && !haveShownCostDialog) { 196 logEvent('tengu_cost_threshold_reached', {}) 197 setShowCostDialog(true) 198 } 199 }, [messages, showCostDialog, haveShownCostDialog]) 200 201 const canUseTool = useCanUseTool(setToolUseConfirm) 202 203 async function onInit() { 204 reverify() 205 206 if (!initialPrompt) { 207 return 208 } 209 210 setIsLoading(true) 211 212 const abortController = new AbortController() 213 setAbortController(abortController) 214 215 const model = await getSlowAndCapableModel() 216 const newMessages = await processUserInput( 217 initialPrompt, 218 'prompt', 219 setToolJSX, 220 { 221 abortController, 222 options: { 223 commands, 224 forkNumber, 225 messageLogName, 226 tools, 227 verbose, 228 slowAndCapableModel: model, 229 maxThinkingTokens: 0, 230 }, 231 messageId: getLastAssistantMessageId(messages), 232 setForkConvoWithMessagesOnTheNextRender, 233 readFileTimestamps: readFileTimestamps.current, 234 }, 235 null, 236 ) 237 238 if (newMessages.length) { 239 for (const message of newMessages) { 240 if (message.type === 'user') { 241 addToHistory(initialPrompt) 242 // TODO: setHistoryIndex 243 } 244 } 245 setMessages(_ => [..._, ...newMessages]) 246 247 // The last message is an assistant message if the user input was a bash command, 248 // or if the user input was an invalid slash command. 249 const lastMessage = newMessages[newMessages.length - 1]! 250 if (lastMessage.type === 'assistant') { 251 setAbortController(null) 252 setIsLoading(false) 253 return 254 } 255 256 const [systemPrompt, context, model, maxThinkingTokens] = 257 await Promise.all([ 258 getSystemPrompt(), 259 getContext(), 260 getSlowAndCapableModel(), 261 getMaxThinkingTokens([...messages, ...newMessages]), 262 ]) 263 264 for await (const message of query( 265 [...messages, ...newMessages], 266 systemPrompt, 267 context, 268 canUseTool, 269 { 270 options: { 271 commands, 272 forkNumber, 273 messageLogName, 274 tools, 275 slowAndCapableModel: model, 276 verbose, 277 dangerouslySkipPermissions, 278 maxThinkingTokens, 279 }, 280 messageId: getLastAssistantMessageId([...messages, ...newMessages]), 281 readFileTimestamps: readFileTimestamps.current, 282 abortController, 283 setToolJSX, 284 }, 285 getBinaryFeedbackResponse, 286 )) { 287 setMessages(oldMessages => [...oldMessages, message]) 288 } 289 } else { 290 addToHistory(initialPrompt) 291 // TODO: setHistoryIndex 292 } 293 294 setHaveShownCostDialog( 295 getGlobalConfig().hasAcknowledgedCostThreshold || false, 296 ) 297 298 setIsLoading(false) 299 } 300 301 async function onQuery( 302 newMessages: MessageType[], 303 abortController: AbortController, 304 ) { 305 setMessages(oldMessages => [...oldMessages, ...newMessages]) 306 307 // Mark onboarding as complete when any user message is sent to Claude 308 markProjectOnboardingComplete() 309 310 // The last message is an assistant message if the user input was a bash command, 311 // or if the user input was an invalid slash command. 312 const lastMessage = newMessages[newMessages.length - 1]! 313 314 // Update terminal title based on user message 315 if ( 316 lastMessage.type === 'user' && 317 typeof lastMessage.message.content === 'string' 318 ) { 319 updateTerminalTitle(lastMessage.message.content) 320 } 321 if (lastMessage.type === 'assistant') { 322 setAbortController(null) 323 setIsLoading(false) 324 return 325 } 326 327 const [systemPrompt, context, model, maxThinkingTokens] = await Promise.all( 328 [ 329 getSystemPrompt(), 330 getContext(), 331 getSlowAndCapableModel(), 332 getMaxThinkingTokens([...messages, lastMessage]), 333 ], 334 ) 335 336 // query the API 337 for await (const message of query( 338 [...messages, lastMessage], 339 systemPrompt, 340 context, 341 canUseTool, 342 { 343 options: { 344 commands, 345 forkNumber, 346 messageLogName, 347 tools, 348 slowAndCapableModel: model, 349 verbose, 350 dangerouslySkipPermissions, 351 maxThinkingTokens, 352 }, 353 messageId: getLastAssistantMessageId([...messages, lastMessage]), 354 readFileTimestamps: readFileTimestamps.current, 355 abortController, 356 setToolJSX, 357 }, 358 getBinaryFeedbackResponse, 359 )) { 360 setMessages(oldMessages => [...oldMessages, message]) 361 } 362 setIsLoading(false) 363 } 364 365 // Register cost summary tracker 366 useCostSummary() 367 368 // Register messages getter and setter 369 useEffect(() => { 370 const getMessages = () => messages 371 setMessagesGetter(getMessages) 372 setMessagesSetter(setMessages) 373 }, [messages]) 374 375 // Record transcripts locally, for debugging and conversation recovery 376 useLogMessages(messages, messageLogName, forkNumber) 377 378 // Log startup time 379 useLogStartupTime() 380 381 // Initial load 382 useEffect(() => { 383 onInit() 384 // TODO: fix this 385 // eslint-disable-next-line react-hooks/exhaustive-deps 386 }, []) 387 388 const normalizedMessages = useMemo( 389 () => normalizeMessages(messages).filter(isNotEmptyMessage), 390 [messages], 391 ) 392 393 const unresolvedToolUseIDs = useMemo( 394 () => getUnresolvedToolUseIDs(normalizedMessages), 395 [normalizedMessages], 396 ) 397 398 const inProgressToolUseIDs = useMemo( 399 () => getInProgressToolUseIDs(normalizedMessages), 400 [normalizedMessages], 401 ) 402 403 const erroredToolUseIDs = useMemo( 404 () => 405 new Set( 406 getErroredToolUseMessages(normalizedMessages).map( 407 _ => (_.message.content[0]! as ToolUseBlockParam).id, 408 ), 409 ), 410 [normalizedMessages], 411 ) 412 413 const messagesJSX = useMemo(() => { 414 return [ 415 { 416 type: 'static', 417 jsx: ( 418 <Box flexDirection="column" key={`logo${forkNumber}`}> 419 <Logo mcpClients={mcpClients} isDefaultModel={isDefaultModel} /> 420 <ProjectOnboarding workspaceDir={getOriginalCwd()} /> 421 </Box> 422 ), 423 }, 424 ...reorderMessages(normalizedMessages).map(_ => { 425 const toolUseID = getToolUseID(_) 426 const message = 427 _.type === 'progress' ? ( 428 _.content.message.content[0]?.type === 'text' && 429 // Hack: AgentTool interrupts use Progress messages, so don't 430 // need an extra ⎿ because <Message /> already adds one. 431 // TODO: Find a cleaner way to do this. 432 _.content.message.content[0].text === INTERRUPT_MESSAGE ? ( 433 <Message 434 message={_.content} 435 messages={_.normalizedMessages} 436 addMargin={false} 437 tools={_.tools} 438 verbose={verbose ?? false} 439 debug={debug} 440 erroredToolUseIDs={new Set()} 441 inProgressToolUseIDs={new Set()} 442 unresolvedToolUseIDs={new Set()} 443 shouldAnimate={false} 444 shouldShowDot={false} 445 /> 446 ) : ( 447 <MessageResponse> 448 <Message 449 message={_.content} 450 messages={_.normalizedMessages} 451 addMargin={false} 452 tools={_.tools} 453 verbose={verbose ?? false} 454 debug={debug} 455 erroredToolUseIDs={new Set()} 456 inProgressToolUseIDs={new Set()} 457 unresolvedToolUseIDs={ 458 new Set([ 459 (_.content.message.content[0]! as ToolUseBlockParam).id, 460 ]) 461 } 462 shouldAnimate={false} 463 shouldShowDot={false} 464 /> 465 </MessageResponse> 466 ) 467 ) : ( 468 <Message 469 message={_} 470 messages={normalizedMessages} 471 addMargin={true} 472 tools={tools} 473 verbose={verbose} 474 debug={debug} 475 erroredToolUseIDs={erroredToolUseIDs} 476 inProgressToolUseIDs={inProgressToolUseIDs} 477 shouldAnimate={ 478 !toolJSX && 479 !toolUseConfirm && 480 !isMessageSelectorVisible && 481 (!toolUseID || inProgressToolUseIDs.has(toolUseID)) 482 } 483 shouldShowDot={true} 484 unresolvedToolUseIDs={unresolvedToolUseIDs} 485 /> 486 ) 487 488 const type = shouldRenderStatically( 489 _, 490 normalizedMessages, 491 unresolvedToolUseIDs, 492 ) 493 ? 'static' 494 : 'transient' 495 496 if (debug) { 497 return { 498 type, 499 jsx: ( 500 <Box 501 borderStyle="single" 502 borderColor={type === 'static' ? 'green' : 'red'} 503 key={_.uuid} 504 width="100%" 505 > 506 {message} 507 </Box> 508 ), 509 } 510 } 511 512 return { 513 type, 514 jsx: ( 515 <Box key={_.uuid} width="100%"> 516 {message} 517 </Box> 518 ), 519 } 520 }), 521 ] 522 }, [ 523 forkNumber, 524 normalizedMessages, 525 tools, 526 verbose, 527 debug, 528 erroredToolUseIDs, 529 inProgressToolUseIDs, 530 toolJSX, 531 toolUseConfirm, 532 isMessageSelectorVisible, 533 unresolvedToolUseIDs, 534 mcpClients, 535 isDefaultModel, 536 ]) 537 538 // only show the dialog once not loading 539 const showingCostDialog = !isLoading && showCostDialog 540 541 return ( 542 <> 543 <Static 544 key={`static-messages-${forkNumber}`} 545 items={messagesJSX.filter(_ => _.type === 'static')} 546 > 547 {_ => _.jsx} 548 </Static> 549 {messagesJSX.filter(_ => _.type === 'transient').map(_ => _.jsx)} 550 <Box 551 borderColor="red" 552 borderStyle={debug ? 'single' : undefined} 553 flexDirection="column" 554 width="100%" 555 > 556 {!toolJSX && !toolUseConfirm && !binaryFeedbackContext && isLoading && ( 557 <Spinner /> 558 )} 559 {toolJSX ? toolJSX.jsx : null} 560 {!toolJSX && binaryFeedbackContext && !isMessageSelectorVisible && ( 561 <BinaryFeedback 562 m1={binaryFeedbackContext.m1} 563 m2={binaryFeedbackContext.m2} 564 resolve={result => { 565 binaryFeedbackContext.resolve(result) 566 setTimeout(() => setBinaryFeedbackContext(null), 0) 567 }} 568 verbose={verbose} 569 normalizedMessages={normalizedMessages} 570 tools={tools} 571 debug={debug} 572 erroredToolUseIDs={erroredToolUseIDs} 573 inProgressToolUseIDs={inProgressToolUseIDs} 574 unresolvedToolUseIDs={unresolvedToolUseIDs} 575 /> 576 )} 577 {!toolJSX && 578 toolUseConfirm && 579 !isMessageSelectorVisible && 580 !binaryFeedbackContext && ( 581 <PermissionRequest 582 toolUseConfirm={toolUseConfirm} 583 onDone={() => setToolUseConfirm(null)} 584 verbose={verbose} 585 /> 586 )} 587 {!toolJSX && 588 !toolUseConfirm && 589 !isMessageSelectorVisible && 590 !binaryFeedbackContext && 591 showingCostDialog && ( 592 <CostThresholdDialog 593 onDone={() => { 594 setShowCostDialog(false) 595 setHaveShownCostDialog(true) 596 const projectConfig = getGlobalConfig() 597 saveGlobalConfig({ 598 ...projectConfig, 599 hasAcknowledgedCostThreshold: true, 600 }) 601 logEvent('tengu_cost_threshold_acknowledged', {}) 602 }} 603 /> 604 )} 605 606 {!toolUseConfirm && 607 !toolJSX?.shouldHidePromptInput && 608 shouldShowPromptInput && 609 !isMessageSelectorVisible && 610 !binaryFeedbackContext && 611 !showingCostDialog && ( 612 <> 613 <PromptInput 614 commands={commands} 615 forkNumber={forkNumber} 616 messageLogName={messageLogName} 617 tools={tools} 618 isDisabled={apiKeyStatus === 'invalid'} 619 isLoading={isLoading} 620 onQuery={onQuery} 621 debug={debug} 622 verbose={verbose} 623 messages={messages} 624 setToolJSX={setToolJSX} 625 onAutoUpdaterResult={setAutoUpdaterResult} 626 autoUpdaterResult={autoUpdaterResult} 627 input={inputValue} 628 onInputChange={setInputValue} 629 mode={inputMode} 630 onModeChange={setInputMode} 631 submitCount={submitCount} 632 onSubmitCountChange={setSubmitCount} 633 setIsLoading={setIsLoading} 634 setAbortController={setAbortController} 635 onShowMessageSelector={() => 636 setIsMessageSelectorVisible(prev => !prev) 637 } 638 setForkConvoWithMessagesOnTheNextRender={ 639 setForkConvoWithMessagesOnTheNextRender 640 } 641 readFileTimestamps={readFileTimestamps.current} 642 /> 643 </> 644 )} 645 </Box> 646 {isMessageSelectorVisible && ( 647 <MessageSelector 648 erroredToolUseIDs={erroredToolUseIDs} 649 unresolvedToolUseIDs={unresolvedToolUseIDs} 650 messages={normalizeMessagesForAPI(messages)} 651 onSelect={async message => { 652 setIsMessageSelectorVisible(false) 653 654 // If the user selected the current prompt, do nothing 655 if (!messages.includes(message)) { 656 return 657 } 658 659 // Cancel tool use calls/requests 660 onCancel() 661 662 // Hack: make sure the "Interrupted by user" message is 663 // rendered in response to the cancellation. Otherwise, 664 // the screen will be cleared but there will remain a 665 // vestigial "Interrupted by user" message at the top. 666 setImmediate(async () => { 667 // Clear messages, and re-render 668 await clearTerminal() 669 setMessages([]) 670 setForkConvoWithMessagesOnTheNextRender( 671 messages.slice(0, messages.indexOf(message)), 672 ) 673 674 // Populate/reset the prompt input 675 if (typeof message.message.content === 'string') { 676 setInputValue(message.message.content) 677 } 678 }) 679 }} 680 onEscape={() => setIsMessageSelectorVisible(false)} 681 tools={tools} 682 /> 683 )} 684 {/** Fix occasional rendering artifact */} 685 <Newline /> 686 </> 687 ) 688 } 689 690 function shouldRenderStatically( 691 message: NormalizedMessage, 692 messages: NormalizedMessage[], 693 unresolvedToolUseIDs: Set<string>, 694 ): boolean { 695 switch (message.type) { 696 case 'user': 697 case 'assistant': { 698 const toolUseID = getToolUseID(message) 699 if (!toolUseID) { 700 return true 701 } 702 if (unresolvedToolUseIDs.has(toolUseID)) { 703 return false 704 } 705 706 const correspondingProgressMessage = messages.find( 707 _ => _.type === 'progress' && _.toolUseID === toolUseID, 708 ) as ProgressMessage | null 709 if (!correspondingProgressMessage) { 710 return true 711 } 712 713 return !intersects( 714 unresolvedToolUseIDs, 715 correspondingProgressMessage.siblingToolUseIDs, 716 ) 717 } 718 case 'progress': 719 return !intersects(unresolvedToolUseIDs, message.siblingToolUseIDs) 720 } 721 } 722 723 function intersects<A>(a: Set<A>, b: Set<A>): boolean { 724 return a.size > 0 && b.size > 0 && [...a].some(_ => b.has(_)) 725 }