test_redis_budget_tracker.py
1 from datetime import datetime, timedelta, timezone 2 from unittest.mock import patch 3 4 import pytest 5 6 from mlflow.entities.gateway_budget_policy import ( 7 BudgetAction, 8 BudgetDuration, 9 BudgetDurationUnit, 10 BudgetTargetScope, 11 BudgetUnit, 12 GatewayBudgetPolicy, 13 ) 14 from mlflow.gateway.budget_tracker import BudgetTracker 15 16 fakeredis = pytest.importorskip("fakeredis") 17 18 19 def _make_policy( 20 budget_policy_id="bp-test", 21 budget_amount=100.0, 22 duration=None, 23 target_scope=BudgetTargetScope.GLOBAL, 24 budget_action=BudgetAction.ALERT, 25 workspace=None, 26 ): 27 return GatewayBudgetPolicy( 28 budget_policy_id=budget_policy_id, 29 budget_unit=BudgetUnit.USD, 30 budget_amount=budget_amount, 31 duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1), 32 target_scope=target_scope, 33 budget_action=budget_action, 34 created_at=0, 35 last_updated_at=0, 36 workspace=workspace, 37 ) 38 39 40 def _make_tracker(): 41 from mlflow.gateway.budget_tracker.redis import RedisBudgetTracker 42 43 client = fakeredis.FakeRedis(decode_responses=True) 44 return RedisBudgetTracker(_client=client) 45 46 47 def test_redis_tracker_is_budget_tracker(): 48 tracker = _make_tracker() 49 assert isinstance(tracker, BudgetTracker) 50 51 52 def test_record_cost_below_limit(): 53 tracker = _make_tracker() 54 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 55 56 newly_exceeded = tracker.record_cost(50.0) 57 assert newly_exceeded == [] 58 59 window = tracker._get_window_info("bp-test") 60 assert window.cumulative_spend == 50.0 61 assert window.exceeded is False 62 63 64 def test_record_cost_exceeds_threshold(): 65 tracker = _make_tracker() 66 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 67 68 newly_exceeded = tracker.record_cost(150.0) 69 assert len(newly_exceeded) == 1 70 assert newly_exceeded[0].policy.budget_policy_id == "bp-test" 71 72 window = tracker._get_window_info("bp-test") 73 assert window.cumulative_spend == 150.0 74 assert window.exceeded is True 75 76 77 def test_record_cost_exceeds_only_once(): 78 tracker = _make_tracker() 79 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 80 81 exceeded1 = tracker.record_cost(150.0) 82 assert len(exceeded1) == 1 83 84 exceeded2 = tracker.record_cost(50.0) 85 assert exceeded2 == [] 86 87 window = tracker._get_window_info("bp-test") 88 assert window.cumulative_spend == 200.0 89 90 91 def test_record_cost_incremental_exceeding(): 92 tracker = _make_tracker() 93 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 94 95 assert tracker.record_cost(60.0) == [] 96 exceeded = tracker.record_cost(50.0) 97 assert len(exceeded) == 1 98 assert tracker._get_window_info("bp-test").cumulative_spend == 110.0 99 100 101 def test_should_reject_request_reject(): 102 tracker = _make_tracker() 103 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)]) 104 105 tracker.record_cost(150.0) 106 exceeded, window = tracker.should_reject_request() 107 assert exceeded is True 108 assert window.policy.budget_policy_id == "bp-test" 109 110 111 def test_should_reject_request_alert_only(): 112 tracker = _make_tracker() 113 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.ALERT)]) 114 115 tracker.record_cost(150.0) 116 exceeded, window = tracker.should_reject_request() 117 assert exceeded is False 118 assert window is None 119 120 121 def test_should_reject_request_not_yet(): 122 tracker = _make_tracker() 123 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)]) 124 125 tracker.record_cost(50.0) 126 exceeded, window = tracker.should_reject_request() 127 assert exceeded is False 128 assert window is None 129 130 131 def test_refresh_policies_removes_deleted_policy(): 132 tracker = _make_tracker() 133 policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0) 134 policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0) 135 tracker.refresh_policies([policy1, policy2]) 136 tracker.record_cost(50.0) 137 138 tracker.refresh_policies([policy1]) 139 assert tracker._get_window_info("bp-1") is not None 140 assert tracker._get_window_info("bp-2") is None 141 142 143 def test_multiple_policies_independent(): 144 tracker = _make_tracker() 145 policy_alert = _make_policy( 146 budget_policy_id="bp-alert", 147 budget_amount=50.0, 148 budget_action=BudgetAction.ALERT, 149 ) 150 policy_reject = _make_policy( 151 budget_policy_id="bp-reject", 152 budget_amount=100.0, 153 budget_action=BudgetAction.REJECT, 154 ) 155 tracker.refresh_policies([policy_alert, policy_reject]) 156 157 exceeded = tracker.record_cost(75.0) 158 assert len(exceeded) == 1 159 assert exceeded[0].policy.budget_policy_id == "bp-alert" 160 161 exceeded, _ = tracker.should_reject_request() 162 assert exceeded is False 163 164 tracker.record_cost(30.0) 165 exceeded, window = tracker.should_reject_request() 166 assert exceeded is True 167 assert window.policy.budget_policy_id == "bp-reject" 168 169 170 def test_workspace_scoped_cost_recording(): 171 tracker = _make_tracker() 172 policy = _make_policy( 173 target_scope=BudgetTargetScope.WORKSPACE, 174 workspace="ws1", 175 budget_amount=100.0, 176 ) 177 tracker.refresh_policies([policy]) 178 179 tracker.record_cost(200.0, workspace="ws2") 180 window = tracker._get_window_info("bp-test") 181 assert window.cumulative_spend == 0.0 182 183 tracker.record_cost(50.0, workspace="ws1") 184 window = tracker._get_window_info("bp-test") 185 assert window.cumulative_spend == 50.0 186 187 188 def test_backfill_spend_sets_cumulative(): 189 tracker = _make_tracker() 190 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 191 192 tracker.backfill_spend({"bp-test": 42.5}) 193 window = tracker._get_window_info("bp-test") 194 assert window.cumulative_spend == 42.5 195 assert window.exceeded is False 196 197 198 def test_backfill_spend_sets_exceeded_when_exceeds(): 199 tracker = _make_tracker() 200 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 201 202 tracker.backfill_spend({"bp-test": 150.0}) 203 window = tracker._get_window_info("bp-test") 204 assert window.cumulative_spend == 150.0 205 assert window.exceeded is True 206 207 208 def test_backfill_spend_sets_exceeded_at_exact_limit(): 209 tracker = _make_tracker() 210 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 211 212 tracker.backfill_spend({"bp-test": 100.0}) 213 window = tracker._get_window_info("bp-test") 214 assert window.cumulative_spend == 100.0 215 assert window.exceeded is True 216 217 218 def test_backfill_spend_nonexistent_is_noop(): 219 tracker = _make_tracker() 220 tracker.refresh_policies([_make_policy()]) 221 tracker.backfill_spend({"nonexistent-policy": 50.0}) 222 223 224 def test_refresh_policies_returns_new_windows(): 225 tracker = _make_tracker() 226 policy1 = _make_policy(budget_policy_id="bp-1") 227 policy2 = _make_policy(budget_policy_id="bp-2") 228 229 new_windows = tracker.refresh_policies([policy1, policy2]) 230 assert len(new_windows) == 2 231 ids = {w.policy.budget_policy_id for w in new_windows} 232 assert ids == {"bp-1", "bp-2"} 233 234 235 def test_refresh_policies_is_idempotent_for_existing_policies(): 236 tracker = _make_tracker() 237 policy = _make_policy(budget_policy_id="bp-1", budget_amount=100.0) 238 239 first_windows = tracker.refresh_policies([policy]) 240 assert len(first_windows) == 1 241 242 tracker.backfill_spend({"bp-1": 42.5}) 243 window_before = tracker._get_window_info("bp-1") 244 assert window_before.cumulative_spend == 42.5 245 246 # Second call with the same policy should not create a new window 247 second_windows = tracker.refresh_policies([policy]) 248 assert len(second_windows) == 0 249 250 # Existing window state should be preserved 251 window_after = tracker._get_window_info("bp-1") 252 assert window_after.cumulative_spend == window_before.cumulative_spend 253 assert window_after.exceeded == window_before.exceeded 254 255 256 def test_get_all_windows(): 257 tracker = _make_tracker() 258 policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0) 259 policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0) 260 tracker.refresh_policies([policy1, policy2]) 261 262 tracker.record_cost(75.0) 263 264 windows = tracker.get_all_windows() 265 assert len(windows) == 2 266 by_id = {w.policy.budget_policy_id: w for w in windows} 267 assert by_id["bp-1"].cumulative_spend == 75.0 268 assert by_id["bp-1"].exceeded is False 269 assert by_id["bp-2"].cumulative_spend == 75.0 270 assert by_id["bp-2"].exceeded is False 271 272 273 def test_should_reject_request_workspace_filtering(): 274 tracker = _make_tracker() 275 policy = _make_policy( 276 target_scope=BudgetTargetScope.WORKSPACE, 277 workspace="ws1", 278 budget_amount=100.0, 279 budget_action=BudgetAction.REJECT, 280 ) 281 tracker.refresh_policies([policy]) 282 283 tracker.record_cost(150.0, workspace="ws1") 284 285 exceeded, window = tracker.should_reject_request(workspace="ws2") 286 assert exceeded is False 287 assert window is None 288 289 exceeded, window = tracker.should_reject_request(workspace="ws1") 290 assert exceeded is True 291 assert window.policy.budget_policy_id == "bp-test" 292 293 294 def test_record_cost_at_exact_budget_boundary(): 295 tracker = _make_tracker() 296 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 297 298 exceeded = tracker.record_cost(100.0) 299 assert len(exceeded) == 1 300 assert exceeded[0].policy.budget_policy_id == "bp-test" 301 302 window = tracker._get_window_info("bp-test") 303 assert window.cumulative_spend == 100.0 304 assert window.exceeded is True 305 306 307 def test_window_rollover_resets_spend(): 308 tracker = _make_tracker() 309 tracker.refresh_policies([ 310 _make_policy( 311 budget_amount=100.0, duration=BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=1) 312 ) 313 ]) 314 315 tracker.record_cost(150.0) 316 window = tracker._get_window_info("bp-test") 317 assert window.cumulative_spend == 150.0 318 assert window.exceeded is True 319 320 # Simulate time advancing past the window boundary 321 future = datetime.now(timezone.utc) + timedelta(minutes=2) 322 with patch( 323 "mlflow.gateway.budget_tracker.redis.datetime", 324 ) as mock_dt: 325 mock_dt.now.return_value = future 326 mock_dt.fromisoformat = datetime.fromisoformat 327 328 tracker.record_cost(10.0) 329 330 window = tracker._get_window_info("bp-test") 331 assert window.cumulative_spend == 10.0 332 assert window.exceeded is False 333 334 335 def test_get_budget_tracker_returns_redis_when_configured(): 336 from mlflow.gateway.budget_tracker.redis import RedisBudgetTracker 337 338 with ( 339 patch( 340 "mlflow.gateway.budget_tracker.MLFLOW_GATEWAY_BUDGET_REDIS_URL.get", 341 return_value="redis://localhost:6379/0", 342 ), 343 patch( 344 "mlflow.gateway.budget_tracker.redis.RedisBudgetTracker.__post_init__", 345 ) as mock_init, 346 patch( 347 "mlflow.gateway.budget_tracker._budget_tracker", 348 new=None, 349 ), 350 ): 351 from mlflow.gateway.budget_tracker import get_budget_tracker 352 353 tracker = get_budget_tracker() 354 assert isinstance(tracker, RedisBudgetTracker) 355 mock_init.assert_called_once() 356 357 # Reset the singleton 358 import mlflow.gateway.budget_tracker 359 360 mlflow.gateway.budget_tracker._budget_tracker = None