/ agent-scan / utils / mcp_tools.py
mcp_tools.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import asyncio
 20  import json
 21  from datetime import timedelta
 22  from typing import Any, AsyncIterator, Dict, Literal, Optional
 23  from contextlib import asynccontextmanager
 24  
 25  from mcp import ClientSession
 26  from mcp.client.sse import sse_client
 27  from mcp.client.streamable_http import streamablehttp_client
 28  
 29  
 30  class MCPTools:
 31      """Small MCP-only wrapper used by this repo (no agno dependency)."""
 32  
 33      def __init__(self, url: Optional[str] = None, transport: Literal["sse", "streamable-http"] = "sse",
 34                   headers: dict = None):
 35          if headers is None:
 36              headers = {}
 37          self.url = url
 38          self.transport = transport
 39          self.timeout_seconds = 10
 40          self.headers = headers
 41          # 缓存工具 schema,用于参数类型转换
 42          self._tools_schema: Dict[str, Dict[str, Any]] = {}
 43  
 44      async def close(self) -> None:
 45          # Stateless wrapper: each operation uses a short-lived session.
 46          return
 47  
 48      @asynccontextmanager
 49      async def _session(self) -> AsyncIterator[ClientSession]:
 50          """Short-lived session (enter/exit in same coroutine; safe for SSE + anyio)."""
 51          if not self.url:
 52              raise ValueError("MCP server url is required")
 53  
 54          if self.transport == "sse":
 55              ctx = sse_client(url=self.url, headers=self.headers)  # type: ignore
 56          elif self.transport == "streamable-http":
 57              ctx = streamablehttp_client(url=self.url, headers=self.headers)  # type: ignore
 58          else:
 59              raise ValueError(f"Unsupported transport protocol: {self.transport}")
 60  
 61          async with ctx as session_params:  # type: ignore
 62              read, write = session_params[0:2]
 63              async with ClientSession(
 64                      read,
 65                      write,
 66                      read_timeout_seconds=timedelta(seconds=self.timeout_seconds),
 67              ) as session:  # type: ignore
 68                  await session.initialize()
 69                  yield session
 70  
 71      def _build_parameter_attributes(self, param: Dict[str, Any]) -> str:
 72          """构建参数的 XML 属性字符串,包含所有 schema 信息"""
 73          attrs = []
 74          
 75          # 基础属性:type 和 required 在调用处处理
 76          
 77          # description: 描述
 78          if 'description' in param and param['description']:
 79              desc = str(param['description']).replace('"', '"')
 80              attrs.append(f'description="{desc}"')
 81          
 82          # enum: 枚举值列表
 83          if 'enum' in param and param['enum']:
 84              enum_values = param['enum']
 85              if isinstance(enum_values, list):
 86                  enum_str = ','.join(str(v) for v in enum_values)
 87                  enum_str = enum_str.replace('"', '"')
 88                  attrs.append(f'enum="{enum_str}"')
 89          
 90          # default: 默认值
 91          if 'default' in param:
 92              default_val = param['default']
 93              if isinstance(default_val, (dict, list)):
 94                  default_str = json.dumps(default_val, ensure_ascii=False)
 95              else:
 96                  default_str = str(default_val)
 97              default_str = default_str.replace('"', '"')
 98              attrs.append(f'default="{default_str}"')
 99          
