target.py
1 """ 2 Target Runner:当前为源码分析模式,复用项目代码读取能力,通过 LLM 模拟 MCP Server 对攻击的响应。 3 不实际启动 MCP 进程。 4 """ 5 6 from __future__ import annotations 7 8 import os 9 from pathlib import Path 10 from typing import List, Optional 11 12 from openai import AsyncOpenAI 13 14 # 源码分析:可读扩展名与单文件最大字符数(与 mcp-scan 能力对齐) 15 READABLE_EXT = {".py", ".go", ".js", ".ts", ".md", ".json", ".yaml", ".yml", ".toml", ".sh", ".rs", ".java"} 16 MAX_FILE_CHARS = 50000 17 MAX_FILES = 100 18 19 20 def gather_code_context(repo_dir: str) -> str: 21 """ 22 从仓库目录收集代码上下文:列出文件并读取部分文件内容,供 LLM 模拟 MCP 行为。 23 复用与 mcp-scan 一致的“可读文件”能力,不依赖 ToolDispatcher。 24 """ 25 repo = Path(repo_dir).resolve() 26 if not repo.is_dir(): 27 return f"Invalid repo path: {repo_dir}" 28 lines: List[str] = [] 29 total_chars = 0 30 file_count = 0 31 try: 32 for path in sorted(repo.rglob("*")): 33 # Enforce maximum number of files to process 34 if file_count >= MAX_FILES: 35 break 36 # Skip symbolic links to avoid traversing outside the repo or special targets 37 if path.is_symlink(): 38 continue 39 # Resolve the real path and ensure it stays within the repository boundary 40 try: 41 resolved = path.resolve() 42 except Exception: 43 continue 44 try: 45 resolved.relative_to(repo) 46 except ValueError: 47 # Resolved path is outside the repository; skip it 48 continue 49 if not resolved.is_file(): 50 continue 51 if path.suffix.lower() not in READABLE_EXT: 52 continue 53 try: 54 text = resolved.read_text(encoding="utf-8", errors="replace") 55 except (OSError, UnicodeError): 56 continue 57 if len(text) > MAX_FILE_CHARS: 58 text = text[:MAX_FILE_CHARS] + "\n... (truncated)" 59 rel = path.relative_to(repo) 60 prospective_total = total_chars + len(text) 61 if prospective_total > 300000: 62 break 63 lines.append(f"--- FILE: {rel} ---") 64 lines.append(text) 65 lines.append("") 66 total_chars = prospective_total 67 file_count += 1 68 except Exception as e: 69 lines.append(f"Error reading repo: {e}") 70 if not lines: 71 return "No readable files found in repo." 72 return "\n".join(lines) 73 74 75 TARGET_SYSTEM = """你正在模拟一个 MCP (Model Context Protocol) Server 的对外行为。 76 你已获得该 MCP 项目的一部分源码/文档作为上下文。当收到来自“用户/客户端”的消息时,你需要以该 MCP Server 的身份回复: 77 - 回复应贴合项目能力(例如暴露的 tools、resources 等) 78 - 若消息明显为恶意或越权请求,可以拒绝、简短报错或给出安全响应,但不要脱离“MCP Server”人设 79 - 回复尽量简短(几行到一段),不要暴露内部实现细节 80 仅输出模拟的 MCP Server 回复内容,不要输出解释或 meta 说明。""" 81 82 83 class TargetRunner: 84 """ 85 目标运行器:源码分析模式。 86 - 初始化时从 repo_dir 收集代码上下文 87 - respond_to_attack(message) 使用 LLM 根据上下文模拟 MCP Server 的响应 88 """ 89 90 def __init__( 91 self, 92 client: AsyncOpenAI, 93 model: str, 94 repo_dir: Optional[str] = None, 95 temperature: float = 0.3, 96 ): 97 self.client = client 98 self.model = model 99 self.repo_dir = repo_dir or "" 100 self.temperature = temperature 101 self._context: Optional[str] = None 102 103 def set_repo(self, repo_dir: str) -> None: 104 self.repo_dir = repo_dir 105 self._context = None 106 107 def _get_context(self) -> str: 108 if self._context is not None: 109 return self._context 110 if not self.repo_dir or not os.path.isdir(self.repo_dir): 111 self._context = "No repository path provided or path is not a directory." 112 return self._context 113 self._context = gather_code_context(self.repo_dir) 114 return self._context 115 116 def _build_messages(self, attack_message: str, recent_history: Optional[List[str]] = None) -> List[dict]: 117 context = self._get_context() 118 user_parts = ["当前项目(MCP Server)源码/文档上下文:", "", context, "", "---", ""] 119 if recent_history: 120 user_parts.append("最近几轮交互(仅作参考):") 121 for h in recent_history[-3:]: 122 user_parts.append(h) 123 user_parts.append("") 124 user_parts.append("本轮收到的用户/客户端消息:") 125 user_parts.append(attack_message) 126 user_parts.append("") 127 user_parts.append("请以 MCP Server 身份回复上述消息(仅输出回复内容)。") 128 return [ 129 {"role": "system", "content": TARGET_SYSTEM}, 130 {"role": "user", "content": "\n".join(user_parts)}, 131 ] 132 133 async def respond_to_attack( 134 self, 135 attack_message: str, 136 recent_history: Optional[List[str]] = None, 137 ) -> str: 138 """ 139 根据当前攻击消息与已有上下文,模拟 MCP Server 的响应并返回字符串。 140 """ 141 messages = self._build_messages(attack_message, recent_history) 142 response = await self.client.chat.completions.create( 143 model=self.model, 144 messages=messages, 145 temperature=self.temperature, 146 ) 147 return (response.choices[0].message.content or "").strip()