/ tests / autologging / test_autologging_safety_unit.py
test_autologging_safety_unit.py
   1  import abc
   2  import copy
   3  import inspect
   4  from contextlib import nullcontext as does_not_raise
   5  from typing import Any, NamedTuple
   6  from unittest import mock
   7  
   8  import pytest
   9  
  10  import mlflow
  11  from mlflow import MlflowClient
  12  from mlflow.entities import RunStatus
  13  from mlflow.utils import autologging_utils
  14  from mlflow.utils.autologging_utils import (
  15      AutologgingEventLogger,
  16      ExceptionSafeAbstractClass,
  17      ExceptionSafeClass,
  18      autologging_integration,
  19      is_testing,
  20      picklable_exception_safe_function,
  21      safe_patch,
  22      with_managed_run,
  23  )
  24  from mlflow.utils.autologging_utils.safety import (
  25      ValidationExemptArgument,
  26      _AutologgingSessionManager,
  27      _validate_args,
  28      _validate_autologging_run,
  29  )
  30  from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING
  31  
  32  from tests.autologging.async_helper import asyncify, run_sync_or_async
  33  from tests.autologging.fixtures import (
  34      patch_destination,  # noqa: F401
  35      test_mode_off,
  36      test_mode_on,
  37  )
  38  from tests.autologging.test_autologging_utils import get_func_attrs
  39  
  40  PATCH_DESTINATION_FN_DEFAULT_RESULT = "original_result"
  41  
  42  
  43  @pytest.fixture(autouse=True)
  44  def turn_test_mode_off_by_default(test_mode_off):
  45      """
  46      Most of the unit test cases in this module assume that autologging APIs are operating in a
  47      standard execution mode (i.e. where test mode is disabled). Accordingly, we turn off autologging
  48      test mode for this test module by default. Test cases that verify behaviors specific to test
  49      mode enable test mode explicitly by specifying the `test_mode_on` fixture.
  50  
  51      For more information about autologging test mode, see the docstring for
  52      :py:func:`mlflow.utils.autologging_utils._is_testing()`.
  53      """
  54  
  55  
  56  @pytest.fixture
  57  def test_autologging_integration():
  58      integration_name = "test_integration"
  59  
  60      @autologging_integration(integration_name)
  61      def autolog(disable=False, silent=False):
  62          pass
  63  
  64      autolog()
  65  
  66      return integration_name
  67  
  68  
  69  class MockEventLogger(AutologgingEventLogger):
  70      class LoggerCall(NamedTuple):
  71          method: str
  72          session: Any
  73          patch_obj: Any
  74          function_name: str
  75          call_args: Any
  76          call_kwargs: Any
  77          exception: Any
  78  
  79      def __init__(self):
  80          self.calls = []
  81  
  82      def reset(self):
  83          self.calls = []
  84  
  85      def log_patch_function_start(self, session, patch_obj, function_name, call_args, call_kwargs):
  86          self.calls.append(
  87              MockEventLogger.LoggerCall(
  88                  "patch_start", session, patch_obj, function_name, call_args, call_kwargs, None
  89              )
  90          )
  91  
  92      def log_patch_function_success(self, session, patch_obj, function_name, call_args, call_kwargs):
  93          self.calls.append(
  94              MockEventLogger.LoggerCall(
  95                  "patch_success", session, patch_obj, function_name, call_args, call_kwargs, None
  96              )
  97          )
  98  
  99      def log_patch_function_error(
 100          self, session, patch_obj, function_name, call_args, call_kwargs, exception
 101      ):
 102          self.calls.append(
 103              MockEventLogger.LoggerCall(
 104                  "patch_error", session, patch_obj, function_name, call_args, call_kwargs, exception
 105              )
 106          )
 107  
 108      def log_original_function_start(
 109          self, session, patch_obj, function_name, call_args, call_kwargs
 110      ):
 111          self.calls.append(
 112              MockEventLogger.LoggerCall(
 113                  "original_start", session, patch_obj, function_name, call_args, call_kwargs, None
 114              )
 115          )
 116  
 117      def log_original_function_success(
 118          self, session, patch_obj, function_name, call_args, call_kwargs
 119      ):
 120          self.calls.append(
 121              MockEventLogger.LoggerCall(
 122                  "original_success", session, patch_obj, function_name, call_args, call_kwargs, None
 123              )
 124          )
 125  
 126      def log_original_function_error(
 127          self, session, patch_obj, function_name, call_args, call_kwargs, exception
 128      ):
 129          self.calls.append(
 130              MockEventLogger.LoggerCall(
 131                  "original_error",
 132                  session,
 133                  patch_obj,
 134                  function_name,
 135                  call_args,
 136                  call_kwargs,
 137                  exception,
 138              )
 139          )
 140  
 141  
 142  @pytest.fixture
 143  def mock_event_logger():
 144      prev_logger = AutologgingEventLogger.get_logger()
 145      try:
 146          logger = MockEventLogger()
 147          AutologgingEventLogger.set_logger(logger)
 148          yield logger
 149      finally:
 150          AutologgingEventLogger.set_logger(prev_logger)
 151  
 152  
 153  def test_is_testing_respects_environment_variable(monkeypatch):
 154      monkeypatch.delenv("MLFLOW_AUTOLOGGING_TESTING", raising=False)
 155      assert not is_testing()
 156  
 157      monkeypatch.setenv("MLFLOW_AUTOLOGGING_TESTING", "false")
 158      assert not is_testing()
 159  
 160      monkeypatch.setenv("MLFLOW_AUTOLOGGING_TESTING", "true")
 161      assert is_testing()
 162  
 163  
 164  def test_safe_patch_forwards_expected_arguments_to_function_based_patch_implementation(
 165      patch_destination, test_autologging_integration
 166  ):
 167      foo_val = None
 168      bar_val = None
 169  
 170      @asyncify(patch_destination.is_async)
 171      def patch_impl(original, foo, bar=10):
 172          nonlocal foo_val
 173          nonlocal bar_val
 174          foo_val = foo
 175          bar_val = bar
 176  
 177      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 178      run_sync_or_async(patch_destination.fn, 7, bar=11)
 179      assert foo_val == 7
 180      assert bar_val == 11
 181  
 182  
 183  def test_safe_patch_provides_expected_original_function(
 184      test_autologging_integration, patch_destination
 185  ):
 186      @asyncify(patch_destination.is_async)
 187      def original_fn(foo, bar=10):
 188          return {
 189              "foo": foo,
 190              "bar": bar,
 191          }
 192  
 193      patch_destination.fn = original_fn
 194  
 195      @asyncify(patch_destination.is_async)
 196      def patch_impl(original, foo, bar):
 197          return original(foo + 1, bar + 2)
 198  
 199      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 200  
 201      assert run_sync_or_async(patch_destination.fn, 1, 2) == {"foo": 2, "bar": 4}
 202  
 203  
 204  def test_safe_patch_propagates_exceptions_raised_from_original_function(
 205      patch_destination, test_autologging_integration
 206  ):
 207      exc_to_throw = Exception("Bad original function")
 208  
 209      @asyncify(patch_destination.is_async)
 210      def original(*args, **kwargs):
 211          raise exc_to_throw
 212  
 213      patch_destination.fn = original
 214  
 215      patch_impl_called = False
 216  
 217      @asyncify(patch_destination.is_async)
 218      def patch_impl(original, *args, **kwargs):
 219          nonlocal patch_impl_called
 220          patch_impl_called = True
 221          return original(*args, **kwargs)
 222  
 223      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 224  
 225      with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
 226          run_sync_or_async(patch_destination.fn)
 227  
 228      assert exc.value == exc_to_throw
 229      assert patch_impl_called
 230  
 231  
 232  def test_safe_patch_logs_exceptions_raised_outside_of_original_function_as_warnings(
 233      patch_destination, test_autologging_integration
 234  ):
 235      exc_to_throw = Exception("Bad patch implementation")
 236  
 237      @asyncify(patch_destination.is_async)
 238      def patch_impl(original, *args, **kwargs):
 239          raise exc_to_throw
 240  
 241      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 242      with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock:
 243          assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT
 244          assert logger_mock.call_count == 1
 245          expected_warning = "Encountered unexpected error during {} autologging: {}".format(
 246              test_autologging_integration, exc_to_throw
 247          )
 248          assert logger_mock.call_args[0][0] == expected_warning
 249  
 250  
 251  @pytest.mark.usefixtures(test_mode_on.__name__)
 252  def test_safe_patch_propagates_exceptions_raised_outside_of_original_function_in_test_mode(
 253      patch_destination, test_autologging_integration
 254  ):
 255      exc_to_throw = Exception("Bad patch implementation")
 256  
 257      @asyncify(patch_destination.is_async)
 258      def patch_impl(original, *args, **kwargs):
 259          raise exc_to_throw
 260  
 261      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 262      with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
 263          run_sync_or_async(patch_destination.fn)
 264  
 265      assert exc.value == exc_to_throw
 266  
 267  
 268  def test_safe_patch_calls_original_function_when_patch_preamble_throws(
 269      patch_destination, test_autologging_integration
 270  ):
 271      patch_impl_called = False
 272  
 273      @asyncify(patch_destination.is_async)
 274      def patch_impl(original, *args, **kwargs):
 275          nonlocal patch_impl_called
 276          patch_impl_called = True
 277          raise Exception("Bad patch preamble")
 278  
 279      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 280      assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT
 281      assert patch_destination.fn_call_count == 1
 282      assert patch_impl_called
 283  
 284  
 285  def test_safe_patch_returns_original_result_without_second_call_when_patch_postamble_throws(
 286      patch_destination, test_autologging_integration
 287  ):
 288      patch_impl_called = False
 289  
 290      @asyncify(patch_destination.is_async)
 291      def patch_impl(original, *args, **kwargs):
 292          nonlocal patch_impl_called
 293          patch_impl_called = True
 294          original(*args, **kwargs)
 295          raise Exception("Bad patch postamble")
 296  
 297      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 298      assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT
 299      assert patch_destination.fn_call_count == 1
 300      assert patch_impl_called
 301  
 302  
 303  def test_safe_patch_respects_disable_flag(patch_destination):
 304      patch_impl_call_count = 0
 305  
 306      @autologging_integration("test_respects_disable")
 307      def autolog(disable=False, silent=False):
 308          @asyncify(patch_destination.is_async)
 309          def patch_impl(original, *args, **kwargs):
 310              nonlocal patch_impl_call_count
 311              patch_impl_call_count += 1
 312              return original(*args, **kwargs)
 313  
 314          safe_patch("test_respects_disable", patch_destination, "fn", patch_impl)
 315  
 316      autolog(disable=False)
 317      run_sync_or_async(patch_destination.fn)
 318      assert patch_impl_call_count == 1
 319  
 320      autolog(disable=True)
 321      run_sync_or_async(patch_destination.fn)
 322      assert patch_impl_call_count == 1
 323  
 324  
 325  def test_safe_patch_returns_original_result_and_ignores_patch_return_value(
 326      patch_destination, test_autologging_integration
 327  ):
 328      patch_impl_called = False
 329  
 330      @asyncify(patch_destination.is_async)
 331      def patch_impl(original, *args, **kwargs):
 332          nonlocal patch_impl_called
 333          patch_impl_called = True
 334          return 10
 335  
 336      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 337      assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT
 338      assert patch_destination.fn_call_count == 1
 339      assert patch_impl_called
 340  
 341  
 342  @pytest.mark.usefixtures(test_mode_on.__name__)
 343  def test_safe_patch_validates_arguments_to_original_function_in_test_mode(
 344      patch_destination, test_autologging_integration
 345  ):
 346      @asyncify(patch_destination.is_async)
 347      def patch_impl(original, *args, **kwargs):
 348          return original("1", "2", "3")
 349  
 350      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 351  
 352      with (
 353          pytest.raises(Exception, match="does not match expected input"),
 354          mock.patch(
 355              "mlflow.utils.autologging_utils.safety._validate_args",
 356              wraps=autologging_utils.safety._validate_args,
 357          ) as validate_mock,
 358      ):
 359          run_sync_or_async(patch_destination.fn, "a", "b", "c")
 360  
 361      assert validate_mock.call_count == 1
 362  
 363  
 364  @pytest.mark.usefixtures(test_mode_on.__name__)
 365  def test_safe_patch_throws_when_autologging_runs_are_leaked_in_test_mode(
 366      patch_destination, test_autologging_integration
 367  ):
 368      assert autologging_utils.is_testing()
 369  
 370      @asyncify(patch_destination.is_async)
 371      def leak_run_patch_impl(original, *args, **kwargs):
 372          mlflow.start_run(nested=True)
 373  
 374      safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl)
 375      with pytest.raises(AssertionError, match="leaked an active run"):
 376          run_sync_or_async(patch_destination.fn)
 377  
 378      # End the leaked run
 379      mlflow.end_run()
 380  
 381      with mlflow.start_run():
 382          # If a user-generated run existed prior to the autologged training session, we expect
 383          # that safe patch will not throw a leaked run exception
 384          patch_destination.fn()
 385          # End the leaked nested run
 386          mlflow.end_run()
 387  
 388      assert not mlflow.active_run()
 389  
 390  
 391  def test_safe_patch_does_not_throw_when_autologging_runs_are_leaked_in_standard_mode(
 392      patch_destination, test_autologging_integration
 393  ):
 394      assert not autologging_utils.is_testing()
 395  
 396      @asyncify(patch_destination.is_async)
 397      def leak_run_patch_impl(original, *args, **kwargs):
 398          mlflow.start_run(nested=True)
 399  
 400      safe_patch(test_autologging_integration, patch_destination, "fn", leak_run_patch_impl)
 401      run_sync_or_async(patch_destination.fn)
 402      assert mlflow.active_run()
 403  
 404      # End the leaked run
 405      mlflow.end_run()
 406  
 407      assert not mlflow.active_run()
 408  
 409  
 410  @pytest.mark.usefixtures(test_mode_on.__name__)
 411  def test_safe_patch_validates_autologging_runs_when_necessary_in_test_mode(
 412      patch_destination, test_autologging_integration
 413  ):
 414      assert autologging_utils.is_testing()
 415  
 416      def no_tag_run_patch(original, *args, **kwargs):
 417          with mlflow.start_run(nested=True):
 418              return original(*args, **kwargs)
 419  
 420      async def async_no_tag_run_patch(original, *args, **kwargs):
 421          with mlflow.start_run(nested=True):
 422              return await original(*args, **kwargs)
 423  
 424      if patch_destination.is_async:
 425          safe_patch(test_autologging_integration, patch_destination, "fn", async_no_tag_run_patch)
 426      else:
 427          safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch)
 428  
 429      with mock.patch(
 430          "mlflow.utils.autologging_utils.safety._validate_autologging_run",
 431          wraps=_validate_autologging_run,
 432      ) as validate_run_mock:
 433          with pytest.raises(
 434              AssertionError, match="failed to set autologging tag with expected value"
 435          ):
 436              run_sync_or_async(patch_destination.fn)
 437          assert validate_run_mock.call_count == 1
 438  
 439          validate_run_mock.reset_mock()
 440  
 441          with mlflow.start_run(nested=True):
 442              # If a user-generated run existed prior to the autologged training session, we expect
 443              # that safe patch will not attempt to validate it
 444              run_sync_or_async(patch_destination.fn)
 445          assert not validate_run_mock.called
 446  
 447  
 448  def test_safe_patch_does_not_validate_autologging_runs_in_standard_mode(
 449      patch_destination, test_autologging_integration
 450  ):
 451      assert not autologging_utils.is_testing()
 452  
 453      @asyncify(patch_destination.is_async)
 454      def no_tag_run_patch_impl(original, *args, **kwargs):
 455          with mlflow.start_run(nested=True):
 456              return original(*args, **kwargs)
 457  
 458      safe_patch(test_autologging_integration, patch_destination, "fn", no_tag_run_patch_impl)
 459  
 460      with mock.patch(
 461          "mlflow.utils.autologging_utils.safety._validate_autologging_run",
 462          wraps=_validate_autologging_run,
 463      ) as validate_run_mock:
 464          run_sync_or_async(patch_destination.fn)
 465  
 466          with mlflow.start_run(nested=True):
 467              # If a user-generated run existed prior to the autologged training session, we expect
 468              # that safe patch will not attempt to validate it
 469              run_sync_or_async(patch_destination.fn)
 470  
 471          assert not validate_run_mock.called
 472  
 473  
 474  def test_safe_patch_manages_run_if_specified_and_sets_expected_run_tags(
 475      patch_destination, test_autologging_integration
 476  ):
 477      client = MlflowClient()
 478      active_run = None
 479  
 480      @asyncify(patch_destination.is_async)
 481      def patch_impl(original, *args, **kwargs):
 482          nonlocal active_run
 483          active_run = mlflow.active_run()
 484          return original(*args, **kwargs)
 485  
 486      if patch_destination.is_async:
 487          with pytest.raises(Exception, match="manage_run parameter is not supported"):
 488              safe_patch(
 489                  test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True
 490              )
 491          return
 492  
 493      with mock.patch(
 494          "mlflow.utils.autologging_utils.safety.with_managed_run", wraps=with_managed_run
 495      ) as managed_run_mock:
 496          safe_patch(
 497              test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=True
 498          )
 499  
 500      run_sync_or_async(patch_destination.fn)
 501      assert managed_run_mock.call_count == 1
 502      assert active_run is not None
 503      assert active_run.info.run_id is not None
 504      assert (
 505          client.get_run(active_run.info.run_id).data.tags[MLFLOW_AUTOLOGGING] == "test_integration"
 506      )
 507  
 508  
 509  def test_safe_patch_does_not_manage_run_if_unspecified(
 510      patch_destination, test_autologging_integration
 511  ):
 512      active_run = None
 513  
 514      @asyncify(patch_destination.is_async)
 515      def patch_impl(original, *args, **kwargs):
 516          nonlocal active_run
 517          active_run = mlflow.active_run()
 518          return original(*args, **kwargs)
 519  
 520      with mock.patch(
 521          "mlflow.utils.autologging_utils.with_managed_run", wraps=with_managed_run
 522      ) as managed_run_mock:
 523          safe_patch(
 524              test_autologging_integration, patch_destination, "fn", patch_impl, manage_run=False
 525          )
 526          run_sync_or_async(patch_destination.fn)
 527          assert managed_run_mock.call_count == 0
 528          assert active_run is None
 529  
 530  
 531  def test_safe_patch_preserves_signature_of_patched_function(
 532      patch_destination, test_autologging_integration
 533  ):
 534      @asyncify(patch_destination.is_async)
 535      def original(a, b, c=10, *, d=11):
 536          return 10
 537  
 538      patch_destination.fn = original
 539  
 540      patch_impl_called = False
 541  
 542      @asyncify(patch_destination.is_async)
 543      def patch_impl(original, *args, **kwargs):
 544          nonlocal patch_impl_called
 545          patch_impl_called = True
 546          return original(*args, **kwargs)
 547  
 548      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 549      run_sync_or_async(patch_destination.fn, 1, 2)
 550      assert patch_impl_called
 551      assert inspect.signature(patch_destination.fn) == inspect.signature(original)
 552  
 553  
 554  def test_safe_patch_provides_original_function_with_expected_signature(
 555      patch_destination, test_autologging_integration
 556  ):
 557      @asyncify(patch_destination.is_async)
 558      def original(a, b, c=10, *, d=11):
 559          return 10
 560  
 561      patch_destination.fn = original
 562  
 563      original_signature = False
 564  
 565      @asyncify(patch_destination.is_async)
 566      def patch_impl(original, *args, **kwargs):
 567          nonlocal original_signature
 568          original_signature = inspect.signature(original)
 569          return original(*args, **kwargs)
 570  
 571      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 572      run_sync_or_async(patch_destination.fn, 1, 2)
 573      assert original_signature == inspect.signature(original)
 574  
 575  
 576  def test_safe_patch_makes_expected_event_logging_calls_for_successful_patch_invocation(
 577      patch_destination,
 578      test_autologging_integration,
 579      mock_event_logger,
 580  ):
 581      patch_session = None
 582      og_call_kwargs = {}
 583  
 584      @asyncify(patch_destination.is_async)
 585      def patch_impl(original, *args, **kwargs):
 586          nonlocal og_call_kwargs
 587          kwargs.update({"extra_func": picklable_exception_safe_function(lambda k: "foo")})
 588          og_call_kwargs = kwargs
 589  
 590          nonlocal patch_session
 591          patch_session = _AutologgingSessionManager.active_session()
 592  
 593          original(*args, **kwargs)
 594  
 595      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 596  
 597      run_sync_or_async(patch_destination.fn, "a", 1, b=2)
 598      expected_order = ["patch_start", "original_start", "original_success", "patch_success"]
 599      assert [call.method for call in mock_event_logger.calls] == expected_order
 600      assert all(call.session == patch_session for call in mock_event_logger.calls)
 601      assert all(call.patch_obj == patch_destination for call in mock_event_logger.calls)
 602      assert all(call.function_name == "fn" for call in mock_event_logger.calls)
 603      patch_start, original_start, original_success, patch_success = mock_event_logger.calls
 604      assert patch_start.call_args == patch_success.call_args == ("a", 1)
 605      assert patch_start.call_kwargs == patch_success.call_kwargs == {"b": 2}
 606      assert original_start.call_args == original_success.call_args == ("a", 1)
 607      assert original_start.call_kwargs == original_success.call_kwargs == og_call_kwargs
 608      assert patch_start.exception is original_start.exception is None
 609      assert patch_success.exception is original_success.exception is None
 610  
 611  
 612  def test_safe_patch_makes_expected_event_logging_calls_when_patch_impl_throws_and_original_succeeds(
 613      patch_destination,
 614      test_autologging_integration,
 615      mock_event_logger,
 616  ):
 617      exc_to_raise = Exception("thrown from patch")
 618  
 619      throw_location = None
 620  
 621      @asyncify(patch_destination.is_async)
 622      def patch_impl(original, *args, **kwargs):
 623          nonlocal throw_location
 624  
 625          if throw_location == "before":
 626              raise exc_to_raise
 627  
 628          original(*args, **kwargs)
 629  
 630          if throw_location != "before":
 631              raise exc_to_raise
 632  
 633      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 634  
 635      expected_order = [
 636          "patch_start",
 637          "original_start",
 638          "original_success",
 639          "patch_error",
 640      ]
 641  
 642      for throw_location in ["before", "after"]:
 643          mock_event_logger.reset()
 644          run_sync_or_async(patch_destination.fn)
 645          assert [call.method for call in mock_event_logger.calls] == expected_order
 646          patch_start, original_start, original_success, patch_error = mock_event_logger.calls
 647          assert patch_start.exception is None
 648          assert original_start.exception is None
 649          assert original_success.exception is None
 650          assert patch_error.exception == exc_to_raise
 651  
 652  
 653  def test_safe_patch_makes_expected_event_logging_calls_when_patch_impl_throws_and_original_throws(
 654      patch_destination,
 655      test_autologging_integration,
 656      mock_event_logger,
 657  ):
 658      exc_to_raise = Exception("thrown from patch")
 659      original_err_to_raise = Exception("throw from original")
 660  
 661      throw_location = None
 662  
 663      @asyncify(patch_destination.is_async)
 664      def patch_impl(original, *args, **kwargs):
 665          nonlocal throw_location
 666  
 667          if throw_location == "before":
 668              raise exc_to_raise
 669  
 670          original(*args, **kwargs)
 671  
 672          if throw_location != "before":
 673              raise exc_to_raise
 674  
 675      safe_patch(test_autologging_integration, patch_destination, "throw_error_fn", patch_impl)
 676  
 677      expected_order = ["patch_start", "original_start", "original_error"]
 678  
 679      for throw_location in ["before", "after"]:
 680          mock_event_logger.reset()
 681          with pytest.raises(Exception, match="throw from original"):
 682              run_sync_or_async(patch_destination.throw_error_fn, original_err_to_raise)
 683          assert [call.method for call in mock_event_logger.calls] == expected_order
 684          patch_start, original_start, original_error = mock_event_logger.calls
 685          assert patch_start.exception is None
 686          assert original_start.exception is None
 687          assert original_error.exception == original_err_to_raise
 688  
 689  
 690  def test_safe_patch_makes_expected_event_logging_calls_when_original_function_throws(
 691      patch_destination,
 692      test_autologging_integration,
 693      mock_event_logger,
 694  ):
 695      exc_to_raise = Exception("thrown from patch")
 696  
 697      @asyncify(patch_destination.is_async)
 698      def original(*args, **kwargs):
 699          raise exc_to_raise
 700  
 701      patch_destination.fn = original
 702  
 703      @asyncify(patch_destination.is_async)
 704      def patch_impl(original, *args, **kwargs):
 705          original(*args, **kwargs)
 706  
 707      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 708  
 709      with pytest.raises(Exception, match="thrown from patch"):
 710          run_sync_or_async(patch_destination.fn)
 711      expected_order = ["patch_start", "original_start", "original_error"]
 712      assert [call.method for call in mock_event_logger.calls] == expected_order
 713      patch_start, original_start, original_error = mock_event_logger.calls
 714      assert patch_start.exception is original_start.exception is None
 715      assert original_error.exception == exc_to_raise
 716  
 717  
 718  @pytest.mark.usefixtures(test_mode_off.__name__)
 719  def test_safe_patch_succeeds_when_event_logging_throws_in_standard_mode(
 720      patch_destination,
 721      test_autologging_integration,
 722  ):
 723      patch_preamble_called = False
 724      patch_postamble_called = False
 725  
 726      @asyncify(patch_destination.is_async)
 727      def patch_impl(original, *args, **kwargs):
 728          nonlocal patch_preamble_called
 729          patch_preamble_called = True
 730          original(*args, **kwargs)
 731          nonlocal patch_postamble_called
 732          patch_postamble_called = True
 733  
 734      safe_patch(test_autologging_integration, patch_destination, "fn", patch_impl)
 735  
 736      class ThrowingLogger(MockEventLogger):
 737          def log_patch_function_start(
 738              self, session, patch_obj, function_name, call_args, call_kwargs
 739          ):
 740              super().log_patch_function_start(
 741                  session, patch_obj, function_name, call_args, call_kwargs
 742              )
 743              raise Exception("failed")
 744  
 745          def log_patch_function_success(
 746              self, session, patch_obj, function_name, call_args, call_kwargs
 747          ):
 748              super().log_patch_function_success(
 749                  session, patch_obj, function_name, call_args, call_kwargs
 750              )
 751              raise Exception("failed")
 752  
 753          def log_patch_function_error(
 754              self, session, patch_obj, function_name, call_args, call_kwargs, exception
 755          ):
 756              super().log_patch_function_error(
 757                  session, patch_obj, function_name, call_args, call_kwargs, exception
 758              )
 759              raise Exception("failed")
 760  
 761          def log_original_function_start(
 762              self, session, patch_obj, function_name, call_args, call_kwargs
 763          ):
 764              super().log_original_function_start(
 765                  session, patch_obj, function_name, call_args, call_kwargs
 766              )
 767              raise Exception("failed")
 768  
 769          def log_original_function_success(
 770              self, session, patch_obj, function_name, call_args, call_kwargs
 771          ):
 772              super().log_original_function_success(
 773                  session, patch_obj, function_name, call_args, call_kwargs
 774              )
 775              raise Exception("failed")
 776  
 777          def log_original_function_error(
 778              self, session, patch_obj, function_name, call_args, call_kwargs, exception
 779          ):
 780              super().log_original_function_error(
 781                  session, patch_obj, function_name, call_args, call_kwargs, exception
 782              )
 783              raise Exception("failed")
 784  
 785      logger = ThrowingLogger()
 786      AutologgingEventLogger.set_logger(logger)
 787      assert run_sync_or_async(patch_destination.fn) == PATCH_DESTINATION_FN_DEFAULT_RESULT
 788      assert patch_preamble_called
 789      assert patch_postamble_called
 790      expected_calls = ["patch_start", "original_start", "original_success", "patch_success"]
 791      assert [call.method for call in logger.calls] == expected_calls
 792  
 793  
 794  def test_picklable_exception_safe_function_exhibits_expected_behavior_in_standard_mode():
 795      assert not autologging_utils.is_testing()
 796  
 797      @picklable_exception_safe_function
 798      def non_throwing_function():
 799          return 10
 800  
 801      assert non_throwing_function() == 10
 802  
 803      exc_to_throw = Exception("bad implementation")
 804  
 805      @picklable_exception_safe_function
 806      def throwing_function():
 807          raise exc_to_throw
 808  
 809      with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock:
 810          throwing_function()
 811          assert logger_mock.call_count == 1
 812          message, formatting_arg = logger_mock.call_args[0]
 813          assert "unexpected error during autologging" in message
 814          assert formatting_arg == exc_to_throw
 815  
 816  
 817  @pytest.mark.usefixtures(test_mode_on.__name__)
 818  def test_picklable_exception_safe_function_exhibits_expected_behavior_in_test_mode():
 819      assert autologging_utils.is_testing()
 820  
 821      @picklable_exception_safe_function
 822      def non_throwing_function():
 823          return 10
 824  
 825      assert non_throwing_function() == 10
 826  
 827      exc_to_throw = Exception("function error")
 828  
 829      @picklable_exception_safe_function
 830      def throwing_function():
 831          raise exc_to_throw
 832  
 833      with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
 834          throwing_function()
 835  
 836      assert exc.value == exc_to_throw
 837  
 838  
 839  @pytest.mark.parametrize(
 840      ("baseclass", "metaclass"),
 841      [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)],
 842  )
 843  def test_exception_safe_class_exhibits_expected_behavior_in_standard_mode(baseclass, metaclass):
 844      assert not autologging_utils.is_testing()
 845  
 846      class NonThrowingClass(baseclass, metaclass=metaclass):
 847          def function(self):
 848              return 10
 849  
 850      assert NonThrowingClass().function() == 10
 851  
 852      exc_to_throw = Exception("function error")
 853  
 854      class ThrowingClass(baseclass, metaclass=metaclass):
 855          def function(self):
 856              raise exc_to_throw
 857  
 858      with mock.patch("mlflow.utils.autologging_utils._logger.warning") as logger_mock:
 859          ThrowingClass().function()
 860  
 861          assert logger_mock.call_count == 1
 862  
 863          message, formatting_arg = logger_mock.call_args[0]
 864          assert "unexpected error during autologging" in message
 865          assert formatting_arg == exc_to_throw
 866  
 867  
 868  @pytest.mark.usefixtures(test_mode_on.__name__)
 869  @pytest.mark.parametrize(
 870      ("baseclass", "metaclass"),
 871      [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)],
 872  )
 873  def test_exception_safe_class_exhibits_expected_behavior_in_test_mode(baseclass, metaclass):
 874      assert autologging_utils.is_testing()
 875  
 876      class NonThrowingClass(baseclass, metaclass=metaclass):
 877          def function(self):
 878              return 10
 879  
 880      assert NonThrowingClass().function() == 10
 881  
 882      exc_to_throw = Exception("function error")
 883  
 884      class ThrowingClass(baseclass, metaclass=metaclass):
 885          def function(self):
 886              raise exc_to_throw
 887  
 888      with pytest.raises(Exception, match=str(exc_to_throw)) as exc:
 889          ThrowingClass().function()
 890  
 891      assert exc.value == exc_to_throw
 892  
 893  
 894  def test_with_managed_run_with_non_throwing_function_exhibits_expected_behavior():
 895      client = MlflowClient()
 896  
 897      def patch_function(original, *args, **kwargs):
 898          return mlflow.active_run()
 899  
 900      patch_function = with_managed_run("test_integration", patch_function)
 901  
 902      run1 = patch_function(lambda: "foo")
 903      run1_status = client.get_run(run1.info.run_id).info.status
 904      assert RunStatus.from_string(run1_status) == RunStatus.FINISHED
 905  
 906      with mlflow.start_run() as active_run:
 907          run2 = patch_function(lambda: "foo")
 908  
 909      assert run2 == active_run
 910      run2_status = client.get_run(run2.info.run_id).info.status
 911      assert RunStatus.from_string(run2_status) == RunStatus.FINISHED
 912  
 913  
 914  def test_with_managed_run_with_throwing_function_exhibits_expected_behavior():
 915      client = MlflowClient()
 916      patch_function_active_run = None
 917  
 918      def patch_function(original, *args, **kwargs):
 919          nonlocal patch_function_active_run
 920          patch_function_active_run = mlflow.active_run()
 921          raise Exception("bad implementation")
 922  
 923      patch_function = with_managed_run("test_integration", patch_function)
 924  
 925      with pytest.raises(Exception, match="bad implementation"):
 926          patch_function(lambda: "foo")
 927  
 928      assert patch_function_active_run is not None
 929      status1 = client.get_run(patch_function_active_run.info.run_id).info.status
 930      assert RunStatus.from_string(status1) == RunStatus.FAILED
 931  
 932      with mlflow.start_run() as active_run, pytest.raises(Exception, match="bad implementation"):
 933          patch_function(lambda: "foo")
 934      assert patch_function_active_run == active_run
 935      # `with_managed_run` should not terminate a preexisting MLflow run,
 936      # even if the patch function throws
 937      status2 = client.get_run(active_run.info.run_id).info.status
 938      assert RunStatus.from_string(status2) == RunStatus.FINISHED
 939  
 940  
 941  def test_with_managed_run_sets_specified_run_tags():
 942      client = MlflowClient()
 943      tags_to_set = {
 944          "foo": "bar",
 945          "num_layers": "7",
 946      }
 947  
 948      patch_function_1 = with_managed_run(
 949          "test_integration", lambda original, *args, **kwargs: mlflow.active_run(), tags=tags_to_set
 950      )
 951      run1 = patch_function_1(lambda: "foo")
 952      assert tags_to_set.items() <= client.get_run(run1.info.run_id).data.tags.items()
 953  
 954  
 955  @pytest.mark.usefixtures(test_mode_on.__name__)
 956  def test_with_managed_run_ends_run_on_keyboard_interrupt():
 957      client = MlflowClient()
 958      run = None
 959  
 960      def original():
 961          nonlocal run
 962          run = mlflow.active_run()
 963          raise KeyboardInterrupt
 964  
 965      patch_function_1 = with_managed_run(
 966          "test_integration", lambda original, *args, **kwargs: original(*args, **kwargs)
 967      )
 968  
 969      with pytest.raises(KeyboardInterrupt, match=r".*"):
 970          patch_function_1(original)
 971  
 972      assert not mlflow.active_run()
 973      run_status_1 = client.get_run(run.info.run_id).info.status
 974      assert RunStatus.from_string(run_status_1) == RunStatus.FAILED
 975  
 976  
 977  @pytest.mark.usefixtures(test_mode_on.__name__)
 978  def test_validate_args_succeeds_when_arg_sets_are_equivalent_or_identical():
 979      args = (1, "b", ["c"])
 980      kwargs = {
 981          "foo": ["bar"],
 982          "biz": {"baz": 5},
 983      }
 984  
 985      _validate_args("autologging_integration_name", "function_name", args, kwargs, args, kwargs)
 986      _validate_args("autologging_integration_name", "function_name", args, {}, args, {})
 987      _validate_args("autologging_integration_name", "function_name", (), kwargs, (), kwargs)
 988  
 989      args_copy = copy.deepcopy(args)
 990      kwargs_copy = copy.deepcopy(kwargs)
 991  
 992      _validate_args(
 993          "autologging_integration_name", "function_name", args, kwargs, args_copy, kwargs_copy
 994      )
 995      _validate_args("autologging_integration_name", "function_name", args, {}, args_copy, {})
 996      _validate_args("autologging_integration_name", "function_name", (), kwargs, (), kwargs_copy)
 997  
 998  
 999  @pytest.mark.usefixtures(test_mode_on.__name__)
