mcp_tools.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import asyncio 20 import json 21 from datetime import timedelta 22 from typing import Any, AsyncIterator, Dict, Literal, Optional 23 from contextlib import asynccontextmanager 24 25 from mcp import ClientSession 26 from mcp.client.sse import sse_client 27 from mcp.client.streamable_http import streamablehttp_client 28 29 30 class MCPTools: 31 """Small MCP-only wrapper used by this repo (no agno dependency).""" 32 33 def __init__(self, url: Optional[str] = None, transport: Literal["sse", "streamable-http"] = "sse", 34 headers: dict = None): 35 if headers is None: 36 headers = {} 37 self.url = url 38 self.transport = transport 39 self.timeout_seconds = 10 40 self.headers = headers 41 # 缓存工具 schema,用于参数类型转换 42 self._tools_schema: Dict[str, Dict[str, Any]] = {} 43 44 async def close(self) -> None: 45 # Stateless wrapper: each operation uses a short-lived session. 46 return 47 48 @asynccontextmanager 49 async def _session(self) -> AsyncIterator[ClientSession]: 50 """Short-lived session (enter/exit in same coroutine; safe for SSE + anyio).""" 51 if not self.url: 52 raise ValueError("MCP server url is required") 53 54 if self.transport == "sse": 55 ctx = sse_client(url=self.url, headers=self.headers) # type: ignore 56 elif self.transport == "streamable-http": 57 ctx = streamablehttp_client(url=self.url, headers=self.headers) # type: ignore 58 else: 59 raise ValueError(f"Unsupported transport protocol: {self.transport}") 60 61 async with ctx as session_params: # type: ignore 62 read, write = session_params[0:2] 63 async with ClientSession( 64 read, 65 write, 66 read_timeout_seconds=timedelta(seconds=self.timeout_seconds), 67 ) as session: # type: ignore 68 await session.initialize() 69 yield session 70 71 def _build_parameter_attributes(self, param: Dict[str, Any]) -> str: 72 """构建参数的 XML 属性字符串,包含所有 schema 信息""" 73 attrs = [] 74 75 # 基础属性:type 和 required 在调用处处理 76 77 # description: 描述 78 if 'description' in param and param['description']: 79 desc = str(param['description']).replace('"', '"') 80 attrs.append(f'description="{desc}"') 81 82 # enum: 枚举值列表 83 if 'enum' in param and param['enum']: 84 enum_values = param['enum'] 85 if isinstance(enum_values, list): 86 enum_str = ','.join(str(v) for v in enum_values) 87 enum_str = enum_str.replace('"', '"') 88 attrs.append(f'enum="{enum_str}"') 89 90 # default: 默认值 91 if 'default' in param: 92 default_val = param['default'] 93 if isinstance(default_val, (dict, list)): 94 default_str = json.dumps(default_val, ensure_ascii=False) 95 else: 96 default_str = str(default_val) 97 default_str = default_str.replace('"', '"') 98 attrs.append(f'default="{default_str}"') 99 100 # minimum/maximum: 数值范围 101 if 'minimum' in param: 102 attrs.append(f'minimum="{param["minimum"]}"') 103 if 'maximum' in param: 104 attrs.append(f'maximum="{param["maximum"]}"') 105 106 # minLength/maxLength: 字符串长度限制 107 if 'minLength' in param: 108 attrs.append(f'minLength="{param["minLength"]}"') 109 if 'maxLength' in param: 110 attrs.append(f'maxLength="{param["maxLength"]}"') 111 112 # pattern: 正则表达式模式 113 if 'pattern' in param and param['pattern']: 114 pattern_str = str(param['pattern']).replace('"', '"') 115 attrs.append(f'pattern="{pattern_str}"') 116 117 # format: 格式(如 date-time, email, uri 等) 118 if 'format' in param and param['format']: 119 attrs.append(f'format="{param["format"]}"') 120 121 # examples: 示例值 122 if 'examples' in param and param['examples']: 123 examples = param['examples'] 124 if isinstance(examples, list) and examples: 125 examples_str = ','.join(str(v) for v in examples) 126 examples_str = examples_str.replace('"', '"') 127 attrs.append(f'examples="{examples_str}"') 128 129 # items: 数组元素类型(对于 array 类型) 130 if 'items' in param: 131 items = param['items'] 132 if isinstance(items, dict): 133 if 'type' in items: 134 attrs.append(f'itemsType="{items["type"]}"') 135 if 'enum' in items: 136 items_enum = items['enum'] 137 if isinstance(items_enum, list): 138 items_enum_str = ','.join(str(v) for v in items_enum) 139 items_enum_str = items_enum_str.replace('"', '"') 140 attrs.append(f'itemsEnum="{items_enum_str}"') 141 142 return ' '.join(attrs) 143 144 async def describe_mcp_tools(self) -> str: 145 """Return `<mcp_tools>` XML listing tool names and descriptions.""" 146 try: 147 async with self._session() as session: 148 data = await session.list_tools() 149 except BaseExceptionGroup as eg: 150 root_cause = self._extract_root_cause(eg) 151 raise RuntimeError(f"Failed to fetch MCP tools: {root_cause}") from eg 152 except Exception as e: 153 raise RuntimeError(f"Failed to fetch MCP tools: {type(e).__name__}: {e}") from e 154 155 xml_lines = ["<mcp_tools>"] 156 for t in data.tools: 157 # 缓存工具 schema,用于后续参数类型转换 158 self._tools_schema[t.name] = t.inputSchema 159 160 parameters = '' 161 for k, param in t.inputSchema['properties'].items(): 162 required = 'true' if k in t.inputSchema.get("required", []) else 'false' 163 param_type = param.get('type', 'string') 164 # 构建基础属性 165 base_attrs = f'name="{k}" type="{param_type}" required="{required}"' 166 # 构建额外的 schema 属性 167 extra_attrs = self._build_parameter_attributes(param) 168 # 合并所有属性(如果 extra_attrs 不为空,则添加空格) 169 all_attrs = f'{base_attrs} {extra_attrs}'.strip() if extra_attrs else base_attrs 170 parameters += f'''<parameter {all_attrs}></parameter>''' 171 xml_lines.append(f''' 172 <name>{t.name}</name> 173 <description>{t.description}</description> 174 <parameters> 175 <parameter name="tool_name" type=string required=true>tool_name is {t.name}</parameter> 176 {parameters} 177 </parameters> 178 ''') 179 name = t.name 180 detail = t.description or "" 181 xml_lines.append(f"detail:{detail} 调用格式:\n<tool_name>{name}</tool_name>\n</tool>") 182 xml_lines.append("</mcp_tools>") 183 return "\n".join(xml_lines) 184 185 def _convert_param_type(self, value: Any, param_type: str) -> Any: 186 """根据 schema 定义的类型转换参数值""" 187 if value is None: 188 return None 189 190 try: 191 if param_type == "integer": 192 return int(value) 193 elif param_type == "number": 194 return float(value) 195 elif param_type == "boolean": 196 if isinstance(value, bool): 197 return value 198 if isinstance(value, str): 199 return value.lower() in ("true", "1", "yes") 200 return bool(value) 201 elif param_type == "array": 202 if isinstance(value, list): 203 return value 204 if isinstance(value, str): 205 import json 206 return json.loads(value) 207 return [value] 208 elif param_type == "object": 209 if isinstance(value, dict): 210 return value 211 if isinstance(value, str): 212 import json 213 return json.loads(value) 214 return value 215 else: 216 # string 或其他类型,保持原样 217 return value 218 except (ValueError, TypeError): 219 # 转换失败,返回原值 220 return value 221 222 def _convert_args_by_schema(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]: 223 """根据工具 schema 转换所有参数类型""" 224 schema = self._tools_schema.get(tool_name) 225 if not schema: 226 return args 227 228 properties = schema.get("properties", {}) 229 converted_args = {} 230 231 for key, value in args.items(): 232 param_schema = properties.get(key, {}) 233 param_type = param_schema.get("type", "string") 234 converted_args[key] = self._convert_param_type(value, param_type) 235 236 return converted_args 237 238 def _extract_root_cause(self, exc: Exception) -> str: 239 """从 ExceptionGroup/TaskGroup 中提取原始错误信息""" 240 # 处理 ExceptionGroup (Python 3.11+) 241 if isinstance(exc, BaseExceptionGroup): 242 messages = [] 243 for sub_exc in exc.exceptions: 244 # 递归提取嵌套的 ExceptionGroup 245 messages.append(self._extract_root_cause(sub_exc)) 246 return "; ".join(messages) 247 # 普通异常,返回其消息 248 return f"{type(exc).__name__}: {exc}" 249 250 async def call_remote_tool(self, tool_name: str, **kw) -> Any: 251 """ 252 Call remote MCP server tool. 253 call: {"toolName": name, "args": {...}} 254 """ 255 if not tool_name: 256 raise ValueError("call_remote_tool requires call['toolName']") 257 258 # 根据 schema 转换参数类型 259 converted_kw = self._convert_args_by_schema(tool_name, kw) 260 261 try: 262 async with self._session() as session: 263 result = await session.call_tool(tool_name, converted_kw) 264 if result is None: 265 return None 266 result = result.content[0] 267 # 判断TextContent or ImageContent or VideoContent 268 if hasattr(result, 'text'): 269 return result.text 270 elif hasattr(result, 'data'): 271 return result.data 272 except BaseExceptionGroup as eg: 273 # 提取 TaskGroup 中的原始错误 274 root_cause = self._extract_root_cause(eg) 275 raise RuntimeError(f"MCP call failed: {root_cause}") from eg 276 except Exception as e: 277 raise RuntimeError(f"MCP call failed: {type(e).__name__}: {e}") from e 278 279 280 if __name__ == "__main__": 281 async def main(): 282 mcp_tools_manager = MCPTools(url="http://localhost:8090/sse", transport="sse") 283 description = await mcp_tools_manager.describe_mcp_tools() 284 print(description) 285 result = await mcp_tools_manager.call_remote_tool( 286 "get_filename1", 287 filename="/etc/passwd" 288 ) 289 print(f"Tool call result: {result}") 290 291 292 asyncio.run(main())