/ mcp-scan / agent / agent.py
agent.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import time
 20  from typing import Any
 21  
 22  from agent.base_agent import BaseAgent
 23  from tools.dispatcher import ToolDispatcher
 24  from utils.aig_logger import mcpLogger
 25  from utils.extract_vuln import VulnerabilityExtractor
 26  from utils.loging import logger
 27  from utils.project_analyzer import analyze_language, calc_mcp_score, get_top_language
 28  from utils.prompt_manager import prompt_manager
 29  
 30  
 31  def is_vuln_review_output(content: str) -> bool:
 32      return "<vuln>" in content or "<empty>" in content
 33  
 34  
 35  class ScanStage:
 36      """定义扫描的一个阶段"""
 37  
 38      def __init__(
 39          self,
 40          stage_id: str,
 41          name: str,
 42          template: str,
 43          output_format: str = None,
 44          output_check_fn=None,
 45          language="zh",
 46      ):
 47          self.stage_id = stage_id
 48          self.name = name
 49          self.template = template
 50          self.output_format = output_format
 51          self.output_check_fn = output_check_fn
 52          self.language = language
 53  
 54  
 55  class ScanPipeline:
 56      """标准扫描流水线逻辑"""
 57  
 58      def __init__(self, agent_wrapper: "Agent"):
 59          self.agent_wrapper = agent_wrapper
 60          self.results = {}
 61  
 62      async def execute_stage(
 63          self, stage: ScanStage, repo_dir: str, prompt: str, context_data: dict[str, Any] = None
 64      ) -> str:
 65          logger.info(f"=== 阶段 {stage.stage_id}: {stage.name} ===")
 66          mcpLogger.new_plan_step(stepId=stage.stage_id, stepName=stage.name)
 67  
 68          # 加载提示词模板
 69          instruction = prompt_manager.load_template(stage.template)
 70  
 71          # 初始化阶段 Agent
 72          agent = BaseAgent(
 73              name=f"{stage.name} Agent",
 74              instruction=instruction,
 75              llm=self.agent_wrapper.llm,
 76              dispatcher=self.agent_wrapper.dispatcher,
 77              specialized_llms=self.agent_wrapper.specialized_llms,
 78              log_step_id=stage.stage_id,
 79              debug=self.agent_wrapper.debug,
 80              output_format=stage.output_format,
 81              output_check_fn=stage.output_check_fn,
 82              language=stage.language,
 83          )
 84          agent.set_repo_dir(repo_dir)
 85          await agent.initialize()
 86  
 87          # 构造用户消息
 88          user_msg = f"请进行{stage.name},文件夹在 {repo_dir}\n{prompt}"
 89          if context_data:
 90              user_msg += "\n\n有以下背景信息:\n"
 91              for key, value in context_data.items():
 92                  user_msg += f"{key}:{value}\n\n"
 93  
 94          agent.add_user_message(user_msg)
 95  
 96          # 运行并返回结果
 97          result = await agent.run()
 98          self.results[stage.name] = result
 99          return result
