/ tests / gateway / test_redis_budget_tracker.py
test_redis_budget_tracker.py
  1  from datetime import datetime, timedelta, timezone
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  from mlflow.entities.gateway_budget_policy import (
  7      BudgetAction,
  8      BudgetDuration,
  9      BudgetDurationUnit,
 10      BudgetTargetScope,
 11      BudgetUnit,
 12      GatewayBudgetPolicy,
 13  )
 14  from mlflow.gateway.budget_tracker import BudgetTracker
 15  
 16  fakeredis = pytest.importorskip("fakeredis")
 17  
 18  
 19  def _make_policy(
 20      budget_policy_id="bp-test",
 21      budget_amount=100.0,
 22      duration=None,
 23      target_scope=BudgetTargetScope.GLOBAL,
 24      budget_action=BudgetAction.ALERT,
 25      workspace=None,
 26  ):
 27      return GatewayBudgetPolicy(
 28          budget_policy_id=budget_policy_id,
 29          budget_unit=BudgetUnit.USD,
 30          budget_amount=budget_amount,
 31          duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1),
 32          target_scope=target_scope,
 33          budget_action=budget_action,
 34          created_at=0,
 35          last_updated_at=0,
 36          workspace=workspace,
 37      )
 38  
 39  
 40  def _make_tracker():
 41      from mlflow.gateway.budget_tracker.redis import RedisBudgetTracker
 42  
 43      client = fakeredis.FakeRedis(decode_responses=True)
 44      return RedisBudgetTracker(_client=client)
 45  
 46  
 47  def test_redis_tracker_is_budget_tracker():
 48      tracker = _make_tracker()
 49      assert isinstance(tracker, BudgetTracker)
 50  
 51  
 52  def test_record_cost_below_limit():
 53      tracker = _make_tracker()
 54      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
 55  
 56      newly_exceeded = tracker.record_cost(50.0)
 57      assert newly_exceeded == []
 58  
 59      window = tracker._get_window_info("bp-test")
 60      assert window.cumulative_spend == 50.0
 61      assert window.exceeded is False
 62  
 63  
 64  def test_record_cost_exceeds_threshold():
 65      tracker = _make_tracker()
 66      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
 67  
 68      newly_exceeded = tracker.record_cost(150.0)
 69      assert len(newly_exceeded) == 1
 70      assert newly_exceeded[0].policy.budget_policy_id == "bp-test"
 71  
 72      window = tracker._get_window_info("bp-test")
 73      assert window.cumulative_spend == 150.0
 74      assert window.exceeded is True
 75  
 76  
 77  def test_record_cost_exceeds_only_once():
 78      tracker = _make_tracker()
 79      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
 80  
 81      exceeded1 = tracker.record_cost(150.0)
 82      assert len(exceeded1) == 1
 83  
 84      exceeded2 = tracker.record_cost(50.0)
 85      assert exceeded2 == []
 86  
 87      window = tracker._get_window_info("bp-test")
 88      assert window.cumulative_spend == 200.0
 89  
 90  
 91  def test_record_cost_incremental_exceeding():
 92      tracker = _make_tracker()
 93      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
 94  
 95      assert tracker.record_cost(60.0) == []
 96      exceeded = tracker.record_cost(50.0)
 97      assert len(exceeded) == 1
 98      assert tracker._get_window_info("bp-test").cumulative_spend == 110.0
 99  
