task-write-test_test.go
1 package task 2 3 import ( 4 "context" 5 "fmt" 6 "os" 7 "path/filepath" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/mock" 12 "github.com/stretchr/testify/require" 13 "github.com/symflower/eval-dev-quality/evaluate/metrics" 14 metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" 15 tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing" 16 "github.com/symflower/eval-dev-quality/language" 17 "github.com/symflower/eval-dev-quality/language/golang" 18 "github.com/symflower/eval-dev-quality/language/java" 19 "github.com/symflower/eval-dev-quality/language/ruby" 20 languagetesting "github.com/symflower/eval-dev-quality/language/testing" 21 "github.com/symflower/eval-dev-quality/log" 22 modeltesting "github.com/symflower/eval-dev-quality/model/testing" 23 "github.com/symflower/eval-dev-quality/task" 24 evaltask "github.com/symflower/eval-dev-quality/task" 25 "github.com/zimmski/osutil" 26 "github.com/zimmski/osutil/bytesutil" 27 ) 28 29 func TestTaskWriteTestsRun(t *testing.T) { 30 validate := func(t *testing.T, tc *tasktesting.TestCaseTask) { 31 t.Run(tc.Name, func(t *testing.T) { 32 task, err := TaskForIdentifier(IdentifierWriteTests) 33 require.NoError(t, err) 34 tc.Task = task 35 36 tc.Validate(t, 37 func(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository evaltask.Repository, cleanup func(), err error) { 38 return TemporaryRepository(logger, testDataPath, repositoryPathRelative) 39 }, 40 ) 41 }) 42 } 43 44 t.Run("Clear repository on each task file", func(t *testing.T) { 45 temporaryDirectoryPath := t.TempDir() 46 47 repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain") 48 require.NoError(t, os.MkdirAll(repositoryPath, 0700)) 49 require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "go.mod"), []byte("module plain\n\ngo 1.21.5"), 0600)) 50 require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskA.go"), []byte("package plain\n\nfunc TaskA(){}"), 0600)) 51 require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskB.go"), []byte("package plain\n\nfunc TaskB(){}"), 0600)) 52 53 modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model") 54 55 // Generate invalid code for the first taskcontext. 56 modelMock.RegisterGenerateSuccess(t, "taskA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Once() 57 // Generate valid code for the second taskcontext. 58 modelMock.RegisterGenerateSuccess(t, "taskB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestTaskB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Once() 59 60 validate(t, &tasktesting.TestCaseTask{ 61 Name: "Plain", 62 63 Model: modelMock, 64 Language: &golang.Language{}, 65 TestDataPath: temporaryDirectoryPath, 66 RepositoryPath: filepath.Join("golang", "plain"), 67 68 ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ 69 IdentifierWriteTests: metrics.Assessments{ 70 metrics.AssessmentKeyFilesExecuted: 1, 71 metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, 72 metrics.AssessmentKeyResponseNoError: 2, 73 }, 74 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 75 metrics.AssessmentKeyFilesExecuted: 1, 76 metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, 77 metrics.AssessmentKeyResponseNoError: 2, 78 }, 79 }, 80 ExpectedProblemContains: []string{ 81 "expected 'package', found does", 82 "exit status 1", 83 }, 84 ValidateLog: func(t *testing.T, data string) { 85 assert.Contains(t, data, "Evaluating model \"mocked-model\"") 86 assert.Contains(t, data, "PASS: TestTaskB") 87 }, 88 }) 89 }) 90 91 t.Run("Symflower Fix", func(t *testing.T) { 92 t.Run("Go", func(t *testing.T) { 93 validateGo := func(t *testing.T, testName string, language language.Language, testFileContent string, expectedAssessments map[evaltask.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) { 94 temporaryDirectoryPath := t.TempDir() 95 repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain") 96 require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath)) 97 98 modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model") 99 modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Once() 100 101 validate(t, &tasktesting.TestCaseTask{ 102 Name: testName, 103 104 Model: modelMock, 105 Language: language, 106 TestDataPath: temporaryDirectoryPath, 107 RepositoryPath: filepath.Join("golang", "plain"), 108 109 ExpectedRepositoryAssessment: expectedAssessments, 110 ExpectedProblemContains: expectedProblems, 111 ValidateLog: func(t *testing.T, data string) { 112 assert.Contains(t, data, "Evaluating model \"mocked-model\"") 113 if assertTestsPass { 114 assert.Contains(t, data, "PASS: TestPlain") 115 } 116 }, 117 }) 118 } 119 { 120 expectedAssessments := map[evaltask.Identifier]metrics.Assessments{ 121 IdentifierWriteTests: metrics.Assessments{ 122 metrics.AssessmentKeyFilesExecuted: 1, 123 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 124 metrics.AssessmentKeyResponseNoError: 1, 125 metrics.AssessmentKeyCoverage: 10, 126 }, 127 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 128 metrics.AssessmentKeyFilesExecuted: 1, 129 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 130 metrics.AssessmentKeyResponseNoError: 1, 131 metrics.AssessmentKeyCoverage: 10, 132 }, 133 } 134 validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(` 135 package plain 136 137 import "testing" 138 139 func TestPlain(t *testing.T) { 140 plain() 141 } 142 `), expectedAssessments, nil, true) 143 } 144 { 145 expectedAssessments := map[evaltask.Identifier]metrics.Assessments{ 146 IdentifierWriteTests: metrics.Assessments{ 147 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 148 metrics.AssessmentKeyResponseNoError: 1, 149 }, 150 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 151 metrics.AssessmentKeyFilesExecuted: 1, 152 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 153 metrics.AssessmentKeyResponseNoError: 1, 154 metrics.AssessmentKeyCoverage: 10, 155 }, 156 } 157 expectedProblems := []string{ 158 "imported and not used", 159 } 160 validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(` 161 package plain 162 163 import ( 164 "testing" 165 "strings" 166 ) 167 168 func TestPlain(t *testing.T) { 169 plain() 170 } 171 `), expectedAssessments, expectedProblems, true) 172 } 173 { 174 expectedAssessments := map[evaltask.Identifier]metrics.Assessments{ 175 IdentifierWriteTests: metrics.Assessments{ 176 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 177 metrics.AssessmentKeyResponseNoError: 1, 178 }, 179 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 180 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 181 metrics.AssessmentKeyResponseNoError: 1, 182 }, 183 } 184 expectedProblems := []string{ 185 "expected declaration, found this", 186 "unable to format source code", 187 } 188 validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(` 189 package plain 190 191 this is not valid go code 192 `), expectedAssessments, expectedProblems, false) 193 } 194 { 195 expectedAssessments := map[task.Identifier]metrics.Assessments{ 196 IdentifierWriteTests: metrics.Assessments{ 197 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 198 metrics.AssessmentKeyResponseNoError: 1, 199 }, 200 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 201 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 202 metrics.AssessmentKeyResponseNoError: 1, 203 }, 204 } 205 expectedProblems := []string{ 206 "context deadline exceeded", 207 } 208 209 languageMock := languagetesting.NewMockLanguageNamed(t, "golang") 210 languageMock.On("Files", mock.Anything, mock.Anything).Return([]string{filepath.Join("golang", "plain")}, nil).Once() 211 languageMock.On("ExecuteTests", mock.Anything, mock.Anything).Return(nil, nil, context.DeadlineExceeded).Once() 212 213 validateGo(t, "Execution timeout", languageMock, "", expectedAssessments, expectedProblems, false) 214 } 215 }) 216 }) 217 218 { 219 if osutil.IsWindows() { 220 t.Skip("Ruby is not tested in the Windows CI") 221 } 222 223 temporaryDirectoryPath := t.TempDir() 224 repositoryPath := filepath.Join(temporaryDirectoryPath, "ruby", "plain") 225 require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "ruby", "plain"), repositoryPath)) 226 227 testFileContent := bytesutil.StringTrimIndentations(` 228 require_relative "../lib/plain" 229 230 class TestPlain < Minitest::Test 231 def test_plain 232 plain() 233 end 234 end 235 `) 236 modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model") 237 modelMock.RegisterGenerateSuccess(t, filepath.Join("test", "plain_test.rb"), testFileContent, metricstesting.AssessmentsWithProcessingTime).Maybe() 238 239 validate(t, &tasktesting.TestCaseTask{ 240 Name: "Ruby", 241 242 Model: modelMock, 243 Language: &ruby.Language{}, 244 TestDataPath: temporaryDirectoryPath, 245 RepositoryPath: filepath.Join("ruby", "plain"), 246 247 ExpectedRepositoryAssessment: map[task.Identifier]metrics.Assessments{ 248 IdentifierWriteTests: metrics.Assessments{ 249 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 250 metrics.AssessmentKeyFilesExecuted: 1, 251 metrics.AssessmentKeyCoverage: 10, 252 metrics.AssessmentKeyResponseNoError: 1, 253 }, 254 IdentifierWriteTestsSymflowerFix: metrics.Assessments{ 255 metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, 256 metrics.AssessmentKeyFilesExecuted: 1, 257 metrics.AssessmentKeyCoverage: 10, 258 metrics.AssessmentKeyResponseNoError: 1, 259 }, 260 }, 261 ExpectedProblemContains: nil, 262 ValidateLog: func(t *testing.T, data string) { 263 assert.Contains(t, data, "Evaluating model \"mocked-model\"") 264 }, 265 }) 266 } 267 } 268 269 func TestValidateWriteTestsRepository(t *testing.T) { 270 validate := func(t *testing.T, tc *tasktesting.TestCaseValidateRepository) { 271 tc.Validate(t, validateWriteTestsRepository) 272 } 273 274 t.Run("Go", func(t *testing.T) { 275 t.Run("Plain", func(t *testing.T) { 276 validate(t, &tasktesting.TestCaseValidateRepository{ 277 Name: "Well-formed", 278 279 TestdataPath: filepath.Join("..", "..", "testdata"), 280 RepositoryPath: filepath.Join("golang", "plain"), 281 Language: &golang.Language{}, 282 }) 283 }) 284 t.Run("Light", func(t *testing.T) { 285 validate(t, &tasktesting.TestCaseValidateRepository{ 286 Name: "Repository with test files", 287 288 Before: func(repositoryPath string) { 289 fileATest, err := os.Create(filepath.Join(repositoryPath, "fileA_test.go")) 290 require.NoError(t, err) 291 fileATest.Close() 292 }, 293 294 TestdataPath: filepath.Join("..", "..", "testdata"), 295 RepositoryPath: filepath.Join("golang", "light"), 296 Language: &golang.Language{}, 297 ExpectedError: func(t *testing.T, err error) { 298 assert.ErrorContains(t, err, "must contain only Go source files, but found [fileA_test.go]") 299 }, 300 }) 301 validate(t, &tasktesting.TestCaseValidateRepository{ 302 Name: "Well-formed", 303 304 TestdataPath: filepath.Join("..", "..", "testdata"), 305 RepositoryPath: filepath.Join("golang", "light"), 306 Language: &golang.Language{}, 307 }) 308 }) 309 }) 310 t.Run("Java", func(t *testing.T) { 311 t.Run("Plain", func(t *testing.T) { 312 validate(t, &tasktesting.TestCaseValidateRepository{ 313 Name: "Well-formed", 314 315 TestdataPath: filepath.Join("..", "..", "testdata"), 316 RepositoryPath: filepath.Join("java", "plain"), 317 Language: &java.Language{}, 318 }) 319 }) 320 t.Run("Light", func(t *testing.T) { 321 validate(t, &tasktesting.TestCaseValidateRepository{ 322 Name: "Repository with test files", 323 324 Before: func(repositoryPath string) { 325 somePackage := filepath.Join(repositoryPath, "src", "test", "java", "com", "eval") 326 require.NoError(t, os.MkdirAll(somePackage, 0700)) 327 328 fileATest, err := os.Create(filepath.Join(somePackage, "FileATest.java")) 329 require.NoError(t, err) 330 fileATest.Close() 331 }, 332 333 TestdataPath: filepath.Join("..", "..", "testdata"), 334 RepositoryPath: filepath.Join("java", "light"), 335 Language: &java.Language{}, 336 337 ExpectedError: func(t *testing.T, err error) { 338 assert.ErrorContains(t, err, fmt.Sprintf("must contain only Java source files, but found [%s]", filepath.Join("src", "test", "java", "com", "eval", "FileATest.java"))) 339 }, 340 }) 341 validate(t, &tasktesting.TestCaseValidateRepository{ 342 Name: "Well-formed", 343 344 TestdataPath: filepath.Join("..", "..", "testdata"), 345 RepositoryPath: filepath.Join("java", "light"), 346 Language: &java.Language{}, 347 }) 348 }) 349 }) 350 }