test_retry_utils.py
1 """Tests for agent.retry_utils jittered backoff.""" 2 3 import threading 4 5 import agent.retry_utils as retry_utils 6 from agent.retry_utils import jittered_backoff 7 8 9 def test_backoff_is_exponential(): 10 """Base delay should double each attempt (before jitter).""" 11 for attempt in (1, 2, 3, 4): 12 delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)] 13 expected = min(5.0 * (2 ** (attempt - 1)), 120.0) 14 mean = sum(delays) / len(delays) 15 assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}" 16 17 18 def test_backoff_respects_max_delay(): 19 """Even with high attempt numbers, delay should not exceed max_delay.""" 20 for attempt in (10, 20, 100): 21 delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0) 22 assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s" 23 24 25 def test_backoff_adds_jitter(): 26 """With jitter enabled, delays should vary across calls.""" 27 delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)] 28 assert min(delays) != max(delays), "jitter should produce varying delays" 29 assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay" 30 assert all(d <= 15.0 for d in delays), "jittered delay should be bounded" 31 32 33 def test_backoff_attempt_1_is_base(): 34 """First attempt delay should equal base_delay (with no jitter).""" 35 delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0) 36 assert delay == 3.0 37 38 39 def test_backoff_with_zero_base_delay_returns_max(): 40 """base_delay=0 should return max_delay (guard against busy-wait).""" 41 delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0) 42 assert delay == 60.0 43 44 45 def test_backoff_with_extreme_attempt_returns_max(): 46 """Very large attempt numbers should not overflow and should return max_delay.""" 47 delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) 48 assert delay == 120.0 49 50 51 def test_backoff_negative_attempt_treated_as_one(): 52 """Negative attempt should not crash and behaves like attempt=1.""" 53 delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0) 54 assert delay == 10.0 55 56 57 def test_backoff_thread_safety(): 58 """Concurrent calls should generally produce different delays.""" 59 results = [] 60 barrier = threading.Barrier(8) 61 62 def _call_backoff(): 63 barrier.wait() 64 results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)) 65 66 threads = [threading.Thread(target=_call_backoff) for _ in range(8)] 67 for t in threads: 68 t.start() 69 for t in threads: 70 t.join(timeout=5) 71 72 assert len(results) == 8 73 unique = len(set(results)) 74 assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique" 75 76 77 def test_backoff_uses_locked_tick_for_seed(monkeypatch): 78 """Seed derivation should use per-call tick captured under lock.""" 79 import time 80 81 monkeypatch.setattr(retry_utils, "_jitter_counter", 0) 82 83 recorded_seeds = [] 84 85 class _RecordingRandom: 86 def __init__(self, seed): 87 recorded_seeds.append(seed) 88 89 def uniform(self, a, b): 90 return 0.0 91 92 monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom) 93 94 fixed_time_ns = 123456789 95 96 def _time_ns_wait_for_two_ticks(): 97 deadline = time.time() + 2.0 98 while retry_utils._jitter_counter < 2 and time.time() < deadline: 99 time.sleep(0.001) 100 return fixed_time_ns 101 102 monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks) 103 104 barrier = threading.Barrier(2) 105 106 def _call(): 107 barrier.wait() 108 jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) 109 110 threads = [threading.Thread(target=_call) for _ in range(2)] 111 for t in threads: 112 t.start() 113 for t in threads: 114 t.join(timeout=5) 115 116 assert len(recorded_seeds) == 2 117 assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}"