test_budget_tracker.py
1 from datetime import datetime, timedelta, timezone 2 from unittest.mock import patch 3 4 import pytest 5 6 from mlflow.entities.gateway_budget_policy import ( 7 BudgetAction, 8 BudgetDuration, 9 BudgetDurationUnit, 10 BudgetTargetScope, 11 BudgetUnit, 12 GatewayBudgetPolicy, 13 ) 14 from mlflow.gateway.budget_tracker import ( 15 BudgetTracker, 16 _compute_window_end, 17 _compute_window_start, 18 _policy_applies, 19 ) 20 from mlflow.gateway.budget_tracker.in_memory import InMemoryBudgetTracker 21 22 23 def _make_policy( 24 budget_policy_id="bp-test", 25 budget_amount=100.0, 26 duration=None, 27 target_scope=BudgetTargetScope.GLOBAL, 28 budget_action=BudgetAction.ALERT, 29 workspace=None, 30 ): 31 return GatewayBudgetPolicy( 32 budget_policy_id=budget_policy_id, 33 budget_unit=BudgetUnit.USD, 34 budget_amount=budget_amount, 35 duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1), 36 target_scope=target_scope, 37 budget_action=budget_action, 38 created_at=0, 39 last_updated_at=0, 40 workspace=workspace, 41 ) 42 43 44 # --- _compute_window_start tests --- 45 46 47 def test_compute_window_start_minutes(): 48 now = datetime(2025, 6, 15, 10, 37, 0, tzinfo=timezone.utc) 49 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=15), now) 50 assert start == datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) 51 52 53 def test_compute_window_start_hours(): 54 now = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) 55 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.HOURS, value=2), now) 56 assert start == datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc) 57 58 59 def test_compute_window_start_days(): 60 now = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc) 61 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.DAYS, value=7), now) 62 epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) 63 days_since_epoch = (now - epoch).days 64 window_index = days_since_epoch // 7 65 expected = epoch + timedelta(days=window_index * 7) 66 assert start == expected 67 68 69 def test_compute_window_start_weeks(): 70 # June 15, 2025 is a Sunday — window should start on that Sunday 71 now = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc) 72 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=1), now) 73 assert start == datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc) 74 assert start.weekday() == 6 # Sunday 75 76 # Wednesday mid-week — window should start on preceding Sunday 77 now = datetime(2025, 6, 18, 14, 30, 0, tzinfo=timezone.utc) 78 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=1), now) 79 assert start == datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc) 80 assert start.weekday() == 6 # Sunday 81 82 # Multi-week (2-week) windows also start on Sundays 83 now = datetime(2025, 6, 20, 0, 0, 0, tzinfo=timezone.utc) 84 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=2), now) 85 assert start.weekday() == 6 # Sunday 86 87 88 def test_compute_window_start_months(): 89 now = datetime(2025, 8, 15, tzinfo=timezone.utc) 90 start = _compute_window_start(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), now) 91 # Total months from epoch: (2025-1970)*12 + (8-1) = 660 + 7 = 667 92 # Window index: 667 // 3 = 222, window_start_months = 666 93 # start_year = 1970 + 666//12 = 1970 + 55 = 2025, start_month = (666%12)+1 = 7 94 assert start == datetime(2025, 7, 1, tzinfo=timezone.utc) 95 96 97 # --- _compute_window_end tests --- 98 99 100 def test_compute_window_end_minutes(): 101 start = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) 102 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=15), start) 103 assert end == datetime(2025, 6, 15, 10, 45, 0, tzinfo=timezone.utc) 104 105 106 def test_compute_window_end_hours(): 107 start = datetime(2025, 6, 15, 10, 0, 0, tzinfo=timezone.utc) 108 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.HOURS, value=2), start) 109 assert end == datetime(2025, 6, 15, 12, 0, 0, tzinfo=timezone.utc) 110 111 112 def test_compute_window_end_days(): 113 start = datetime(2025, 6, 15, 0, 0, 0, tzinfo=timezone.utc) 114 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.DAYS, value=7), start) 115 assert end == datetime(2025, 6, 22, 0, 0, 0, tzinfo=timezone.utc) 116 117 118 def test_compute_window_end_weeks(): 119 start = datetime(2025, 6, 12, 0, 0, 0, tzinfo=timezone.utc) 120 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.WEEKS, value=2), start) 121 assert end == datetime(2025, 6, 26, 0, 0, 0, tzinfo=timezone.utc) 122 123 124 def test_compute_window_end_months(): 125 start = datetime(2025, 7, 1, tzinfo=timezone.utc) 126 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), start) 127 assert end == datetime(2025, 10, 1, tzinfo=timezone.utc) 128 129 130 def test_compute_window_end_months_crosses_year(): 131 start = datetime(2025, 11, 1, tzinfo=timezone.utc) 132 end = _compute_window_end(BudgetDuration(unit=BudgetDurationUnit.MONTHS, value=3), start) 133 assert end == datetime(2026, 2, 1, tzinfo=timezone.utc) 134 135 136 # --- _policy_applies tests --- 137 138 139 def test_policy_applies_global(): 140 policy = _make_policy(target_scope=BudgetTargetScope.GLOBAL) 141 assert _policy_applies(policy, None) is True 142 assert _policy_applies(policy, "ws1") is True 143 144 145 def test_policy_applies_workspace_match(): 146 policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1") 147 assert _policy_applies(policy, "ws1") is True 148 149 150 def test_policy_applies_workspace_no_match(): 151 policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1") 152 assert _policy_applies(policy, "ws2") is False 153 154 155 def test_policy_applies_workspace_none(): 156 policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE, workspace="ws1") 157 assert _policy_applies(policy, None) is False 158 159 160 def test_policy_applies_workspace_none_matches_default(): 161 policy = _make_policy(target_scope=BudgetTargetScope.WORKSPACE) 162 # policy.workspace resolves to "default" via __post_init__ 163 assert _policy_applies(policy, None) is True 164 165 166 # --- InMemoryBudgetTracker tests --- 167 168 169 def test_in_memory_tracker_is_budget_tracker(): 170 tracker = InMemoryBudgetTracker() 171 assert isinstance(tracker, BudgetTracker) 172 173 174 def test_record_cost_below_limit(): 175 tracker = InMemoryBudgetTracker() 176 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 177 178 newly_exceeded = tracker.record_cost(50.0) 179 assert newly_exceeded == [] 180 181 window = tracker._get_window_info("bp-test") 182 assert window.cumulative_spend == 50.0 183 assert window.exceeded is False 184 185 186 def test_record_cost_exceeds_threshold(): 187 tracker = InMemoryBudgetTracker() 188 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 189 190 newly_exceeded = tracker.record_cost(150.0) 191 assert len(newly_exceeded) == 1 192 assert newly_exceeded[0].policy.budget_policy_id == "bp-test" 193 194 window = tracker._get_window_info("bp-test") 195 assert window.cumulative_spend == 150.0 196 assert window.exceeded is True 197 198 199 def test_record_cost_exceeds_only_once(): 200 tracker = InMemoryBudgetTracker() 201 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 202 203 exceeded1 = tracker.record_cost(150.0) 204 assert len(exceeded1) == 1 205 206 exceeded2 = tracker.record_cost(50.0) 207 assert exceeded2 == [] 208 209 window = tracker._get_window_info("bp-test") 210 assert window.cumulative_spend == 200.0 211 212 213 def test_record_cost_incremental_exceeding(): 214 tracker = InMemoryBudgetTracker() 215 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 216 217 assert tracker.record_cost(60.0) == [] 218 exceeded = tracker.record_cost(50.0) 219 assert len(exceeded) == 1 220 assert tracker._get_window_info("bp-test").cumulative_spend == 110.0 221 222 223 def test_should_reject_request_reject(): 224 tracker = InMemoryBudgetTracker() 225 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)]) 226 227 tracker.record_cost(150.0) 228 exceeded, window = tracker.should_reject_request() 229 assert exceeded is True 230 assert window.policy.budget_policy_id == "bp-test" 231 232 233 def test_should_reject_request_alert_only(): 234 tracker = InMemoryBudgetTracker() 235 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.ALERT)]) 236 237 tracker.record_cost(150.0) 238 exceeded, window = tracker.should_reject_request() 239 assert exceeded is False 240 assert window is None 241 242 243 def test_should_reject_request_not_yet(): 244 tracker = InMemoryBudgetTracker() 245 tracker.refresh_policies([_make_policy(budget_amount=100.0, budget_action=BudgetAction.REJECT)]) 246 247 tracker.record_cost(50.0) 248 exceeded, window = tracker.should_reject_request() 249 assert exceeded is False 250 assert window is None 251 252 253 def test_window_resets_on_expiry(): 254 tracker = InMemoryBudgetTracker() 255 policy = _make_policy( 256 budget_amount=100.0, 257 duration=BudgetDuration(unit=BudgetDurationUnit.MINUTES, value=5), 258 ) 259 tracker.refresh_policies([policy]) 260 tracker.record_cost(80.0) 261 262 window = tracker._get_window_info("bp-test") 263 assert window.cumulative_spend == 80.0 264 265 # Simulate time passing beyond window end 266 with patch( 267 "mlflow.gateway.budget_tracker.in_memory.datetime", 268 ) as mock_dt: 269 mock_dt.now.return_value = window.window_end + timedelta(seconds=1) 270 mock_dt.side_effect = lambda *args, **kw: datetime(*args, **kw) 271 newly_exceeded = tracker.record_cost(10.0) 272 273 window = tracker._get_window_info("bp-test") 274 assert window.cumulative_spend == 10.0 275 assert window.exceeded is False 276 assert newly_exceeded == [] 277 278 279 def test_refresh_policies_preserves_spend_in_same_window(): 280 tracker = InMemoryBudgetTracker() 281 policy = _make_policy(budget_amount=100.0) 282 tracker.refresh_policies([policy]) 283 tracker.record_cost(60.0) 284 285 # Reload same policy — spend should be preserved 286 tracker.refresh_policies([policy]) 287 window = tracker._get_window_info("bp-test") 288 assert window.cumulative_spend == 60.0 289 290 291 def test_refresh_policies_removes_deleted_policy(): 292 tracker = InMemoryBudgetTracker() 293 policy1 = _make_policy(budget_policy_id="bp-1", budget_amount=100.0) 294 policy2 = _make_policy(budget_policy_id="bp-2", budget_amount=200.0) 295 tracker.refresh_policies([policy1, policy2]) 296 tracker.record_cost(50.0) 297 298 # Reload with only policy1 — policy2 window should be gone 299 tracker.refresh_policies([policy1]) 300 assert tracker._get_window_info("bp-1") is not None 301 assert tracker._get_window_info("bp-2") is None 302 303 304 def test_multiple_policies_independent(): 305 tracker = InMemoryBudgetTracker() 306 policy_alert = _make_policy( 307 budget_policy_id="bp-alert", 308 budget_amount=50.0, 309 budget_action=BudgetAction.ALERT, 310 ) 311 policy_reject = _make_policy( 312 budget_policy_id="bp-reject", 313 budget_amount=100.0, 314 budget_action=BudgetAction.REJECT, 315 ) 316 tracker.refresh_policies([policy_alert, policy_reject]) 317 318 exceeded = tracker.record_cost(75.0) 319 # Only the alert policy should be exceeded (50 < 75) 320 assert len(exceeded) == 1 321 assert exceeded[0].policy.budget_policy_id == "bp-alert" 322 323 # Reject policy should be at 75, not exceeded yet 324 exceeded, _ = tracker.should_reject_request() 325 assert exceeded is False 326 327 # Push reject over threshold 328 tracker.record_cost(30.0) 329 exceeded, window = tracker.should_reject_request() 330 assert exceeded is True 331 assert window.policy.budget_policy_id == "bp-reject" 332 333 334 def test_workspace_scoped_cost_recording(): 335 tracker = InMemoryBudgetTracker() 336 policy = _make_policy( 337 target_scope=BudgetTargetScope.WORKSPACE, 338 workspace="ws1", 339 budget_amount=100.0, 340 ) 341 tracker.refresh_policies([policy]) 342 343 # Cost from different workspace — should not apply 344 tracker.record_cost(200.0, workspace="ws2") 345 window = tracker._get_window_info("bp-test") 346 assert window.cumulative_spend == 0.0 347 348 # Cost from matching workspace — should apply 349 tracker.record_cost(50.0, workspace="ws1") 350 assert window.cumulative_spend == 50.0 351 352 353 @pytest.mark.parametrize("duration_unit", list(BudgetDurationUnit)) 354 def test_all_duration_units_window_consistency(duration_unit): 355 now = datetime(2025, 6, 15, 10, 30, 0, tzinfo=timezone.utc) 356 duration = BudgetDuration(unit=duration_unit, value=1) 357 start = _compute_window_start(duration, now) 358 end = _compute_window_end(duration, start) 359 assert start < end 360 assert start <= now < end 361 362 363 # --- refresh_policies return value tests --- 364 365 366 def test_refresh_policies_returns_all_windows(): 367 tracker = InMemoryBudgetTracker() 368 policy1 = _make_policy(budget_policy_id="bp-1") 369 policy2 = _make_policy(budget_policy_id="bp-2") 370 371 windows = tracker.refresh_policies([policy1, policy2]) 372 assert len(windows) == 2 373 ids = {w.policy.budget_policy_id for w in windows} 374 assert ids == {"bp-1", "bp-2"} 375 376 377 def test_refresh_policies_returns_all_windows_on_reload(): 378 tracker = InMemoryBudgetTracker() 379 policy = _make_policy() 380 381 windows = tracker.refresh_policies([policy]) 382 assert len(windows) == 1 383 384 # Reload same policy within same window — still returns the existing window 385 windows = tracker.refresh_policies([policy]) 386 assert len(windows) == 1 387 388 389 def test_refresh_policies_returns_all_windows_on_mixed(): 390 tracker = InMemoryBudgetTracker() 391 policy1 = _make_policy(budget_policy_id="bp-1") 392 tracker.refresh_policies([policy1]) 393 394 policy2 = _make_policy(budget_policy_id="bp-2") 395 windows = tracker.refresh_policies([policy1, policy2]) 396 assert len(windows) == 2 397 ids = {w.policy.budget_policy_id for w in windows} 398 assert ids == {"bp-1", "bp-2"} 399 400 401 # --- backfill_spend tests --- 402 403 404 def test_backfill_spend_sets_cumulative(): 405 tracker = InMemoryBudgetTracker() 406 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 407 408 tracker.backfill_spend({"bp-test": 42.5}) 409 window = tracker._get_window_info("bp-test") 410 assert window.cumulative_spend == 42.5 411 assert window.exceeded is False 412 413 414 def test_backfill_spend_sets_exceeded_when_exceeds(): 415 tracker = InMemoryBudgetTracker() 416 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 417 418 tracker.backfill_spend({"bp-test": 150.0}) 419 window = tracker._get_window_info("bp-test") 420 assert window.cumulative_spend == 150.0 421 assert window.exceeded is True 422 423 424 def test_backfill_spend_sets_exceeded_at_exact_limit(): 425 tracker = InMemoryBudgetTracker() 426 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 427 428 tracker.backfill_spend({"bp-test": 100.0}) 429 window = tracker._get_window_info("bp-test") 430 assert window.cumulative_spend == 100.0 431 assert window.exceeded is True 432 433 434 def test_backfill_spend_nonexistent_is_noop(): 435 tracker = InMemoryBudgetTracker() 436 tracker.refresh_policies([_make_policy()]) 437 # Should not raise 438 tracker.backfill_spend({"nonexistent-policy": 50.0}) 439 440 441 def test_backfill_spend_uses_max_to_protect_in_process_spend(): 442 # Simulate trace-flush lag: in-process spend is ahead of what DB reports. 443 # backfill_spend must not decrease cumulative_spend below the in-process value. 444 tracker = InMemoryBudgetTracker() 445 tracker.refresh_policies([_make_policy(budget_amount=100.0)]) 446 tracker.record_cost(30.0) # in-process spend = 30 447 448 tracker.backfill_spend({"bp-test": 10.0}) # DB is behind 449 window = tracker._get_window_info("bp-test") 450 assert window.cumulative_spend == 30.0 # in-process value preserved 451 452 453 # --- get_all_windows tests --- 454 455 456 def test_get_all_windows_empty(): 457 tracker = InMemoryBudgetTracker() 458 assert tracker.get_all_windows() == [] 459 460 461 def test_get_all_windows_returns_all_policies(): 462 tracker = InMemoryBudgetTracker() 463 policy1 = _make_policy(budget_policy_id="bp-1") 464 policy2 = _make_policy(budget_policy_id="bp-2") 465 tracker.refresh_policies([policy1, policy2]) 466 467 windows = tracker.get_all_windows() 468 assert len(windows) == 2 469 ids = {w.policy.budget_policy_id for w in windows} 470 assert ids == {"bp-1", "bp-2"} 471 472 473 def test_get_all_windows_reflects_current_spend(): 474 tracker = InMemoryBudgetTracker() 475 policy = _make_policy(budget_policy_id="bp-spend", budget_amount=100.0) 476 tracker.refresh_policies([policy]) 477 tracker.record_cost(42.5) 478 479 windows = tracker.get_all_windows() 480 assert len(windows) == 1 481 assert windows[0].cumulative_spend == 42.5 482 483 484 def test_get_all_windows_after_policy_removed(): 485 tracker = InMemoryBudgetTracker() 486 policy1 = _make_policy(budget_policy_id="bp-1") 487 policy2 = _make_policy(budget_policy_id="bp-2") 488 tracker.refresh_policies([policy1, policy2]) 489 490 tracker.refresh_policies([policy1]) 491 492 windows = tracker.get_all_windows() 493 assert len(windows) == 1 494 assert windows[0].policy.budget_policy_id == "bp-1"