/ AIG-PromptSecurity / cli / red_team_runner.py
red_team_runner.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import pandas as pd
 20  
 21  from cli.aig_logger import logger
 22  from cli.aig_logger import (
 23      newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate
 24  )
 25  import uuid
 26  import inspect
 27  from typing import List, Any, Optional
 28  from deepteam.red_teamer import RedTeamer
 29  from deepteam.plugin_system import PluginManager
 30  from utils.strategy_map import get_strategy_map
 31  from cli.model_utils import BaseLLM
 32  from cli.parsers import parse_attack, parse_vulnerability, parse_metric_class, dynamic_import
 33  
 34  
 35  class RedTeamRunner:
 36      """红队测试运行器"""
 37      
 38      def __init__(self, plugin_manager: PluginManager):
 39          self.plugin_manager = plugin_manager
 40      
 41      def run_red_team(
 42          self,
 43          models: List[BaseLLM],
 44          simulator_model: BaseLLM,
 45          evaluate_model: BaseLLM,
 46          scenarios: List[str],
 47          techniques: List[str],
 48          async_mode: bool = False,
 49          choice: str = "random",
 50          metric: Optional[str] = None,
 51          report_path: Optional[str] = None,
 52      ) -> str:
 53          """运行红队测试"""
 54          logger.new_plan_step(newPlanStep(stepId="1", title=logger.translated_msg("Pre-Jailbreak Parameter Parsing")))
 55          for m in models:
 56              logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load model: {model_name}", model_name=m.get_model_name()), status="running"))
 57              # 测试连通
 58              is_connection, msg = m.test_model_connection()
 59              m_status = "completed" if is_connection else "failed"
 60              logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load model: {model_name}", model_name=m.get_model_name()), status=m_status))
 61              if m_status == "failed":
 62                  logger.error(msg)
 63                  logger.critical_issue(content=logger.translated_msg("Load model: {model_name} failed: {message}", model_name=m.get_model_name(), message=msg))
 64                  return
 65  
 66          # 解析漏洞
 67          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load scenarios"), status="completed"))
 68  
 69          vulnerabilities = []
 70          try:
 71              for arg in scenarios:
 72                  vs, vs_name = parse_vulnerability(arg, self.plugin_manager)
 73                  if vs_name is not None:
 74                      logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load inputs: {vs_name}", vs_name=vs_name), status="completed"))
 75                  vulnerabilities.extend(vs)
 76          except Exception as e:
 77              logger.exception(e)
 78              logger.critical_issue(content=logger.translated_msg("Load scenarios failed"))
 79              return
 80  
 81          # 运行红队测试
 82          red_teamer = RedTeamer(simulator_model=simulator_model, evaluation_model=evaluate_model, async_mode=async_mode)
 83          red_teamer.max_concurrent = max(red_teamer.max_concurrent, simulator_model.max_concurrent, evaluate_model.max_concurrent)
 84  
 85          # 如果指定了自定义metric,则对所有vulnerability类型使用该metric
 86          if metric:
 87              metric_class_path, metric_kwarg = parse_metric_class(metric)
 88          else:
 89              metric_class_path, metric_kwarg = None, None
 90          
 91          need_evaluation_model = True
 92          if metric_class_path:
 93              logger.debug(f"Using metric: {metric_class_path}")
 94              
 95              # 首先检查是否是自定义插件
 96              custom_metric = self.plugin_manager.create_metric_instance(metric_class_path, model=evaluate_model, async_mode=async_mode)
 97              if custom_metric:
 98                  red_teamer.custom_metric = custom_metric  # type: ignore
 99              else:
