/ tests / genai / evaluate / test_rate_limiter.py
test_rate_limiter.py
  1  import pytest
  2  
  3  from mlflow.genai.evaluation.harness import (
  4      AUTO_INITIAL_RPS,
  5      _make_rate_limiter,
  6      _parse_rate_limit,
  7  )
  8  from mlflow.genai.evaluation.rate_limiter import (
  9      NoOpRateLimiter,
 10      RPSRateLimiter,
 11      call_with_retry,
 12      eval_retry_context,
 13      is_rate_limit_error,
 14  )
 15  from mlflow.genai.judges.adapters.litellm_adapter import (
 16      _get_litellm_retry_policy,
 17      is_litellm_rate_limit_retries_disabled,
 18  )
 19  from mlflow.utils.rest_utils import is_429_retry_disabled
 20  
 21  
 22  class FakeClock:
 23      """Deterministic clock for testing. sleep() advances the clock by the requested amount.
 24  
 25      Thread safety is not needed here because RPSRateLimiter's internal lock serializes
 26      all calls to clock() and sleep() — they are never called concurrently for a given limiter.
 27      """
 28  
 29      def __init__(self):
 30          self._now = 0.0
 31          self.sleep_calls: list[float] = []
 32  
 33      def monotonic(self) -> float:
 34          return self._now
 35  
 36      def sleep(self, seconds: float) -> None:
 37          self.sleep_calls.append(seconds)
 38          self._now += seconds
 39  
 40  
 41  # ── Token bucket tests ──
 42  
 43  
 44  def test_invalid_rate_raises():
 45      with pytest.raises(ValueError, match="must be positive"):
 46          RPSRateLimiter(0)
 47      with pytest.raises(ValueError, match="must be positive"):
 48          RPSRateLimiter(-1)
 49  
 50  
 51  def test_sub_one_rps_can_acquire():
 52      # rps < 1.0 was broken: _max_tokens was set to rps, so the bucket could never
 53      # accumulate a full token and acquire() would loop forever.
 54      clock = FakeClock()
 55      limiter = RPSRateLimiter(0.5, clock=clock.monotonic, sleep=clock.sleep)  # 1 req / 2s
 56  
 57      # First acquire: initial tokens=0.5, sleeps 1s to reach 1.0, then succeeds.
 58      # Second acquire: tokens=0, sleeps 2s to reach 1.0, then succeeds.
 59      limiter.acquire()
 60      limiter.acquire()
 61  
 62      total_sleep = sum(clock.sleep_calls)
 63      assert total_sleep == pytest.approx(3.0, abs=0.1)
 64  
 65  
 66  def test_burst_tokens_consumed_without_sleeping():
 67      clock = FakeClock()
 68      limiter = RPSRateLimiter(5, clock=clock.monotonic, sleep=clock.sleep)
 69  
 70      for _ in range(5):
 71          limiter.acquire()
 72  
 73      assert clock.sleep_calls == []
 74  
 75  
 76  def test_sleep_called_when_tokens_exhausted():
 77      clock = FakeClock()
 78      limiter = RPSRateLimiter(5, clock=clock.monotonic, sleep=clock.sleep)
 79  
 80      for _ in range(5):
 81          limiter.acquire()
 82  
 83      limiter.acquire()
 84      assert len(clock.sleep_calls) == 1
 85      assert clock.sleep_calls[0] == pytest.approx(0.2, abs=0.01)
 86  
 87  
 88  def test_total_sleep_for_sustained_rate():
 89      clock = FakeClock()
 90      limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep)
 91  
 92      for _ in range(20):
 93          limiter.acquire()
 94  
 95      total_sleep = sum(clock.sleep_calls)
 96      assert total_sleep == pytest.approx(1.0, abs=0.01)
 97  
 98  
 99  def test_tokens_refill_after_idle():
