agent.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import time 20 from typing import Any 21 22 from agent.base_agent import BaseAgent 23 from tools.dispatcher import ToolDispatcher 24 from utils.aig_logger import mcpLogger 25 from utils.extract_vuln import VulnerabilityExtractor 26 from utils.loging import logger 27 from utils.project_analyzer import analyze_language, calc_mcp_score, get_top_language 28 from utils.prompt_manager import prompt_manager 29 30 31 def is_vuln_review_output(content: str) -> bool: 32 return "<vuln>" in content or "<empty>" in content 33 34 35 class ScanStage: 36 """定义扫描的一个阶段""" 37 38 def __init__( 39 self, 40 stage_id: str, 41 name: str, 42 template: str, 43 output_format: str = None, 44 output_check_fn=None, 45 language="zh", 46 ): 47 self.stage_id = stage_id 48 self.name = name 49 self.template = template 50 self.output_format = output_format 51 self.output_check_fn = output_check_fn 52 self.language = language 53 54 55 class ScanPipeline: 56 """标准扫描流水线逻辑""" 57 58 def __init__(self, agent_wrapper: "Agent"): 59 self.agent_wrapper = agent_wrapper 60 self.results = {} 61 62 async def execute_stage( 63 self, stage: ScanStage, repo_dir: str, prompt: str, context_data: dict[str, Any] = None 64 ) -> str: 65 logger.info(f"=== 阶段 {stage.stage_id}: {stage.name} ===") 66 mcpLogger.new_plan_step(stepId=stage.stage_id, stepName=stage.name) 67 68 # 加载提示词模板 69 instruction = prompt_manager.load_template(stage.template) 70 71 # 初始化阶段 Agent 72 agent = BaseAgent( 73 name=f"{stage.name} Agent", 74 instruction=instruction, 75 llm=self.agent_wrapper.llm, 76 dispatcher=self.agent_wrapper.dispatcher, 77 specialized_llms=self.agent_wrapper.specialized_llms, 78 log_step_id=stage.stage_id, 79 debug=self.agent_wrapper.debug, 80 output_format=stage.output_format, 81 output_check_fn=stage.output_check_fn, 82 language=stage.language, 83 ) 84 agent.set_repo_dir(repo_dir) 85 await agent.initialize() 86 87 # 构造用户消息 88 user_msg = f"请进行{stage.name},文件夹在 {repo_dir}\n{prompt}" 89 if context_data: 90 user_msg += "\n\n有以下背景信息:\n" 91 for key, value in context_data.items(): 92 user_msg += f"{key}:{value}\n\n" 93 94 agent.add_user_message(user_msg) 95 96 # 运行并返回结果 97 result = await agent.run() 98 self.results[stage.name] = result 99 return result 100 101 async def execute_stage_dynamic( 102 self, stage: ScanStage, prompt: str, context_data: dict[str, Any] = None 103 ) -> str: 104 logger.info(f"=== 阶段 {stage.stage_id}: {stage.name} ===") 105 mcpLogger.new_plan_step(stepId=stage.stage_id, stepName=stage.name) 106 107 # 加载提示词模板 108 instruction = prompt_manager.load_template(stage.template) 109 110 # 初始化阶段 Agent 111 agent = BaseAgent( 112 name=f"{stage.name} Agent", 113 instruction=instruction, 114 llm=self.agent_wrapper.llm, 115 dispatcher=self.agent_wrapper.dispatcher, 116 specialized_llms=self.agent_wrapper.specialized_llms, 117 log_step_id=stage.stage_id, 118 debug=self.agent_wrapper.debug, 119 output_format=stage.output_format, 120 output_check_fn=stage.output_check_fn, 121 ) 122 await agent.initialize() 123 124 # 构造用户消息 125 user_msg = f"请进行{stage.name},进行MCP动态扫描\n{prompt}" 126 if context_data: 127 user_msg += "\n\n有以下背景信息:\n" 128 for key, value in context_data.items(): 129 user_msg += f"{key}:{value}\n\n" 130 131 agent.add_user_message(user_msg) 132 133 # 运行并返回结果 134 result = await agent.run() 135 self.results[stage.name] = result 136 return result 137 138 139 class Agent: 140 def __init__( 141 self, 142 llm, 143 specialized_llms: dict = None, 144 debug: bool = False, 145 server_url: str = None, 146 language="zh", 147 headers=None, 148 ): 149 self.llm = llm 150 self.specialized_llms = specialized_llms or {} 151 self.debug = debug 152 self.dispatcher = ToolDispatcher(mcp_server_url=server_url, mcp_headers=headers) 153 self.pipeline = ScanPipeline(self) 154 self.language = language 155 156 async def scan(self, repo_dir: str, prompt: str): 157 result_meta = { 158 "readme": "", 159 "score": 0, 160 "language": "", 161 "start_time": time.time(), 162 "end_time": 0, 163 "results": [], 164 "llm": self.llm.model, 165 } 166 # 1. 信息收集 167 info_ret_format = "生成一份详细的信息收集报告,使用Markdown格式。报告需基于输入数据如实总结,确保读者(对项目一无所知)能快速理解项目全貌。" 168 info_collection = await self.pipeline.execute_stage( 169 ScanStage( 170 "1", 171 "Info Collection", 172 "agents/project_summary", 173 output_format=info_ret_format, 174 language=self.language, 175 ), 176 repo_dir, 177 prompt, 178 ) 179 180 # 2. 代码审计 181 audit_ret_format = """ 182 markdown格式返回 183 对于每个确认的漏洞,必须提供: 184 - 具体位置:文件路径和行号范围 185 - 完整代码片段:显示漏洞的代码段 186 - 技术分析:漏洞原理和利用方法 187 - 影响评估:可获得的权限和影响范围 188 - 修复建议:详细的安全加固方案 189 - 攻击路径:具体的利用步骤(如适用) 190 严格标准:必须提供完整的漏洞利用路径和影响分析。 191 """ 192 code_audit = await self.pipeline.execute_stage( 193 ScanStage( 194 "2", 195 "Code Audit", 196 "agents/code_audit", 197 output_format=audit_ret_format, 198 language=self.language, 199 ), 200 repo_dir, 201 prompt, 202 {"信息收集报告": info_collection}, 203 ) 204 205 # 3. 漏洞整理 206 review_format = """ 207 必须满足以下xml格式,多个漏洞返回多个vuln标签 208 <vuln> 209 <title>title</title> 210 <desc> 211 <!-- Markdown格式漏洞描述 --> 212 ## 漏洞详情 213 **文件位置**: 214 **漏洞类型**: 215 **风险等级**: 216 217 ### 技术分析 218 219 ### 攻击路径 220 221 ### 影响评估 222 </desc> 223 <risk_type>RiskType</risk_type> 224 <level>Level</level> 225 <suggestion> 226 ## 修复建议 227 </suggestion> 228 </vuln> 229 若无漏洞或漏洞为空,返回<empty> 230 """.strip() 231 vuln_review = await self.pipeline.execute_stage( 232 ScanStage( 233 "3", 234 "Vulnerability Review", 235 "agents/vuln_review", 236 output_format=review_format, 237 output_check_fn=is_vuln_review_output, 238 language=self.language, 239 ), 240 repo_dir, 241 prompt, 242 {"代码审计报告": code_audit}, 243 ) 244 245 # 提取与分析结果 246 extractor = VulnerabilityExtractor() 247 vuln_results = extractor.extract_vulnerabilities(vuln_review) 248 249 elasped_time = (time.time() - result_meta["start_time"]) / 60 250 logger.info(f"扫描任务完成,总耗时 {elasped_time:.2f} 分钟") 251 lang_stats = analyze_language(repo_dir) 252 top_language = get_top_language(lang_stats) 253 safety_score = calc_mcp_score(vuln_results) 254 255 result_meta.update( 256 { 257 "readme": info_collection, 258 "score": safety_score, 259 "language": top_language, 260 "end_time": time.time(), 261 "results": vuln_results, 262 } 263 ) 264 mcpLogger.result_update(result_meta) 265 return result_meta 266 267 async def dynamic_analysis(self, prompt: str): 268 result_meta = { 269 "readme": "", 270 "score": 0, 271 "language": "", 272 "start_time": time.time(), 273 "end_time": 0, 274 "results": [], 275 } 276 277 info_ret_format = "生成一份详细的MCP(model context protocol)信息收集报告,使用Markdown格式。报告需基于输入数据如实总结,确保读者(对项目一无所知)能快速理解项目全貌。" 278 info_collection = await self.pipeline.execute_stage_dynamic( 279 ScanStage( 280 "1", 281 "Info Collection", 282 "agents/dynamic/project_summary", 283 output_format=info_ret_format, 284 language=self.language, 285 ), 286 prompt=prompt, 287 ) 288 result_meta["readme"] = info_collection 289 290 # 漏洞探测 291 vuln_ret_format = """ 292 ## Output format 293 - The output should be in Markdown format. Please Never use any other format, and make sure the output has no format issue. 294 - The Markdown document should have the following Chapter: 295 - "Overview": `YES` or `NO`, representing whether there are any risks analyzed. 296 - "Threats": A list of xml strings, each representing a threat analyzed. Including threat types, confidence scores, and potential impacts. 297 - "Reasons": A list of normal strings, each representing the reason why the corresponding threat is analyzed. 298 - "Summarization": A paragraph summarizing the overall security assessment results. 299 - example: 300 ``` 301 # Overview 302 - YES 303 # Threats 304 - <threat><tool_name>{{ tool_name }}</tool_name><type>SQL Injection</type><confidence>0.9</confidence><impact>High</impact></threat> 305 # Reasons 306 - SQL Injection: The tool named {{ tool_name }} detected a potential SQL Injection vulnerability in the input parameter. 307 # Summarization: 308 ...... (The clear, detailed summary of the security assessment results) 309 ``` 310 """ 311 report1 = await self.pipeline.execute_stage_dynamic( 312 ScanStage( 313 "2", 314 "Malicious Testing", 315 "agents/dynamic/malicious_behaviour_testing.md", 316 output_format=vuln_ret_format, 317 language=self.language, 318 ), 319 prompt, 320 {"信息收集报告": info_collection}, 321 ) 322 report2 = await self.pipeline.execute_stage_dynamic( 323 ScanStage( 324 "3", 325 "Vulnerability Testing", 326 "agents/dynamic/vulnerability_testing.md", 327 output_format=vuln_ret_format, 328 language=self.language, 329 ), 330 prompt, 331 {"信息收集报告": info_collection, "malicious testing": report1}, 332 ) 333 334 # 3. 漏洞整理 335 review_format = """ 336 必须满足以下xml格式,多个漏洞返回多个vuln标签 337 <vuln> 338 <title>title</title> 339 <desc> 340 <!-- Markdown格式漏洞描述 --> 341 ## 漏洞详情 342 **文件位置**: 343 **漏洞类型**: 344 **风险等级**: 345 346 ### 技术分析 347 348 ### 攻击路径 349 350 ### 影响评估 351 </desc> 352 <risk_type>RiskType</risk_type> 353 <level>Level</level> 354 <suggestion> 355 ## 修复建议 356 </suggestion> 357 </vuln> 358 若无漏洞或漏洞为空,返回<empty> 359 """.strip() 360 vuln_review = await self.pipeline.execute_stage_dynamic( 361 ScanStage( 362 "4", 363 "Vulnerability Review", 364 "agents/dynamic/general_analyzing_prompt_template", 365 output_format=review_format, 366 output_check_fn=is_vuln_review_output, 367 language=self.language, 368 ), 369 prompt, 370 {"malicious testing": report1, "vulnerability testing": report2}, 371 ) 372 # 提取与分析结果 373 extractor = VulnerabilityExtractor() 374 vuln_results = extractor.extract_vulnerabilities(vuln_review) 375 safety_score = calc_mcp_score(vuln_results) 376 377 result_meta.update( 378 { 379 "readme": info_collection, 380 "score": safety_score, 381 "end_time": time.time(), 382 "results": vuln_results, 383 } 384 ) 385 mcpLogger.result_update(result_meta) 386 return result_meta