/ src / services / mcpClient.ts
mcpClient.ts
  1  import { zipObject } from 'lodash-es'
  2  import {
  3    getCurrentProjectConfig,
  4    McpServerConfig,
  5    saveCurrentProjectConfig,
  6    getGlobalConfig,
  7    saveGlobalConfig,
  8    getMcprcConfig,
  9    addMcprcServerForTesting,
 10    removeMcprcServerForTesting,
 11  } from '../utils/config.js'
 12  import { existsSync, readFileSync, writeFileSync } from 'fs'
 13  import { join } from 'path'
 14  import { getCwd } from '../utils/state.js'
 15  import { safeParseJSON } from '../utils/json.js'
 16  import {
 17    ImageBlockParam,
 18    MessageParam,
 19    ToolResultBlockParam,
 20  } from '@anthropic-ai/sdk/resources/index.mjs'
 21  import { Client } from '@modelcontextprotocol/sdk/client/index.js'
 22  import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'
 23  import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
 24  import {
 25    CallToolResultSchema,
 26    ClientRequest,
 27    ListPromptsResult,
 28    ListPromptsResultSchema,
 29    ListToolsResult,
 30    ListToolsResultSchema,
 31    Result,
 32    ResultSchema,
 33  } from '@modelcontextprotocol/sdk/types.js'
 34  import { memoize, pickBy } from 'lodash-es'
 35  import type { Tool } from '../Tool.js'
 36  import { MCPTool } from '../tools/MCPTool/MCPTool.js'
 37  import { logMCPError } from '../utils/log.js'
 38  import { Command } from '../commands.js'
 39  import { logEvent } from '../services/statsig.js'
 40  
 41  type McpName = string
 42  
 43  export function parseEnvVars(
 44    rawEnvArgs: string[] | undefined,
 45  ): Record<string, string> {
 46    const parsedEnv: Record<string, string> = {}
 47  
 48    // Parse individual env vars
 49    if (rawEnvArgs) {
 50      for (const envStr of rawEnvArgs) {
 51        const [key, ...valueParts] = envStr.split('=')
 52        if (!key || valueParts.length === 0) {
 53          throw new Error(
 54            `Invalid environment variable format: ${envStr}, environment variables should be added as: -e KEY1=value1 -e KEY2=value2`,
 55          )
 56        }
 57        parsedEnv[key] = valueParts.join('=')
 58      }
 59    }
 60    return parsedEnv
 61  }
 62  
 63  const VALID_SCOPES = ['project', 'global', 'mcprc'] as const
 64  type ConfigScope = (typeof VALID_SCOPES)[number]
 65  const EXTERNAL_SCOPES = ['project', 'global'] as ConfigScope[]
 66  
 67  export function ensureConfigScope(scope?: string): ConfigScope {
 68    if (!scope) return 'project'
 69  
 70    const scopesToCheck =
 71      process.env.USER_TYPE === 'external' ? EXTERNAL_SCOPES : VALID_SCOPES
 72  
 73    if (!scopesToCheck.includes(scope as ConfigScope)) {
 74      throw new Error(
 75        `Invalid scope: ${scope}. Must be one of: ${scopesToCheck.join(', ')}`,
 76      )
 77    }
 78  
 79    return scope as ConfigScope
 80  }
 81  
 82  export function addMcpServer(
 83    name: McpName,
 84    server: McpServerConfig,
 85    scope: ConfigScope = 'project',
 86  ): void {
 87    if (scope === 'mcprc') {
 88      if (process.env.NODE_ENV === 'test') {
 89        addMcprcServerForTesting(name, server)
 90      } else {
 91        const mcprcPath = join(getCwd(), '.mcprc')
 92        let mcprcConfig: Record<string, McpServerConfig> = {}
 93  
 94        // Read existing config if present
 95        if (existsSync(mcprcPath)) {
 96          try {
 97            const mcprcContent = readFileSync(mcprcPath, 'utf-8')
 98            const existingConfig = safeParseJSON(mcprcContent)
 99            if (existingConfig && typeof existingConfig === 'object') {
100              mcprcConfig = existingConfig as Record<string, McpServerConfig>
101            }
102          } catch {
103            // If we can't read/parse, start with empty config
104          }
105        }
106  
107        // Add the server
108        mcprcConfig[name] = server
109  
110        // Write back to .mcprc
111        try {
112          writeFileSync(mcprcPath, JSON.stringify(mcprcConfig, null, 2), 'utf-8')
113        } catch (error) {
114          throw new Error(`Failed to write to .mcprc: ${error}`)
115        }
116      }
117    } else if (scope === 'global') {
118      const config = getGlobalConfig()
119      if (!config.mcpServers) {
120        config.mcpServers = {}
121      }
122      config.mcpServers[name] = server
123      saveGlobalConfig(config)
124    } else {
125      const config = getCurrentProjectConfig()
126      if (!config.mcpServers) {
127        config.mcpServers = {}
128      }
129      config.mcpServers[name] = server
130      saveCurrentProjectConfig(config)
131    }
132  }
133  
134  export function removeMcpServer(
135    name: McpName,
136    scope: ConfigScope = 'project',
137  ): void {
138    if (scope === 'mcprc') {
139      if (process.env.NODE_ENV === 'test') {
140        removeMcprcServerForTesting(name)
141      } else {
142        const mcprcPath = join(getCwd(), '.mcprc')
143        if (!existsSync(mcprcPath)) {
144          throw new Error('No .mcprc file found in this directory')
145        }
146  
147        try {
148          const mcprcContent = readFileSync(mcprcPath, 'utf-8')
149          const mcprcConfig = safeParseJSON(mcprcContent) as Record<
150            string,
151            McpServerConfig
152          > | null
153  
154          if (
155            !mcprcConfig ||
156            typeof mcprcConfig !== 'object' ||
157            !mcprcConfig[name]
158          ) {
159            throw new Error(`No MCP server found with name: ${name} in .mcprc`)
160          }
161  
162          delete mcprcConfig[name]
163          writeFileSync(mcprcPath, JSON.stringify(mcprcConfig, null, 2), 'utf-8')
164        } catch (error) {
165          if (error instanceof Error) {
166            throw error
167          }
168          throw new Error(`Failed to remove from .mcprc: ${error}`)
169        }
170      }
171    } else if (scope === 'global') {
172      const config = getGlobalConfig()
173      if (!config.mcpServers?.[name]) {
174        throw new Error(`No global MCP server found with name: ${name}`)
175      }
176      delete config.mcpServers[name]
177      saveGlobalConfig(config)
178    } else {
179      const config = getCurrentProjectConfig()
180      if (!config.mcpServers?.[name]) {
181        throw new Error(`No local MCP server found with name: ${name}`)
182      }
183      delete config.mcpServers[name]
184      saveCurrentProjectConfig(config)
185    }
186  }
187  
188  export function listMCPServers(): Record<string, McpServerConfig> {
189    const globalConfig = getGlobalConfig()
190    const mcprcConfig = getMcprcConfig()
191    const projectConfig = getCurrentProjectConfig()
192    return {
193      ...(globalConfig.mcpServers ?? {}),
194      ...(mcprcConfig ?? {}), // mcprc configs override global ones
195      ...(projectConfig.mcpServers ?? {}), // Project configs override mcprc ones
196    }
197  }
198  
199  export type ScopedMcpServerConfig = McpServerConfig & {
200    scope: ConfigScope
201  }
202  
203  export function getMcpServer(name: McpName): ScopedMcpServerConfig | undefined {
204    const projectConfig = getCurrentProjectConfig()
205    const mcprcConfig = getMcprcConfig()
206    const globalConfig = getGlobalConfig()
207  
208    // Check each scope in order of precedence
209    if (projectConfig.mcpServers?.[name]) {
210      return { ...projectConfig.mcpServers[name], scope: 'project' }
211    }
212  
213    if (mcprcConfig?.[name]) {
214      return { ...mcprcConfig[name], scope: 'mcprc' }
215    }
216  
217    if (globalConfig.mcpServers?.[name]) {
218      return { ...globalConfig.mcpServers[name], scope: 'global' }
219    }
220  
221    return undefined
222  }
223  
224  async function connectToServer(
225    name: string,
226    serverRef: McpServerConfig,
227  ): Promise<Client> {
228    const transport =
229      serverRef.type === 'sse'
230        ? new SSEClientTransport(new URL(serverRef.url))
231        : new StdioClientTransport({
232            command: serverRef.command,
233            args: serverRef.args,
234            env: {
235              ...process.env,
236              ...serverRef.env,
237            } as Record<string, string>,
238            stderr: 'pipe', // prevents error output from the MCP server from printing to the UI
239          })
240  
241    const client = new Client(
242      {
243        name: 'claude',
244        version: '0.1.0',
245      },
246      {
247        capabilities: {},
248      },
249    )
250  
251    // Add a timeout to connection attempts to prevent tests from hanging indefinitely
252    const CONNECTION_TIMEOUT_MS = 5000
253    const connectPromise = client.connect(transport)
254    const timeoutPromise = new Promise<never>((_, reject) => {
255      const timeoutId = setTimeout(() => {
256        reject(
257          new Error(
258            `Connection to MCP server "${name}" timed out after ${CONNECTION_TIMEOUT_MS}ms`,
259          ),
260        )
261      }, CONNECTION_TIMEOUT_MS)
262  
263      // Clean up timeout if connect resolves or rejects
264      connectPromise.then(
265        () => clearTimeout(timeoutId),
266        () => clearTimeout(timeoutId),
267      )
268    })
269  
270    await Promise.race([connectPromise, timeoutPromise])
271  
272    if (serverRef.type === 'stdio') {
273      ;(transport as StdioClientTransport).stderr?.on('data', (data: Buffer) => {
274        const errorText = data.toString().trim()
275        if (errorText) {
276          logMCPError(name, `Server stderr: ${errorText}`)
277        }
278      })
279    }
280    return client
281  }
282  
283  type ConnectedClient = {
284    client: Client
285    name: string
286    type: 'connected'
287  }
288  type FailedClient = {
289    name: string
290    type: 'failed'
291  }
292  export type WrappedClient = ConnectedClient | FailedClient
293  
294  export function getMcprcServerStatus(
295    serverName: string,
296  ): 'approved' | 'rejected' | 'pending' {
297    const config = getCurrentProjectConfig()
298    if (config.approvedMcprcServers?.includes(serverName)) {
299      return 'approved'
300    }
301    if (config.rejectedMcprcServers?.includes(serverName)) {
302      return 'rejected'
303    }
304    return 'pending'
305  }
306  
307  export const getClients = memoize(async (): Promise<WrappedClient[]> => {
308    // TODO: This is a temporary fix for a hang during npm run verify in CI.
309    // We need to investigate why MCP client connections hang in CI verify but not in CI tests.
310    if (process.env.CI && process.env.NODE_ENV !== 'test') {
311      return []
312    }
313  
314    const globalServers = getGlobalConfig().mcpServers ?? {}
315    const mcprcServers = getMcprcConfig()
316    const projectServers = getCurrentProjectConfig().mcpServers ?? {}
317  
318    // Filter mcprc servers to only include approved ones
319    const approvedMcprcServers = pickBy(
320      mcprcServers,
321      (_, name) => getMcprcServerStatus(name) === 'approved',
322    )
323  
324    const allServers = {
325      ...globalServers,
326      ...approvedMcprcServers, // Approved .mcprc servers override global ones
327      ...projectServers, // Project servers take highest precedence
328    }
329  
330    return await Promise.all(
331      Object.entries(allServers).map(async ([name, serverRef]) => {
332        try {
333          const client = await connectToServer(name, serverRef)
334          logEvent('tengu_mcp_server_connection_succeeded', {})
335          return { name, client, type: 'connected' as const }
336        } catch (error) {
337          logEvent('tengu_mcp_server_connection_failed', {})
338          logMCPError(
339            name,
340            `Connection failed: ${error instanceof Error ? error.message : String(error)}`,
341          )
342          return { name, type: 'failed' as const }
343        }
344      }),
345    )
346  })
347  
348  async function requestAll<
349    ResultT extends Result,
350    ResultSchemaT extends typeof ResultSchema,
351  >(
352    req: ClientRequest,
353    resultSchema: ResultSchemaT,
354    requiredCapability: string,
355  ): Promise<{ client: ConnectedClient; result: ResultT }[]> {
356    const clients = await getClients()
357    const results = await Promise.allSettled(
358      clients.map(async client => {
359        if (client.type === 'failed') return null
360  
361        try {
362          const capabilities = await client.client.getServerCapabilities()
363          if (!capabilities?.[requiredCapability]) {
364            return null
365          }
366          return {
367            client,
368            result: (await client.client.request(req, resultSchema)) as ResultT,
369          }
370        } catch (error) {
371          if (client.type === 'connected') {
372            logMCPError(
373              client.name,
374              `Failed to request '${req.method}': ${error instanceof Error ? error.message : String(error)}`,
375            )
376          }
377          return null
378        }
379      }),
380    )
381    return results
382      .filter(
383        (
384          result,
385        ): result is PromiseFulfilledResult<{
386          client: ConnectedClient
387          result: ResultT
388        } | null> => result.status === 'fulfilled',
389      )
390      .map(result => result.value)
391      .filter(
392        (result): result is { client: ConnectedClient; result: ResultT } =>
393          result !== null,
394      )
395  }
396  
397  export const getMCPTools = memoize(async (): Promise<Tool[]> => {
398    const toolsList = await requestAll<
399      ListToolsResult,
400      typeof ListToolsResultSchema
401    >(
402      {
403        method: 'tools/list',
404      },
405      ListToolsResultSchema,
406      'tools',
407    )
408  
409    // TODO: Add zod schema validation
410    return toolsList.flatMap(({ client, result: { tools } }) =>
411      tools.map(
412        (tool): Tool => ({
413          ...MCPTool,
414          name: 'mcp__' + client.name + '__' + tool.name,
415          async description() {
416            return tool.description ?? ''
417          },
418          async prompt() {
419            return tool.description ?? ''
420          },
421          inputJSONSchema: tool.inputSchema as Tool['inputJSONSchema'],
422          async *call(args: Record<string, unknown>) {
423            const data = await callMCPTool({ client, tool: tool.name, args })
424            yield {
425              type: 'result' as const,
426              data,
427              resultForAssistant: data,
428            }
429          },
430          userFacingName() {
431            return `${client.name}:${tool.name} (MCP)`
432          },
433        }),
434      ),
435    )
436  })
437  
438  async function callMCPTool({
439    client: { client, name },
440    tool,
441    args,
442  }: {
443    client: ConnectedClient
444    tool: string
445    args: Record<string, unknown>
446  }): Promise<ToolResultBlockParam['content']> {
447    const result = await client.callTool(
448      {
449        name: tool,
450        arguments: args,
451      },
452      CallToolResultSchema,
453    )
454  
455    if ('isError' in result && result.isError) {
456      const errorMessage = `Error calling tool ${tool}: ${result.error}`
457      logMCPError(name, errorMessage)
458      throw Error(errorMessage)
459    }
460  
461    // Handle toolResult-type response
462    if ('toolResult' in result) {
463      return String(result.toolResult)
464    }
465  
466    // Handle content array response
467    if ('content' in result && Array.isArray(result.content)) {
468      return result.content.map(item => {
469        if (item.type === 'image') {
470          return {
471            type: 'image',
472            source: {
473              type: 'base64',
474              data: String(item.data),
475              media_type: item.mimeType as ImageBlockParam.Source['media_type'],
476            },
477          }
478        }
479        return item
480      })
481    }
482  
483    throw Error(`Unexpected response format from tool ${tool}`)
484  }
485  
486  export const getMCPCommands = memoize(async (): Promise<Command[]> => {
487    const results = await requestAll<
488      ListPromptsResult,
489      typeof ListPromptsResultSchema
490    >(
491      {
492        method: 'prompts/list',
493      },
494      ListPromptsResultSchema,
495      'prompts',
496    )
497  
498    return results.flatMap(({ client, result }) =>
499      result.prompts?.map(_ => {
500        const argNames = Object.values(_.arguments ?? {}).map(k => k.name)
501        return {
502          type: 'prompt',
503          name: 'mcp__' + client.name + '__' + _.name,
504          description: _.description ?? '',
505          isEnabled: true,
506          isHidden: false,
507          progressMessage: 'running',
508          userFacingName() {
509            return `${client.name}:${_.name} (MCP)`
510          },
511          argNames,
512          async getPromptForCommand(args: string) {
513            const argsArray = args.split(' ')
514            return await runCommand(
515              { name: _.name, client },
516              zipObject(argNames, argsArray),
517            )
518          },
519        }
520      }),
521    )
522  })
523  
524  export async function runCommand(
525    { name, client }: { name: string; client: ConnectedClient },
526    args: Record<string, string>,
527  ): Promise<MessageParam[]> {
528    try {
529      const result = await client.client.getPrompt({ name, arguments: args })
530      // TODO: Support type == resource
531      return result.messages.map(
532        (message): MessageParam => ({
533          role: message.role,
534          content: [
535            message.content.type === 'text'
536              ? {
537                  type: 'text',
538                  text: message.content.text,
539                }
540              : {
541                  type: 'image',
542                  source: {
543                    data: String(message.content.data),
544                    media_type: message.content
545                      .mimeType as ImageBlockParam.Source['media_type'],
546                    type: 'base64',
547                  },
548                },
549          ],
550        }),
551      )
552    } catch (error) {
553      logMCPError(
554        client.name,
555        `Error running command '${name}': ${error instanceof Error ? error.message : String(error)}`,
556      )
557      throw error
558    }
559  }