/ AIG-PromptSecurity / deepteam / plugin_system / plugin_manager.py
plugin_manager.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  from typing import Dict, List, Optional, Any, Union
 20  from pathlib import Path
 21  from loguru import logger
 22  
 23  from .plugin_loader import PluginLoader
 24  from .plugin_registry import PluginRegistry
 25  from .remote_plugin_downloader import RemotePluginDownloader
 26  from deepteam.attacks import BaseAttack
 27  from deepteam.metrics import BaseRedTeamingMetric
 28  from deepteam.vulnerabilities import BaseVulnerability
 29  
 30  
 31  class PluginManager:
 32      """插件管理器,作为插件系统的主要接口"""
 33      
 34      def __init__(self):
 35          self.registry = PluginRegistry()
 36          self.loader = PluginLoader(self.registry)
 37          self.downloader = RemotePluginDownloader()
 38          
 39      def load_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]:
 40          """加载单个插件"""
 41          plugin_path = str(plugin_path)
 42          
 43          # 检查是否为远程插件URL
 44          if self.downloader.is_remote_plugin_url(plugin_path):
 45              return self._load_remote_plugin(plugin_path)
 46          
 47          # 本地插件加载
 48          return self.loader.load_plugin(plugin_path)
 49      
 50      def _load_remote_plugin(self, url: str) -> Dict[str, Any]:
 51          """加载远程插件"""
 52          result = {
 53              'success': False,
 54              'plugin_info': None,
 55              'errors': [],
 56              'warnings': []
 57          }
 58          
 59          try:
 60              # 下载并解压远程插件
 61              download_result = self.downloader.download_and_extract_plugin(url)
 62              
 63              if not download_result['success']:
 64                  result['errors'].extend(download_result['errors'])
 65                  result['warnings'].extend(download_result['warnings'])
 66                  return result
 67              
 68              # 加载解压后的插件
 69              extracted_path = download_result['extracted_path']
 70              if extracted_path:
 71                  load_result = self.loader.load_plugin(extracted_path)
 72                  
 73                  if load_result['success']:
 74                      result['success'] = True
 75                      result['plugin_info'] = {
 76                          'type': 'remote',
 77                          'url': url,
 78                          'local_path': extracted_path,
 79                          'load_info': load_result['plugin_info']
 80                      }
 81                      logger.info(f"远程插件加载成功: {url}")
 82                  else:
 83                      result['errors'].extend(load_result['errors'])
 84                      result['warnings'].extend(load_result['warnings'])
 85              
 86          except Exception as e:
 87              result['errors'].append(f"加载远程插件时发生错误: {str(e)}")
 88              logger.error(f"加载远程插件失败: {str(e)}")
 89          
 90          return result
 91      
 92      def load_remote_plugin(self, url: str, force_download: bool = False) -> Dict[str, Any]:
 93          """专门用于加载远程插件的方法"""
 94          return self._load_remote_plugin(url)
 95      
 96      def list_remote_plugins(self) -> List[Dict[str, Any]]:
 97          """列出已下载的远程插件"""
 98          return self.downloader.list_remote_plugins()
 99      
