tool_scanner.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  """
 20  工具扫描器模块
 21  用于扫描所有可用的工具并提取参数信息,支持CLI展示
 22  """
 23  
 24  import os
 25  import glob
 26  import importlib
 27  import inspect
 28  from typing import Dict, List, Optional, Any
 29  from pathlib import Path
 30  
 31  
 32  class ToolScanner:
 33      """工具扫描器,用于扫描和提取工具信息"""
 34      
 35      def __init__(self):
 36          self.tools_info = {}
 37          self.plugin_paths = []
 38          self.remote_plugin_urls = []
 39      
 40      def add_plugin_path(self, path: str):
 41          """添加插件路径"""
 42          if os.path.exists(path):
 43              self.plugin_paths.append(path)
 44      
 45      def add_remote_plugin_url(self, url: str):
 46          """添加远程插件URL"""
 47          if url.startswith('http') and (url.endswith('.zip') or url.endswith('.py')):
 48              self.remote_plugin_urls.append(url)
 49      
 50      def scan_all_tools(self) -> Dict[str, Any]:
 51          """扫描所有可用的工具"""
 52          self.tools_info = {}
 53          
 54          # 扫描内置工具
 55          self._scan_builtin_tools()
 56          # 扫描插件工具
 57          self._scan_plugin_tools()
 58          # 扫描远程插件工具
 59          self._scan_remote_plugin_tools()
 60          
 61          return self.tools_info
 62      
 63      def _scan_remote_plugin_tools(self):
 64          """扫描远程插件工具"""
 65          if not self.remote_plugin_urls:
 66              return
 67              
 68          try:
 69              from .remote_plugin_downloader import RemotePluginDownloader
 70              downloader = RemotePluginDownloader()
 71              
 72              for url in self.remote_plugin_urls:
 73                  # 下载并解压远程插件
 74                  download_result = downloader.download_and_extract_plugin(url)
 75                  if download_result['success'] and download_result['extracted_path']:
 76                      # 扫描解压后的插件
 77                      self._scan_plugin_directory(download_result['extracted_path'])
 78          except ImportError:
 79              print("Warning: 远程插件下载器不可用,跳过远程插件扫描")
 80          except Exception as e:
 81              print(f"Warning: 扫描远程插件时发生错误: {e}")
 82      
 83      def _scan_builtin_tools(self):
 84          """扫描内置工具"""
 85          # 扫描 deepteam/attacks/ 目录
 86          self._scan_directory('deepteam/attacks/', 'attack')
 87          # 扫描 deepteam/metrics/ 目录  
 88          self._scan_directory('deepteam/metrics/', 'metric')
 89          # 扫描 deepteam/vulnerabilities/ 目录
 90          self._scan_directory('deepteam/vulnerabilities/', 'vulnerability')
 91      
 92      def _scan_plugin_tools(self):
 93          """扫描插件工具"""
 94          # 扫描 plugin/ 目录下的插件
 95          if os.path.exists('plugin/'):
 96              self._scan_plugin_directory('plugin/')
 97          
 98          # 扫描用户指定的插件路径
 99          for plugin_path in self.plugin_paths:
