/ tests / gateway / test_budget_tracker.py
test_budget_tracker.py
  1  from datetime import datetime, timedelta, timezone
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  from mlflow.entities.gateway_budget_policy import (
  7      BudgetAction,
  8      BudgetDuration,
  9      BudgetDurationUnit,
 10      BudgetTargetScope,
 11      BudgetUnit,
 12      GatewayBudgetPolicy,
 13  )
 14  from mlflow.gateway.budget_tracker import (
 15      BudgetTracker,
 16      _compute_window_end,
 17      _compute_window_start,
 18      _policy_applies,
 19  )
 20  from mlflow.gateway.budget_tracker.in_memory import InMemoryBudgetTracker
 21  
 22  
 23  def _make_policy(
 24      budget_policy_id="bp-test",
 25      budget_amount=100.0,
 26      duration=None,
 27      target_scope=BudgetTargetScope.GLOBAL,
 28      budget_action=BudgetAction.ALERT,
 29      workspace=None,
 30  ):
 31      return GatewayBudgetPolicy(
 32          budget_policy_id=budget_policy_id,
 33          budget_unit=BudgetUnit.USD,
 34          budget_amount=budget_amount,
 35          duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1),
 36          target_scope=target_scope,
 37          budget_action=budget_action,
 38          created_at=0,
 39          last_updated_at=0,
 40          workspace=workspace,
 41      )
 42  
 43  
 44  # --- _compute_window_start tests ---
 45  
 46  
 47  def test_compute_window_start_minutes():
 48      now = datetime(2025, 6, 15, 10, 37, 0, tzinfo=timezone.utc)
 49      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=15), now)
 50      assert start == datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc)
 51  
 52  
 53  def test_compute_window_start_hours():
 54      now = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc)
 55      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.HOURS, value=2), now)
 56      assert start == datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc)
 57  
 58  
 59  def test_compute_window_start_days():
 60      now = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc)
 61      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.DAYS, value=7), now)
 62      epoch = datetime(1970, 1, 1, tzinfo=timezone.utc)
 63      days_since_epoch = (now - epoch).days
 64      window_index = days_since_epoch // 7
 65      expected = epoch + timedelta(days=window_index * 7)
 66      assert start == expected
 67  
 68  
 69  def test_compute_window_start_weeks():
 70      # June 15, 2025 is a Sunday — window should start on that Sunday
 71      now = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc)
 72      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=1), now)
 73      assert start == datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc)
 74      assert start.weekday() == 6  # Sunday
 75  
 76      # Wednesday mid-week — window should start on preceding Sunday
 77      now = datetime(2025, 6, 18, 14, 30, 0, tzinfo=timezone.utc)
 78      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=1), now)
 79      assert start == datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc)
 80      assert start.weekday() == 6  # Sunday
 81  
 82      # Multi-week (2-week) windows also start on Sundays
 83      now = datetime(2025, 6, 20, 0, 0, 0, tzinfo=timezone.utc)
 84      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=2), now)
 85      assert start.weekday() == 6  # Sunday
 86  
 87  
 88  def test_compute_window_start_months():
 89      now = datetime(2025, 8, 15, tzinfo=timezone.utc)
 90      start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), now)
 91      # Total months from epoch: (2025-1970)*12 + (8-1) = 660 + 7 = 667
 92      # Window index: 667 // 3 = 222, window_start_months = 666
 93      # start_year = 1970 + 666//12 = 1970 + 55 = 2025, start_month = (666%12)+1 = 7
 94      assert start == datetime(2025, 7, 1, tzinfo=timezone.utc)
 95  
 96  
 97  # --- _compute_window_end tests ---
 98  
 99  