100  
101  def test_should_reject_request_reject():
102      tracker = _make_tracker()
103      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)])
104  
105      tracker.record_cost(150.0)
106      exceeded, window = tracker.should_reject_request()
107      assert exceeded is True
108      assert window.policy.budget_policy_id == "bp-test"
109  
110  
111  def test_should_reject_request_alert_only():
112      tracker = _make_tracker()
113      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.ALERT)])
114  
115      tracker.record_cost(150.0)
116      exceeded, window = tracker.should_reject_request()
117      assert exceeded is False
118      assert window is None
119  
120  
121  def test_should_reject_request_not_yet():
122      tracker = _make_tracker()
123      tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)])
124  
125      tracker.record_cost(50.0)
126      exceeded, window = tracker.should_reject_request()
127      assert exceeded is False
128      assert window is None
129  
130  
131  def test_refresh_policies_removes_deleted_policy():
132      tracker = _make_tracker()
133      policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0)
134      policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0)
135      tracker.refresh_policies([policy1, policy2])
136      tracker.record_cost(50.0)
137  
138      tracker.refresh_policies([policy1])
139      assert tracker._get_window_info("bp-1") is not None
140      assert tracker._get_window_info("bp-2") is None
141  
142  
143  def test_multiple_policies_independent():
144      tracker = _make_tracker()
145      policy_alert = _make_policy(
146          budget_policy_id="bp-alert",
147          budget_amount=50.0,
148          budget_action=BudgetAction.ALERT,
149      )
150      policy_reject = _make_policy(
151          budget_policy_id="bp-reject",
152          budget_amount=100.0,
153          budget_action=BudgetAction.REJECT,
154      )
155      tracker.refresh_policies([policy_alert, policy_reject])
156  
157      exceeded = tracker.record_cost(75.0)
158      assert len(exceeded) == 1
159      assert exceeded[0].policy.budget_policy_id == "bp-alert"
160  
161      exceeded, _ = tracker.should_reject_request()
162      assert exceeded is False
163  
164      tracker.record_cost(30.0)
165      exceeded, window = tracker.should_reject_request()
166      assert exceeded is True
167      assert window.policy.budget_policy_id == "bp-reject"
168  
169  
170  def test_workspace_scoped_cost_recording():
171      tracker = _make_tracker()
172      policy = _make_policy(
173          target_scope=BudgetTargetScope.WORKSPACE,
174          workspace="ws1",
175          budget_amount=100.0,
176      )
177      tracker.refresh_policies([policy])
178  
179      tracker.record_cost(200.0, workspace="ws2")
180      window = tracker._get_window_info("bp-test")
181      assert window.cumulative_spend == 0.0
182  
183      tracker.record_cost(50.0, workspace="ws1")
184      window = tracker._get_window_info("bp-test")
185      assert window.cumulative_spend == 50.0
186  
187  
188  def test_backfill_spend_sets_cumulative():
189      tracker = _make_tracker()
190      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
191  
192      tracker.backfill_spend({"bp-test": 42.5})
193      window = tracker._get_window_info("bp-test")
194      assert window.cumulative_spend == 42.5
195      assert window.exceeded is False
196  
197  
198  def test_backfill_spend_sets_exceeded_when_exceeds():
199      tracker = _make_tracker()
200      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
201  
202      tracker.backfill_spend({"bp-test": 150.0})
203      window = tracker._get_window_info("bp-test")
204      assert window.cumulative_spend == 150.0
205      assert window.exceeded is True
206  
207  
208  def test_backfill_spend_sets_exceeded_at_exact_limit():
209      tracker = _make_tracker()
210      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
211  
212      tracker.backfill_spend({"bp-test": 100.0})
213      window = tracker._get_window_info("bp-test")
214      assert window.cumulative_spend == 100.0
215      assert window.exceeded is True
216  
217  
218  def test_backfill_spend_nonexistent_is_noop():
219      tracker = _make_tracker()
220      tracker.refresh_policies([_make_policy()])
221      tracker.backfill_spend({"nonexistent-policy": 50.0})
222  
223  
224  def test_refresh_policies_returns_new_windows():
225      tracker = _make_tracker()
226      policy1 = _make_policy(budget_policy_id="bp-1")
227      policy2 = _make_policy(budget_policy_id="bp-2")
228  
229      new_windows = tracker.refresh_policies([policy1, policy2])
230      assert len(new_windows) == 2
231      ids = {w.policy.budget_policy_id for w in new_windows}
232      assert ids == {"bp-1", "bp-2"}
233  
234  
235  def test_refresh_policies_is_idempotent_for_existing_policies():
236      tracker = _make_tracker()
237      policy = _make_policy(budget_policy_id="bp-1", budget_amount=100.0)
238  
239      first_windows = tracker.refresh_policies([policy])
240      assert len(first_windows) == 1
241  
242      tracker.backfill_spend({"bp-1": 42.5})
243      window_before = tracker._get_window_info("bp-1")
244      assert window_before.cumulative_spend == 42.5
245  
246      # Second call with the same policy should not create a new window
247      second_windows = tracker.refresh_policies([policy])
248      assert len(second_windows) == 0
249  
250      # Existing window state should be preserved
251      window_after = tracker._get_window_info("bp-1")
252      assert window_after.cumulative_spend == window_before.cumulative_spend
253      assert window_after.exceeded == window_before.exceeded
254  
255  
256  def test_get_all_windows():
257      tracker = _make_tracker()
258      policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0)
259      policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0)
260      tracker.refresh_policies([policy1, policy2])
261  
262      tracker.record_cost(75.0)
263  
264      windows = tracker.get_all_windows()
265      assert len(windows) == 2
266      by_id = {w.policy.budget_policy_id: w for w in windows}
267      assert by_id["bp-1"].cumulative_spend == 75.0
268      assert by_id["bp-1"].exceeded is False
269      assert by_id["bp-2"].cumulative_spend == 75.0
270      assert by_id["bp-2"].exceeded is False
271  
272  
273  def test_should_reject_request_workspace_filtering():
274      tracker = _make_tracker()
275      policy = _make_policy(
276          target_scope=BudgetTargetScope.WORKSPACE,
277          workspace="ws1",
278          budget_amount=100.0,
279          budget_action=BudgetAction.REJECT,
280      )
281      tracker.refresh_policies([policy])
282  
283      tracker.record_cost(150.0, workspace="ws1")
284  
285      exceeded, window = tracker.should_reject_request(workspace="ws2")
286      assert exceeded is False
287      assert window is None
288  
289      exceeded, window = tracker.should_reject_request(workspace="ws1")
290      assert exceeded is True
291      assert window.policy.budget_policy_id == "bp-test"
292  
293  
294  def test_record_cost_at_exact_budget_boundary():
295      tracker = _make_tracker()
296      tracker.refresh_policies([_make_policy(budget_amount=100.0)])
297  
298      exceeded = tracker.record_cost(100.0)
299      assert len(exceeded) == 1
300      assert exceeded[0].policy.budget_policy_id == "bp-test"
301  
302      window = tracker._get_window_info("bp-test")
303      assert window.cumulative_spend == 100.0
304      assert window.exceeded is True
305  
306  
307  def test_window_rollover_resets_spend():
308      tracker = _make_tracker()
309      tracker.refresh_policies([
310          _make_policy(
311              budget_amount=100.0, duration=BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=1)
312          )
313      ])
314  
315      tracker.record_cost(150.0)
316      window = tracker._get_window_info("bp-test")
317      assert window.cumulative_spend == 150.0
318      assert window.exceeded is True
319  
320      # Simulate time advancing past the window boundary
321      future = datetime.now(timezone.utc) + timedelta(minutes=2)
322      with patch(
323          "mlflow.gateway.budget_tracker.redis.datetime",
324      ) as mock_dt:
325          mock_dt.now.return_value = future
326          mock_dt.fromisoformat = datetime.fromisoformat
327  
328          tracker.record_cost(10.0)
329  
330      window = tracker._get_window_info("bp-test")
331      assert window.cumulative_spend == 10.0
332      assert window.exceeded is False
333  
334  
335  def test_get_budget_tracker_returns_redis_when_configured():
336      from mlflow.gateway.budget_tracker.redis import RedisBudgetTracker
337  
338      with (
339          patch(
340              "mlflow.gateway.budget_tracker.MLFLOW_GATEWAY_BUDGET_REDIS_URL.get",
341              return_value="redis://localhost:6379/0",
342          ),
343          patch(
344              "mlflow.gateway.budget_tracker.redis.RedisBudgetTracker.__post_init__",
345          ) as mock_init,
346          patch(
347              "mlflow.gateway.budget_tracker._budget_tracker",
348              new=None,
349          ),
350      ):
351          from mlflow.gateway.budget_tracker import get_budget_tracker
352  
353          tracker = get_budget_tracker()
354          assert isinstance(tracker, RedisBudgetTracker)
355          mock_init.assert_called_once()
356  
357          # Reset the singleton
358          import mlflow.gateway.budget_tracker
359  
360          mlflow.gateway.budget_tracker._budget_tracker = None