/ AIG-PromptSecurity / cli / parsers.py
parsers.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import os
 20  import re
 21  import importlib
 22  import ast
 23  from typing import List, Any, Tuple
 24  from deepteam.plugin_system import PluginManager
 25  from cli.aig_logger import logger
 26  from cli.aig_logger import (
 27      newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate
 28  )
 29  from .mappings import TECHNIQUE_CLASS_MAP, SCENARIO_CLASS_MAP, METRIC_CLASS_MAP
 30  
 31  
 32  def dynamic_import(class_path: str) -> Any:
 33      """动态导入类"""
 34      module_path, class_name = class_path.rsplit(".", 1)
 35      module = importlib.import_module(module_path)
 36      return getattr(module, class_name)
 37  
 38  
 39  def parse_kwargs(param_str: str) -> dict:
 40      """解析参数字符串为字典"""
 41      kwargs = {}
 42      
 43      # 特殊处理 prompt 参数,直接传递整个字符串
 44      if param_str.startswith("prompt="):
 45          prompt_value = param_str[7:]  # 去掉 "prompt="
 46          kwargs["prompt"] = prompt_value
 47          return kwargs
 48      
 49      # 处理其他参数
 50      params = []
 51      buf = ''
 52      bracket_level = 0
 53      for c in param_str:
 54          if c == '[':
 55              bracket_level += 1
 56          elif c == ']':
 57              bracket_level -= 1
 58          if c == ',' and bracket_level == 0:
 59              params.append(buf)
 60              buf = ''
 61          else:
 62              buf += c
 63      if buf:
 64          params.append(buf)
 65      
 66      for kv in params:
 67          if "=" in kv:
 68              k, v = kv.split("=", 1)
 69              v = v.strip()
 70              try:
 71                  v_eval = ast.literal_eval(v)
 72                  kwargs[k.strip()] = v_eval
 73              except Exception:
 74                  kwargs[k.strip()] = v
 75      return kwargs
 76  
 77  
 78  def parse_metric_class(arg: str) -> Tuple[str | None, str | None]:
 79      """解析指标类名"""
 80      if not arg:
 81          return None
 82      if ":" in arg:
 83          class_name, param_str = arg.split(":", 1)
 84          kwargs = parse_kwargs(param_str)
 85      else:
 86          class_name = arg
 87          kwargs = None
 88      return METRIC_CLASS_MAP.get(class_name, class_name), kwargs
 89  
 90  
 91  def parse_attack(arg: str, plugin_manager: PluginManager) -> Any:
 92      """解析攻击参数"""
 93      if ":" in arg:
 94          class_name, param_str = arg.split(":", 1)
 95          
 96          # 首先检查是否是自定义插件
 97          custom_attack = plugin_manager.create_attack_instance(class_name)
 98          if custom_attack:
 99              kwargs = parse_kwargs(param_str)
