/ src / screens / REPL.tsx
REPL.tsx
  1  import { ToolUseBlockParam } from '@anthropic-ai/sdk/resources/index.mjs'
  2  import { Box, Newline, Static } from 'ink'
  3  import ProjectOnboarding, {
  4    markProjectOnboardingComplete,
  5  } from '../ProjectOnboarding.js'
  6  import { CostThresholdDialog } from '../components/CostThresholdDialog.js'
  7  import * as React from 'react'
  8  import { useEffect, useMemo, useRef, useState, useCallback } from 'react'
  9  import { Command } from '../commands.js'
 10  import { Logo } from '../components/Logo.js'
 11  import { Message } from '../components/Message.js'
 12  import { MessageResponse } from '../components/MessageResponse.js'
 13  import { MessageSelector } from '../components/MessageSelector.js'
 14  import {
 15    PermissionRequest,
 16    ToolUseConfirm,
 17  } from '../components/permissions/PermissionRequest.js'
 18  import PromptInput from '../components/PromptInput.js'
 19  import { Spinner } from '../components/Spinner.js'
 20  import { getSystemPrompt } from '../constants/prompts.js'
 21  import { getContext } from '../context.js'
 22  import { getTotalCost, useCostSummary } from '../cost-tracker.js'
 23  import { useLogStartupTime } from '../hooks/useLogStartupTime.js'
 24  import { addToHistory } from '../history.js'
 25  import { useApiKeyVerification } from '../hooks/useApiKeyVerification.js'
 26  import { useCancelRequest } from '../hooks/useCancelRequest.js'
 27  import useCanUseTool from '../hooks/useCanUseTool.js'
 28  import { useLogMessages } from '../hooks/useLogMessages.js'
 29  import { setMessagesGetter, setMessagesSetter } from '../messages.js'
 30  import {
 31    AssistantMessage,
 32    BinaryFeedbackResult,
 33    Message as MessageType,
 34    ProgressMessage,
 35    query,
 36  } from '../query.js'
 37  import type { WrappedClient } from '../services/mcpClient.js'
 38  import type { Tool } from '../Tool.js'
 39  import { AutoUpdaterResult } from '../utils/autoUpdater.js'
 40  import { getGlobalConfig, saveGlobalConfig } from '../utils/config.js'
 41  import { logEvent } from '../services/statsig.js'
 42  import { getNextAvailableLogForkNumber } from '../utils/log.js'
 43  import {
 44    getErroredToolUseMessages,
 45    getInProgressToolUseIDs,
 46    getLastAssistantMessageId,
 47    getToolUseID,
 48    getUnresolvedToolUseIDs,
 49    INTERRUPT_MESSAGE,
 50    isNotEmptyMessage,
 51    type NormalizedMessage,
 52    normalizeMessages,
 53    normalizeMessagesForAPI,
 54    processUserInput,
 55    reorderMessages,
 56  } from '../utils/messages.js'
 57  import { getSlowAndCapableModel } from '../utils/model.js'
 58  import { clearTerminal, updateTerminalTitle } from '../utils/terminal.js'
 59  import { BinaryFeedback } from '../components/binary-feedback/BinaryFeedback.js'
 60  import { getMaxThinkingTokens } from '../utils/thinking.js'
 61  import { getOriginalCwd } from '../utils/state.js'
 62  
 63  type Props = {
 64    commands: Command[]
 65    dangerouslySkipPermissions?: boolean
 66    debug?: boolean
 67    initialForkNumber?: number | undefined
 68    initialPrompt: string | undefined
 69    // A unique name for the message log file, used to identify the fork
 70    messageLogName: string
 71    shouldShowPromptInput: boolean
 72    tools: Tool[]
 73    verbose: boolean | undefined
 74    // Initial messages to populate the REPL with
 75    initialMessages?: MessageType[]
 76    // MCP clients
 77    mcpClients?: WrappedClient[]
 78    // Flag to indicate if current model is default
 79    isDefaultModel?: boolean
 80  }
 81  
 82  export type BinaryFeedbackContext = {
 83    m1: AssistantMessage
 84    m2: AssistantMessage
 85    resolve: (result: BinaryFeedbackResult) => void
 86  }
 87  
 88  export function REPL({
 89    commands,
 90    dangerouslySkipPermissions,
 91    debug = false,
 92    initialForkNumber = 0,
 93    initialPrompt,
 94    messageLogName,
 95    shouldShowPromptInput,
 96    tools,
 97    verbose: verboseFromCLI,
 98    initialMessages,
 99    mcpClients = [],
100    isDefaultModel = true,
101  }: Props): React.ReactNode {
102    // TODO: probably shouldn't re-read config from file synchronously on every keystroke
103    const verbose = verboseFromCLI ?? getGlobalConfig().verbose
104  
105    // Used to force the logo to re-render and conversation log to use a new file
106    const [forkNumber, setForkNumber] = useState(
107      getNextAvailableLogForkNumber(messageLogName, initialForkNumber, 0),
108    )
109  
110    const [
111      forkConvoWithMessagesOnTheNextRender,
112      setForkConvoWithMessagesOnTheNextRender,
113    ] = useState<MessageType[] | null>(null)
114  
115    const [abortController, setAbortController] =
116      useState<AbortController | null>(null)
117    const [isLoading, setIsLoading] = useState(false)
118    const [autoUpdaterResult, setAutoUpdaterResult] =
119      useState<AutoUpdaterResult | null>(null)
120    const [toolJSX, setToolJSX] = useState<{
121      jsx: React.ReactNode | null
122      shouldHidePromptInput: boolean
123    } | null>(null)
124    const [toolUseConfirm, setToolUseConfirm] = useState<ToolUseConfirm | null>(
125      null,
126    )
127    const [messages, setMessages] = useState<MessageType[]>(initialMessages ?? [])
128    const [inputValue, setInputValue] = useState('')
129    const [inputMode, setInputMode] = useState<'bash' | 'prompt'>('prompt')
130    const [submitCount, setSubmitCount] = useState(0)
131    const [isMessageSelectorVisible, setIsMessageSelectorVisible] =
132      useState(false)
133    const [showCostDialog, setShowCostDialog] = useState(false)
134    const [haveShownCostDialog, setHaveShownCostDialog] = useState(
135      getGlobalConfig().hasAcknowledgedCostThreshold,
136    )
137  
138    const [binaryFeedbackContext, setBinaryFeedbackContext] =
139      useState<BinaryFeedbackContext | null>(null)
140  
141    const getBinaryFeedbackResponse = useCallback(
142      (
143        m1: AssistantMessage,
144        m2: AssistantMessage,
145      ): Promise<BinaryFeedbackResult> => {
146        return new Promise<BinaryFeedbackResult>(resolvePromise => {
147          setBinaryFeedbackContext({
148            m1,
149            m2,
150            resolve: resolvePromise,
151          })
152        })
153      },
154      [],
155    )
156  
157    const readFileTimestamps = useRef<{
158      [filename: string]: number
159    }>({})
160  
161    const { status: apiKeyStatus, reverify } = useApiKeyVerification()
162    function onCancel() {
163      if (!isLoading) {
164        return
165      }
166      setIsLoading(false)
167      if (toolUseConfirm) {
168        // Tool use confirm handles the abort signal itself
169        toolUseConfirm.onAbort()
170      } else {
171        abortController?.abort()
172      }
173    }
174  
175    useCancelRequest(
176      setToolJSX,
177      setToolUseConfirm,
178      setBinaryFeedbackContext,
179      onCancel,
180      isLoading,
181      isMessageSelectorVisible,
182      abortController?.signal,
183    )
184  
185    useEffect(() => {
186      if (forkConvoWithMessagesOnTheNextRender) {
187        setForkNumber(_ => _ + 1)
188        setForkConvoWithMessagesOnTheNextRender(null)
189        setMessages(forkConvoWithMessagesOnTheNextRender)
190      }
191    }, [forkConvoWithMessagesOnTheNextRender])
192  
193    useEffect(() => {
194      const totalCost = getTotalCost()
195      if (totalCost >= 5 /* $5 */ && !showCostDialog && !haveShownCostDialog) {
196        logEvent('tengu_cost_threshold_reached', {})
197        setShowCostDialog(true)
198      }
199    }, [messages, showCostDialog, haveShownCostDialog])
200  
201    const canUseTool = useCanUseTool(setToolUseConfirm)
202  
203    async function onInit() {
204      reverify()
205  
206      if (!initialPrompt) {
207        return
208      }
209  
210      setIsLoading(true)
211  
212      const abortController = new AbortController()
213      setAbortController(abortController)
214  
215      const model = await getSlowAndCapableModel()
216      const newMessages = await processUserInput(
217        initialPrompt,
218        'prompt',
219        setToolJSX,
220        {
221          abortController,
222          options: {
223            commands,
224            forkNumber,
225            messageLogName,
226            tools,
227            verbose,
228            slowAndCapableModel: model,
229            maxThinkingTokens: 0,
230          },
231          messageId: getLastAssistantMessageId(messages),
232          setForkConvoWithMessagesOnTheNextRender,
233          readFileTimestamps: readFileTimestamps.current,
234        },
235        null,
236      )
237  
238      if (newMessages.length) {
239        for (const message of newMessages) {
240          if (message.type === 'user') {
241            addToHistory(initialPrompt)
242            // TODO: setHistoryIndex
243          }
244        }
245        setMessages(_ => [..._, ...newMessages])
246  
247        // The last message is an assistant message if the user input was a bash command,
248        // or if the user input was an invalid slash command.
249        const lastMessage = newMessages[newMessages.length - 1]!
250        if (lastMessage.type === 'assistant') {
251          setAbortController(null)
252          setIsLoading(false)
253          return
254        }
255  
256        const [systemPrompt, context, model, maxThinkingTokens] =
257          await Promise.all([
258            getSystemPrompt(),
259            getContext(),
260            getSlowAndCapableModel(),
261            getMaxThinkingTokens([...messages, ...newMessages]),
262          ])
263  
264        for await (const message of query(
265          [...messages, ...newMessages],
266          systemPrompt,
267          context,
268          canUseTool,
269          {
270            options: {
271              commands,
272              forkNumber,
273              messageLogName,
274              tools,
275              slowAndCapableModel: model,
276              verbose,
277              dangerouslySkipPermissions,
278              maxThinkingTokens,
279            },
280            messageId: getLastAssistantMessageId([...messages, ...newMessages]),
281            readFileTimestamps: readFileTimestamps.current,
282            abortController,
283            setToolJSX,
284          },
285          getBinaryFeedbackResponse,
286        )) {
287          setMessages(oldMessages => [...oldMessages, message])
288        }
289      } else {
290        addToHistory(initialPrompt)
291        // TODO: setHistoryIndex
292      }
293  
294      setHaveShownCostDialog(
295        getGlobalConfig().hasAcknowledgedCostThreshold || false,
296      )
297  
298      setIsLoading(false)
299    }
300  
301    async function onQuery(
302      newMessages: MessageType[],
303      abortController: AbortController,
304    ) {
305      setMessages(oldMessages => [...oldMessages, ...newMessages])
306  
307      // Mark onboarding as complete when any user message is sent to Claude
308      markProjectOnboardingComplete()
309  
310      // The last message is an assistant message if the user input was a bash command,
311      // or if the user input was an invalid slash command.
312      const lastMessage = newMessages[newMessages.length - 1]!
313  
314      // Update terminal title based on user message
315      if (
316        lastMessage.type === 'user' &&
317        typeof lastMessage.message.content === 'string'
318      ) {
319        updateTerminalTitle(lastMessage.message.content)
320      }
321      if (lastMessage.type === 'assistant') {
322        setAbortController(null)
323        setIsLoading(false)
324        return
325      }
326  
327      const [systemPrompt, context, model, maxThinkingTokens] = await Promise.all(
328        [
329          getSystemPrompt(),
330          getContext(),
331          getSlowAndCapableModel(),
332          getMaxThinkingTokens([...messages, lastMessage]),
333        ],
334      )
335  
336      // query the API
337      for await (const message of query(
338        [...messages, lastMessage],
339        systemPrompt,
340        context,
341        canUseTool,
342        {
343          options: {
344            commands,
345            forkNumber,
346            messageLogName,
347            tools,
348            slowAndCapableModel: model,
349            verbose,
350            dangerouslySkipPermissions,
351            maxThinkingTokens,
352          },
353          messageId: getLastAssistantMessageId([...messages, lastMessage]),
354          readFileTimestamps: readFileTimestamps.current,
355          abortController,
356          setToolJSX,
357        },
358        getBinaryFeedbackResponse,
359      )) {
360        setMessages(oldMessages => [...oldMessages, message])
361      }
362      setIsLoading(false)
363    }
364  
365    // Register cost summary tracker
366    useCostSummary()
367  
368    // Register messages getter and setter
369    useEffect(() => {
370      const getMessages = () => messages
371      setMessagesGetter(getMessages)
372      setMessagesSetter(setMessages)
373    }, [messages])
374  
375    // Record transcripts locally, for debugging and conversation recovery
376    useLogMessages(messages, messageLogName, forkNumber)
377  
378    // Log startup time
379    useLogStartupTime()
380  
381    // Initial load
382    useEffect(() => {
383      onInit()
384      // TODO: fix this
385      // eslint-disable-next-line react-hooks/exhaustive-deps
386    }, [])
387  
388    const normalizedMessages = useMemo(
389      () => normalizeMessages(messages).filter(isNotEmptyMessage),
390      [messages],
391    )
392  
393    const unresolvedToolUseIDs = useMemo(
394      () => getUnresolvedToolUseIDs(normalizedMessages),
395      [normalizedMessages],
396    )
397  
398    const inProgressToolUseIDs = useMemo(
399      () => getInProgressToolUseIDs(normalizedMessages),
400      [normalizedMessages],
401    )
402  
403    const erroredToolUseIDs = useMemo(
404      () =>
405        new Set(
406          getErroredToolUseMessages(normalizedMessages).map(
407            _ => (_.message.content[0]! as ToolUseBlockParam).id,
408          ),
409        ),
410      [normalizedMessages],
411    )
412  
413    const messagesJSX = useMemo(() => {
414      return [
415        {
416          type: 'static',
417          jsx: (
418            <Box flexDirection="column" key={`logo${forkNumber}`}>
419              <Logo mcpClients={mcpClients} isDefaultModel={isDefaultModel} />
420              <ProjectOnboarding workspaceDir={getOriginalCwd()} />
421            </Box>
422          ),
423        },
424        ...reorderMessages(normalizedMessages).map(_ => {
425          const toolUseID = getToolUseID(_)
426          const message =
427            _.type === 'progress' ? (
428              _.content.message.content[0]?.type === 'text' &&
429              // Hack: AgentTool interrupts use Progress messages, so don't
430              // need an extra ⎿ because <Message /> already adds one.
431              // TODO: Find a cleaner way to do this.
432              _.content.message.content[0].text === INTERRUPT_MESSAGE ? (
433                <Message
434                  message={_.content}
435                  messages={_.normalizedMessages}
436                  addMargin={false}
437                  tools={_.tools}
438                  verbose={verbose ?? false}
439                  debug={debug}
440                  erroredToolUseIDs={new Set()}
441                  inProgressToolUseIDs={new Set()}
442                  unresolvedToolUseIDs={new Set()}
443                  shouldAnimate={false}
444                  shouldShowDot={false}
445                />
446              ) : (
447                <MessageResponse>
448                  <Message
449                    message={_.content}
450                    messages={_.normalizedMessages}
451                    addMargin={false}
452                    tools={_.tools}
453                    verbose={verbose ?? false}
454                    debug={debug}
455                    erroredToolUseIDs={new Set()}
456                    inProgressToolUseIDs={new Set()}
457                    unresolvedToolUseIDs={
458                      new Set([
459                        (_.content.message.content[0]! as ToolUseBlockParam).id,
460                      ])
461                    }
462                    shouldAnimate={false}
463                    shouldShowDot={false}
464                  />
465                </MessageResponse>
466              )
467            ) : (
468              <Message
469                message={_}
470                messages={normalizedMessages}
471                addMargin={true}
472                tools={tools}
473                verbose={verbose}
474                debug={debug}
475                erroredToolUseIDs={erroredToolUseIDs}
476                inProgressToolUseIDs={inProgressToolUseIDs}
477                shouldAnimate={
478                  !toolJSX &&
479                  !toolUseConfirm &&
480                  !isMessageSelectorVisible &&
481                  (!toolUseID || inProgressToolUseIDs.has(toolUseID))
482                }
483                shouldShowDot={true}
484                unresolvedToolUseIDs={unresolvedToolUseIDs}
485              />
486            )
487  
488          const type = shouldRenderStatically(
489            _,
490            normalizedMessages,
491            unresolvedToolUseIDs,
492          )
493            ? 'static'
494            : 'transient'
495  
496          if (debug) {
497            return {
498              type,
499              jsx: (
500                <Box
501                  borderStyle="single"
502                  borderColor={type === 'static' ? 'green' : 'red'}
503                  key={_.uuid}
504                  width="100%"
505                >
506                  {message}
507                </Box>
508              ),
509            }
510          }
511  
512          return {
513            type,
514            jsx: (
515              <Box key={_.uuid} width="100%">
516                {message}
517              </Box>
518            ),
519          }
520        }),
521      ]
522    }, [
523      forkNumber,
524      normalizedMessages,
525      tools,
526      verbose,
527      debug,
528      erroredToolUseIDs,
529      inProgressToolUseIDs,
530      toolJSX,
531      toolUseConfirm,
532      isMessageSelectorVisible,
533      unresolvedToolUseIDs,
534      mcpClients,
535      isDefaultModel,
536    ])
537  
538    // only show the dialog once not loading
539    const showingCostDialog = !isLoading && showCostDialog
540  
541    return (
542      <>
543        <Static
544          key={`static-messages-${forkNumber}`}
545          items={messagesJSX.filter(_ => _.type === 'static')}
546        >
547          {_ => _.jsx}
548        </Static>
549        {messagesJSX.filter(_ => _.type === 'transient').map(_ => _.jsx)}
550        <Box
551          borderColor="red"
552          borderStyle={debug ? 'single' : undefined}
553          flexDirection="column"
554          width="100%"
555        >
556          {!toolJSX && !toolUseConfirm && !binaryFeedbackContext && isLoading && (
557            <Spinner />
558          )}
559          {toolJSX ? toolJSX.jsx : null}
560          {!toolJSX && binaryFeedbackContext && !isMessageSelectorVisible && (
561            <BinaryFeedback
562              m1={binaryFeedbackContext.m1}
563              m2={binaryFeedbackContext.m2}
564              resolve={result => {
565                binaryFeedbackContext.resolve(result)
566                setTimeout(() => setBinaryFeedbackContext(null), 0)
567              }}
568              verbose={verbose}
569              normalizedMessages={normalizedMessages}
570              tools={tools}
571              debug={debug}
572              erroredToolUseIDs={erroredToolUseIDs}
573              inProgressToolUseIDs={inProgressToolUseIDs}
574              unresolvedToolUseIDs={unresolvedToolUseIDs}
575            />
576          )}
577          {!toolJSX &&
578            toolUseConfirm &&
579            !isMessageSelectorVisible &&
580            !binaryFeedbackContext && (
581              <PermissionRequest
582                toolUseConfirm={toolUseConfirm}
583                onDone={() => setToolUseConfirm(null)}
584                verbose={verbose}
585              />
586            )}
587          {!toolJSX &&
588            !toolUseConfirm &&
589            !isMessageSelectorVisible &&
590            !binaryFeedbackContext &&
591            showingCostDialog && (
592              <CostThresholdDialog
593                onDone={() => {
594                  setShowCostDialog(false)
595                  setHaveShownCostDialog(true)
596                  const projectConfig = getGlobalConfig()
597                  saveGlobalConfig({
598                    ...projectConfig,
599                    hasAcknowledgedCostThreshold: true,
600                  })
601                  logEvent('tengu_cost_threshold_acknowledged', {})
602                }}
603              />
604            )}
605  
606          {!toolUseConfirm &&
607            !toolJSX?.shouldHidePromptInput &&
608            shouldShowPromptInput &&
609            !isMessageSelectorVisible &&
610            !binaryFeedbackContext &&
611            !showingCostDialog && (
612              <>
613                <PromptInput
614                  commands={commands}
615                  forkNumber={forkNumber}
616                  messageLogName={messageLogName}
617                  tools={tools}
618                  isDisabled={apiKeyStatus === 'invalid'}
619                  isLoading={isLoading}
620                  onQuery={onQuery}
621                  debug={debug}
622                  verbose={verbose}
623                  messages={messages}
624                  setToolJSX={setToolJSX}
625                  onAutoUpdaterResult={setAutoUpdaterResult}
626                  autoUpdaterResult={autoUpdaterResult}
627                  input={inputValue}
628                  onInputChange={setInputValue}
629                  mode={inputMode}
630                  onModeChange={setInputMode}
631                  submitCount={submitCount}
632                  onSubmitCountChange={setSubmitCount}
633                  setIsLoading={setIsLoading}
634                  setAbortController={setAbortController}
635                  onShowMessageSelector={() =>
636                    setIsMessageSelectorVisible(prev => !prev)
637                  }
638                  setForkConvoWithMessagesOnTheNextRender={
639                    setForkConvoWithMessagesOnTheNextRender
640                  }
641                  readFileTimestamps={readFileTimestamps.current}
642                />
643              </>
644            )}
645        </Box>
646        {isMessageSelectorVisible && (
647          <MessageSelector
648            erroredToolUseIDs={erroredToolUseIDs}
649            unresolvedToolUseIDs={unresolvedToolUseIDs}
650            messages={normalizeMessagesForAPI(messages)}
651            onSelect={async message => {
652              setIsMessageSelectorVisible(false)
653  
654              // If the user selected the current prompt, do nothing
655              if (!messages.includes(message)) {
656                return
657              }
658  
659              // Cancel tool use calls/requests
660              onCancel()
661  
662              // Hack: make sure the "Interrupted by user" message is
663              // rendered in response to the cancellation. Otherwise,
664              // the screen will be cleared but there will remain a
665              // vestigial "Interrupted by user" message at the top.
666              setImmediate(async () => {
667                // Clear messages, and re-render
668                await clearTerminal()
669                setMessages([])
670                setForkConvoWithMessagesOnTheNextRender(
671                  messages.slice(0, messages.indexOf(message)),
672                )
673  
674                // Populate/reset the prompt input
675                if (typeof message.message.content === 'string') {
676                  setInputValue(message.message.content)
677                }
678              })
679            }}
680            onEscape={() => setIsMessageSelectorVisible(false)}
681            tools={tools}
682          />
683        )}
684        {/** Fix occasional rendering artifact */}
685        <Newline />
686      </>
687    )
688  }
689  
690  function shouldRenderStatically(
691    message: NormalizedMessage,
692    messages: NormalizedMessage[],
693    unresolvedToolUseIDs: Set<string>,
694  ): boolean {
695    switch (message.type) {
696      case 'user':
697      case 'assistant': {
698        const toolUseID = getToolUseID(message)
699        if (!toolUseID) {
700          return true
701        }
702        if (unresolvedToolUseIDs.has(toolUseID)) {
703          return false
704        }
705  
706        const correspondingProgressMessage = messages.find(
707          _ => _.type === 'progress' && _.toolUseID === toolUseID,
708        ) as ProgressMessage | null
709        if (!correspondingProgressMessage) {
710          return true
711        }
712  
713        return !intersects(
714          unresolvedToolUseIDs,
715          correspondingProgressMessage.siblingToolUseIDs,
716        )
717      }
718      case 'progress':
719        return !intersects(unresolvedToolUseIDs, message.siblingToolUseIDs)
720    }
721  }
722  
723  function intersects<A>(a: Set<A>, b: Set<A>): boolean {
724    return a.size > 0 && b.size > 0 && [...a].some(_ => b.has(_))
725  }