attack_simulator.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import copy
 20  import random
 21  import asyncio
 22  from tqdm import tqdm
 23  from pydantic import BaseModel
 24  from typing import List, Optional, Union
 25  import inspect
 26  from cli.aig_logger import logger
 27  from cli.aig_logger import (
 28      newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate
 29  )
 30  import uuid
 31  
 32  from deepeval.models import DeepEvalBaseLLM
 33  from deepeval.metrics.utils import initialize_model, trimAndLoadJson
 34  
 35  from deepteam.attacks import BaseAttack
 36  from deepteam.vulnerabilities import BaseVulnerability, CustomPrompt, MultiDatasetVulnerability
 37  from deepteam.vulnerabilities.types import VulnerabilityType
 38  from deepteam.attacks.multi_turn.types import CallbackType
 39  from deepteam.attacks.attack_simulator.template import AttackSimulatorTemplate
 40  from deepteam.attacks.attack_simulator.schema import SyntheticDataList
 41  
 42  
 43  class SimulatedAttack(BaseModel):
 44      vulnerability: str
 45      vulnerability_type: VulnerabilityType
 46      original_input: Optional[str] = None
 47      input: Optional[str] = None
 48      attack_method: Optional[str] = None
 49      error: Optional[str] = None
 50      useless: bool = False
 51  
 52  
 53  class AttackSimulator:
 54      model_callback: Union[CallbackType, None] = None
 55      max_concurrent = 10
 56  
 57      def __init__(
 58          self,
 59          purpose: str,
 60          simulator_model: Optional[Union[str, DeepEvalBaseLLM]] = None,
 61      ):
 62          # Initialize models and async mode
 63          self.purpose = purpose
 64          self.simulator_model, self.using_native_model = initialize_model(
 65              simulator_model
 66          )
 67  
 68          # Define list of attacks and unaligned vulnerabilities
 69          self.simulated_attacks: List[SimulatedAttack] = []
 70  
 71      ##################################################
 72      ### Generating Attacks ###########################
 73      ##################################################
 74  
 75      def simulate(
 76          self,
 77          attacks_per_vulnerability_type: int,
 78          vulnerabilities: List[BaseVulnerability],
 79          attacks: List[BaseAttack],
 80          ignore_errors: bool,
 81          choice: str = "random",  # 新增参数:random 或 serial
 82      ) -> List[SimulatedAttack]:
 83          # Simulate unenhanced attacks for each vulnerability
 84          baseline_attacks: List[SimulatedAttack] = []
 85          num_vulnerabilities = len(vulnerabilities)
 86          num_vulnerability_types = sum(
 87              len(v.get_types()) for v in vulnerabilities
 88          )
 89          pbar = tqdm(
 90              vulnerabilities,
 91              desc=f"💥 Generating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (for {num_vulnerability_types} vulnerability types across {num_vulnerabilities} vulnerability(s))",
 92          )
 93          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="running"))
 94          
 95          tool_id = uuid.uuid4().hex
 96          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
 97              "Simulating {num_vulnerabilities} attacks", num_vulnerabilities=num_vulnerabilities
 98          ), status="todo"))
 99  
