/ AIG-PromptSecurity / deepteam / plugin_system / plugin_validator.py
plugin_validator.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import inspect
 20  import ast
 21  from typing import Dict, List, Tuple, Optional, Any
 22  from pathlib import Path
 23  import importlib.util
 24  from loguru import logger
 25  
 26  from deepteam.attacks import BaseAttack
 27  from deepteam.metrics import BaseRedTeamingMetric
 28  from deepteam.vulnerabilities import BaseVulnerability
 29  
 30  
 31  class PluginValidator:
 32      """验证用户上传的算子插件是否符合系统要求"""
 33      
 34      def __init__(self):
 35          self.required_base_classes = {
 36              'attack': BaseAttack,
 37              'metric': BaseRedTeamingMetric,
 38              'vulnerability': BaseVulnerability
 39          }
 40          
 41          self.required_methods = {
 42              'attack': ['enhance', 'get_name'],
 43              'metric': ['measure', 'get_name'],
 44              'vulnerability': ['get_name', 'get_types']
 45          }
 46      
 47      def validate_plugin_file(self, file_path: Path) -> Dict[str, Any]:
 48          """验证单个插件文件"""
 49          result = {
 50              'valid': False,
 51              'plugin_type': None,
 52              'class_name': None,
 53              'errors': [],
 54              'warnings': []
 55          }
 56          
 57          try:
 58              # 检查文件扩展名
 59              if file_path.suffix != '.py':
 60                  result['errors'].append(f"文件必须是Python文件 (.py): {file_path}")
 61                  return result
 62              
 63              # 解析Python文件
 64              with open(file_path, 'r', encoding='utf-8') as f:
 65                  content = f.read()
 66              
 67              # 使用AST解析文件
 68              tree = ast.parse(content)
 69              
 70              # 查找类定义
 71              classes = [node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]
 72              
 73              if not classes:
 74                  result['errors'].append("文件中没有找到类定义")
 75                  return result
 76              
 77              # 检查每个类
 78              for class_node in classes:
 79                  class_result = self._validate_class(class_node, content, file_path)
 80                  if class_result['valid']:
 81                      result.update(class_result)
 82                      break
 83              else:
 84                  result['errors'].append("没有找到符合要求的类")
 85                  
 86          except Exception as e:
 87              result['errors'].append(f"验证过程中发生错误: {str(e)}")
 88          
 89          return result
 90      
 91      def _validate_class(self, class_node: ast.ClassDef, content: str, file_path: Path) -> Dict[str, Any]:
 92          """验证单个类是否符合插件要求"""
 93          result = {
 94              'valid': False,
 95              'plugin_type': None,
 96              'class_name': class_node.name,
 97              'errors': [],
 98              'warnings': []
 99          }