100              self._scan_plugin_directory(plugin_path)
101      
102      def _scan_directory(self, directory: str, tool_type: str):
103          """扫描指定目录下的工具"""
104          if not os.path.exists(directory):
105              return
106              
107          for file_path in glob.glob(f"{directory}/**/*.py", recursive=True):
108              if self._is_tool_file(file_path):
109                  tool_info = self._extract_tool_info(file_path, tool_type)
110                  if tool_info:
111                      self.tools_info[tool_info['name']] = tool_info
112      
113      def _scan_plugin_directory(self, plugin_path: str):
114          """扫描插件目录"""
115          if os.path.isfile(plugin_path):
116              # 单个文件插件
117              if plugin_path.endswith('.py'):
118                  tool_infos = self._extract_tool_info_from_file(plugin_path)
119                  for tool_info in tool_infos:
120                      self.tools_info[tool_info['name']] = tool_info
121          elif os.path.isdir(plugin_path):
122              # 文件夹插件
123              for file_path in glob.glob(f"{plugin_path}/**/*.py", recursive=True):
124                  tool_infos = self._extract_tool_info_from_file(file_path)
125                  for tool_info in tool_infos:
126                      self.tools_info[tool_info['name']] = tool_info
127      
128      def _is_tool_file(self, file_path: str) -> bool:
129          """判断是否为工具文件"""
130          # 排除 __init__.py 和测试文件
131          filename = os.path.basename(file_path)
132          if filename.startswith('__') or filename.startswith('test_'):
133              return False
134          
135          # 检查文件内容是否包含工具类
136          try:
137              with open(file_path, 'r', encoding='utf-8') as f:
138                  content = f.read()
139                  # 简单检查是否包含工具基类
140                  tool_keywords = ['BaseAttack', 'BaseRedTeamingMetric', 'BaseVulnerability']
141                  return any(keyword in content for keyword in tool_keywords)
142          except:
143              return False
144      
145      def _extract_tool_info(self, file_path: str, tool_type: str) -> Optional[Dict[str, Any]]:
146          """从文件中提取工具信息"""
147          try:
148              # 转换为模块路径
149              module_path = self._file_path_to_module_path(file_path)
150              if not module_path:
151                  return None
152              
153              # 动态导入模块
154              module = importlib.import_module(module_path)
155              
156              # 查找工具类
157              tool_class = self._find_tool_class(module, tool_type)
158              if not tool_class:
159                  return None
160              
161              # 提取参数信息
162              param_info = self._extract_parameters(tool_class)
163              
164              return {
165                  'name': tool_class.__name__,
166                  'type': tool_type,
167                  'file': file_path,
168                  'description': getattr(tool_class, '__doc__', ''),
169                  'parameters': param_info,
170                  'has_parameter_descriptions': hasattr(tool_class, '_parameter_descriptions')
171              }
172          except Exception as e:
173              print(f"Warning: Failed to scan {file_path}: {e}")
174              return None
175      
176      def _extract_tool_info_from_file(self, file_path: str) -> List[Dict[str, Any]]:
177          """从插件文件中提取工具信息"""
178          try:
179              # 转换为模块路径
180              module_path = self._file_path_to_module_path(file_path)
181              if not module_path:
182                  return []
183              
184              # 动态导入模块
185              module = importlib.import_module(module_path)
186              
187              # 查找所有工具类
188              all_tools = []
189              for tool_type in ['attack', 'metric', 'vulnerability']:
190                  tool_classes = self._find_all_tool_classes(module, tool_type)
191                  for tool_class in tool_classes:
192                      param_info = self._extract_parameters(tool_class)
193                      tool_info = {
194                          'name': tool_class.__name__,
195                          'type': tool_type,
196                          'file': file_path,
197                          'description': getattr(tool_class, '__doc__', ''),
198                          'parameters': param_info,
199                          'has_parameter_descriptions': hasattr(tool_class, '_parameter_descriptions')
200                      }
201                      all_tools.append(tool_info)
202              
203              return all_tools
204              
205          except Exception as e:
206              print(f"Warning: Failed to scan plugin {file_path}: {e}")
207              return []
208      
209      def _file_path_to_module_path(self, file_path: str) -> Optional[str]:
210          """将文件路径转换为模块路径"""
211          try:
212              # 移除 .py 扩展名
213              if file_path.endswith('.py'):
214                  file_path = file_path[:-3]
215              
216              # 替换路径分隔符
217              module_path = file_path.replace('/', '.').replace('\\', '.')
218              
219              # 移除开头的点
220              if module_path.startswith('.'):
221                  module_path = module_path[1:]
222              
223              return module_path
224          except:
225              return None
226      
227      def _find_tool_class(self, module, tool_type: str):
228          """在模块中查找工具类"""
229          tool_base_classes = {
230              'attack': ['BaseAttack'],
231              'metric': ['BaseRedTeamingMetric'],
232              'vulnerability': ['BaseVulnerability']
233          }
234          
235          base_classes = tool_base_classes.get(tool_type, [])
236          
237          for attr_name in dir(module):
238              attr = getattr(module, attr_name)
239              if inspect.isclass(attr):
240                  # 检查是否继承自工具基类
241                  for base_class_name in base_classes:
242                      try:
243                          # 获取基类
244                          base_class = getattr(module, base_class_name, None)
245                          if base_class and issubclass(attr, base_class) and attr != base_class:
246                              return attr
247                      except:
248                          continue
249                  
250                  # 检查模块的基类
251                  for base_class_name in base_classes:
252                      try:
253                          # 尝试从其他模块导入基类
254                          if tool_type == 'attack':
255                              from deepteam.attacks.base_attack import BaseAttack
256                              if issubclass(attr, BaseAttack) and attr != BaseAttack:
257                                  return attr
258                          elif tool_type == 'metric':
259                              from deepteam.metrics.base_red_teaming_metric import BaseRedTeamingMetric
260                              if issubclass(attr, BaseRedTeamingMetric) and attr != BaseRedTeamingMetric:
261                                  return attr
262                          elif tool_type == 'vulnerability':
263                              from deepteam.vulnerabilities.base_vulnerability import BaseVulnerability
264                              if issubclass(attr, BaseVulnerability) and attr != BaseVulnerability:
265                                  return attr
266                      except:
267                          continue
268          
269          return None
270      
271      def _find_all_tool_classes(self, module, tool_type: str) -> List:
272          """在模块中查找所有工具类"""
273          tool_classes = []
274          tool_base_classes = {
275              'attack': ['BaseAttack'],
276              'metric': ['BaseRedTeamingMetric'],
277              'vulnerability': ['BaseVulnerability']
278          }
279          
280          base_classes = tool_base_classes.get(tool_type, [])
281          
282          for attr_name in dir(module):
283              attr = getattr(module, attr_name)
284              if inspect.isclass(attr):
285                  # 检查是否继承自工具基类
286                  for base_class_name in base_classes:
287                      try:
288                          # 获取基类
289                          base_class = getattr(module, base_class_name, None)
290                          if base_class and issubclass(attr, base_class) and attr != base_class:
291                              tool_classes.append(attr)
292                              break
293                      except:
294                          continue
295                  
296                  # 检查模块的基类
297                  for base_class_name in base_classes:
298                      try:
299                          # 尝试从其他模块导入基类
300                          if tool_type == 'attack':
301                              from deepteam.attacks.base_attack import BaseAttack
302                              if issubclass(attr, BaseAttack) and attr != BaseAttack:
303                                  tool_classes.append(attr)
304                                  break
305                          elif tool_type == 'metric':
306                              from deepteam.metrics.base_red_teaming_metric import BaseRedTeamingMetric
307                              if issubclass(attr, BaseRedTeamingMetric) and attr != BaseRedTeamingMetric:
308                                  tool_classes.append(attr)
309                                  break
310                          elif tool_type == 'vulnerability':
311                              from deepteam.vulnerabilities.base_vulnerability import BaseVulnerability
312                              if issubclass(attr, BaseVulnerability) and attr != BaseVulnerability:
313                                  tool_classes.append(attr)
314                                  break
315                      except:
316                          continue
317          
318          return tool_classes
319      
320      def _extract_parameters(self, tool_class) -> Dict[str, Any]:
321          """提取工具类的参数信息"""
322          parameters = {}
323          
324          try:
325              # 获取 __init__ 方法的参数
326              init_sig = inspect.signature(tool_class.__init__)
327              
328              for param_name, param in init_sig.parameters.items():
329                  if param_name == 'self':
330                      continue
331                      
332                  param_info = {
333                      'required': param.default == inspect.Parameter.empty,
334                      'default': param.default if param.default != inspect.Parameter.empty else None,
335                      'description': ''
336                  }
337                  
338                  # 从装饰器中获取参数描述
339                  if hasattr(tool_class, '_parameter_descriptions'):
340                      param_info['description'] = tool_class._parameter_descriptions.get(param_name, '')
341                  
342                  parameters[param_name] = param_info
343          except Exception as e:
344              print(f"Warning: Failed to extract parameters from {tool_class.__name__}: {e}")
345          
346          return parameters
347      
348      def validate_tool_completeness(self) -> List[str]:
349          """验证工具参数说明的完整性"""
350          warnings = []
351          for tool_name, tool_info in self.tools_info.items():
352              if not tool_info['has_parameter_descriptions']:
353                  warnings.append(f"Warning: {tool_name} 缺少参数说明装饰器")
354              else:
355                  # 检查是否所有参数都有描述
356                  for param_name, param_info in tool_info['parameters'].items():
357                      if not param_info['description']:
358                          warnings.append(f"Warning: {tool_name}.{param_name} 缺少参数描述")
359          
360          return warnings
361      
362      def get_tools_by_type(self, tool_type: str) -> Dict[str, Any]:
363          """获取指定类型的工具"""
364          return {name: info for name, info in self.tools_info.items() 
365                  if info['type'] == tool_type}
366      
367      def get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
368          """获取指定工具的详细信息"""
369          return self.tools_info.get(tool_name)