100                  # 如果不是自定义插件,使用内置映射
101                  custom_metric_class = dynamic_import(metric_class_path)
102  
103                  init_signature = inspect.signature(custom_metric_class.__init__)
104                  possible_params = {
105                      "model": evaluate_model,
106                      "async_mode": async_mode,
107                      **metric_kwarg  # 合并额外参数
108                  }
109                  
110                  # 筛选出 __init__ 支持的参数
111                  supported_params = {
112                      param: possible_params[param]
113                      for param in possible_params
114                      if param in init_signature.parameters
115                  }
116                  
117                  red_teamer.custom_metric = custom_metric_class(**supported_params)
118  
119                  # 如果评估方法需要评估模型
120                  if "model" not in supported_params:
121                      need_evaluation_model = False
122  
123          metric_name = red_teamer.custom_metric.__name__ if red_teamer.custom_metric else "Default"
124          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load metric: {metric_name}", metric_name=metric_name), status="completed"))
125          if need_evaluation_model:
126              logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load evaluate model: {model_name}", model_name=evaluate_model.get_model_name()), status="running"))
127              # 测试连通
128              is_connection, msg = evaluate_model.test_model_connection()
129              m_status = "completed" if is_connection else "failed"
130              logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load evaluate model: {model_name}", model_name=evaluate_model.get_model_name()), status=m_status))
131              if m_status == "failed":
132                  logger.error(msg)
133                  logger.critical_issue(content=logger.translated_msg("Load evaluate model: {model_name} failed: {message}", model_name=evaluate_model.get_model_name(), message=msg))
134                  return
135  
136          # logger.debug(f"Total vulnerabilities created: {len(vulnerabilities)}")
137          for i, v in enumerate(vulnerabilities):
138              logger.debug(f"Vulnerability {i+1}: {v.get_name()}")
139              if hasattr(v, 'prompts'):
140                  logger.debug(f"Vulnerability {i+1} prompts: {v.prompts}")
141  
142          # 解析攻击手法
143          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load attacks"), status="running"))
144          attacks = [parse_attack(a, self.plugin_manager) for a in techniques]
145          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg(
146              "Load attacks: {attacks}", attacks=", ".join([attack.get_name() for attack in attacks])
147          ), status="completed"))
148          # logger.debug(f"Total attacks created: {len(attacks)}")
149  
150          # 获取攻击策略
151          logger.debug(f"Attack selection strategy: {choice}")
152  
153          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load simulator model: {model_name}", model_name=simulator_model.get_model_name()), status="running"))
154          # 测试连通
155          is_connection, msg = simulator_model.test_model_connection()
156          m_status = "completed" if is_connection else "failed"
157          logger.status_update(statusUpdate(stepId="1", brief=logger.translated_msg("Pre-Jailbreak Parameter Parsing"), description=logger.translated_msg("Load simulator model: {model_name}", model_name=simulator_model.get_model_name()), status=m_status))
158          if m_status == "failed":
159              logger.error(msg)
160              logger.critical_issue(content=logger.translated_msg("Load simulator model: {model_name} failed: {message}", model_name=simulator_model.get_model_name(), message=msg))
161              return
162  
163          try:
164              all_risk_assessments = []
165              for model in models:
166                  red_teamer.max_concurrent = max(red_teamer.max_concurrent, model.max_concurrent)
167                  model_callback = model.a_generate if async_mode else model.generate
168  
169                  # 预校验编码类算子
170                  logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Risk Assessment"), description=logger.translated_msg(
171                      "Pre-verification attacks"
172                  ), status="running"))
173                  tool_id = uuid.uuid4().hex
174                  logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
175                      "Pre-verification attacks"
176                  ), status="todo"))
177                  logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
178                      "Attacks that need to be verified: {attacks}", attacks=", ".join([attack.get_name() for attack in attacks])
179                  ), status="doing"))
180  
181                  logger.disable()
182                  pre_vulnerabilities, _ = parse_vulnerability("Custom:prompt=Output the largest five-digit Arabic numeral", self.plugin_manager)
183                  strategy_map = get_strategy_map()
184                  for strategy in strategy_map["configSchema"]["strategies"]:
185                      if strategy["id"] == "Encoding":
186                          encoding_methods = [method["id"] for method in strategy["methods"]]
187                  pre_attacks = [attack for attack in attacks if attack.get_name() in encoding_methods]
188                  red_teamer.red_team(
189                      model_callback=model_callback,
190                      vulnerabilities=pre_vulnerabilities,
191                      attacks=pre_attacks,
192                      ignore_errors=True,
193                      reuse_simulated_attacks=False,
194                      choice="parallel",
195                      model_name=model.get_model_name()
196                  )
197                  filtered_name = []
198                  for case in red_teamer.risk_assessment.test_cases:
199                      # 1. 泛化前后无变化校验通过 2. 泛化变化的内容能正常输出通过
200                      if case.actual_output and "99999" not in case.actual_output:
201                          filtered_name.append(case.attack_method)
202                  reserved_attacks = [attack for attack in attacks if attack.get_name() not in filtered_name]
203                  logger.enable()
204  
205                  if not reserved_attacks:
206                      logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
207                      "The selected attacks are all invalid for the current model. Please try other attacks."
208                  ), status="done"))
209                  else:
210                      logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
211                          "Attacks that passed verification: {attacks}", attacks=", ".join([attack.get_name() for attack in reserved_attacks])
212                      ), status="done"))
213  
214                  logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Risk Assessment"), description=logger.translated_msg(
215                      "Pre-verification attacks"
216                  ), status="completed"))
217  
218                  red_teamer.red_team(
219                      model_callback=model_callback,
220                      vulnerabilities=vulnerabilities,
221                      attacks=reserved_attacks,
222                      ignore_errors=True,
223                      reuse_simulated_attacks=False,
224                      choice=choice,
225                      model_name=model.get_model_name()
226                  )
227                  all_risk_assessments.append((model.get_model_name(), red_teamer.risk_assessment))
228          except Exception as e:
229              logger.exception(e)
230              logger.critical_issue(content=logger.translated_msg("An error occurred during {model_name} assessment. Please try again later.", model_name=model.get_model_name()))
231              return
232  
233          tool_id = uuid.uuid4().hex
234          logger.new_plan_step(newPlanStep(stepId="3", title=logger.translated_msg("Generating report")))
235          logger.status_update(statusUpdate(stepId="3", brief=logger.translated_msg("A.I.G is working"), description=logger.translated_msg("Generating report"), status="running"))
236          logger.tool_used(toolUsed(stepId="3", tool_id=tool_id, brief=logger.translated_msg("Report in progress"), status="todo"))
237          
238          try:
239              # content, status = red_teamer.get_risk_assessment_markdown()
240              # with open(report_path, "w", encoding="utf-8") as fw:
241              #     fw.write(content)
242              # logger.result_update(resultUpdate(msgType="file", content=report_path, status=status))
243              contents = []
244              final_status = False
245              df_list = []
246              attachment_path = f"logs/attachment_{uuid.uuid4().hex}.csv"
247              for model_name, risk_assessment in all_risk_assessments:
248                  content, status = red_teamer.get_risk_assessment_json(risk_assessment, model_name)
249                  final_status = True if final_status else status
250                  try:
251                      df_list.append(pd.read_csv(content["attachment"]))
252                  except Exception as e:
253                      logger.exception(e)
254                  content["attachment"] = attachment_path
255                  contents.append(content)
256  
257              if df_list:
258                  combined_df = pd.concat(df_list, ignore_index=True)
259              else:
260                  combined_df = pd.DataFrame([])
261              combined_df.to_csv(attachment_path, encoding="utf-8-sig", index=False)
262          except Exception as e:
263              logger.exception(e)
264              logger.critical_issue(content=logger.translated_msg("An error occurred during report generated. Please try again later."))
265              return
266  
267          logger.tool_used(toolUsed(stepId="3", tool_id=tool_id, tool_name="Report generated", brief=logger.translated_msg("Report generated"), status="done"))
268          logger.status_update(statusUpdate(stepId="3", brief=logger.translated_msg("A.I.G is working"), description=logger.translated_msg("Generating report"), status="completed"))
269          # save_report_path = red_teamer.save_risk_assessment_report() 
270          # logger.info(f'Original {model_name} report save to: {save_report_path}')
271          logger.result_update(resultUpdate(msgType="json", content=contents, status=final_status))
272          logger.info(f'Get resultUpdate done!')