/ tests / hermes_cli / test_atomic_json_write.py
test_atomic_json_write.py
  1  """Tests for utils.atomic_json_write — crash-safe JSON file writes."""
  2  
  3  import json
  4  import os
  5  from pathlib import Path
  6  from unittest.mock import patch
  7  
  8  import pytest
  9  
 10  from utils import atomic_json_write
 11  
 12  
 13  class TestAtomicJsonWrite:
 14      """Core atomic write behavior."""
 15  
 16      def test_writes_valid_json(self, tmp_path):
 17          target = tmp_path / "data.json"
 18          data = {"key": "value", "nested": {"a": 1}}
 19          atomic_json_write(target, data)
 20  
 21          result = json.loads(target.read_text(encoding="utf-8"))
 22          assert result == data
 23  
 24      def test_creates_parent_directories(self, tmp_path):
 25          target = tmp_path / "deep" / "nested" / "dir" / "data.json"
 26          atomic_json_write(target, {"ok": True})
 27  
 28          assert target.exists()
 29          assert json.loads(target.read_text())["ok"] is True
 30  
 31      def test_overwrites_existing_file(self, tmp_path):
 32          target = tmp_path / "data.json"
 33          target.write_text('{"old": true}')
 34  
 35          atomic_json_write(target, {"new": True})
 36          result = json.loads(target.read_text())
 37          assert result == {"new": True}
 38  
 39      def test_preserves_original_on_serialization_error(self, tmp_path):
 40          target = tmp_path / "data.json"
 41          original = {"preserved": True}
 42          target.write_text(json.dumps(original))
 43  
 44          # Try to write non-serializable data — should fail
 45          with pytest.raises(TypeError):
 46              atomic_json_write(target, {"bad": object()})
 47  
 48          # Original file should be untouched
 49          result = json.loads(target.read_text())
 50          assert result == original
 51  
 52      def test_no_leftover_temp_files_on_success(self, tmp_path):
 53          target = tmp_path / "data.json"
 54          atomic_json_write(target, [1, 2, 3])
 55  
 56          # No .tmp files should be left behind
 57          tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
 58          assert len(tmp_files) == 0
 59          assert target.exists()
 60  
 61      def test_no_leftover_temp_files_on_failure(self, tmp_path):
 62          target = tmp_path / "data.json"
 63  
 64          with pytest.raises(TypeError):
 65              atomic_json_write(target, {"bad": object()})
 66  
 67          # No temp files should be left behind
 68          tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
 69          assert len(tmp_files) == 0
 70  
 71      def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
 72          class SimulatedAbort(BaseException):
 73              pass
 74  
 75          target = tmp_path / "data.json"
 76          original = {"preserved": True}
 77          target.write_text(json.dumps(original), encoding="utf-8")
 78  
 79          with patch("utils.json.dump", side_effect=SimulatedAbort):
 80              with pytest.raises(SimulatedAbort):
 81                  atomic_json_write(target, {"new": True})
 82  
 83          tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
 84          assert len(tmp_files) == 0
 85          assert json.loads(target.read_text(encoding="utf-8")) == original
 86  
 87      def test_accepts_string_path(self, tmp_path):
 88          target = str(tmp_path / "string_path.json")
 89          atomic_json_write(target, {"string": True})
 90  
 91          result = json.loads(Path(target).read_text())
 92          assert result == {"string": True}
 93  
 94      def test_writes_list_data(self, tmp_path):
 95          target = tmp_path / "list.json"
 96          data = [1, "two", {"three": 3}]
 97          atomic_json_write(target, data)
 98  
 99          result = json.loads(target.read_text())
100          assert result == data
101  
102      def test_empty_list(self, tmp_path):
103          target = tmp_path / "empty.json"
104          atomic_json_write(target, [])
105  
106          result = json.loads(target.read_text())
107          assert result == []
108  
109      def test_custom_indent(self, tmp_path):
110          target = tmp_path / "custom.json"
111          atomic_json_write(target, {"a": 1}, indent=4)
112  
113          text = target.read_text()
114          assert '    "a"' in text  # 4-space indent
115  
116      def test_accepts_json_dump_default_hook(self, tmp_path):
117          class CustomValue:
118              def __str__(self):
119                  return "custom-value"
120  
121          target = tmp_path / "custom_default.json"
122          atomic_json_write(target, {"value": CustomValue()}, default=str)
123  
124          result = json.loads(target.read_text(encoding="utf-8"))
125          assert result == {"value": "custom-value"}
126  
127      def test_unicode_content(self, tmp_path):
128          target = tmp_path / "unicode.json"
129          data = {"emoji": "🎉", "japanese": "日本語"}
130          atomic_json_write(target, data)
131  
132          result = json.loads(target.read_text(encoding="utf-8"))
133          assert result["emoji"] == "🎉"
134          assert result["japanese"] == "日本語"
135  
136      def test_concurrent_writes_dont_corrupt(self, tmp_path):
137          """Multiple rapid writes should each produce valid JSON."""
138          import threading
139  
140          target = tmp_path / "concurrent.json"
141          errors = []
142  
143          def writer(n):
144              try:
145                  atomic_json_write(target, {"writer": n, "data": list(range(100))})
146              except Exception as e:
147                  errors.append(e)
148  
149          threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
150          for t in threads:
151              t.start()
152          for t in threads:
153              t.join()
154  
155          assert not errors
156          # File should contain valid JSON from one of the writers
157          result = json.loads(target.read_text())
158          assert "writer" in result
159          assert len(result["data"]) == 100