orchestrator.py
1 """ 2 红队编排器:主入口,编排 Attacker / Target / Evaluator 三角色协作的攻击流程。 3 支持 Crescendo 与 TAP 两种策略。 4 """ 5 6 from __future__ import annotations 7 8 import os 9 from typing import List, Optional, Any, Literal 10 11 from openai import AsyncOpenAI 12 13 from redteam.attacker import AttackerAgent 14 from redteam.evaluator import EvaluatorAgent 15 from redteam.target import TargetRunner 16 from redteam.strategy import ( 17 CrescendoStrategy, 18 CrescendoPhase, 19 TAPStrategy, 20 AttackNode, 21 ConversationTurn, 22 ) 23 24 try: 25 from utils.config import DEFAULT_MODEL, DEFAULT_BASE_URL 26 from utils.loging import logger 27 except ImportError: 28 DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "deepseek/deepseek-v3.2-exp") 29 DEFAULT_BASE_URL = os.environ.get("DEFAULT_BASE_URL", "https://openrouter.ai/api/v1") 30 import logging 31 logger = logging.getLogger("redteam") 32 33 34 def _get_api_key() -> str: 35 return os.environ.get("OPENROUTER_API_KEY") or os.environ.get("API_KEY") or "" 36 37 38 class RedTeamOrchestrator: 39 """ 40 红队编排器:创建 Attacker / Evaluator / Target,按策略执行多轮攻击并收集结果。 41 LLM 调用统一走 OpenAI 兼容接口(AsyncOpenAI),与 mcp-scan 的模型配置方式一致。 42 """ 43 44 def __init__( 45 self, 46 api_key: Optional[str] = None, 47 base_url: Optional[str] = None, 48 model: Optional[str] = None, 49 repo_dir: Optional[str] = None, 50 ): 51 self.api_key = api_key or _get_api_key() 52 if not self.api_key: 53 raise ValueError( 54 "Missing API key for AsyncOpenAI client. Please provide 'api_key' explicitly " 55 "or set the 'OPENROUTER_API_KEY' or 'API_KEY' environment variable." 56 ) 57 self.base_url = base_url or DEFAULT_BASE_URL 58 self.model = model or DEFAULT_MODEL 59 self.repo_dir = repo_dir or "" 60 self._client: Optional[AsyncOpenAI] = None 61 self._attacker: Optional[AttackerAgent] = None 62 self._evaluator: Optional[EvaluatorAgent] = None 63 self._target: Optional[TargetRunner] = None 64 65 @property 66 def client(self) -> AsyncOpenAI: 67 if self._client is None: 68 self._client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url, timeout=90) 69 return self._client 70 71 @property 72 def attacker(self) -> AttackerAgent: 73 if self._attacker is None: 74 self._attacker = AttackerAgent(self.client, self.model) 75 return self._attacker 76 77 @property 78 def evaluator(self) -> EvaluatorAgent: 79 if self._evaluator is None: 80 self._evaluator = EvaluatorAgent(self.client, self.model) 81 return self._evaluator 82 83 @property 84 def target(self) -> TargetRunner: 85 if self._target is None: 86 self._target = TargetRunner(self.client, self.model, self.repo_dir) 87 return self._target 88 89 def set_repo(self, repo_dir: str) -> None: 90 self.repo_dir = repo_dir 91 if self._target is not None: 92 self._target.set_repo(repo_dir) 93 94 async def run_crescendo( 95 self, 96 attack_target: str, 97 max_rounds_per_phase: int = 3, 98 max_total_rounds: int = 20, 99 ) -> dict: 100 """ 101 执行 Crescendo 策略:建立信任 → 试探边界 → 逐步升级 → 发起攻击。 102 返回包含 rounds, success, best_score, phases 等的结果字典。 103 """ 104 strategy = CrescendoStrategy( 105 max_rounds_per_phase=max_rounds_per_phase, 106 max_total_rounds=max_total_rounds, 107 ) 108 history: List[ConversationTurn] = [] 109 rounds_log: List[dict] = [] 110 best_score = 0.0 111 success = False 112 for round_index in range(max_total_rounds): 113 phase = strategy.current_phase(round_index) 114 # Attacker 生成攻击 115 attack_out = await self.attacker.generate_attack( 116 attack_target=attack_target, 117 history=history, 118 phase=phase, 119 ) 120 msg = attack_out.get("message", "").strip() 121 if not msg: 122 logger.warning("Attacker produced empty message, stopping.") 123 break 124 # Target 模拟响应 125 recent = [t.to_history_text() for t in history] 126 target_response = await self.target.respond_to_attack(msg, recent) 127 turn = ConversationTurn( 128 attack_message=msg, 129 target_response=target_response, 130 attack_technique=attack_out.get("attack_technique"), 131 thought=attack_out.get("thought"), 132 reflection=attack_out.get("reflection"), 133 ) 134 history.append(turn) 135 # Evaluator 打分 136 eval_result = await self.evaluator.evaluate( 137 attack_target=attack_target, 138 turn=turn, 139 history_before=history[:-1], 140 ) 141 score = float(eval_result.get("score", 5)) 142 is_success = bool(eval_result.get("is_successful", False)) 143 best_score = max(best_score, score) 144 if is_success: 145 success = True 146 rounds_log.append({ 147 "round": round_index + 1, 148 "phase": phase.value, 149 "turn": { 150 "attack_message": msg, 151 "target_response": target_response, 152 "attack_technique": turn.attack_technique, 153 }, 154 "eval": eval_result, 155 "score": score, 156 }) 157 logger.info(f"Crescendo round {round_index + 1} phase={phase.value} score={score} success={is_success}") 158 if not strategy.should_continue(round_index + 1, score, is_success): 159 break 160 return { 161 "strategy": "crescendo", 162 "attack_target": attack_target, 163 "rounds": rounds_log, 164 "success": success, 165 "best_score": best_score, 166 "total_rounds": len(rounds_log), 167 "history": history, 168 } 169 170 async def run_tap( 171 self, 172 attack_target: str, 173 branch_factor: int = 3, 174 top_k: int = 2, 175 max_depth: int = 5, 176 ) -> dict: 177 """ 178 执行 TAP 策略:多分支生成,两阶段剪枝保留 top_k。 179 返回包含 root, success_nodes, all_nodes, best_score 等的结果字典。 180 """ 181 strategy = TAPStrategy(branch_factor=branch_factor, top_k=top_k, max_depth=max_depth) 182 root = AttackNode( 183 node_id="root", 184 turn=ConversationTurn(attack_message="", target_response=""), 185 depth=0, 186 ) 187 all_nodes: List[AttackNode] = [root] 188 node_counter = 0 189 190 def next_id() -> str: 191 nonlocal node_counter 192 node_counter += 1 193 return f"n{node_counter}" 194 195 async def expand_node(node: AttackNode) -> None: 196 if not strategy.should_expand(node): 197 return 198 history = node.conversation_history() 199 # 注意:history[0] 是 root 节点的「空」占位轮次(attack_message / target_response 皆为空字符串), 200 # 仅用于统一对话树结构,不应暴露给 attacker / target / evaluator。 201 # 因此下面统一使用 history[1:] 只传递真实的对话轮次; 202 # 当 node 为 root(depth=0)时,history[1:] 为空列表,表示「当前无历史对话」,这是 *有意为之*。 203 # 生成 branch_factor 个变体 204 candidates: List[AttackNode] = [] 205 for _ in range(strategy.branch_factor): 206 attack_out = await self.attacker.generate_attack( 207 attack_target=attack_target, 208 history=history[1:], # 去掉 root 的空占位轮次;root 节点时传入空历史是预期行为 209 phase=None, 210 ) 211 msg = attack_out.get("message", "").strip() 212 if not msg: 213 continue 214 recent = [t.to_history_text() for t in history[1:]] 215 target_response = await self.target.respond_to_attack(msg, recent) 216 turn = ConversationTurn( 217 attack_message=msg, 218 target_response=target_response, 219 attack_technique=attack_out.get("attack_technique"), 220 thought=attack_out.get("thought"), 221 reflection=attack_out.get("reflection"), 222 ) 223 child = AttackNode(node_id=next_id(), turn=turn, depth=node.depth + 1) 224 eval_result = await self.evaluator.evaluate( 225 attack_target=attack_target, 226 turn=turn, 227 history_before=history[1:], 228 ) 229 child.score = float(eval_result.get("score", 5)) 230 child.on_topic = bool(eval_result.get("on_topic", True)) 231 child.is_successful = bool(eval_result.get("is_successful", False)) 232 child.meta["eval"] = eval_result 233 candidates.append(child) 234 all_nodes.append(child) 235 # 两阶段剪枝 236 kept = strategy.prune(candidates) 237 for c in kept: 238 node.add_child(c) 239 # 递归扩展叶节点 240 for c in kept: 241 await expand_node(c) 242 243 await expand_node(root) 244 245 def collect_reachable_nodes(node: AttackNode) -> List[AttackNode]: 246 """ 247 从 root 出发遍历 TAP 树,收集最终树上(未被剪枝掉)的所有节点。 248 这样 success_nodes / best_score 等统计只基于最终树,而不是所有候选节点。 249 """ 250 reachable: List[AttackNode] = [] 251 stack: List[AttackNode] = [node] 252 while stack: 253 current = stack.pop() 254 reachable.append(current) 255 # 假设 AttackNode.children 存在且在 add_child 时已维护 256 for child in getattr(current, "children", []) or []: 257 stack.append(child) 258 return reachable 259 260 reachable_nodes = collect_reachable_nodes(root) 261 leaves = strategy.leaves(root) 262 success_nodes = [n for n in reachable_nodes if n.is_successful] 263 best_score = max((n.score for n in reachable_nodes if n.depth > 0), default=0.0) 264 return { 265 "strategy": "tap", 266 "attack_target": attack_target, 267 "root": root, 268 "all_nodes": all_nodes, 269 "success_nodes": success_nodes, 270 "leaves": leaves, 271 "best_score": best_score, 272 } 273 274 async def run( 275 self, 276 attack_target: str, 277 strategy_name: Literal["crescendo", "tap"] = "crescendo", 278 **strategy_kwargs: Any, 279 ) -> dict: 280 """ 281 统一入口:按 strategy_name 执行 Crescendo 或 TAP,返回策略相关结果字典。 282 """ 283 if strategy_name == "crescendo": 284 return await self.run_crescendo(attack_target, **strategy_kwargs) 285 if strategy_name == "tap": 286 return await self.run_tap(attack_target, **strategy_kwargs) 287 raise ValueError(f"Unknown strategy: {strategy_name}")