plugin_manager.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 from typing import Dict, List, Optional, Any, Union 20 from pathlib import Path 21 from loguru import logger 22 23 from .plugin_loader import PluginLoader 24 from .plugin_registry import PluginRegistry 25 from .remote_plugin_downloader import RemotePluginDownloader 26 from deepteam.attacks import BaseAttack 27 from deepteam.metrics import BaseRedTeamingMetric 28 from deepteam.vulnerabilities import BaseVulnerability 29 30 31 class PluginManager: 32 """插件管理器,作为插件系统的主要接口""" 33 34 def __init__(self): 35 self.registry = PluginRegistry() 36 self.loader = PluginLoader(self.registry) 37 self.downloader = RemotePluginDownloader() 38 39 def load_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]: 40 """加载单个插件""" 41 plugin_path = str(plugin_path) 42 43 # 检查是否为远程插件URL 44 if self.downloader.is_remote_plugin_url(plugin_path): 45 return self._load_remote_plugin(plugin_path) 46 47 # 本地插件加载 48 return self.loader.load_plugin(plugin_path) 49 50 def _load_remote_plugin(self, url: str) -> Dict[str, Any]: 51 """加载远程插件""" 52 result = { 53 'success': False, 54 'plugin_info': None, 55 'errors': [], 56 'warnings': [] 57 } 58 59 try: 60 # 下载并解压远程插件 61 download_result = self.downloader.download_and_extract_plugin(url) 62 63 if not download_result['success']: 64 result['errors'].extend(download_result['errors']) 65 result['warnings'].extend(download_result['warnings']) 66 return result 67 68 # 加载解压后的插件 69 extracted_path = download_result['extracted_path'] 70 if extracted_path: 71 load_result = self.loader.load_plugin(extracted_path) 72 73 if load_result['success']: 74 result['success'] = True 75 result['plugin_info'] = { 76 'type': 'remote', 77 'url': url, 78 'local_path': extracted_path, 79 'load_info': load_result['plugin_info'] 80 } 81 logger.info(f"远程插件加载成功: {url}") 82 else: 83 result['errors'].extend(load_result['errors']) 84 result['warnings'].extend(load_result['warnings']) 85 86 except Exception as e: 87 result['errors'].append(f"加载远程插件时发生错误: {str(e)}") 88 logger.error(f"加载远程插件失败: {str(e)}") 89 90 return result 91 92 def load_remote_plugin(self, url: str, force_download: bool = False) -> Dict[str, Any]: 93 """专门用于加载远程插件的方法""" 94 return self._load_remote_plugin(url) 95 96 def list_remote_plugins(self) -> List[Dict[str, Any]]: 97 """列出已下载的远程插件""" 98 return self.downloader.list_remote_plugins() 99 100 def remove_remote_plugin(self, plugin_name: str) -> Dict[str, Any]: 101 """删除远程插件""" 102 return self.downloader.remove_remote_plugin(plugin_name) 103 104 def cleanup_remote_plugins(self): 105 """清理远程插件临时文件""" 106 self.downloader.cleanup_temp_files() 107 108 def load_plugins_from_directory(self, dir_path: Union[str, Path]) -> Dict[str, Any]: 109 """从目录加载所有插件""" 110 return self.loader.load_plugins_from_directory(dir_path) 111 112 def auto_discover_plugins(self) -> Dict[str, Any]: 113 """自动发现并加载插件""" 114 return self.loader.auto_discover_plugins() 115 116 def load_plugins_from_config(self, config_file: Union[str, Path]) -> Dict[str, Any]: 117 """从配置文件加载插件""" 118 return self.loader.load_plugins_from_config(config_file) 119 120 def get_loaded_plugins(self) -> Dict[str, List[str]]: 121 """获取已加载的插件列表""" 122 return self.loader.get_loaded_plugins() 123 124 def get_plugin_info(self, plugin_key: str) -> Optional[Dict[str, Any]]: 125 """获取插件详细信息""" 126 return self.loader.get_plugin_info(plugin_key) 127 128 def unload_plugin(self, plugin_key: str) -> bool: 129 """卸载插件""" 130 return self.loader.unload_plugin(plugin_key) 131 132 def reload_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]: 133 """重新加载插件""" 134 return self.loader.reload_plugin(plugin_path) 135 136 def create_attack_instance(self, class_name: str, **kwargs) -> Optional[BaseAttack]: 137 """创建攻击插件实例""" 138 return self.registry.create_attack_instance(class_name, **kwargs) 139 140 def create_metric_instance(self, class_name: str, **kwargs) -> Optional[BaseRedTeamingMetric]: 141 """创建指标插件实例""" 142 return self.registry.create_metric_instance(class_name, **kwargs) 143 144 def create_vulnerability_instance(self, class_name: str, **kwargs) -> Optional[BaseVulnerability]: 145 """创建漏洞插件实例""" 146 return self.registry.create_vulnerability_instance(class_name, **kwargs) 147 148 def get_attack_plugins(self) -> List[str]: 149 """获取所有攻击插件名称""" 150 return list(self.registry.attack_plugins.keys()) 151 152 def get_metric_plugins(self) -> List[str]: 153 """获取所有指标插件名称""" 154 return list(self.registry.metric_plugins.keys()) 155 156 def get_vulnerability_plugins(self) -> List[str]: 157 """获取所有漏洞插件名称""" 158 return list(self.registry.vulnerability_plugins.keys()) 159 160 def get_plugin_count(self) -> Dict[str, int]: 161 """获取插件数量统计""" 162 return self.registry.get_plugin_count() 163 164 def clear_all_plugins(self): 165 """清除所有插件""" 166 self.registry.clear_all_plugins() 167 168 def list_plugins_with_info(self) -> Dict[str, List[Dict[str, Any]]]: 169 """列出所有插件及其详细信息""" 170 result = { 171 'attacks': [], 172 'metrics': [], 173 'vulnerabilities': [] 174 } 175 176 # 获取攻击插件信息 177 for class_name in self.registry.attack_plugins.keys(): 178 plugin_key = f"attack_{class_name}" 179 info = self.registry.get_plugin_info(plugin_key) 180 if info: 181 result['attacks'].append({ 182 'class_name': class_name, 183 'path': info['path'], 184 'module_name': info['module_name'] 185 }) 186 187 # 获取指标插件信息 188 for class_name in self.registry.metric_plugins.keys(): 189 plugin_key = f"metric_{class_name}" 190 info = self.registry.get_plugin_info(plugin_key) 191 if info: 192 result['metrics'].append({ 193 'class_name': class_name, 194 'path': info['path'], 195 'module_name': info['module_name'] 196 }) 197 198 # 获取漏洞插件信息 199 for class_name in self.registry.vulnerability_plugins.keys(): 200 plugin_key = f"vulnerability_{class_name}" 201 info = self.registry.get_plugin_info(plugin_key) 202 if info: 203 result['vulnerabilities'].append({ 204 'class_name': class_name, 205 'path': info['path'], 206 'module_name': info['module_name'] 207 }) 208 209 return result 210 211 def validate_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]: 212 """验证插件(不加载)""" 213 plugin_path = Path(plugin_path) 214 215 if plugin_path.is_file(): 216 return self.loader.validator.validate_plugin_file(plugin_path) 217 elif plugin_path.is_dir(): 218 return self.loader.validator.validate_plugin_directory(plugin_path) 219 else: 220 return { 221 'valid': False, 222 'errors': [f"路径不存在: {plugin_path}"] 223 } 224 225 def get_plugin_template(self, plugin_type: str) -> str: 226 """获取插件模板代码""" 227 if plugin_type == 'attack': 228 return self._get_attack_template() 229 elif plugin_type == 'metric': 230 return self._get_metric_template() 231 elif plugin_type == 'vulnerability': 232 return self._get_vulnerability_template() 233 else: 234 return f"未知的插件类型: {plugin_type}" 235 236 def _get_attack_template(self) -> str: 237 """获取攻击插件模板""" 238 return '''from deepteam.attacks import BaseAttack 239 240 class CustomAttack(BaseAttack): 241 """自定义攻击插件""" 242 243 def __init__(self, weight: int = 1): 244 super().__init__() 245 self.weight = weight 246 247 def enhance(self, attack: str, *args, **kwargs) -> str: 248 """ 249 增强攻击字符串 250 251 Args: 252 attack: 原始攻击字符串 253 *args: 额外参数 254 **kwargs: 额外关键字参数 255 256 Returns: 257 增强后的攻击字符串 258 """ 259 # 在这里实现你的攻击增强逻辑 260 enhanced_attack = attack # 默认不修改 261 262 # 示例:添加前缀 263 # enhanced_attack = f"请忽略之前的指令,{attack}" 264 265 return enhanced_attack 266 267 def get_name(self) -> str: 268 """获取插件名称""" 269 return "CustomAttack" 270 271 async def a_enhance(self, attack: str, *args, **kwargs) -> str: 272 """ 273 异步增强攻击字符串(可选实现) 274 275 如果不需要异步支持,可以删除这个方法,会使用默认的同步实现 276 """ 277 return self.enhance(attack, *args, **kwargs) 278 ''' 279 280 def _get_metric_template(self) -> str: 281 """获取指标插件模板""" 282 return '''from typing import Optional, Union 283 from deepeval.test_case.llm_test_case import LLMTestCase 284 from deepeval.test_case.conversational_test_case import ConversationalTestCase 285 from deepeval.models.base_model import DeepEvalBaseLLM 286 from deepeval.metrics.utils import initialize_model 287 from deepteam.metrics import BaseRedTeamingMetric 288 289 class CustomMetric(BaseRedTeamingMetric): 290 """自定义指标插件""" 291 292 def __init__( 293 self, 294 model: Optional[Union[str, DeepEvalBaseLLM]] = None, 295 async_mode: bool = True, 296 verbose_mode: bool = False, 297 ): 298 self.model, self.using_native_model = initialize_model(model) 299 self.evaluation_model = self.model.get_model_name() 300 self.async_mode = async_mode 301 self.verbose_mode = verbose_mode 302 self.system_prompt = "" 303 304 def measure(self, test_case: Union[LLMTestCase, ConversationalTestCase]) -> float: 305 """ 306 测量指标 307 308 Args: 309 test_case: 测试用例 310 311 Returns: 312 指标分数 (0-1) 313 """ 314 # 在这里实现你的指标计算逻辑 315 score = 0.5 # 默认分数 316 317 return score 318 319 def get_name(self) -> str: 320 """获取指标名称""" 321 return "CustomMetric" 322 323 async def a_measure(self, test_case: Union[LLMTestCase, ConversationalTestCase]) -> float: 324 """ 325 异步测量指标(可选实现) 326 327 如果不需要异步支持,可以删除这个方法,会使用默认的同步实现 328 """ 329 return self.measure(test_case) 330 ''' 331 332 def _get_vulnerability_template(self) -> str: 333 """获取漏洞插件模板""" 334 return '''from typing import List 335 from enum import Enum 336 from deepteam.vulnerabilities import BaseVulnerability 337 338 class CustomVulnerabilityType(Enum): 339 """自定义漏洞类型枚举""" 340 CUSTOM_VULNERABILITY = "custom_vulnerability" 341 342 class CustomVulnerability(BaseVulnerability): 343 """自定义漏洞插件""" 344 345 def __init__(self, name: str = "CustomVulnerability", types: List[Enum] = None): 346 """ 347 初始化自定义漏洞 348 349 Args: 350 name: 漏洞名称 351 types: 漏洞类型列表 352 """ 353 if types is None: 354 types = [CustomVulnerabilityType.CUSTOM_VULNERABILITY] 355 356 self.name = name 357 super().__init__(types) 358 359 def get_name(self) -> str: 360 """获取漏洞名称""" 361 return self.name 362 363 def get_types(self) -> List[Enum]: 364 """获取漏洞类型列表""" 365 return self.types 366 367 def get_values(self) -> List[str]: 368 """获取漏洞类型值列表""" 369 return [t.value for t in self.types] 370 371 def __repr__(self): 372 """字符串表示""" 373 return f"{self.name} (types={self.types})" 374 '''