tool_scanner.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 """ 20 工具扫描器模块 21 用于扫描所有可用的工具并提取参数信息,支持CLI展示 22 """ 23 24 import os 25 import glob 26 import importlib 27 import inspect 28 from typing import Dict, List, Optional, Any 29 from pathlib import Path 30 31 32 class ToolScanner: 33 """工具扫描器,用于扫描和提取工具信息""" 34 35 def __init__(self): 36 self.tools_info = {} 37 self.plugin_paths = [] 38 self.remote_plugin_urls = [] 39 40 def add_plugin_path(self, path: str): 41 """添加插件路径""" 42 if os.path.exists(path): 43 self.plugin_paths.append(path) 44 45 def add_remote_plugin_url(self, url: str): 46 """添加远程插件URL""" 47 if url.startswith('http') and (url.endswith('.zip') or url.endswith('.py')): 48 self.remote_plugin_urls.append(url) 49 50 def scan_all_tools(self) -> Dict[str, Any]: 51 """扫描所有可用的工具""" 52 self.tools_info = {} 53 54 # 扫描内置工具 55 self._scan_builtin_tools() 56 # 扫描插件工具 57 self._scan_plugin_tools() 58 # 扫描远程插件工具 59 self._scan_remote_plugin_tools() 60 61 return self.tools_info 62 63 def _scan_remote_plugin_tools(self): 64 """扫描远程插件工具""" 65 if not self.remote_plugin_urls: 66 return 67 68 try: 69 from .remote_plugin_downloader import RemotePluginDownloader 70 downloader = RemotePluginDownloader() 71 72 for url in self.remote_plugin_urls: 73 # 下载并解压远程插件 74 download_result = downloader.download_and_extract_plugin(url) 75 if download_result['success'] and download_result['extracted_path']: 76 # 扫描解压后的插件 77 self._scan_plugin_directory(download_result['extracted_path']) 78 except ImportError: 79 print("Warning: 远程插件下载器不可用,跳过远程插件扫描") 80 except Exception as e: 81 print(f"Warning: 扫描远程插件时发生错误: {e}") 82 83 def _scan_builtin_tools(self): 84 """扫描内置工具""" 85 # 扫描 deepteam/attacks/ 目录 86 self._scan_directory('deepteam/attacks/', 'attack') 87 # 扫描 deepteam/metrics/ 目录 88 self._scan_directory('deepteam/metrics/', 'metric') 89 # 扫描 deepteam/vulnerabilities/ 目录 90 self._scan_directory('deepteam/vulnerabilities/', 'vulnerability') 91 92 def _scan_plugin_tools(self): 93 """扫描插件工具""" 94 # 扫描 plugin/ 目录下的插件 95 if os.path.exists('plugin/'): 96 self._scan_plugin_directory('plugin/') 97 98 # 扫描用户指定的插件路径 99 for plugin_path in self.plugin_paths: 100 self._scan_plugin_directory(plugin_path) 101 102 def _scan_directory(self, directory: str, tool_type: str): 103 """扫描指定目录下的工具""" 104 if not os.path.exists(directory): 105 return 106 107 for file_path in glob.glob(f"{directory}/**/*.py", recursive=True): 108 if self._is_tool_file(file_path): 109 tool_info = self._extract_tool_info(file_path, tool_type) 110 if tool_info: 111 self.tools_info[tool_info['name']] = tool_info 112 113 def _scan_plugin_directory(self, plugin_path: str): 114 """扫描插件目录""" 115 if os.path.isfile(plugin_path): 116 # 单个文件插件 117 if plugin_path.endswith('.py'): 118 tool_infos = self._extract_tool_info_from_file(plugin_path) 119 for tool_info in tool_infos: 120 self.tools_info[tool_info['name']] = tool_info 121 elif os.path.isdir(plugin_path): 122 # 文件夹插件 123 for file_path in glob.glob(f"{plugin_path}/**/*.py", recursive=True): 124 tool_infos = self._extract_tool_info_from_file(file_path) 125 for tool_info in tool_infos: 126 self.tools_info[tool_info['name']] = tool_info 127 128 def _is_tool_file(self, file_path: str) -> bool: 129 """判断是否为工具文件""" 130 # 排除 __init__.py 和测试文件 131 filename = os.path.basename(file_path) 132 if filename.startswith('__') or filename.startswith('test_'): 133 return False 134 135 # 检查文件内容是否包含工具类 136 try: 137 with open(file_path, 'r', encoding='utf-8') as f: 138 content = f.read() 139 # 简单检查是否包含工具基类 140 tool_keywords = ['BaseAttack', 'BaseRedTeamingMetric', 'BaseVulnerability'] 141 return any(keyword in content for keyword in tool_keywords) 142 except: 143 return False 144 145 def _extract_tool_info(self, file_path: str, tool_type: str) -> Optional[Dict[str, Any]]: 146 """从文件中提取工具信息""" 147 try: 148 # 转换为模块路径 149 module_path = self._file_path_to_module_path(file_path) 150 if not module_path: 151 return None 152 153 # 动态导入模块 154 module = importlib.import_module(module_path) 155 156 # 查找工具类 157 tool_class = self._find_tool_class(module, tool_type) 158 if not tool_class: 159 return None 160 161 # 提取参数信息 162 param_info = self._extract_parameters(tool_class) 163 164 return { 165 'name': tool_class.__name__, 166 'type': tool_type, 167 'file': file_path, 168 'description': getattr(tool_class, '__doc__', ''), 169 'parameters': param_info, 170 'has_parameter_descriptions': hasattr(tool_class, '_parameter_descriptions') 171 } 172 except Exception as e: 173 print(f"Warning: Failed to scan {file_path}: {e}") 174 return None 175 176 def _extract_tool_info_from_file(self, file_path: str) -> List[Dict[str, Any]]: 177 """从插件文件中提取工具信息""" 178 try: 179 # 转换为模块路径 180 module_path = self._file_path_to_module_path(file_path) 181 if not module_path: 182 return [] 183 184 # 动态导入模块 185 module = importlib.import_module(module_path) 186 187 # 查找所有工具类 188 all_tools = [] 189 for tool_type in ['attack', 'metric', 'vulnerability']: 190 tool_classes = self._find_all_tool_classes(module, tool_type) 191 for tool_class in tool_classes: 192 param_info = self._extract_parameters(tool_class) 193 tool_info = { 194 'name': tool_class.__name__, 195 'type': tool_type, 196 'file': file_path, 197 'description': getattr(tool_class, '__doc__', ''), 198 'parameters': param_info, 199 'has_parameter_descriptions': hasattr(tool_class, '_parameter_descriptions') 200 } 201 all_tools.append(tool_info) 202 203 return all_tools 204 205 except Exception as e: 206 print(f"Warning: Failed to scan plugin {file_path}: {e}") 207 return [] 208 209 def _file_path_to_module_path(self, file_path: str) -> Optional[str]: 210 """将文件路径转换为模块路径""" 211 try: 212 # 移除 .py 扩展名 213 if file_path.endswith('.py'): 214 file_path = file_path[:-3] 215 216 # 替换路径分隔符 217 module_path = file_path.replace('/', '.').replace('\\', '.') 218 219 # 移除开头的点 220 if module_path.startswith('.'): 221 module_path = module_path[1:] 222 223 return module_path 224 except: 225 return None 226 227 def _find_tool_class(self, module, tool_type: str): 228 """在模块中查找工具类""" 229 tool_base_classes = { 230 'attack': ['BaseAttack'], 231 'metric': ['BaseRedTeamingMetric'], 232 'vulnerability': ['BaseVulnerability'] 233 } 234 235 base_classes = tool_base_classes.get(tool_type, []) 236 237 for attr_name in dir(module): 238 attr = getattr(module, attr_name) 239 if inspect.isclass(attr): 240 # 检查是否继承自工具基类 241 for base_class_name in base_classes: 242 try: 243 # 获取基类 244 base_class = getattr(module, base_class_name, None) 245 if base_class and issubclass(attr, base_class) and attr != base_class: 246 return attr 247 except: 248 continue 249 250 # 检查模块的基类 251 for base_class_name in base_classes: 252 try: 253 # 尝试从其他模块导入基类 254 if tool_type == 'attack': 255 from deepteam.attacks.base_attack import BaseAttack 256 if issubclass(attr, BaseAttack) and attr != BaseAttack: 257 return attr 258 elif tool_type == 'metric': 259 from deepteam.metrics.base_red_teaming_metric import BaseRedTeamingMetric 260 if issubclass(attr, BaseRedTeamingMetric) and attr != BaseRedTeamingMetric: 261 return attr 262 elif tool_type == 'vulnerability': 263 from deepteam.vulnerabilities.base_vulnerability import BaseVulnerability 264 if issubclass(attr, BaseVulnerability) and attr != BaseVulnerability: 265 return attr 266 except: 267 continue 268 269 return None 270 271 def _find_all_tool_classes(self, module, tool_type: str) -> List: 272 """在模块中查找所有工具类""" 273 tool_classes = [] 274 tool_base_classes = { 275 'attack': ['BaseAttack'], 276 'metric': ['BaseRedTeamingMetric'], 277 'vulnerability': ['BaseVulnerability'] 278 } 279 280 base_classes = tool_base_classes.get(tool_type, []) 281 282 for attr_name in dir(module): 283 attr = getattr(module, attr_name) 284 if inspect.isclass(attr): 285 # 检查是否继承自工具基类 286 for base_class_name in base_classes: 287 try: 288 # 获取基类 289 base_class = getattr(module, base_class_name, None) 290 if base_class and issubclass(attr, base_class) and attr != base_class: 291 tool_classes.append(attr) 292 break 293 except: 294 continue 295 296 # 检查模块的基类 297 for base_class_name in base_classes: 298 try: 299 # 尝试从其他模块导入基类 300 if tool_type == 'attack': 301 from deepteam.attacks.base_attack import BaseAttack 302 if issubclass(attr, BaseAttack) and attr != BaseAttack: 303 tool_classes.append(attr) 304 break 305 elif tool_type == 'metric': 306 from deepteam.metrics.base_red_teaming_metric import BaseRedTeamingMetric 307 if issubclass(attr, BaseRedTeamingMetric) and attr != BaseRedTeamingMetric: 308 tool_classes.append(attr) 309 break 310 elif tool_type == 'vulnerability': 311 from deepteam.vulnerabilities.base_vulnerability import BaseVulnerability 312 if issubclass(attr, BaseVulnerability) and attr != BaseVulnerability: 313 tool_classes.append(attr) 314 break 315 except: 316 continue 317 318 return tool_classes 319 320 def _extract_parameters(self, tool_class) -> Dict[str, Any]: 321 """提取工具类的参数信息""" 322 parameters = {} 323 324 try: 325 # 获取 __init__ 方法的参数 326 init_sig = inspect.signature(tool_class.__init__) 327 328 for param_name, param in init_sig.parameters.items(): 329 if param_name == 'self': 330 continue 331 332 param_info = { 333 'required': param.default == inspect.Parameter.empty, 334 'default': param.default if param.default != inspect.Parameter.empty else None, 335 'description': '' 336 } 337 338 # 从装饰器中获取参数描述 339 if hasattr(tool_class, '_parameter_descriptions'): 340 param_info['description'] = tool_class._parameter_descriptions.get(param_name, '') 341 342 parameters[param_name] = param_info 343 except Exception as e: 344 print(f"Warning: Failed to extract parameters from {tool_class.__name__}: {e}") 345 346 return parameters 347 348 def validate_tool_completeness(self) -> List[str]: 349 """验证工具参数说明的完整性""" 350 warnings = [] 351 for tool_name, tool_info in self.tools_info.items(): 352 if not tool_info['has_parameter_descriptions']: 353 warnings.append(f"Warning: {tool_name} 缺少参数说明装饰器") 354 else: 355 # 检查是否所有参数都有描述 356 for param_name, param_info in tool_info['parameters'].items(): 357 if not param_info['description']: 358 warnings.append(f"Warning: {tool_name}.{param_name} 缺少参数描述") 359 360 return warnings 361 362 def get_tools_by_type(self, tool_type: str) -> Dict[str, Any]: 363 """获取指定类型的工具""" 364 return {name: info for name, info in self.tools_info.items() 365 if info['type'] == tool_type} 366 367 def get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]: 368 """获取指定工具的详细信息""" 369 return self.tools_info.get(tool_name)