test_batch_runner_checkpoint.py
1 """Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity.""" 2 3 import json 4 import os 5 from pathlib import Path 6 from threading import Lock 7 from unittest.mock import patch, MagicMock 8 9 import pytest 10 11 # batch_runner uses relative imports, ensure project root is on path 12 import sys 13 sys.path.insert(0, str(Path(__file__).parent.parent)) 14 15 from batch_runner import BatchRunner, _process_batch_worker 16 17 18 @pytest.fixture 19 def runner(tmp_path): 20 """Create a BatchRunner with all paths pointing at tmp_path.""" 21 prompts_file = tmp_path / "prompts.jsonl" 22 prompts_file.write_text("") 23 output_file = tmp_path / "output.jsonl" 24 checkpoint_file = tmp_path / "checkpoint.json" 25 r = BatchRunner.__new__(BatchRunner) 26 r.run_name = "test_run" 27 r.checkpoint_file = checkpoint_file 28 r.output_file = output_file 29 r.prompts_file = prompts_file 30 return r 31 32 33 class TestSaveCheckpoint: 34 """Verify _save_checkpoint writes valid, atomic JSON.""" 35 36 def test_writes_valid_json(self, runner): 37 data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}} 38 runner._save_checkpoint(data) 39 40 result = json.loads(runner.checkpoint_file.read_text()) 41 assert result["run_name"] == "test" 42 assert result["completed_prompts"] == [1, 2, 3] 43 44 def test_adds_last_updated(self, runner): 45 data = {"run_name": "test", "completed_prompts": []} 46 runner._save_checkpoint(data) 47 48 result = json.loads(runner.checkpoint_file.read_text()) 49 assert "last_updated" in result 50 assert result["last_updated"] is not None 51 52 def test_overwrites_previous_checkpoint(self, runner): 53 runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]}) 54 runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]}) 55 56 result = json.loads(runner.checkpoint_file.read_text()) 57 assert result["completed_prompts"] == [1, 2, 3] 58 59 def test_with_lock(self, runner): 60 lock = Lock() 61 data = {"run_name": "test", "completed_prompts": [42]} 62 runner._save_checkpoint(data, lock=lock) 63 64 result = json.loads(runner.checkpoint_file.read_text()) 65 assert result["completed_prompts"] == [42] 66 67 def test_without_lock(self, runner): 68 data = {"run_name": "test", "completed_prompts": [99]} 69 runner._save_checkpoint(data, lock=None) 70 71 result = json.loads(runner.checkpoint_file.read_text()) 72 assert result["completed_prompts"] == [99] 73 74 def test_creates_parent_dirs(self, tmp_path): 75 runner_deep = BatchRunner.__new__(BatchRunner) 76 runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json" 77 78 data = {"run_name": "test", "completed_prompts": []} 79 runner_deep._save_checkpoint(data) 80 81 assert runner_deep.checkpoint_file.exists() 82 83 def test_no_temp_files_left(self, runner): 84 runner._save_checkpoint({"run_name": "test", "completed_prompts": []}) 85 86 tmp_files = [f for f in runner.checkpoint_file.parent.iterdir() 87 if ".tmp" in f.name] 88 assert len(tmp_files) == 0 89 90 91 class TestLoadCheckpoint: 92 """Verify _load_checkpoint reads existing data or returns defaults.""" 93 94 def test_returns_empty_when_no_file(self, runner): 95 result = runner._load_checkpoint() 96 assert result.get("completed_prompts", []) == [] 97 98 def test_loads_existing_checkpoint(self, runner): 99 data = {"run_name": "test_run", "completed_prompts": [5, 10, 15], 100 "batch_stats": {"0": {"processed": 3}}} 101 runner.checkpoint_file.write_text(json.dumps(data)) 102 103 result = runner._load_checkpoint() 104 assert result["completed_prompts"] == [5, 10, 15] 105 assert result["batch_stats"]["0"]["processed"] == 3 106 107 def test_handles_corrupt_json(self, runner): 108 runner.checkpoint_file.write_text("{broken json!!") 109 110 result = runner._load_checkpoint() 111 # Should return empty/default, not crash 112 assert isinstance(result, dict) 113 114 115 class TestResumePreservesProgress: 116 """Verify that initializing a run with resume=True loads prior checkpoint.""" 117 118 def test_completed_prompts_loaded_from_checkpoint(self, runner): 119 # Simulate a prior run that completed prompts 0-4 120 prior = { 121 "run_name": "test_run", 122 "completed_prompts": [0, 1, 2, 3, 4], 123 "batch_stats": {"0": {"processed": 5}}, 124 "last_updated": "2026-01-01T00:00:00", 125 } 126 runner.checkpoint_file.write_text(json.dumps(prior)) 127 128 # Load checkpoint like run() does 129 checkpoint_data = runner._load_checkpoint() 130 if checkpoint_data.get("run_name") != runner.run_name: 131 checkpoint_data = { 132 "run_name": runner.run_name, 133 "completed_prompts": [], 134 "batch_stats": {}, 135 "last_updated": None, 136 } 137 138 completed_set = set(checkpoint_data.get("completed_prompts", [])) 139 assert completed_set == {0, 1, 2, 3, 4} 140 141 def test_different_run_name_starts_fresh(self, runner): 142 prior = { 143 "run_name": "different_run", 144 "completed_prompts": [0, 1, 2], 145 "batch_stats": {}, 146 } 147 runner.checkpoint_file.write_text(json.dumps(prior)) 148 149 checkpoint_data = runner._load_checkpoint() 150 if checkpoint_data.get("run_name") != runner.run_name: 151 checkpoint_data = { 152 "run_name": runner.run_name, 153 "completed_prompts": [], 154 "batch_stats": {}, 155 "last_updated": None, 156 } 157 158 assert checkpoint_data["completed_prompts"] == [] 159 assert checkpoint_data["run_name"] == "test_run" 160 161 162 class TestBatchWorkerResumeBehavior: 163 def test_discarded_no_reasoning_prompts_are_marked_completed(self, tmp_path, monkeypatch): 164 batch_file = tmp_path / "batch_1.jsonl" 165 prompt_result = { 166 "success": True, 167 "trajectory": [{"role": "assistant", "content": "x"}], 168 "reasoning_stats": {"has_any_reasoning": False}, 169 "tool_stats": {}, 170 "metadata": {}, 171 "completed": True, 172 "api_calls": 1, 173 "toolsets_used": [], 174 } 175 176 monkeypatch.setattr("batch_runner._process_single_prompt", lambda *args, **kwargs: prompt_result) 177 178 result = _process_batch_worker(( 179 1, 180 [(0, {"prompt": "hi"})], 181 tmp_path, 182 set(), 183 {"verbose": False}, 184 )) 185 186 assert result["discarded_no_reasoning"] == 1 187 assert result["completed_prompts"] == [0] 188 assert not batch_file.exists() or batch_file.read_text() == "" 189 190 191 class TestFinalCheckpointNoDuplicates: 192 """Regression: the final checkpoint must not contain duplicate prompt 193 indices. 194 195 Before PR #15161, `run()` populated `completed_prompts_set` incrementally 196 as each batch completed, then at the end built `all_completed_prompts = 197 list(completed_prompts_set)` AND extended it again with every batch's 198 `completed_prompts` — double-counting every index. 199 """ 200 201 def _simulate_final_aggregation_fixed(self, batch_results): 202 """Mirror the fixed code path in batch_runner.run().""" 203 completed_prompts_set = set() 204 for result in batch_results: 205 completed_prompts_set.update(result.get("completed_prompts", [])) 206 # This is what the fixed code now writes to the checkpoint: 207 return sorted(completed_prompts_set) 208 209 def test_no_duplicates_in_final_list(self): 210 batch_results = [ 211 {"completed_prompts": [0, 1, 2]}, 212 {"completed_prompts": [3, 4]}, 213 {"completed_prompts": [5]}, 214 ] 215 final = self._simulate_final_aggregation_fixed(batch_results) 216 assert final == [0, 1, 2, 3, 4, 5] 217 assert len(final) == len(set(final)) # no duplicates 218 219 def test_persisted_checkpoint_has_unique_prompts(self, runner): 220 """Write what run()'s fixed aggregation produces to disk; the file 221 must load back with no duplicate indices.""" 222 batch_results = [ 223 {"completed_prompts": [0, 1]}, 224 {"completed_prompts": [2, 3]}, 225 ] 226 final = self._simulate_final_aggregation_fixed(batch_results) 227 runner._save_checkpoint({ 228 "run_name": runner.run_name, 229 "completed_prompts": final, 230 "batch_stats": {}, 231 }) 232 loaded = json.loads(runner.checkpoint_file.read_text()) 233 cp = loaded["completed_prompts"] 234 assert cp == sorted(set(cp)) 235 assert len(cp) == 4 236 237 def test_old_buggy_pattern_would_have_duplicates(self): 238 """Document the bug this PR fixes: the old code shape produced 239 duplicates. Kept as a sanity anchor so a future refactor that 240 re-introduces the pattern is immediately visible.""" 241 completed_prompts_set = set() 242 results = [] 243 for batch in ({"completed_prompts": [0, 1, 2]}, 244 {"completed_prompts": [3, 4]}): 245 completed_prompts_set.update(batch["completed_prompts"]) 246 results.append(batch) 247 # Buggy aggregation (pre-fix): 248 buggy = list(completed_prompts_set) 249 for br in results: 250 buggy.extend(br.get("completed_prompts", [])) 251 # Every index appears twice 252 assert len(buggy) == 2 * len(set(buggy))