task-write-test.go
1 package task 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "strings" 8 9 pkgerrors "github.com/pkg/errors" 10 "github.com/symflower/eval-dev-quality/evaluate/metrics" 11 "github.com/symflower/eval-dev-quality/language" 12 "github.com/symflower/eval-dev-quality/log" 13 "github.com/symflower/eval-dev-quality/model" 14 evaltask "github.com/symflower/eval-dev-quality/task" 15 ) 16 17 // TaskWriteTests holds the write test task. 18 type TaskWriteTests struct { 19 } 20 21 var _ evaltask.Task = (*TaskWriteTests)(nil) 22 23 // Identifier returns the write test task identifier. 24 func (t *TaskWriteTests) Identifier() evaltask.Identifier { 25 return IdentifierWriteTests 26 } 27 28 // TaskWriteTests generates test files for the given implementation file in a repository. 29 func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[evaltask.Identifier]metrics.Assessments, problems []error, err error) { 30 modelCapability, ok := ctx.Model.(model.CapabilityWriteTests) 31 if !ok { 32 return nil, nil, pkgerrors.Wrap(evaltask.ErrTaskUnsupportedByModel, fmt.Sprintf("%q does not support %q", ctx.Model.ID(), string(t.Identifier()))) 33 } 34 35 taskLogger, err := newTaskLogger(ctx, t) 36 if err != nil { 37 return nil, nil, err 38 } 39 defer func() { 40 taskLogger.finalize(problems) 41 }() 42 43 dataPath := ctx.Repository.DataPath() 44 filePaths, err := ctx.Language.Files(taskLogger.Logger, dataPath) 45 if err != nil { 46 return nil, problems, pkgerrors.WithStack(err) 47 } 48 49 modelAssessment := metrics.NewAssessments() 50 withSymflowerAssessment := metrics.NewAssessments() 51 52 maximumReachableFiles := uint64(len(filePaths)) 53 modelAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles 54 withSymflowerAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles 55 56 for _, filePath := range filePaths { 57 modelAssessmentForFile := metrics.NewAssessments() 58 withSymflowerAssessmentForFile := modelAssessmentForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until a failure actually happens. 59 60 if err := ctx.Repository.Reset(ctx.Logger); err != nil { 61 ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) 62 } 63 64 modelContext := model.Context{ 65 Language: ctx.Language, 66 67 RepositoryPath: dataPath, 68 FilePath: filePath, 69 70 Logger: taskLogger.Logger, 71 } 72 assessments, err := modelCapability.WriteTests(modelContext) 73 if err != nil { 74 problems = append(problems, pkgerrors.WithMessage(err, filePath)) 75 76 continue 77 } 78 if assessments[metrics.AssessmentKeyProcessingTime] == 0 { 79 return nil, nil, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name()) 80 } 81 modelAssessmentForFile.Add(assessments) 82 modelAssessmentForFile.Award(metrics.AssessmentKeyResponseNoError) 83 84 testResult, ps, err := ctx.Language.ExecuteTests(taskLogger.Logger, dataPath) 85 problems = append(problems, ps...) 86 if err != nil { 87 problems = append(problems, pkgerrors.WithMessage(err, filePath)) 88 89 // If there is an execution timeout do not run "symflower fix" because the code itself is correct. 90 if errors.Is(err, context.DeadlineExceeded) { 91 modelAssessment.Add(modelAssessmentForFile) 92 withSymflowerAssessment.Add(withSymflowerAssessmentForFile) 93 94 continue 95 } 96 97 // Run "symflower fix" if the model response fails to execute. 98 if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". 99 withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath()) 100 problems = append(problems, ps...) 101 if err != nil { 102 problems = append(problems, err) 103 104 modelAssessment.Add(modelAssessmentForFile) 105 withSymflowerAssessment.Add(withSymflowerAssessmentForFile) 106 107 continue 108 } else { 109 ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage) 110 111 // Symflower was able to fix a failure so now update the assessment with the improved results. 112 withSymflowerFixAssessments := metrics.NewAssessments() 113 withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime 114 withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) 115 withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage) 116 117 withSymflowerAssessmentForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentForFile, withSymflowerFixAssessments) 118 } 119 } 120 } else { 121 taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage) 122 modelAssessmentForFile.Award(metrics.AssessmentKeyFilesExecuted) 123 modelAssessmentForFile.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage) 124 } 125 126 modelAssessment.Add(modelAssessmentForFile) 127 withSymflowerAssessment.Add(withSymflowerAssessmentForFile) 128 } 129 130 repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ 131 IdentifierWriteTests: modelAssessment, 132 IdentifierWriteTestsSymflowerFix: withSymflowerAssessment, 133 } 134 135 return repositoryAssessment, problems, nil 136 } 137 138 // validateWriteTestsRepository checks if the repository for the "write-tests" task is well-formed. 139 func validateWriteTestsRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) { 140 logger.Printf("validating repository %q", repositoryPath) 141 142 files, err := language.Files(logger, repositoryPath) 143 if err != nil { 144 return pkgerrors.WithStack(err) 145 } 146 147 var sourceFiles []string 148 var testFiles []string 149 for _, file := range files { 150 if strings.HasSuffix(file, language.DefaultTestFileSuffix()) { 151 testFiles = append(testFiles, file) 152 } else if strings.HasSuffix(file, language.DefaultFileExtension()) { 153 sourceFiles = append(sourceFiles, file) 154 } 155 } 156 157 if len(sourceFiles) == 0 { 158 return pkgerrors.Errorf("the repository %q must contain at least one %s source file, but found none", repositoryPath, language.Name()) 159 } else if len(testFiles) > 0 { 160 return pkgerrors.Errorf("the repository %q must contain only %s source files, but found %+v", repositoryPath, language.Name(), testFiles) 161 } 162 163 return nil 164 }