/ tests / test_batch_runner_checkpoint.py
test_batch_runner_checkpoint.py
  1  """Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
  2  
  3  import json
  4  import os
  5  from pathlib import Path
  6  from threading import Lock
  7  from unittest.mock import patch, MagicMock
  8  
  9  import pytest
 10  
 11  # batch_runner uses relative imports, ensure project root is on path
 12  import sys
 13  sys.path.insert(0, str(Path(__file__).parent.parent))
 14  
 15  from batch_runner import BatchRunner, _process_batch_worker
 16  
 17  
 18  @pytest.fixture
 19  def runner(tmp_path):
 20      """Create a BatchRunner with all paths pointing at tmp_path."""
 21      prompts_file = tmp_path / "prompts.jsonl"
 22      prompts_file.write_text("")
 23      output_file = tmp_path / "output.jsonl"
 24      checkpoint_file = tmp_path / "checkpoint.json"
 25      r = BatchRunner.__new__(BatchRunner)
 26      r.run_name = "test_run"
 27      r.checkpoint_file = checkpoint_file
 28      r.output_file = output_file
 29      r.prompts_file = prompts_file
 30      return r
 31  
 32  
 33  class TestSaveCheckpoint:
 34      """Verify _save_checkpoint writes valid, atomic JSON."""
 35  
 36      def test_writes_valid_json(self, runner):
 37          data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
 38          runner._save_checkpoint(data)
 39  
 40          result = json.loads(runner.checkpoint_file.read_text())
 41          assert result["run_name"] == "test"
 42          assert result["completed_prompts"] == [1, 2, 3]
 43  
 44      def test_adds_last_updated(self, runner):
 45          data = {"run_name": "test", "completed_prompts": []}
 46          runner._save_checkpoint(data)
 47  
 48          result = json.loads(runner.checkpoint_file.read_text())
 49          assert "last_updated" in result
 50          assert result["last_updated"] is not None
 51  
 52      def test_overwrites_previous_checkpoint(self, runner):
 53          runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
 54          runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
 55  
 56          result = json.loads(runner.checkpoint_file.read_text())
 57          assert result["completed_prompts"] == [1, 2, 3]
 58  
 59      def test_with_lock(self, runner):
 60          lock = Lock()
 61          data = {"run_name": "test", "completed_prompts": [42]}
 62          runner._save_checkpoint(data, lock=lock)
 63  
 64          result = json.loads(runner.checkpoint_file.read_text())
 65          assert result["completed_prompts"] == [42]
 66  
 67      def test_without_lock(self, runner):
 68          data = {"run_name": "test", "completed_prompts": [99]}
 69          runner._save_checkpoint(data, lock=None)
 70  
 71          result = json.loads(runner.checkpoint_file.read_text())
 72          assert result["completed_prompts"] == [99]
 73  
 74      def test_creates_parent_dirs(self, tmp_path):
 75          runner_deep = BatchRunner.__new__(BatchRunner)
 76          runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
 77  
 78          data = {"run_name": "test", "completed_prompts": []}
 79          runner_deep._save_checkpoint(data)
 80  
 81          assert runner_deep.checkpoint_file.exists()
 82  
 83      def test_no_temp_files_left(self, runner):
 84          runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
 85  
 86          tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
 87                       if ".tmp" in f.name]
 88          assert len(tmp_files) == 0
 89  
 90  
 91  class TestLoadCheckpoint:
 92      """Verify _load_checkpoint reads existing data or returns defaults."""
 93  
 94      def test_returns_empty_when_no_file(self, runner):
 95          result = runner._load_checkpoint()
 96          assert result.get("completed_prompts", []) == []
 97  
 98      def test_loads_existing_checkpoint(self, runner):
 99          data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
100                  "batch_stats": {"0": {"processed": 3}}}
101          runner.checkpoint_file.write_text(json.dumps(data))
102  
103          result = runner._load_checkpoint()
104          assert result["completed_prompts"] == [5, 10, 15]
105          assert result["batch_stats"]["0"]["processed"] == 3
106  
107      def test_handles_corrupt_json(self, runner):
108          runner.checkpoint_file.write_text("{broken json!!")
109  
110          result = runner._load_checkpoint()
111          # Should return empty/default, not crash
112          assert isinstance(result, dict)
113  
114  
115  class TestResumePreservesProgress:
116      """Verify that initializing a run with resume=True loads prior checkpoint."""
117  
118      def test_completed_prompts_loaded_from_checkpoint(self, runner):
119          # Simulate a prior run that completed prompts 0-4
120          prior = {
121              "run_name": "test_run",
122              "completed_prompts": [0, 1, 2, 3, 4],
123              "batch_stats": {"0": {"processed": 5}},
124              "last_updated": "2026-01-01T00:00:00",
125          }
126          runner.checkpoint_file.write_text(json.dumps(prior))
127  
128          # Load checkpoint like run() does
129          checkpoint_data = runner._load_checkpoint()
130          if checkpoint_data.get("run_name") != runner.run_name:
131              checkpoint_data = {
132                  "run_name": runner.run_name,
133                  "completed_prompts": [],
134                  "batch_stats": {},
135                  "last_updated": None,
136              }
137  
138          completed_set = set(checkpoint_data.get("completed_prompts", []))
139          assert completed_set == {0, 1, 2, 3, 4}
140  
141      def test_different_run_name_starts_fresh(self, runner):
142          prior = {
143              "run_name": "different_run",
144              "completed_prompts": [0, 1, 2],
145              "batch_stats": {},
146          }
147          runner.checkpoint_file.write_text(json.dumps(prior))
148  
149          checkpoint_data = runner._load_checkpoint()
150          if checkpoint_data.get("run_name") != runner.run_name:
151              checkpoint_data = {
152                  "run_name": runner.run_name,
153                  "completed_prompts": [],
154                  "batch_stats": {},
155                  "last_updated": None,
156              }
157  
158          assert checkpoint_data["completed_prompts"] == []
159          assert checkpoint_data["run_name"] == "test_run"
160  
161  
162  class TestBatchWorkerResumeBehavior:
163      def test_discarded_no_reasoning_prompts_are_marked_completed(self, tmp_path, monkeypatch):
164          batch_file = tmp_path / "batch_1.jsonl"
165          prompt_result = {
166              "success": True,
167              "trajectory": [{"role": "assistant", "content": "x"}],
168              "reasoning_stats": {"has_any_reasoning": False},
169              "tool_stats": {},
170              "metadata": {},
171              "completed": True,
172              "api_calls": 1,
173              "toolsets_used": [],
174          }
175  
176          monkeypatch.setattr("batch_runner._process_single_prompt", lambda *args, **kwargs: prompt_result)
177  
178          result = _process_batch_worker((
179              1,
180              [(0, {"prompt": "hi"})],
181              tmp_path,
182              set(),
183              {"verbose": False},
184          ))
185  
186          assert result["discarded_no_reasoning"] == 1
187          assert result["completed_prompts"] == [0]
188          assert not batch_file.exists() or batch_file.read_text() == ""
189  
190  
191  class TestFinalCheckpointNoDuplicates:
192      """Regression: the final checkpoint must not contain duplicate prompt
193      indices.
194  
195      Before PR #15161, `run()` populated `completed_prompts_set` incrementally
196      as each batch completed, then at the end built `all_completed_prompts =
197      list(completed_prompts_set)` AND extended it again with every batch's
198      `completed_prompts` — double-counting every index.
199      """
200  
201      def _simulate_final_aggregation_fixed(self, batch_results):
202          """Mirror the fixed code path in batch_runner.run()."""
203          completed_prompts_set = set()
204          for result in batch_results:
205              completed_prompts_set.update(result.get("completed_prompts", []))
206          # This is what the fixed code now writes to the checkpoint:
207          return sorted(completed_prompts_set)
208  
209      def test_no_duplicates_in_final_list(self):
210          batch_results = [
211              {"completed_prompts": [0, 1, 2]},
212              {"completed_prompts": [3, 4]},
213              {"completed_prompts": [5]},
214          ]
215          final = self._simulate_final_aggregation_fixed(batch_results)
216          assert final == [0, 1, 2, 3, 4, 5]
217          assert len(final) == len(set(final))  # no duplicates
218  
219      def test_persisted_checkpoint_has_unique_prompts(self, runner):
220          """Write what run()'s fixed aggregation produces to disk; the file
221          must load back with no duplicate indices."""
222          batch_results = [
223              {"completed_prompts": [0, 1]},
224              {"completed_prompts": [2, 3]},
225          ]
226          final = self._simulate_final_aggregation_fixed(batch_results)
227          runner._save_checkpoint({
228              "run_name": runner.run_name,
229              "completed_prompts": final,
230              "batch_stats": {},
231          })
232          loaded = json.loads(runner.checkpoint_file.read_text())
233          cp = loaded["completed_prompts"]
234          assert cp == sorted(set(cp))
235          assert len(cp) == 4
236  
237      def test_old_buggy_pattern_would_have_duplicates(self):
238          """Document the bug this PR fixes: the old code shape produced
239          duplicates.  Kept as a sanity anchor so a future refactor that
240          re-introduces the pattern is immediately visible."""
241          completed_prompts_set = set()
242          results = []
243          for batch in ({"completed_prompts": [0, 1, 2]},
244                        {"completed_prompts": [3, 4]}):
245              completed_prompts_set.update(batch["completed_prompts"])
246              results.append(batch)
247          # Buggy aggregation (pre-fix):
248          buggy = list(completed_prompts_set)
249          for br in results:
250              buggy.extend(br.get("completed_prompts", []))
251          # Every index appears twice
252          assert len(buggy) == 2 * len(set(buggy))