evaluate.go
  1  package evaluate
  2  
  3  import (
  4  	"os"
  5  	"path/filepath"
  6  	"strings"
  7  
  8  	"github.com/symflower/eval-dev-quality/evaluate/report"
  9  	evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task"
 10  	"github.com/symflower/eval-dev-quality/language"
 11  	evallanguage "github.com/symflower/eval-dev-quality/language"
 12  	"github.com/symflower/eval-dev-quality/log"
 13  	evalmodel "github.com/symflower/eval-dev-quality/model"
 14  	"github.com/symflower/eval-dev-quality/provider"
 15  	evaltask "github.com/symflower/eval-dev-quality/task"
 16  )
 17  
 18  // Context holds an evaluation context.
 19  type Context struct {
 20  	// Log holds the logger of the context.
 21  	Log *log.Logger
 22  
 23  	// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
 24  	Languages []evallanguage.Language
 25  
 26  	// Models determines which models should be used for the evaluation, or empty if all models should be used.
 27  	Models []evalmodel.Model
 28  	// ProviderForModel holds the models and their associated provider.
 29  	ProviderForModel map[evalmodel.Model]provider.Provider
 30  	// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
 31  	QueryAttempts uint
 32  
 33  	// RepositoryPaths determines which relative repository paths should be used for the evaluation, or empty if all repositories should be used.
 34  	RepositoryPaths []string
 35  	// ResultPath holds the directory path where results should be written to.
 36  	ResultPath string
 37  	// TestdataPath determines the testdata path where all repositories reside grouped by languages.
 38  	TestdataPath string
 39  
 40  	// Runs holds the number of runs to perform.
 41  	Runs uint
 42  	// RunsSequential indicates that interleaved runs are disabled and runs are performed sequentially.
 43  	RunsSequential bool
 44  	// NoDisqualification indicates that models are not to be disqualified if they fail to solve basic language tasks.
 45  	NoDisqualification bool
 46  }
 47  
 48  // runsAtLanguageLevel returns how many runs to perform on language level.
 49  func (ctx *Context) runsAtLanguageLevel() uint {
 50  	if ctx.RunsSequential {
 51  		return 1
 52  	}
 53  
 54  	return ctx.Runs
 55  }
 56  
 57  // runsAtModelLevel returns how many runs to perform on model level.
 58  func (ctx *Context) runsAtModelLevel() uint {
 59  	if ctx.RunsSequential {
 60  		return ctx.Runs
 61  	}
 62  
 63  	return 1
 64  }
 65  
 66  // RepositoryPlainName holds the name of the plain repository.
 67  const RepositoryPlainName = "plain"
 68  
 69  // Evaluate runs an evaluation on the given context and returns its results.
 70  func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uint64) {
 71  	// Check that models and languages can be evaluated by executing the "plain" repositories.
 72  	modelSucceededBasicChecksOfLanguage := map[evalmodel.Model]map[evallanguage.Language]bool{}
 73  	ctx.Log.Printf("Checking that models and languages can be used for evaluation")
 74  	// Ensure we report metrics for every model even if they are excluded.
 75  	assessments = report.NewAssessmentStore()
 76  	problemsPerModel := map[string][]error{}
 77  	// Write the evaluation CSV header so it's only written once.
 78  	evaluationCSVFile, err := os.OpenFile(filepath.Join(ctx.ResultPath, "evaluation.csv"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 79  	if err != nil {
 80  		ctx.Log.Panicf("ERROR: unable to create evaluation CSV file: %+v", err)
 81  	}
 82  	defer evaluationCSVFile.Close()
 83  	evaluationFile, err := report.NewEvaluationFile(evaluationCSVFile)
 84  	if err != nil {
 85  		ctx.Log.Panicf("ERROR: %+v", err)
 86  	}
 87  
 88  	{
 89  		// Create temporary repositories for each language so the repository is copied only once per language.
 90  		temporaryRepositories := map[string]evaltask.Repository{}
 91  		for _, language := range ctx.Languages {
 92  			repositoryPath := filepath.Join(language.ID(), RepositoryPlainName)
 93  			temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath)
 94  			if err != nil {
 95  				ctx.Log.Panicf("ERROR: unable to create temporary repository path: %+v", err)
 96  			} else if err = temporaryRepository.Validate(ctx.Log, language); err != nil {
 97  				ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err)
 98  			}
 99  
100  			defer cleanup()
101  
102  			temporaryRepositories[repositoryPath] = temporaryRepository
103  		}
104  
105  		logger := ctx.Log
106  		for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ {
107  			if ctx.Runs > 1 && !ctx.RunsSequential {
108  				logger.Printf("Run %d/%d", rl+1, ctx.Runs)
109  			}
110  
111  			logger := logger.With(log.AttributeKeyRun, rl+1)
112  
113  			for _, language := range ctx.Languages {
114  				logger := logger.With(log.AttributeKeyLanguage, language.ID())
115  
116  				languageID := language.ID()
117  				repositoryPath := filepath.Join(language.ID(), RepositoryPlainName)
118  				temporaryRepository := temporaryRepositories[repositoryPath]
119  
120  				logger = logger.With(log.AttributeKeyRepository, temporaryRepository.Name())
121  				for _, model := range ctx.Models {
122  					modelID := model.ID()
123  					logger := logger.With(log.AttributeKeyModel, modelID)
124  
125  					if modelSucceededBasicChecksOfLanguage[model] == nil {
126  						modelSucceededBasicChecksOfLanguage[model] = map[evallanguage.Language]bool{}
127  					}
128  
129  					if r, ok := model.(evalmodel.SetQueryAttempts); ok {
130  						r.SetQueryAttempts(ctx.QueryAttempts)
131  					}
132  
133  					for _, taskIdentifier := range temporaryRepository.SupportedTasks() {
134  						task, err := evaluatetask.TaskForIdentifier(taskIdentifier)
135  						if err != nil {
136  							logger.Fatal(err)
137  						}
138  
139  						logger := logger.With(log.AttributeKeyTask, taskIdentifier)
140  						withLoadedModel(logger, model, ctx.ProviderForModel[model], func() {
141  							for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
142  								if ctx.Runs > 1 && ctx.RunsSequential {
143  									logger.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
144  								}
145  
146  								if err := temporaryRepository.Reset(logger); err != nil {
147  									logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
148  								}
149  
150  								taskContext := evaltask.Context{
151  									Language:   language,
152  									Repository: temporaryRepository,
153  									Model:      model,
154  
155  									ResultPath: ctx.ResultPath,
156  
157  									Logger: logger,
158  								}
159  								assessment, ps, err := task.Run(taskContext)
160  								if err != nil {
161  									ps = append(ps, err)
162  								}
163  								if len(ps) > 0 {
164  									logger.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
165  									problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
166  								} else {
167  									modelSucceededBasicChecksOfLanguage[model][language] = true
168  								}
169  								assessments.AddAssessmentPerTask(model, language, repositoryPath, assessment)
170  								// Write the task assessment to the evaluation CSV file.
171  								evaluationFile.WriteEvaluationRecord(model, language, temporaryRepository.Name(), assessment)
172  							}
173  						})
174  					}
175  				}
176  			}
177  		}
178  	}
179  
180  	repositoriesLookup := make(map[string]bool, len(ctx.RepositoryPaths))
181  	for _, repositoryPath := range ctx.RepositoryPaths {
182  		repositoriesLookup[repositoryPath] = true
183  	}
184  
185  	// Evaluating models and languages.
186  	ctx.Log.Printf("Evaluating models and languages")
187  	// Create temporary repositories for each language so the repository is copied only once per language.
188  	temporaryRepositories := map[string]*evaluatetask.Repository{}
189  	for _, l := range ctx.Languages {
190  		relativeRepositoryPaths, err := language.RepositoriesForLanguage(l, ctx.TestdataPath)
191  		if err != nil {
192  			ctx.Log.Panicf("ERROR: %s", err)
193  		}
194  		for _, repositoryPath := range relativeRepositoryPaths {
195  
196  			// Do not include "plain" repositories in this step of the evaluation, because they have been checked with the common check before.
197  			if !repositoriesLookup[repositoryPath] || strings.HasSuffix(repositoryPath, RepositoryPlainName) {
198  				continue
199  			}
200  
201  			temporaryRepository, cleanup, err := evaluatetask.TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath)
202  			if err != nil {
203  				ctx.Log.Panicf("ERROR: unable to create temporary repository path: %s", err)
204  			} else if err = temporaryRepository.Validate(ctx.Log, l); err != nil {
205  				ctx.Log.Panicf("ERROR: malformed repository %q: %+v", temporaryRepository.Name(), err)
206  			}
207  
208  			defer cleanup()
209  
210  			temporaryRepositories[repositoryPath] = temporaryRepository
211  		}
212  	}
213  	logger := ctx.Log
214  	for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ {
215  		if ctx.Runs > 1 && !ctx.RunsSequential {
216  			logger.Printf("Run %d/%d", rl+1, ctx.Runs)
217  		}
218  
219  		logger := logger.With(log.AttributeKeyRun, rl+1)
220  
221  		for _, language := range ctx.Languages {
222  			languageID := language.ID()
223  			logger := logger.With(log.AttributeKeyLanguage, languageID)
224  
225  			languagePath := filepath.Join(ctx.TestdataPath, languageID)
226  			repositories, err := os.ReadDir(languagePath)
227  			if err != nil {
228  				logger.Panicf("ERROR: language path %q cannot be accessed: %s", languagePath, err)
229  			}
230  
231  			for _, repository := range repositories {
232  				repositoryPath := filepath.Join(languageID, repository.Name())
233  				temporaryRepository := temporaryRepositories[repositoryPath]
234  
235  				if !repository.IsDir() || (len(ctx.RepositoryPaths) > 0 && !repositoriesLookup[repositoryPath]) {
236  					continue
237  				}
238  
239  				// Do not include "plain" repositories in this step of the evaluation, because they have been checked with the common check before.
240  				if repository.Name() == RepositoryPlainName {
241  					continue
242  				}
243  
244  				logger = logger.With(log.AttributeKeyRepository, repositoryPath)
245  				for _, model := range ctx.Models {
246  					modelID := model.ID()
247  					logger := logger.With(log.AttributeKeyModel, modelID)
248  
249  					if !ctx.NoDisqualification && !modelSucceededBasicChecksOfLanguage[model][language] {
250  						logger.Printf("Excluding model %q for language %q cause it did not succeed basic checks", model.ID(), language.ID())
251  
252  						continue
253  					}
254  					for _, taskIdentifier := range temporaryRepository.Tasks {
255  						task, err := evaluatetask.TaskForIdentifier(taskIdentifier)
256  						if err != nil {
257  							logger.Fatal(err)
258  						}
259  						logger := logger.With(log.AttributeKeyTask, taskIdentifier)
260  						withLoadedModel(logger, model, ctx.ProviderForModel[model], func() {
261  							for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
262  								if ctx.Runs > 1 && ctx.RunsSequential {
263  									logger.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
264  								}
265  
266  								if err := temporaryRepository.Reset(logger); err != nil {
267  									logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
268  								}
269  
270  								taskContext := evaltask.Context{
271  									Language:   language,
272  									Repository: temporaryRepository,
273  									Model:      model,
274  
275  									ResultPath: ctx.ResultPath,
276  
277  									Logger: logger,
278  								}
279  								assessment, ps, err := task.Run(taskContext)
280  								problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
281  								if err != nil {
282  									logger.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
283  								}
284  								assessments.AddAssessmentPerTask(model, language, repositoryPath, assessment)
285  								// Write the task assessment to the evaluation CSV file.
286  								evaluationFile.WriteEvaluationRecord(model, language, temporaryRepository.Name(), assessment)
287  							}
288  						})
289  					}
290  				}
291  			}
292  		}
293  	}
294  
295  	// Set the total score to the number of evaluated languages if we are just checking the "plain" repositories since there is only one task to solve per language.
296  	isOnlyPlainRepositories := true
297  	for _, repositoryPath := range ctx.RepositoryPaths {
298  		if filepath.Base(repositoryPath) != RepositoryPlainName {
299  			isOnlyPlainRepositories = false
300  
301  			break
302  		}
303  	}
304  	if isOnlyPlainRepositories {
305  		// For every write-test task in the plain repository, each model is also executed with the `symflower fix` which results in double the total results.
306  		totalScore = 2 * uint64(len(ctx.Languages)) * uint64(ctx.Runs)
307  	}
308  
309  	return assessments, totalScore
310  }
311  
312  // withLoadedModel loads the model for the duration of the given task if supported by the model's provider.
313  func withLoadedModel(log *log.Logger, model evalmodel.Model, modelProvider provider.Provider, task func()) {
314  	if loader, ok := modelProvider.(provider.Loader); ok {
315  		log.Printf("preloading model %q", model.ID())
316  		if err := loader.Load(model.ID()); err != nil {
317  			log.Panicf("ERROR: could not load model %q with provider %q", model.ID(), modelProvider.ID())
318  		}
319  		defer func() {
320  			log.Printf("unloading model %q", model.ID())
321  			if err := loader.Unload(model.ID()); err != nil {
322  				log.Panicf("ERROR: could not unload model %q with provider %q", model.ID(), modelProvider.ID())
323  			}
324  		}()
325  	}
326  
327  	task()
328  }