/ tests / test_retry_utils.py
test_retry_utils.py
  1  """Tests for agent.retry_utils jittered backoff."""
  2  
  3  import threading
  4  
  5  import agent.retry_utils as retry_utils
  6  from agent.retry_utils import jittered_backoff
  7  
  8  
  9  def test_backoff_is_exponential():
 10      """Base delay should double each attempt (before jitter)."""
 11      for attempt in (1, 2, 3, 4):
 12          delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)]
 13          expected = min(5.0 * (2 ** (attempt - 1)), 120.0)
 14          mean = sum(delays) / len(delays)
 15          assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}"
 16  
 17  
 18  def test_backoff_respects_max_delay():
 19      """Even with high attempt numbers, delay should not exceed max_delay."""
 20      for attempt in (10, 20, 100):
 21          delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0)
 22          assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s"
 23  
 24  
 25  def test_backoff_adds_jitter():
 26      """With jitter enabled, delays should vary across calls."""
 27      delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)]
 28      assert min(delays) != max(delays), "jitter should produce varying delays"
 29      assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay"
 30      assert all(d <= 15.0 for d in delays), "jittered delay should be bounded"
 31  
 32  
 33  def test_backoff_attempt_1_is_base():
 34      """First attempt delay should equal base_delay (with no jitter)."""
 35      delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0)
 36      assert delay == 3.0
 37  
 38  
 39  def test_backoff_with_zero_base_delay_returns_max():
 40      """base_delay=0 should return max_delay (guard against busy-wait)."""
 41      delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0)
 42      assert delay == 60.0
 43  
 44  
 45  def test_backoff_with_extreme_attempt_returns_max():
 46      """Very large attempt numbers should not overflow and should return max_delay."""
 47      delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0)
 48      assert delay == 120.0
 49  
 50  
 51  def test_backoff_negative_attempt_treated_as_one():
 52      """Negative attempt should not crash and behaves like attempt=1."""
 53      delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0)
 54      assert delay == 10.0
 55  
 56  
 57  def test_backoff_thread_safety():
 58      """Concurrent calls should generally produce different delays."""
 59      results = []
 60      barrier = threading.Barrier(8)
 61  
 62      def _call_backoff():
 63          barrier.wait()
 64          results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5))
 65  
 66      threads = [threading.Thread(target=_call_backoff) for _ in range(8)]
 67      for t in threads:
 68          t.start()
 69      for t in threads:
 70          t.join(timeout=5)
 71  
 72      assert len(results) == 8
 73      unique = len(set(results))
 74      assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique"
 75  
 76  
 77  def test_backoff_uses_locked_tick_for_seed(monkeypatch):
 78      """Seed derivation should use per-call tick captured under lock."""
 79      import time
 80  
 81      monkeypatch.setattr(retry_utils, "_jitter_counter", 0)
 82  
 83      recorded_seeds = []
 84  
 85      class _RecordingRandom:
 86          def __init__(self, seed):
 87              recorded_seeds.append(seed)
 88  
 89          def uniform(self, a, b):
 90              return 0.0
 91  
 92      monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom)
 93  
 94      fixed_time_ns = 123456789
 95  
 96      def _time_ns_wait_for_two_ticks():
 97          deadline = time.time() + 2.0
 98          while retry_utils._jitter_counter < 2 and time.time() < deadline:
 99              time.sleep(0.001)
100          return fixed_time_ns
101  
102      monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks)
103  
104      barrier = threading.Barrier(2)
105  
106      def _call():
107          barrier.wait()
108          jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)
109  
110      threads = [threading.Thread(target=_call) for _ in range(2)]
111      for t in threads:
112          t.start()
113      for t in threads:
114          t.join(timeout=5)
115  
116      assert len(recorded_seeds) == 2
117      assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}"