/ qwencoder-eval / instruct / eval-dev-quality / evaluate / task / task-write-test_test.go
task-write-test_test.go
  1  package task
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"os"
  7  	"path/filepath"
  8  	"testing"
  9  
 10  	"github.com/stretchr/testify/assert"
 11  	"github.com/stretchr/testify/mock"
 12  	"github.com/stretchr/testify/require"
 13  	"github.com/symflower/eval-dev-quality/evaluate/metrics"
 14  	metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
 15  	tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing"
 16  	"github.com/symflower/eval-dev-quality/language"
 17  	"github.com/symflower/eval-dev-quality/language/golang"
 18  	"github.com/symflower/eval-dev-quality/language/java"
 19  	"github.com/symflower/eval-dev-quality/language/ruby"
 20  	languagetesting "github.com/symflower/eval-dev-quality/language/testing"
 21  	"github.com/symflower/eval-dev-quality/log"
 22  	modeltesting "github.com/symflower/eval-dev-quality/model/testing"
 23  	"github.com/symflower/eval-dev-quality/task"
 24  	evaltask "github.com/symflower/eval-dev-quality/task"
 25  	"github.com/zimmski/osutil"
 26  	"github.com/zimmski/osutil/bytesutil"
 27  )
 28  
 29  func TestTaskWriteTestsRun(t *testing.T) {
 30  	validate := func(t *testing.T, tc *tasktesting.TestCaseTask) {
 31  		t.Run(tc.Name, func(t *testing.T) {
 32  			task, err := TaskForIdentifier(IdentifierWriteTests)
 33  			require.NoError(t, err)
 34  			tc.Task = task
 35  
 36  			tc.Validate(t,
 37  				func(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository evaltask.Repository, cleanup func(), err error) {
 38  					return TemporaryRepository(logger, testDataPath, repositoryPathRelative)
 39  				},
 40  			)
 41  		})
 42  	}
 43  
 44  	t.Run("Clear repository on each task file", func(t *testing.T) {
 45  		temporaryDirectoryPath := t.TempDir()
 46  
 47  		repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain")
 48  		require.NoError(t, os.MkdirAll(repositoryPath, 0700))
 49  		require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "go.mod"), []byte("module plain\n\ngo 1.21.5"), 0600))
 50  		require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskA.go"), []byte("package plain\n\nfunc TaskA(){}"), 0600))
 51  		require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskB.go"), []byte("package plain\n\nfunc TaskB(){}"), 0600))
 52  
 53  		modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
 54  
 55  		// Generate invalid code for the first taskcontext.
 56  		modelMock.RegisterGenerateSuccess(t, "taskA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Once()
 57  		// Generate valid code for the second taskcontext.
 58  		modelMock.RegisterGenerateSuccess(t, "taskB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestTaskB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Once()
 59  
 60  		validate(t, &tasktesting.TestCaseTask{
 61  			Name: "Plain",
 62  
 63  			Model:          modelMock,
 64  			Language:       &golang.Language{},
 65  			TestDataPath:   temporaryDirectoryPath,
 66  			RepositoryPath: filepath.Join("golang", "plain"),
 67  
 68  			ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{
 69  				IdentifierWriteTests: metrics.Assessments{
 70  					metrics.AssessmentKeyFilesExecuted:                 1,
 71  					metrics.AssessmentKeyFilesExecutedMaximumReachable: 2,
 72  					metrics.AssessmentKeyResponseNoError:               2,
 73  				},
 74  				IdentifierWriteTestsSymflowerFix: metrics.Assessments{
 75  					metrics.AssessmentKeyFilesExecuted:                 1,
 76  					metrics.AssessmentKeyFilesExecutedMaximumReachable: 2,
 77  					metrics.AssessmentKeyResponseNoError:               2,
 78  				},
 79  			},
 80  			ExpectedProblemContains: []string{
 81  				"expected 'package', found does",
 82  				"exit status 1",
 83  			},
 84  			ValidateLog: func(t *testing.T, data string) {
 85  				assert.Contains(t, data, "Evaluating model \"mocked-model\"")
 86  				assert.Contains(t, data, "PASS: TestTaskB")
 87  			},
 88  		})
 89  	})
 90  
 91  	t.Run("Symflower Fix", func(t *testing.T) {
 92  		t.Run("Go", func(t *testing.T) {
 93  			validateGo := func(t *testing.T, testName string, language language.Language, testFileContent string, expectedAssessments map[evaltask.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) {
 94  				temporaryDirectoryPath := t.TempDir()
 95  				repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain")
 96  				require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath))
 97  
 98  				modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
 99  				modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Once()
100  
101  				validate(t, &tasktesting.TestCaseTask{
102  					Name: testName,
103  
104  					Model:          modelMock,
105  					Language:       language,
106  					TestDataPath:   temporaryDirectoryPath,
107  					RepositoryPath: filepath.Join("golang", "plain"),
108  
109  					ExpectedRepositoryAssessment: expectedAssessments,
110  					ExpectedProblemContains:      expectedProblems,
111  					ValidateLog: func(t *testing.T, data string) {
112  						assert.Contains(t, data, "Evaluating model \"mocked-model\"")
113  						if assertTestsPass {
114  							assert.Contains(t, data, "PASS: TestPlain")
115  						}
116  					},
117  				})
118  			}
119  			{
120  				expectedAssessments := map[evaltask.Identifier]metrics.Assessments{
121  					IdentifierWriteTests: metrics.Assessments{
122  						metrics.AssessmentKeyFilesExecuted:                 1,
123  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
124  						metrics.AssessmentKeyResponseNoError:               1,
125  						metrics.AssessmentKeyCoverage:                      10,
126  					},
127  					IdentifierWriteTestsSymflowerFix: metrics.Assessments{
128  						metrics.AssessmentKeyFilesExecuted:                 1,
129  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
130  						metrics.AssessmentKeyResponseNoError:               1,
131  						metrics.AssessmentKeyCoverage:                      10,
132  					},
133  				}
134  				validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(`
135  					package plain
136  
137  					import "testing"
138  
139  					func TestPlain(t *testing.T) {
140  						   plain()
141  					}
142  				`), expectedAssessments, nil, true)
143  			}
144  			{
145  				expectedAssessments := map[evaltask.Identifier]metrics.Assessments{
146  					IdentifierWriteTests: metrics.Assessments{
147  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
148  						metrics.AssessmentKeyResponseNoError:               1,
149  					},
150  					IdentifierWriteTestsSymflowerFix: metrics.Assessments{
151  						metrics.AssessmentKeyFilesExecuted:                 1,
152  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
153  						metrics.AssessmentKeyResponseNoError:               1,
154  						metrics.AssessmentKeyCoverage:                      10,
155  					},
156  				}
157  				expectedProblems := []string{
158  					"imported and not used",
159  				}
160  				validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(`
161  					package plain
162  
163  					import (
164  						"testing"
165  						"strings"
166  					)
167  
168  					func TestPlain(t *testing.T) {
169  					   	plain()
170  					}
171  				`), expectedAssessments, expectedProblems, true)
172  			}
173  			{
174  				expectedAssessments := map[evaltask.Identifier]metrics.Assessments{
175  					IdentifierWriteTests: metrics.Assessments{
176  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
177  						metrics.AssessmentKeyResponseNoError:               1,
178  					},
179  					IdentifierWriteTestsSymflowerFix: metrics.Assessments{
180  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
181  						metrics.AssessmentKeyResponseNoError:               1,
182  					},
183  				}
184  				expectedProblems := []string{
185  					"expected declaration, found this",
186  					"unable to format source code",
187  				}
188  				validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(`
189  					package plain
190  
191  					this is not valid go code
192  				`), expectedAssessments, expectedProblems, false)
193  			}
194  			{
195  				expectedAssessments := map[task.Identifier]metrics.Assessments{
196  					IdentifierWriteTests: metrics.Assessments{
197  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
198  						metrics.AssessmentKeyResponseNoError:               1,
199  					},
200  					IdentifierWriteTestsSymflowerFix: metrics.Assessments{
201  						metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
202  						metrics.AssessmentKeyResponseNoError:               1,
203  					},
204  				}
205  				expectedProblems := []string{
206  					"context deadline exceeded",
207  				}
208  
209  				languageMock := languagetesting.NewMockLanguageNamed(t, "golang")
210  				languageMock.On("Files", mock.Anything, mock.Anything).Return([]string{filepath.Join("golang", "plain")}, nil).Once()
211  				languageMock.On("ExecuteTests", mock.Anything, mock.Anything).Return(nil, nil, context.DeadlineExceeded).Once()
212  
213  				validateGo(t, "Execution timeout", languageMock, "", expectedAssessments, expectedProblems, false)
214  			}
215  		})
216  	})
217  
218  	{
219  		if osutil.IsWindows() {
220  			t.Skip("Ruby is not tested in the Windows CI")
221  		}
222  
223  		temporaryDirectoryPath := t.TempDir()
224  		repositoryPath := filepath.Join(temporaryDirectoryPath, "ruby", "plain")
225  		require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "ruby", "plain"), repositoryPath))
226  
227  		testFileContent := bytesutil.StringTrimIndentations(`
228  			require_relative "../lib/plain"
229  
230  			class TestPlain < Minitest::Test
231  				def test_plain
232  					plain()
233  				end
234  			end
235  		`)
236  		modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
237  		modelMock.RegisterGenerateSuccess(t, filepath.Join("test", "plain_test.rb"), testFileContent, metricstesting.AssessmentsWithProcessingTime).Maybe()
238  
239  		validate(t, &tasktesting.TestCaseTask{
240  			Name: "Ruby",
241  
242  			Model:          modelMock,
243  			Language:       &ruby.Language{},
244  			TestDataPath:   temporaryDirectoryPath,
245  			RepositoryPath: filepath.Join("ruby", "plain"),
246  
247  			ExpectedRepositoryAssessment: map[task.Identifier]metrics.Assessments{
248  				IdentifierWriteTests: metrics.Assessments{
249  					metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
250  					metrics.AssessmentKeyFilesExecuted:                 1,
251  					metrics.AssessmentKeyCoverage:                      10,
252  					metrics.AssessmentKeyResponseNoError:               1,
253  				},
254  				IdentifierWriteTestsSymflowerFix: metrics.Assessments{
255  					metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
256  					metrics.AssessmentKeyFilesExecuted:                 1,
257  					metrics.AssessmentKeyCoverage:                      10,
258  					metrics.AssessmentKeyResponseNoError:               1,
259  				},
260  			},
261  			ExpectedProblemContains: nil,
262  			ValidateLog: func(t *testing.T, data string) {
263  				assert.Contains(t, data, "Evaluating model \"mocked-model\"")
264  			},
265  		})
266  	}
267  }
268  
269  func TestValidateWriteTestsRepository(t *testing.T) {
270  	validate := func(t *testing.T, tc *tasktesting.TestCaseValidateRepository) {
271  		tc.Validate(t, validateWriteTestsRepository)
272  	}
273  
274  	t.Run("Go", func(t *testing.T) {
275  		t.Run("Plain", func(t *testing.T) {
276  			validate(t, &tasktesting.TestCaseValidateRepository{
277  				Name: "Well-formed",
278  
279  				TestdataPath:   filepath.Join("..", "..", "testdata"),
280  				RepositoryPath: filepath.Join("golang", "plain"),
281  				Language:       &golang.Language{},
282  			})
283  		})
284  		t.Run("Light", func(t *testing.T) {
285  			validate(t, &tasktesting.TestCaseValidateRepository{
286  				Name: "Repository with test files",
287  
288  				Before: func(repositoryPath string) {
289  					fileATest, err := os.Create(filepath.Join(repositoryPath, "fileA_test.go"))
290  					require.NoError(t, err)
291  					fileATest.Close()
292  				},
293  
294  				TestdataPath:   filepath.Join("..", "..", "testdata"),
295  				RepositoryPath: filepath.Join("golang", "light"),
296  				Language:       &golang.Language{},
297  				ExpectedError: func(t *testing.T, err error) {
298  					assert.ErrorContains(t, err, "must contain only Go source files, but found [fileA_test.go]")
299  				},
300  			})
301  			validate(t, &tasktesting.TestCaseValidateRepository{
302  				Name: "Well-formed",
303  
304  				TestdataPath:   filepath.Join("..", "..", "testdata"),
305  				RepositoryPath: filepath.Join("golang", "light"),
306  				Language:       &golang.Language{},
307  			})
308  		})
309  	})
310  	t.Run("Java", func(t *testing.T) {
311  		t.Run("Plain", func(t *testing.T) {
312  			validate(t, &tasktesting.TestCaseValidateRepository{
313  				Name: "Well-formed",
314  
315  				TestdataPath:   filepath.Join("..", "..", "testdata"),
316  				RepositoryPath: filepath.Join("java", "plain"),
317  				Language:       &java.Language{},
318  			})
319  		})
320  		t.Run("Light", func(t *testing.T) {
321  			validate(t, &tasktesting.TestCaseValidateRepository{
322  				Name: "Repository with test files",
323  
324  				Before: func(repositoryPath string) {
325  					somePackage := filepath.Join(repositoryPath, "src", "test", "java", "com", "eval")
326  					require.NoError(t, os.MkdirAll(somePackage, 0700))
327  
328  					fileATest, err := os.Create(filepath.Join(somePackage, "FileATest.java"))
329  					require.NoError(t, err)
330  					fileATest.Close()
331  				},
332  
333  				TestdataPath:   filepath.Join("..", "..", "testdata"),
334  				RepositoryPath: filepath.Join("java", "light"),
335  				Language:       &java.Language{},
336  
337  				ExpectedError: func(t *testing.T, err error) {
338  					assert.ErrorContains(t, err, fmt.Sprintf("must contain only Java source files, but found [%s]", filepath.Join("src", "test", "java", "com", "eval", "FileATest.java")))
339  				},
340  			})
341  			validate(t, &tasktesting.TestCaseValidateRepository{
342  				Name: "Well-formed",
343  
344  				TestdataPath:   filepath.Join("..", "..", "testdata"),
345  				RepositoryPath: filepath.Join("java", "light"),
346  				Language:       &java.Language{},
347  			})
348  		})
349  	})
350  }