/ AIG-PromptSecurity / cli_run.py
cli_run.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import time
 20  from pathlib import Path
 21  import argparse
 22  
 23  from cli.aig_logger import logger
 24  from cli.aig_logger import (
 25      newPlanStep, statusUpdate, toolUsed, actionLog, resultUpdate
 26  )
 27  from deepteam.plugin_system import PluginManager
 28  from cli.models import create_model
 29  from cli.plugin_commands import list_plugins, load_plugins_from_args, show_plugin_template, validate_plugin, auto_discover_plugins
 30  from cli.red_team_runner import RedTeamRunner
 31  from cli.tool_scanner_cli import handle_tool_scanning
 32  
 33  
 34  # logger config
 35  logger.add(f"logs/red_team_{{time:YYYY-MM-DD_HH-mm-ss}}.log", level="DEBUG", enqueue=True, retention="7 days")
 36  
 37  # 全局插件管理器
 38  plugin_manager = PluginManager()
 39  
 40  def cleanup_expired_files(log_path: str = "logs", max_age_seconds: int = 86400*30, pattern: str = "attachment_*.csv"):
 41      now = time.time()
 42      for file in Path(log_path).glob(pattern):
 43          if (now - file.stat().st_mtime) > max_age_seconds:
 44              file.unlink()
 45  
 46  def main():
 47      """主函数"""
 48      parser = argparse.ArgumentParser(description="Red Team CLI Runner")
 49      
 50      # 工具扫描相关参数(放在最前面,优先级最高)
 51      parser.add_argument("--scan-tools", type=str, choices=["all", "techniques", "metrics", "scenarios"], 
 52                         help="Scan and display all available tools and their parameters")
 53      parser.add_argument("--show-tool-params", type=str, 
 54                         help="Show detailed parameter information for a specific tool")
 55      
 56      # 插件相关参数
 57      parser.add_argument("--plugins", type=str, nargs='+', help="Custom plugin files or directories to load")
 58      parser.add_argument("--list-plugins", action="store_true", help="List all available plugins")
 59      parser.add_argument("--show-template", type=str, choices=["attack", "metric", "vulnerability"], help="Show plugin template")
 60      parser.add_argument("--validate-plugin", type=str, help="Validate a plugin file or directory")
 61      parser.add_argument("--auto-discover", action="store_true", help="Auto-discover plugins from default directories")
 62      
 63      # 红队测试相关参数
 64      parser.add_argument("--base_url", type=str, action='append', help="Base URL for ChatOpenAI")
 65      parser.add_argument("--api_key", type=str, nargs=1, action='append', help="API Key for ChatOpenAI")
 66      parser.add_argument("--model", type=str, action='append', help="Model name for ChatOpenAI")
 67      parser.add_argument("--max_concurrent", type=int, action='append', help="Max concurrent")
 68      parser.add_argument("--sim_base_url", type=str, help="Base URL for a simulator model")
 69      parser.add_argument("--sim_api_key", type=str, nargs=1, help="API Key for a simulator model")
 70      parser.add_argument("--simulator_model", type=str, help="Model name for a simulator model")
 71      parser.add_argument("--sim_max_concurrent", type=int, default=10, help="Max concurrent")
 72      parser.add_argument("--eval_base_url", type=str, help="Base URL for a evaluate model")
 73      parser.add_argument("--eval_api_key", type=str, nargs=1, help="API Key for a evaluate model")
 74      parser.add_argument("--evaluate_model", type=str, help="Model name for a evaluate model")
 75      parser.add_argument("--eval_max_concurrent", type=int, default=10, help="Max concurrent")
 76  
 77      parser.add_argument("--scenarios", type=str, nargs='+', help="Scenarios to test")
 78      parser.add_argument("--techniques", type=str, nargs='+', help="Techniques to test")
 79      
 80      parser.add_argument("--async_mode", action='store_true', help="Enable async mode")
 81      parser.add_argument("--choice", type=str, default="random", choices=["random", "serial", "parallel"], 
 82                         help="Technique selection strategy: 'random' (default) or 'serial' (nested techniques) or 'parallel'")
 83      parser.add_argument("--metric", type=str, help="Metric class name (e.g., 'RandomMetric')")
 84      parser.add_argument("--report", type=str, default="logs/report.md", help="Path to save the risk assessment report (default: logs/report.md)")
 85      parser.add_argument("--lang", type=str, default="zh_CN", help="Report language")
 86      
 87      args = parser.parse_args()
 88  
 89      logger.set_language(lang=args.lang)
 90  
 91      # 处理工具扫描相关命令(优先级最高)
 92      if args.scan_tools or args.show_tool_params:
 93          if handle_tool_scanning(args):
 94              exit(0)
 95  
 96      # 处理插件相关命令
 97      if args.show_template:
 98          show_plugin_template(args.show_template, plugin_manager)
 99          exit(0)
100      
101      if args.validate_plugin:
102          validate_plugin(args.validate_plugin, plugin_manager)
103          exit(0)
104      
105      # 加载插件(在list_plugins之前)
106      if args.auto_discover:
107          auto_discover_plugins(plugin_manager)
108      
109      if args.plugins:
110          load_plugins_from_args(args.plugins, plugin_manager)
111      
112      if args.list_plugins:
113          list_plugins(plugin_manager)
114          exit(0)
115  
116      # 初始化模型
117      models = []
118      lengths = list(map(len, (args.base_url, args.api_key, args.model, args.max_concurrent)))
119      if len(set(lengths)) != 1:
120          raise ValueError("base_url, api_key, model, max_concurrent must have same number of parameters")
121      for base_url, api_key, model_name, max_concurrent  in zip(args.base_url, args.api_key, args.model, args.max_concurrent):
122          model = create_model(model_name, base_url, api_key[0], max_concurrent)
123          models.append(model)
124          
125      if any(param is None for param in (args.evaluate_model, args.eval_base_url, args.eval_api_key, args.eval_max_concurrent)):
126          evaluate_model = models[0]
127      else:
128          evaluate_model = create_model(args.evaluate_model, args.eval_base_url, args.eval_api_key[0], args.eval_max_concurrent)
129  
130      if any(param is None for param in (args.simulator_model, args.sim_base_url, args.sim_api_key, args.sim_max_concurrent)):
131          simulator_model = evaluate_model
132      else:
133          simulator_model = create_model(args.simulator_model, args.sim_base_url, args.sim_api_key[0], args.sim_max_concurrent)
134  
135      # 创建红队运行器
136      runner = RedTeamRunner(plugin_manager)
137      
138      # 运行红队测试
139      runner.run_red_team(
140          models=models,
141          simulator_model=simulator_model,
142          evaluate_model=evaluate_model,
143          scenarios=args.scenarios,
144          techniques=args.techniques,
145          async_mode=args.async_mode,
146          choice=args.choice,
147          metric=args.metric,
148          report_path=args.report,
149      )
150  
151  
152  if __name__ == "__main__":
153      try:
154          main()
155          # 清理过期文件
156          cleanup_expired_files()
157      except Exception as e:
158          logger.error(e)
159          logger.critical_issue(content=logger.translated_msg("Something went wrong. Please try again in a few moments."))