/ AIG-PromptSecurity / deepteam / plugin_system / plugin_registry.py
plugin_registry.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  from typing import Dict, List, Optional, Any, Type
 20  from pathlib import Path
 21  import importlib.util
 22  import sys
 23  from loguru import logger
 24  
 25  from deepteam.attacks import BaseAttack
 26  from deepteam.metrics import BaseRedTeamingMetric
 27  from deepteam.vulnerabilities import BaseVulnerability
 28  
 29  
 30  class PluginRegistry:
 31      """插件注册表,管理已加载的插件"""
 32      
 33      def __init__(self):
 34          self.attack_plugins: Dict[str, Type[BaseAttack]] = {}
 35          self.metric_plugins: Dict[str, Type[BaseRedTeamingMetric]] = {}
 36          self.vulnerability_plugins: Dict[str, Type[BaseVulnerability]] = {}
 37          self.plugin_metadata: Dict[str, Dict[str, Any]] = {}
 38          self.loaded_modules: Dict[str, Any] = {}
 39      
 40      def register_plugin(self, plugin_path: Path, plugin_info: Dict[str, Any]) -> bool:
 41          """注册插件"""
 42          try:
 43              plugin_type = plugin_info['plugin_type']
 44              class_name = plugin_info['class_name']
 45              
 46              # 动态加载模块
 47              module_name = f"custom_plugin_{plugin_path.stem}"
 48              spec = importlib.util.spec_from_file_location(module_name, plugin_path)
 49              
 50              if spec is None:
 51                  logger.error(f"无法创建模块规范: {plugin_path}")
 52                  return False
 53                  
 54              module = importlib.util.module_from_spec(spec)
 55              if spec.loader is None:
 56                  logger.error(f"模块加载器为空: {plugin_path}")
 57                  return False
 58                  
 59              spec.loader.exec_module(module)
 60              
 61              # 获取类
 62              plugin_class = getattr(module, class_name, None)
 63              if not plugin_class:
 64                  logger.error(f"在模块中找不到类 {class_name}: {plugin_path}")
 65                  return False
 66              
 67              # 验证类类型
 68              if plugin_type == 'attack' and not issubclass(plugin_class, BaseAttack):
 69                  logger.error(f"类 {class_name} 不是有效的攻击插件")
 70                  return False
 71              elif plugin_type == 'metric' and not issubclass(plugin_class, BaseRedTeamingMetric):
 72                  logger.error(f"类 {class_name} 不是有效的指标插件")
 73                  return False
 74              elif plugin_type == 'vulnerability' and not issubclass(plugin_class, BaseVulnerability):
 75                  logger.error(f"类 {class_name} 不是有效的漏洞插件")
 76                  return False
 77              
 78              # 注册插件
 79              plugin_key = f"{plugin_type}_{class_name}"
 80              
 81              if plugin_type == 'attack':
 82                  self.attack_plugins[class_name] = plugin_class  # type: ignore
 83              elif plugin_type == 'metric':
 84                  self.metric_plugins[class_name] = plugin_class  # type: ignore
 85              elif plugin_type == 'vulnerability':
 86                  self.vulnerability_plugins[class_name] = plugin_class  # type: ignore
 87              
 88              # 保存元数据
 89              self.plugin_metadata[plugin_key] = {
 90                  'path': str(plugin_path),
 91                  'type': plugin_type,
 92                  'class_name': class_name,
 93                  'module_name': module_name,
 94                  'info': plugin_info
 95              }
 96              
 97              # 保存模块引用
 98              self.loaded_modules[module_name] = module
 99              