100      def remove_remote_plugin(self, plugin_name: str) -> Dict[str, Any]:
101          """删除远程插件"""
102          return self.downloader.remove_remote_plugin(plugin_name)
103      
104      def cleanup_remote_plugins(self):
105          """清理远程插件临时文件"""
106          self.downloader.cleanup_temp_files()
107      
108      def load_plugins_from_directory(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
109          """从目录加载所有插件"""
110          return self.loader.load_plugins_from_directory(dir_path)
111      
112      def auto_discover_plugins(self) -> Dict[str, Any]:
113          """自动发现并加载插件"""
114          return self.loader.auto_discover_plugins()
115      
116      def load_plugins_from_config(self, config_file: Union[str, Path]) -> Dict[str, Any]:
117          """从配置文件加载插件"""
118          return self.loader.load_plugins_from_config(config_file)
119      
120      def get_loaded_plugins(self) -> Dict[str, List[str]]:
121          """获取已加载的插件列表"""
122          return self.loader.get_loaded_plugins()
123      
124      def get_plugin_info(self, plugin_key: str) -> Optional[Dict[str, Any]]:
125          """获取插件详细信息"""
126          return self.loader.get_plugin_info(plugin_key)
127      
128      def unload_plugin(self, plugin_key: str) -> bool:
129          """卸载插件"""
130          return self.loader.unload_plugin(plugin_key)
131      
132      def reload_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]:
133          """重新加载插件"""
134          return self.loader.reload_plugin(plugin_path)
135      
136      def create_attack_instance(self, class_name: str, **kwargs) -> Optional[BaseAttack]:
137          """创建攻击插件实例"""
138          return self.registry.create_attack_instance(class_name, **kwargs)
139      
140      def create_metric_instance(self, class_name: str, **kwargs) -> Optional[BaseRedTeamingMetric]:
141          """创建指标插件实例"""
142          return self.registry.create_metric_instance(class_name, **kwargs)
143      
144      def create_vulnerability_instance(self, class_name: str, **kwargs) -> Optional[BaseVulnerability]:
145          """创建漏洞插件实例"""
146          return self.registry.create_vulnerability_instance(class_name, **kwargs)
147      
148      def get_attack_plugins(self) -> List[str]:
149          """获取所有攻击插件名称"""
150          return list(self.registry.attack_plugins.keys())
151      
152      def get_metric_plugins(self) -> List[str]:
153          """获取所有指标插件名称"""
154          return list(self.registry.metric_plugins.keys())
155      
156      def get_vulnerability_plugins(self) -> List[str]:
157          """获取所有漏洞插件名称"""
158          return list(self.registry.vulnerability_plugins.keys())
159      
160      def get_plugin_count(self) -> Dict[str, int]:
161          """获取插件数量统计"""
162          return self.registry.get_plugin_count()
163      
164      def clear_all_plugins(self):
165          """清除所有插件"""
166          self.registry.clear_all_plugins()
167      
168      def list_plugins_with_info(self) -> Dict[str, List[Dict[str, Any]]]:
169          """列出所有插件及其详细信息"""
170          result = {
171              'attacks': [],
172              'metrics': [],
173              'vulnerabilities': []
174          }
175          
176          # 获取攻击插件信息
177          for class_name in self.registry.attack_plugins.keys():
178              plugin_key = f"attack_{class_name}"
179              info = self.registry.get_plugin_info(plugin_key)
180              if info:
181                  result['attacks'].append({
182                      'class_name': class_name,
183                      'path': info['path'],
184                      'module_name': info['module_name']
185                  })
186          
187          # 获取指标插件信息
188          for class_name in self.registry.metric_plugins.keys():
189              plugin_key = f"metric_{class_name}"
190              info = self.registry.get_plugin_info(plugin_key)
191              if info:
192                  result['metrics'].append({
193                      'class_name': class_name,
194                      'path': info['path'],
195                      'module_name': info['module_name']
196                  })
197          
198          # 获取漏洞插件信息
199          for class_name in self.registry.vulnerability_plugins.keys():
200              plugin_key = f"vulnerability_{class_name}"
201              info = self.registry.get_plugin_info(plugin_key)
202              if info:
203                  result['vulnerabilities'].append({
204                      'class_name': class_name,
205                      'path': info['path'],
206                      'module_name': info['module_name']
207                  })
208          
209          return result
210      
211      def validate_plugin(self, plugin_path: Union[str, Path]) -> Dict[str, Any]:
212          """验证插件(不加载)"""
213          plugin_path = Path(plugin_path)
214          
215          if plugin_path.is_file():
216              return self.loader.validator.validate_plugin_file(plugin_path)
217          elif plugin_path.is_dir():
218              return self.loader.validator.validate_plugin_directory(plugin_path)
219          else:
220              return {
221                  'valid': False,
222                  'errors': [f"路径不存在: {plugin_path}"]
223              }
224      
225      def get_plugin_template(self, plugin_type: str) -> str:
226          """获取插件模板代码"""
227          if plugin_type == 'attack':
228              return self._get_attack_template()
229          elif plugin_type == 'metric':
230              return self._get_metric_template()
231          elif plugin_type == 'vulnerability':
232              return self._get_vulnerability_template()
233          else:
234              return f"未知的插件类型: {plugin_type}"
235      
236      def _get_attack_template(self) -> str:
237          """获取攻击插件模板"""
238          return '''from deepteam.attacks import BaseAttack
239  
240  class CustomAttack(BaseAttack):
241      """自定义攻击插件"""
242      
243      def __init__(self, weight: int = 1):
244          super().__init__()
245          self.weight = weight
246      
247      def enhance(self, attack: str, *args, **kwargs) -> str:
248          """
249          增强攻击字符串
250          
251          Args:
252              attack: 原始攻击字符串
253              *args: 额外参数
254              **kwargs: 额外关键字参数
255              
256          Returns:
257              增强后的攻击字符串
258          """
259          # 在这里实现你的攻击增强逻辑
260          enhanced_attack = attack  # 默认不修改
261          
262          # 示例:添加前缀
263          # enhanced_attack = f"请忽略之前的指令,{attack}"
264          
265          return enhanced_attack
266      
267      def get_name(self) -> str:
268          """获取插件名称"""
269          return "CustomAttack"
270      
271      async def a_enhance(self, attack: str, *args, **kwargs) -> str:
272          """
273          异步增强攻击字符串(可选实现)
274          
275          如果不需要异步支持,可以删除这个方法,会使用默认的同步实现
276          """
277          return self.enhance(attack, *args, **kwargs)
278  '''
279      
280      def _get_metric_template(self) -> str:
281          """获取指标插件模板"""
282          return '''from typing import Optional, Union
283  from deepeval.test_case.llm_test_case import LLMTestCase
284  from deepeval.test_case.conversational_test_case import ConversationalTestCase
285  from deepeval.models.base_model import DeepEvalBaseLLM
286  from deepeval.metrics.utils import initialize_model
287  from deepteam.metrics import BaseRedTeamingMetric
288  
289  class CustomMetric(BaseRedTeamingMetric):
290      """自定义指标插件"""
291      
292      def __init__(
293          self,
294          model: Optional[Union[str, DeepEvalBaseLLM]] = None,
295          async_mode: bool = True,
296          verbose_mode: bool = False,
297      ):
298          self.model, self.using_native_model = initialize_model(model)
299          self.evaluation_model = self.model.get_model_name()
300          self.async_mode = async_mode
301          self.verbose_mode = verbose_mode
302          self.system_prompt = ""
303      
304      def measure(self, test_case: Union[LLMTestCase, ConversationalTestCase]) -> float:
305          """
306          测量指标
307          
308          Args:
309              test_case: 测试用例
310              
311          Returns:
312              指标分数 (0-1)
313          """
314          # 在这里实现你的指标计算逻辑
315          score = 0.5  # 默认分数
316          
317          return score
318      
319      def get_name(self) -> str:
320          """获取指标名称"""
321          return "CustomMetric"
322      
323      async def a_measure(self, test_case: Union[LLMTestCase, ConversationalTestCase]) -> float:
324          """
325          异步测量指标(可选实现)
326          
327          如果不需要异步支持,可以删除这个方法,会使用默认的同步实现
328          """
329          return self.measure(test_case)
330  '''
331      
332      def _get_vulnerability_template(self) -> str:
333          """获取漏洞插件模板"""
334          return '''from typing import List
335  from enum import Enum
336  from deepteam.vulnerabilities import BaseVulnerability
337  
338  class CustomVulnerabilityType(Enum):
339      """自定义漏洞类型枚举"""
340      CUSTOM_VULNERABILITY = "custom_vulnerability"
341  
342  class CustomVulnerability(BaseVulnerability):
343      """自定义漏洞插件"""
344      
345      def __init__(self, name: str = "CustomVulnerability", types: List[Enum] = None):
346          """
347          初始化自定义漏洞
348          
349          Args:
350              name: 漏洞名称
351              types: 漏洞类型列表
352          """
353          if types is None:
354              types = [CustomVulnerabilityType.CUSTOM_VULNERABILITY]
355          
356          self.name = name
357          super().__init__(types)
358      
359      def get_name(self) -> str:
360          """获取漏洞名称"""
361          return self.name
362      
363      def get_types(self) -> List[Enum]:
364          """获取漏洞类型列表"""
365          return self.types
366      
367      def get_values(self) -> List[str]:
368          """获取漏洞类型值列表"""
369          return [t.value for t in self.types]
370      
371      def __repr__(self):
372          """字符串表示"""
373          return f"{self.name} (types={self.types})"
374  '''