/ src / services / claude.ts
claude.ts
  1  import '@anthropic-ai/sdk/shims/node'
  2  import Anthropic, { APIConnectionError, APIError } from '@anthropic-ai/sdk'
  3  import { AnthropicBedrock } from '@anthropic-ai/bedrock-sdk'
  4  import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'
  5  import type { BetaUsage } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs'
  6  import chalk from 'chalk'
  7  import { createHash, randomUUID } from 'crypto'
  8  import 'dotenv/config'
  9  import { getBetas } from '../utils/betas.js'
 10  
 11  import { addToTotalCost } from '../cost-tracker.js'
 12  import type { AssistantMessage, UserMessage } from '../query.js'
 13  import { Tool } from '../Tool.js'
 14  import { getAnthropicApiKey, getOrCreateUserID } from '../utils/config.js'
 15  import { logError, SESSION_ID } from '../utils/log.js'
 16  import { USER_AGENT } from '../utils/http.js'
 17  import {
 18    createAssistantAPIErrorMessage,
 19    normalizeContentFromAPI,
 20  } from '../utils/messages.js'
 21  import { countTokens } from '../utils/tokens.js'
 22  import { logEvent } from './statsig.js'
 23  import { withVCR } from './vcr.js'
 24  import { zodToJsonSchema } from 'zod-to-json-schema'
 25  import type { BetaMessageStream } from '@anthropic-ai/sdk/lib/BetaMessageStream.mjs'
 26  import type {
 27    Message as APIMessage,
 28    MessageParam,
 29    TextBlockParam,
 30  } from '@anthropic-ai/sdk/resources/index.mjs'
 31  import { SMALL_FAST_MODEL, USE_BEDROCK, USE_VERTEX } from '../utils/model.js'
 32  import { getCLISyspromptPrefix } from '../constants/prompts.js'
 33  import { getVertexRegionForModel } from '../utils/model.js'
 34  
 35  interface StreamResponse extends APIMessage {
 36    ttftMs?: number
 37  }
 38  
 39  export const API_ERROR_MESSAGE_PREFIX = 'API Error'
 40  export const PROMPT_TOO_LONG_ERROR_MESSAGE = 'Prompt is too long'
 41  export const CREDIT_BALANCE_TOO_LOW_ERROR_MESSAGE = 'Credit balance is too low'
 42  export const INVALID_API_KEY_ERROR_MESSAGE =
 43    'Invalid API key · Please run /login'
 44  export const NO_CONTENT_MESSAGE = '(no content)'
 45  const PROMPT_CACHING_ENABLED = !process.env.DISABLE_PROMPT_CACHING
 46  
 47  // @see https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table
 48  const HAIKU_COST_PER_MILLION_INPUT_TOKENS = 0.8
 49  const HAIKU_COST_PER_MILLION_OUTPUT_TOKENS = 4
 50  const HAIKU_COST_PER_MILLION_PROMPT_CACHE_WRITE_TOKENS = 1
 51  const HAIKU_COST_PER_MILLION_PROMPT_CACHE_READ_TOKENS = 0.08
 52  
 53  const SONNET_COST_PER_MILLION_INPUT_TOKENS = 3
 54  const SONNET_COST_PER_MILLION_OUTPUT_TOKENS = 15
 55  const SONNET_COST_PER_MILLION_PROMPT_CACHE_WRITE_TOKENS = 3.75
 56  const SONNET_COST_PER_MILLION_PROMPT_CACHE_READ_TOKENS = 0.3
 57  
 58  export const MAIN_QUERY_TEMPERATURE = 1 // to get more variation for binary feedback
 59  
 60  function getMetadata() {
 61    return {
 62      user_id: `${getOrCreateUserID()}_${SESSION_ID}`,
 63    }
 64  }
 65  
 66  const MAX_RETRIES = process.env.USER_TYPE === 'SWE_BENCH' ? 100 : 10
 67  const BASE_DELAY_MS = 500
 68  
 69  interface RetryOptions {
 70    maxRetries?: number
 71  }
 72  
 73  function getRetryDelay(
 74    attempt: number,
 75    retryAfterHeader?: string | null,
 76  ): number {
 77    if (retryAfterHeader) {
 78      const seconds = parseInt(retryAfterHeader, 10)
 79      if (!isNaN(seconds)) {
 80        return seconds * 1000
 81      }
 82    }
 83    return Math.min(BASE_DELAY_MS * Math.pow(2, attempt - 1), 32000) // Max 32s delay
 84  }
 85  
 86  function shouldRetry(error: APIError): boolean {
 87    // Check for overloaded errors first and only retry for SWE_BENCH
 88    if (error.message?.includes('"type":"overloaded_error"')) {
 89      return process.env.USER_TYPE === 'SWE_BENCH'
 90    }
 91  
 92    // Note this is not a standard header.
 93    const shouldRetryHeader = error.headers?.['x-should-retry']
 94  
 95    // If the server explicitly says whether or not to retry, obey.
 96    if (shouldRetryHeader === 'true') return true
 97    if (shouldRetryHeader === 'false') return false
 98  
 99    if (error instanceof APIConnectionError) {
100      return true
101    }
102  
103    if (!error.status) return false
104  
105    // Retry on request timeouts.
106    if (error.status === 408) return true
107  
108    // Retry on lock timeouts.
109    if (error.status === 409) return true
110  
111    // Retry on rate limits.
112    if (error.status === 429) return true
113  
114    // Retry internal errors.
115    if (error.status && error.status >= 500) return true
116  
117    return false
118  }
119  
120  async function withRetry<T>(
121    operation: (attempt: number) => Promise<T>,
122    options: RetryOptions = {},
123  ): Promise<T> {
124    const maxRetries = options.maxRetries ?? MAX_RETRIES
125    let lastError: unknown
126  
127    for (let attempt = 1; attempt <= maxRetries + 1; attempt++) {
128      try {
129        return await operation(attempt)
130      } catch (error) {
131        lastError = error
132  
133        // Only retry if the error indicates we should
134        if (
135          attempt > maxRetries ||
136          !(error instanceof APIError) ||
137          !shouldRetry(error)
138        ) {
139          throw error
140        }
141        // Get retry-after header if available
142        const retryAfter = error.headers?.['retry-after'] ?? null
143        const delayMs = getRetryDelay(attempt, retryAfter)
144  
145        console.log(
146          `  ⎿  ${chalk.red(`API ${error.name} (${error.message}) · Retrying in ${Math.round(delayMs / 1000)} seconds… (attempt ${attempt}/${maxRetries})`)}`,
147        )
148  
149        logEvent('tengu_api_retry', {
150          attempt: String(attempt),
151          delayMs: String(delayMs),
152          error: error.message,
153          status: String(error.status),
154          provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
155        })
156  
157        await new Promise(resolve => setTimeout(resolve, delayMs))
158      }
159    }
160  
161    throw lastError
162  }
163  
164  export async function verifyApiKey(apiKey: string): Promise<boolean> {
165    const anthropic = new Anthropic({
166      apiKey,
167      dangerouslyAllowBrowser: true,
168      maxRetries: 3,
169      defaultHeaders: {
170        'User-Agent': USER_AGENT,
171      },
172    })
173  
174    try {
175      await withRetry(
176        async () => {
177          const model = SMALL_FAST_MODEL
178          const messages: MessageParam[] = [{ role: 'user', content: 'test' }]
179          await anthropic.messages.create({
180            model,
181            max_tokens: 1,
182            messages,
183            temperature: 0,
184            metadata: getMetadata(),
185          })
186          return true
187        },
188        { maxRetries: 2 }, // Use fewer retries for API key verification
189      )
190      return true
191    } catch (error) {
192      logError(error)
193      // Check for authentication error
194      if (
195        error instanceof Error &&
196        error.message.includes(
197          '{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}',
198        )
199      ) {
200        return false
201      }
202      throw error
203    }
204  }
205  
206  async function handleMessageStream(
207    stream: BetaMessageStream,
208  ): Promise<StreamResponse> {
209    const streamStartTime = Date.now()
210    let ttftMs: number | undefined
211  
212    // TODO(ben): Consider showing an incremental progress indicator.
213    for await (const part of stream) {
214      if (part.type === 'message_start') {
215        ttftMs = Date.now() - streamStartTime
216      }
217    }
218  
219    const finalResponse = await stream.finalMessage()
220    return {
221      ...finalResponse,
222      ttftMs,
223    }
224  }
225  
226  let anthropicClient: Anthropic | null = null
227  
228  /**
229   * Get the Anthropic client, creating it if it doesn't exist
230   */
231  export function getAnthropicClient(model?: string): Anthropic {
232    if (anthropicClient) {
233      return anthropicClient
234    }
235  
236    const region = getVertexRegionForModel(model)
237  
238    const defaultHeaders: { [key: string]: string } = {
239      'x-app': 'cli',
240      'User-Agent': USER_AGENT,
241    }
242    if (process.env.ANTHROPIC_AUTH_TOKEN) {
243      defaultHeaders['Authorization'] =
244        `Bearer ${process.env.ANTHROPIC_AUTH_TOKEN}`
245    }
246  
247    const ARGS = {
248      defaultHeaders,
249      maxRetries: 0, // Disabled auto-retry in favor of manual implementation
250      timeout: parseInt(process.env.API_TIMEOUT_MS || String(60 * 1000), 10),
251    }
252    if (USE_BEDROCK) {
253      const client = new AnthropicBedrock(ARGS)
254      anthropicClient = client
255      return client
256    }
257    if (USE_VERTEX) {
258      const vertexArgs = {
259        ...ARGS,
260        region: region || process.env.CLOUD_ML_REGION || 'us-east5',
261      }
262      const client = new AnthropicVertex(vertexArgs)
263      anthropicClient = client
264      return client
265    }
266  
267    const apiKey = getAnthropicApiKey()
268    if (process.env.USER_TYPE === 'ant' && !apiKey) {
269      console.error(
270        chalk.red(
271          '[ANT-ONLY] Please set the ANTHROPIC_API_KEY environment variable to use the CLI. To create a new key, go to https://console.anthropic.com/settings/keys.',
272        ),
273      )
274    }
275    anthropicClient = new Anthropic({
276      apiKey,
277      dangerouslyAllowBrowser: true,
278      ...ARGS,
279    })
280    return anthropicClient
281  }
282  
283  /**
284   * Reset the Anthropic client to null, forcing a new client to be created on next use
285   */
286  export function resetAnthropicClient(): void {
287    anthropicClient = null
288  }
289  
290  /**
291   * Environment variables for different client types:
292   *
293   * Direct API:
294   * - ANTHROPIC_API_KEY: Required for direct API access
295   *
296   * AWS Bedrock:
297   * - AWS credentials configured via aws-sdk defaults
298   *
299   * Vertex AI:
300   * - Model-specific region variables (highest priority):
301   *   - VERTEX_REGION_CLAUDE_3_5_HAIKU: Region for Claude 3.5 Haiku model
302   *   - VERTEX_REGION_CLAUDE_3_5_SONNET: Region for Claude 3.5 Sonnet model
303   *   - VERTEX_REGION_CLAUDE_3_7_SONNET: Region for Claude 3.7 Sonnet model
304   * - CLOUD_ML_REGION: Optional. The default GCP region to use for all models
305   *   If specific model region not specified above
306   * - ANTHROPIC_VERTEX_PROJECT_ID: Required. Your GCP project ID
307   * - Standard GCP credentials configured via google-auth-library
308   *
309   * Priority for determining region:
310   * 1. Hardcoded model-specific environment variables
311   * 2. Global CLOUD_ML_REGION variable
312   * 3. Default region from config
313   * 4. Fallback region (us-east5)
314   */
315  
316  export function userMessageToMessageParam(
317    message: UserMessage,
318    addCache = false,
319  ): MessageParam {
320    if (addCache) {
321      if (typeof message.message.content === 'string') {
322        return {
323          role: 'user',
324          content: [
325            {
326              type: 'text',
327              text: message.message.content,
328              ...(PROMPT_CACHING_ENABLED
329                ? { cache_control: { type: 'ephemeral' } }
330                : {}),
331            },
332          ],
333        }
334      } else {
335        return {
336          role: 'user',
337          content: message.message.content.map((_, i) => ({
338            ..._,
339            ...(i === message.message.content.length - 1
340              ? PROMPT_CACHING_ENABLED
341                ? { cache_control: { type: 'ephemeral' } }
342                : {}
343              : {}),
344          })),
345        }
346      }
347    }
348    return {
349      role: 'user',
350      content: message.message.content,
351    }
352  }
353  
354  export function assistantMessageToMessageParam(
355    message: AssistantMessage,
356    addCache = false,
357  ): MessageParam {
358    if (addCache) {
359      if (typeof message.message.content === 'string') {
360        return {
361          role: 'assistant',
362          content: [
363            {
364              type: 'text',
365              text: message.message.content,
366              ...(PROMPT_CACHING_ENABLED
367                ? { cache_control: { type: 'ephemeral' } }
368                : {}),
369            },
370          ],
371        }
372      } else {
373        return {
374          role: 'assistant',
375          content: message.message.content.map((_, i) => ({
376            ..._,
377            ...(i === message.message.content.length - 1 &&
378            _.type !== 'thinking' &&
379            _.type !== 'redacted_thinking'
380              ? PROMPT_CACHING_ENABLED
381                ? { cache_control: { type: 'ephemeral' } }
382                : {}
383              : {}),
384          })),
385        }
386      }
387    }
388    return {
389      role: 'assistant',
390      content: message.message.content,
391    }
392  }
393  
394  function splitSysPromptPrefix(systemPrompt: string[]): string[] {
395    // split out the first block of the system prompt as the "prefix" for API
396    // to match on in https://console.statsig.com/4aF3Ewatb6xPVpCwxb5nA3/dynamic_configs/claude_cli_system_prompt_prefixes
397    const systemPromptFirstBlock = systemPrompt[0] || ''
398    const systemPromptRest = systemPrompt.slice(1)
399    return [systemPromptFirstBlock, systemPromptRest.join('\n')].filter(Boolean)
400  }
401  
402  export async function querySonnet(
403    messages: (UserMessage | AssistantMessage)[],
404    systemPrompt: string[],
405    maxThinkingTokens: number,
406    tools: Tool[],
407    signal: AbortSignal,
408    options: {
409      dangerouslySkipPermissions: boolean
410      model: string
411      prependCLISysprompt: boolean
412    },
413  ): Promise<AssistantMessage> {
414    return await withVCR(messages, () =>
415      querySonnetWithPromptCaching(
416        messages,
417        systemPrompt,
418        maxThinkingTokens,
419        tools,
420        signal,
421        options,
422      ),
423    )
424  }
425  
426  export function formatSystemPromptWithContext(
427    systemPrompt: string[],
428    context: { [k: string]: string },
429  ): string[] {
430    if (Object.entries(context).length === 0) {
431      return systemPrompt
432    }
433  
434    return [
435      ...systemPrompt,
436      `\nAs you answer the user's questions, you can use the following context:\n`,
437      ...Object.entries(context).map(
438        ([key, value]) => `<context name="${key}">${value}</context>`,
439      ),
440    ]
441  }
442  
443  async function querySonnetWithPromptCaching(
444    messages: (UserMessage | AssistantMessage)[],
445    systemPrompt: string[],
446    maxThinkingTokens: number,
447    tools: Tool[],
448    signal: AbortSignal,
449    options: {
450      dangerouslySkipPermissions: boolean
451      model: string
452      prependCLISysprompt: boolean
453    },
454  ): Promise<AssistantMessage> {
455    const anthropic = await getAnthropicClient(options.model)
456  
457    // Prepend system prompt block for easy API identification
458    if (options.prependCLISysprompt) {
459      // Log stats about first block for analyzing prefix matching config (see https://console.statsig.com/4aF3Ewatb6xPVpCwxb5nA3/dynamic_configs/claude_cli_system_prompt_prefixes)
460      const [firstSyspromptBlock] = splitSysPromptPrefix(systemPrompt)
461      logEvent('tengu_sysprompt_block', {
462        snippet: firstSyspromptBlock?.slice(0, 20),
463        length: String(firstSyspromptBlock?.length ?? 0),
464        hash: firstSyspromptBlock
465          ? createHash('sha256').update(firstSyspromptBlock).digest('hex')
466          : '',
467      })
468  
469      systemPrompt = [getCLISyspromptPrefix(), ...systemPrompt]
470    }
471  
472    const system: TextBlockParam[] = splitSysPromptPrefix(systemPrompt).map(
473      _ => ({
474        ...(PROMPT_CACHING_ENABLED
475          ? { cache_control: { type: 'ephemeral' } }
476          : {}),
477        text: _,
478        type: 'text',
479      }),
480    )
481  
482    const toolSchemas = await Promise.all(
483      tools.map(async _ => ({
484        name: _.name,
485        description: await _.prompt({
486          dangerouslySkipPermissions: options.dangerouslySkipPermissions,
487        }),
488        // Use tool's JSON schema directly if provided, otherwise convert Zod schema
489        input_schema: ('inputJSONSchema' in _ && _.inputJSONSchema
490          ? _.inputJSONSchema
491          : zodToJsonSchema(_.inputSchema)) as Anthropic.Tool.InputSchema,
492      })),
493    )
494  
495    const betas = await getBetas()
496    const useBetas = PROMPT_CACHING_ENABLED && betas.length > 0
497    logEvent('tengu_api_query', {
498      model: options.model,
499      messagesLength: String(
500        JSON.stringify([...system, ...messages, ...toolSchemas]).length,
501      ),
502      temperature: String(MAIN_QUERY_TEMPERATURE),
503      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
504      ...(useBetas ? { betas: betas.join(',') } : {}),
505    })
506  
507    const startIncludingRetries = Date.now()
508    let start = Date.now()
509    let attemptNumber = 0
510    let response
511    let stream: BetaMessageStream | undefined = undefined
512    try {
513      response = await withRetry(async attempt => {
514        attemptNumber = attempt
515        start = Date.now()
516        const s = anthropic.beta.messages.stream(
517          {
518            model: options.model,
519            max_tokens: Math.max(
520              maxThinkingTokens + 1,
521              getMaxTokensForModel(options.model),
522            ),
523            messages: addCacheBreakpoints(messages),
524            temperature: MAIN_QUERY_TEMPERATURE,
525            system,
526            tools: toolSchemas,
527            ...(useBetas ? { betas } : {}),
528            metadata: getMetadata(),
529            ...(process.env.USER_TYPE === 'ant' && maxThinkingTokens > 0
530              ? {
531                  thinking: {
532                    budget_tokens: maxThinkingTokens,
533                    type: 'enabled',
534                  },
535                }
536              : {}),
537          },
538          { signal },
539        )
540        stream = s
541        return handleMessageStream(s)
542      })
543    } catch (error) {
544      logError(error)
545      logEvent('tengu_api_error', {
546        model: options.model,
547        error: error instanceof Error ? error.message : String(error),
548        status: error instanceof APIError ? String(error.status) : undefined,
549        messageCount: String(messages.length),
550        messageTokens: String(countTokens(messages)),
551        durationMs: String(Date.now() - start),
552        durationMsIncludingRetries: String(Date.now() - startIncludingRetries),
553        attempt: String(attemptNumber),
554        provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
555        requestId:
556          (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
557      })
558      return getAssistantMessageFromError(error)
559    }
560    const durationMs = Date.now() - start
561    const durationMsIncludingRetries = Date.now() - startIncludingRetries
562    logEvent('tengu_api_success', {
563      model: options.model,
564      messageCount: String(messages.length),
565      messageTokens: String(countTokens(messages)),
566      inputTokens: String(response.usage.input_tokens),
567      outputTokens: String(response.usage.output_tokens),
568      cachedInputTokens: String(
569        (response.usage as BetaUsage).cache_read_input_tokens ?? 0,
570      ),
571      uncachedInputTokens: String(
572        (response.usage as BetaUsage).cache_creation_input_tokens ?? 0,
573      ),
574      durationMs: String(durationMs),
575      durationMsIncludingRetries: String(durationMsIncludingRetries),
576      attempt: String(attemptNumber),
577      ttftMs: String(response.ttftMs),
578      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
579      requestId:
580        (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
581      stop_reason: response.stop_reason ?? undefined,
582    })
583  
584    const inputTokens = response.usage.input_tokens
585    const outputTokens = response.usage.output_tokens
586    const cacheReadInputTokens =
587      (response.usage as BetaUsage).cache_read_input_tokens ?? 0
588    const cacheCreationInputTokens =
589      (response.usage as BetaUsage).cache_creation_input_tokens ?? 0
590    const costUSD =
591      (inputTokens / 1_000_000) * SONNET_COST_PER_MILLION_INPUT_TOKENS +
592      (outputTokens / 1_000_000) * SONNET_COST_PER_MILLION_OUTPUT_TOKENS +
593      (cacheReadInputTokens / 1_000_000) *
594        SONNET_COST_PER_MILLION_PROMPT_CACHE_READ_TOKENS +
595      (cacheCreationInputTokens / 1_000_000) *
596        SONNET_COST_PER_MILLION_PROMPT_CACHE_WRITE_TOKENS
597  
598    addToTotalCost(costUSD, durationMsIncludingRetries)
599  
600    return {
601      message: {
602        ...response,
603        content: normalizeContentFromAPI(response.content),
604        usage: {
605          ...response.usage,
606          cache_read_input_tokens: response.usage.cache_read_input_tokens ?? 0,
607          cache_creation_input_tokens:
608            response.usage.cache_creation_input_tokens ?? 0,
609        },
610      },
611      costUSD,
612      durationMs,
613      type: 'assistant',
614      uuid: randomUUID(),
615    }
616  }
617  
618  function getAssistantMessageFromError(error: unknown): AssistantMessage {
619    if (error instanceof Error && error.message.includes('prompt is too long')) {
620      return createAssistantAPIErrorMessage(PROMPT_TOO_LONG_ERROR_MESSAGE)
621    }
622    if (
623      error instanceof Error &&
624      error.message.includes('Your credit balance is too low')
625    ) {
626      return createAssistantAPIErrorMessage(CREDIT_BALANCE_TOO_LOW_ERROR_MESSAGE)
627    }
628    if (
629      error instanceof Error &&
630      error.message.toLowerCase().includes('x-api-key')
631    ) {
632      return createAssistantAPIErrorMessage(INVALID_API_KEY_ERROR_MESSAGE)
633    }
634    if (error instanceof Error) {
635      return createAssistantAPIErrorMessage(
636        `${API_ERROR_MESSAGE_PREFIX}: ${error.message}`,
637      )
638    }
639    return createAssistantAPIErrorMessage(API_ERROR_MESSAGE_PREFIX)
640  }
641  
642  function addCacheBreakpoints(
643    messages: (UserMessage | AssistantMessage)[],
644  ): MessageParam[] {
645    return messages.map((msg, index) => {
646      return msg.type === 'user'
647        ? userMessageToMessageParam(msg, index > messages.length - 3)
648        : assistantMessageToMessageParam(msg, index > messages.length - 3)
649    })
650  }
651  
652  async function queryHaikuWithPromptCaching({
653    systemPrompt,
654    userPrompt,
655    assistantPrompt,
656    signal,
657  }: {
658    systemPrompt: string[]
659    userPrompt: string
660    assistantPrompt?: string
661    signal?: AbortSignal
662  }): Promise<AssistantMessage> {
663    const anthropic = await getAnthropicClient(SMALL_FAST_MODEL)
664    const model = SMALL_FAST_MODEL
665    const messages = [
666      {
667        role: 'user' as const,
668        content: userPrompt,
669      },
670      ...(assistantPrompt
671        ? [{ role: 'assistant' as const, content: assistantPrompt }]
672        : []),
673    ]
674  
675    const system: TextBlockParam[] = splitSysPromptPrefix(systemPrompt).map(
676      _ => ({
677        ...(PROMPT_CACHING_ENABLED
678          ? { cache_control: { type: 'ephemeral' } }
679          : {}),
680        text: _,
681        type: 'text',
682      }),
683    )
684  
685    logEvent('tengu_api_query', {
686      model,
687      messagesLength: String(JSON.stringify([...system, ...messages]).length),
688      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
689    })
690    let attemptNumber = 0
691    let start = Date.now()
692    const startIncludingRetries = Date.now()
693    let response: StreamResponse
694    let stream: BetaMessageStream | undefined = undefined
695    try {
696      response = await withRetry(async attempt => {
697        attemptNumber = attempt
698        start = Date.now()
699        const s = anthropic.beta.messages.stream(
700          {
701            model,
702            max_tokens: 512,
703            messages,
704            system,
705            temperature: 0,
706            metadata: getMetadata(),
707            stream: true,
708          },
709          { signal },
710        )
711        stream = s
712        return await handleMessageStream(s)
713      })
714    } catch (error) {
715      logError(error)
716      logEvent('tengu_api_error', {
717        error: error instanceof Error ? error.message : String(error),
718        status: error instanceof APIError ? String(error.status) : undefined,
719        model: SMALL_FAST_MODEL,
720        messageCount: String(assistantPrompt ? 2 : 1),
721        durationMs: String(Date.now() - start),
722        durationMsIncludingRetries: String(Date.now() - startIncludingRetries),
723        attempt: String(attemptNumber),
724        provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
725        requestId:
726          (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
727      })
728      return getAssistantMessageFromError(error)
729    }
730  
731    const inputTokens = response.usage.input_tokens
732    const outputTokens = response.usage.output_tokens
733    const cacheReadInputTokens = response.usage.cache_read_input_tokens ?? 0
734    const cacheCreationInputTokens =
735      response.usage.cache_creation_input_tokens ?? 0
736    const costUSD =
737      (inputTokens / 1_000_000) * HAIKU_COST_PER_MILLION_INPUT_TOKENS +
738      (outputTokens / 1_000_000) * HAIKU_COST_PER_MILLION_OUTPUT_TOKENS +
739      (cacheReadInputTokens / 1_000_000) *
740        HAIKU_COST_PER_MILLION_PROMPT_CACHE_READ_TOKENS +
741      (cacheCreationInputTokens / 1_000_000) *
742        HAIKU_COST_PER_MILLION_PROMPT_CACHE_WRITE_TOKENS
743  
744    const durationMs = Date.now() - start
745    const durationMsIncludingRetries = Date.now() - startIncludingRetries
746    addToTotalCost(costUSD, durationMsIncludingRetries)
747  
748    const assistantMessage: AssistantMessage = {
749      durationMs,
750      message: {
751        ...response,
752        content: normalizeContentFromAPI(response.content),
753      },
754      costUSD,
755      uuid: randomUUID(),
756      type: 'assistant',
757    }
758  
759    logEvent('tengu_api_success', {
760      model: SMALL_FAST_MODEL,
761      messageCount: String(assistantPrompt ? 2 : 1),
762      inputTokens: String(inputTokens),
763      outputTokens: String(response.usage.output_tokens),
764      cachedInputTokens: String(response.usage.cache_read_input_tokens ?? 0),
765      uncachedInputTokens: String(
766        response.usage.cache_creation_input_tokens ?? 0,
767      ),
768      durationMs: String(durationMs),
769      durationMsIncludingRetries: String(durationMsIncludingRetries),
770      ttftMs: String(response.ttftMs),
771      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
772      requestId:
773        (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
774      stop_reason: response.stop_reason ?? undefined,
775    })
776  
777    return assistantMessage
778  }
779  
780  async function queryHaikuWithoutPromptCaching({
781    systemPrompt,
782    userPrompt,
783    assistantPrompt,
784    signal,
785  }: {
786    systemPrompt: string[]
787    userPrompt: string
788    assistantPrompt?: string
789    signal?: AbortSignal
790  }): Promise<AssistantMessage> {
791    const anthropic = await getAnthropicClient(SMALL_FAST_MODEL)
792    const model = SMALL_FAST_MODEL
793    const messages = [
794      { role: 'user' as const, content: userPrompt },
795      ...(assistantPrompt
796        ? [{ role: 'assistant' as const, content: assistantPrompt }]
797        : []),
798    ]
799    logEvent('tengu_api_query', {
800      model,
801      messagesLength: String(
802        JSON.stringify([{ systemPrompt }, ...messages]).length,
803      ),
804      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
805    })
806  
807    let attemptNumber = 0
808    let start = Date.now()
809    const startIncludingRetries = Date.now()
810    let response: StreamResponse
811    let stream: BetaMessageStream | undefined = undefined
812    try {
813      response = await withRetry(async attempt => {
814        attemptNumber = attempt
815        start = Date.now()
816        const s = anthropic.beta.messages.stream(
817          {
818            model,
819            max_tokens: 512,
820            messages,
821            system: splitSysPromptPrefix(systemPrompt).map(text => ({
822              type: 'text',
823              text,
824            })),
825            temperature: 0,
826            metadata: getMetadata(),
827            stream: true,
828          },
829          { signal },
830        )
831        stream = s
832        return await handleMessageStream(s)
833      })
834    } catch (error) {
835      logError(error)
836      logEvent('tengu_api_error', {
837        error: error instanceof Error ? error.message : String(error),
838        status: error instanceof APIError ? String(error.status) : undefined,
839        model: SMALL_FAST_MODEL,
840        messageCount: String(assistantPrompt ? 2 : 1),
841        durationMs: String(Date.now() - start),
842        durationMsIncludingRetries: String(Date.now() - startIncludingRetries),
843        attempt: String(attemptNumber),
844        provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
845        requestId:
846          (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
847      })
848      return getAssistantMessageFromError(error)
849    }
850    const durationMs = Date.now() - start
851    const durationMsIncludingRetries = Date.now() - startIncludingRetries
852    logEvent('tengu_api_success', {
853      model: SMALL_FAST_MODEL,
854      messageCount: String(assistantPrompt ? 2 : 1),
855      inputTokens: String(response.usage.input_tokens),
856      outputTokens: String(response.usage.output_tokens),
857      durationMs: String(durationMs),
858      durationMsIncludingRetries: String(durationMsIncludingRetries),
859      attempt: String(attemptNumber),
860      provider: USE_BEDROCK ? 'bedrock' : USE_VERTEX ? 'vertex' : '1p',
861      requestId:
862        (stream as BetaMessageStream | undefined)?.request_id ?? undefined,
863      stop_reason: response.stop_reason ?? undefined,
864    })
865  
866    const inputTokens = response.usage.input_tokens
867    const outputTokens = response.usage.output_tokens
868    const costUSD =
869      (inputTokens / 1_000_000) * HAIKU_COST_PER_MILLION_INPUT_TOKENS +
870      (outputTokens / 1_000_000) * HAIKU_COST_PER_MILLION_OUTPUT_TOKENS
871  
872    addToTotalCost(costUSD, durationMs)
873  
874    const assistantMessage: AssistantMessage = {
875      durationMs,
876      message: {
877        ...response,
878        content: normalizeContentFromAPI(response.content),
879        usage: {
880          ...response.usage,
881          cache_read_input_tokens: 0,
882          cache_creation_input_tokens: 0,
883        },
884      },
885      costUSD,
886      type: 'assistant',
887      uuid: randomUUID(),
888    }
889  
890    return assistantMessage
891  }
892  
893  export async function queryHaiku({
894    systemPrompt = [],
895    userPrompt,
896    assistantPrompt,
897    enablePromptCaching = false,
898    signal,
899  }: {
900    systemPrompt: string[]
901    userPrompt: string
902    assistantPrompt?: string
903    enablePromptCaching?: boolean
904    signal?: AbortSignal
905  }): Promise<AssistantMessage> {
906    return await withVCR(
907      [
908        {
909          message: {
910            role: 'user',
911            content: systemPrompt.map(text => ({ type: 'text', text })),
912          },
913          type: 'user',
914          uuid: randomUUID(),
915        },
916        {
917          message: { role: 'user', content: userPrompt },
918          type: 'user',
919          uuid: randomUUID(),
920        },
921      ],
922      () => {
923        return enablePromptCaching
924          ? queryHaikuWithPromptCaching({
925              systemPrompt,
926              userPrompt,
927              assistantPrompt,
928              signal,
929            })
930          : queryHaikuWithoutPromptCaching({
931              systemPrompt,
932              userPrompt,
933              assistantPrompt,
934              signal,
935            })
936      },
937    )
938  }
939  
940  function getMaxTokensForModel(model: string): number {
941    if (model.includes('3-5')) {
942      return 8192
943    }
944    if (model.includes('haiku')) {
945      return 8192
946    }
947    return 20_000
948  }