100              return custom_attack.__class__(**kwargs)
101          
102          # 如果不是自定义插件,使用内置映射
103          class_path = TECHNIQUE_CLASS_MAP.get(class_name)
104          if not class_path:
105              raise ValueError(f"未知的攻击类型: {class_name}")
106              
107          cls = dynamic_import(class_path)
108          kwargs = parse_kwargs(param_str)
109          return cls(**kwargs)
110      else:
111          class_name = arg
112          
113          # 首先检查是否是自定义插件
114          custom_attack = plugin_manager.create_attack_instance(class_name)
115          if custom_attack:
116              return custom_attack
117          
118          # 如果不是自定义插件,使用内置映射
119          class_path = TECHNIQUE_CLASS_MAP.get(class_name)
120          if not class_path:
121              raise ValueError(f"未知的攻击类型: {class_name}")
122              
123          cls = dynamic_import(class_path)
124          # 为不同攻击方法设置不同权重,确保均衡使用
125          if class_name == "PromptInjection":
126              return cls(weight=1)
127          elif class_name == "Roleplay":
128              return cls(weight=1)
129          elif class_name == "Base64":
130              return cls(weight=1)
131          else:
132              return cls()
133  
134  
135  def parse_vulnerability(arg: str, plugin_manager: PluginManager):
136      """解析漏洞参数"""
137      if ":" in arg:
138          class_name, param_str = arg.split(":", 1)
139          
140          # 首先检查是否是自定义插件
141          custom_vulnerability = plugin_manager.create_vulnerability_instance(class_name)
142          if custom_vulnerability:
143              kwargs = parse_kwargs(param_str)
144              return [custom_vulnerability.__class__(**kwargs)]
145          
146          if class_name == "Custom":
147              from deepteam.vulnerabilities import CustomPrompt
148              kwargs = parse_kwargs(param_str)
149              logger.debug(f"Creating CustomPrompt with kwargs: {kwargs}")
150              
151              # 为每个prompt创建独立的CustomPrompt对象
152              if 'prompt' in kwargs:
153                  prompt_value = kwargs['prompt']
154                  if isinstance(prompt_value, str):
155                      return [CustomPrompt(**kwargs)], prompt_value
156              elif 'prompt_file' in kwargs:
157                  # 处理prompt_file参数,为每个prompt创建独立的vulnerability对象
158                  prompt_file = kwargs['prompt_file']
159                  # 先创建一个临时的CustomPrompt来获取prompts和元数据
160                  temp_vuln = CustomPrompt(prompt_file=prompt_file)
161                  prompts = temp_vuln.prompts
162                  metadata = temp_vuln.metadata
163                  
164                  vulnerabilities = []
165                  for i, (prompt, meta) in enumerate(zip(prompts, metadata)):
166                      vuln = CustomPrompt(prompt=prompt)
167                      # 使用元数据中的信息来命名vulnerability(仅对文件输入)
168                      category = meta.get('category', 'custom')
169                      vuln.name = f"{category}"
170                      vulnerabilities.append(vuln)
171                  return vulnerabilities, os.path.basename(prompt_file)
172              else:
173                  return [CustomPrompt(**kwargs)], None
174          elif class_name == "MultiDataset":
175              from deepteam.vulnerabilities import MultiDatasetVulnerability
176              kwargs = parse_kwargs(param_str)
177              logger.debug(f"Creating MultiDatasetVulnerability with kwargs: {kwargs}")
178              
179              # 处理MultiDatasetVulnerability的特殊参数
180              dataset_file = kwargs.get('dataset_file')
181              num_prompts = kwargs.get('num_prompts', 10)
182              random_seed = kwargs.get('random_seed')
183              prompt_column = kwargs.get('prompt_column')
184              filter_conditions = kwargs.get('filter_conditions')
185              
186              # 创建MultiDatasetVulnerability对象
187              vuln = MultiDatasetVulnerability(
188                  dataset_file=dataset_file,
189                  num_prompts=num_prompts,
190                  random_seed=random_seed,
191                  prompt_column=prompt_column,
192                  filter_conditions=filter_conditions
193              )
194              tag = vuln.dataset_name
195              if tag is None:
196                  source_file = os.path.basename(dataset_file)
197                  dataset_name = os.path.splitext(source_file)[0]
198                  match = re.search(r"(.*?)(?:-\d{16})?$", dataset_name)
199                  tag = match.group(1) if match else dataset_name
200  
201              # 为每个prompt创建独立的vulnerability对象
202              vulnerabilities = []
203              for i, prompt in enumerate(vuln.prompts):
204                  # 创建新的MultiDatasetVulnerability实例,但只包含单个prompt
205                  single_vuln = MultiDatasetVulnerability(
206                      prompt=prompt,
207                  )
208                  single_vuln.types = vuln.types
209                  
210                  # 使用元数据中的信息来命名vulnerability
211                  meta = vuln.metadata[i]
212                  single_vuln.metadata = [meta]
213                  category = meta.get('category')
214                  single_vuln.name = f"{tag}-{category}" if category else f"{tag}"
215                  
216                  vulnerabilities.append(single_vuln)
217              
218              return vulnerabilities, os.path.basename(dataset_file)
219          else:
220              # 如果不是自定义插件,使用内置映射
221              class_path = SCENARIO_CLASS_MAP.get(class_name)
222              if not class_path:
223                  raise ValueError(f"未知的漏洞类型: {class_name}")
224                  
225              cls = dynamic_import(class_path)
226              kwargs = parse_kwargs(param_str)
227              return [cls(**kwargs)], None
228      else:
229          class_name = arg
230          
231          # 首先检查是否是自定义插件
232          custom_vulnerability = plugin_manager.create_vulnerability_instance(class_name)
233          if custom_vulnerability:
234              return [custom_vulnerability], None
235          
236          if class_name == "Custom":
237              from deepteam.vulnerabilities import CustomPrompt
238              return [CustomPrompt()], None
239          elif class_name == "MultiDataset":
240              from deepteam.vulnerabilities import MultiDatasetVulnerability
241              return [MultiDatasetVulnerability()], None
242          else:
243              # 如果不是自定义插件,使用内置映射
244              class_path = SCENARIO_CLASS_MAP.get(class_name)
245              if not class_path:
246                  raise ValueError(f"未知的漏洞类型: {class_name}")
247                  
248              cls = dynamic_import(class_path)
249              return [cls()], None