100          # minimum/maximum: 数值范围
101          if 'minimum' in param:
102              attrs.append(f'minimum="{param["minimum"]}"')
103          if 'maximum' in param:
104              attrs.append(f'maximum="{param["maximum"]}"')
105          
106          # minLength/maxLength: 字符串长度限制
107          if 'minLength' in param:
108              attrs.append(f'minLength="{param["minLength"]}"')
109          if 'maxLength' in param:
110              attrs.append(f'maxLength="{param["maxLength"]}"')
111          
112          # pattern: 正则表达式模式
113          if 'pattern' in param and param['pattern']:
114              pattern_str = str(param['pattern']).replace('"', '"')
115              attrs.append(f'pattern="{pattern_str}"')
116          
117          # format: 格式(如 date-time, email, uri 等)
118          if 'format' in param and param['format']:
119              attrs.append(f'format="{param["format"]}"')
120          
121          # examples: 示例值
122          if 'examples' in param and param['examples']:
123              examples = param['examples']
124              if isinstance(examples, list) and examples:
125                  examples_str = ','.join(str(v) for v in examples)
126                  examples_str = examples_str.replace('"', '"')
127                  attrs.append(f'examples="{examples_str}"')
128          
129          # items: 数组元素类型(对于 array 类型)
130          if 'items' in param:
131              items = param['items']
132              if isinstance(items, dict):
133                  if 'type' in items:
134                      attrs.append(f'itemsType="{items["type"]}"')
135                  if 'enum' in items:
136                      items_enum = items['enum']
137                      if isinstance(items_enum, list):
138                          items_enum_str = ','.join(str(v) for v in items_enum)
139                          items_enum_str = items_enum_str.replace('"', '"')
140                          attrs.append(f'itemsEnum="{items_enum_str}"')
141          
142          return ' '.join(attrs)
143  
144      async def describe_mcp_tools(self) -> str:
145          """Return `<mcp_tools>` XML listing tool names and descriptions."""
146          try:
147              async with self._session() as session:
148                  data = await session.list_tools()
149          except BaseExceptionGroup as eg:
150              root_cause = self._extract_root_cause(eg)
151              raise RuntimeError(f"Failed to fetch MCP tools: {root_cause}") from eg
152          except Exception as e:
153              raise RuntimeError(f"Failed to fetch MCP tools: {type(e).__name__}: {e}") from e
154  
155          xml_lines = ["<mcp_tools>"]
156          for t in data.tools:
157              # 缓存工具 schema,用于后续参数类型转换
158              self._tools_schema[t.name] = t.inputSchema
159  
160              parameters = ''
161              for k, param in t.inputSchema['properties'].items():
162                  required = 'true' if k in t.inputSchema.get("required", []) else 'false'
163                  param_type = param.get('type', 'string')
164                  # 构建基础属性
165                  base_attrs = f'name="{k}" type="{param_type}" required="{required}"'
166                  # 构建额外的 schema 属性
167                  extra_attrs = self._build_parameter_attributes(param)
168                  # 合并所有属性(如果 extra_attrs 不为空,则添加空格)
169                  all_attrs = f'{base_attrs} {extra_attrs}'.strip() if extra_attrs else base_attrs
170                  parameters += f'''<parameter {all_attrs}></parameter>'''
171              xml_lines.append(f'''
172      <name>{t.name}</name>
173      <description>{t.description}</description>
174      <parameters>
175        <parameter name="tool_name" type=string required=true>tool_name is {t.name}</parameter>
176        {parameters}
177      </parameters>
178              ''')
179              name = t.name
180              detail = t.description or ""
181              xml_lines.append(f"detail:{detail} 调用格式:\n<tool_name>{name}</tool_name>\n</tool>")
182          xml_lines.append("</mcp_tools>")
183          return "\n".join(xml_lines)
184  
185      def _convert_param_type(self, value: Any, param_type: str) -> Any:
186          """根据 schema 定义的类型转换参数值"""
187          if value is None:
188              return None
189  
190          try:
191              if param_type == "integer":
192                  return int(value)
193              elif param_type == "number":
194                  return float(value)
195              elif param_type == "boolean":
196                  if isinstance(value, bool):
197                      return value
198                  if isinstance(value, str):
199                      return value.lower() in ("true", "1", "yes")
200                  return bool(value)
201              elif param_type == "array":
202                  if isinstance(value, list):
203                      return value
204                  if isinstance(value, str):
205                      import json
206                      return json.loads(value)
207                  return [value]
208              elif param_type == "object":
209                  if isinstance(value, dict):
210                      return value
211                  if isinstance(value, str):
212                      import json
213                      return json.loads(value)
214                  return value
215              else:
216                  # string 或其他类型,保持原样
217                  return value
218          except (ValueError, TypeError):
219              # 转换失败,返回原值
220              return value
221  
222      def _convert_args_by_schema(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
223          """根据工具 schema 转换所有参数类型"""
224          schema = self._tools_schema.get(tool_name)
225          if not schema:
226              return args
227  
228          properties = schema.get("properties", {})
229          converted_args = {}
230  
231          for key, value in args.items():
232              param_schema = properties.get(key, {})
233              param_type = param_schema.get("type", "string")
234              converted_args[key] = self._convert_param_type(value, param_type)
235  
236          return converted_args
237  
238      def _extract_root_cause(self, exc: Exception) -> str:
239          """从 ExceptionGroup/TaskGroup 中提取原始错误信息"""
240          # 处理 ExceptionGroup (Python 3.11+)
241          if isinstance(exc, BaseExceptionGroup):
242              messages = []
243              for sub_exc in exc.exceptions:
244                  # 递归提取嵌套的 ExceptionGroup
245                  messages.append(self._extract_root_cause(sub_exc))
246              return "; ".join(messages)
247          # 普通异常,返回其消息
248          return f"{type(exc).__name__}: {exc}"
249  
250      async def call_remote_tool(self, tool_name: str, **kw) -> Any:
251          """
252          Call remote MCP server tool.
253          call: {"toolName": name, "args": {...}}
254          """
255          if not tool_name:
256              raise ValueError("call_remote_tool requires call['toolName']")
257  
258          # 根据 schema 转换参数类型
259          converted_kw = self._convert_args_by_schema(tool_name, kw)
260  
261          try:
262              async with self._session() as session:
263                  result = await session.call_tool(tool_name, converted_kw)
264                  if result is None:
265                      return None
266                  result = result.content[0]
267                  # 判断TextContent or ImageContent or VideoContent
268                  if hasattr(result, 'text'):
269                      return result.text
270                  elif hasattr(result, 'data'):
271                      return result.data
272          except BaseExceptionGroup as eg:
273              # 提取 TaskGroup 中的原始错误
274              root_cause = self._extract_root_cause(eg)
275              raise RuntimeError(f"MCP call failed: {root_cause}") from eg
276          except Exception as e:
277              raise RuntimeError(f"MCP call failed: {type(e).__name__}: {e}") from e
278  
279  
280  if __name__ == "__main__":
281      async def main():
282          mcp_tools_manager = MCPTools(url="http://localhost:8090/sse", transport="sse")
283          description = await mcp_tools_manager.describe_mcp_tools()
284          print(description)
285          result = await mcp_tools_manager.call_remote_tool(
286              "get_filename1",
287              filename="/etc/passwd"
288          )
289          print(f"Tool call result: {result}")
290  
291  
292      asyncio.run(main())