/ mcp-scan / redteam / orchestrator.py
orchestrator.py
  1  """
  2  红队编排器:主入口,编排 Attacker / Target / Evaluator 三角色协作的攻击流程。
  3  支持 Crescendo 与 TAP 两种策略。
  4  """
  5  
  6  from __future__ import annotations
  7  
  8  import os
  9  from typing import List, Optional, Any, Literal
 10  
 11  from openai import AsyncOpenAI
 12  
 13  from redteam.attacker import AttackerAgent
 14  from redteam.evaluator import EvaluatorAgent
 15  from redteam.target import TargetRunner
 16  from redteam.strategy import (
 17      CrescendoStrategy,
 18      CrescendoPhase,
 19      TAPStrategy,
 20      AttackNode,
 21      ConversationTurn,
 22  )
 23  
 24  try:
 25      from utils.config import DEFAULT_MODEL, DEFAULT_BASE_URL
 26      from utils.loging import logger
 27  except ImportError:
 28      DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "deepseek/deepseek-v3.2-exp")
 29      DEFAULT_BASE_URL = os.environ.get("DEFAULT_BASE_URL", "https://openrouter.ai/api/v1")
 30      import logging
 31      logger = logging.getLogger("redteam")
 32  
 33  
 34  def _get_api_key() -> str:
 35      return os.environ.get("OPENROUTER_API_KEY") or os.environ.get("API_KEY") or ""
 36  
 37  
 38  class RedTeamOrchestrator:
 39      """
 40      红队编排器:创建 Attacker / Evaluator / Target,按策略执行多轮攻击并收集结果。
 41      LLM 调用统一走 OpenAI 兼容接口(AsyncOpenAI),与 mcp-scan 的模型配置方式一致。
 42      """
 43  
 44      def __init__(
 45          self,
 46          api_key: Optional[str] = None,
 47          base_url: Optional[str] = None,
 48          model: Optional[str] = None,
 49          repo_dir: Optional[str] = None,
 50      ):
 51          self.api_key = api_key or _get_api_key()
 52          if not self.api_key:
 53              raise ValueError(
 54                  "Missing API key for AsyncOpenAI client. Please provide 'api_key' explicitly "
 55                  "or set the 'OPENROUTER_API_KEY' or 'API_KEY' environment variable."
 56              )
 57          self.base_url = base_url or DEFAULT_BASE_URL
 58          self.model = model or DEFAULT_MODEL
 59          self.repo_dir = repo_dir or ""
 60          self._client: Optional[AsyncOpenAI] = None
 61          self._attacker: Optional[AttackerAgent] = None
 62          self._evaluator: Optional[EvaluatorAgent] = None
 63          self._target: Optional[TargetRunner] = None
 64  
 65      @property
 66      def client(self) -> AsyncOpenAI:
 67          if self._client is None:
 68              self._client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url, timeout=90)
 69          return self._client
 70  
 71      @property
 72      def attacker(self) -> AttackerAgent:
 73          if self._attacker is None:
 74              self._attacker = AttackerAgent(self.client, self.model)
 75          return self._attacker
 76  
 77      @property
 78      def evaluator(self) -> EvaluatorAgent:
 79          if self._evaluator is None:
 80              self._evaluator = EvaluatorAgent(self.client, self.model)
 81          return self._evaluator
 82  
 83      @property
 84      def target(self) -> TargetRunner:
 85          if self._target is None:
 86              self._target = TargetRunner(self.client, self.model, self.repo_dir)
 87          return self._target
 88  
 89      def set_repo(self, repo_dir: str) -> None:
 90          self.repo_dir = repo_dir
 91          if self._target is not None:
 92              self._target.set_repo(repo_dir)
 93  
 94      async def run_crescendo(
 95          self,
 96          attack_target: str,
 97          max_rounds_per_phase: int = 3,
 98          max_total_rounds: int = 20,
 99      ) -> dict:
