parsers.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import os 20 import re 21 import importlib 22 import ast 23 from typing import List, Any, Tuple 24 from deepteam.plugin_system import PluginManager 25 from cli.aig_logger import logger 26 from cli.aig_logger import ( 27 newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate 28 ) 29 from .mappings import TECHNIQUE_CLASS_MAP, SCENARIO_CLASS_MAP, METRIC_CLASS_MAP 30 31 32 def dynamic_import(class_path: str) -> Any: 33 """动态导入类""" 34 module_path, class_name = class_path.rsplit(".", 1) 35 module = importlib.import_module(module_path) 36 return getattr(module, class_name) 37 38 39 def parse_kwargs(param_str: str) -> dict: 40 """解析参数字符串为字典""" 41 kwargs = {} 42 43 # 特殊处理 prompt 参数,直接传递整个字符串 44 if param_str.startswith("prompt="): 45 prompt_value = param_str[7:] # 去掉 "prompt=" 46 kwargs["prompt"] = prompt_value 47 return kwargs 48 49 # 处理其他参数 50 params = [] 51 buf = '' 52 bracket_level = 0 53 for c in param_str: 54 if c == '[': 55 bracket_level += 1 56 elif c == ']': 57 bracket_level -= 1 58 if c == ',' and bracket_level == 0: 59 params.append(buf) 60 buf = '' 61 else: 62 buf += c 63 if buf: 64 params.append(buf) 65 66 for kv in params: 67 if "=" in kv: 68 k, v = kv.split("=", 1) 69 v = v.strip() 70 try: 71 v_eval = ast.literal_eval(v) 72 kwargs[k.strip()] = v_eval 73 except Exception: 74 kwargs[k.strip()] = v 75 return kwargs 76 77 78 def parse_metric_class(arg: str) -> Tuple[str | None, str | None]: 79 """解析指标类名""" 80 if not arg: 81 return None 82 if ":" in arg: 83 class_name, param_str = arg.split(":", 1) 84 kwargs = parse_kwargs(param_str) 85 else: 86 class_name = arg 87 kwargs = None 88 return METRIC_CLASS_MAP.get(class_name, class_name), kwargs 89 90 91 def parse_attack(arg: str, plugin_manager: PluginManager) -> Any: 92 """解析攻击参数""" 93 if ":" in arg: 94 class_name, param_str = arg.split(":", 1) 95 96 # 首先检查是否是自定义插件 97 custom_attack = plugin_manager.create_attack_instance(class_name) 98 if custom_attack: 99 kwargs = parse_kwargs(param_str) 100 return custom_attack.__class__(**kwargs) 101 102 # 如果不是自定义插件,使用内置映射 103 class_path = TECHNIQUE_CLASS_MAP.get(class_name) 104 if not class_path: 105 raise ValueError(f"未知的攻击类型: {class_name}") 106 107 cls = dynamic_import(class_path) 108 kwargs = parse_kwargs(param_str) 109 return cls(**kwargs) 110 else: 111 class_name = arg 112 113 # 首先检查是否是自定义插件 114 custom_attack = plugin_manager.create_attack_instance(class_name) 115 if custom_attack: 116 return custom_attack 117 118 # 如果不是自定义插件,使用内置映射 119 class_path = TECHNIQUE_CLASS_MAP.get(class_name) 120 if not class_path: 121 raise ValueError(f"未知的攻击类型: {class_name}") 122 123 cls = dynamic_import(class_path) 124 # 为不同攻击方法设置不同权重,确保均衡使用 125 if class_name == "PromptInjection": 126 return cls(weight=1) 127 elif class_name == "Roleplay": 128 return cls(weight=1) 129 elif class_name == "Base64": 130 return cls(weight=1) 131 else: 132 return cls() 133 134 135 def parse_vulnerability(arg: str, plugin_manager: PluginManager): 136 """解析漏洞参数""" 137 if ":" in arg: 138 class_name, param_str = arg.split(":", 1) 139 140 # 首先检查是否是自定义插件 141 custom_vulnerability = plugin_manager.create_vulnerability_instance(class_name) 142 if custom_vulnerability: 143 kwargs = parse_kwargs(param_str) 144 return [custom_vulnerability.__class__(**kwargs)] 145 146 if class_name == "Custom": 147 from deepteam.vulnerabilities import CustomPrompt 148 kwargs = parse_kwargs(param_str) 149 logger.debug(f"Creating CustomPrompt with kwargs: {kwargs}") 150 151 # 为每个prompt创建独立的CustomPrompt对象 152 if 'prompt' in kwargs: 153 prompt_value = kwargs['prompt'] 154 if isinstance(prompt_value, str): 155 return [CustomPrompt(**kwargs)], prompt_value 156 elif 'prompt_file' in kwargs: 157 # 处理prompt_file参数,为每个prompt创建独立的vulnerability对象 158 prompt_file = kwargs['prompt_file'] 159 # 先创建一个临时的CustomPrompt来获取prompts和元数据 160 temp_vuln = CustomPrompt(prompt_file=prompt_file) 161 prompts = temp_vuln.prompts 162 metadata = temp_vuln.metadata 163 164 vulnerabilities = [] 165 for i, (prompt, meta) in enumerate(zip(prompts, metadata)): 166 vuln = CustomPrompt(prompt=prompt) 167 # 使用元数据中的信息来命名vulnerability(仅对文件输入) 168 category = meta.get('category', 'custom') 169 vuln.name = f"{category}" 170 vulnerabilities.append(vuln) 171 return vulnerabilities, os.path.basename(prompt_file) 172 else: 173 return [CustomPrompt(**kwargs)], None 174 elif class_name == "MultiDataset": 175 from deepteam.vulnerabilities import MultiDatasetVulnerability 176 kwargs = parse_kwargs(param_str) 177 logger.debug(f"Creating MultiDatasetVulnerability with kwargs: {kwargs}") 178 179 # 处理MultiDatasetVulnerability的特殊参数 180 dataset_file = kwargs.get('dataset_file') 181 num_prompts = kwargs.get('num_prompts', 10) 182 random_seed = kwargs.get('random_seed') 183 prompt_column = kwargs.get('prompt_column') 184 filter_conditions = kwargs.get('filter_conditions') 185 186 # 创建MultiDatasetVulnerability对象 187 vuln = MultiDatasetVulnerability( 188 dataset_file=dataset_file, 189 num_prompts=num_prompts, 190 random_seed=random_seed, 191 prompt_column=prompt_column, 192 filter_conditions=filter_conditions 193 ) 194 tag = vuln.dataset_name 195 if tag is None: 196 source_file = os.path.basename(dataset_file) 197 dataset_name = os.path.splitext(source_file)[0] 198 match = re.search(r"(.*?)(?:-\d{16})?$", dataset_name) 199 tag = match.group(1) if match else dataset_name 200 201 # 为每个prompt创建独立的vulnerability对象 202 vulnerabilities = [] 203 for i, prompt in enumerate(vuln.prompts): 204 # 创建新的MultiDatasetVulnerability实例,但只包含单个prompt 205 single_vuln = MultiDatasetVulnerability( 206 prompt=prompt, 207 ) 208 single_vuln.types = vuln.types 209 210 # 使用元数据中的信息来命名vulnerability 211 meta = vuln.metadata[i] 212 single_vuln.metadata = [meta] 213 category = meta.get('category') 214 single_vuln.name = f"{tag}-{category}" if category else f"{tag}" 215 216 vulnerabilities.append(single_vuln) 217 218 return vulnerabilities, os.path.basename(dataset_file) 219 else: 220 # 如果不是自定义插件,使用内置映射 221 class_path = SCENARIO_CLASS_MAP.get(class_name) 222 if not class_path: 223 raise ValueError(f"未知的漏洞类型: {class_name}") 224 225 cls = dynamic_import(class_path) 226 kwargs = parse_kwargs(param_str) 227 return [cls(**kwargs)], None 228 else: 229 class_name = arg 230 231 # 首先检查是否是自定义插件 232 custom_vulnerability = plugin_manager.create_vulnerability_instance(class_name) 233 if custom_vulnerability: 234 return [custom_vulnerability], None 235 236 if class_name == "Custom": 237 from deepteam.vulnerabilities import CustomPrompt 238 return [CustomPrompt()], None 239 elif class_name == "MultiDataset": 240 from deepteam.vulnerabilities import MultiDatasetVulnerability 241 return [MultiDatasetVulnerability()], None 242 else: 243 # 如果不是自定义插件,使用内置映射 244 class_path = SCENARIO_CLASS_MAP.get(class_name) 245 if not class_path: 246 raise ValueError(f"未知的漏洞类型: {class_name}") 247 248 cls = dynamic_import(class_path) 249 return [cls()], None