100          
101          try:
102              # 检查类是否继承自正确的基类
103              base_classes = self._get_base_classes(class_node)
104              
105              # 确定插件类型
106              plugin_type = None
107              for base_class_name in base_classes:
108                  if 'Attack' in base_class_name or 'attack' in base_class_name.lower():
109                      plugin_type = 'attack'
110                      break
111                  elif 'Metric' in base_class_name or 'metric' in base_class_name.lower():
112                      plugin_type = 'metric'
113                      break
114                  elif 'Vulnerability' in base_class_name or 'vulnerability' in base_class_name.lower():
115                      plugin_type = 'vulnerability'
116                      break
117              
118              if not plugin_type:
119                  # 尝试动态导入验证
120                  plugin_type = self._dynamic_validate_class(file_path, class_node.name)
121              
122              if not plugin_type:
123                  result['errors'].append(f"类 {class_node.name} 没有继承正确的基类")
124                  return result
125              
126              result['plugin_type'] = plugin_type
127              
128              # 检查必需的方法
129              required_methods = self.required_methods[plugin_type]
130              class_methods = [node.name for node in class_node.body if isinstance(node, ast.FunctionDef)]
131              
132              missing_methods = []
133              for method in required_methods:
134                  if method not in class_methods:
135                      missing_methods.append(method)
136              
137              if missing_methods:
138                  result['errors'].append(f"缺少必需的方法: {', '.join(missing_methods)}")
139                  return result
140              
141              # 检查是否有抽象方法
142              has_abstract_methods = self._check_abstract_methods(class_node)
143              if has_abstract_methods:
144                  result['warnings'].append("类包含抽象方法,需要实现")
145              
146              result['valid'] = True
147              
148          except Exception as e:
149              result['errors'].append(f"验证类时发生错误: {str(e)}")
150          
151          return result
152      
153      def _get_base_classes(self, class_node: ast.ClassDef) -> List[str]:
154          """获取类的基类名称"""
155          base_classes = []
156          for base in class_node.bases:
157              if isinstance(base, ast.Name):
158                  base_classes.append(base.id)
159              elif isinstance(base, ast.Attribute):
160                  # 处理类似 module.Class 的情况
161                  base_classes.append(self._get_attribute_name(base))
162          return base_classes
163      
164      def _get_attribute_name(self, node: ast.Attribute) -> str:
165          """获取属性访问的完整名称"""
166          if isinstance(node.value, ast.Name):
167              return f"{node.value.id}.{node.attr}"
168          elif isinstance(node.value, ast.Attribute):
169              return f"{self._get_attribute_name(node.value)}.{node.attr}"
170          else:
171              return node.attr
172      
173      def _check_abstract_methods(self, class_node: ast.ClassDef) -> bool:
174          """检查类是否包含抽象方法"""
175          for node in class_node.body:
176              if isinstance(node, ast.FunctionDef):
177                  # 检查是否有 @abstractmethod 装饰器
178                  for decorator in node.decorator_list:
179                      if isinstance(decorator, ast.Name) and decorator.id == 'abstractmethod':
180                          return True
181                      elif isinstance(decorator, ast.Attribute) and decorator.attr == 'abstractmethod':
182                          return True
183          return False
184      
185      def _dynamic_validate_class(self, file_path: Path, class_name: str) -> Optional[str]:
186          """动态导入验证类"""
187          try:
188              # 动态导入模块
189              spec = importlib.util.spec_from_file_location("temp_module", file_path)
190              if spec is None:
191                  return None
192                  
193              module = importlib.util.module_from_spec(spec)
194              if spec.loader is None:
195                  return None
196                  
197              spec.loader.exec_module(module)
198              
199              # 获取类
200              cls = getattr(module, class_name, None)
201              if not cls:
202                  return None
203              
204              # 检查继承关系
205              if issubclass(cls, BaseAttack):
206                  return 'attack'
207              elif issubclass(cls, BaseRedTeamingMetric):
208                  return 'metric'
209              elif issubclass(cls, BaseVulnerability):
210                  return 'vulnerability'
211              
212          except Exception as e:
213              logger.warning(f"动态验证失败: {e}")
214          
215          return None
216      
217      def validate_plugin_directory(self, dir_path: Path) -> Dict[str, Any]:
218          """验证插件目录"""
219          result = {
220              'valid': False,
221              'plugins': [],
222              'errors': [],
223              'warnings': []
224          }
225          
226          if not dir_path.exists() or not dir_path.is_dir():
227              result['errors'].append(f"目录不存在或不是有效目录: {dir_path}")
228              return result
229          
230          # 查找所有Python文件
231          py_files = list(dir_path.glob("*.py"))
232          
233          if not py_files:
234              result['errors'].append(f"目录中没有找到Python文件: {dir_path}")
235              return result
236          
237          # 验证每个文件
238          for py_file in py_files:
239              file_result = self.validate_plugin_file(py_file)
240              if file_result['valid']:
241                  result['plugins'].append(file_result)
242              else:
243                  result['errors'].extend([f"{py_file.name}: {error}" for error in file_result['errors']])
244                  result['warnings'].extend([f"{py_file.name}: {warning}" for warning in file_result['warnings']])
245          
246          result['valid'] = len(result['plugins']) > 0
247          
248          return result