100          """
101          执行 Crescendo 策略:建立信任 → 试探边界 → 逐步升级 → 发起攻击。
102          返回包含 rounds, success, best_score, phases 等的结果字典。
103          """
104          strategy = CrescendoStrategy(
105              max_rounds_per_phase=max_rounds_per_phase,
106              max_total_rounds=max_total_rounds,
107          )
108          history: List[ConversationTurn] = []
109          rounds_log: List[dict] = []
110          best_score = 0.0
111          success = False
112          for round_index in range(max_total_rounds):
113              phase = strategy.current_phase(round_index)
114              # Attacker 生成攻击
115              attack_out = await self.attacker.generate_attack(
116                  attack_target=attack_target,
117                  history=history,
118                  phase=phase,
119              )
120              msg = attack_out.get("message", "").strip()
121              if not msg:
122                  logger.warning("Attacker produced empty message, stopping.")
123                  break
124              # Target 模拟响应
125              recent = [t.to_history_text() for t in history]
126              target_response = await self.target.respond_to_attack(msg, recent)
127              turn = ConversationTurn(
128                  attack_message=msg,
129                  target_response=target_response,
130                  attack_technique=attack_out.get("attack_technique"),
131                  thought=attack_out.get("thought"),
132                  reflection=attack_out.get("reflection"),
133              )
134              history.append(turn)
135              # Evaluator 打分
136              eval_result = await self.evaluator.evaluate(
137                  attack_target=attack_target,
138                  turn=turn,
139                  history_before=history[:-1],
140              )
141              score = float(eval_result.get("score", 5))
142              is_success = bool(eval_result.get("is_successful", False))
143              best_score = max(best_score, score)
144              if is_success:
145                  success = True
146              rounds_log.append({
147                  "round": round_index + 1,
148                  "phase": phase.value,
149                  "turn": {
150                      "attack_message": msg,
151                      "target_response": target_response,
152                      "attack_technique": turn.attack_technique,
153                  },
154                  "eval": eval_result,
155                  "score": score,
156              })
157              logger.info(f"Crescendo round {round_index + 1} phase={phase.value} score={score} success={is_success}")
158              if not strategy.should_continue(round_index + 1, score, is_success):
159                  break
160          return {
161              "strategy": "crescendo",
162              "attack_target": attack_target,
163              "rounds": rounds_log,
164              "success": success,
165              "best_score": best_score,
166              "total_rounds": len(rounds_log),
167              "history": history,
168          }
169  
170      async def run_tap(
171          self,
172          attack_target: str,
173          branch_factor: int = 3,
174          top_k: int = 2,
175          max_depth: int = 5,
176      ) -> dict:
177          """
178          执行 TAP 策略:多分支生成,两阶段剪枝保留 top_k。
179          返回包含 root, success_nodes, all_nodes, best_score 等的结果字典。
180          """
181          strategy = TAPStrategy(branch_factor=branch_factor, top_k=top_k, max_depth=max_depth)
182          root = AttackNode(
183              node_id="root",
184              turn=ConversationTurn(attack_message="", target_response=""),
185              depth=0,
186          )
187          all_nodes: List[AttackNode] = [root]
188          node_counter = 0
189  
190          def next_id() -> str:
191              nonlocal node_counter
192              node_counter += 1
193              return f"n{node_counter}"
194  
195          async def expand_node(node: AttackNode) -> None:
196              if not strategy.should_expand(node):
197                  return
198              history = node.conversation_history()
199              # 注意:history[0] 是 root 节点的「空」占位轮次(attack_message / target_response 皆为空字符串),
200              # 仅用于统一对话树结构,不应暴露给 attacker / target / evaluator。
201              # 因此下面统一使用 history[1:] 只传递真实的对话轮次;
202              # 当 node 为 root(depth=0)时,history[1:] 为空列表,表示「当前无历史对话」,这是 *有意为之*。
203              # 生成 branch_factor 个变体
204              candidates: List[AttackNode] = []
205              for _ in range(strategy.branch_factor):
206                  attack_out = await self.attacker.generate_attack(
207                      attack_target=attack_target,
208                      history=history[1:],  # 去掉 root 的空占位轮次;root 节点时传入空历史是预期行为
209                      phase=None,
210                  )
211                  msg = attack_out.get("message", "").strip()
212                  if not msg:
213                      continue
214                  recent = [t.to_history_text() for t in history[1:]]
215                  target_response = await self.target.respond_to_attack(msg, recent)
216                  turn = ConversationTurn(
217                      attack_message=msg,
218                      target_response=target_response,
219                      attack_technique=attack_out.get("attack_technique"),
220                      thought=attack_out.get("thought"),
221                      reflection=attack_out.get("reflection"),
222                  )
223                  child = AttackNode(node_id=next_id(), turn=turn, depth=node.depth + 1)
224                  eval_result = await self.evaluator.evaluate(
225                      attack_target=attack_target,
226                      turn=turn,
227                      history_before=history[1:],
228                  )
229                  child.score = float(eval_result.get("score", 5))
230                  child.on_topic = bool(eval_result.get("on_topic", True))
231                  child.is_successful = bool(eval_result.get("is_successful", False))
232                  child.meta["eval"] = eval_result
233                  candidates.append(child)
234                  all_nodes.append(child)
235              # 两阶段剪枝
236              kept = strategy.prune(candidates)
237              for c in kept:
238                  node.add_child(c)
239              # 递归扩展叶节点
240              for c in kept:
241                  await expand_node(c)
242  
243          await expand_node(root)
244  
245          def collect_reachable_nodes(node: AttackNode) -> List[AttackNode]:
246              """
247              从 root 出发遍历 TAP 树,收集最终树上(未被剪枝掉)的所有节点。
248              这样 success_nodes / best_score 等统计只基于最终树,而不是所有候选节点。
249              """
250              reachable: List[AttackNode] = []
251              stack: List[AttackNode] = [node]
252              while stack:
253                  current = stack.pop()
254                  reachable.append(current)
255                  # 假设 AttackNode.children 存在且在 add_child 时已维护
256                  for child in getattr(current, "children", []) or []:
257                      stack.append(child)
258              return reachable
259  
260          reachable_nodes = collect_reachable_nodes(root)
261          leaves = strategy.leaves(root)
262          success_nodes = [n for n in reachable_nodes if n.is_successful]
263          best_score = max((n.score for n in reachable_nodes if n.depth > 0), default=0.0)
264          return {
265              "strategy": "tap",
266              "attack_target": attack_target,
267              "root": root,
268              "all_nodes": all_nodes,
269              "success_nodes": success_nodes,
270              "leaves": leaves,
271              "best_score": best_score,
272          }
273  
274      async def run(
275          self,
276          attack_target: str,
277          strategy_name: Literal["crescendo", "tap"] = "crescendo",
278          **strategy_kwargs: Any,
279      ) -> dict:
280          """
281          统一入口:按 strategy_name 执行 Crescendo 或 TAP,返回策略相关结果字典。
282          """
283          if strategy_name == "crescendo":
284              return await self.run_crescendo(attack_target, **strategy_kwargs)
285          if strategy_name == "tap":
286              return await self.run_tap(attack_target, **strategy_kwargs)
287          raise ValueError(f"Unknown strategy: {strategy_name}")