plugin_validator.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import inspect 20 import ast 21 from typing import Dict, List, Tuple, Optional, Any 22 from pathlib import Path 23 import importlib.util 24 from loguru import logger 25 26 from deepteam.attacks import BaseAttack 27 from deepteam.metrics import BaseRedTeamingMetric 28 from deepteam.vulnerabilities import BaseVulnerability 29 30 31 class PluginValidator: 32 """验证用户上传的算子插件是否符合系统要求""" 33 34 def __init__(self): 35 self.required_base_classes = { 36 'attack': BaseAttack, 37 'metric': BaseRedTeamingMetric, 38 'vulnerability': BaseVulnerability 39 } 40 41 self.required_methods = { 42 'attack': ['enhance', 'get_name'], 43 'metric': ['measure', 'get_name'], 44 'vulnerability': ['get_name', 'get_types'] 45 } 46 47 def validate_plugin_file(self, file_path: Path) -> Dict[str, Any]: 48 """验证单个插件文件""" 49 result = { 50 'valid': False, 51 'plugin_type': None, 52 'class_name': None, 53 'errors': [], 54 'warnings': [] 55 } 56 57 try: 58 # 检查文件扩展名 59 if file_path.suffix != '.py': 60 result['errors'].append(f"文件必须是Python文件 (.py): {file_path}") 61 return result 62 63 # 解析Python文件 64 with open(file_path, 'r', encoding='utf-8') as f: 65 content = f.read() 66 67 # 使用AST解析文件 68 tree = ast.parse(content) 69 70 # 查找类定义 71 classes = [node for node in ast.walk(tree) if isinstance(node, ast.ClassDef)] 72 73 if not classes: 74 result['errors'].append("文件中没有找到类定义") 75 return result 76 77 # 检查每个类 78 for class_node in classes: 79 class_result = self._validate_class(class_node, content, file_path) 80 if class_result['valid']: 81 result.update(class_result) 82 break 83 else: 84 result['errors'].append("没有找到符合要求的类") 85 86 except Exception as e: 87 result['errors'].append(f"验证过程中发生错误: {str(e)}") 88 89 return result 90 91 def _validate_class(self, class_node: ast.ClassDef, content: str, file_path: Path) -> Dict[str, Any]: 92 """验证单个类是否符合插件要求""" 93 result = { 94 'valid': False, 95 'plugin_type': None, 96 'class_name': class_node.name, 97 'errors': [], 98 'warnings': [] 99 } 100 101 try: 102 # 检查类是否继承自正确的基类 103 base_classes = self._get_base_classes(class_node) 104 105 # 确定插件类型 106 plugin_type = None 107 for base_class_name in base_classes: 108 if 'Attack' in base_class_name or 'attack' in base_class_name.lower(): 109 plugin_type = 'attack' 110 break 111 elif 'Metric' in base_class_name or 'metric' in base_class_name.lower(): 112 plugin_type = 'metric' 113 break 114 elif 'Vulnerability' in base_class_name or 'vulnerability' in base_class_name.lower(): 115 plugin_type = 'vulnerability' 116 break 117 118 if not plugin_type: 119 # 尝试动态导入验证 120 plugin_type = self._dynamic_validate_class(file_path, class_node.name) 121 122 if not plugin_type: 123 result['errors'].append(f"类 {class_node.name} 没有继承正确的基类") 124 return result 125 126 result['plugin_type'] = plugin_type 127 128 # 检查必需的方法 129 required_methods = self.required_methods[plugin_type] 130 class_methods = [node.name for node in class_node.body if isinstance(node, ast.FunctionDef)] 131 132 missing_methods = [] 133 for method in required_methods: 134 if method not in class_methods: 135 missing_methods.append(method) 136 137 if missing_methods: 138 result['errors'].append(f"缺少必需的方法: {', '.join(missing_methods)}") 139 return result 140 141 # 检查是否有抽象方法 142 has_abstract_methods = self._check_abstract_methods(class_node) 143 if has_abstract_methods: 144 result['warnings'].append("类包含抽象方法,需要实现") 145 146 result['valid'] = True 147 148 except Exception as e: 149 result['errors'].append(f"验证类时发生错误: {str(e)}") 150 151 return result 152 153 def _get_base_classes(self, class_node: ast.ClassDef) -> List[str]: 154 """获取类的基类名称""" 155 base_classes = [] 156 for base in class_node.bases: 157 if isinstance(base, ast.Name): 158 base_classes.append(base.id) 159 elif isinstance(base, ast.Attribute): 160 # 处理类似 module.Class 的情况 161 base_classes.append(self._get_attribute_name(base)) 162 return base_classes 163 164 def _get_attribute_name(self, node: ast.Attribute) -> str: 165 """获取属性访问的完整名称""" 166 if isinstance(node.value, ast.Name): 167 return f"{node.value.id}.{node.attr}" 168 elif isinstance(node.value, ast.Attribute): 169 return f"{self._get_attribute_name(node.value)}.{node.attr}" 170 else: 171 return node.attr 172 173 def _check_abstract_methods(self, class_node: ast.ClassDef) -> bool: 174 """检查类是否包含抽象方法""" 175 for node in class_node.body: 176 if isinstance(node, ast.FunctionDef): 177 # 检查是否有 @abstractmethod 装饰器 178 for decorator in node.decorator_list: 179 if isinstance(decorator, ast.Name) and decorator.id == 'abstractmethod': 180 return True 181 elif isinstance(decorator, ast.Attribute) and decorator.attr == 'abstractmethod': 182 return True 183 return False 184 185 def _dynamic_validate_class(self, file_path: Path, class_name: str) -> Optional[str]: 186 """动态导入验证类""" 187 try: 188 # 动态导入模块 189 spec = importlib.util.spec_from_file_location("temp_module", file_path) 190 if spec is None: 191 return None 192 193 module = importlib.util.module_from_spec(spec) 194 if spec.loader is None: 195 return None 196 197 spec.loader.exec_module(module) 198 199 # 获取类 200 cls = getattr(module, class_name, None) 201 if not cls: 202 return None 203 204 # 检查继承关系 205 if issubclass(cls, BaseAttack): 206 return 'attack' 207 elif issubclass(cls, BaseRedTeamingMetric): 208 return 'metric' 209 elif issubclass(cls, BaseVulnerability): 210 return 'vulnerability' 211 212 except Exception as e: 213 logger.warning(f"动态验证失败: {e}") 214 215 return None 216 217 def validate_plugin_directory(self, dir_path: Path) -> Dict[str, Any]: 218 """验证插件目录""" 219 result = { 220 'valid': False, 221 'plugins': [], 222 'errors': [], 223 'warnings': [] 224 } 225 226 if not dir_path.exists() or not dir_path.is_dir(): 227 result['errors'].append(f"目录不存在或不是有效目录: {dir_path}") 228 return result 229 230 # 查找所有Python文件 231 py_files = list(dir_path.glob("*.py")) 232 233 if not py_files: 234 result['errors'].append(f"目录中没有找到Python文件: {dir_path}") 235 return result 236 237 # 验证每个文件 238 for py_file in py_files: 239 file_result = self.validate_plugin_file(py_file) 240 if file_result['valid']: 241 result['plugins'].append(file_result) 242 else: 243 result['errors'].extend([f"{py_file.name}: {error}" for error in file_result['errors']]) 244 result['warnings'].extend([f"{py_file.name}: {warning}" for warning in file_result['warnings']]) 245 246 result['valid'] = len(result['plugins']) > 0 247 248 return result