test_atomic_json_write.py
1 """Tests for utils.atomic_json_write — crash-safe JSON file writes.""" 2 3 import json 4 import os 5 from pathlib import Path 6 from unittest.mock import patch 7 8 import pytest 9 10 from utils import atomic_json_write 11 12 13 class TestAtomicJsonWrite: 14 """Core atomic write behavior.""" 15 16 def test_writes_valid_json(self, tmp_path): 17 target = tmp_path / "data.json" 18 data = {"key": "value", "nested": {"a": 1}} 19 atomic_json_write(target, data) 20 21 result = json.loads(target.read_text(encoding="utf-8")) 22 assert result == data 23 24 def test_creates_parent_directories(self, tmp_path): 25 target = tmp_path / "deep" / "nested" / "dir" / "data.json" 26 atomic_json_write(target, {"ok": True}) 27 28 assert target.exists() 29 assert json.loads(target.read_text())["ok"] is True 30 31 def test_overwrites_existing_file(self, tmp_path): 32 target = tmp_path / "data.json" 33 target.write_text('{"old": true}') 34 35 atomic_json_write(target, {"new": True}) 36 result = json.loads(target.read_text()) 37 assert result == {"new": True} 38 39 def test_preserves_original_on_serialization_error(self, tmp_path): 40 target = tmp_path / "data.json" 41 original = {"preserved": True} 42 target.write_text(json.dumps(original)) 43 44 # Try to write non-serializable data — should fail 45 with pytest.raises(TypeError): 46 atomic_json_write(target, {"bad": object()}) 47 48 # Original file should be untouched 49 result = json.loads(target.read_text()) 50 assert result == original 51 52 def test_no_leftover_temp_files_on_success(self, tmp_path): 53 target = tmp_path / "data.json" 54 atomic_json_write(target, [1, 2, 3]) 55 56 # No .tmp files should be left behind 57 tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name] 58 assert len(tmp_files) == 0 59 assert target.exists() 60 61 def test_no_leftover_temp_files_on_failure(self, tmp_path): 62 target = tmp_path / "data.json" 63 64 with pytest.raises(TypeError): 65 atomic_json_write(target, {"bad": object()}) 66 67 # No temp files should be left behind 68 tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name] 69 assert len(tmp_files) == 0 70 71 def test_cleans_up_temp_file_on_baseexception(self, tmp_path): 72 class SimulatedAbort(BaseException): 73 pass 74 75 target = tmp_path / "data.json" 76 original = {"preserved": True} 77 target.write_text(json.dumps(original), encoding="utf-8") 78 79 with patch("utils.json.dump", side_effect=SimulatedAbort): 80 with pytest.raises(SimulatedAbort): 81 atomic_json_write(target, {"new": True}) 82 83 tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name] 84 assert len(tmp_files) == 0 85 assert json.loads(target.read_text(encoding="utf-8")) == original 86 87 def test_accepts_string_path(self, tmp_path): 88 target = str(tmp_path / "string_path.json") 89 atomic_json_write(target, {"string": True}) 90 91 result = json.loads(Path(target).read_text()) 92 assert result == {"string": True} 93 94 def test_writes_list_data(self, tmp_path): 95 target = tmp_path / "list.json" 96 data = [1, "two", {"three": 3}] 97 atomic_json_write(target, data) 98 99 result = json.loads(target.read_text()) 100 assert result == data 101 102 def test_empty_list(self, tmp_path): 103 target = tmp_path / "empty.json" 104 atomic_json_write(target, []) 105 106 result = json.loads(target.read_text()) 107 assert result == [] 108 109 def test_custom_indent(self, tmp_path): 110 target = tmp_path / "custom.json" 111 atomic_json_write(target, {"a": 1}, indent=4) 112 113 text = target.read_text() 114 assert ' "a"' in text # 4-space indent 115 116 def test_accepts_json_dump_default_hook(self, tmp_path): 117 class CustomValue: 118 def __str__(self): 119 return "custom-value" 120 121 target = tmp_path / "custom_default.json" 122 atomic_json_write(target, {"value": CustomValue()}, default=str) 123 124 result = json.loads(target.read_text(encoding="utf-8")) 125 assert result == {"value": "custom-value"} 126 127 def test_unicode_content(self, tmp_path): 128 target = tmp_path / "unicode.json" 129 data = {"emoji": "🎉", "japanese": "日本語"} 130 atomic_json_write(target, data) 131 132 result = json.loads(target.read_text(encoding="utf-8")) 133 assert result["emoji"] == "🎉" 134 assert result["japanese"] == "日本語" 135 136 def test_concurrent_writes_dont_corrupt(self, tmp_path): 137 """Multiple rapid writes should each produce valid JSON.""" 138 import threading 139 140 target = tmp_path / "concurrent.json" 141 errors = [] 142 143 def writer(n): 144 try: 145 atomic_json_write(target, {"writer": n, "data": list(range(100))}) 146 except Exception as e: 147 errors.append(e) 148 149 threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)] 150 for t in threads: 151 t.start() 152 for t in threads: 153 t.join() 154 155 assert not errors 156 # File should contain valid JSON from one of the writers 157 result = json.loads(target.read_text()) 158 assert "writer" in result 159 assert len(result["data"]) == 100