evaluate.go
1 package evaluate 2 3 import ( 4 "os" 5 "path/filepath" 6 "strings" 7 8 "github.com/symflower/eval-dev-quality/evaluate/report" 9 evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task" 10 "github.com/symflower/eval-dev-quality/language" 11 evallanguage "github.com/symflower/eval-dev-quality/language" 12 "github.com/symflower/eval-dev-quality/log" 13 evalmodel "github.com/symflower/eval-dev-quality/model" 14 "github.com/symflower/eval-dev-quality/provider" 15 evaltask "github.com/symflower/eval-dev-quality/task" 16 ) 17 18 // Context holds an evaluation context. 19 type Context struct { 20 // Log holds the logger of the context. 21 Log *log.Logger 22 23 // Languages determines which language should be used for the evaluation, or empty if all languages should be used. 24 Languages []evallanguage.Language 25 26 // Models determines which models should be used for the evaluation, or empty if all models should be used. 27 Models []evalmodel.Model 28 // ProviderForModel holds the models and their associated provider. 29 ProviderForModel map[evalmodel.Model]provider.Provider 30 // QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. 31 QueryAttempts uint 32 33 // RepositoryPaths determines which relative repository paths should be used for the evaluation, or empty if all repositories should be used. 34 RepositoryPaths []string 35 // ResultPath holds the directory path where results should be written to. 36 ResultPath string 37 // TestdataPath determines the testdata path where all repositories reside grouped by languages. 38 TestdataPath string 39 40 // Runs holds the number of runs to perform. 41 Runs uint 42 // RunsSequential indicates that interleaved runs are disabled and runs are performed sequentially. 43 RunsSequential bool 44 // NoDisqualification indicates that models are not to be disqualified if they fail to solve basic language tasks. 45 NoDisqualification bool 46 } 47 48 // runsAtLanguageLevel returns how many runs to perform on language level. 49 func (ctx *Context) runsAtLanguageLevel() uint { 50 if ctx.RunsSequential { 51 return 1 52 } 53 54 return ctx.Runs 55 } 56 57 // runsAtModelLevel returns how many runs to perform on model level. 58 func (ctx *Context) runsAtModelLevel() uint { 59 if ctx.RunsSequential { 60 return ctx.Runs 61 } 62 63 return 1 64 } 65 66 // RepositoryPlainName holds the name of the plain repository. 67 const RepositoryPlainName = "plain" 68 69 // Evaluate runs an evaluation on the given context and returns its results. 70 func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uint64) { 71 // Check that models and languages can be evaluated by executing the "plain" repositories. 72 modelSucceededBasicChecksOfLanguage := map[evalmodel.Model]map[evallanguage.Language]bool{} 73 ctx.Log.Printf("Checking that models and languages can be used for evaluation") 74 // Ensure we report metrics for every model even if they are excluded. 75 assessments = report.NewAssessmentStore() 76 problemsPerModel := map[string][]error{} 77 // Write the evaluation CSV header so it's only written once. 78 evaluationCSVFile, err := os.OpenFile(filepath.Join(ctx.ResultPath, "evaluation.csv"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 79 if err != nil { 80 ctx.Log.Panicf("ERROR: unable to create evaluation CSV file: %+v", err) 81 } 82 defer evaluationCSVFile.Close() 83 evaluationFile, err := report.NewEvaluationFile(evaluationCSVFile) 84 if err != nil { 85 ctx.Log.Panicf("ERROR: %+v", err) 86 } 87 88 { 89 // Create temporary repositories for each language so the repository is copied only once per language. 90 temporaryRepositories := map[string]evaltask.Repository{} 91 for _, language := range ctx.Languages { 92 repositoryPath := filepath.Join(language.ID(), RepositoryPlainName) 93 temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath) 94 if err != nil { 95 ctx.Log.Panicf("ERROR: unable to create temporary repository path: %+v", err) 96 } else if err = temporaryRepository.Validate(ctx.Log, language); err != nil { 97 ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err) 98 } 99 100 defer cleanup() 101 102 temporaryRepositories[repositoryPath] = temporaryRepository 103 } 104 105 logger := ctx.Log 106 for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ { 107 if ctx.Runs > 1 && !ctx.RunsSequential { 108 logger.Printf("Run %d/%d", rl+1, ctx.Runs) 109 } 110 111 logger := logger.With(log.AttributeKeyRun, rl+1) 112 113 for _, language := range ctx.Languages { 114 logger := logger.With(log.AttributeKeyLanguage, language.ID()) 115 116 languageID := language.ID() 117 repositoryPath := filepath.Join(language.ID(), RepositoryPlainName) 118 temporaryRepository := temporaryRepositories[repositoryPath] 119 120 logger = logger.With(log.AttributeKeyRepository, temporaryRepository.Name()) 121 for _, model := range ctx.Models { 122 modelID := model.ID() 123 logger := logger.With(log.AttributeKeyModel, modelID) 124 125 if modelSucceededBasicChecksOfLanguage[model] == nil { 126 modelSucceededBasicChecksOfLanguage[model] = map[evallanguage.Language]bool{} 127 } 128 129 if r, ok := model.(evalmodel.SetQueryAttempts); ok { 130 r.SetQueryAttempts(ctx.QueryAttempts) 131 } 132 133 for _, taskIdentifier := range temporaryRepository.SupportedTasks() { 134 task, err := evaluatetask.TaskForIdentifier(taskIdentifier) 135 if err != nil { 136 logger.Fatal(err) 137 } 138 139 logger := logger.With(log.AttributeKeyTask, taskIdentifier) 140 withLoadedModel(logger, model, ctx.ProviderForModel[model], func() { 141 for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ { 142 if ctx.Runs > 1 && ctx.RunsSequential { 143 logger.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID) 144 } 145 146 if err := temporaryRepository.Reset(logger); err != nil { 147 logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) 148 } 149 150 taskContext := evaltask.Context{ 151 Language: language, 152 Repository: temporaryRepository, 153 Model: model, 154 155 ResultPath: ctx.ResultPath, 156 157 Logger: logger, 158 } 159 assessment, ps, err := task.Run(taskContext) 160 if err != nil { 161 ps = append(ps, err) 162 } 163 if len(ps) > 0 { 164 logger.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps) 165 problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...) 166 } else { 167 modelSucceededBasicChecksOfLanguage[model][language] = true 168 } 169 assessments.AddAssessmentPerTask(model, language, repositoryPath, assessment) 170 // Write the task assessment to the evaluation CSV file. 171 evaluationFile.WriteEvaluationRecord(model, language, temporaryRepository.Name(), assessment) 172 } 173 }) 174 } 175 } 176 } 177 } 178 } 179 180 repositoriesLookup := make(map[string]bool, len(ctx.RepositoryPaths)) 181 for _, repositoryPath := range ctx.RepositoryPaths { 182 repositoriesLookup[repositoryPath] = true 183 } 184 185 // Evaluating models and languages. 186 ctx.Log.Printf("Evaluating models and languages") 187 // Create temporary repositories for each language so the repository is copied only once per language. 188 temporaryRepositories := map[string]*evaluatetask.Repository{} 189 for _, l := range ctx.Languages { 190 relativeRepositoryPaths, err := language.RepositoriesForLanguage(l, ctx.TestdataPath) 191 if err != nil { 192 ctx.Log.Panicf("ERROR: %s", err) 193 } 194 for _, repositoryPath := range relativeRepositoryPaths { 195 196 // Do not include "plain" repositories in this step of the evaluation, because they have been checked with the common check before. 197 if !repositoriesLookup[repositoryPath] || strings.HasSuffix(repositoryPath, RepositoryPlainName) { 198 continue 199 } 200 201 temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath) 202 if err != nil { 203 ctx.Log.Panicf("ERROR: unable to create temporary repository path: %s", err) 204 } else if err = temporaryRepository.Validate(ctx.Log, l); err != nil { 205 ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err) 206 } 207 208 defer cleanup() 209 210 temporaryRepositories[repositoryPath] = temporaryRepository 211 } 212 } 213 logger := ctx.Log 214 for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ { 215 if ctx.Runs > 1 && !ctx.RunsSequential { 216 logger.Printf("Run %d/%d", rl+1, ctx.Runs) 217 } 218 219 logger := logger.With(log.AttributeKeyRun, rl+1) 220 221 for _, language := range ctx.Languages { 222 languageID := language.ID() 223 logger := logger.With(log.AttributeKeyLanguage, languageID) 224 225 languagePath := filepath.Join(ctx.TestdataPath, languageID) 226 repositories, err := os.ReadDir(languagePath) 227 if err != nil { 228 logger.Panicf("ERROR: language path %q cannot be accessed: %s", languagePath, err) 229 } 230 231 for _, repository := range repositories { 232 repositoryPath := filepath.Join(languageID, repository.Name()) 233 temporaryRepository := temporaryRepositories[repositoryPath] 234 235 if !repository.IsDir() || (len(ctx.RepositoryPaths) > 0 && !repositoriesLookup[repositoryPath]) { 236 continue 237 } 238 239 // Do not include "plain" repositories in this step of the evaluation, because they have been checked with the common check before. 240 if repository.Name() == RepositoryPlainName { 241 continue 242 } 243 244 logger = logger.With(log.AttributeKeyRepository, repositoryPath) 245 for _, model := range ctx.Models { 246 modelID := model.ID() 247 logger := logger.With(log.AttributeKeyModel, modelID) 248 249 if !ctx.NoDisqualification && !modelSucceededBasicChecksOfLanguage[model][language] { 250 logger.Printf("Excluding model %q for language %q cause it did not succeed basic checks", model.ID(), language.ID()) 251 252 continue 253 } 254 for _, taskIdentifier := range temporaryRepository.Tasks { 255 task, err := evaluatetask.TaskForIdentifier(taskIdentifier) 256 if err != nil { 257 logger.Fatal(err) 258 } 259 logger := logger.With(log.AttributeKeyTask, taskIdentifier) 260 withLoadedModel(logger, model, ctx.ProviderForModel[model], func() { 261 for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ { 262 if ctx.Runs > 1 && ctx.RunsSequential { 263 logger.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID) 264 } 265 266 if err := temporaryRepository.Reset(logger); err != nil { 267 logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) 268 } 269 270 taskContext := evaltask.Context{ 271 Language: language, 272 Repository: temporaryRepository, 273 Model: model, 274 275 ResultPath: ctx.ResultPath, 276 277 Logger: logger, 278 } 279 assessment, ps, err := task.Run(taskContext) 280 problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...) 281 if err != nil { 282 logger.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err) 283 } 284 assessments.AddAssessmentPerTask(model, language, repositoryPath, assessment) 285 // Write the task assessment to the evaluation CSV file. 286 evaluationFile.WriteEvaluationRecord(model, language, temporaryRepository.Name(), assessment) 287 } 288 }) 289 } 290 } 291 } 292 } 293 } 294 295 // Set the total score to the number of evaluated languages if we are just checking the "plain" repositories since there is only one task to solve per language. 296 isOnlyPlainRepositories := true 297 for _, repositoryPath := range ctx.RepositoryPaths { 298 if filepath.Base(repositoryPath) != RepositoryPlainName { 299 isOnlyPlainRepositories = false 300 301 break 302 } 303 } 304 if isOnlyPlainRepositories { 305 // For every write-test task in the plain repository, each model is also executed with the `symflower fix` which results in double the total results. 306 totalScore = 2 * uint64(len(ctx.Languages)) * uint64(ctx.Runs) 307 } 308 309 return assessments, totalScore 310 } 311 312 // withLoadedModel loads the model for the duration of the given task if supported by the model's provider. 313 func withLoadedModel(log *log.Logger, model evalmodel.Model, modelProvider provider.Provider, task func()) { 314 if loader, ok := modelProvider.(provider.Loader); ok { 315 log.Printf("preloading model %q", model.ID()) 316 if err := loader.Load(model.ID()); err != nil { 317 log.Panicf("ERROR: could not load model %q with provider %q", model.ID(), modelProvider.ID()) 318 } 319 defer func() { 320 log.Printf("unloading model %q", model.ID()) 321 if err := loader.Unload(model.ID()); err != nil { 322 log.Panicf("ERROR: could not unload model %q with provider %q", model.ID(), modelProvider.ID()) 323 } 324 }() 325 } 326 327 task() 328 }