red_team_runner.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import pandas as pd 20 21 from cli.aig_logger import logger 22 from cli.aig_logger import ( 23 newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate 24 ) 25 import uuid 26 import inspect 27 from typing import List, Any, Optional 28 from deepteam.red_teamer import RedTeamer 29 from deepteam.plugin_system import PluginManager 30 from utils.strategy_map import get_strategy_map 31 from cli.model_utils import BaseLLM 32 from cli.parsers import parse_attack, parse_vulnerability, parse_metric_class, dynamic_import 33 34 35 class RedTeamRunner: 36 """红队测试运行器""" 37 38 def __init__(self, plugin_manager: PluginManager): 39 self.plugin_manager = plugin_manager 40 41 def run_red_team( 42 self, 43 models: List[BaseLLM], 44 simulator_model: BaseLLM, 45 evaluate_model: BaseLLM, 46 scenarios: List[str], 47 techniques: List[str], 48 async_mode: bool = False, 49 choice: str = "random", 50 metric: Optional[str] = None, 51 report_path: Optional[str] = None, 52 ) -> str: 53 """运行红队测试""" 54 logger.new_plan_step(newPlanStep(stepId="1", title=logger.translated_msg("Pre-Jailbreak Parameter Parsing"))) 55 for m in models: 56 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load model: {model_name}", model_name=m.get_model_name()), status="running")) 57 # 测试连通 58 is_connection, msg = m.test_model_connection() 59 m_status = "completed" if is_connection else "failed" 60 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load model: {model_name}", model_name=m.get_model_name()), status=m_status)) 61 if m_status == "failed": 62 logger.error(msg) 63 logger.critical_issue(content=logger.translated_msg("Load model: {model_name} failed: {message}", model_name=m.get_model_name(), message=msg)) 64 return 65 66 # 解析漏洞 67 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load scenarios"), status="completed")) 68 69 vulnerabilities = [] 70 try: 71 for arg in scenarios: 72 vs, vs_name = parse_vulnerability(arg, self.plugin_manager) 73 if vs_name is not None: 74 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load inputs: {vs_name}", vs_name=vs_name), status="completed")) 75 vulnerabilities.extend(vs) 76 except Exception as e: 77 logger.exception(e) 78 logger.critical_issue(content=logger.translated_msg("Load scenarios failed")) 79 return 80 81 # 运行红队测试 82 red_teamer = RedTeamer(simulator_model=simulator_model, evaluation_model=evaluate_model, async_mode=async_mode) 83 red_teamer.max_concurrent = max(red_teamer.max_concurrent, simulator_model.max_concurrent, evaluate_model.max_concurrent) 84 85 # 如果指定了自定义metric,则对所有vulnerability类型使用该metric 86 if metric: 87 metric_class_path, metric_kwarg = parse_metric_class(metric) 88 else: 89 metric_class_path, metric_kwarg = None, None 90 91 need_evaluation_model = True 92 if metric_class_path: 93 logger.debug(f"Using metric: {metric_class_path}") 94 95 # 首先检查是否是自定义插件 96 custom_metric = self.plugin_manager.create_metric_instance(metric_class_path, model=evaluate_model, async_mode=async_mode) 97 if custom_metric: 98 red_teamer.custom_metric = custom_metric # type: ignore 99 else: 100 # 如果不是自定义插件,使用内置映射 101 custom_metric_class = dynamic_import(metric_class_path) 102 103 init_signature = inspect.signature(custom_metric_class.__init__) 104 possible_params = { 105 "model": evaluate_model, 106 "async_mode": async_mode, 107 **metric_kwarg # 合并额外参数 108 } 109 110 # 筛选出 __init__ 支持的参数 111 supported_params = { 112 param: possible_params[param] 113 for param in possible_params 114 if param in init_signature.parameters 115 } 116 117 red_teamer.custom_metric = custom_metric_class(**supported_params) 118 119 # 如果评估方法需要评估模型 120 if "model" not in supported_params: 121 need_evaluation_model = False 122 123 metric_name = red_teamer.custom_metric.__name__ if red_teamer.custom_metric else "Default" 124 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load metric: {metric_name}", metric_name=metric_name), status="completed")) 125 if need_evaluation_model: 126 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load evaluate model: {model_name}", model_name=evaluate_model.get_model_name()), status="running")) 127 # 测试连通 128 is_connection, msg = evaluate_model.test_model_connection() 129 m_status = "completed" if is_connection else "failed" 130 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load evaluate model: {model_name}", model_name=evaluate_model.get_model_name()), status=m_status)) 131 if m_status == "failed": 132 logger.error(msg) 133 logger.critical_issue(content=logger.translated_msg("Load evaluate model: {model_name} failed: {message}", model_name=evaluate_model.get_model_name(), message=msg)) 134 return 135 136 # logger.debug(f"Total vulnerabilities created: {len(vulnerabilities)}") 137 for i, v in enumerate(vulnerabilities): 138 logger.debug(f"Vulnerability {i+1}: {v.get_name()}") 139 if hasattr(v, 'prompts'): 140 logger.debug(f"Vulnerability {i+1} prompts: {v.prompts}") 141 142 # 解析攻击手法 143 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load attacks"), status="running")) 144 attacks = [parse_attack(a, self.plugin_manager) for a in techniques] 145 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg( 146 "Load attacks: {attacks}", attacks=", ".join([attack.get_name() for attack in attacks]) 147 ), status="completed")) 148 # logger.debug(f"Total attacks created: {len(attacks)}") 149 150 # 获取攻击策略 151 logger.debug(f"Attack selection strategy: {choice}") 152 153 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load simulator model: {model_name}", model_name=simulator_model.get_model_name()), status="running")) 154 # 测试连通 155 is_connection, msg = simulator_model.test_model_connection() 156 m_status = "completed" if is_connection else "failed" 157 logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load simulator model: {model_name}", model_name=simulator_model.get_model_name()), status=m_status)) 158 if m_status == "failed": 159 logger.error(msg) 160 logger.critical_issue(content=logger.translated_msg("Load simulator model: {model_name} failed: {message}", model_name=simulator_model.get_model_name(), message=msg)) 161 return 162 163 try: 164 all_risk_assessments = [] 165 for model in models: 166 red_teamer.max_concurrent = max(red_teamer.max_concurrent, model.max_concurrent) 167 model_callback = model.a_generate if async_mode else model.generate 168 169 # 预校验编码类算子 170 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Risk Assessment"), description=logger.translated_msg( 171 "Pre-verification attacks" 172 ), status="running")) 173 tool_id = uuid.uuid4().hex 174 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 175 "Pre-verification attacks" 176 ), status="todo")) 177 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 178 "Attacks that need to be verified: {attacks}", attacks=", ".join([attack.get_name() for attack in attacks]) 179 ), status="doing")) 180 181 logger.disable() 182 pre_vulnerabilities, _ = parse_vulnerability("Custom:prompt=Output the largest five-digit Arabic numeral", self.plugin_manager) 183 strategy_map = get_strategy_map() 184 for strategy in strategy_map["configSchema"]["strategies"]: 185 if strategy["id"] == "Encoding": 186 encoding_methods = [method["id"] for method in strategy["methods"]] 187 pre_attacks = [attack for attack in attacks if attack.get_name() in encoding_methods] 188 red_teamer.red_team( 189 model_callback=model_callback, 190 vulnerabilities=pre_vulnerabilities, 191 attacks=pre_attacks, 192 ignore_errors=True, 193 reuse_simulated_attacks=False, 194 choice="parallel", 195 model_name=model.get_model_name() 196 ) 197 filtered_name = [] 198 for case in red_teamer.risk_assessment.test_cases: 199 # 1. 泛化前后无变化校验通过 2. 泛化变化的内容能正常输出通过 200 if case.actual_output and "99999" not in case.actual_output: 201 filtered_name.append(case.attack_method) 202 reserved_attacks = [attack for attack in attacks if attack.get_name() not in filtered_name] 203 logger.enable() 204 205 if not reserved_attacks: 206 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 207 "The selected attacks are all invalid for the current model. Please try other attacks." 208 ), status="done")) 209 else: 210 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 211 "Attacks that passed verification: {attacks}", attacks=", ".join([attack.get_name() for attack in reserved_attacks]) 212 ), status="done")) 213 214 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Risk Assessment"), description=logger.translated_msg( 215 "Pre-verification attacks" 216 ), status="completed")) 217 218 red_teamer.red_team( 219 model_callback=model_callback, 220 vulnerabilities=vulnerabilities, 221 attacks=reserved_attacks, 222 ignore_errors=True, 223 reuse_simulated_attacks=False, 224 choice=choice, 225 model_name=model.get_model_name() 226 ) 227 all_risk_assessments.append((model.get_model_name(), red_teamer.risk_assessment)) 228 except Exception as e: 229 logger.exception(e) 230 logger.critical_issue(content=logger.translated_msg("An error occurred during {model_name} assessment. Please try again later.", model_name=model.get_model_name())) 231 return 232 233 tool_id = uuid.uuid4().hex 234 logger.new_plan_step(newPlanStep(stepId="3", title=logger.translated_msg("Generating report"))) 235 logger.status_update(statusUpdate(stepId="3", brief=logger.translated_msg("A.I.G is working"), description=logger.translated_msg("Generating report"), status="running")) 236 logger.tool_used(toolUsed(stepId="3", tool_id=tool_id, brief=logger.translated_msg("Report in progress"), status="todo")) 237 238 try: 239 # content, status = red_teamer.get_risk_assessment_markdown() 240 # with open(report_path, "w", encoding="utf-8") as fw: 241 # fw.write(content) 242 # logger.result_update(resultUpdate(msgType="file", content=report_path, status=status)) 243 contents = [] 244 final_status = False 245 df_list = [] 246 attachment_path = f"logs/attachment_{uuid.uuid4().hex}.csv" 247 for model_name, risk_assessment in all_risk_assessments: 248 content, status = red_teamer.get_risk_assessment_json(risk_assessment, model_name) 249 final_status = True if final_status else status 250 try: 251 df_list.append(pd.read_csv(content["attachment"])) 252 except Exception as e: 253 logger.exception(e) 254 content["attachment"] = attachment_path 255 contents.append(content) 256 257 if df_list: 258 combined_df = pd.concat(df_list, ignore_index=True) 259 else: 260 combined_df = pd.DataFrame([]) 261 combined_df.to_csv(attachment_path, encoding="utf-8-sig", index=False) 262 except Exception as e: 263 logger.exception(e) 264 logger.critical_issue(content=logger.translated_msg("An error occurred during report generated. Please try again later.")) 265 return 266 267 logger.tool_used(toolUsed(stepId="3", tool_id=tool_id, tool_name="Report generated", brief=logger.translated_msg("Report generated"), status="done")) 268 logger.status_update(statusUpdate(stepId="3", brief=logger.translated_msg("A.I.G is working"), description=logger.translated_msg("Generating report"), status="completed")) 269 # save_report_path = red_teamer.save_risk_assessment_report() 270 # logger.info(f'Original {model_name} report save to: {save_report_path}') 271 logger.result_update(resultUpdate(msgType="json", content=contents, status=final_status)) 272 logger.info(f'Get resultUpdate done!')