/ mcp-scan / main.py
main.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  3  #
  4  # Licensed under the Apache License, Version 2.0 (the "License");
  5  # you may not use this file except in compliance with the License.
  6  # You may obtain a copy of the License at
  7  #
  8  #     http://www.apache.org/licenses/LICENSE-2.0
  9  #
 10  # Unless required by applicable law or agreed to in writing, software
 11  # distributed under the License is distributed on an "AS IS" BASIS,
 12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  # See the License for the specific language governing permissions and
 14  # limitations under the License.
 15  #
 16  # Requirement: Any integration or derivative work must explicitly attribute
 17  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 18  # documentation or user interface, as detailed in the NOTICE file.
 19  
 20  """
 21  Agent Framework - 主入口文件
 22  
 23  这是一个模仿 Claude Code / Gemini CLI 的 Agent 框架。
 24  Agent 可以自动调用工具完成任务。
 25  """
 26  
 27  import argparse
 28  import asyncio
 29  import os
 30  import sys
 31  
 32  from agent.agent import Agent
 33  from utils import config
 34  from utils.aig_logger import mcpLogger
 35  from utils.llm import LLM
 36  from utils.llm_manager import LLMManager
 37  from utils.loging import logger
 38  
 39  # 重要:导入 tools 包以触发工具注册
 40  import tools as _  # isort: skip
 41  
 42  _prompts = {"zh": "所有回复都应使用中文。", "en": "All responses should be in English."}
 43  
 44  
 45  def parse_args():
 46      """解析命令行参数"""
 47      parser = argparse.ArgumentParser(
 48          description="Agent Framework - 代码扫描和漏洞检测工具",
 49          formatter_class=argparse.RawDescriptionHelpFormatter,
 50      )
 51  
 52      # 必需参数
 53      parser.add_argument("--repo", default="", help="要扫描的项目文件夹路径")
 54  
 55      # 可选参数
 56      parser.add_argument("-p", "--prompt", default="", help="自定义扫描提示词(可选)")
 57  
 58      parser.add_argument(
 59          "-m",
 60          "--model",
 61          default=config.DEFAULT_MODEL,
 62          help=f"LLM 模型名称(默认: {config.DEFAULT_MODEL})",
 63      )
 64  
 65      parser.add_argument(
 66          "-k",
 67          "--api_key",
 68          default=None,
 69          help="API Key(如果不提供,将从环境变量 OPENROUTER_API_KEY 读取)",
 70      )
 71  
 72      parser.add_argument(
 73          "-u",
 74          "--base_url",
 75          default=config.DEFAULT_BASE_URL,
 76          help=f"API 基础 URL(默认: {config.DEFAULT_BASE_URL})",
 77      )
 78  
 79      parser.add_argument(
 80          "--debug",
 81          action="store_true",
 82          help="启用 debug 模式(包括 Laminar 跟踪)",
 83          default=False,
 84      )
 85  
 86      parser.add_argument("--server_url", help="remote MCP server URL", default=None)
 87  
 88      parser.add_argument(
 89          "--header",
 90          action="append",
 91          dest="headers",
 92          help="Custom header in key:value format (can be used multiple times)",
 93          default=[],
 94      )
 95  
 96      parser.add_argument("--language", default="zh", choices=["zh", "en"], help="Output language")
 97  
 98      return parser.parse_args()
 99  
100  
101  async def main():
102      """主函数"""
103      # 解析命令行参数
104      args = parse_args()
105  
106      # 获取 API Key(优先使用命令行参数,否则从环境变量读取)
107      api_key = args.api_key or os.environ.get("OPENROUTER_API_KEY")
108      if not api_key:
109          logger.error(
110              "API Key not provided. Use --api_key or set OPENROUTER_API_KEY environment variable."
111          )
112          sys.exit(1)
113  
114      # 创建主 LLM 实例
115      llm = LLM(
116          model=args.model,
117          api_key=api_key,
118          base_url=args.base_url,
119          context_window=config.DEFAULT_MODEL_CONTEXT_WINDOW,
120      )
121      logger.info(f"Main LLM initialized: {args.model}")
122  
123      # 使用主 API Key 作为默认值
124      llm_manager = LLMManager(api_key=api_key, base_url=args.base_url)
125  
126      # 获取专用LLM实例字典
127      specialized_llms = llm_manager.get_specialized_llms(["thinking", "coding"])
128      logger.info(f"Specialized LLMs configured: {list(specialized_llms.keys())}")
129  
130      user_prompt = args.prompt.strip()
131      lang_prompt = _prompts.get(args.language, "")
132      prompt = f"{user_prompt}\n\n{lang_prompt}" if user_prompt else lang_prompt
133      if prompt:
134          logger.info(f"Custom prompt: {prompt}")
135  
136      # 解析 headers
137      headers = {}
138      if args.headers:
139          for header_item in args.headers:
140              try:
141                  if ":" in header_item:
142                      key, value = header_item.split(":", 1)
143                      headers[key.strip()] = value.strip()
144                  elif "=" in header_item:
145                      key, value = header_item.split("=", 1)
146                      headers[key.strip()] = value.strip()
147                  else:
148                      logger.warning(f"Ignored invalid header format: {header_item}")
149              except Exception as e:
150                  logger.warning(f"Failed to parse header {header_item}: {e}")
151  
152          if headers:
153              logger.info(f"Custom headers: {headers}")
154  
155      agent = Agent(
156          llm=llm,
157          specialized_llms=specialized_llms,
158          debug=args.debug,
159          server_url=args.server_url,
160          language=args.language,
161          headers=headers,
162      )
163      try:
164          if args.server_url:
165              logger.info(f"Server mode enabled with URL: {args.server_url}")
166              dynamic_results = await agent.dynamic_analysis(prompt)
167              logger.info(f"Dynamic analysis results:\n{dynamic_results}")
168          else:
169              # 验证项目路径
170              if not os.path.exists(args.repo):
171                  logger.error(f"Project path does not exist: {args.repo}")
172                  sys.exit(1)
173  
174              if not os.path.isdir(args.repo):
175                  logger.error(f"Project path is not a directory: {args.repo}")
176                  sys.exit(1)
177  
178              logger.info(f"Starting scan on: {args.repo}")
179              result = await agent.scan(args.repo, prompt)
180              logger.info(f"Scan completed successfully:\n\n {result}")
181      except KeyboardInterrupt:
182          print("\n\nTask interrupted by user.")
183          logger.warning("Task interrupted by user")
184      except Exception as e:
185          print(f"\n\nError during execution: {e}")
186          logger.error(f"Error during execution: {e}")
187          import traceback
188  
189          tb = traceback.format_exc()
190          mcpLogger.error_log(f"Execution failed: {e}\n{tb}")
191          raise
192      finally:
193          # 确保关闭资源
194          if hasattr(agent, "dispatcher"):
195              await agent.dispatcher.close()
196  
197  
198  if __name__ == "__main__":
199      asyncio.run(main())