main.py
1 #!/usr/bin/env python3 2 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # 16 # Requirement: Any integration or derivative work must explicitly attribute 17 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 18 # documentation or user interface, as detailed in the NOTICE file. 19 20 """ 21 Agent Framework - 主入口文件 22 23 这是一个模仿 Claude Code / Gemini CLI 的 Agent 框架。 24 Agent 可以自动调用工具完成任务。 25 """ 26 27 import argparse 28 import asyncio 29 import os 30 import sys 31 32 from agent.agent import Agent 33 from utils import config 34 from utils.aig_logger import mcpLogger 35 from utils.llm import LLM 36 from utils.llm_manager import LLMManager 37 from utils.loging import logger 38 39 # 重要:导入 tools 包以触发工具注册 40 import tools as _ # isort: skip 41 42 _prompts = {"zh": "所有回复都应使用中文。", "en": "All responses should be in English."} 43 44 45 def parse_args(): 46 """解析命令行参数""" 47 parser = argparse.ArgumentParser( 48 description="Agent Framework - 代码扫描和漏洞检测工具", 49 formatter_class=argparse.RawDescriptionHelpFormatter, 50 ) 51 52 # 必需参数 53 parser.add_argument("--repo", default="", help="要扫描的项目文件夹路径") 54 55 # 可选参数 56 parser.add_argument("-p", "--prompt", default="", help="自定义扫描提示词(可选)") 57 58 parser.add_argument( 59 "-m", 60 "--model", 61 default=config.DEFAULT_MODEL, 62 help=f"LLM 模型名称(默认: {config.DEFAULT_MODEL})", 63 ) 64 65 parser.add_argument( 66 "-k", 67 "--api_key", 68 default=None, 69 help="API Key(如果不提供,将从环境变量 OPENROUTER_API_KEY 读取)", 70 ) 71 72 parser.add_argument( 73 "-u", 74 "--base_url", 75 default=config.DEFAULT_BASE_URL, 76 help=f"API 基础 URL(默认: {config.DEFAULT_BASE_URL})", 77 ) 78 79 parser.add_argument( 80 "--debug", 81 action="store_true", 82 help="启用 debug 模式(包括 Laminar 跟踪)", 83 default=False, 84 ) 85 86 parser.add_argument("--server_url", help="remote MCP server URL", default=None) 87 88 parser.add_argument( 89 "--header", 90 action="append", 91 dest="headers", 92 help="Custom header in key:value format (can be used multiple times)", 93 default=[], 94 ) 95 96 parser.add_argument("--language", default="zh", choices=["zh", "en"], help="Output language") 97 98 return parser.parse_args() 99 100 101 async def main(): 102 """主函数""" 103 # 解析命令行参数 104 args = parse_args() 105 106 # 获取 API Key(优先使用命令行参数,否则从环境变量读取) 107 api_key = args.api_key or os.environ.get("OPENROUTER_API_KEY") 108 if not api_key: 109 logger.error( 110 "API Key not provided. Use --api_key or set OPENROUTER_API_KEY environment variable." 111 ) 112 sys.exit(1) 113 114 # 创建主 LLM 实例 115 llm = LLM( 116 model=args.model, 117 api_key=api_key, 118 base_url=args.base_url, 119 context_window=config.DEFAULT_MODEL_CONTEXT_WINDOW, 120 ) 121 logger.info(f"Main LLM initialized: {args.model}") 122 123 # 使用主 API Key 作为默认值 124 llm_manager = LLMManager(api_key=api_key, base_url=args.base_url) 125 126 # 获取专用LLM实例字典 127 specialized_llms = llm_manager.get_specialized_llms(["thinking", "coding"]) 128 logger.info(f"Specialized LLMs configured: {list(specialized_llms.keys())}") 129 130 user_prompt = args.prompt.strip() 131 lang_prompt = _prompts.get(args.language, "") 132 prompt = f"{user_prompt}\n\n{lang_prompt}" if user_prompt else lang_prompt 133 if prompt: 134 logger.info(f"Custom prompt: {prompt}") 135 136 # 解析 headers 137 headers = {} 138 if args.headers: 139 for header_item in args.headers: 140 try: 141 if ":" in header_item: 142 key, value = header_item.split(":", 1) 143 headers[key.strip()] = value.strip() 144 elif "=" in header_item: 145 key, value = header_item.split("=", 1) 146 headers[key.strip()] = value.strip() 147 else: 148 logger.warning(f"Ignored invalid header format: {header_item}") 149 except Exception as e: 150 logger.warning(f"Failed to parse header {header_item}: {e}") 151 152 if headers: 153 logger.info(f"Custom headers: {headers}") 154 155 agent = Agent( 156 llm=llm, 157 specialized_llms=specialized_llms, 158 debug=args.debug, 159 server_url=args.server_url, 160 language=args.language, 161 headers=headers, 162 ) 163 try: 164 if args.server_url: 165 logger.info(f"Server mode enabled with URL: {args.server_url}") 166 dynamic_results = await agent.dynamic_analysis(prompt) 167 logger.info(f"Dynamic analysis results:\n{dynamic_results}") 168 else: 169 # 验证项目路径 170 if not os.path.exists(args.repo): 171 logger.error(f"Project path does not exist: {args.repo}") 172 sys.exit(1) 173 174 if not os.path.isdir(args.repo): 175 logger.error(f"Project path is not a directory: {args.repo}") 176 sys.exit(1) 177 178 logger.info(f"Starting scan on: {args.repo}") 179 result = await agent.scan(args.repo, prompt) 180 logger.info(f"Scan completed successfully:\n\n {result}") 181 except KeyboardInterrupt: 182 print("\n\nTask interrupted by user.") 183 logger.warning("Task interrupted by user") 184 except Exception as e: 185 print(f"\n\nError during execution: {e}") 186 logger.error(f"Error during execution: {e}") 187 import traceback 188 189 tb = traceback.format_exc() 190 mcpLogger.error_log(f"Execution failed: {e}\n{tb}") 191 raise 192 finally: 193 # 确保关闭资源 194 if hasattr(agent, "dispatcher"): 195 await agent.dispatcher.close() 196 197 198 if __name__ == "__main__": 199 asyncio.run(main())