100          for idx, vulnerability in enumerate(pbar):
101              logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
102                  "Simulating {idx} / {num_vulnerabilities} attacks", idx=idx+1, num_vulnerabilities=num_vulnerabilities
103              ), status="doing"))
104  
105              baseline_attacks.extend(
106                  self.simulate_baseline_attacks(
107                      attacks_per_vulnerability_type=attacks_per_vulnerability_type,
108                      vulnerability=vulnerability,
109                      ignore_errors=ignore_errors,
110                  )
111              )
112          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Simulate baseline attacks", brief=logger.translated_msg(
113              "Simulating {length} attacks done", length=len(vulnerabilities)
114          ), status="done"))
115  
116          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="completed"))
117  
118          # Enhance attacks by sampling from the provided distribution
119          enhanced_attacks: List[SimulatedAttack] = []
120          if choice == "serial":
121              unpack_attacks = [attacks]
122          elif choice == "parallel":
123              unpack_attacks = attacks
124          else:
125              attack_weights = [attack.weight for attack in attacks]
126              unpack_attacks = random.choices(attacks, weights=attack_weights, k=1)
127          num_baseline_attacks = len(baseline_attacks) * len(unpack_attacks)
128          pbar = tqdm(
129              total=num_baseline_attacks,
130              desc=f"✨ Simulating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (using {len(attacks)} method(s))",
131          )
132  
133          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg(
134              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
135          ), status="running"))
136          
137          tool_id = uuid.uuid4().hex
138          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg(
139              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
140          ), status="todo"))
141  
142          for index, (baseline_attack, unpack_attack) in enumerate(
143              (baseline_attack, unpack_attack) 
144              for baseline_attack in baseline_attacks 
145              for unpack_attack in unpack_attacks
146          ):
147              logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
148                  "Simulating {idx} / {num_baseline_attacks} attacks", idx=index+1, num_baseline_attacks=num_baseline_attacks
149              ), status="doing"))
150              if choice == "serial":
151                  # 串行嵌套攻击:按顺序应用所有攻击方法
152                  enhanced_attack = self.enhance_attack_serial(
153                      attacks=unpack_attack,
154                      simulated_attack=baseline_attack,
155                      ignore_errors=ignore_errors,
156                  )
157              else:
158                  enhanced_attack = self.enhance_attack(
159                      attack=unpack_attack,
160                      simulated_attack=baseline_attack,
161                      ignore_errors=ignore_errors,
162                  )
163  
164              # 泛化前后无变化
165              if baseline_attack.input == enhanced_attack.input:
166                  enhanced_attack.useless = True
167              enhanced_attacks.append(enhanced_attack)
168              pbar.update(1)
169  
170          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg(
171              "Enhance {num_baseline_attacks} attacks done", num_baseline_attacks=num_baseline_attacks
172          ), status="done"))
173  
174          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg(
175              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
176          ), status="completed"))
177  
178          self.simulated_attacks.extend(enhanced_attacks)
179  
180          return enhanced_attacks
181  
182      async def a_simulate(
183          self,
184          attacks_per_vulnerability_type: int,
185          vulnerabilities: List[BaseVulnerability],
186          attacks: List[BaseAttack],
187          ignore_errors: bool,
188          choice: str = "random",  # 新增参数:random 或 serial
189      ) -> List[SimulatedAttack]:
190          self.semaphore = asyncio.Semaphore(self.max_concurrent)
191  
192          # Simulate unenhanced attacks for each vulnerability
193          baseline_attacks: List[SimulatedAttack] = []
194          num_vulnerabilities = len(vulnerabilities)
195          num_vulnerability_types = sum(
196              len(v.get_types()) for v in vulnerabilities
197          )
198          pbar = tqdm(
199              vulnerabilities,
200              desc=f"💥 Generating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (for {num_vulnerability_types} vulnerability types across {num_vulnerabilities} vulnerability(s))",
201          )
202          tool_id = uuid.uuid4().hex
203          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="running"))
204          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
205              "Simulating {num_vulnerabilities} attacks", num_vulnerabilities=num_vulnerabilities
206          ), status="todo"))
207          
208          async def throttled_simulate_baseline_attack(vulnerability):
209              result = await self.a_simulate_baseline_attacks(
210                  attacks_per_vulnerability_type=attacks_per_vulnerability_type,
211                  vulnerability=vulnerability,
212                  ignore_errors=ignore_errors,
213              )
214              return result
215  
216          simulate_tasks = [
217              throttled_simulate_baseline_attack(vulnerability) for vulnerability in vulnerabilities
218          ]
219          
220          for completed, coro in enumerate(asyncio.as_completed(simulate_tasks), 1):
221              result = await(coro)
222              baseline_attacks.extend(result)
223              logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
224                  "Simulating {idx} / {num_vulnerabilities} attacks", idx=completed, num_vulnerabilities=num_vulnerabilities
225              ), status="doing"))
226              pbar.update(1)
227          
228          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Simulate baseline attacks", brief=logger.translated_msg(
229              "Simulating {num_vulnerabilities} attacks done", num_vulnerabilities=num_vulnerabilities
230          ), status="done"))
231          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="completed"))
232          pbar.close()
233  
234          # Enhance attacks by sampling from the provided distribution
235          enhanced_attacks: List[SimulatedAttack] = []
236          if choice == "serial":
237              unpack_attacks = [attacks]
238          elif choice == "parallel":
239              unpack_attacks = attacks
240          else:
241              attack_weights = [attack.weight for attack in attacks]
242              unpack_attacks = random.choices(attacks, weights=attack_weights, k=1)
243          num_baseline_attacks = len(baseline_attacks) * len(unpack_attacks)
244          pbar = tqdm(
245              total=num_baseline_attacks,
246              desc=f"✨ Simulating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (using {len(attacks)} method(s))",
247          )
248              
249          async def throttled_attack_method(
250              unpack_attack: List[BaseAttack] | BaseAttack,
251              baseline_attack: SimulatedAttack,
252          ):
253              async with self.semaphore:
254                  if choice == "serial":
255                      # 串行嵌套攻击:按顺序应用所有攻击方法
256                      enhanced_attack = await self.a_enhance_attack_serial(
257                          attacks=unpack_attack,
258                          simulated_attack=baseline_attack,
259                          ignore_errors=ignore_errors,
260                      )
261                  else:
262                      enhanced_attack = await self.a_enhance_attack(
263                          attack=unpack_attack,
264                          simulated_attack=baseline_attack,
265                          ignore_errors=ignore_errors,
266                      )
267  
268                  # 泛化前后无变化
269                  if baseline_attack.input == enhanced_attack.input:
270                      enhanced_attack.useless = True
271                  return enhanced_attack
272          
273          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg(
274              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
275          ), status="running"))
276  
277          tasks = [
278              throttled_attack_method(unpack_attack, baseline_attack) for baseline_attack in baseline_attacks for unpack_attack in unpack_attacks
279          ]
280  
281          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg(
282              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
283          ), status="todo"))
284  
285          for completed, coro in enumerate(asyncio.as_completed(tasks), 1):
286              logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg(
287                  "Enhance {idx} / {num_baseline_attacks} attacks", idx=completed, num_baseline_attacks=num_baseline_attacks
288              ), status="doing"))
289              result = await coro
290              enhanced_attacks.append(result)
291              pbar.update(1)
292          
293          logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg(
294              "Enhance {num_baseline_attacks} attacks done", num_baseline_attacks=num_baseline_attacks
295          ), status="done"))
296          
297          logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg(
298              "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks
299          ), status="completed"))
300          pbar.close()
301  
302          # Store the simulated and enhanced attacks
303          self.simulated_attacks.extend(enhanced_attacks)
304  
305          return enhanced_attacks
306  
307      ##################################################
308      ### Simulating Base (Unenhanced) Attacks #########
309      ##################################################
310  
311      def simulate_baseline_attacks(
312          self,
313          attacks_per_vulnerability_type: int,
314          vulnerability: BaseVulnerability,
315          ignore_errors: bool,
316      ) -> List[SimulatedAttack]:
317          baseline_attacks: List[SimulatedAttack] = []
318  
319          for vulnerability_type in vulnerability.get_types():
320              try:
321                  if isinstance(vulnerability, CustomPrompt) or isinstance(vulnerability, MultiDatasetVulnerability):
322                      local_attacks = vulnerability.custom_prompt
323                  else:
324                      local_attacks = self.simulate_local_attack(
325                          self.purpose,
326                          vulnerability_type,
327                          attacks_per_vulnerability_type,
328                          (
329                              vulnerability.custom_prompt
330                              if hasattr(vulnerability, "custom_prompt")
331                              else None
332                          ),
333                      )
334                  baseline_attacks.extend(
335                      [
336                          SimulatedAttack(
337                              vulnerability=vulnerability.get_name(),
338                              vulnerability_type=vulnerability_type,
339                              original_input=local_attack,
340                              input=local_attack,
341                          )
342                          for local_attack in local_attacks
343                      ]
344                  )
345              except Exception as e:
346                  if ignore_errors:
347                      for _ in range(attacks_per_vulnerability_type):
348                          baseline_attacks.append(
349                              SimulatedAttack(
350                                  vulnerability=vulnerability.get_name(),
351                                  vulnerability_type=vulnerability_type,
352                                  error=f"Error simulating adversarial attacks: {str(e)}",
353                              )
354                          )
355                  else:
356                      raise
357          return baseline_attacks
358  
359      async def a_simulate_baseline_attacks(
360          self,
361          attacks_per_vulnerability_type: int,
362          vulnerability: BaseVulnerability,
363          ignore_errors: bool,
364      ) -> List[SimulatedAttack]:
365          baseline_attacks: List[SimulatedAttack] = []
366          for vulnerability_type in vulnerability.get_types():
367              try:
368                  if isinstance(vulnerability, CustomPrompt) or isinstance(vulnerability, MultiDatasetVulnerability):
369                      local_attacks = vulnerability.custom_prompt
370                  else:
371                      local_attacks = await self.a_simulate_local_attack(
372                          self.purpose,
373                          vulnerability_type,
374                          attacks_per_vulnerability_type,
375                          (
376                              vulnerability.custom_prompt
377                              if hasattr(vulnerability, "custom_prompt")
378                              else None
379                          ),
380                      )
381  
382                  baseline_attacks.extend(
383                      [
384                          SimulatedAttack(
385                              vulnerability=vulnerability.get_name(),
386                              vulnerability_type=vulnerability_type,
387                              original_input=local_attack, 
388                              input=local_attack,
389                          )
390                          for local_attack in local_attacks
391                      ]
392                  )
393              except Exception as e:
394                  if ignore_errors:
395                      for _ in range(attacks_per_vulnerability_type):
396                          baseline_attacks.append(
397                              SimulatedAttack(
398                                  vulnerability=vulnerability.get_name(),
399                                  vulnerability_type=vulnerability_type,
400                                  error=f"Error simulating adversarial attacks: {str(e)}",
401                              )
402                          )
403                  else:
404                      raise
405          return baseline_attacks
406  
407      ##################################################
408      ### Enhance attacks ##############################
409      ##################################################
410  
411      def enhance_attack(
412          self,
413          attack: BaseAttack,
414          simulated_attack: SimulatedAttack,
415          ignore_errors: bool,
416      ):
417          simulated_attack = copy.deepcopy(simulated_attack)
418          attack_input = simulated_attack.input
419          if attack_input is None:
420              return simulated_attack
421  
422          simulated_attack.attack_method = attack.get_name()
423          sig = inspect.signature(attack.enhance)
424          try:
425              if (
426                  "simulator_model" in sig.parameters
427                  and "model_callback" in sig.parameters
428              ):
429                  simulated_attack.input = attack.enhance(
430                      attack=attack_input,
431                      simulator_model=self.simulator_model,
432                      model_callback=self.model_callback,
433                  )
434              elif "simulator_model" in sig.parameters:
435                  simulated_attack.input = attack.enhance(
436                      attack=attack_input,
437                      simulator_model=self.simulator_model,
438                  )
439              elif "model_callback" in sig.parameters:
440                  simulated_attack.input = attack.enhance(
441                      attack=attack_input,
442                      model_callback=self.model_callback,
443                  )
444              else:
445                  simulated_attack.input = attack.enhance(attack=attack_input)
446          except Exception as e:
447              if ignore_errors:
448                  simulated_attack.error = "Error enhancing attack"
449                  return simulated_attack
450              else:
451                  raise
452  
453          return simulated_attack
454  
455      def enhance_attack_serial(
456          self,
457          attacks: List[BaseAttack],
458          simulated_attack: SimulatedAttack,
459          ignore_errors: bool,
460      ):
461          """
462          串行嵌套攻击:按顺序应用所有攻击方法
463          例如:Base64(ROT13(原始攻击))
464          """
465          attack_input = simulated_attack.input
466          if attack_input is None:
467              return simulated_attack
468  
469          # 记录所有使用的攻击方法名称
470          attack_methods = []
471          current_input = attack_input
472          
473          logger.debug(f"Starting serial attack enhancement")
474          logger.debug(f"Original input: {attack_input[:100]}...")
475          logger.debug(f"Number of attacks to apply: {len(attacks)}")
476  
477          try:
478              for i, attack in enumerate(attacks):
479                  attack_name = attack.get_name()
480                  attack_methods.append(attack_name)
481                  
482                  logger.debug(f"Step {i+1}/{len(attacks)} - Applying {attack_name}")
483                  logger.debug(f"Input before {attack_name}: {current_input[:100]}...")
484                  
485                  sig = inspect.signature(attack.enhance)
486                  
487                  # 根据攻击方法的参数需求调用
488                  if ("simulator_model" in sig.parameters and "model_callback" in sig.parameters):
489                      logger.debug(f"Calling {attack_name}.enhance with simulator_model and model_callback")
490                      current_input = attack.enhance(
491                          attack=current_input,
492                          simulator_model=self.simulator_model,
493                          model_callback=self.model_callback,
494                      )
495                  elif "simulator_model" in sig.parameters:
496                      logger.debug(f"Calling {attack_name}.enhance with simulator_model")
497                      current_input = attack.enhance(
498                          attack=current_input,
499                          simulator_model=self.simulator_model,
500                      )
501                  elif "model_callback" in sig.parameters:
502                      logger.debug(f"Calling {attack_name}.enhance with model_callback")
503                      current_input = attack.enhance(
504                          attack=current_input,
505                          model_callback=self.model_callback,
506                      )
507                  else:
508                      logger.debug(f"Calling {attack_name}.enhance with attack parameter only")
509                      current_input = attack.enhance(attack=current_input)
510                  
511                  logger.debug(f"Output after {attack_name}: {current_input[:100]}...")
512                  logger.debug(f"Input length changed to {len(current_input)}")
513  
514              # 更新模拟攻击对象
515              simulated_attack.input = current_input
516              simulated_attack.attack_method = " + ".join(attack_methods)  # 记录所有攻击方法
517              
518              logger.debug(f"Final attack method: {simulated_attack.attack_method}")
519              logger.debug(f"Final input: {current_input[:100]}...")
520              logger.debug(f"Serial attack enhancement completed successfully")
521              
522          except Exception as e:
523              logger.debug(f"Error in serial attack enhancement: {str(e)}")
524              if ignore_errors:
525                  simulated_attack.error = f"Error in serial attack enhancement: {str(e)}"
526                  return simulated_attack
527              else:
528                  raise
529  
530          return simulated_attack
531  
532      async def a_enhance_attack(
533          self,
534          attack: BaseAttack,
535          simulated_attack: SimulatedAttack,
536          ignore_errors: bool,
537      ):
538          simulated_attack = copy.deepcopy(simulated_attack)
539          attack_input = simulated_attack.input
540          if attack_input is None:
541              return simulated_attack
542  
543          simulated_attack.attack_method = attack.get_name()
544          sig = inspect.signature(attack.a_enhance)
545  
546          try:
547              if (
548                  "simulator_model" in sig.parameters
549                  and "model_callback" in sig.parameters
550              ):
551                  simulated_attack.input = await attack.a_enhance(
552                      attack=attack_input,
553                      simulator_model=self.simulator_model,
554                      model_callback=self.model_callback,
555                  )
556              elif "simulator_model" in sig.parameters:
557                  simulated_attack.input = await attack.a_enhance(
558                      attack=attack_input,
559                      simulator_model=self.simulator_model,
560                  )
561              elif "model_callback" in sig.parameters:
562                  simulated_attack.input = await attack.a_enhance(
563                      attack=attack_input,
564                      model_callback=self.model_callback,
565                  )
566              else:
567                  simulated_attack.input = await attack.a_enhance(
568                      attack=attack_input
569                  )
570          except:
571              if ignore_errors:
572                  simulated_attack.error = "Error enhancing attack"
573                  return simulated_attack
574              else:
575                  raise
576  
577          return simulated_attack
578  
579      async def a_enhance_attack_serial(
580          self,
581          attacks: List[BaseAttack],
582          simulated_attack: SimulatedAttack,
583          ignore_errors: bool,
584      ):
585          """
586          异步串行嵌套攻击:按顺序应用所有攻击方法
587          例如:Base64(ROT13(原始攻击))
588          """
589          attack_input = simulated_attack.input
590          if attack_input is None:
591              return simulated_attack
592  
593          # 记录所有使用的攻击方法名称
594          attack_methods = []
595          current_input = attack_input
596          
597          logger.debug(f"Starting async serial attack enhancement")
598          logger.debug(f"Original input: {attack_input[:100]}...")
599          logger.debug(f"Number of attacks to apply: {len(attacks)}")
600  
601          try:
602              for i, attack in enumerate(attacks):
603                  attack_name = attack.get_name()
604                  attack_methods.append(attack_name)
605                  
606                  logger.debug(f"Step {i+1}/{len(attacks)} - Applying {attack_name}")
607                  logger.debug(f"Input before {attack_name}: {current_input[:100]}...")
608                  
609                  sig = inspect.signature(attack.enhance)
610                  
611                  # 根据攻击方法的参数需求调用
612                  if ("simulator_model" in sig.parameters and "model_callback" in sig.parameters):
613                      logger.debug(f"Calling {attack_name}.enhance with simulator_model and model_callback")
614                      current_input = attack.enhance(
615                          attack=current_input,
616                          simulator_model=self.simulator_model,
617                          model_callback=self.model_callback,
618                      )
619                  elif "simulator_model" in sig.parameters:
620                      logger.debug(f"Calling {attack_name}.enhance with simulator_model")
621                      current_input = attack.enhance(
622                          attack=current_input,
623                          simulator_model=self.simulator_model,
624                      )
625                  elif "model_callback" in sig.parameters:
626                      logger.debug(f"Calling {attack_name}.enhance with model_callback")
627                      current_input = attack.enhance(
628                          attack=current_input,
629                          model_callback=self.model_callback,
630                      )
631                  else:
632                      logger.debug(f"Calling {attack_name}.enhance with attack parameter only")
633                      current_input = attack.enhance(attack=current_input)
634                  
635                  logger.debug(f"Output after {attack_name}: {current_input[:100]}...")
636                  logger.debug(f"Input length changed from {len(attack_input) if i == 0 else len(await attacks[i-1].enhance(attack_input))} to {len(current_input)}")
637  
638              # 更新模拟攻击对象
639              simulated_attack.input = current_input
640              simulated_attack.attack_method = " + ".join(attack_methods)  # 记录所有攻击方法
641              
642              logger.debug(f"Final attack method: {simulated_attack.attack_method}")
643              logger.debug(f"Final input: {current_input[:100]}...")
644              logger.debug(f"Async serial attack enhancement completed successfully")
645              
646          except Exception as e:
647              logger.debug(f"Error in async serial attack enhancement: {str(e)}")
648              if ignore_errors:
649                  simulated_attack.error = f"Error in serial attack enhancement: {str(e)}"
650                  return simulated_attack
651              else:
652                  raise
653  
654          return simulated_attack
655  
656      def simulate_local_attack(
657          self,
658          purpose: str,
659          vulnerability_type: VulnerabilityType,
660          num_attacks: int,
661          custom_prompt: Optional[str] = None,
662      ) -> List[str]:
663          """Simulate attacks using local LLM model"""
664          # Get the appropriate prompt template from AttackSimulatorTemplate
665          prompt = AttackSimulatorTemplate.generate_attacks(
666              max_goldens=num_attacks,
667              vulnerability_type=vulnerability_type,
668              purpose=purpose,
669              custom_prompt=custom_prompt,
670          )
671          if self.using_native_model:
672              # For models that support schema validation directly
673              res, _ = self.simulator_model.generate(
674                  prompt, schema=SyntheticDataList
675              )
676              return [item.input for item in res.data]
677          else:
678              try:
679                  res: SyntheticDataList = self.simulator_model.generate(
680                      prompt, schema=SyntheticDataList
681                  )
682                  return [item.input for item in res.data]
683              except TypeError:
684                  res = self.simulator_model.generate(prompt)
685                  data = trimAndLoadJson(res)
686                  return [item["input"] for item in data["data"]]
687  
688      async def a_simulate_local_attack(
689          self,
690          purpose: str,
691          vulnerability_type: VulnerabilityType,
692          num_attacks: int,
693          custom_prompt: Optional[str] = None,
694      ) -> List[str]:
695          """Asynchronously simulate attacks using local LLM model"""
696  
697          prompt = AttackSimulatorTemplate.generate_attacks(
698              max_goldens=num_attacks,
699              vulnerability_type=vulnerability_type,
700              purpose=purpose,
701              custom_prompt=custom_prompt,
702          )
703  
704          if self.using_native_model:
705              # For models that support schema validation directly
706              res, _ = await self.simulator_model.a_generate(
707                  prompt, schema=SyntheticDataList
708              )
709              return [item.input for item in res.data]
710          else:
711              try:
712                  res: SyntheticDataList = await self.simulator_model.a_generate(
713                      prompt, schema=SyntheticDataList
714                  )
715                  return [item.input for item in res.data]
716              except TypeError:
717                  res = await self.simulator_model.a_generate(prompt)
718                  data = trimAndLoadJson(res)
719                  return [item["input"] for item in data["data"]]