/ qwencoder-eval / instruct / eval-dev-quality / evaluate / task / task-write-test.go
task-write-test.go
  1  package task
  2  
  3  import (
  4  	"context"
  5  	"errors"
  6  	"fmt"
  7  	"strings"
  8  
  9  	pkgerrors "github.com/pkg/errors"
 10  	"github.com/symflower/eval-dev-quality/evaluate/metrics"
 11  	"github.com/symflower/eval-dev-quality/language"
 12  	"github.com/symflower/eval-dev-quality/log"
 13  	"github.com/symflower/eval-dev-quality/model"
 14  	evaltask "github.com/symflower/eval-dev-quality/task"
 15  )
 16  
 17  // TaskWriteTests holds the write test task.
 18  type TaskWriteTests struct {
 19  }
 20  
 21  var _ evaltask.Task = (*TaskWriteTests)(nil)
 22  
 23  // Identifier returns the write test task identifier.
 24  func (t *TaskWriteTests) Identifier() evaltask.Identifier {
 25  	return IdentifierWriteTests
 26  }
 27  
 28  // TaskWriteTests generates test files for the given implementation file in a repository.
 29  func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[evaltask.Identifier]metrics.Assessments, problems []error, err error) {
 30  	modelCapability, ok := ctx.Model.(model.CapabilityWriteTests)
 31  	if !ok {
 32  		return nil, nil, pkgerrors.Wrap(evaltask.ErrTaskUnsupportedByModel, fmt.Sprintf("%q does not support %q", ctx.Model.ID(), string(t.Identifier())))
 33  	}
 34  
 35  	taskLogger, err := newTaskLogger(ctx, t)
 36  	if err != nil {
 37  		return nil, nil, err
 38  	}
 39  	defer func() {
 40  		taskLogger.finalize(problems)
 41  	}()
 42  
 43  	dataPath := ctx.Repository.DataPath()
 44  	filePaths, err := ctx.Language.Files(taskLogger.Logger, dataPath)
 45  	if err != nil {
 46  		return nil, problems, pkgerrors.WithStack(err)
 47  	}
 48  
 49  	modelAssessment := metrics.NewAssessments()
 50  	withSymflowerAssessment := metrics.NewAssessments()
 51  
 52  	maximumReachableFiles := uint64(len(filePaths))
 53  	modelAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles
 54  	withSymflowerAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles
 55  
 56  	for _, filePath := range filePaths {
 57  		modelAssessmentForFile := metrics.NewAssessments()
 58  		withSymflowerAssessmentForFile := modelAssessmentForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until a failure actually happens.
 59  
 60  		if err := ctx.Repository.Reset(ctx.Logger); err != nil {
 61  			ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
 62  		}
 63  
 64  		modelContext := model.Context{
 65  			Language: ctx.Language,
 66  
 67  			RepositoryPath: dataPath,
 68  			FilePath:       filePath,
 69  
 70  			Logger: taskLogger.Logger,
 71  		}
 72  		assessments, err := modelCapability.WriteTests(modelContext)
 73  		if err != nil {
 74  			problems = append(problems, pkgerrors.WithMessage(err, filePath))
 75  
 76  			continue
 77  		}
 78  		if assessments[metrics.AssessmentKeyProcessingTime] == 0 {
 79  			return nil, nil, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name())
 80  		}
 81  		modelAssessmentForFile.Add(assessments)
 82  		modelAssessmentForFile.Award(metrics.AssessmentKeyResponseNoError)
 83  
 84  		testResult, ps, err := ctx.Language.ExecuteTests(taskLogger.Logger, dataPath)
 85  		problems = append(problems, ps...)
 86  		if err != nil {
 87  			problems = append(problems, pkgerrors.WithMessage(err, filePath))
 88  
 89  			// If there is an execution timeout do not run "symflower fix" because the code itself is correct.
 90  			if errors.Is(err, context.DeadlineExceeded) {
 91  				modelAssessment.Add(modelAssessmentForFile)
 92  				withSymflowerAssessment.Add(withSymflowerAssessmentForFile)
 93  
 94  				continue
 95  			}
 96  
 97  			// Run "symflower fix" if the model response fails to execute.
 98  			if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix".
 99  				withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath())
100  				problems = append(problems, ps...)
101  				if err != nil {
102  					problems = append(problems, err)
103  
104  					modelAssessment.Add(modelAssessmentForFile)
105  					withSymflowerAssessment.Add(withSymflowerAssessmentForFile)
106  
107  					continue
108  				} else {
109  					ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage)
110  
111  					// Symflower was able to fix a failure so now update the assessment with the improved results.
112  					withSymflowerFixAssessments := metrics.NewAssessments()
113  					withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime
114  					withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted)
115  					withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage)
116  
117  					withSymflowerAssessmentForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentForFile, withSymflowerFixAssessments)
118  				}
119  			}
120  		} else {
121  			taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage)
122  			modelAssessmentForFile.Award(metrics.AssessmentKeyFilesExecuted)
123  			modelAssessmentForFile.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage)
124  		}
125  
126  		modelAssessment.Add(modelAssessmentForFile)
127  		withSymflowerAssessment.Add(withSymflowerAssessmentForFile)
128  	}
129  
130  	repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{
131  		IdentifierWriteTests:             modelAssessment,
132  		IdentifierWriteTestsSymflowerFix: withSymflowerAssessment,
133  	}
134  
135  	return repositoryAssessment, problems, nil
136  }
137  
138  // validateWriteTestsRepository checks if the repository for the "write-tests" task is well-formed.
139  func validateWriteTestsRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) {
140  	logger.Printf("validating repository %q", repositoryPath)
141  
142  	files, err := language.Files(logger, repositoryPath)
143  	if err != nil {
144  		return pkgerrors.WithStack(err)
145  	}
146  
147  	var sourceFiles []string
148  	var testFiles []string
149  	for _, file := range files {
150  		if strings.HasSuffix(file, language.DefaultTestFileSuffix()) {
151  			testFiles = append(testFiles, file)
152  		} else if strings.HasSuffix(file, language.DefaultFileExtension()) {
153  			sourceFiles = append(sourceFiles, file)
154  		}
155  	}
156  
157  	if len(sourceFiles) == 0 {
158  		return pkgerrors.Errorf("the repository %q must contain at least one %s source file, but found none", repositoryPath, language.Name())
159  	} else if len(testFiles) > 0 {
160  		return pkgerrors.Errorf("the repository %q must contain only %s source files, but found %+v", repositoryPath, language.Name(), testFiles)
161  	}
162  
163  	return nil
164  }