100  
101      async def execute_stage_dynamic(
102          self, stage: ScanStage, prompt: str, context_data: dict[str, Any] = None
103      ) -> str:
104          logger.info(f"=== 阶段 {stage.stage_id}: {stage.name} ===")
105          mcpLogger.new_plan_step(stepId=stage.stage_id, stepName=stage.name)
106  
107          # 加载提示词模板
108          instruction = prompt_manager.load_template(stage.template)
109  
110          # 初始化阶段 Agent
111          agent = BaseAgent(
112              name=f"{stage.name} Agent",
113              instruction=instruction,
114              llm=self.agent_wrapper.llm,
115              dispatcher=self.agent_wrapper.dispatcher,
116              specialized_llms=self.agent_wrapper.specialized_llms,
117              log_step_id=stage.stage_id,
118              debug=self.agent_wrapper.debug,
119              output_format=stage.output_format,
120              output_check_fn=stage.output_check_fn,
121          )
122          await agent.initialize()
123  
124          # 构造用户消息
125          user_msg = f"请进行{stage.name},进行MCP动态扫描\n{prompt}"
126          if context_data:
127              user_msg += "\n\n有以下背景信息:\n"
128              for key, value in context_data.items():
129                  user_msg += f"{key}:{value}\n\n"
130  
131          agent.add_user_message(user_msg)
132  
133          # 运行并返回结果
134          result = await agent.run()
135          self.results[stage.name] = result
136          return result
137  
138  
139  class Agent:
140      def __init__(
141          self,
142          llm,
143          specialized_llms: dict = None,
144          debug: bool = False,
145          server_url: str = None,
146          language="zh",
147          headers=None,
148      ):
149          self.llm = llm
150          self.specialized_llms = specialized_llms or {}
151          self.debug = debug
152          self.dispatcher = ToolDispatcher(mcp_server_url=server_url, mcp_headers=headers)
153          self.pipeline = ScanPipeline(self)
154          self.language = language
155  
156      async def scan(self, repo_dir: str, prompt: str):
157          result_meta = {
158              "readme": "",
159              "score": 0,
160              "language": "",
161              "start_time": time.time(),
162              "end_time": 0,
163              "results": [],
164              "llm": self.llm.model,
165          }
166          # 1. 信息收集
167          info_ret_format = "生成一份详细的信息收集报告,使用Markdown格式。报告需基于输入数据如实总结,确保读者(对项目一无所知)能快速理解项目全貌。"
168          info_collection = await self.pipeline.execute_stage(
169              ScanStage(
170                  "1",
171                  "Info Collection",
172                  "agents/project_summary",
173                  output_format=info_ret_format,
174                  language=self.language,
175              ),
176              repo_dir,
177              prompt,
178          )
179  
180          # 2. 代码审计
181          audit_ret_format = """
182  markdown格式返回
183  对于每个确认的漏洞,必须提供:
184  - 具体位置:文件路径和行号范围
185  - 完整代码片段:显示漏洞的代码段
186  - 技术分析:漏洞原理和利用方法
187  - 影响评估:可获得的权限和影响范围
188  - 修复建议:详细的安全加固方案
189  - 攻击路径:具体的利用步骤(如适用)
190  严格标准:必须提供完整的漏洞利用路径和影响分析。
191          """
192          code_audit = await self.pipeline.execute_stage(
193              ScanStage(
194                  "2",
195                  "Code Audit",
196                  "agents/code_audit",
197                  output_format=audit_ret_format,
198                  language=self.language,
199              ),
200              repo_dir,
201              prompt,
202              {"信息收集报告": info_collection},
203          )
204  
205          # 3. 漏洞整理
206          review_format = """
207  必须满足以下xml格式,多个漏洞返回多个vuln标签
208  <vuln>
209    <title>title</title>
210    <desc>
211    <!-- Markdown格式漏洞描述 -->
212    ## 漏洞详情
213    **文件位置**: 
214    **漏洞类型**: 
215    **风险等级**: 
216    
217    ### 技术分析
218    
219    ### 攻击路径
220    
221    ### 影响评估  
222    </desc>
223    <risk_type>RiskType</risk_type>
224    <level>Level</level>
225    <suggestion>
226    ## 修复建议
227    </suggestion>
228  </vuln>
229  若无漏洞或漏洞为空,返回<empty>
230  """.strip()
231          vuln_review = await self.pipeline.execute_stage(
232              ScanStage(
233                  "3",
234                  "Vulnerability Review",
235                  "agents/vuln_review",
236                  output_format=review_format,
237                  output_check_fn=is_vuln_review_output,
238                  language=self.language,
239              ),
240              repo_dir,
241              prompt,
242              {"代码审计报告": code_audit},
243          )
244  
245          # 提取与分析结果
246          extractor = VulnerabilityExtractor()
247          vuln_results = extractor.extract_vulnerabilities(vuln_review)
248  
249          elasped_time = (time.time() - result_meta["start_time"]) / 60
250          logger.info(f"扫描任务完成,总耗时 {elasped_time:.2f} 分钟")
251          lang_stats = analyze_language(repo_dir)
252          top_language = get_top_language(lang_stats)
253          safety_score = calc_mcp_score(vuln_results)
254  
255          result_meta.update(
256              {
257                  "readme": info_collection,
258                  "score": safety_score,
259                  "language": top_language,
260                  "end_time": time.time(),
261                  "results": vuln_results,
262              }
263          )
264          mcpLogger.result_update(result_meta)
265          return result_meta
266  
267      async def dynamic_analysis(self, prompt: str):
268          result_meta = {
269              "readme": "",
270              "score": 0,
271              "language": "",
272              "start_time": time.time(),
273              "end_time": 0,
274              "results": [],
275          }
276  
277          info_ret_format = "生成一份详细的MCP(model context protocol)信息收集报告,使用Markdown格式。报告需基于输入数据如实总结,确保读者(对项目一无所知)能快速理解项目全貌。"
278          info_collection = await self.pipeline.execute_stage_dynamic(
279              ScanStage(
280                  "1",
281                  "Info Collection",
282                  "agents/dynamic/project_summary",
283                  output_format=info_ret_format,
284                  language=self.language,
285              ),
286              prompt=prompt,
287          )
288          result_meta["readme"] = info_collection
289  
290          # 漏洞探测
291          vuln_ret_format = """
292          ## Output format
293  - The output should be in Markdown format. Please Never use any other format, and make sure the output has no format issue.
294  - The Markdown document should have the following Chapter:
295      - "Overview": `YES` or `NO`, representing whether there are any risks analyzed.
296      - "Threats": A list of xml strings, each representing a threat analyzed. Including threat types, confidence scores, and potential impacts.
297      - "Reasons": A list of normal strings, each representing the reason why the corresponding threat is analyzed.
298      - "Summarization": A paragraph summarizing the overall security assessment results.
299  - example:
300      ```
301      # Overview
302      - YES
303      # Threats
304          - <threat><tool_name>{{ tool_name }}</tool_name><type>SQL Injection</type><confidence>0.9</confidence><impact>High</impact></threat>
305      # Reasons 
306          - SQL Injection: The tool named {{ tool_name }} detected a potential SQL Injection vulnerability in the input parameter.
307      # Summarization: 
308          ...... (The clear, detailed summary of the security assessment results)
309      ```
310          """
311          report1 = await self.pipeline.execute_stage_dynamic(
312              ScanStage(
313                  "2",
314                  "Malicious Testing",
315                  "agents/dynamic/malicious_behaviour_testing.md",
316                  output_format=vuln_ret_format,
317                  language=self.language,
318              ),
319              prompt,
320              {"信息收集报告": info_collection},
321          )
322          report2 = await self.pipeline.execute_stage_dynamic(
323              ScanStage(
324                  "3",
325                  "Vulnerability Testing",
326                  "agents/dynamic/vulnerability_testing.md",
327                  output_format=vuln_ret_format,
328                  language=self.language,
329              ),
330              prompt,
331              {"信息收集报告": info_collection, "malicious testing": report1},
332          )
333  
334          # 3. 漏洞整理
335          review_format = """
336          必须满足以下xml格式,多个漏洞返回多个vuln标签
337          <vuln>
338            <title>title</title>
339            <desc>
340            <!-- Markdown格式漏洞描述 -->
341            ## 漏洞详情
342            **文件位置**: 
343            **漏洞类型**: 
344            **风险等级**: 
345  
346            ### 技术分析
347  
348            ### 攻击路径
349  
350            ### 影响评估  
351            </desc>
352            <risk_type>RiskType</risk_type>
353            <level>Level</level>
354            <suggestion>
355            ## 修复建议
356            </suggestion>
357          </vuln>
358          若无漏洞或漏洞为空,返回<empty>
359          """.strip()
360          vuln_review = await self.pipeline.execute_stage_dynamic(
361              ScanStage(
362                  "4",
363                  "Vulnerability Review",
364                  "agents/dynamic/general_analyzing_prompt_template",
365                  output_format=review_format,
366                  output_check_fn=is_vuln_review_output,
367                  language=self.language,
368              ),
369              prompt,
370              {"malicious testing": report1, "vulnerability testing": report2},
371          )
372          # 提取与分析结果
373          extractor = VulnerabilityExtractor()
374          vuln_results = extractor.extract_vulnerabilities(vuln_review)
375          safety_score = calc_mcp_score(vuln_results)
376  
377          result_meta.update(
378              {
379                  "readme": info_collection,
380                  "score": safety_score,
381                  "end_time": time.time(),
382                  "results": vuln_results,
383              }
384          )
385          mcpLogger.result_update(result_meta)
386          return result_meta