optimize.py
1 # -*- coding: utf-8 -*- 2 # @Date : 8/23/2024 20:00 PM 3 # @Author : didi 4 # @Desc : Entrance of AFlow. 5 6 import argparse 7 from typing import Dict, List 8 9 from metagpt.configs.models_config import ModelsConfig 10 from metagpt.ext.aflow.data.download_data import download 11 from metagpt.ext.aflow.scripts.optimizer import Optimizer 12 13 14 class ExperimentConfig: 15 def __init__(self, dataset: str, question_type: str, operators: List[str]): 16 self.dataset = dataset 17 self.question_type = question_type 18 self.operators = operators 19 20 21 EXPERIMENT_CONFIGS: Dict[str, ExperimentConfig] = { 22 "DROP": ExperimentConfig( 23 dataset="DROP", 24 question_type="qa", 25 operators=["Custom", "AnswerGenerate", "ScEnsemble"], 26 ), 27 "HotpotQA": ExperimentConfig( 28 dataset="HotpotQA", 29 question_type="qa", 30 operators=["Custom", "AnswerGenerate", "ScEnsemble"], 31 ), 32 "MATH": ExperimentConfig( 33 dataset="MATH", 34 question_type="math", 35 operators=["Custom", "ScEnsemble", "Programmer"], 36 ), 37 "GSM8K": ExperimentConfig( 38 dataset="GSM8K", 39 question_type="math", 40 operators=["Custom", "ScEnsemble", "Programmer"], 41 ), 42 "MBPP": ExperimentConfig( 43 dataset="MBPP", 44 question_type="code", 45 operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"], 46 ), 47 "HumanEval": ExperimentConfig( 48 dataset="HumanEval", 49 question_type="code", 50 operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"], 51 ), 52 } 53 54 55 def parse_args(): 56 parser = argparse.ArgumentParser(description="AFlow Optimizer") 57 parser.add_argument( 58 "--dataset", 59 type=str, 60 choices=list(EXPERIMENT_CONFIGS.keys()), 61 required=True, 62 help="Dataset type", 63 ) 64 parser.add_argument("--sample", type=int, default=4, help="Sample count") 65 parser.add_argument( 66 "--optimized_path", 67 type=str, 68 default="metagpt/ext/aflow/scripts/optimized", 69 help="Optimized result save path", 70 ) 71 parser.add_argument("--initial_round", type=int, default=1, help="Initial round") 72 parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds") 73 parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop") 74 parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds") 75 parser.add_argument( 76 "--if_first_optimize", 77 type=lambda x: x.lower() == "true", 78 default=True, 79 help="Whether to download dataset for the first time", 80 ) 81 parser.add_argument( 82 "--opt_model_name", 83 type=str, 84 default="claude-3-5-sonnet-20240620", 85 help="Specifies the name of the model used for optimization tasks.", 86 ) 87 parser.add_argument( 88 "--exec_model_name", 89 type=str, 90 default="gpt-4o-mini", 91 help="Specifies the name of the model used for execution tasks.", 92 ) 93 return parser.parse_args() 94 95 96 if __name__ == "__main__": 97 args = parse_args() 98 99 config = EXPERIMENT_CONFIGS[args.dataset] 100 101 models_config = ModelsConfig.default() 102 opt_llm_config = models_config.get(args.opt_model_name) 103 if opt_llm_config is None: 104 raise ValueError( 105 f"The optimization model '{args.opt_model_name}' was not found in the 'models' section of the configuration file. " 106 "Please add it to the configuration file or specify a valid model using the --opt_model_name flag. " 107 ) 108 109 exec_llm_config = models_config.get(args.exec_model_name) 110 if exec_llm_config is None: 111 raise ValueError( 112 f"The execution model '{args.exec_model_name}' was not found in the 'models' section of the configuration file. " 113 "Please add it to the configuration file or specify a valid model using the --exec_model_name flag. " 114 ) 115 116 download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize) 117 118 optimizer = Optimizer( 119 dataset=config.dataset, 120 question_type=config.question_type, 121 opt_llm_config=opt_llm_config, 122 exec_llm_config=exec_llm_config, 123 check_convergence=args.check_convergence, 124 operators=config.operators, 125 optimized_path=args.optimized_path, 126 sample=args.sample, 127 initial_round=args.initial_round, 128 max_rounds=args.max_rounds, 129 validation_rounds=args.validation_rounds, 130 ) 131 132 # Optimize workflow via setting the optimizer's mode to 'Graph' 133 optimizer.optimize("Graph") 134 135 # Test workflow via setting the optimizer's mode to 'Test' 136 # optimizer.optimize("Test")