100      clock = FakeClock()
101      limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep)
102  
103      for _ in range(10):
104          limiter.acquire()
105  
106      clock._now += 1.0
107  
108      sleep_before = len(clock.sleep_calls)
109      for _ in range(10):
110          limiter.acquire()
111  
112      assert clock.sleep_calls[sleep_before:] == []
113  
114  
115  def test_partial_refill():
116      clock = FakeClock()
117      limiter = RPSRateLimiter(10, clock=clock.monotonic, sleep=clock.sleep)
118  
119      for _ in range(10):
120          limiter.acquire()
121  
122      clock._now += 0.5
123  
124      sleep_before = len(clock.sleep_calls)
125      for _ in range(5):
126          limiter.acquire()
127  
128      assert clock.sleep_calls[sleep_before:] == []
129  
130      limiter.acquire()
131      assert len(clock.sleep_calls) == sleep_before + 1
132  
133  
134  def test_noop_acquire_does_nothing():
135      limiter = NoOpRateLimiter()
136      for _ in range(1000):
137          limiter.acquire()
138  
139  
140  # ── _make_rate_limiter / _parse_rate_limit tests ──
141  
142  
143  def test_make_rate_limiter_positive_rate():
144      assert isinstance(_make_rate_limiter(10.0), RPSRateLimiter)
145  
146  
147  def test_make_rate_limiter_zero_returns_noop():
148      assert isinstance(_make_rate_limiter(0.0), NoOpRateLimiter)
149  
150  
151  def test_make_rate_limiter_none_returns_noop():
152      assert isinstance(_make_rate_limiter(None), NoOpRateLimiter)
153  
154  
155  def test_make_rate_limiter_adaptive():
156      limiter = _make_rate_limiter(10.0, adaptive=True)
157      assert isinstance(limiter, RPSRateLimiter)
158      assert limiter._adaptive is True
159  
160  
161  @pytest.mark.parametrize(
162      ("raw", "expected_rps", "expected_adaptive"),
163      [
164          ("auto", AUTO_INITIAL_RPS, True),
165          ("AUTO", AUTO_INITIAL_RPS, True),
166          (" Auto ", AUTO_INITIAL_RPS, True),
167          ("25", 25.0, False),
168          ("0", None, False),
169          (None, None, False),
170      ],
171  )
172  def test_parse_rate_limit(raw, expected_rps, expected_adaptive):
173      rps, adaptive = _parse_rate_limit(raw)
174      assert rps == expected_rps
175      assert adaptive == expected_adaptive
176  
177  
178  # ── is_rate_limit_error tests ──
179  
180  
181  class _FakeRateLimitError(Exception):
182      pass
183  
184  
185  _FakeRateLimitError.__name__ = "RateLimitError"
186  
187  
188  class _FakeStatusCodeError(Exception):
189      def __init__(self, status_code):
190          self.status_code = status_code
191          super().__init__(f"HTTP {status_code}")
192  
193  
194  class _FakeResponseError(Exception):
195      def __init__(self, status_code):
196          self.response = type("R", (), {"status_code": status_code})()
197          super().__init__(f"HTTP {status_code}")
198  
199  
200  @pytest.mark.parametrize(
201      ("exc", "expected"),
202      [
203          (_FakeRateLimitError("rate limit"), True),
204          (_FakeStatusCodeError(429), True),
205          (_FakeResponseError(429), True),
206          (Exception("Error 429: too many requests"), True),
207          (Exception("rate limit exceeded"), True),
208          (_FakeStatusCodeError(500), False),
209          (_FakeResponseError(500), False),
210          (Exception("something else entirely"), False),
211          (ValueError("bad value"), False),
212      ],
213  )
214  def test_is_rate_limit_error(exc, expected):
215      assert is_rate_limit_error(exc) == expected
216  
217  
218  # ── AIMD tests ──
219  
220  
221  def test_throttle_halves_rate():
222      clock = FakeClock()
223      limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep)
224  
225      limiter.report_throttle()
226      assert limiter._rps == pytest.approx(5.0)
227  
228  
229  def test_throttle_respects_floor():
230      clock = FakeClock()
231      limiter = RPSRateLimiter(2.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep)
232  
233      # First throttle: 2.0 * 0.5 = 1.0
234      limiter.report_throttle()
235      assert limiter._rps == pytest.approx(1.0)
236  
237      # Second throttle (after cooldown): should stay at floor 1.0
238      clock._now += 10.0
239      limiter.report_throttle()
240      assert limiter._rps == pytest.approx(1.0)
241  
242  
243  def test_throttle_cooldown_coalesces_rapid_signals():
244      clock = FakeClock()
245      limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep)
246  
247      limiter.report_throttle()
248      assert limiter._rps == pytest.approx(5.0)
249  
250      # Within cooldown window — should be ignored
251      clock._now += 1.0
252      limiter.report_throttle()
253      assert limiter._rps == pytest.approx(5.0)
254  
255      # After cooldown — should take effect
256      clock._now += 10.0
257      limiter.report_throttle()
258      assert limiter._rps == pytest.approx(2.5)
259  
260  
261  def test_success_restores_rate():
262      clock = FakeClock()
263      limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep)
264  
265      limiter.report_throttle()
266      assert limiter._rps == pytest.approx(5.0)
267  
268      # Repeatedly report success — rate should climb back past initial
269      for _ in range(100):
270          limiter.report_success()
271  
272      assert limiter._rps > 10.0
273  
274  
275  @pytest.mark.parametrize(
276      ("multiplier", "expected_ceiling"),
277      [(5.0, 50.0), (3.0, 30.0)],
278  )
279  def test_success_climbs_to_multiplier_ceiling(multiplier, expected_ceiling):
280      clock = FakeClock()
281      limiter = RPSRateLimiter(
282          10.0,
283          adaptive=True,
284          max_rps_multiplier=multiplier,
285          clock=clock.monotonic,
286          sleep=clock.sleep,
287      )
288      for _ in range(10000):
289          limiter.report_success()
290      assert limiter._rps == pytest.approx(expected_ceiling)
291  
292  
293  def test_adaptive_false_ignores_throttle_and_success():
294      clock = FakeClock()
295      limiter = RPSRateLimiter(10.0, adaptive=False, clock=clock.monotonic, sleep=clock.sleep)
296  
297      limiter.report_throttle()
298      assert limiter._rps == pytest.approx(10.0)
299  
300      limiter.report_success()
301      assert limiter._rps == pytest.approx(10.0)
302  
303  
304  # ── call_with_retry tests ──
305  
306  
307  def test_call_with_retry_success():
308      sleep_calls = []
309      limiter = NoOpRateLimiter()
310      result = call_with_retry(lambda: 42, limiter, max_retries=3, sleep=sleep_calls.append)
311      assert result == 42
312      assert sleep_calls == []
313  
314  
315  def test_call_with_retry_retries_on_429_then_succeeds():
316      sleep_calls = []
317      limiter = NoOpRateLimiter()
318      attempts = []
319  
320      def flaky_fn():
321          attempts.append(1)
322          if len(attempts) < 3:
323              raise _FakeRateLimitError("rate limited")
324          return "ok"
325  
326      result = call_with_retry(flaky_fn, limiter, max_retries=3, sleep=sleep_calls.append)
327      assert result == "ok"
328      assert len(attempts) == 3
329      # Two retries with exponential backoff: 2^0=1, 2^1=2
330      assert sleep_calls == [1, 2]
331  
332  
333  def test_call_with_retry_non_429_propagates_immediately():
334      sleep_calls = []
335      limiter = NoOpRateLimiter()
336  
337      def always_raises():
338          raise ValueError("bad input")
339  
340      with pytest.raises(ValueError, match="bad input"):
341          call_with_retry(always_raises, limiter, max_retries=3, sleep=sleep_calls.append)
342      assert sleep_calls == []
343  
344  
345  def test_call_with_retry_exhausted_retries():
346      sleep_calls = []
347      limiter = NoOpRateLimiter()
348  
349      def always_rate_limited():
350          raise _FakeRateLimitError("rate limited")
351  
352      with pytest.raises(_FakeRateLimitError, match="rate limited"):
353          call_with_retry(always_rate_limited, limiter, max_retries=2, sleep=sleep_calls.append)
354      # 3 attempts total (initial + 2 retries), 2 sleeps
355      assert len(sleep_calls) == 2
356  
357  
358  def test_call_with_retry_reports_throttle_and_success():
359      clock = FakeClock()
360      limiter = RPSRateLimiter(10.0, adaptive=True, clock=clock.monotonic, sleep=clock.sleep)
361      attempts = []
362  
363      def flaky_fn():
364          attempts.append(1)
365          if len(attempts) == 1:
366              raise _FakeRateLimitError("rate limited")
367          return "ok"
368  
369      result = call_with_retry(flaky_fn, limiter, max_retries=3, sleep=clock.sleep)
370      assert result == "ok"
371      # After throttle: 10.0 * 0.5 = 5.0, then success bumps it back up slightly
372      assert limiter._rps < 10.0
373  
374  
375  # ── eval_retry_context tests ──
376  
377  
378  def _retry_flags_active():
379      """Check that both downstream retry-suppression flags are set."""
380      return is_litellm_rate_limit_retries_disabled() and is_429_retry_disabled()
381  
382  
383  def test_eval_retry_context_sets_and_resets():
384      assert not _retry_flags_active()
385  
386      with eval_retry_context():
387          assert _retry_flags_active()
388  
389      assert not _retry_flags_active()
390  
391  
392  def test_eval_retry_context_nests():
393      assert not _retry_flags_active()
394  
395      with eval_retry_context():
396          assert _retry_flags_active()
397          with eval_retry_context():
398              assert _retry_flags_active()
399          assert _retry_flags_active()
400  
401      assert not _retry_flags_active()
402  
403  
404  def test_litellm_retry_policy_disables_rate_limit_retries_when_flag_set():
405      with eval_retry_context():
406          policy = _get_litellm_retry_policy(3)
407      assert policy.RateLimitErrorRetries == 0
408      assert policy.TimeoutErrorRetries == 3