100  def test_compute_window_end_minutes():
101      start = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc)
102      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=15), start)
103      assert end == datetime(2025, 6, 15, 10, 45, 0, tzinfo=timezone.utc)
104  
105  
106  def test_compute_window_end_hours():
107      start = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc)
108      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.HOURS, value=2), start)
109      assert end == datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc)
110  
111  
112  def test_compute_window_end_days():
113      start = datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc)
114      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.DAYS, value=7), start)
115      assert end == datetime(2025, 6, 22, 0, 0, 0, tzinfo=timezone.utc)
116  
117  
118  def test_compute_window_end_weeks():
119      start = datetime(2025, 6, 12, 0, 0, 0, tzinfo=timezone.utc)
120      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=2), start)
121      assert end == datetime(2025, 6, 26, 0, 0, 0, tzinfo=timezone.utc)
122  
123  
124  def test_compute_window_end_months():
125      start = datetime(2025, 7, 1, tzinfo=timezone.utc)
126      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), start)
127      assert end == datetime(2025, 10, 1, tzinfo=timezone.utc)
128  
129  
130  def test_compute_window_end_months_crosses_year():
131      start = datetime(2025, 11, 1, tzinfo=timezone.utc)
132      end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), start)
133      assert end == datetime(2026, 2, 1, tzinfo=timezone.utc)
134  
135  
136  # --- _policy_applies tests ---
137  
138  
139  def test_policy_applies_global():
140      policy = _make_policy(target_scope=BudgetTargetScope.GLOBAL)
141      assert _policy_applies(policy, None) is True
142      assert _policy_applies(policy, "ws1") is True
143  
144  
145  def test_policy_applies_workspace_match():
146      policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1")
147      assert _policy_applies(policy, "ws1") is True
148  
149  
150  def test_policy_applies_workspace_no_match():
151      policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1")
152      assert _policy_applies(policy, "ws2") is False
153  
154  
155  def test_policy_applies_workspace_none():
156      policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1")
157      assert _policy_applies(policy, None) is False
158  
159  
160  def test_policy_applies_workspace_none_matches_default():
161      policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE)
162      # policy.workspace resolves to "default" via __post_init__
163      assert _policy_applies(policy, None) is True
164  
165  
166  # --- InMemoryBudgetTracker tests ---
167  
168  
169  def test_in_memory_tracker_is_budget_tracker():
170      tracker = InMemoryBudgetTracker()
171      assert isinstance(tracker, BudgetTracker)
172  
173  
174  def test_record_cost_below_limit():
175      tracker = InMemoryBudgetTracker()
176      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
177  
178      newly_exceeded = tracker.record_cost(50.0)
179      assert newly_exceeded == []
180  
181      window = tracker._get_window_info("bp-test")
182      assert window.cumulative_spend == 50.0
183      assert window.exceeded is False
184  
185  
186  def test_record_cost_exceeds_threshold():
187      tracker = InMemoryBudgetTracker()
188      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
189  
190      newly_exceeded = tracker.record_cost(150.0)
191      assert len(newly_exceeded) == 1
192      assert newly_exceeded[0].policy.budget_policy_id == "bp-test"
193  
194      window = tracker._get_window_info("bp-test")
195      assert window.cumulative_spend == 150.0
196      assert window.exceeded is True
197  
198  
199  def test_record_cost_exceeds_only_once():
200      tracker = InMemoryBudgetTracker()
201      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
202  
203      exceeded1 = tracker.record_cost(150.0)
204      assert len(exceeded1) == 1
205  
206      exceeded2 = tracker.record_cost(50.0)
207      assert exceeded2 == []
208  
209      window = tracker._get_window_info("bp-test")
210      assert window.cumulative_spend == 200.0
211  
212  
213  def test_record_cost_incremental_exceeding():
214      tracker = InMemoryBudgetTracker()
215      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
216  
217      assert tracker.record_cost(60.0) == []
218      exceeded = tracker.record_cost(50.0)
219      assert len(exceeded) == 1
220      assert tracker._get_window_info("bp-test").cumulative_spend == 110.0
221  
222  
223  def test_should_reject_request_reject():
224      tracker = InMemoryBudgetTracker()
225      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)])
226  
227      tracker.record_cost(150.0)
228      exceeded, window = tracker.should_reject_request()
229      assert exceeded is True
230      assert window.policy.budget_policy_id == "bp-test"
231  
232  
233  def test_should_reject_request_alert_only():
234      tracker = InMemoryBudgetTracker()
235      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.ALERT)])
236  
237      tracker.record_cost(150.0)
238      exceeded, window = tracker.should_reject_request()
239      assert exceeded is False
240      assert window is None
241  
242  
243  def test_should_reject_request_not_yet():
244      tracker = InMemoryBudgetTracker()
245      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)])
246  
247      tracker.record_cost(50.0)
248      exceeded, window = tracker.should_reject_request()
249      assert exceeded is False
250      assert window is None
251  
252  
253  def test_window_resets_on_expiry():
254      tracker = InMemoryBudgetTracker()
255      policy = _make_policy(
256          budget_amount=100.0,
257          duration=BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=5),
258      )
259      tracker.refresh_policies([policy])
260      tracker.record_cost(80.0)
261  
262      window = tracker._get_window_info("bp-test")
263      assert window.cumulative_spend == 80.0
264  
265      # Simulate time passing beyond window end
266      with patch(
267          "mlflow.gateway.budget_tracker.in_memory.datetime",
268      ) as mock_dt:
269          mock_dt.now.return_value = window.window_end + timedelta(seconds=1)
270          mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw)
271          newly_exceeded = tracker.record_cost(10.0)
272  
273      window = tracker._get_window_info("bp-test")
274      assert window.cumulative_spend == 10.0
275      assert window.exceeded is False
276      assert newly_exceeded == []
277  
278  
279  def test_refresh_policies_preserves_spend_in_same_window():
280      tracker = InMemoryBudgetTracker()
281      policy = _make_policy(budget_amount=100.0)
282      tracker.refresh_policies([policy])
283      tracker.record_cost(60.0)
284  
285      # Reload same policy — spend should be preserved
286      tracker.refresh_policies([policy])
287      window = tracker._get_window_info("bp-test")
288      assert window.cumulative_spend == 60.0
289  
290  
291  def test_refresh_policies_removes_deleted_policy():
292      tracker = InMemoryBudgetTracker()
293      policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0)
294      policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0)
295      tracker.refresh_policies([policy1, policy2])
296      tracker.record_cost(50.0)
297  
298      # Reload with only policy1 — policy2 window should be gone
299      tracker.refresh_policies([policy1])
300      assert tracker._get_window_info("bp-1") is not None
301      assert tracker._get_window_info("bp-2") is None
302  
303  
304  def test_multiple_policies_independent():
305      tracker = InMemoryBudgetTracker()
306      policy_alert = _make_policy(
307          budget_policy_id="bp-alert",
308          budget_amount=50.0,
309          budget_action=BudgetAction.ALERT,
310      )
311      policy_reject = _make_policy(
312          budget_policy_id="bp-reject",
313          budget_amount=100.0,
314          budget_action=BudgetAction.REJECT,
315      )
316      tracker.refresh_policies([policy_alert, policy_reject])
317  
318      exceeded = tracker.record_cost(75.0)
319      # Only the alert policy should be exceeded (50 < 75)
320      assert len(exceeded) == 1
321      assert exceeded[0].policy.budget_policy_id == "bp-alert"
322  
323      # Reject policy should be at 75, not exceeded yet
324      exceeded, _ = tracker.should_reject_request()
325      assert exceeded is False
326  
327      # Push reject over threshold
328      tracker.record_cost(30.0)
329      exceeded, window = tracker.should_reject_request()
330      assert exceeded is True
331      assert window.policy.budget_policy_id == "bp-reject"
332  
333  
334  def test_workspace_scoped_cost_recording():
335      tracker = InMemoryBudgetTracker()
336      policy = _make_policy(
337          target_scope=BudgetTargetScope.WORKSPACE,
338          workspace="ws1",
339          budget_amount=100.0,
340      )
341      tracker.refresh_policies([policy])
342  
343      # Cost from different workspace — should not apply
344      tracker.record_cost(200.0, workspace="ws2")
345      window = tracker._get_window_info("bp-test")
346      assert window.cumulative_spend == 0.0
347  
348      # Cost from matching workspace — should apply
349      tracker.record_cost(50.0, workspace="ws1")
350      assert window.cumulative_spend == 50.0
351  
352  
353  @pytest.mark.parametrize("duration_unit", list(BudgetDurationUnit))
354  def test_all_duration_units_window_consistency(duration_unit):
355      now = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc)
356      duration = BudgetDuration(unit=duration_unit, value=1)
357      start = _compute_window_start(duration, now)
358      end = _compute_window_end(duration, start)
359      assert start < end
360      assert start <= now < end
361  
362  
363  # --- refresh_policies return value tests ---
364  
365  
366  def test_refresh_policies_returns_all_windows():
367      tracker = InMemoryBudgetTracker()
368      policy1 = _make_policy(budget_policy_id="bp-1")
369      policy2 = _make_policy(budget_policy_id="bp-2")
370  
371      windows = tracker.refresh_policies([policy1, policy2])
372      assert len(windows) == 2
373      ids = {w.policy.budget_policy_id for w in windows}
374      assert ids == {"bp-1", "bp-2"}
375  
376  
377  def test_refresh_policies_returns_all_windows_on_reload():
378      tracker = InMemoryBudgetTracker()
379      policy = _make_policy()
380  
381      windows = tracker.refresh_policies([policy])
382      assert len(windows) == 1
383  
384      # Reload same policy within same window — still returns the existing window
385      windows = tracker.refresh_policies([policy])
386      assert len(windows) == 1
387  
388  
389  def test_refresh_policies_returns_all_windows_on_mixed():
390      tracker = InMemoryBudgetTracker()
391      policy1 = _make_policy(budget_policy_id="bp-1")
392      tracker.refresh_policies([policy1])
393  
394      policy2 = _make_policy(budget_policy_id="bp-2")
395      windows = tracker.refresh_policies([policy1, policy2])
396      assert len(windows) == 2
397      ids = {w.policy.budget_policy_id for w in windows}
398      assert ids == {"bp-1", "bp-2"}
399  
400  
401  # --- backfill_spend tests ---
402  
403  
404  def test_backfill_spend_sets_cumulative():
405      tracker = InMemoryBudgetTracker()
406      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
407  
408      tracker.backfill_spend({"bp-test": 42.5})
409      window = tracker._get_window_info("bp-test")
410      assert window.cumulative_spend == 42.5
411      assert window.exceeded is False
412  
413  
414  def test_backfill_spend_sets_exceeded_when_exceeds():
415      tracker = InMemoryBudgetTracker()
416      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
417  
418      tracker.backfill_spend({"bp-test": 150.0})
419      window = tracker._get_window_info("bp-test")
420      assert window.cumulative_spend == 150.0
421      assert window.exceeded is True
422  
423  
424  def test_backfill_spend_sets_exceeded_at_exact_limit():
425      tracker = InMemoryBudgetTracker()
426      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
427  
428      tracker.backfill_spend({"bp-test": 100.0})
429      window = tracker._get_window_info("bp-test")
430      assert window.cumulative_spend == 100.0
431      assert window.exceeded is True
432  
433  
434  def test_backfill_spend_nonexistent_is_noop():
435      tracker = InMemoryBudgetTracker()
436      tracker.refresh_policies([_make_policy()])
437      # Should not raise
438      tracker.backfill_spend({"nonexistent-policy": 50.0})
439  
440  
441  def test_backfill_spend_uses_max_to_protect_in_process_spend():
442      # Simulate trace-flush lag: in-process spend is ahead of what DB reports.
443      # backfill_spend must not decrease cumulative_spend below the in-process value.
444      tracker = InMemoryBudgetTracker()
445      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
446      tracker.record_cost(30.0)  # in-process spend = 30
447  
448      tracker.backfill_spend({"bp-test": 10.0})  # DB is behind
449      window = tracker._get_window_info("bp-test")
450      assert window.cumulative_spend == 30.0  # in-process value preserved
451  
452  
453  # --- get_all_windows tests ---
454  
455  
456  def test_get_all_windows_empty():
457      tracker = InMemoryBudgetTracker()
458      assert tracker.get_all_windows() == []
459  
460  
461  def test_get_all_windows_returns_all_policies():
462      tracker = InMemoryBudgetTracker()
463      policy1 = _make_policy(budget_policy_id="bp-1")
464      policy2 = _make_policy(budget_policy_id="bp-2")
465      tracker.refresh_policies([policy1, policy2])
466  
467      windows = tracker.get_all_windows()
468      assert len(windows) == 2
469      ids = {w.policy.budget_policy_id for w in windows}
470      assert ids == {"bp-1", "bp-2"}
471  
472  
473  def test_get_all_windows_reflects_current_spend():
474      tracker = InMemoryBudgetTracker()
475      policy = _make_policy(budget_policy_id="bp-spend", budget_amount=100.0)
476      tracker.refresh_policies([policy])
477      tracker.record_cost(42.5)
478  
479      windows = tracker.get_all_windows()
480      assert len(windows) == 1
481      assert windows[0].cumulative_spend == 42.5
482  
483  
484  def test_get_all_windows_after_policy_removed():
485      tracker = InMemoryBudgetTracker()
486      policy1 = _make_policy(budget_policy_id="bp-1")
487      policy2 = _make_policy(budget_policy_id="bp-2")
488      tracker.refresh_policies([policy1, policy2])
489  
490      tracker.refresh_policies([policy1])
491  
492      windows = tracker.get_all_windows()
493      assert len(windows) == 1
494      assert windows[0].policy.budget_policy_id == "bp-1"