plugin_registry.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 from typing import Dict, List, Optional, Any, Type 20 from pathlib import Path 21 import importlib.util 22 import sys 23 from loguru import logger 24 25 from deepteam.attacks import BaseAttack 26 from deepteam.metrics import BaseRedTeamingMetric 27 from deepteam.vulnerabilities import BaseVulnerability 28 29 30 class PluginRegistry: 31 """插件注册表,管理已加载的插件""" 32 33 def __init__(self): 34 self.attack_plugins: Dict[str, Type[BaseAttack]] = {} 35 self.metric_plugins: Dict[str, Type[BaseRedTeamingMetric]] = {} 36 self.vulnerability_plugins: Dict[str, Type[BaseVulnerability]] = {} 37 self.plugin_metadata: Dict[str, Dict[str, Any]] = {} 38 self.loaded_modules: Dict[str, Any] = {} 39 40 def register_plugin(self, plugin_path: Path, plugin_info: Dict[str, Any]) -> bool: 41 """注册插件""" 42 try: 43 plugin_type = plugin_info['plugin_type'] 44 class_name = plugin_info['class_name'] 45 46 # 动态加载模块 47 module_name = f"custom_plugin_{plugin_path.stem}" 48 spec = importlib.util.spec_from_file_location(module_name, plugin_path) 49 50 if spec is None: 51 logger.error(f"无法创建模块规范: {plugin_path}") 52 return False 53 54 module = importlib.util.module_from_spec(spec) 55 if spec.loader is None: 56 logger.error(f"模块加载器为空: {plugin_path}") 57 return False 58 59 spec.loader.exec_module(module) 60 61 # 获取类 62 plugin_class = getattr(module, class_name, None) 63 if not plugin_class: 64 logger.error(f"在模块中找不到类 {class_name}: {plugin_path}") 65 return False 66 67 # 验证类类型 68 if plugin_type == 'attack' and not issubclass(plugin_class, BaseAttack): 69 logger.error(f"类 {class_name} 不是有效的攻击插件") 70 return False 71 elif plugin_type == 'metric' and not issubclass(plugin_class, BaseRedTeamingMetric): 72 logger.error(f"类 {class_name} 不是有效的指标插件") 73 return False 74 elif plugin_type == 'vulnerability' and not issubclass(plugin_class, BaseVulnerability): 75 logger.error(f"类 {class_name} 不是有效的漏洞插件") 76 return False 77 78 # 注册插件 79 plugin_key = f"{plugin_type}_{class_name}" 80 81 if plugin_type == 'attack': 82 self.attack_plugins[class_name] = plugin_class # type: ignore 83 elif plugin_type == 'metric': 84 self.metric_plugins[class_name] = plugin_class # type: ignore 85 elif plugin_type == 'vulnerability': 86 self.vulnerability_plugins[class_name] = plugin_class # type: ignore 87 88 # 保存元数据 89 self.plugin_metadata[plugin_key] = { 90 'path': str(plugin_path), 91 'type': plugin_type, 92 'class_name': class_name, 93 'module_name': module_name, 94 'info': plugin_info 95 } 96 97 # 保存模块引用 98 self.loaded_modules[module_name] = module 99 100 logger.info(f"成功注册插件: {plugin_key}") 101 return True 102 103 except Exception as e: 104 logger.error(f"注册插件失败 {plugin_path}: {e}") 105 return False 106 107 def get_attack_plugin(self, class_name: str) -> Optional[Type[BaseAttack]]: 108 """获取攻击插件类""" 109 return self.attack_plugins.get(class_name) 110 111 def get_metric_plugin(self, class_name: str) -> Optional[Type[BaseRedTeamingMetric]]: 112 """获取指标插件类""" 113 return self.metric_plugins.get(class_name) 114 115 def get_vulnerability_plugin(self, class_name: str) -> Optional[Type[BaseVulnerability]]: 116 """获取漏洞插件类""" 117 return self.vulnerability_plugins.get(class_name) 118 119 def create_attack_instance(self, class_name: str, **kwargs) -> Optional[BaseAttack]: 120 """创建攻击插件实例""" 121 plugin_class = self.get_attack_plugin(class_name) 122 if plugin_class: 123 try: 124 return plugin_class(**kwargs) 125 except Exception as e: 126 logger.error(f"创建攻击插件实例失败 {class_name}: {e}") 127 return None 128 129 def create_metric_instance(self, class_name: str, **kwargs) -> Optional[BaseRedTeamingMetric]: 130 """创建指标插件实例""" 131 plugin_class = self.get_metric_plugin(class_name) 132 if plugin_class: 133 try: 134 return plugin_class(**kwargs) 135 except Exception as e: 136 logger.error(f"创建指标插件实例失败 {class_name}: {e}") 137 return None 138 139 def create_vulnerability_instance(self, class_name: str, **kwargs) -> Optional[BaseVulnerability]: 140 """创建漏洞插件实例""" 141 plugin_class = self.get_vulnerability_plugin(class_name) 142 if plugin_class: 143 try: 144 return plugin_class(**kwargs) 145 except Exception as e: 146 logger.error(f"创建漏洞插件实例失败 {class_name}: {e}") 147 return None 148 149 def list_plugins(self) -> Dict[str, List[str]]: 150 """列出所有已注册的插件""" 151 return { 152 'attacks': list(self.attack_plugins.keys()), 153 'metrics': list(self.metric_plugins.keys()), 154 'vulnerabilities': list(self.vulnerability_plugins.keys()) 155 } 156 157 def get_plugin_info(self, plugin_key: str) -> Optional[Dict[str, Any]]: 158 """获取插件信息""" 159 return self.plugin_metadata.get(plugin_key) 160 161 def unregister_plugin(self, plugin_key: str) -> bool: 162 """注销插件""" 163 try: 164 if plugin_key in self.plugin_metadata: 165 info = self.plugin_metadata[plugin_key] 166 plugin_type = info['type'] 167 class_name = info['class_name'] 168 169 # 从注册表中移除 170 if plugin_type == 'attack' and class_name in self.attack_plugins: 171 del self.attack_plugins[class_name] 172 elif plugin_type == 'metric' and class_name in self.metric_plugins: 173 del self.metric_plugins[class_name] 174 elif plugin_type == 'vulnerability' and class_name in self.vulnerability_plugins: 175 del self.vulnerability_plugins[class_name] 176 177 # 移除模块引用 178 module_name = info['module_name'] 179 if module_name in self.loaded_modules: 180 del self.loaded_modules[module_name] 181 182 # 移除元数据 183 del self.plugin_metadata[plugin_key] 184 185 logger.info(f"成功注销插件: {plugin_key}") 186 return True 187 188 except Exception as e: 189 logger.error(f"注销插件失败 {plugin_key}: {e}") 190 191 return False 192 193 def clear_all_plugins(self): 194 """清除所有插件""" 195 self.attack_plugins.clear() 196 self.metric_plugins.clear() 197 self.vulnerability_plugins.clear() 198 self.plugin_metadata.clear() 199 self.loaded_modules.clear() 200 logger.info("已清除所有插件") 201 202 def get_plugin_count(self) -> Dict[str, int]: 203 """获取插件数量统计""" 204 return { 205 'attacks': len(self.attack_plugins), 206 'metrics': len(self.metric_plugins), 207 'vulnerabilities': len(self.vulnerability_plugins), 208 'total': len(self.plugin_metadata) 209 }