/ AIG-PromptSecurity / cli_run.py
cli_run.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import time 20 from pathlib import Path 21 import argparse 22 23 from cli.aig_logger import logger 24 from cli.aig_logger import ( 25 newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate 26 ) 27 from deepteam.plugin_system import PluginManager 28 from cli.models import create_model 29 from cli.plugin_commands import list_plugins, load_plugins_from_args, show_plugin_template, validate_plugin, auto_discover_plugins 30 from cli.red_team_runner import RedTeamRunner 31 from cli.tool_scanner_cli import handle_tool_scanning 32 33 34 # logger config 35 logger.add(f"logs/red_team_{{time:YYYY-MM-DD_HH-mm-ss}}.log", level="DEBUG", enqueue=True, retention="7 days") 36 37 # 全局插件管理器 38 plugin_manager = PluginManager() 39 40 def cleanup_expired_files(log_path: str = "logs", max_age_seconds: int = 86400*30, pattern: str = "attachment_*.csv"): 41 now = time.time() 42 for file in Path(log_path).glob(pattern): 43 if (now - file.stat().st_mtime) > max_age_seconds: 44 file.unlink() 45 46 def main(): 47 """主函数""" 48 parser = argparse.ArgumentParser(description="Red Team CLI Runner") 49 50 # 工具扫描相关参数(放在最前面,优先级最高) 51 parser.add_argument("--scan-tools", type=str, choices=["all", "techniques", "metrics", "scenarios"], 52 help="Scan and display all available tools and their parameters") 53 parser.add_argument("--show-tool-params", type=str, 54 help="Show detailed parameter information for a specific tool") 55 56 # 插件相关参数 57 parser.add_argument("--plugins", type=str, nargs='+', help="Custom plugin files or directories to load") 58 parser.add_argument("--list-plugins", action="store_true", help="List all available plugins") 59 parser.add_argument("--show-template", type=str, choices=["attack", "metric", "vulnerability"], help="Show plugin template") 60 parser.add_argument("--validate-plugin", type=str, help="Validate a plugin file or directory") 61 parser.add_argument("--auto-discover", action="store_true", help="Auto-discover plugins from default directories") 62 63 # 红队测试相关参数 64 parser.add_argument("--base_url", type=str, action='append', help="Base URL for ChatOpenAI") 65 parser.add_argument("--api_key", type=str, nargs=1, action='append', help="API Key for ChatOpenAI") 66 parser.add_argument("--model", type=str, action='append', help="Model name for ChatOpenAI") 67 parser.add_argument("--max_concurrent", type=int, action='append', help="Max concurrent") 68 parser.add_argument("--sim_base_url", type=str, help="Base URL for a simulator model") 69 parser.add_argument("--sim_api_key", type=str, nargs=1, help="API Key for a simulator model") 70 parser.add_argument("--simulator_model", type=str, help="Model name for a simulator model") 71 parser.add_argument("--sim_max_concurrent", type=int, default=10, help="Max concurrent") 72 parser.add_argument("--eval_base_url", type=str, help="Base URL for a evaluate model") 73 parser.add_argument("--eval_api_key", type=str, nargs=1, help="API Key for a evaluate model") 74 parser.add_argument("--evaluate_model", type=str, help="Model name for a evaluate model") 75 parser.add_argument("--eval_max_concurrent", type=int, default=10, help="Max concurrent") 76 77 parser.add_argument("--scenarios", type=str, nargs='+', help="Scenarios to test") 78 parser.add_argument("--techniques", type=str, nargs='+', help="Techniques to test") 79 80 parser.add_argument("--async_mode", action='store_true', help="Enable async mode") 81 parser.add_argument("--choice", type=str, default="random", choices=["random", "serial", "parallel"], 82 help="Technique selection strategy: 'random' (default) or 'serial' (nested techniques) or 'parallel'") 83 parser.add_argument("--metric", type=str, help="Metric class name (e.g., 'RandomMetric')") 84 parser.add_argument("--report", type=str, default="logs/report.md", help="Path to save the risk assessment report (default: logs/report.md)") 85 parser.add_argument("--lang", type=str, default="zh_CN", help="Report language") 86 87 args = parser.parse_args() 88 89 logger.set_language(lang=args.lang) 90 91 # 处理工具扫描相关命令(优先级最高) 92 if args.scan_tools or args.show_tool_params: 93 if handle_tool_scanning(args): 94 exit(0) 95 96 # 处理插件相关命令 97 if args.show_template: 98 show_plugin_template(args.show_template, plugin_manager) 99 exit(0) 100 101 if args.validate_plugin: 102 validate_plugin(args.validate_plugin, plugin_manager) 103 exit(0) 104 105 # 加载插件(在list_plugins之前) 106 if args.auto_discover: 107 auto_discover_plugins(plugin_manager) 108 109 if args.plugins: 110 load_plugins_from_args(args.plugins, plugin_manager) 111 112 if args.list_plugins: 113 list_plugins(plugin_manager) 114 exit(0) 115 116 # 初始化模型 117 models = [] 118 lengths = list(map(len, (args.base_url, args.api_key, args.model, args.max_concurrent))) 119 if len(set(lengths)) != 1: 120 raise ValueError("base_url, api_key, model, max_concurrent must have same number of parameters") 121 for base_url, api_key, model_name, max_concurrent in zip(args.base_url, args.api_key, args.model, args.max_concurrent): 122 model = create_model(model_name, base_url, api_key[0], max_concurrent) 123 models.append(model) 124 125 if any(param is None for param in (args.evaluate_model, args.eval_base_url, args.eval_api_key, args.eval_max_concurrent)): 126 evaluate_model = models[0] 127 else: 128 evaluate_model = create_model(args.evaluate_model, args.eval_base_url, args.eval_api_key[0], args.eval_max_concurrent) 129 130 if any(param is None for param in (args.simulator_model, args.sim_base_url, args.sim_api_key, args.sim_max_concurrent)): 131 simulator_model = evaluate_model 132 else: 133 simulator_model = create_model(args.simulator_model, args.sim_base_url, args.sim_api_key[0], args.sim_max_concurrent) 134 135 # 创建红队运行器 136 runner = RedTeamRunner(plugin_manager) 137 138 # 运行红队测试 139 runner.run_red_team( 140 models=models, 141 simulator_model=simulator_model, 142 evaluate_model=evaluate_model, 143 scenarios=args.scenarios, 144 techniques=args.techniques, 145 async_mode=args.async_mode, 146 choice=args.choice, 147 metric=args.metric, 148 report_path=args.report, 149 ) 150 151 152 if __name__ == "__main__": 153 try: 154 main() 155 # 清理过期文件 156 cleanup_expired_files() 157 except Exception as e: 158 logger.error(e) 159 logger.critical_issue(content=logger.translated_msg("Something went wrong. Please try again in a few moments."))