attack_simulator.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import copy 20 import random 21 import asyncio 22 from tqdm import tqdm 23 from pydantic import BaseModel 24 from typing import List, Optional, Union 25 import inspect 26 from cli.aig_logger import logger 27 from cli.aig_logger import ( 28 newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate 29 ) 30 import uuid 31 32 from deepeval.models import DeepEvalBaseLLM 33 from deepeval.metrics.utils import initialize_model, trimAndLoadJson 34 35 from deepteam.attacks import BaseAttack 36 from deepteam.vulnerabilities import BaseVulnerability, CustomPrompt, MultiDatasetVulnerability 37 from deepteam.vulnerabilities.types import VulnerabilityType 38 from deepteam.attacks.multi_turn.types import CallbackType 39 from deepteam.attacks.attack_simulator.template import AttackSimulatorTemplate 40 from deepteam.attacks.attack_simulator.schema import SyntheticDataList 41 42 43 class SimulatedAttack(BaseModel): 44 vulnerability: str 45 vulnerability_type: VulnerabilityType 46 original_input: Optional[str] = None 47 input: Optional[str] = None 48 attack_method: Optional[str] = None 49 error: Optional[str] = None 50 useless: bool = False 51 52 53 class AttackSimulator: 54 model_callback: Union[CallbackType, None] = None 55 max_concurrent = 10 56 57 def __init__( 58 self, 59 purpose: str, 60 simulator_model: Optional[Union[str, DeepEvalBaseLLM]] = None, 61 ): 62 # Initialize models and async mode 63 self.purpose = purpose 64 self.simulator_model, self.using_native_model = initialize_model( 65 simulator_model 66 ) 67 68 # Define list of attacks and unaligned vulnerabilities 69 self.simulated_attacks: List[SimulatedAttack] = [] 70 71 ################################################## 72 ### Generating Attacks ########################### 73 ################################################## 74 75 def simulate( 76 self, 77 attacks_per_vulnerability_type: int, 78 vulnerabilities: List[BaseVulnerability], 79 attacks: List[BaseAttack], 80 ignore_errors: bool, 81 choice: str = "random", # 新增参数:random 或 serial 82 ) -> List[SimulatedAttack]: 83 # Simulate unenhanced attacks for each vulnerability 84 baseline_attacks: List[SimulatedAttack] = [] 85 num_vulnerabilities = len(vulnerabilities) 86 num_vulnerability_types = sum( 87 len(v.get_types()) for v in vulnerabilities 88 ) 89 pbar = tqdm( 90 vulnerabilities, 91 desc=f"💥 Generating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (for {num_vulnerability_types} vulnerability types across {num_vulnerabilities} vulnerability(s))", 92 ) 93 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="running")) 94 95 tool_id = uuid.uuid4().hex 96 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 97 "Simulating {num_vulnerabilities} attacks", num_vulnerabilities=num_vulnerabilities 98 ), status="todo")) 99 100 for idx, vulnerability in enumerate(pbar): 101 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 102 "Simulating {idx} / {num_vulnerabilities} attacks", idx=idx+1, num_vulnerabilities=num_vulnerabilities 103 ), status="doing")) 104 105 baseline_attacks.extend( 106 self.simulate_baseline_attacks( 107 attacks_per_vulnerability_type=attacks_per_vulnerability_type, 108 vulnerability=vulnerability, 109 ignore_errors=ignore_errors, 110 ) 111 ) 112 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Simulate baseline attacks", brief=logger.translated_msg( 113 "Simulating {length} attacks done", length=len(vulnerabilities) 114 ), status="done")) 115 116 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="completed")) 117 118 # Enhance attacks by sampling from the provided distribution 119 enhanced_attacks: List[SimulatedAttack] = [] 120 if choice == "serial": 121 unpack_attacks = [attacks] 122 elif choice == "parallel": 123 unpack_attacks = attacks 124 else: 125 attack_weights = [attack.weight for attack in attacks] 126 unpack_attacks = random.choices(attacks, weights=attack_weights, k=1) 127 num_baseline_attacks = len(baseline_attacks) * len(unpack_attacks) 128 pbar = tqdm( 129 total=num_baseline_attacks, 130 desc=f"✨ Simulating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (using {len(attacks)} method(s))", 131 ) 132 133 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg( 134 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 135 ), status="running")) 136 137 tool_id = uuid.uuid4().hex 138 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg( 139 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 140 ), status="todo")) 141 142 for index, (baseline_attack, unpack_attack) in enumerate( 143 (baseline_attack, unpack_attack) 144 for baseline_attack in baseline_attacks 145 for unpack_attack in unpack_attacks 146 ): 147 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 148 "Simulating {idx} / {num_baseline_attacks} attacks", idx=index+1, num_baseline_attacks=num_baseline_attacks 149 ), status="doing")) 150 if choice == "serial": 151 # 串行嵌套攻击:按顺序应用所有攻击方法 152 enhanced_attack = self.enhance_attack_serial( 153 attacks=unpack_attack, 154 simulated_attack=baseline_attack, 155 ignore_errors=ignore_errors, 156 ) 157 else: 158 enhanced_attack = self.enhance_attack( 159 attack=unpack_attack, 160 simulated_attack=baseline_attack, 161 ignore_errors=ignore_errors, 162 ) 163 164 # 泛化前后无变化 165 if baseline_attack.input == enhanced_attack.input: 166 enhanced_attack.useless = True 167 enhanced_attacks.append(enhanced_attack) 168 pbar.update(1) 169 170 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg( 171 "Enhance {num_baseline_attacks} attacks done", num_baseline_attacks=num_baseline_attacks 172 ), status="done")) 173 174 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg( 175 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 176 ), status="completed")) 177 178 self.simulated_attacks.extend(enhanced_attacks) 179 180 return enhanced_attacks 181 182 async def a_simulate( 183 self, 184 attacks_per_vulnerability_type: int, 185 vulnerabilities: List[BaseVulnerability], 186 attacks: List[BaseAttack], 187 ignore_errors: bool, 188 choice: str = "random", # 新增参数:random 或 serial 189 ) -> List[SimulatedAttack]: 190 self.semaphore = asyncio.Semaphore(self.max_concurrent) 191 192 # Simulate unenhanced attacks for each vulnerability 193 baseline_attacks: List[SimulatedAttack] = [] 194 num_vulnerabilities = len(vulnerabilities) 195 num_vulnerability_types = sum( 196 len(v.get_types()) for v in vulnerabilities 197 ) 198 pbar = tqdm( 199 vulnerabilities, 200 desc=f"💥 Generating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (for {num_vulnerability_types} vulnerability types across {num_vulnerabilities} vulnerability(s))", 201 ) 202 tool_id = uuid.uuid4().hex 203 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="running")) 204 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 205 "Simulating {num_vulnerabilities} attacks", num_vulnerabilities=num_vulnerabilities 206 ), status="todo")) 207 208 async def throttled_simulate_baseline_attack(vulnerability): 209 result = await self.a_simulate_baseline_attacks( 210 attacks_per_vulnerability_type=attacks_per_vulnerability_type, 211 vulnerability=vulnerability, 212 ignore_errors=ignore_errors, 213 ) 214 return result 215 216 simulate_tasks = [ 217 throttled_simulate_baseline_attack(vulnerability) for vulnerability in vulnerabilities 218 ] 219 220 for completed, coro in enumerate(asyncio.as_completed(simulate_tasks), 1): 221 result = await(coro) 222 baseline_attacks.extend(result) 223 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 224 "Simulating {idx} / {num_vulnerabilities} attacks", idx=completed, num_vulnerabilities=num_vulnerabilities 225 ), status="doing")) 226 pbar.update(1) 227 228 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Simulate baseline attacks", brief=logger.translated_msg( 229 "Simulating {num_vulnerabilities} attacks done", num_vulnerabilities=num_vulnerabilities 230 ), status="done")) 231 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg("Generating attacks"), status="completed")) 232 pbar.close() 233 234 # Enhance attacks by sampling from the provided distribution 235 enhanced_attacks: List[SimulatedAttack] = [] 236 if choice == "serial": 237 unpack_attacks = [attacks] 238 elif choice == "parallel": 239 unpack_attacks = attacks 240 else: 241 attack_weights = [attack.weight for attack in attacks] 242 unpack_attacks = random.choices(attacks, weights=attack_weights, k=1) 243 num_baseline_attacks = len(baseline_attacks) * len(unpack_attacks) 244 pbar = tqdm( 245 total=num_baseline_attacks, 246 desc=f"✨ Simulating {num_vulnerability_types * attacks_per_vulnerability_type} attacks (using {len(attacks)} method(s))", 247 ) 248 249 async def throttled_attack_method( 250 unpack_attack: List[BaseAttack] | BaseAttack, 251 baseline_attack: SimulatedAttack, 252 ): 253 async with self.semaphore: 254 if choice == "serial": 255 # 串行嵌套攻击:按顺序应用所有攻击方法 256 enhanced_attack = await self.a_enhance_attack_serial( 257 attacks=unpack_attack, 258 simulated_attack=baseline_attack, 259 ignore_errors=ignore_errors, 260 ) 261 else: 262 enhanced_attack = await self.a_enhance_attack( 263 attack=unpack_attack, 264 simulated_attack=baseline_attack, 265 ignore_errors=ignore_errors, 266 ) 267 268 # 泛化前后无变化 269 if baseline_attack.input == enhanced_attack.input: 270 enhanced_attack.useless = True 271 return enhanced_attack 272 273 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg( 274 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 275 ), status="running")) 276 277 tasks = [ 278 throttled_attack_method(unpack_attack, baseline_attack) for baseline_attack in baseline_attacks for unpack_attack in unpack_attacks 279 ] 280 281 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg( 282 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 283 ), status="todo")) 284 285 for completed, coro in enumerate(asyncio.as_completed(tasks), 1): 286 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, brief=logger.translated_msg( 287 "Enhance {idx} / {num_baseline_attacks} attacks", idx=completed, num_baseline_attacks=num_baseline_attacks 288 ), status="doing")) 289 result = await coro 290 enhanced_attacks.append(result) 291 pbar.update(1) 292 293 logger.tool_used(toolUsed(stepId="2", tool_id=tool_id, tool_name="Enhance attacks", brief=logger.translated_msg( 294 "Enhance {num_baseline_attacks} attacks done", num_baseline_attacks=num_baseline_attacks 295 ), status="done")) 296 297 logger.status_update(statusUpdate(stepId="2", brief=logger.translated_msg("Jailbreaking"), description=logger.translated_msg( 298 "Enhance {num_baseline_attacks} attacks", num_baseline_attacks=num_baseline_attacks 299 ), status="completed")) 300 pbar.close() 301 302 # Store the simulated and enhanced attacks 303 self.simulated_attacks.extend(enhanced_attacks) 304 305 return enhanced_attacks 306 307 ################################################## 308 ### Simulating Base (Unenhanced) Attacks ######### 309 ################################################## 310 311 def simulate_baseline_attacks( 312 self, 313 attacks_per_vulnerability_type: int, 314 vulnerability: BaseVulnerability, 315 ignore_errors: bool, 316 ) -> List[SimulatedAttack]: 317 baseline_attacks: List[SimulatedAttack] = [] 318 319 for vulnerability_type in vulnerability.get_types(): 320 try: 321 if isinstance(vulnerability, CustomPrompt) or isinstance(vulnerability, MultiDatasetVulnerability): 322 local_attacks = vulnerability.custom_prompt 323 else: 324 local_attacks = self.simulate_local_attack( 325 self.purpose, 326 vulnerability_type, 327 attacks_per_vulnerability_type, 328 ( 329 vulnerability.custom_prompt 330 if hasattr(vulnerability, "custom_prompt") 331 else None 332 ), 333 ) 334 baseline_attacks.extend( 335 [ 336 SimulatedAttack( 337 vulnerability=vulnerability.get_name(), 338 vulnerability_type=vulnerability_type, 339 original_input=local_attack, 340 input=local_attack, 341 ) 342 for local_attack in local_attacks 343 ] 344 ) 345 except Exception as e: 346 if ignore_errors: 347 for _ in range(attacks_per_vulnerability_type): 348 baseline_attacks.append( 349 SimulatedAttack( 350 vulnerability=vulnerability.get_name(), 351 vulnerability_type=vulnerability_type, 352 error=f"Error simulating adversarial attacks: {str(e)}", 353 ) 354 ) 355 else: 356 raise 357 return baseline_attacks 358 359 async def a_simulate_baseline_attacks( 360 self, 361 attacks_per_vulnerability_type: int, 362 vulnerability: BaseVulnerability, 363 ignore_errors: bool, 364 ) -> List[SimulatedAttack]: 365 baseline_attacks: List[SimulatedAttack] = [] 366 for vulnerability_type in vulnerability.get_types(): 367 try: 368 if isinstance(vulnerability, CustomPrompt) or isinstance(vulnerability, MultiDatasetVulnerability): 369 local_attacks = vulnerability.custom_prompt 370 else: 371 local_attacks = await self.a_simulate_local_attack( 372 self.purpose, 373 vulnerability_type, 374 attacks_per_vulnerability_type, 375 ( 376 vulnerability.custom_prompt 377 if hasattr(vulnerability, "custom_prompt") 378 else None 379 ), 380 ) 381 382 baseline_attacks.extend( 383 [ 384 SimulatedAttack( 385 vulnerability=vulnerability.get_name(), 386 vulnerability_type=vulnerability_type, 387 original_input=local_attack, 388 input=local_attack, 389 ) 390 for local_attack in local_attacks 391 ] 392 ) 393 except Exception as e: 394 if ignore_errors: 395 for _ in range(attacks_per_vulnerability_type): 396 baseline_attacks.append( 397 SimulatedAttack( 398 vulnerability=vulnerability.get_name(), 399 vulnerability_type=vulnerability_type, 400 error=f"Error simulating adversarial attacks: {str(e)}", 401 ) 402 ) 403 else: 404 raise 405 return baseline_attacks 406 407 ################################################## 408 ### Enhance attacks ############################## 409 ################################################## 410 411 def enhance_attack( 412 self, 413 attack: BaseAttack, 414 simulated_attack: SimulatedAttack, 415 ignore_errors: bool, 416 ): 417 simulated_attack = copy.deepcopy(simulated_attack) 418 attack_input = simulated_attack.input 419 if attack_input is None: 420 return simulated_attack 421 422 simulated_attack.attack_method = attack.get_name() 423 sig = inspect.signature(attack.enhance) 424 try: 425 if ( 426 "simulator_model" in sig.parameters 427 and "model_callback" in sig.parameters 428 ): 429 simulated_attack.input = attack.enhance( 430 attack=attack_input, 431 simulator_model=self.simulator_model, 432 model_callback=self.model_callback, 433 ) 434 elif "simulator_model" in sig.parameters: 435 simulated_attack.input = attack.enhance( 436 attack=attack_input, 437 simulator_model=self.simulator_model, 438 ) 439 elif "model_callback" in sig.parameters: 440 simulated_attack.input = attack.enhance( 441 attack=attack_input, 442 model_callback=self.model_callback, 443 ) 444 else: 445 simulated_attack.input = attack.enhance(attack=attack_input) 446 except Exception as e: 447 if ignore_errors: 448 simulated_attack.error = "Error enhancing attack" 449 return simulated_attack 450 else: 451 raise 452 453 return simulated_attack 454 455 def enhance_attack_serial( 456 self, 457 attacks: List[BaseAttack], 458 simulated_attack: SimulatedAttack, 459 ignore_errors: bool, 460 ): 461 """ 462 串行嵌套攻击:按顺序应用所有攻击方法 463 例如:Base64(ROT13(原始攻击)) 464 """ 465 attack_input = simulated_attack.input 466 if attack_input is None: 467 return simulated_attack 468 469 # 记录所有使用的攻击方法名称 470 attack_methods = [] 471 current_input = attack_input 472 473 logger.debug(f"Starting serial attack enhancement") 474 logger.debug(f"Original input: {attack_input[:100]}...") 475 logger.debug(f"Number of attacks to apply: {len(attacks)}") 476 477 try: 478 for i, attack in enumerate(attacks): 479 attack_name = attack.get_name() 480 attack_methods.append(attack_name) 481 482 logger.debug(f"Step {i+1}/{len(attacks)} - Applying {attack_name}") 483 logger.debug(f"Input before {attack_name}: {current_input[:100]}...") 484 485 sig = inspect.signature(attack.enhance) 486 487 # 根据攻击方法的参数需求调用 488 if ("simulator_model" in sig.parameters and "model_callback" in sig.parameters): 489 logger.debug(f"Calling {attack_name}.enhance with simulator_model and model_callback") 490 current_input = attack.enhance( 491 attack=current_input, 492 simulator_model=self.simulator_model, 493 model_callback=self.model_callback, 494 ) 495 elif "simulator_model" in sig.parameters: 496 logger.debug(f"Calling {attack_name}.enhance with simulator_model") 497 current_input = attack.enhance( 498 attack=current_input, 499 simulator_model=self.simulator_model, 500 ) 501 elif "model_callback" in sig.parameters: 502 logger.debug(f"Calling {attack_name}.enhance with model_callback") 503 current_input = attack.enhance( 504 attack=current_input, 505 model_callback=self.model_callback, 506 ) 507 else: 508 logger.debug(f"Calling {attack_name}.enhance with attack parameter only") 509 current_input = attack.enhance(attack=current_input) 510 511 logger.debug(f"Output after {attack_name}: {current_input[:100]}...") 512 logger.debug(f"Input length changed to {len(current_input)}") 513 514 # 更新模拟攻击对象 515 simulated_attack.input = current_input 516 simulated_attack.attack_method = " + ".join(attack_methods) # 记录所有攻击方法 517 518 logger.debug(f"Final attack method: {simulated_attack.attack_method}") 519 logger.debug(f"Final input: {current_input[:100]}...") 520 logger.debug(f"Serial attack enhancement completed successfully") 521 522 except Exception as e: 523 logger.debug(f"Error in serial attack enhancement: {str(e)}") 524 if ignore_errors: 525 simulated_attack.error = f"Error in serial attack enhancement: {str(e)}" 526 return simulated_attack 527 else: 528 raise 529 530 return simulated_attack 531 532 async def a_enhance_attack( 533 self, 534 attack: BaseAttack, 535 simulated_attack: SimulatedAttack, 536 ignore_errors: bool, 537 ): 538 simulated_attack = copy.deepcopy(simulated_attack) 539 attack_input = simulated_attack.input 540 if attack_input is None: 541 return simulated_attack 542 543 simulated_attack.attack_method = attack.get_name() 544 sig = inspect.signature(attack.a_enhance) 545 546 try: 547 if ( 548 "simulator_model" in sig.parameters 549 and "model_callback" in sig.parameters 550 ): 551 simulated_attack.input = await attack.a_enhance( 552 attack=attack_input, 553 simulator_model=self.simulator_model, 554 model_callback=self.model_callback, 555 ) 556 elif "simulator_model" in sig.parameters: 557 simulated_attack.input = await attack.a_enhance( 558 attack=attack_input, 559 simulator_model=self.simulator_model, 560 ) 561 elif "model_callback" in sig.parameters: 562 simulated_attack.input = await attack.a_enhance( 563 attack=attack_input, 564 model_callback=self.model_callback, 565 ) 566 else: 567 simulated_attack.input = await attack.a_enhance( 568 attack=attack_input 569 ) 570 except: 571 if ignore_errors: 572 simulated_attack.error = "Error enhancing attack" 573 return simulated_attack 574 else: 575 raise 576 577 return simulated_attack 578 579 async def a_enhance_attack_serial( 580 self, 581 attacks: List[BaseAttack], 582 simulated_attack: SimulatedAttack, 583 ignore_errors: bool, 584 ): 585 """ 586 异步串行嵌套攻击:按顺序应用所有攻击方法 587 例如:Base64(ROT13(原始攻击)) 588 """ 589 attack_input = simulated_attack.input 590 if attack_input is None: 591 return simulated_attack 592 593 # 记录所有使用的攻击方法名称 594 attack_methods = [] 595 current_input = attack_input 596 597 logger.debug(f"Starting async serial attack enhancement") 598 logger.debug(f"Original input: {attack_input[:100]}...") 599 logger.debug(f"Number of attacks to apply: {len(attacks)}") 600 601 try: 602 for i, attack in enumerate(attacks): 603 attack_name = attack.get_name() 604 attack_methods.append(attack_name) 605 606 logger.debug(f"Step {i+1}/{len(attacks)} - Applying {attack_name}") 607 logger.debug(f"Input before {attack_name}: {current_input[:100]}...") 608 609 sig = inspect.signature(attack.enhance) 610 611 # 根据攻击方法的参数需求调用 612 if ("simulator_model" in sig.parameters and "model_callback" in sig.parameters): 613 logger.debug(f"Calling {attack_name}.enhance with simulator_model and model_callback") 614 current_input = attack.enhance( 615 attack=current_input, 616 simulator_model=self.simulator_model, 617 model_callback=self.model_callback, 618 ) 619 elif "simulator_model" in sig.parameters: 620 logger.debug(f"Calling {attack_name}.enhance with simulator_model") 621 current_input = attack.enhance( 622 attack=current_input, 623 simulator_model=self.simulator_model, 624 ) 625 elif "model_callback" in sig.parameters: 626 logger.debug(f"Calling {attack_name}.enhance with model_callback") 627 current_input = attack.enhance( 628 attack=current_input, 629 model_callback=self.model_callback, 630 ) 631 else: 632 logger.debug(f"Calling {attack_name}.enhance with attack parameter only") 633 current_input = attack.enhance(attack=current_input) 634 635 logger.debug(f"Output after {attack_name}: {current_input[:100]}...") 636 logger.debug(f"Input length changed from {len(attack_input) if i == 0 else len(await attacks[i-1].enhance(attack_input))} to {len(current_input)}") 637 638 # 更新模拟攻击对象 639 simulated_attack.input = current_input 640 simulated_attack.attack_method = " + ".join(attack_methods) # 记录所有攻击方法 641 642 logger.debug(f"Final attack method: {simulated_attack.attack_method}") 643 logger.debug(f"Final input: {current_input[:100]}...") 644 logger.debug(f"Async serial attack enhancement completed successfully") 645 646 except Exception as e: 647 logger.debug(f"Error in async serial attack enhancement: {str(e)}") 648 if ignore_errors: 649 simulated_attack.error = f"Error in serial attack enhancement: {str(e)}" 650 return simulated_attack 651 else: 652 raise 653 654 return simulated_attack 655 656 def simulate_local_attack( 657 self, 658 purpose: str, 659 vulnerability_type: VulnerabilityType, 660 num_attacks: int, 661 custom_prompt: Optional[str] = None, 662 ) -> List[str]: 663 """Simulate attacks using local LLM model""" 664 # Get the appropriate prompt template from AttackSimulatorTemplate 665 prompt = AttackSimulatorTemplate.generate_attacks( 666 max_goldens=num_attacks, 667 vulnerability_type=vulnerability_type, 668 purpose=purpose, 669 custom_prompt=custom_prompt, 670 ) 671 if self.using_native_model: 672 # For models that support schema validation directly 673 res, _ = self.simulator_model.generate( 674 prompt, schema=SyntheticDataList 675 ) 676 return [item.input for item in res.data] 677 else: 678 try: 679 res: SyntheticDataList = self.simulator_model.generate( 680 prompt, schema=SyntheticDataList 681 ) 682 return [item.input for item in res.data] 683 except TypeError: 684 res = self.simulator_model.generate(prompt) 685 data = trimAndLoadJson(res) 686 return [item["input"] for item in data["data"]] 687 688 async def a_simulate_local_attack( 689 self, 690 purpose: str, 691 vulnerability_type: VulnerabilityType, 692 num_attacks: int, 693 custom_prompt: Optional[str] = None, 694 ) -> List[str]: 695 """Asynchronously simulate attacks using local LLM model""" 696 697 prompt = AttackSimulatorTemplate.generate_attacks( 698 max_goldens=num_attacks, 699 vulnerability_type=vulnerability_type, 700 purpose=purpose, 701 custom_prompt=custom_prompt, 702 ) 703 704 if self.using_native_model: 705 # For models that support schema validation directly 706 res, _ = await self.simulator_model.a_generate( 707 prompt, schema=SyntheticDataList 708 ) 709 return [item.input for item in res.data] 710 else: 711 try: 712 res: SyntheticDataList = await self.simulator_model.a_generate( 713 prompt, schema=SyntheticDataList 714 ) 715 return [item.input for item in res.data] 716 except TypeError: 717 res = await self.simulator_model.a_generate(prompt) 718 data = trimAndLoadJson(res) 719 return [item["input"] for item in data["data"]]