100              logger.info(f"成功注册插件: {plugin_key}")
101              return True
102              
103          except Exception as e:
104              logger.error(f"注册插件失败 {plugin_path}: {e}")
105              return False
106      
107      def get_attack_plugin(self, class_name: str) -> Optional[Type[BaseAttack]]:
108          """获取攻击插件类"""
109          return self.attack_plugins.get(class_name)
110      
111      def get_metric_plugin(self, class_name: str) -> Optional[Type[BaseRedTeamingMetric]]:
112          """获取指标插件类"""
113          return self.metric_plugins.get(class_name)
114      
115      def get_vulnerability_plugin(self, class_name: str) -> Optional[Type[BaseVulnerability]]:
116          """获取漏洞插件类"""
117          return self.vulnerability_plugins.get(class_name)
118      
119      def create_attack_instance(self, class_name: str, **kwargs) -> Optional[BaseAttack]:
120          """创建攻击插件实例"""
121          plugin_class = self.get_attack_plugin(class_name)
122          if plugin_class:
123              try:
124                  return plugin_class(**kwargs)
125              except Exception as e:
126                  logger.error(f"创建攻击插件实例失败 {class_name}: {e}")
127          return None
128      
129      def create_metric_instance(self, class_name: str, **kwargs) -> Optional[BaseRedTeamingMetric]:
130          """创建指标插件实例"""
131          plugin_class = self.get_metric_plugin(class_name)
132          if plugin_class:
133              try:
134                  return plugin_class(**kwargs)
135              except Exception as e:
136                  logger.error(f"创建指标插件实例失败 {class_name}: {e}")
137          return None
138      
139      def create_vulnerability_instance(self, class_name: str, **kwargs) -> Optional[BaseVulnerability]:
140          """创建漏洞插件实例"""
141          plugin_class = self.get_vulnerability_plugin(class_name)
142          if plugin_class:
143              try:
144                  return plugin_class(**kwargs)
145              except Exception as e:
146                  logger.error(f"创建漏洞插件实例失败 {class_name}: {e}")
147          return None
148      
149      def list_plugins(self) -> Dict[str, List[str]]:
150          """列出所有已注册的插件"""
151          return {
152              'attacks': list(self.attack_plugins.keys()),
153              'metrics': list(self.metric_plugins.keys()),
154              'vulnerabilities': list(self.vulnerability_plugins.keys())
155          }
156      
157      def get_plugin_info(self, plugin_key: str) -> Optional[Dict[str, Any]]:
158          """获取插件信息"""
159          return self.plugin_metadata.get(plugin_key)
160      
161      def unregister_plugin(self, plugin_key: str) -> bool:
162          """注销插件"""
163          try:
164              if plugin_key in self.plugin_metadata:
165                  info = self.plugin_metadata[plugin_key]
166                  plugin_type = info['type']
167                  class_name = info['class_name']
168                  
169                  # 从注册表中移除
170                  if plugin_type == 'attack' and class_name in self.attack_plugins:
171                      del self.attack_plugins[class_name]
172                  elif plugin_type == 'metric' and class_name in self.metric_plugins:
173                      del self.metric_plugins[class_name]
174                  elif plugin_type == 'vulnerability' and class_name in self.vulnerability_plugins:
175                      del self.vulnerability_plugins[class_name]
176                  
177                  # 移除模块引用
178                  module_name = info['module_name']
179                  if module_name in self.loaded_modules:
180                      del self.loaded_modules[module_name]
181                  
182                  # 移除元数据
183                  del self.plugin_metadata[plugin_key]
184                  
185                  logger.info(f"成功注销插件: {plugin_key}")
186                  return True
187                  
188          except Exception as e:
189              logger.error(f"注销插件失败 {plugin_key}: {e}")
190          
191          return False
192      
193      def clear_all_plugins(self):
194          """清除所有插件"""
195          self.attack_plugins.clear()
196          self.metric_plugins.clear()
197          self.vulnerability_plugins.clear()
198          self.plugin_metadata.clear()
199          self.loaded_modules.clear()
200          logger.info("已清除所有插件")
201      
202      def get_plugin_count(self) -> Dict[str, int]:
203          """获取插件数量统计"""
204          return {
205              'attacks': len(self.attack_plugins),
206              'metrics': len(self.metric_plugins),
207              'vulnerabilities': len(self.vulnerability_plugins),
208              'total': len(self.plugin_metadata)
209          }