1000  def test_validate_args_throws_when_extra_args_are_not_functions_classes_or_lists():
1001      user_call_args = (1, "b", ["c"])
1002      user_call_kwargs = {
1003          "foo": ["bar"],
1004          "biz": {"baz": 5},
1005      }
1006  
1007      invalid_type_autologging_call_args = copy.deepcopy(user_call_args)
1008      invalid_type_autologging_call_args[2].append(10)
1009      invalid_type_autologging_call_kwargs = copy.deepcopy(user_call_kwargs)
1010      invalid_type_autologging_call_kwargs["new"] = {}
1011  
1012      with pytest.raises(Exception, match="Invalid new input"):
1013          _validate_args(
1014              "autologging_integration_name",
1015              "function_name",
1016              user_call_args,
1017              user_call_kwargs,
1018              invalid_type_autologging_call_args,
1019              user_call_kwargs,
1020          )
1021  
1022      with pytest.raises(Exception, match="Invalid new input"):
1023          _validate_args(
1024              "autologging_integration_name",
1025              "function_name",
1026              user_call_args,
1027              user_call_kwargs,
1028              user_call_args,
1029              invalid_type_autologging_call_kwargs,
1030          )
1031  
1032  
1033  @pytest.mark.usefixtures(test_mode_on.__name__)
1034  def test_validate_args_throws_when_extra_args_are_not_exception_safe():
1035      user_call_args = (1, "b", ["c"])
1036      user_call_kwargs = {
1037          "foo": ["bar"],
1038          "biz": {"baz": 5},
1039      }
1040  
1041      class Unsafe:
1042          pass
1043  
1044      unsafe_autologging_call_args = copy.deepcopy(user_call_args)
1045      unsafe_autologging_call_args += (lambda: "foo",)
1046      unsafe_autologging_call_kwargs1 = copy.deepcopy(user_call_kwargs)
1047      unsafe_autologging_call_kwargs1["foo"].append(Unsafe())
1048  
1049      with pytest.raises(Exception, match="not exception-safe"):
1050          _validate_args(
1051              "autologging_integration_name",
1052              "function_name",
1053              user_call_args,
1054              user_call_kwargs,
1055              unsafe_autologging_call_args,
1056              user_call_kwargs,
1057          )
1058  
1059      with pytest.raises(Exception, match="Invalid new input"):
1060          _validate_args(
1061              "autologging_integration_name",
1062              "function_name",
1063              user_call_args,
1064              user_call_kwargs,
1065              user_call_args,
1066              unsafe_autologging_call_kwargs1,
1067          )
1068  
1069      unsafe_autologging_call_kwargs2 = copy.deepcopy(user_call_kwargs)
1070      unsafe_autologging_call_kwargs2["biz"]["new"] = Unsafe()
1071  
1072      with pytest.raises(Exception, match="Invalid new input"):
1073          _validate_args(
1074              "autologging_integration_name",
1075              "function_name",
1076              user_call_args,
1077              user_call_kwargs,
1078              user_call_args,
1079              unsafe_autologging_call_kwargs2,
1080          )
1081  
1082  
1083  @pytest.mark.usefixtures(test_mode_on.__name__)
1084  @pytest.mark.parametrize(
1085      ("baseclass", "metaclass"),
1086      [(object, ExceptionSafeClass), (abc.ABC, ExceptionSafeAbstractClass)],
1087  )
1088  def test_validate_args_succeeds_when_extra_args_are_picklable_exception_safe_functions_or_classes(
1089      baseclass, metaclass
1090  ):
1091      user_call_args = (1, "b", ["c"])
1092      user_call_kwargs = {
1093          "foo": ["bar"],
1094      }
1095  
1096      class Safe(baseclass, metaclass=metaclass):
1097          pass
1098  
1099      autologging_call_args = copy.deepcopy(user_call_args)
1100      autologging_call_args[2].append(Safe())
1101      autologging_call_args += (picklable_exception_safe_function(lambda: "foo"),)
1102  
1103      autologging_call_kwargs = copy.deepcopy(user_call_kwargs)
1104      autologging_call_kwargs["foo"].append(picklable_exception_safe_function(lambda: "foo"))
1105      autologging_call_kwargs["new"] = Safe()
1106  
1107      _validate_args(
1108          "autologging_integration_name",
1109          "function_name",
1110          user_call_args,
1111          user_call_kwargs,
1112          autologging_call_args,
1113          autologging_call_kwargs,
1114      )
1115  
1116  
1117  @pytest.mark.usefixtures(test_mode_on.__name__)
1118  def test_validate_args_throws_when_args_are_omitted():
1119      user_call_args = (1, "b", ["c"], {"d": "e"})
1120      user_call_kwargs = {
1121          "foo": ["bar"],
1122          "biz": {"baz": 4, "fuzz": 5},
1123      }
1124  
1125      invalid_autologging_call_args_1 = copy.deepcopy(user_call_args)
1126      invalid_autologging_call_args_1[2].pop()
1127      invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs)
1128      invalid_autologging_call_kwargs_1["foo"].pop()
1129  
1130      with pytest.raises(Exception, match="missing from the call"):
1131          _validate_args(
1132              "autologging_integration_name",
1133              "function_name",
1134              user_call_args,
1135              user_call_kwargs,
1136              invalid_autologging_call_args_1,
1137              user_call_kwargs,
1138          )
1139  
1140      with pytest.raises(Exception, match="missing from the call"):
1141          _validate_args(
1142              "autologging_integration_name",
1143              "function_name",
1144              user_call_args,
1145              user_call_kwargs,
1146              user_call_args,
1147              invalid_autologging_call_kwargs_1,
1148          )
1149  
1150      invalid_autologging_call_args_2 = copy.deepcopy(user_call_args)[1:]
1151      invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs)
1152      invalid_autologging_call_kwargs_2.pop("foo")
1153  
1154      with pytest.raises(Exception, match="missing from the call"):
1155          _validate_args(
1156              "autologging_integration_name",
1157              "function_name",
1158              user_call_args,
1159              user_call_kwargs,
1160              invalid_autologging_call_args_2,
1161              user_call_kwargs,
1162          )
1163  
1164      with pytest.raises(Exception, match="omit one or more expected keys"):
1165          _validate_args(
1166              "autologging_integration_name",
1167              "function_name",
1168              user_call_args,
1169              user_call_kwargs,
1170              user_call_args,
1171              invalid_autologging_call_kwargs_2,
1172          )
1173  
1174      invalid_autologging_call_args_3 = copy.deepcopy(user_call_args)
1175      invalid_autologging_call_args_3[3].pop("d")
1176      invalid_autologging_call_kwargs_3 = copy.deepcopy(user_call_kwargs)
1177      invalid_autologging_call_kwargs_3["biz"].pop("baz")
1178  
1179      with pytest.raises(Exception, match="omit one or more expected keys"):
1180          _validate_args(
1181              "autologging_integration_name",
1182              "function_name",
1183              user_call_args,
1184              user_call_kwargs,
1185              invalid_autologging_call_args_3,
1186              user_call_kwargs,
1187          )
1188  
1189      with pytest.raises(Exception, match="omit one or more expected keys"):
1190          _validate_args(
1191              "autologging_integration_name",
1192              "function_name",
1193              user_call_args,
1194              user_call_kwargs,
1195              user_call_args,
1196              invalid_autologging_call_kwargs_3,
1197          )
1198  
1199  
1200  @pytest.mark.usefixtures(test_mode_on.__name__)
1201  def test_validate_args_throws_when_arg_types_or_values_are_changed():
1202      user_call_args = (1, "b", ["c"])
1203      user_call_kwargs = {
1204          "foo": ["bar"],
1205      }
1206  
1207      invalid_autologging_call_args_1 = copy.deepcopy(user_call_args)
1208      invalid_autologging_call_args_1 = (2,) + invalid_autologging_call_args_1[1:]
1209      invalid_autologging_call_kwargs_1 = copy.deepcopy(user_call_kwargs)
1210      invalid_autologging_call_kwargs_1["foo"] = ["biz"]
1211  
1212      with pytest.raises(Exception, match="does not match expected input"):
1213          _validate_args(
1214              "autologging_integration_name",
1215              "function_name",
1216              user_call_args,
1217              user_call_kwargs,
1218              invalid_autologging_call_args_1,
1219              user_call_kwargs,
1220          )
1221  
1222      with pytest.raises(Exception, match="does not match expected input"):
1223          _validate_args(
1224              "autologging_integration_name",
1225              "function_name",
1226              user_call_args,
1227              user_call_kwargs,
1228              user_call_args,
1229              invalid_autologging_call_kwargs_1,
1230          )
1231  
1232      call_arg_1, call_arg_2, _ = copy.deepcopy(user_call_args)
1233      invalid_autologging_call_args_2 = ({"7": 1}, call_arg_1, call_arg_2)
1234      invalid_autologging_call_kwargs_2 = copy.deepcopy(user_call_kwargs)
1235      invalid_autologging_call_kwargs_2["foo"] = 8
1236  
1237      with pytest.raises(Exception, match="does not match expected type"):
1238          _validate_args(
1239              "autologging_integration_name",
1240              "function_name",
1241              user_call_args,
1242              user_call_kwargs,
1243              invalid_autologging_call_args_2,
1244              user_call_kwargs,
1245          )
1246  
1247      with pytest.raises(Exception, match="does not match expected type"):
1248          _validate_args(
1249              "autologging_integration_name",
1250              "function_name",
1251              user_call_args,
1252              user_call_kwargs,
1253              user_call_args,
1254              invalid_autologging_call_kwargs_2,
1255          )
1256  
1257  
1258  @pytest.mark.usefixtures(test_mode_on.__name__)
1259  @pytest.mark.parametrize(
1260      ("expectation", "al_name", "func_name", "user_args", "user_kwargs", "al_args", "al_kwargs"),
1261      [
1262          (
1263              does_not_raise(),
1264              "foo",
1265              "fit",
1266              (
1267                  None,
1268                  3,
1269              ),
1270              {},
1271              (
1272                  None,
1273                  4,
1274              ),
1275              {},
1276          ),
1277          (does_not_raise(), "foo", "fit", (), {"x": 3}, (), {"x": 4}),
1278          (
1279              pytest.raises(AssertionError, match="does not match expected input"),
1280              "foo",
1281              "fit",
1282              (None, None, 3),
1283              {},
1284              (
1285                  None,
1286                  None,
1287                  4,
1288              ),
1289              {},
1290          ),
1291          (
1292              pytest.raises(AssertionError, match="does not match expected input"),
1293              "foo",
1294              "fit",
1295              (),
1296              {"y": 3},
1297              (),
1298              {"y": 4},
1299          ),
1300          (
1301              pytest.raises(AssertionError, match="does not match expected input"),
1302              "foo2",
1303              "fit",
1304              (),
1305              {"x": 3},
1306              (),
1307              {"x": 4},
1308          ),
1309          (
1310              pytest.raises(AssertionError, match="does not match expected input"),
1311              "foo",
1312              "fit2",
1313              (),
1314              {"x": 3},
1315              (),
1316              {"x": 4},
1317          ),
1318          (
1319              pytest.raises(AssertionError, match="does not match expected type"),
1320              "foo",
1321              "fit",
1322              (),
1323              {"x": [1, 2]},
1324              (),
1325              {"x": 4},
1326          ),
1327          (
1328              pytest.raises(AssertionError, match="does not match expected type"),
1329              "foo",
1330              "bar",
1331              (1,),
1332              {},
1333              (None,),
1334              {},
1335          ),
1336          (
1337              pytest.raises(AssertionError, match="Invalid new input"),
1338              "foo",
1339              "bar",
1340              (None,),
1341              {},
1342              (2,),
1343              {},
1344          ),
1345          (does_not_raise(), "ml", "flow", ([1, 2, 3],), {}, ([2],), {}),
1346          (does_not_raise(), "ml", "flow", (), {"cool": [1, 2, 3]}, (), {"cool": [2]}),
1347          (
1348              pytest.raises(AssertionError, match="does not match expected type"),
1349              "ml",
1350              "flow",
1351              (),
1352              {"cool": 3},
1353              (),
1354              {"cool": [2]},
1355          ),
1356      ],
1357  )
1358  def test_validate_args_respects_validation_exemptions(
1359      expectation, al_name, func_name, user_args, user_kwargs, al_args, al_kwargs
1360  ):
1361      with (
1362          mock.patch(
1363              "mlflow.utils.autologging_utils.safety._VALIDATION_EXEMPT_ARGUMENTS",
1364              [
1365                  ValidationExemptArgument("foo", "fit", lambda x: isinstance(x, int), 1, "x"),
1366                  ValidationExemptArgument("ml", "flow", lambda z: isinstance(z, list), 0, "cool"),
1367              ],
1368          ),
1369          expectation,
1370      ):
1371          _validate_args(al_name, func_name, user_args, user_kwargs, al_args, al_kwargs)
1372  
1373  
1374  def test_validate_autologging_run_validates_autologging_tag_correctly():
1375      with mlflow.start_run():
1376          run_id_1 = mlflow.active_run().info.run_id
1377  
1378      with pytest.raises(AssertionError, match="failed to set autologging tag with expected value"):
1379          _validate_autologging_run("test_integration", run_id_1)
1380  
1381      with mlflow.start_run(tags={MLFLOW_AUTOLOGGING: "wrong_value"}):
1382          run_id_2 = mlflow.active_run().info.run_id
1383  
1384      with pytest.raises(
1385          AssertionError, match="failed to set autologging tag with expected value.*wrong_value"
1386      ):
1387          _validate_autologging_run("test_integration", run_id_2)
1388  
1389      with mlflow.start_run(tags={MLFLOW_AUTOLOGGING: "test_integration"}):
1390          run_id_3 = mlflow.active_run().info.run_id
1391  
1392      _validate_autologging_run("test_integration", run_id_3)
1393  
1394  
1395  def test_validate_autologging_run_validates_run_status_correctly():
1396      valid_autologging_tags = {
1397          MLFLOW_AUTOLOGGING: "test_integration",
1398      }
1399  
1400      with mlflow.start_run(tags=valid_autologging_tags) as run_finished:
1401          run_id_finished = run_finished.info.run_id
1402  
1403      assert (
1404          RunStatus.from_string(MlflowClient().get_run(run_id_finished).info.status)
1405          == RunStatus.FINISHED
1406      )
1407      _validate_autologging_run("test_integration", run_id_finished)
1408  
1409      with mlflow.start_run(tags=valid_autologging_tags) as run_failed:
1410          run_id_failed = run_failed.info.run_id
1411  
1412      MlflowClient().set_terminated(run_id_failed, status=RunStatus.to_string(RunStatus.FAILED))
1413      assert (
1414          RunStatus.from_string(MlflowClient().get_run(run_id_failed).info.status) == RunStatus.FAILED
1415      )
1416      _validate_autologging_run("test_integration", run_id_finished)
1417  
1418      run_non_terminal = MlflowClient().create_run(
1419          experiment_id=run_finished.info.experiment_id, tags=valid_autologging_tags
1420      )
1421      run_id_non_terminal = run_non_terminal.info.run_id
1422      assert (
1423          RunStatus.from_string(MlflowClient().get_run(run_id_non_terminal).info.status)
1424          == RunStatus.RUNNING
1425      )
1426      with pytest.raises(AssertionError, match="has a non-terminal status"):
1427          _validate_autologging_run("test_integration", run_id_non_terminal)
1428  
1429  
1430  def test_session_manager_creates_session_before_patch_executes(
1431      patch_destination, test_autologging_integration
1432  ):
1433      is_session_active = None
1434  
1435      @asyncify(patch_destination.is_async)
1436      def check_session_manager_status(original):
1437          nonlocal is_session_active
1438          is_session_active = _AutologgingSessionManager.active_session()
1439  
1440      safe_patch(test_autologging_integration, patch_destination, "fn", check_session_manager_status)
1441      run_sync_or_async(patch_destination.fn)
1442      assert is_session_active is not None
1443  
1444  
1445  def test_session_manager_exits_session_after_patch_executes(
1446      patch_destination, test_autologging_integration
1447  ):
1448      @asyncify(patch_destination.is_async)
1449      def patch_fn(original):
1450          assert _AutologgingSessionManager.active_session() is not None
1451  
1452      safe_patch(test_autologging_integration, patch_destination, "fn", patch_fn)
1453      run_sync_or_async(patch_destination.fn)
1454      assert _AutologgingSessionManager.active_session() is None
1455  
1456  
1457  def test_session_manager_terminates_session_when_appropriate():
1458      with _AutologgingSessionManager.start_session("test_integration") as outer_sess:
1459          assert outer_sess
1460  
1461          with _AutologgingSessionManager.start_session("test_integration") as inner_sess:
1462              assert _AutologgingSessionManager.active_session() == inner_sess == outer_sess
1463  
1464          assert _AutologgingSessionManager.active_session() == outer_sess
1465  
1466      assert not _AutologgingSessionManager.active_session()
1467  
1468  
1469  def test_original_fn_runs_if_patch_should_not_be_applied(patch_destination):
1470      patch_impl_call_count = 0
1471  
1472      @autologging_integration("test_respects_exclusive")
1473      def autolog(disable=False, exclusive=False, silent=False):
1474          @asyncify(patch_destination.is_async)
1475          def patch_impl(original, *args, **kwargs):
1476              nonlocal patch_impl_call_count
1477              patch_impl_call_count += 1
1478              return original(*args, **kwargs)
1479  
1480          safe_patch("test_respects_exclusive", patch_destination, "fn", patch_impl)
1481  
1482      autolog(exclusive=True)
1483      with mlflow.start_run():
1484          run_sync_or_async(patch_destination.fn)
1485      assert patch_impl_call_count == 0
1486      assert patch_destination.fn_call_count == 1
1487  
1488  
1489  def test_patch_runs_if_patch_should_be_applied():
1490      patch_impl_call_count = 0
1491  
1492      class TestPatchWithNewFnObj:
1493          def __init__(self):
1494              self.fn_call_count = 0
1495  
1496          def fn(self, *args, **kwargs):
1497              self.fn_call_count += 1
1498              return PATCH_DESTINATION_FN_DEFAULT_RESULT
1499  
1500          def new_fn(self, *args, **kwargs):
1501              with mlflow.start_run():
1502                  self.fn()
1503  
1504      patch_obj = TestPatchWithNewFnObj()
1505  
1506      @autologging_integration("test_respects_exclusive")
1507      def autolog(disable=False, exclusive=False, silent=False):
1508          def patch_impl(original, *args, **kwargs):
1509              nonlocal patch_impl_call_count
1510              patch_impl_call_count += 1
1511  
1512          def new_fn_patch(original, *args, **kwargs):
1513              pass
1514  
1515          safe_patch("test_respects_exclusive", patch_obj, "fn", patch_impl)
1516          safe_patch("test_respects_exclusive", patch_obj, "new_fn", new_fn_patch)
1517  
1518      # Should patch if no active run
1519      autolog()
1520      patch_obj.fn()
1521      assert patch_impl_call_count == 1
1522  
1523      # Should patch if active run, but not exclusive
1524      autolog(exclusive=False)
1525      with mlflow.start_run():
1526          patch_obj.fn()
1527      assert patch_impl_call_count == 2
1528  
1529      # Should patch if active run and exclusive, but active autologging session
1530      autolog(exclusive=True)
1531      patch_obj.new_fn()
1532      assert patch_impl_call_count == 3
1533  
1534  
1535  def test_nested_call_autologging_disabled_when_top_level_call_autologging_failed(patch_destination):
1536      patch_impl_call_count = 0
1537  
1538      @autologging_integration(
1539          "test_nested_call_autologging_disabled_when_top_level_call_autologging_failed"
1540      )
1541      def autolog(disable=False, exclusive=False, silent=False):
1542          @asyncify(patch_destination.is_async)
1543          def patch_impl(original, *args, **kwargs):
1544              nonlocal patch_impl_call_count
1545              patch_impl_call_count += 1
1546  
1547              level = kwargs["level"]
1548  
1549              if level == 0:
1550                  raise RuntimeError("analog top level call autologging failure.")
1551  
1552              return original(*args, **kwargs)
1553  
1554          safe_patch(
1555              "test_nested_call_autologging_disabled_when_top_level_call_autologging_failed",
1556              patch_destination,
1557              "recursive_fn",
1558              patch_impl,
1559          )
1560  
1561      autolog()
1562      for max_depth in [1, 2, 3]:
1563          patch_impl_call_count = 0
1564          patch_destination.recurse_fn_call_count = 0
1565          with mlflow.start_run():
1566              run_sync_or_async(patch_destination.recursive_fn, level=0, max_depth=max_depth)
1567          assert patch_impl_call_count == 1
1568          assert patch_destination.recurse_fn_call_count == max_depth + 1
1569  
1570  
1571  def test_old_patch_reverted_before_run_autolog_fn():
1572      class PatchDestination:
1573          def f1(self):
1574              pass
1575  
1576      original_f1 = PatchDestination.f1
1577  
1578      @autologging_integration("test_old_patch_reverted_before_run_autolog_fn")
1579      def autolog(disable=False, exclusive=False, silent=False):
1580          assert PatchDestination.f1 is original_f1  # assert old patch has been reverted.
1581  
1582          def patch_impl(original, *args, **kwargs):
1583              pass
1584  
1585          safe_patch(
1586              "test_old_patch_reverted_before_run_autolog_fn",
1587              PatchDestination,
1588              "f1",
1589              patch_impl,
1590          )
1591  
1592      autolog(disable=True)
1593      autolog()
1594      autolog()  # Test second time call autolog will revert first autolog call installed patch
1595  
1596  
1597  def test_safe_patch_support_property_decorated_method():
1598      class BaseEstimator:
1599          def __init__(self, has_predict):
1600              self._has_predict = has_predict
1601  
1602          def _predict(self, X, a, b):
1603              return {"X": X, "a": a, "b": b}
1604  
1605          @property
1606          def predict(self):
1607              if not self._has_predict:
1608                  raise AttributeError("does not have predict")
1609              return self._predict
1610  
1611      class ExtendedEstimator(BaseEstimator):
1612          pass
1613  
1614      original_base_estimator_predict = object.__getattribute__(BaseEstimator, "predict")
1615  
1616      def patched_predict(original, self, *args, **kwargs):
1617          result = original(self, *args, **kwargs)
1618          if "patch_count" not in result:
1619              result["patch_count"] = 1
1620          else:
1621              result["patch_count"] += 1
1622          return result
1623  
1624      flavor_name = "test_if_delegate_has_method_decorated_method_patch"
1625  
1626      @autologging_integration(flavor_name)
1627      def autolog(disable=False, exclusive=False, silent=False):
1628          mlflow.sklearn._patch_estimator_method_if_available(
1629              flavor_name,
1630              BaseEstimator,
1631              "predict",
1632              patched_predict,
1633              manage_run=False,
1634          )
1635          mlflow.sklearn._patch_estimator_method_if_available(
1636              flavor_name,
1637              ExtendedEstimator,
1638              "predict",
1639              patched_predict,
1640              manage_run=False,
1641          )
1642  
1643      autolog()
1644  
1645      for EstimatorCls in [BaseEstimator, ExtendedEstimator]:
1646          assert EstimatorCls.predict.__doc__ == original_base_estimator_predict.__doc__
1647          good_estimator = EstimatorCls(has_predict=True)
1648          assert good_estimator.predict.__doc__ == original_base_estimator_predict.__doc__
1649  
1650          expected_result = {"X": 1, "a": 2, "b": 3, "patch_count": 1}
1651          assert hasattr(good_estimator, "predict")
1652          assert good_estimator.predict(X=1, a=2, b=3) == expected_result
1653          assert good_estimator.predict(1, a=2, b=3) == expected_result
1654          assert good_estimator.predict(1, 2, b=3) == expected_result
1655          assert good_estimator.predict(1, 2, 3) == expected_result
1656  
1657          bad_estimator = EstimatorCls(has_predict=False)
1658          assert not hasattr(bad_estimator, "predict")
1659          with pytest.raises(AttributeError, match="does not have predict"):
1660              bad_estimator.predict(X=1, a=2, b=3)
1661  
1662      autolog(disable=True)
1663      assert original_base_estimator_predict is object.__getattribute__(BaseEstimator, "predict")
1664      assert "predict" not in ExtendedEstimator.__dict__
1665  
1666  
1667  def test_safe_patch_preserves_original_function_attributes():
1668      class Test1:
1669          def predict(self, X, a, b):
1670              """
1671              Test doc for Test1.predict
1672              """
1673  
1674      def patched_predict(original, self, *args, **kwargs):
1675          return original(self, *args, **kwargs)
1676  
1677      flavor_name = "test_safe_patch_preserves_original_function_attributes"
1678  
1679      @autologging_integration(flavor_name)
1680      def autolog(disable=False, exclusive=False, silent=False):
1681          safe_patch(flavor_name, Test1, "predict", patched_predict, manage_run=False)
1682  
1683      original_predict = Test1.predict
1684      autolog()
1685      assert get_func_attrs(Test1.predict) == get_func_attrs(original_predict)