/ src / renderer / store / mcp.ts
mcp.ts
  1  import { defineStore } from 'pinia'
  2  import { watch } from 'vue'
  3  import type {
  4    ChatCompletionRequestContent,
  5    ChatCompletionPromptMessage
  6  } from '@/renderer/types/message'
  7  
  8  import type { MCPAPI, McpObject, ToolType, McpToolType, DXTAPI } from '@/types/mcp'
  9  import { useStdioStore } from '@/renderer/store/stdio'
 10  
 11  import { merge, mapValues } from 'lodash'
 12  
 13  type McpPrimitiveType = 'tools' | 'resources' | 'prompts' | 'metadata'
 14  type AllowedPrimitive = Exclude<McpPrimitiveType, 'metadata'>
 15  type McpMethodType =
 16    | { type: 'list'; fn: () => any }
 17    | { type: 'get'; fn: () => any }
 18    | { type: 'read'; fn: () => any }
 19    | { type: 'call'; fn: () => any }
 20    | { type: 'templates/list'; fn: () => any }
 21    | string
 22  export type McpServerApi = MCPAPI | undefined
 23  
 24  export function getAllowedPrimitive(item: McpObject): AllowedPrimitive[] {
 25    if (!item) return []
 26  
 27    return (Object.keys(item) as Array<keyof typeof item>).filter((key) =>
 28      ['tools', 'resources', 'prompts'].includes(key as AllowedPrimitive)
 29    ) as AllowedPrimitive[]
 30  }
 31  
 32  export function getRawServers(): McpServerApi {
 33    return window.mcpServers?.get()
 34  }
 35  
 36  export function getServers(): McpServerApi {
 37    const mcpServers = getRawServers()
 38    const stdioServers = useStdioStore().configValues
 39  
 40    const renamedStdioServers = mapValues(stdioServers, (v, k) => {
 41      return {
 42        metadata: {
 43          name: k,
 44          type: 'metadata__stdio_config',
 45          config: v
 46        }
 47      }
 48    })
 49  
 50    merge(mcpServers, renamedStdioServers)
 51  
 52    return mcpServers
 53  }
 54  
 55  export function getDxtManifest(): DXTAPI | undefined {
 56    return window.dxtManifest?.get()
 57  }
 58  
 59  export interface FunctionType {
 60    type: 'function'
 61    function: ToolType
 62  }
 63  
 64  export interface McpCoreType {
 65    server: string
 66    primitive: McpPrimitiveType
 67    method: McpMethodType
 68  }
 69  
 70  function getObjectKeys(o: unknown) {
 71    return o && typeof o === 'object' && !Array.isArray(o) ? Object.keys(o) : []
 72  }
 73  
 74  function getSelectedByServer(
 75    serverName: string,
 76    selectedIndex: number | undefined,
 77    _version: number
 78  ): McpCoreType | null {
 79    const mcpServers = getServers()
 80    if (!mcpServers || !mcpServers[serverName]) return null
 81    console.log(mcpServers[serverName])
 82    if (typeof selectedIndex === 'number') {
 83      const selectedPrimitive = {
 84        server: serverName,
 85        primitive: Object.keys(mcpServers[serverName])[selectedIndex] as McpPrimitiveType,
 86        method: Object.values(mcpServers[serverName])[selectedIndex] as McpMethodType
 87      }
 88      console.log(selectedPrimitive)
 89      return selectedPrimitive
 90    } else {
 91      return {
 92        server: serverName,
 93        primitive: 'metadata',
 94        method: JSON.stringify(mcpServers[serverName].metadata, null, 2)
 95      }
 96    }
 97  }
 98  
 99  export const useMcpStore = defineStore('mcpStore', {
100    // TODO: fix any to type
101    state: (): any => ({
102      version: 1,
103      serverTools: [],
104      loading: true,
105      checkList: getObjectKeys(getServers()) as string[],
106      selected: undefined as string[] | undefined,
107      selectedChips: {} // { key : 0 | 1 | 2}
108    }),
109  
110    getters: {
111      getSelected(state): McpCoreType | null {
112        if (state.selected) {
113          const serverName = state.selected[0]
114          const selectedIndex = state.selectedChips[serverName]
115          return getSelectedByServer(serverName, selectedIndex, state.version)
116        } else {
117          return null
118        }
119      }
120    },
121  
122    actions: {
123      watchServerUpdate() {
124        watch(getServers, (newVal: McpServerApi, oldVal: McpServerApi) => {
125          const newKeys = getObjectKeys(newVal)
126          const oldKeys = getObjectKeys(oldVal)
127          const retainedKeys = this.checkList.filter((key: string) => newKeys.includes(key))
128          const addedKeys = newKeys.filter((key) => !oldKeys.includes(key))
129          this.checkList = [...retainedKeys, ...addedKeys]
130        })
131      },
132  
133      getAllByServer: function (serverName: string): McpCoreType[] {
134        const mcpServers = getServers()
135        if (!mcpServers) {
136          return []
137        }
138        const mcpServerObject = mcpServers[serverName]
139        if (!mcpServerObject) return []
140  
141        const allPrimitives = Object.entries(mcpServerObject).map(([key, value]) => {
142          return {
143            server: serverName,
144            primitive: key as McpPrimitiveType,
145            method: value as McpMethodType
146          }
147        })
148  
149        return allPrimitives
150      },
151      updateServers: async function () {
152        const servers = await window.mcpServers?.refresh()
153        console.log(servers)
154        return servers
155      },
156      getServerFunction: function (options: {
157        serverName: string
158        primitiveName: string
159        methodName: string
160      }): Function | null {
161        const { serverName, primitiveName, methodName } = options
162  
163        const allPrimitives = this.getAllByServer(serverName)
164  
165        const foundItem = allPrimitives.find((item: McpCoreType) => item.primitive === primitiveName)
166  
167        if (foundItem) {
168          return foundItem.method?.[methodName] || null
169        } else {
170          return null
171        }
172      },
173  
174      listServerTools: async function (serverNames?: string[]) {
175        const mcpTools: FunctionType[] = []
176  
177        const targets: string[] = serverNames?.length ? serverNames : [this.getSelected?.server]
178  
179        const promises = targets
180          .filter(Boolean) // filter invalid server
181          .map((serverName) =>
182            this.getServerFunction({
183              serverName,
184              primitiveName: 'tools',
185              methodName: 'list'
186            })?.()?.catch(() => null)
187          )
188  
189        const results = await Promise.all(promises)
190        for (const toolsData of results.filter(Boolean)) {
191          if (Array.isArray(toolsData?.tools)) {
192            toolsData.tools.forEach((tool: McpToolType) => {
193              const functionTool: FunctionType = {
194                type: 'function',
195                function: {
196                  name: tool.name,
197                  description: tool.description,
198                  parameters: tool.inputSchema
199                }
200              }
201              mcpTools.push(functionTool)
202            })
203          }
204        }
205  
206        return mcpTools
207      },
208  
209      loadServerTools: function () {
210        this.loading = true
211        try {
212          this.listServerTools().then((tools: FunctionType[]) => {
213            this.serverTools = tools.map((tool) => {
214              return {
215                name: tool.function.name,
216                description: tool.function.description
217              }
218            })
219          })
220        } catch (error) {
221          console.error('Failed to load tools:', error)
222        } finally {
223          this.loading = false
224        }
225      },
226  
227      listTools: async function () {
228        const mcpServers = getServers()
229        if (!mcpServers) {
230          return null
231        }
232        const mcpKeys = Object.keys(mcpServers)
233        const mcpTools: FunctionType[] = []
234        for (const key of mcpKeys) {
235          const toolsListFunction = mcpServers[key]?.tools?.list
236          if (typeof toolsListFunction === 'function') {
237            const tools = await toolsListFunction({ method: 'tools/list' })
238            // console.log(await mcpServers[key]?.prompts?.list())
239            // console.log(await mcpServers[key]?.resources['templates/list']())
240            // console.log(await mcpServers[key]?.resources?.list())
241            if (tools && Array.isArray(tools.tools)) {
242              for (const tool of tools.tools) {
243                const mcpTool: FunctionType = {
244                  type: 'function',
245                  function: {
246                    name: tool.name,
247                    description: tool.description,
248                    parameters: tool.inputSchema
249                    // strict: true
250                  }
251                }
252                mcpTools.push(mcpTool)
253              }
254            }
255          }
256        }
257        return mcpTools
258      },
259      getTool: async function (toolName: string) {
260        const mcpServers = getServers()
261        if (!mcpServers) {
262          return null
263        }
264        const mcpKeys = Object.keys(mcpServers)
265        const result = await Promise.any(
266          mcpKeys.map(async (key) => {
267            const toolsListFunction = mcpServers[key]?.tools?.list
268            if (typeof toolsListFunction === 'function') {
269              const tools = await toolsListFunction({ method: 'tools/list' })
270              if (tools && Array.isArray(tools.tools)) {
271                const foundTool = tools.tools.find((tool) => tool.name === toolName)
272                if (foundTool) {
273                  return {
274                    server: key,
275                    tool: foundTool
276                  }
277                }
278              }
279            }
280            throw new Error(`Tool ${toolName} not found on server ${key}`)
281          })
282        )
283  
284        return result
285      },
286      callTool: async function (toolName: string, toolArgs: string) {
287        const tool = await this.getTool(toolName)
288        if (!tool) {
289          return this.packReturn(`Tool name '${toolName}' not found`)
290        }
291  
292        let toolArguments = {}
293  
294        if (toolArgs) {
295          try {
296            toolArguments = JSON.parse(toolArgs)
297          } catch (e) {
298            return this.packReturn(`Arguments JSON parse error: '${e}'`)
299          }
300        }
301  
302        const params = {
303          name: toolName,
304          arguments: toolArguments
305        }
306  
307        const mcpServerObj: McpObject | undefined = getServers()?.[tool.server]
308  
309        if (mcpServerObj && mcpServerObj.tools && mcpServerObj.tools.call) {
310          return await mcpServerObj.tools.call({ method: 'tools/call', params: params })
311        } else {
312          return null
313        }
314      },
315      convertItem: function (
316        item: ChatCompletionPromptMessage['content']
317      ): ChatCompletionRequestContent {
318        if (item.type === 'text') {
319          return item
320        } else if (item.type === 'image') {
321          const imageUrl = `data:${item.mimeType};base64,${item.data}`
322          return {
323            type: 'image_url',
324            image_url: { url: imageUrl }
325          }
326        } else if (item.type === 'resource') {
327          return {
328            type: 'text',
329            text: JSON.stringify(item.resource, null, 2)
330          }
331        } else {
332          return {
333            type: 'text',
334            text: JSON.stringify(item, null, 2)
335          }
336        }
337      },
338      packReturn: (string: string) => {
339        return {
340          content: [
341            {
342              type: 'text',
343              text: string
344            }
345          ]
346        }
347      }
348    }
349  })