dispatcher.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import inspect 20 from typing import TYPE_CHECKING, Any, Optional 21 22 from tools.registry import get_tool_by_name, get_tools_prompt, needs_context 23 from utils.loging import logger 24 from utils.mcp_tools import MCPTools 25 from utils.prompt_manager import prompt_manager 26 27 if TYPE_CHECKING: # pragma: no cover 28 from utils.tool_context import ToolContext 29 30 31 class ToolDispatcher: 32 def __init__( 33 self, mcp_server_url: str | None = None, mcp_headers: dict[str, str] | None = None 34 ): 35 """ 36 NOTE: __init__ must be synchronous. We do lazy MCP connection on first remote usage. 37 """ 38 self.mcp_server_url = mcp_server_url 39 self.mcp_tools_manager: MCPTools | None = None 40 self.mcp_transport = None 41 self.mcp_headers = mcp_headers 42 43 async def _ensure_mcp_manager(self) -> MCPTools | None: 44 if not self.mcp_server_url: 45 return None 46 if self.mcp_tools_manager: 47 return self.mcp_tools_manager 48 49 transports = [self.mcp_transport] if self.mcp_transport else ["streamable-http", "sse"] 50 for transport in transports: 51 if not transport: 52 continue 53 try: 54 manager = MCPTools(self.mcp_server_url, transport, headers=self.mcp_headers) # type: ignore[arg-type] 55 # verify connectivity 56 await manager.describe_mcp_tools() 57 self.mcp_tools_manager = manager 58 logger.info( 59 f"ToolDispatcher: MCP tools manager initialized with transport: {transport}" 60 ) 61 return self.mcp_tools_manager 62 except Exception: 63 continue 64 65 logger.error(f"ToolDispatcher: Failed to connect to MCP server: {self.mcp_server_url}") 66 return None 67 68 async def get_all_tools_prompt(self) -> str: 69 """获取所有可用工具的描述 Prompt""" 70 # common_tools = ['finish', 'think'] 71 # normal_tools = copy.copy(common_tools) 72 # normal_tools.extend(['read_file', 'execute_shell']) 73 # dynamic_tools = copy.copy(common_tools) 74 # dynamic_tools.extend(['call_mcp_tool', 'list_mcp_tools', 'list_mcp_prompts', 'list_mcp_resources']) 75 76 if self.mcp_server_url: 77 prompt = get_tools_prompt([]) 78 manager = await self._ensure_mcp_manager() 79 if not manager: 80 raise RuntimeError("Failed to connect to MCP server") 81 try: 82 mcp_prompt = await manager.describe_mcp_tools() 83 mcp_remote_prompt = prompt_manager.format_prompt( 84 "dynamic/system_prompt", mcp_tools=mcp_prompt 85 ) 86 prompt += f"\n\n{mcp_remote_prompt}" 87 except Exception as e: 88 logger.error(f"Failed to fetch MCP tools description: {e}") 89 return prompt 90 else: 91 prompt = get_tools_prompt([]) 92 93 return prompt 94 95 async def call_tool( 96 self, tool_name: str, args: dict[str, Any], context: Optional["ToolContext"] = None 97 ) -> str: 98 """统一调用入口:自动识别是本地还是远程工具""" 99 # 1. 尝试作为本地工具调用 100 tool_func = get_tool_by_name(tool_name) 101 if tool_func: 102 if needs_context(tool_name) and context: 103 args["context"] = context 104 105 try: 106 result = tool_func(**args) 107 except Exception as e: 108 return f"Error: {e}" 109 if inspect.isawaitable(result): 110 result = await result 111 return self._format_result(result) 112 return f"Error: Tool '{tool_name}' not found locally or MCP server is unavailable" 113 114 def _format_result(self, result: Any) -> str: 115 if isinstance(result, dict): 116 ret = "" 117 for k, v in result.items(): 118 ret += f"<{k}>{v}</{k}>\n" 119 return ret 120 return str(result) 121 122 async def close(self): 123 if self.mcp_tools_manager: 124 await self.mcp_tools_manager.close() 125 logger.info("ToolDispatcher: MCP tools manager closed")