test_autologging_safety_unit.py
1 import abc 2 import copy 3 import inspect 4 from contextlib import nullcontext as does_not_raise 5 from typing import Any, NamedTuple 6 from unittest import mock 7 8 import pytest 9 10 import mlflow 11 from mlflow import MlflowClient 12 from mlflow.entities import RunStatus 13 from mlflow.utils import autologging_utils 14 from mlflow.utils.autologging_utils import ( 15 AutologgingEventLogger, 16 ExceptionSafeAbstractClass, 17 ExceptionSafeClass, 18 autologging_integration, 19 is_testing, 20 picklable_exception_safe_function, 21 safe_patch, 22 with_managed_run, 23 ) 24 from mlflow.utils.autologging_utils.safety import ( 25 ValidationExemptArgument, 26 _AutologgingSessionManager, 27 _validate_args, 28 _validate_autologging_run, 29 ) 30 from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING 31 32 from tests.autologging.async_helper import asyncify, run_sync_or_async 33 from tests.autologging.fixtures import ( 34 patch_destination, # noqa: F401 35 test_mode_off, 36 test_mode_on, 37 ) 38 from tests.autologging.test_autologging_utils import get_func_attrs 39 40 PATCH_DESTINATION_FN_DEFAULT_RESULT = "original_result" 41 42 43 @pytest.fixture(autouse=True) 44 def turn_test_mode_off_by_default(test_mode_off): 45 """ 46 Most of the unit test cases in this module assume that autologging APIs are operating in a 47 standard execution mode (i.e. where test mode is disabled). Accordingly, we turn off autologging 48 test mode for this test module by default. Test cases that verify behaviors specific to test 49 mode enable test mode explicitly by specifying the `test_mode_on` fixture. 50 51 For more information about autologging test mode, see the docstring for 52 :py:func:`mlflow.utils.autologging_utils._is_testing()`. 53 """ 54 55 56 @pytest.fixture 57 def test_autologging_integration(): 58 integration_name = "test_integration" 59 60 @autologging_integration(integration_name) 61 def autolog(disable=False, silent=False): 62 pass 63 64 autolog() 65 66 return integration_name 67 68 69 class MockEventLogger(AutologgingEventLogger): 70 class LoggerCall(NamedTuple): 71 method: str 72 session: Any 73 patch_obj: Any 74 function_name: str 75 call_args: Any 76 call_kwargs: Any 77 exception: Any 78 79 def __init__(self): 80 self.calls = [] 81 82 def reset(self): 83 self.calls = [] 84 85 def log_patch_function_start(self, session, patch_obj, function_name, call_args, call_kwargs): 86 self.calls.append( 87 MockEventLogger.LoggerCall( 88 "patch_start", session, patch_obj, function_name, call_args, call_kwargs, None 89 ) 90 ) 91 92 def log_patch_function_success(self, session, patch_obj, function_name, call_args, call_kwargs): 93 self.calls.append( 94 MockEventLogger.LoggerCall( 95 "patch_success", session, patch_obj, function_name, call_args, call_kwargs, None 96 ) 97 ) 98 99 def log_patch_function_error( 100 self, session, patch_obj, function_name, call_args, call_kwargs, exception 101 ): 102 self.calls.append( 103 MockEventLogger.LoggerCall( 104 "patch_error", session, patch_obj, function_name, call_args, call_kwargs, exception 105 ) 106 ) 107 108 def log_original_function_start( 109 self, session, patch_obj, function_name, call_args, call_kwargs 110 ): 111 self.calls.append( 112 MockEventLogger.LoggerCall( 113 "original_start", session, patch_obj, function_name, call_args, call_kwargs, None 114 ) 115 ) 116 117 def log_original_function_success( 118 self, session, patch_obj, function_name, call_args, call_kwargs 119 ): 120 self.calls.append( 121 MockEventLogger.LoggerCall( 122 "original_success", session, patch_obj, function_name, call_args, call_kwargs, None 123 ) 124 ) 125 126 def log_original_function_error( 127 self, session, patch_obj, function_name, call_args, call_kwargs, exception 128 ): 129 self.calls.append( 130 MockEventLogger.LoggerCall( 131 "original_error", 132 session, 133 patch_obj, 134 function_name, 135 call_args, 136 call_kwargs, 137 exception, 138 ) 139 ) 140 141 142 @pytest.fixture 143 def mock_event_logger(): 144 prev_logger = AutologgingEventLogger.get_logger() 145 try: 146 logger = MockEventLogger() 147 AutologgingEventLogger.set_logger(logger) 148 yield logger 149 finally: 150 AutologgingEventLogger.set_logger(prev_logger) 151 152 153 def test_is_testing_respects_environment_variable(monkeypatch): 154 monkeypatch.delenv("MLFLOW_AUTOLOGGING_TESTING", raising=False) 155 assert not is_testing() 156 157 monkeypatch.setenv("MLFLOW_AUTOLOGGING_TESTING", "false") 158 assert not is_testing() 159 160 monkeypatch.setenv("MLFLOW_AUTOLOGGING_TESTING", "true") 161 assert is_testing() 162 163 164 def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation( 165 patch_destination, test_autologging_integration 166 ): 167 foo_val = None 168 bar_val = None 169 170 @asyncify(patch_destination.is_async) 171 def patch_impl(original, foo, bar=10): 172 nonlocal foo_val 173 nonlocal bar_val 174 foo_val = foo 175 bar_val = bar 176 177 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 178 run_sync_or_async(patch_destination.fn, 7, bar=11) 179 assert foo_val == 7 180 assert bar_val == 11 181 182 183 def test_safe_patch_provides_expected_original_function( 184 test_autologging_integration, patch_destination 185 ): 186 @asyncify(patch_destination.is_async) 187 def original_fn(foo, bar=10): 188 return { 189 "foo": foo, 190 "bar": bar, 191 } 192 193 patch_destination.fn = original_fn 194 195 @asyncify(patch_destination.is_async) 196 def patch_impl(original, foo, bar): 197 return original(foo + 1, bar + 2) 198 199 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 200 201 assert run_sync_or_async(patch_destination.fn, 1, 2) == {"foo": 2, "bar": 4} 202 203 204 def test_safe_patch_propagates_exceptions_raised_from_original_function( 205 patch_destination, test_autologging_integration 206 ): 207 exc_to_throw = Exception("Bad original function") 208 209 @asyncify(patch_destination.is_async) 210 def original(*args, **kwargs): 211 raise exc_to_throw 212 213 patch_destination.fn = original 214 215 patch_impl_called = False 216 217 @asyncify(patch_destination.is_async) 218 def patch_impl(original, *args, **kwargs): 219 nonlocal patch_impl_called 220 patch_impl_called = True 221 return original(*args, **kwargs) 222 223 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 224 225 with pytest.raises(Exception, match=str(exc_to_throw)) as exc: 226 run_sync_or_async(patch_destination.fn) 227 228 assert exc.value == exc_to_throw 229 assert patch_impl_called 230 231 232 def test_safe_patch_logs_exceptions_raised_outside_of_original_function_as_warnings( 233 patch_destination, test_autologging_integration 234 ): 235 exc_to_throw = Exception("Bad patch implementation") 236 237 @asyncify(patch_destination.is_async) 238 def patch_impl(original, *args, **kwargs): 239 raise exc_to_throw 240 241 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 242 with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: 243 assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT 244 assert logger_mock.call_count == 1 245 expected_warning = "Encountered unexpected error during {} autologging: {}".format( 246 test_autologging_integration, exc_to_throw 247 ) 248 assert logger_mock.call_args[0][0] == expected_warning 249 250 251 @pytest.mark.usefixtures(test_mode_on.__name__) 252 def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode( 253 patch_destination, test_autologging_integration 254 ): 255 exc_to_throw = Exception("Bad patch implementation") 256 257 @asyncify(patch_destination.is_async) 258 def patch_impl(original, *args, **kwargs): 259 raise exc_to_throw 260 261 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 262 with pytest.raises(Exception, match=str(exc_to_throw)) as exc: 263 run_sync_or_async(patch_destination.fn) 264 265 assert exc.value == exc_to_throw 266 267 268 def test_safe_patch_calls_original_function_when_patch_preamble_throws( 269 patch_destination, test_autologging_integration 270 ): 271 patch_impl_called = False 272 273 @asyncify(patch_destination.is_async) 274 def patch_impl(original, *args, **kwargs): 275 nonlocal patch_impl_called 276 patch_impl_called = True 277 raise Exception("Bad patch preamble") 278 279 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 280 assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT 281 assert patch_destination.fn_call_count == 1 282 assert patch_impl_called 283 284 285 def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws( 286 patch_destination, test_autologging_integration 287 ): 288 patch_impl_called = False 289 290 @asyncify(patch_destination.is_async) 291 def patch_impl(original, *args, **kwargs): 292 nonlocal patch_impl_called 293 patch_impl_called = True 294 original(*args, **kwargs) 295 raise Exception("Bad patch postamble") 296 297 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 298 assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT 299 assert patch_destination.fn_call_count == 1 300 assert patch_impl_called 301 302 303 def test_safe_patch_respects_disable_flag(patch_destination): 304 patch_impl_call_count = 0 305 306 @autologging_integration("test_respects_disable") 307 def autolog(disable=False, silent=False): 308 @asyncify(patch_destination.is_async) 309 def patch_impl(original, *args, **kwargs): 310 nonlocal patch_impl_call_count 311 patch_impl_call_count += 1 312 return original(*args, **kwargs) 313 314 safe_patch("test_respects_disable", patch_destination, "fn", patch_impl) 315 316 autolog(disable=False) 317 run_sync_or_async(patch_destination.fn) 318 assert patch_impl_call_count == 1 319 320 autolog(disable=True) 321 run_sync_or_async(patch_destination.fn) 322 assert patch_impl_call_count == 1 323 324 325 def test_safe_patch_returns_original_result_and_ignores_patch_return_value( 326 patch_destination, test_autologging_integration 327 ): 328 patch_impl_called = False 329 330 @asyncify(patch_destination.is_async) 331 def patch_impl(original, *args, **kwargs): 332 nonlocal patch_impl_called 333 patch_impl_called = True 334 return 10 335 336 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 337 assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT 338 assert patch_destination.fn_call_count == 1 339 assert patch_impl_called 340 341 342 @pytest.mark.usefixtures(test_mode_on.__name__) 343 def test_safe_patch_validates_arguments_to_original_function_in_test_mode( 344 patch_destination, test_autologging_integration 345 ): 346 @asyncify(patch_destination.is_async) 347 def patch_impl(original, *args, **kwargs): 348 return original("1", "2", "3") 349 350 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 351 352 with ( 353 pytest.raises(Exception, match="does not match expected input"), 354 mock.patch( 355 "mlflow.utils.autologging_utils.safety._validate_args", 356 wraps=autologging_utils.safety._validate_args, 357 ) as validate_mock, 358 ): 359 run_sync_or_async(patch_destination.fn, "a", "b", "c") 360 361 assert validate_mock.call_count == 1 362 363 364 @pytest.mark.usefixtures(test_mode_on.__name__) 365 def test_safe_patch_throws_when_autologging_runs_are_leaked_in_test_mode( 366 patch_destination, test_autologging_integration 367 ): 368 assert autologging_utils.is_testing() 369 370 @asyncify(patch_destination.is_async) 371 def leak_run_patch_impl(original, *args, **kwargs): 372 mlflow.start_run(nested=True) 373 374 safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) 375 with pytest.raises(AssertionError, match="leaked an active run"): 376 run_sync_or_async(patch_destination.fn) 377 378 # End the leaked run 379 mlflow.end_run() 380 381 with mlflow.start_run(): 382 # If a user-generated run existed prior to the autologged training session, we expect 383 # that safe patch will not throw a leaked run exception 384 patch_destination.fn() 385 # End the leaked nested run 386 mlflow.end_run() 387 388 assert not mlflow.active_run() 389 390 391 def test_safe_patch_does_not_throw_when_autologging_runs_are_leaked_in_standard_mode( 392 patch_destination, test_autologging_integration 393 ): 394 assert not autologging_utils.is_testing() 395 396 @asyncify(patch_destination.is_async) 397 def leak_run_patch_impl(original, *args, **kwargs): 398 mlflow.start_run(nested=True) 399 400 safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl) 401 run_sync_or_async(patch_destination.fn) 402 assert mlflow.active_run() 403 404 # End the leaked run 405 mlflow.end_run() 406 407 assert not mlflow.active_run() 408 409 410 @pytest.mark.usefixtures(test_mode_on.__name__) 411 def test_safe_patch_validates_autologging_runs_when_necessary_in_test_mode( 412 patch_destination, test_autologging_integration 413 ): 414 assert autologging_utils.is_testing() 415 416 def no_tag_run_patch(original, *args, **kwargs): 417 with mlflow.start_run(nested=True): 418 return original(*args, **kwargs) 419 420 async def async_no_tag_run_patch(original, *args, **kwargs): 421 with mlflow.start_run(nested=True): 422 return await original(*args, **kwargs) 423 424 if patch_destination.is_async: 425 safe_patch(test_autologging_integration, patch_destination, "fn", async_no_tag_run_patch) 426 else: 427 safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch) 428 429 with mock.patch( 430 "mlflow.utils.autologging_utils.safety._validate_autologging_run", 431 wraps=_validate_autologging_run, 432 ) as validate_run_mock: 433 with pytest.raises( 434 AssertionError, match="failed to set autologging tag with expected value" 435 ): 436 run_sync_or_async(patch_destination.fn) 437 assert validate_run_mock.call_count == 1 438 439 validate_run_mock.reset_mock() 440 441 with mlflow.start_run(nested=True): 442 # If a user-generated run existed prior to the autologged training session, we expect 443 # that safe patch will not attempt to validate it 444 run_sync_or_async(patch_destination.fn) 445 assert not validate_run_mock.called 446 447 448 def test_safe_patch_does_not_validate_autologging_runs_in_standard_mode( 449 patch_destination, test_autologging_integration 450 ): 451 assert not autologging_utils.is_testing() 452 453 @asyncify(patch_destination.is_async) 454 def no_tag_run_patch_impl(original, *args, **kwargs): 455 with mlflow.start_run(nested=True): 456 return original(*args, **kwargs) 457 458 safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl) 459 460 with mock.patch( 461 "mlflow.utils.autologging_utils.safety._validate_autologging_run", 462 wraps=_validate_autologging_run, 463 ) as validate_run_mock: 464 run_sync_or_async(patch_destination.fn) 465 466 with mlflow.start_run(nested=True): 467 # If a user-generated run existed prior to the autologged training session, we expect 468 # that safe patch will not attempt to validate it 469 run_sync_or_async(patch_destination.fn) 470 471 assert not validate_run_mock.called 472 473 474 def test_safe_patch_manages_run_if_specified_and_sets_expected_run_tags( 475 patch_destination, test_autologging_integration 476 ): 477 client = MlflowClient() 478 active_run = None 479 480 @asyncify(patch_destination.is_async) 481 def patch_impl(original, *args, **kwargs): 482 nonlocal active_run 483 active_run = mlflow.active_run() 484 return original(*args, **kwargs) 485 486 if patch_destination.is_async: 487 with pytest.raises(Exception, match="manage_run parameter is not supported"): 488 safe_patch( 489 test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True 490 ) 491 return 492 493 with mock.patch( 494 "mlflow.utils.autologging_utils.safety.with_managed_run", wraps=with_managed_run 495 ) as managed_run_mock: 496 safe_patch( 497 test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True 498 ) 499 500 run_sync_or_async(patch_destination.fn) 501 assert managed_run_mock.call_count == 1 502 assert active_run is not None 503 assert active_run.info.run_id is not None 504 assert ( 505 client.get_run(active_run.info.run_id).data.tags[MLFLOW_AUTOLOGGING] == "test_integration" 506 ) 507 508 509 def test_safe_patch_does_not_manage_run_if_unspecified( 510 patch_destination, test_autologging_integration 511 ): 512 active_run = None 513 514 @asyncify(patch_destination.is_async) 515 def patch_impl(original, *args, **kwargs): 516 nonlocal active_run 517 active_run = mlflow.active_run() 518 return original(*args, **kwargs) 519 520 with mock.patch( 521 "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run 522 ) as managed_run_mock: 523 safe_patch( 524 test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=False 525 ) 526 run_sync_or_async(patch_destination.fn) 527 assert managed_run_mock.call_count == 0 528 assert active_run is None 529 530 531 def test_safe_patch_preserves_signature_of_patched_function( 532 patch_destination, test_autologging_integration 533 ): 534 @asyncify(patch_destination.is_async) 535 def original(a, b, c=10, *, d=11): 536 return 10 537 538 patch_destination.fn = original 539 540 patch_impl_called = False 541 542 @asyncify(patch_destination.is_async) 543 def patch_impl(original, *args, **kwargs): 544 nonlocal patch_impl_called 545 patch_impl_called = True 546 return original(*args, **kwargs) 547 548 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 549 run_sync_or_async(patch_destination.fn, 1, 2) 550 assert patch_impl_called 551 assert inspect.signature(patch_destination.fn) == inspect.signature(original) 552 553 554 def test_safe_patch_provides_original_function_with_expected_signature( 555 patch_destination, test_autologging_integration 556 ): 557 @asyncify(patch_destination.is_async) 558 def original(a, b, c=10, *, d=11): 559 return 10 560 561 patch_destination.fn = original 562 563 original_signature = False 564 565 @asyncify(patch_destination.is_async) 566 def patch_impl(original, *args, **kwargs): 567 nonlocal original_signature 568 original_signature = inspect.signature(original) 569 return original(*args, **kwargs) 570 571 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 572 run_sync_or_async(patch_destination.fn, 1, 2) 573 assert original_signature == inspect.signature(original) 574 575 576 def test_safe_patch_makes_expected_event_logging_calls_for_successful_patch_invocation( 577 patch_destination, 578 test_autologging_integration, 579 mock_event_logger, 580 ): 581 patch_session = None 582 og_call_kwargs = {} 583 584 @asyncify(patch_destination.is_async) 585 def patch_impl(original, *args, **kwargs): 586 nonlocal og_call_kwargs 587 kwargs.update({"extra_func": picklable_exception_safe_function(lambda k: "foo")}) 588 og_call_kwargs = kwargs 589 590 nonlocal patch_session 591 patch_session = _AutologgingSessionManager.active_session() 592 593 original(*args, **kwargs) 594 595 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 596 597 run_sync_or_async(patch_destination.fn, "a", 1, b=2) 598 expected_order = ["patch_start", "original_start", "original_success", "patch_success"] 599 assert [call.method for call in mock_event_logger.calls] == expected_order 600 assert all(call.session == patch_session for call in mock_event_logger.calls) 601 assert all(call.patch_obj == patch_destination for call in mock_event_logger.calls) 602 assert all(call.function_name == "fn" for call in mock_event_logger.calls) 603 patch_start, original_start, original_success, patch_success = mock_event_logger.calls 604 assert patch_start.call_args == patch_success.call_args == ("a", 1) 605 assert patch_start.call_kwargs == patch_success.call_kwargs == {"b": 2} 606 assert original_start.call_args == original_success.call_args == ("a", 1) 607 assert original_start.call_kwargs == original_success.call_kwargs == og_call_kwargs 608 assert patch_start.exception is original_start.exception is None 609 assert patch_success.exception is original_success.exception is None 610 611 612 def test_safe_patch_makes_expected_event_logging_calls_when_patch_impl_throws_and_original_succeeds( 613 patch_destination, 614 test_autologging_integration, 615 mock_event_logger, 616 ): 617 exc_to_raise = Exception("thrown from patch") 618 619 throw_location = None 620 621 @asyncify(patch_destination.is_async) 622 def patch_impl(original, *args, **kwargs): 623 nonlocal throw_location 624 625 if throw_location == "before": 626 raise exc_to_raise 627 628 original(*args, **kwargs) 629 630 if throw_location != "before": 631 raise exc_to_raise 632 633 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 634 635 expected_order = [ 636 "patch_start", 637 "original_start", 638 "original_success", 639 "patch_error", 640 ] 641 642 for throw_location in ["before", "after"]: 643 mock_event_logger.reset() 644 run_sync_or_async(patch_destination.fn) 645 assert [call.method for call in mock_event_logger.calls] == expected_order 646 patch_start, original_start, original_success, patch_error = mock_event_logger.calls 647 assert patch_start.exception is None 648 assert original_start.exception is None 649 assert original_success.exception is None 650 assert patch_error.exception == exc_to_raise 651 652 653 def test_safe_patch_makes_expected_event_logging_calls_when_patch_impl_throws_and_original_throws( 654 patch_destination, 655 test_autologging_integration, 656 mock_event_logger, 657 ): 658 exc_to_raise = Exception("thrown from patch") 659 original_err_to_raise = Exception("throw from original") 660 661 throw_location = None 662 663 @asyncify(patch_destination.is_async) 664 def patch_impl(original, *args, **kwargs): 665 nonlocal throw_location 666 667 if throw_location == "before": 668 raise exc_to_raise 669 670 original(*args, **kwargs) 671 672 if throw_location != "before": 673 raise exc_to_raise 674 675 safe_patch(test_autologging_integration, patch_destination, "throw_error_fn", patch_impl) 676 677 expected_order = ["patch_start", "original_start", "original_error"] 678 679 for throw_location in ["before", "after"]: 680 mock_event_logger.reset() 681 with pytest.raises(Exception, match="throw from original"): 682 run_sync_or_async(patch_destination.throw_error_fn, original_err_to_raise) 683 assert [call.method for call in mock_event_logger.calls] == expected_order 684 patch_start, original_start, original_error = mock_event_logger.calls 685 assert patch_start.exception is None 686 assert original_start.exception is None 687 assert original_error.exception == original_err_to_raise 688 689 690 def test_safe_patch_makes_expected_event_logging_calls_when_original_function_throws( 691 patch_destination, 692 test_autologging_integration, 693 mock_event_logger, 694 ): 695 exc_to_raise = Exception("thrown from patch") 696 697 @asyncify(patch_destination.is_async) 698 def original(*args, **kwargs): 699 raise exc_to_raise 700 701 patch_destination.fn = original 702 703 @asyncify(patch_destination.is_async) 704 def patch_impl(original, *args, **kwargs): 705 original(*args, **kwargs) 706 707 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 708 709 with pytest.raises(Exception, match="thrown from patch"): 710 run_sync_or_async(patch_destination.fn) 711 expected_order = ["patch_start", "original_start", "original_error"] 712 assert [call.method for call in mock_event_logger.calls] == expected_order 713 patch_start, original_start, original_error = mock_event_logger.calls 714 assert patch_start.exception is original_start.exception is None 715 assert original_error.exception == exc_to_raise 716 717 718 @pytest.mark.usefixtures(test_mode_off.__name__) 719 def test_safe_patch_succeeds_when_event_logging_throws_in_standard_mode( 720 patch_destination, 721 test_autologging_integration, 722 ): 723 patch_preamble_called = False 724 patch_postamble_called = False 725 726 @asyncify(patch_destination.is_async) 727 def patch_impl(original, *args, **kwargs): 728 nonlocal patch_preamble_called 729 patch_preamble_called = True 730 original(*args, **kwargs) 731 nonlocal patch_postamble_called 732 patch_postamble_called = True 733 734 safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl) 735 736 class ThrowingLogger(MockEventLogger): 737 def log_patch_function_start( 738 self, session, patch_obj, function_name, call_args, call_kwargs 739 ): 740 super().log_patch_function_start( 741 session, patch_obj, function_name, call_args, call_kwargs 742 ) 743 raise Exception("failed") 744 745 def log_patch_function_success( 746 self, session, patch_obj, function_name, call_args, call_kwargs 747 ): 748 super().log_patch_function_success( 749 session, patch_obj, function_name, call_args, call_kwargs 750 ) 751 raise Exception("failed") 752 753 def log_patch_function_error( 754 self, session, patch_obj, function_name, call_args, call_kwargs, exception 755 ): 756 super().log_patch_function_error( 757 session, patch_obj, function_name, call_args, call_kwargs, exception 758 ) 759 raise Exception("failed") 760 761 def log_original_function_start( 762 self, session, patch_obj, function_name, call_args, call_kwargs 763 ): 764 super().log_original_function_start( 765 session, patch_obj, function_name, call_args, call_kwargs 766 ) 767 raise Exception("failed") 768 769 def log_original_function_success( 770 self, session, patch_obj, function_name, call_args, call_kwargs 771 ): 772 super().log_original_function_success( 773 session, patch_obj, function_name, call_args, call_kwargs 774 ) 775 raise Exception("failed") 776 777 def log_original_function_error( 778 self, session, patch_obj, function_name, call_args, call_kwargs, exception 779 ): 780 super().log_original_function_error( 781 session, patch_obj, function_name, call_args, call_kwargs, exception 782 ) 783 raise Exception("failed") 784 785 logger = ThrowingLogger() 786 AutologgingEventLogger.set_logger(logger) 787 assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT 788 assert patch_preamble_called 789 assert patch_postamble_called 790 expected_calls = ["patch_start", "original_start", "original_success", "patch_success"] 791 assert [call.method for call in logger.calls] == expected_calls 792 793 794 def test_picklable_exception_safe_function_exhibits_expected_behavior_in_standard_mode(): 795 assert not autologging_utils.is_testing() 796 797 @picklable_exception_safe_function 798 def non_throwing_function(): 799 return 10 800 801 assert non_throwing_function() == 10 802 803 exc_to_throw = Exception("bad implementation") 804 805 @picklable_exception_safe_function 806 def throwing_function(): 807 raise exc_to_throw 808 809 with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: 810 throwing_function() 811 assert logger_mock.call_count == 1 812 message, formatting_arg = logger_mock.call_args[0] 813 assert "unexpected error during autologging" in message 814 assert formatting_arg == exc_to_throw 815 816 817 @pytest.mark.usefixtures(test_mode_on.__name__) 818 def test_picklable_exception_safe_function_exhibits_expected_behavior_in_test_mode(): 819 assert autologging_utils.is_testing() 820 821 @picklable_exception_safe_function 822 def non_throwing_function(): 823 return 10 824 825 assert non_throwing_function() == 10 826 827 exc_to_throw = Exception("function error") 828 829 @picklable_exception_safe_function 830 def throwing_function(): 831 raise exc_to_throw 832 833 with pytest.raises(Exception, match=str(exc_to_throw)) as exc: 834 throwing_function() 835 836 assert exc.value == exc_to_throw 837 838 839 @pytest.mark.parametrize( 840 ("baseclass", "metaclass"), 841 [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)], 842 ) 843 def test_exception_safe_class_exhibits_expected_behavior_in_standard_mode(baseclass, metaclass): 844 assert not autologging_utils.is_testing() 845 846 class NonThrowingClass(baseclass, metaclass=metaclass): 847 def function(self): 848 return 10 849 850 assert NonThrowingClass().function() == 10 851 852 exc_to_throw = Exception("function error") 853 854 class ThrowingClass(baseclass, metaclass=metaclass): 855 def function(self): 856 raise exc_to_throw 857 858 with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock: 859 ThrowingClass().function() 860 861 assert logger_mock.call_count == 1 862 863 message, formatting_arg = logger_mock.call_args[0] 864 assert "unexpected error during autologging" in message 865 assert formatting_arg == exc_to_throw 866 867 868 @pytest.mark.usefixtures(test_mode_on.__name__) 869 @pytest.mark.parametrize( 870 ("baseclass", "metaclass"), 871 [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)], 872 ) 873 def test_exception_safe_class_exhibits_expected_behavior_in_test_mode(baseclass, metaclass): 874 assert autologging_utils.is_testing() 875 876 class NonThrowingClass(baseclass, metaclass=metaclass): 877 def function(self): 878 return 10 879 880 assert NonThrowingClass().function() == 10 881 882 exc_to_throw = Exception("function error") 883 884 class ThrowingClass(baseclass, metaclass=metaclass): 885 def function(self): 886 raise exc_to_throw 887 888 with pytest.raises(Exception, match=str(exc_to_throw)) as exc: 889 ThrowingClass().function() 890 891 assert exc.value == exc_to_throw 892 893 894 def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior(): 895 client = MlflowClient() 896 897 def patch_function(original, *args, **kwargs): 898 return mlflow.active_run() 899 900 patch_function = with_managed_run("test_integration", patch_function) 901 902 run1 = patch_function(lambda: "foo") 903 run1_status = client.get_run(run1.info.run_id).info.status 904 assert RunStatus.from_string(run1_status) == RunStatus.FINISHED 905 906 with mlflow.start_run() as active_run: 907 run2 = patch_function(lambda: "foo") 908 909 assert run2 == active_run 910 run2_status = client.get_run(run2.info.run_id).info.status 911 assert RunStatus.from_string(run2_status) == RunStatus.FINISHED 912 913 914 def test_with_managed_run_with_throwing_function_exhibits_expected_behavior(): 915 client = MlflowClient() 916 patch_function_active_run = None 917 918 def patch_function(original, *args, **kwargs): 919 nonlocal patch_function_active_run 920 patch_function_active_run = mlflow.active_run() 921 raise Exception("bad implementation") 922 923 patch_function = with_managed_run("test_integration", patch_function) 924 925 with pytest.raises(Exception, match="bad implementation"): 926 patch_function(lambda: "foo") 927 928 assert patch_function_active_run is not None 929 status1 = client.get_run(patch_function_active_run.info.run_id).info.status 930 assert RunStatus.from_string(status1) == RunStatus.FAILED 931 932 with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"): 933 patch_function(lambda: "foo") 934 assert patch_function_active_run == active_run 935 # `with_managed_run` should not terminate a preexisting MLflow run, 936 # even if the patch function throws 937 status2 = client.get_run(active_run.info.run_id).info.status 938 assert RunStatus.from_string(status2) == RunStatus.FINISHED 939 940 941 def test_with_managed_run_sets_specified_run_tags(): 942 client = MlflowClient() 943 tags_to_set = { 944 "foo": "bar", 945 "num_layers": "7", 946 } 947 948 patch_function_1 = with_managed_run( 949 "test_integration", lambda original, *args, **kwargs: mlflow.active_run(), tags=tags_to_set 950 ) 951 run1 = patch_function_1(lambda: "foo") 952 assert tags_to_set.items() <= client.get_run(run1.info.run_id).data.tags.items() 953 954 955 @pytest.mark.usefixtures(test_mode_on.__name__) 956 def test_with_managed_run_ends_run_on_keyboard_interrupt(): 957 client = MlflowClient() 958 run = None 959 960 def original(): 961 nonlocal run 962 run = mlflow.active_run() 963 raise KeyboardInterrupt 964 965 patch_function_1 = with_managed_run( 966 "test_integration", lambda original, *args, **kwargs: original(*args, **kwargs) 967 ) 968 969 with pytest.raises(KeyboardInterrupt, match=r".*"): 970 patch_function_1(original) 971 972 assert not mlflow.active_run() 973 run_status_1 = client.get_run(run.info.run_id).info.status 974 assert RunStatus.from_string(run_status_1) == RunStatus.FAILED 975 976 977 @pytest.mark.usefixtures(test_mode_on.__name__) 978 def test_validate_args_succeeds_when_arg_sets_are_equivalent_or_identical(): 979 args = (1, "b", ["c"]) 980 kwargs = { 981 "foo": ["bar"], 982 "biz": {"baz": 5}, 983 } 984 985 _validate_args("autologging_integration_name", "function_name", args, kwargs, args, kwargs) 986 _validate_args("autologging_integration_name", "function_name", args, {}, args, {}) 987 _validate_args("autologging_integration_name", "function_name", (), kwargs, (), kwargs) 988 989 args_copy = copy.deepcopy(args) 990 kwargs_copy = copy.deepcopy(kwargs) 991 992 _validate_args( 993 "autologging_integration_name", "function_name", args, kwargs, args_copy, kwargs_copy 994 ) 995 _validate_args("autologging_integration_name", "function_name", args, {}, args_copy, {}) 996 _validate_args("autologging_integration_name", "function_name", (), kwargs, (), kwargs_copy) 997 998 999 @pytest.mark.usefixtures(test_mode_on.__name__) 1000 def test_validate_args_throws_when_extra_args_are_not_functions_classes_or_lists(): 1001 user_call_args = (1, "b", ["c"]) 1002 user_call_kwargs = { 1003 "foo": ["bar"], 1004 "biz": {"baz": 5}, 1005 } 1006 1007 invalid_type_autologging_call_args = copy.deepcopy(user_call_args) 1008 invalid_type_autologging_call_args[2].append(10) 1009 invalid_type_autologging_call_kwargs = copy.deepcopy(user_call_kwargs) 1010 invalid_type_autologging_call_kwargs["new"] = {} 1011 1012 with pytest.raises(Exception, match="Invalid new input"): 1013 _validate_args( 1014 "autologging_integration_name", 1015 "function_name", 1016 user_call_args, 1017 user_call_kwargs, 1018 invalid_type_autologging_call_args, 1019 user_call_kwargs, 1020 ) 1021 1022 with pytest.raises(Exception, match="Invalid new input"): 1023 _validate_args( 1024 "autologging_integration_name", 1025 "function_name", 1026 user_call_args, 1027 user_call_kwargs, 1028 user_call_args, 1029 invalid_type_autologging_call_kwargs, 1030 ) 1031 1032 1033 @pytest.mark.usefixtures(test_mode_on.__name__) 1034 def test_validate_args_throws_when_extra_args_are_not_exception_safe(): 1035 user_call_args = (1, "b", ["c"]) 1036 user_call_kwargs = { 1037 "foo": ["bar"], 1038 "biz": {"baz": 5}, 1039 } 1040 1041 class Unsafe: 1042 pass 1043 1044 unsafe_autologging_call_args = copy.deepcopy(user_call_args) 1045 unsafe_autologging_call_args += (lambda: "foo",) 1046 unsafe_autologging_call_kwargs1 = copy.deepcopy(user_call_kwargs) 1047 unsafe_autologging_call_kwargs1["foo"].append(Unsafe()) 1048 1049 with pytest.raises(Exception, match="not exception-safe"): 1050 _validate_args( 1051 "autologging_integration_name", 1052 "function_name", 1053 user_call_args, 1054 user_call_kwargs, 1055 unsafe_autologging_call_args, 1056 user_call_kwargs, 1057 ) 1058 1059 with pytest.raises(Exception, match="Invalid new input"): 1060 _validate_args( 1061 "autologging_integration_name", 1062 "function_name", 1063 user_call_args, 1064 user_call_kwargs, 1065 user_call_args, 1066 unsafe_autologging_call_kwargs1, 1067 ) 1068 1069 unsafe_autologging_call_kwargs2 = copy.deepcopy(user_call_kwargs) 1070 unsafe_autologging_call_kwargs2["biz"]["new"] = Unsafe() 1071 1072 with pytest.raises(Exception, match="Invalid new input"): 1073 _validate_args( 1074 "autologging_integration_name", 1075 "function_name", 1076 user_call_args, 1077 user_call_kwargs, 1078 user_call_args, 1079 unsafe_autologging_call_kwargs2, 1080 ) 1081 1082 1083 @pytest.mark.usefixtures(test_mode_on.__name__) 1084 @pytest.mark.parametrize( 1085 ("baseclass", "metaclass"), 1086 [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)], 1087 ) 1088 def test_validate_args_succeeds_when_extra_args_are_picklable_exception_safe_functions_or_classes( 1089 baseclass, metaclass 1090 ): 1091 user_call_args = (1, "b", ["c"]) 1092 user_call_kwargs = { 1093 "foo": ["bar"], 1094 } 1095 1096 class Safe(baseclass, metaclass=metaclass): 1097 pass 1098 1099 autologging_call_args = copy.deepcopy(user_call_args) 1100 autologging_call_args[2].append(Safe()) 1101 autologging_call_args += (picklable_exception_safe_function(lambda: "foo"),) 1102 1103 autologging_call_kwargs = copy.deepcopy(user_call_kwargs) 1104 autologging_call_kwargs["foo"].append(picklable_exception_safe_function(lambda: "foo")) 1105 autologging_call_kwargs["new"] = Safe() 1106 1107 _validate_args( 1108 "autologging_integration_name", 1109 "function_name", 1110 user_call_args, 1111 user_call_kwargs, 1112 autologging_call_args, 1113 autologging_call_kwargs, 1114 ) 1115 1116 1117 @pytest.mark.usefixtures(test_mode_on.__name__) 1118 def test_validate_args_throws_when_args_are_omitted(): 1119 user_call_args = (1, "b", ["c"], {"d": "e"}) 1120 user_call_kwargs = { 1121 "foo": ["bar"], 1122 "biz": {"baz": 4, "fuzz": 5}, 1123 } 1124 1125 invalid_autologging_call_args_1 = copy.deepcopy(user_call_args) 1126 invalid_autologging_call_args_1[2].pop() 1127 invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs) 1128 invalid_autologging_call_kwargs_1["foo"].pop() 1129 1130 with pytest.raises(Exception, match="missing from the call"): 1131 _validate_args( 1132 "autologging_integration_name", 1133 "function_name", 1134 user_call_args, 1135 user_call_kwargs, 1136 invalid_autologging_call_args_1, 1137 user_call_kwargs, 1138 ) 1139 1140 with pytest.raises(Exception, match="missing from the call"): 1141 _validate_args( 1142 "autologging_integration_name", 1143 "function_name", 1144 user_call_args, 1145 user_call_kwargs, 1146 user_call_args, 1147 invalid_autologging_call_kwargs_1, 1148 ) 1149 1150 invalid_autologging_call_args_2 = copy.deepcopy(user_call_args)[1:] 1151 invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs) 1152 invalid_autologging_call_kwargs_2.pop("foo") 1153 1154 with pytest.raises(Exception, match="missing from the call"): 1155 _validate_args( 1156 "autologging_integration_name", 1157 "function_name", 1158 user_call_args, 1159 user_call_kwargs, 1160 invalid_autologging_call_args_2, 1161 user_call_kwargs, 1162 ) 1163 1164 with pytest.raises(Exception, match="omit one or more expected keys"): 1165 _validate_args( 1166 "autologging_integration_name", 1167 "function_name", 1168 user_call_args, 1169 user_call_kwargs, 1170 user_call_args, 1171 invalid_autologging_call_kwargs_2, 1172 ) 1173 1174 invalid_autologging_call_args_3 = copy.deepcopy(user_call_args) 1175 invalid_autologging_call_args_3[3].pop("d") 1176 invalid_autologging_call_kwargs_3 = copy.deepcopy(user_call_kwargs) 1177 invalid_autologging_call_kwargs_3["biz"].pop("baz") 1178 1179 with pytest.raises(Exception, match="omit one or more expected keys"): 1180 _validate_args( 1181 "autologging_integration_name", 1182 "function_name", 1183 user_call_args, 1184 user_call_kwargs, 1185 invalid_autologging_call_args_3, 1186 user_call_kwargs, 1187 ) 1188 1189 with pytest.raises(Exception, match="omit one or more expected keys"): 1190 _validate_args( 1191 "autologging_integration_name", 1192 "function_name", 1193 user_call_args, 1194 user_call_kwargs, 1195 user_call_args, 1196 invalid_autologging_call_kwargs_3, 1197 ) 1198 1199 1200 @pytest.mark.usefixtures(test_mode_on.__name__) 1201 def test_validate_args_throws_when_arg_types_or_values_are_changed(): 1202 user_call_args = (1, "b", ["c"]) 1203 user_call_kwargs = { 1204 "foo": ["bar"], 1205 } 1206 1207 invalid_autologging_call_args_1 = copy.deepcopy(user_call_args) 1208 invalid_autologging_call_args_1 = (2,) + invalid_autologging_call_args_1[1:] 1209 invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs) 1210 invalid_autologging_call_kwargs_1["foo"] = ["biz"] 1211 1212 with pytest.raises(Exception, match="does not match expected input"): 1213 _validate_args( 1214 "autologging_integration_name", 1215 "function_name", 1216 user_call_args, 1217 user_call_kwargs, 1218 invalid_autologging_call_args_1, 1219 user_call_kwargs, 1220 ) 1221 1222 with pytest.raises(Exception, match="does not match expected input"): 1223 _validate_args( 1224 "autologging_integration_name", 1225 "function_name", 1226 user_call_args, 1227 user_call_kwargs, 1228 user_call_args, 1229 invalid_autologging_call_kwargs_1, 1230 ) 1231 1232 call_arg_1, call_arg_2, _ = copy.deepcopy(user_call_args) 1233 invalid_autologging_call_args_2 = ({"7": 1}, call_arg_1, call_arg_2) 1234 invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs) 1235 invalid_autologging_call_kwargs_2["foo"] = 8 1236 1237 with pytest.raises(Exception, match="does not match expected type"): 1238 _validate_args( 1239 "autologging_integration_name", 1240 "function_name", 1241 user_call_args, 1242 user_call_kwargs, 1243 invalid_autologging_call_args_2, 1244 user_call_kwargs, 1245 ) 1246 1247 with pytest.raises(Exception, match="does not match expected type"): 1248 _validate_args( 1249 "autologging_integration_name", 1250 "function_name", 1251 user_call_args, 1252 user_call_kwargs, 1253 user_call_args, 1254 invalid_autologging_call_kwargs_2, 1255 ) 1256 1257 1258 @pytest.mark.usefixtures(test_mode_on.__name__) 1259 @pytest.mark.parametrize( 1260 ("expectation", "al_name", "func_name", "user_args", "user_kwargs", "al_args", "al_kwargs"), 1261 [ 1262 ( 1263 does_not_raise(), 1264 "foo", 1265 "fit", 1266 ( 1267 None, 1268 3, 1269 ), 1270 {}, 1271 ( 1272 None, 1273 4, 1274 ), 1275 {}, 1276 ), 1277 (does_not_raise(), "foo", "fit", (), {"x": 3}, (), {"x": 4}), 1278 ( 1279 pytest.raises(AssertionError, match="does not match expected input"), 1280 "foo", 1281 "fit", 1282 (None, None, 3), 1283 {}, 1284 ( 1285 None, 1286 None, 1287 4, 1288 ), 1289 {}, 1290 ), 1291 ( 1292 pytest.raises(AssertionError, match="does not match expected input"), 1293 "foo", 1294 "fit", 1295 (), 1296 {"y": 3}, 1297 (), 1298 {"y": 4}, 1299 ), 1300 ( 1301 pytest.raises(AssertionError, match="does not match expected input"), 1302 "foo2", 1303 "fit", 1304 (), 1305 {"x": 3}, 1306 (), 1307 {"x": 4}, 1308 ), 1309 ( 1310 pytest.raises(AssertionError, match="does not match expected input"), 1311 "foo", 1312 "fit2", 1313 (), 1314 {"x": 3}, 1315 (), 1316 {"x": 4}, 1317 ), 1318 ( 1319 pytest.raises(AssertionError, match="does not match expected type"), 1320 "foo", 1321 "fit", 1322 (), 1323 {"x": [1, 2]}, 1324 (), 1325 {"x": 4}, 1326 ), 1327 ( 1328 pytest.raises(AssertionError, match="does not match expected type"), 1329 "foo", 1330 "bar", 1331 (1,), 1332 {}, 1333 (None,), 1334 {}, 1335 ), 1336 ( 1337 pytest.raises(AssertionError, match="Invalid new input"), 1338 "foo", 1339 "bar", 1340 (None,), 1341 {}, 1342 (2,), 1343 {}, 1344 ), 1345 (does_not_raise(), "ml", "flow", ([1, 2, 3],), {}, ([2],), {}), 1346 (does_not_raise(), "ml", "flow", (), {"cool": [1, 2, 3]}, (), {"cool": [2]}), 1347 ( 1348 pytest.raises(AssertionError, match="does not match expected type"), 1349 "ml", 1350 "flow", 1351 (), 1352 {"cool": 3}, 1353 (), 1354 {"cool": [2]}, 1355 ), 1356 ], 1357 ) 1358 def test_validate_args_respects_validation_exemptions( 1359 expectation, al_name, func_name, user_args, user_kwargs, al_args, al_kwargs 1360 ): 1361 with ( 1362 mock.patch( 1363 "mlflow.utils.autologging_utils.safety._VALIDATION_EXEMPT_ARGUMENTS", 1364 [ 1365 ValidationExemptArgument("foo", "fit", lambda x: isinstance(x, int), 1, "x"), 1366 ValidationExemptArgument("ml", "flow", lambda z: isinstance(z, list), 0, "cool"), 1367 ], 1368 ), 1369 expectation, 1370 ): 1371 _validate_args(al_name, func_name, user_args, user_kwargs, al_args, al_kwargs) 1372 1373 1374 def test_validate_autologging_run_validates_autologging_tag_correctly(): 1375 with mlflow.start_run(): 1376 run_id_1 = mlflow.active_run().info.run_id 1377 1378 with pytest.raises(AssertionError, match="failed to set autologging tag with expected value"): 1379 _validate_autologging_run("test_integration", run_id_1) 1380 1381 with mlflow.start_run(tags={MLFLOW_AUTOLOGGING: "wrong_value"}): 1382 run_id_2 = mlflow.active_run().info.run_id 1383 1384 with pytest.raises( 1385 AssertionError, match="failed to set autologging tag with expected value.*wrong_value" 1386 ): 1387 _validate_autologging_run("test_integration", run_id_2) 1388 1389 with mlflow.start_run(tags={MLFLOW_AUTOLOGGING: "test_integration"}): 1390 run_id_3 = mlflow.active_run().info.run_id 1391 1392 _validate_autologging_run("test_integration", run_id_3) 1393 1394 1395 def test_validate_autologging_run_validates_run_status_correctly(): 1396 valid_autologging_tags = { 1397 MLFLOW_AUTOLOGGING: "test_integration", 1398 } 1399 1400 with mlflow.start_run(tags=valid_autologging_tags) as run_finished: 1401 run_id_finished = run_finished.info.run_id 1402 1403 assert ( 1404 RunStatus.from_string(MlflowClient().get_run(run_id_finished).info.status) 1405 == RunStatus.FINISHED 1406 ) 1407 _validate_autologging_run("test_integration", run_id_finished) 1408 1409 with mlflow.start_run(tags=valid_autologging_tags) as run_failed: 1410 run_id_failed = run_failed.info.run_id 1411 1412 MlflowClient().set_terminated(run_id_failed, status=RunStatus.to_string(RunStatus.FAILED)) 1413 assert ( 1414 RunStatus.from_string(MlflowClient().get_run(run_id_failed).info.status) == RunStatus.FAILED 1415 ) 1416 _validate_autologging_run("test_integration", run_id_finished) 1417 1418 run_non_terminal = MlflowClient().create_run( 1419 experiment_id=run_finished.info.experiment_id, tags=valid_autologging_tags 1420 ) 1421 run_id_non_terminal = run_non_terminal.info.run_id 1422 assert ( 1423 RunStatus.from_string(MlflowClient().get_run(run_id_non_terminal).info.status) 1424 == RunStatus.RUNNING 1425 ) 1426 with pytest.raises(AssertionError, match="has a non-terminal status"): 1427 _validate_autologging_run("test_integration", run_id_non_terminal) 1428 1429 1430 def test_session_manager_creates_session_before_patch_executes( 1431 patch_destination, test_autologging_integration 1432 ): 1433 is_session_active = None 1434 1435 @asyncify(patch_destination.is_async) 1436 def check_session_manager_status(original): 1437 nonlocal is_session_active 1438 is_session_active = _AutologgingSessionManager.active_session() 1439 1440 safe_patch(test_autologging_integration, patch_destination, "fn", check_session_manager_status) 1441 run_sync_or_async(patch_destination.fn) 1442 assert is_session_active is not None 1443 1444 1445 def test_session_manager_exits_session_after_patch_executes( 1446 patch_destination, test_autologging_integration 1447 ): 1448 @asyncify(patch_destination.is_async) 1449 def patch_fn(original): 1450 assert _AutologgingSessionManager.active_session() is not None 1451 1452 safe_patch(test_autologging_integration, patch_destination, "fn", patch_fn) 1453 run_sync_or_async(patch_destination.fn) 1454 assert _AutologgingSessionManager.active_session() is None 1455 1456 1457 def test_session_manager_terminates_session_when_appropriate(): 1458 with _AutologgingSessionManager.start_session("test_integration") as outer_sess: 1459 assert outer_sess 1460 1461 with _AutologgingSessionManager.start_session("test_integration") as inner_sess: 1462 assert _AutologgingSessionManager.active_session() == inner_sess == outer_sess 1463 1464 assert _AutologgingSessionManager.active_session() == outer_sess 1465 1466 assert not _AutologgingSessionManager.active_session() 1467 1468 1469 def test_original_fn_runs_if_patch_should_not_be_applied(patch_destination): 1470 patch_impl_call_count = 0 1471 1472 @autologging_integration("test_respects_exclusive") 1473 def autolog(disable=False, exclusive=False, silent=False): 1474 @asyncify(patch_destination.is_async) 1475 def patch_impl(original, *args, **kwargs): 1476 nonlocal patch_impl_call_count 1477 patch_impl_call_count += 1 1478 return original(*args, **kwargs) 1479 1480 safe_patch("test_respects_exclusive", patch_destination, "fn", patch_impl) 1481 1482 autolog(exclusive=True) 1483 with mlflow.start_run(): 1484 run_sync_or_async(patch_destination.fn) 1485 assert patch_impl_call_count == 0 1486 assert patch_destination.fn_call_count == 1 1487 1488 1489 def test_patch_runs_if_patch_should_be_applied(): 1490 patch_impl_call_count = 0 1491 1492 class TestPatchWithNewFnObj: 1493 def __init__(self): 1494 self.fn_call_count = 0 1495 1496 def fn(self, *args, **kwargs): 1497 self.fn_call_count += 1 1498 return PATCH_DESTINATION_FN_DEFAULT_RESULT 1499 1500 def new_fn(self, *args, **kwargs): 1501 with mlflow.start_run(): 1502 self.fn() 1503 1504 patch_obj = TestPatchWithNewFnObj() 1505 1506 @autologging_integration("test_respects_exclusive") 1507 def autolog(disable=False, exclusive=False, silent=False): 1508 def patch_impl(original, *args, **kwargs): 1509 nonlocal patch_impl_call_count 1510 patch_impl_call_count += 1 1511 1512 def new_fn_patch(original, *args, **kwargs): 1513 pass 1514 1515 safe_patch("test_respects_exclusive", patch_obj, "fn", patch_impl) 1516 safe_patch("test_respects_exclusive", patch_obj, "new_fn", new_fn_patch) 1517 1518 # Should patch if no active run 1519 autolog() 1520 patch_obj.fn() 1521 assert patch_impl_call_count == 1 1522 1523 # Should patch if active run, but not exclusive 1524 autolog(exclusive=False) 1525 with mlflow.start_run(): 1526 patch_obj.fn() 1527 assert patch_impl_call_count == 2 1528 1529 # Should patch if active run and exclusive, but active autologging session 1530 autolog(exclusive=True) 1531 patch_obj.new_fn() 1532 assert patch_impl_call_count == 3 1533 1534 1535 def test_nested_call_autologging_disabled_when_top_level_call_autologging_failed(patch_destination): 1536 patch_impl_call_count = 0 1537 1538 @autologging_integration( 1539 "test_nested_call_autologging_disabled_when_top_level_call_autologging_failed" 1540 ) 1541 def autolog(disable=False, exclusive=False, silent=False): 1542 @asyncify(patch_destination.is_async) 1543 def patch_impl(original, *args, **kwargs): 1544 nonlocal patch_impl_call_count 1545 patch_impl_call_count += 1 1546 1547 level = kwargs["level"] 1548 1549 if level == 0: 1550 raise RuntimeError("analog top level call autologging failure.") 1551 1552 return original(*args, **kwargs) 1553 1554 safe_patch( 1555 "test_nested_call_autologging_disabled_when_top_level_call_autologging_failed", 1556 patch_destination, 1557 "recursive_fn", 1558 patch_impl, 1559 ) 1560 1561 autolog() 1562 for max_depth in [1, 2, 3]: 1563 patch_impl_call_count = 0 1564 patch_destination.recurse_fn_call_count = 0 1565 with mlflow.start_run(): 1566 run_sync_or_async(patch_destination.recursive_fn, level=0, max_depth=max_depth) 1567 assert patch_impl_call_count == 1 1568 assert patch_destination.recurse_fn_call_count == max_depth + 1 1569 1570 1571 def test_old_patch_reverted_before_run_autolog_fn(): 1572 class PatchDestination: 1573 def f1(self): 1574 pass 1575 1576 original_f1 = PatchDestination.f1 1577 1578 @autologging_integration("test_old_patch_reverted_before_run_autolog_fn") 1579 def autolog(disable=False, exclusive=False, silent=False): 1580 assert PatchDestination.f1 is original_f1 # assert old patch has been reverted. 1581 1582 def patch_impl(original, *args, **kwargs): 1583 pass 1584 1585 safe_patch( 1586 "test_old_patch_reverted_before_run_autolog_fn", 1587 PatchDestination, 1588 "f1", 1589 patch_impl, 1590 ) 1591 1592 autolog(disable=True) 1593 autolog() 1594 autolog() # Test second time call autolog will revert first autolog call installed patch 1595 1596 1597 def test_safe_patch_support_property_decorated_method(): 1598 class BaseEstimator: 1599 def __init__(self, has_predict): 1600 self._has_predict = has_predict 1601 1602 def _predict(self, X, a, b): 1603 return {"X": X, "a": a, "b": b} 1604 1605 @property 1606 def predict(self): 1607 if not self._has_predict: 1608 raise AttributeError("does not have predict") 1609 return self._predict 1610 1611 class ExtendedEstimator(BaseEstimator): 1612 pass 1613 1614 original_base_estimator_predict = object.__getattribute__(BaseEstimator, "predict") 1615 1616 def patched_predict(original, self, *args, **kwargs): 1617 result = original(self, *args, **kwargs) 1618 if "patch_count" not in result: 1619 result["patch_count"] = 1 1620 else: 1621 result["patch_count"] += 1 1622 return result 1623 1624 flavor_name = "test_if_delegate_has_method_decorated_method_patch" 1625 1626 @autologging_integration(flavor_name) 1627 def autolog(disable=False, exclusive=False, silent=False): 1628 mlflow.sklearn._patch_estimator_method_if_available( 1629 flavor_name, 1630 BaseEstimator, 1631 "predict", 1632 patched_predict, 1633 manage_run=False, 1634 ) 1635 mlflow.sklearn._patch_estimator_method_if_available( 1636 flavor_name, 1637 ExtendedEstimator, 1638 "predict", 1639 patched_predict, 1640 manage_run=False, 1641 ) 1642 1643 autolog() 1644 1645 for EstimatorCls in [BaseEstimator, ExtendedEstimator]: 1646 assert EstimatorCls.predict.__doc__ == original_base_estimator_predict.__doc__ 1647 good_estimator = EstimatorCls(has_predict=True) 1648 assert good_estimator.predict.__doc__ == original_base_estimator_predict.__doc__ 1649 1650 expected_result = {"X": 1, "a": 2, "b": 3, "patch_count": 1} 1651 assert hasattr(good_estimator, "predict") 1652 assert good_estimator.predict(X=1, a=2, b=3) == expected_result 1653 assert good_estimator.predict(1, a=2, b=3) == expected_result 1654 assert good_estimator.predict(1, 2, b=3) == expected_result 1655 assert good_estimator.predict(1, 2, 3) == expected_result 1656 1657 bad_estimator = EstimatorCls(has_predict=False) 1658 assert not hasattr(bad_estimator, "predict") 1659 with pytest.raises(AttributeError, match="does not have predict"): 1660 bad_estimator.predict(X=1, a=2, b=3) 1661 1662 autolog(disable=True) 1663 assert original_base_estimator_predict is object.__getattribute__(BaseEstimator, "predict") 1664 assert "predict" not in ExtendedEstimator.__dict__ 1665 1666 1667 def test_safe_patch_preserves_original_function_attributes(): 1668 class Test1: 1669 def predict(self, X, a, b): 1670 """ 1671 Test doc for Test1.predict 1672 """ 1673 1674 def patched_predict(original, self, *args, **kwargs): 1675 return original(self, *args, **kwargs) 1676 1677 flavor_name = "test_safe_patch_preserves_original_function_attributes" 1678 1679 @autologging_integration(flavor_name) 1680 def autolog(disable=False, exclusive=False, silent=False): 1681 safe_patch(flavor_name, Test1, "predict", patched_predict, manage_run=False) 1682 1683 original_predict = Test1.predict 1684 autolog() 1685 assert get_func_attrs(Test1.predict) == get_func_attrs(original_predict)