test_rate_limiter.py
1 import pytest 2 3 from mlflow.genai.evaluation.harness import ( 4 AUTO_INITIAL_RPS, 5 _make_rate_limiter, 6 _parse_rate_limit, 7 ) 8 from mlflow.genai.evaluation.rate_limiter import ( 9 NoOpRateLimiter, 10 RPSRateLimiter, 11 call_with_retry, 12 eval_retry_context, 13 is_rate_limit_error, 14 ) 15 from mlflow.genai.judges.adapters.litellm_adapter import ( 16 _get_litellm_retry_policy, 17 is_litellm_rate_limit_retries_disabled, 18 ) 19 from mlflow.utils.rest_utils import is_429_retry_disabled 20 21 22 class FakeClock: 23 """Deterministic clock for testing. sleep() advances the clock by the requested amount. 24 25 Thread safety is not needed here because RPSRateLimiter's internal lock serializes 26 all calls to clock() and sleep() — they are never called concurrently for a given limiter. 27 """ 28 29 def __init__(self): 30 self._now = 0.0 31 self.sleep_calls: list[float] = [] 32 33 def monotonic(self) -> float: 34 return self._now 35 36 def sleep(self, seconds: float) -> None: 37 self.sleep_calls.append(seconds) 38 self._now += seconds 39 40 41 # ── Token bucket tests ── 42 43 44 def test_invalid_rate_raises(): 45 with pytest.raises(ValueError, match="must be positive"): 46 RPSRateLimiter(0) 47 with pytest.raises(ValueError, match="must be positive"): 48 RPSRateLimiter(-1) 49 50 51 def test_sub_one_rps_can_acquire(): 52 # rps < 1.0 was broken: _max_tokens was set to rps, so the bucket could never 53 # accumulate a full token and acquire() would loop forever. 54 clock = FakeClock() 55 limiter = RPSRateLimiter(0.5, clock=clock.monotonic, sleep=clock.sleep) # 1 req / 2s 56 57 # First acquire: initial tokens=0.5, sleeps 1s to reach 1.0, then succeeds. 58 # Second acquire: tokens=0, sleeps 2s to reach 1.0, then succeeds. 59 limiter.acquire() 60 limiter.acquire() 61 62 total_sleep = sum(clock.sleep_calls) 63 assert total_sleep == pytest.approx(3.0, abs=0.1) 64 65 66 def test_burst_tokens_consumed_without_sleeping(): 67 clock = FakeClock() 68 limiter = RPSRateLimiter(5, clock=clock.monotonic, sleep=clock.sleep) 69 70 for _ in range(5): 71 limiter.acquire() 72 73 assert clock.sleep_calls == [] 74 75 76 def test_sleep_called_when_tokens_exhausted(): 77 clock = FakeClock() 78 limiter = RPSRateLimiter(5, clock=clock.monotonic, sleep=clock.sleep) 79 80 for _ in range(5): 81 limiter.acquire() 82 83 limiter.acquire() 84 assert len(clock.sleep_calls) == 1 85 assert clock.sleep_calls[0] == pytest.approx(0.2, abs=0.01) 86 87 88 def test_total_sleep_for_sustained_rate(): 89 clock = FakeClock() 90 limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep) 91 92 for _ in range(20): 93 limiter.acquire() 94 95 total_sleep = sum(clock.sleep_calls) 96 assert total_sleep == pytest.approx(1.0, abs=0.01) 97 98 99 def test_tokens_refill_after_idle(): 100 clock = FakeClock() 101 limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep) 102 103 for _ in range(10): 104 limiter.acquire() 105 106 clock._now += 1.0 107 108 sleep_before = len(clock.sleep_calls) 109 for _ in range(10): 110 limiter.acquire() 111 112 assert clock.sleep_calls[sleep_before:] == [] 113 114 115 def test_partial_refill(): 116 clock = FakeClock() 117 limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep) 118 119 for _ in range(10): 120 limiter.acquire() 121 122 clock._now += 0.5 123 124 sleep_before = len(clock.sleep_calls) 125 for _ in range(5): 126 limiter.acquire() 127 128 assert clock.sleep_calls[sleep_before:] == [] 129 130 limiter.acquire() 131 assert len(clock.sleep_calls) == sleep_before + 1 132 133 134 def test_noop_acquire_does_nothing(): 135 limiter = NoOpRateLimiter() 136 for _ in range(1000): 137 limiter.acquire() 138 139 140 # ── _make_rate_limiter / _parse_rate_limit tests ── 141 142 143 def test_make_rate_limiter_positive_rate(): 144 assert isinstance(_make_rate_limiter(10.0), RPSRateLimiter) 145 146 147 def test_make_rate_limiter_zero_returns_noop(): 148 assert isinstance(_make_rate_limiter(0.0), NoOpRateLimiter) 149 150 151 def test_make_rate_limiter_none_returns_noop(): 152 assert isinstance(_make_rate_limiter(None), NoOpRateLimiter) 153 154 155 def test_make_rate_limiter_adaptive(): 156 limiter = _make_rate_limiter(10.0, adaptive=True) 157 assert isinstance(limiter, RPSRateLimiter) 158 assert limiter._adaptive is True 159 160 161 @pytest.mark.parametrize( 162 ("raw", "expected_rps", "expected_adaptive"), 163 [ 164 ("auto", AUTO_INITIAL_RPS, True), 165 ("AUTO", AUTO_INITIAL_RPS, True), 166 (" Auto ", AUTO_INITIAL_RPS, True), 167 ("25", 25.0, False), 168 ("0", None, False), 169 (None, None, False), 170 ], 171 ) 172 def test_parse_rate_limit(raw, expected_rps, expected_adaptive): 173 rps, adaptive = _parse_rate_limit(raw) 174 assert rps == expected_rps 175 assert adaptive == expected_adaptive 176 177 178 # ── is_rate_limit_error tests ── 179 180 181 class _FakeRateLimitError(Exception): 182 pass 183 184 185 _FakeRateLimitError.__name__ = "RateLimitError" 186 187 188 class _FakeStatusCodeError(Exception): 189 def __init__(self, status_code): 190 self.status_code = status_code 191 super().__init__(f"HTTP {status_code}") 192 193 194 class _FakeResponseError(Exception): 195 def __init__(self, status_code): 196 self.response = type("R", (), {"status_code": status_code})() 197 super().__init__(f"HTTP {status_code}") 198 199 200 @pytest.mark.parametrize( 201 ("exc", "expected"), 202 [ 203 (_FakeRateLimitError("rate limit"), True), 204 (_FakeStatusCodeError(429), True), 205 (_FakeResponseError(429), True), 206 (Exception("Error 429: too many requests"), True), 207 (Exception("rate limit exceeded"), True), 208 (_FakeStatusCodeError(500), False), 209 (_FakeResponseError(500), False), 210 (Exception("something else entirely"), False), 211 (ValueError("bad value"), False), 212 ], 213 ) 214 def test_is_rate_limit_error(exc, expected): 215 assert is_rate_limit_error(exc) == expected 216 217 218 # ── AIMD tests ── 219 220 221 def test_throttle_halves_rate(): 222 clock = FakeClock() 223 limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep) 224 225 limiter.report_throttle() 226 assert limiter._rps == pytest.approx(5.0) 227 228 229 def test_throttle_respects_floor(): 230 clock = FakeClock() 231 limiter = RPSRateLimiter(2.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep) 232 233 # First throttle: 2.0 * 0.5 = 1.0 234 limiter.report_throttle() 235 assert limiter._rps == pytest.approx(1.0) 236 237 # Second throttle (after cooldown): should stay at floor 1.0 238 clock._now += 10.0 239 limiter.report_throttle() 240 assert limiter._rps == pytest.approx(1.0) 241 242 243 def test_throttle_cooldown_coalesces_rapid_signals(): 244 clock = FakeClock() 245 limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep) 246 247 limiter.report_throttle() 248 assert limiter._rps == pytest.approx(5.0) 249 250 # Within cooldown window — should be ignored 251 clock._now += 1.0 252 limiter.report_throttle() 253 assert limiter._rps == pytest.approx(5.0) 254 255 # After cooldown — should take effect 256 clock._now += 10.0 257 limiter.report_throttle() 258 assert limiter._rps == pytest.approx(2.5) 259 260 261 def test_success_restores_rate(): 262 clock = FakeClock() 263 limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep) 264 265 limiter.report_throttle() 266 assert limiter._rps == pytest.approx(5.0) 267 268 # Repeatedly report success — rate should climb back past initial 269 for _ in range(100): 270 limiter.report_success() 271 272 assert limiter._rps > 10.0 273 274 275 @pytest.mark.parametrize( 276 ("multiplier", "expected_ceiling"), 277 [(5.0, 50.0), (3.0, 30.0)], 278 ) 279 def test_success_climbs_to_multiplier_ceiling(multiplier, expected_ceiling): 280 clock = FakeClock() 281 limiter = RPSRateLimiter( 282 10.0, 283 adaptive=True, 284 max_rps_multiplier=multiplier, 285 clock=clock.monotonic, 286 sleep=clock.sleep, 287 ) 288 for _ in range(10000): 289 limiter.report_success() 290 assert limiter._rps == pytest.approx(expected_ceiling) 291 292 293 def test_adaptive_false_ignores_throttle_and_success(): 294 clock = FakeClock() 295 limiter = RPSRateLimiter(10.0, adaptive=False, clock=clock.monotonic, sleep=clock.sleep) 296 297 limiter.report_throttle() 298 assert limiter._rps == pytest.approx(10.0) 299 300 limiter.report_success() 301 assert limiter._rps == pytest.approx(10.0) 302 303 304 # ── call_with_retry tests ── 305 306 307 def test_call_with_retry_success(): 308 sleep_calls = [] 309 limiter = NoOpRateLimiter() 310 result = call_with_retry(lambda: 42, limiter, max_retries=3, sleep=sleep_calls.append) 311 assert result == 42 312 assert sleep_calls == [] 313 314 315 def test_call_with_retry_retries_on_429_then_succeeds(): 316 sleep_calls = [] 317 limiter = NoOpRateLimiter() 318 attempts = [] 319 320 def flaky_fn(): 321 attempts.append(1) 322 if len(attempts) < 3: 323 raise _FakeRateLimitError("rate limited") 324 return "ok" 325 326 result = call_with_retry(flaky_fn, limiter, max_retries=3, sleep=sleep_calls.append) 327 assert result == "ok" 328 assert len(attempts) == 3 329 # Two retries with exponential backoff: 2^0=1, 2^1=2 330 assert sleep_calls == [1, 2] 331 332 333 def test_call_with_retry_non_429_propagates_immediately(): 334 sleep_calls = [] 335 limiter = NoOpRateLimiter() 336 337 def always_raises(): 338 raise ValueError("bad input") 339 340 with pytest.raises(ValueError, match="bad input"): 341 call_with_retry(always_raises, limiter, max_retries=3, sleep=sleep_calls.append) 342 assert sleep_calls == [] 343 344 345 def test_call_with_retry_exhausted_retries(): 346 sleep_calls = [] 347 limiter = NoOpRateLimiter() 348 349 def always_rate_limited(): 350 raise _FakeRateLimitError("rate limited") 351 352 with pytest.raises(_FakeRateLimitError, match="rate limited"): 353 call_with_retry(always_rate_limited, limiter, max_retries=2, sleep=sleep_calls.append) 354 # 3 attempts total (initial + 2 retries), 2 sleeps 355 assert len(sleep_calls) == 2 356 357 358 def test_call_with_retry_reports_throttle_and_success(): 359 clock = FakeClock() 360 limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep) 361 attempts = [] 362 363 def flaky_fn(): 364 attempts.append(1) 365 if len(attempts) == 1: 366 raise _FakeRateLimitError("rate limited") 367 return "ok" 368 369 result = call_with_retry(flaky_fn, limiter, max_retries=3, sleep=clock.sleep) 370 assert result == "ok" 371 # After throttle: 10.0 * 0.5 = 5.0, then success bumps it back up slightly 372 assert limiter._rps < 10.0 373 374 375 # ── eval_retry_context tests ── 376 377 378 def _retry_flags_active(): 379 """Check that both downstream retry-suppression flags are set.""" 380 return is_litellm_rate_limit_retries_disabled() and is_429_retry_disabled() 381 382 383 def test_eval_retry_context_sets_and_resets(): 384 assert not _retry_flags_active() 385 386 with eval_retry_context(): 387 assert _retry_flags_active() 388 389 assert not _retry_flags_active() 390 391 392 def test_eval_retry_context_nests(): 393 assert not _retry_flags_active() 394 395 with eval_retry_context(): 396 assert _retry_flags_active() 397 with eval_retry_context(): 398 assert _retry_flags_active() 399 assert _retry_flags_active() 400 401 assert not _retry_flags_active() 402 403 404 def test_litellm_retry_policy_disables_rate_limit_retries_when_flag_set(): 405 with eval_retry_context(): 406 policy = _get_litellm_retry_policy(3) 407 assert policy.RateLimitErrorRetries == 0 408 assert policy.TimeoutErrorRetries == 3