mcpClient.ts
1 import { zipObject } from 'lodash-es' 2 import { 3 getCurrentProjectConfig, 4 McpServerConfig, 5 saveCurrentProjectConfig, 6 getGlobalConfig, 7 saveGlobalConfig, 8 getMcprcConfig, 9 addMcprcServerForTesting, 10 removeMcprcServerForTesting, 11 } from '../utils/config.js' 12 import { existsSync, readFileSync, writeFileSync } from 'fs' 13 import { join } from 'path' 14 import { getCwd } from '../utils/state.js' 15 import { safeParseJSON } from '../utils/json.js' 16 import { 17 ImageBlockParam, 18 MessageParam, 19 ToolResultBlockParam, 20 } from '@anthropic-ai/sdk/resources/index.mjs' 21 import { Client } from '@modelcontextprotocol/sdk/client/index.js' 22 import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' 23 import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' 24 import { 25 CallToolResultSchema, 26 ClientRequest, 27 ListPromptsResult, 28 ListPromptsResultSchema, 29 ListToolsResult, 30 ListToolsResultSchema, 31 Result, 32 ResultSchema, 33 } from '@modelcontextprotocol/sdk/types.js' 34 import { memoize, pickBy } from 'lodash-es' 35 import type { Tool } from '../Tool.js' 36 import { MCPTool } from '../tools/MCPTool/MCPTool.js' 37 import { logMCPError } from '../utils/log.js' 38 import { Command } from '../commands.js' 39 import { logEvent } from '../services/statsig.js' 40 41 type McpName = string 42 43 export function parseEnvVars( 44 rawEnvArgs: string[] | undefined, 45 ): Record<string, string> { 46 const parsedEnv: Record<string, string> = {} 47 48 // Parse individual env vars 49 if (rawEnvArgs) { 50 for (const envStr of rawEnvArgs) { 51 const [key, ...valueParts] = envStr.split('=') 52 if (!key || valueParts.length === 0) { 53 throw new Error( 54 `Invalid environment variable format: ${envStr}, environment variables should be added as: -e KEY1=value1 -e KEY2=value2`, 55 ) 56 } 57 parsedEnv[key] = valueParts.join('=') 58 } 59 } 60 return parsedEnv 61 } 62 63 const VALID_SCOPES = ['project', 'global', 'mcprc'] as const 64 type ConfigScope = (typeof VALID_SCOPES)[number] 65 const EXTERNAL_SCOPES = ['project', 'global'] as ConfigScope[] 66 67 export function ensureConfigScope(scope?: string): ConfigScope { 68 if (!scope) return 'project' 69 70 const scopesToCheck = 71 process.env.USER_TYPE === 'external' ? EXTERNAL_SCOPES : VALID_SCOPES 72 73 if (!scopesToCheck.includes(scope as ConfigScope)) { 74 throw new Error( 75 `Invalid scope: ${scope}. Must be one of: ${scopesToCheck.join(', ')}`, 76 ) 77 } 78 79 return scope as ConfigScope 80 } 81 82 export function addMcpServer( 83 name: McpName, 84 server: McpServerConfig, 85 scope: ConfigScope = 'project', 86 ): void { 87 if (scope === 'mcprc') { 88 if (process.env.NODE_ENV === 'test') { 89 addMcprcServerForTesting(name, server) 90 } else { 91 const mcprcPath = join(getCwd(), '.mcprc') 92 let mcprcConfig: Record<string, McpServerConfig> = {} 93 94 // Read existing config if present 95 if (existsSync(mcprcPath)) { 96 try { 97 const mcprcContent = readFileSync(mcprcPath, 'utf-8') 98 const existingConfig = safeParseJSON(mcprcContent) 99 if (existingConfig && typeof existingConfig === 'object') { 100 mcprcConfig = existingConfig as Record<string, McpServerConfig> 101 } 102 } catch { 103 // If we can't read/parse, start with empty config 104 } 105 } 106 107 // Add the server 108 mcprcConfig[name] = server 109 110 // Write back to .mcprc 111 try { 112 writeFileSync(mcprcPath, JSON.stringify(mcprcConfig, null, 2), 'utf-8') 113 } catch (error) { 114 throw new Error(`Failed to write to .mcprc: ${error}`) 115 } 116 } 117 } else if (scope === 'global') { 118 const config = getGlobalConfig() 119 if (!config.mcpServers) { 120 config.mcpServers = {} 121 } 122 config.mcpServers[name] = server 123 saveGlobalConfig(config) 124 } else { 125 const config = getCurrentProjectConfig() 126 if (!config.mcpServers) { 127 config.mcpServers = {} 128 } 129 config.mcpServers[name] = server 130 saveCurrentProjectConfig(config) 131 } 132 } 133 134 export function removeMcpServer( 135 name: McpName, 136 scope: ConfigScope = 'project', 137 ): void { 138 if (scope === 'mcprc') { 139 if (process.env.NODE_ENV === 'test') { 140 removeMcprcServerForTesting(name) 141 } else { 142 const mcprcPath = join(getCwd(), '.mcprc') 143 if (!existsSync(mcprcPath)) { 144 throw new Error('No .mcprc file found in this directory') 145 } 146 147 try { 148 const mcprcContent = readFileSync(mcprcPath, 'utf-8') 149 const mcprcConfig = safeParseJSON(mcprcContent) as Record< 150 string, 151 McpServerConfig 152 > | null 153 154 if ( 155 !mcprcConfig || 156 typeof mcprcConfig !== 'object' || 157 !mcprcConfig[name] 158 ) { 159 throw new Error(`No MCP server found with name: ${name} in .mcprc`) 160 } 161 162 delete mcprcConfig[name] 163 writeFileSync(mcprcPath, JSON.stringify(mcprcConfig, null, 2), 'utf-8') 164 } catch (error) { 165 if (error instanceof Error) { 166 throw error 167 } 168 throw new Error(`Failed to remove from .mcprc: ${error}`) 169 } 170 } 171 } else if (scope === 'global') { 172 const config = getGlobalConfig() 173 if (!config.mcpServers?.[name]) { 174 throw new Error(`No global MCP server found with name: ${name}`) 175 } 176 delete config.mcpServers[name] 177 saveGlobalConfig(config) 178 } else { 179 const config = getCurrentProjectConfig() 180 if (!config.mcpServers?.[name]) { 181 throw new Error(`No local MCP server found with name: ${name}`) 182 } 183 delete config.mcpServers[name] 184 saveCurrentProjectConfig(config) 185 } 186 } 187 188 export function listMCPServers(): Record<string, McpServerConfig> { 189 const globalConfig = getGlobalConfig() 190 const mcprcConfig = getMcprcConfig() 191 const projectConfig = getCurrentProjectConfig() 192 return { 193 ...(globalConfig.mcpServers ?? {}), 194 ...(mcprcConfig ?? {}), // mcprc configs override global ones 195 ...(projectConfig.mcpServers ?? {}), // Project configs override mcprc ones 196 } 197 } 198 199 export type ScopedMcpServerConfig = McpServerConfig & { 200 scope: ConfigScope 201 } 202 203 export function getMcpServer(name: McpName): ScopedMcpServerConfig | undefined { 204 const projectConfig = getCurrentProjectConfig() 205 const mcprcConfig = getMcprcConfig() 206 const globalConfig = getGlobalConfig() 207 208 // Check each scope in order of precedence 209 if (projectConfig.mcpServers?.[name]) { 210 return { ...projectConfig.mcpServers[name], scope: 'project' } 211 } 212 213 if (mcprcConfig?.[name]) { 214 return { ...mcprcConfig[name], scope: 'mcprc' } 215 } 216 217 if (globalConfig.mcpServers?.[name]) { 218 return { ...globalConfig.mcpServers[name], scope: 'global' } 219 } 220 221 return undefined 222 } 223 224 async function connectToServer( 225 name: string, 226 serverRef: McpServerConfig, 227 ): Promise<Client> { 228 const transport = 229 serverRef.type === 'sse' 230 ? new SSEClientTransport(new URL(serverRef.url)) 231 : new StdioClientTransport({ 232 command: serverRef.command, 233 args: serverRef.args, 234 env: { 235 ...process.env, 236 ...serverRef.env, 237 } as Record<string, string>, 238 stderr: 'pipe', // prevents error output from the MCP server from printing to the UI 239 }) 240 241 const client = new Client( 242 { 243 name: 'claude', 244 version: '0.1.0', 245 }, 246 { 247 capabilities: {}, 248 }, 249 ) 250 251 // Add a timeout to connection attempts to prevent tests from hanging indefinitely 252 const CONNECTION_TIMEOUT_MS = 5000 253 const connectPromise = client.connect(transport) 254 const timeoutPromise = new Promise<never>((_, reject) => { 255 const timeoutId = setTimeout(() => { 256 reject( 257 new Error( 258 `Connection to MCP server "${name}" timed out after ${CONNECTION_TIMEOUT_MS}ms`, 259 ), 260 ) 261 }, CONNECTION_TIMEOUT_MS) 262 263 // Clean up timeout if connect resolves or rejects 264 connectPromise.then( 265 () => clearTimeout(timeoutId), 266 () => clearTimeout(timeoutId), 267 ) 268 }) 269 270 await Promise.race([connectPromise, timeoutPromise]) 271 272 if (serverRef.type === 'stdio') { 273 ;(transport as StdioClientTransport).stderr?.on('data', (data: Buffer) => { 274 const errorText = data.toString().trim() 275 if (errorText) { 276 logMCPError(name, `Server stderr: ${errorText}`) 277 } 278 }) 279 } 280 return client 281 } 282 283 type ConnectedClient = { 284 client: Client 285 name: string 286 type: 'connected' 287 } 288 type FailedClient = { 289 name: string 290 type: 'failed' 291 } 292 export type WrappedClient = ConnectedClient | FailedClient 293 294 export function getMcprcServerStatus( 295 serverName: string, 296 ): 'approved' | 'rejected' | 'pending' { 297 const config = getCurrentProjectConfig() 298 if (config.approvedMcprcServers?.includes(serverName)) { 299 return 'approved' 300 } 301 if (config.rejectedMcprcServers?.includes(serverName)) { 302 return 'rejected' 303 } 304 return 'pending' 305 } 306 307 export const getClients = memoize(async (): Promise<WrappedClient[]> => { 308 // TODO: This is a temporary fix for a hang during npm run verify in CI. 309 // We need to investigate why MCP client connections hang in CI verify but not in CI tests. 310 if (process.env.CI && process.env.NODE_ENV !== 'test') { 311 return [] 312 } 313 314 const globalServers = getGlobalConfig().mcpServers ?? {} 315 const mcprcServers = getMcprcConfig() 316 const projectServers = getCurrentProjectConfig().mcpServers ?? {} 317 318 // Filter mcprc servers to only include approved ones 319 const approvedMcprcServers = pickBy( 320 mcprcServers, 321 (_, name) => getMcprcServerStatus(name) === 'approved', 322 ) 323 324 const allServers = { 325 ...globalServers, 326 ...approvedMcprcServers, // Approved .mcprc servers override global ones 327 ...projectServers, // Project servers take highest precedence 328 } 329 330 return await Promise.all( 331 Object.entries(allServers).map(async ([name, serverRef]) => { 332 try { 333 const client = await connectToServer(name, serverRef) 334 logEvent('tengu_mcp_server_connection_succeeded', {}) 335 return { name, client, type: 'connected' as const } 336 } catch (error) { 337 logEvent('tengu_mcp_server_connection_failed', {}) 338 logMCPError( 339 name, 340 `Connection failed: ${error instanceof Error ? error.message : String(error)}`, 341 ) 342 return { name, type: 'failed' as const } 343 } 344 }), 345 ) 346 }) 347 348 async function requestAll< 349 ResultT extends Result, 350 ResultSchemaT extends typeof ResultSchema, 351 >( 352 req: ClientRequest, 353 resultSchema: ResultSchemaT, 354 requiredCapability: string, 355 ): Promise<{ client: ConnectedClient; result: ResultT }[]> { 356 const clients = await getClients() 357 const results = await Promise.allSettled( 358 clients.map(async client => { 359 if (client.type === 'failed') return null 360 361 try { 362 const capabilities = await client.client.getServerCapabilities() 363 if (!capabilities?.[requiredCapability]) { 364 return null 365 } 366 return { 367 client, 368 result: (await client.client.request(req, resultSchema)) as ResultT, 369 } 370 } catch (error) { 371 if (client.type === 'connected') { 372 logMCPError( 373 client.name, 374 `Failed to request '${req.method}': ${error instanceof Error ? error.message : String(error)}`, 375 ) 376 } 377 return null 378 } 379 }), 380 ) 381 return results 382 .filter( 383 ( 384 result, 385 ): result is PromiseFulfilledResult<{ 386 client: ConnectedClient 387 result: ResultT 388 } | null> => result.status === 'fulfilled', 389 ) 390 .map(result => result.value) 391 .filter( 392 (result): result is { client: ConnectedClient; result: ResultT } => 393 result !== null, 394 ) 395 } 396 397 export const getMCPTools = memoize(async (): Promise<Tool[]> => { 398 const toolsList = await requestAll< 399 ListToolsResult, 400 typeof ListToolsResultSchema 401 >( 402 { 403 method: 'tools/list', 404 }, 405 ListToolsResultSchema, 406 'tools', 407 ) 408 409 // TODO: Add zod schema validation 410 return toolsList.flatMap(({ client, result: { tools } }) => 411 tools.map( 412 (tool): Tool => ({ 413 ...MCPTool, 414 name: 'mcp__' + client.name + '__' + tool.name, 415 async description() { 416 return tool.description ?? '' 417 }, 418 async prompt() { 419 return tool.description ?? '' 420 }, 421 inputJSONSchema: tool.inputSchema as Tool['inputJSONSchema'], 422 async *call(args: Record<string, unknown>) { 423 const data = await callMCPTool({ client, tool: tool.name, args }) 424 yield { 425 type: 'result' as const, 426 data, 427 resultForAssistant: data, 428 } 429 }, 430 userFacingName() { 431 return `${client.name}:${tool.name} (MCP)` 432 }, 433 }), 434 ), 435 ) 436 }) 437 438 async function callMCPTool({ 439 client: { client, name }, 440 tool, 441 args, 442 }: { 443 client: ConnectedClient 444 tool: string 445 args: Record<string, unknown> 446 }): Promise<ToolResultBlockParam['content']> { 447 const result = await client.callTool( 448 { 449 name: tool, 450 arguments: args, 451 }, 452 CallToolResultSchema, 453 ) 454 455 if ('isError' in result && result.isError) { 456 const errorMessage = `Error calling tool ${tool}: ${result.error}` 457 logMCPError(name, errorMessage) 458 throw Error(errorMessage) 459 } 460 461 // Handle toolResult-type response 462 if ('toolResult' in result) { 463 return String(result.toolResult) 464 } 465 466 // Handle content array response 467 if ('content' in result && Array.isArray(result.content)) { 468 return result.content.map(item => { 469 if (item.type === 'image') { 470 return { 471 type: 'image', 472 source: { 473 type: 'base64', 474 data: String(item.data), 475 media_type: item.mimeType as ImageBlockParam.Source['media_type'], 476 }, 477 } 478 } 479 return item 480 }) 481 } 482 483 throw Error(`Unexpected response format from tool ${tool}`) 484 } 485 486 export const getMCPCommands = memoize(async (): Promise<Command[]> => { 487 const results = await requestAll< 488 ListPromptsResult, 489 typeof ListPromptsResultSchema 490 >( 491 { 492 method: 'prompts/list', 493 }, 494 ListPromptsResultSchema, 495 'prompts', 496 ) 497 498 return results.flatMap(({ client, result }) => 499 result.prompts?.map(_ => { 500 const argNames = Object.values(_.arguments ?? {}).map(k => k.name) 501 return { 502 type: 'prompt', 503 name: 'mcp__' + client.name + '__' + _.name, 504 description: _.description ?? '', 505 isEnabled: true, 506 isHidden: false, 507 progressMessage: 'running', 508 userFacingName() { 509 return `${client.name}:${_.name} (MCP)` 510 }, 511 argNames, 512 async getPromptForCommand(args: string) { 513 const argsArray = args.split(' ') 514 return await runCommand( 515 { name: _.name, client }, 516 zipObject(argNames, argsArray), 517 ) 518 }, 519 } 520 }), 521 ) 522 }) 523 524 export async function runCommand( 525 { name, client }: { name: string; client: ConnectedClient }, 526 args: Record<string, string>, 527 ): Promise<MessageParam[]> { 528 try { 529 const result = await client.client.getPrompt({ name, arguments: args }) 530 // TODO: Support type == resource 531 return result.messages.map( 532 (message): MessageParam => ({ 533 role: message.role, 534 content: [ 535 message.content.type === 'text' 536 ? { 537 type: 'text', 538 text: message.content.text, 539 } 540 : { 541 type: 'image', 542 source: { 543 data: String(message.content.data), 544 media_type: message.content 545 .mimeType as ImageBlockParam.Source['media_type'], 546 type: 'base64', 547 }, 548 }, 549 ], 550 }), 551 ) 552 } catch (error) { 553 logMCPError( 554 client.name, 555 `Error running command '${name}': ${error instanceof Error ? error.message : String(error)}`, 556 ) 557 throw error 558 } 559 }