safety.py
   1  import abc
   2  import functools
   3  import inspect
   4  import itertools
   5  import uuid
   6  from contextlib import asynccontextmanager, contextmanager
   7  from typing import Any, Callable, NamedTuple
   8  
   9  import mlflow
  10  import mlflow.utils.autologging_utils
  11  from mlflow.entities.run_status import RunStatus
  12  from mlflow.environment_variables import _MLFLOW_AUTOLOGGING_TESTING
  13  from mlflow.exceptions import MlflowException
  14  from mlflow.utils import gorilla, is_iterator
  15  from mlflow.utils.autologging_utils import _logger
  16  from mlflow.utils.autologging_utils.events import AutologgingEventLoggerWrapper
  17  from mlflow.utils.autologging_utils.logging_and_warnings import (
  18      MlflowEventsAndWarningsBehaviorGlobally,
  19      NonMlflowWarningsBehaviorForCurrentThread,
  20  )
  21  from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING
  22  
  23  _AUTOLOGGING_PATCHES = {}
  24  _AUTOLOGGING_CLEANUP_CALLBACKS = {}
  25  
  26  
  27  # Function attribute used for testing purposes to verify that a given function
  28  # has been wrapped with the `exception_safe_function_for_class` and
  29  # `picklable_exception_safe_function` decorators
  30  _ATTRIBUTE_EXCEPTION_SAFE = "exception_safe"
  31  
  32  
  33  _ERROR_MSG = "Encountered unexpected error during {} autologging: {}"
  34  
  35  
  36  def exception_safe_function_for_class(function):
  37      """
  38      Wraps the specified function with broad exception handling to guard
  39      against unexpected errors during autologging.
  40      Note this function creates an unpicklable function as `safe_function` is locally defined,
  41      but a class instance containing methods decorated by this function should be pickalable,
  42      because pickle only saves instance attributes, not methods.
  43      See https://docs.python.org/3/library/pickle.html#pickling-class-instances for more details.
  44      """
  45      if is_testing():
  46          setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
  47  
  48      def safe_function(*args, **kwargs):
  49          try:
  50              return function(*args, **kwargs)
  51          except Exception as e:
  52              if is_testing():
  53                  raise
  54              else:
  55                  _logger.warning("Encountered unexpected error during autologging: %s", e)
  56  
  57      return update_wrapper_extended(safe_function, function)
  58  
  59  
  60  def _safe_function(function, *args, **kwargs):
  61      try:
  62          return function(*args, **kwargs)
  63      except Exception as e:
  64          if is_testing():
  65              raise
  66          else:
  67              _logger.warning("Encountered unexpected error during autologging: %s", e)
  68  
  69  
  70  def picklable_exception_safe_function(function):
  71      """
  72      Wraps the specified function with broad exception handling to guard
  73      against unexpected errors during autologging while preserving picklability.
  74      """
  75      if is_testing():
  76          setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True)
  77  
  78      return update_wrapper_extended(functools.partial(_safe_function, function), function)
  79  
  80  
  81  def _exception_safe_class_factory(base_class):
  82      """
  83      Creates an exception safe metaclass that inherits from `base_class`.
  84      """
  85  
  86      class _ExceptionSafeClass(base_class):
  87          """
  88          Metaclass that wraps all functions defined on the specified class with broad error handling
  89          logic to guard against unexpected errors during autlogging.
  90  
  91          Rationale: Patched autologging functions commonly pass additional class instances as
  92          arguments to their underlying original training routines; for example, Keras autologging
  93          constructs a subclass of `keras.callbacks.Callback` and forwards it to `Model.fit()`.
  94          To prevent errors encountered during method execution within such classes from disrupting
  95          model training, this metaclass wraps all class functions in a broad try / catch statement.
  96  
  97          Note: `ExceptionSafeClass` does not handle exceptions in class methods or static methods,
  98          as these are not always Python callables and are difficult to wrap
  99          """
 100  
 101          def __new__(cls, name, bases, dct):
 102              for m in dct:
 103                  # class methods or static methods are not callable.
 104                  if callable(dct[m]):
 105                      dct[m] = exception_safe_function_for_class(dct[m])
 106              return base_class.__new__(cls, name, bases, dct)
 107  
 108      return _ExceptionSafeClass
 109  
 110  
 111  ExceptionSafeClass = _exception_safe_class_factory(type)
 112  
 113  # `ExceptionSafeClass` causes an error when used with an abstract class.
 114  #
 115  # ```
 116  # class AbstractClass(abc.ABC):
 117  #    ...
 118  #
 119  # class DerivedClass(AbstractClass, metaclass=ExceptionSafeClass):
 120  #    ...
 121  # ```
 122  #
 123  # This raises:
 124  #
 125  # ```
 126  # TypeError: metaclass conflict: the metaclass of a derived class must be
 127  #            a (non-strict) subclass of the metaclasses of all its bases.
 128  # ```
 129  #
 130  # To avoid this error, create `ExceptionSafeAbstractClass` that is based on `abc.ABCMeta`.
 131  ExceptionSafeAbstractClass = _exception_safe_class_factory(abc.ABCMeta)
 132  
 133  
 134  def with_managed_run(autologging_integration, patch_function, tags=None):
 135      """Given a `patch_function`, returns an `augmented_patch_function` that wraps the execution of
 136      `patch_function` with an active MLflow run. The following properties apply:
 137  
 138          - An MLflow run is only created if there is no active run present when the
 139            patch function is executed
 140  
 141          - If an active run is created by the `augmented_patch_function`, it is terminated
 142            with the `FINISHED` state at the end of function execution
 143  
 144          - If an active run is created by the `augmented_patch_function`, it is terminated
 145            with the `FAILED` if an unhandled exception is thrown during function execution
 146  
 147      Note that, if nested runs or non-fluent runs are created by `patch_function`, `patch_function`
 148      is responsible for terminating them by the time it terminates
 149      (or in the event of an exception).
 150  
 151      Args:
 152          autologging_integration: The autologging integration associated
 153              with the `patch_function`.
 154          patch_function: A function object compatible with `safe_patch`.
 155          tags: A dictionary of string tags to set on each managed run created during the
 156              execution of `patch_function`.
 157      """
 158      from mlflow.tracking.fluent import active_run
 159      from mlflow.utils.autologging_utils import _has_active_training_session
 160  
 161      def create_managed_run():
 162          managed_run = mlflow.start_run(tags=tags)
 163          _logger.info(
 164              "Created MLflow autologging run with ID '%s', which will track hyperparameters,"
 165              " performance metrics, model artifacts, and lineage information for the"
 166              " current %s workflow",
 167              managed_run.info.run_id,
 168              autologging_integration,
 169          )
 170          return managed_run
 171  
 172      def patch_with_managed_run(original, *args, **kwargs):
 173          managed_run = None
 174          # If there is an active training session but there is no active run
 175          # in current thread, it means the thread is spawned by `estimator.fit`
 176          # as a worker thread, we should disable autologging in
 177          # these worker threads, so skip creating managed run.
 178          if not active_run() and not _has_active_training_session():
 179              managed_run = create_managed_run()
 180  
 181          try:
 182              result = patch_function(original, *args, **kwargs)
 183          except (Exception, KeyboardInterrupt):
 184              # In addition to standard Python exceptions, handle keyboard interrupts to ensure
 185              # that runs are terminated if a user prematurely interrupts training execution
 186              # (e.g. via sigint / ctrl-c)
 187              if managed_run:
 188                  mlflow.end_run(RunStatus.to_string(RunStatus.FAILED))
 189              raise
 190          else:
 191              if managed_run:
 192                  mlflow.end_run(RunStatus.to_string(RunStatus.FINISHED))
 193              return result
 194  
 195      return patch_with_managed_run
 196  
 197  
 198  def is_testing():
 199      """
 200      Indicates whether or not autologging functionality is running in test mode (as determined
 201      by the `MLFLOW_AUTOLOGGING_TESTING` environment variable). Test mode performs additional
 202      validation during autologging, including:
 203  
 204          - Checks for the exception safety of arguments passed to model training functions
 205            (i.e. all additional arguments should be "exception safe" functions or classes)
 206          - Disables exception handling for patched function logic, ensuring that patch code
 207            executes without errors during testing
 208      """
 209      return _MLFLOW_AUTOLOGGING_TESTING.get()
 210  
 211  
 212  def _resolve_extra_tags(autologging_integration, extra_tags):
 213      tags = {MLFLOW_AUTOLOGGING: autologging_integration}
 214      if extra_tags:
 215          if isinstance(extra_tags, dict):
 216              if MLFLOW_AUTOLOGGING in extra_tags:
 217                  extra_tags.pop(MLFLOW_AUTOLOGGING)
 218                  _logger.warning(
 219                      f"Tag `{MLFLOW_AUTOLOGGING}` is ignored as it is a reserved tag by MLflow "
 220                      f"autologging."
 221                  )
 222              tags.update(extra_tags)
 223          else:
 224              raise mlflow.exceptions.MlflowException.invalid_parameter_value(
 225                  f"Invalid `extra_tags` type: expecting dictionary, "
 226                  f"received `{type(extra_tags).__name__}`"
 227              )
 228      return tags
 229  
 230  
 231  def safe_patch(
 232      autologging_integration,
 233      destination,
 234      function_name,
 235      patch_function,
 236      manage_run=False,
 237      extra_tags=None,
 238  ):
 239      """Patches the specified `function_name` on the specified `destination` class for autologging
 240      purposes, preceding its implementation with an error-safe copy of the specified patch
 241      `patch_function` with the following error handling behavior:
 242          - Exceptions thrown from the underlying / original function
 243            (`<destination>.<function_name>`) are propagated to the caller.
 244          - Exceptions thrown from other parts of the patched implementation (`patch_function`)
 245            are caught and logged as warnings.
 246  
 247      Args:
 248          autologging_integration: The name of the autologging integration associated with the
 249              patch.
 250          destination: The Python class on which the patch is being defined.
 251          function_name: The name of the function to patch on the specified `destination` class.
 252          patch_function: The patched function code to apply. The first argument should be reserved
 253              for an `original` argument representing the underlying / original function. Subsequent
 254              arguments should be identical to those of the original function being patched.
 255          manage_run: If `True`, applies the `with_managed_run` wrapper to the specified
 256              `patch_function`, which automatically creates & terminates an MLflow
 257              active run during patch code execution if necessary. If `False`,
 258              does not apply the `with_managed_run` wrapper to the specified
 259              `patch_function`.
 260          extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
 261      """
 262      from mlflow.tracking.fluent import active_run
 263      from mlflow.utils.autologging_utils import autologging_is_disabled, get_autologging_config
 264  
 265      # NB: Checking the signature of the patch function rather than original, so that we don't
 266      # accidentally change the behavior of existing patches that may use sync patch function
 267      # for async original functions (e.g. LangChain).
 268      is_async_function = inspect.iscoroutinefunction(patch_function)
 269  
 270      if manage_run:
 271          if is_async_function:
 272              raise MlflowException("manage_run parameter is not supported for async functions.")
 273  
 274          tags = _resolve_extra_tags(autologging_integration, extra_tags)
 275          patch_function = with_managed_run(
 276              autologging_integration,
 277              patch_function,
 278              tags=tags,
 279          )
 280  
 281      original_fn = gorilla.get_original_attribute(
 282          destination, function_name, bypass_descriptor_protocol=False
 283      )
 284      # Retrieve raw attribute while bypassing the descriptor protocol
 285      raw_original_obj = gorilla.get_original_attribute(
 286          destination, function_name, bypass_descriptor_protocol=True
 287      )
 288      if original_fn != raw_original_obj:
 289          raise RuntimeError(f"Unsupported patch on {destination}.{function_name}")
 290      elif isinstance(original_fn, property):
 291          if is_async_function:
 292              raise MlflowException("Patching async property methods is not supported.")
 293  
 294          is_property_method = True
 295  
 296          # For property decorated methods (a kind of method delegation), e.g.
 297          # class A:
 298          #   @property
 299          #   def f1(self):
 300          #     ...
 301          #     return delegated_f1
 302          #
 303          # suppose `a1` is an instance of class `A`,
 304          # `A.f1.fget` will get the original `def f1(self)` method,
 305          # and `A.f1.fget(a1)` will be equivalent to `a1.f1()` and
 306          # its return value will be the `delegated_f1` function.
 307          # So using the `property.fget` we can construct the (delegated) "original_fn"
 308          def original(self, *args, **kwargs):
 309              # the `original_fn.fget` will get the original method decorated by `property`
 310              # the `original_fn.fget(self)` will get the delegated function returned by the
 311              # property decorated method.
 312              bound_delegate_method = original_fn.fget(self)
 313              return bound_delegate_method(*args, **kwargs)
 314  
 315      else:
 316          original = original_fn
 317          is_property_method = False
 318  
 319      def safe_patch_function(*args, **kwargs):
 320          """
 321          A safe wrapper around the specified `patch_function` implementation designed to
 322          handle exceptions thrown during the execution of `patch_function`. This wrapper
 323          distinguishes exceptions thrown from the underlying / original function
 324          (`<destination>.<function_name>`) from exceptions thrown from other parts of
 325          `patch_function`. This distinction is made by passing an augmented version of the
 326          underlying / original function to `patch_function` that uses nonlocal state to track
 327          whether or not it has been executed and whether or not it threw an exception.
 328          Exceptions thrown from the underlying / original function are propagated to the caller,
 329          while exceptions thrown from other parts of `patch_function` are caught and logged as
 330          warnings.
 331  
 332          NB: PLEASE BE SUPER CAREFUL WHEN MODIFYING THIS FUNCTION. IT IS USED IN A WIDE VARIETY
 333          OF CONTEXTX AND CRITICAL PATH IN DBR/MLR BY DEFAULT. ANY BUG HERE CAN BREAK USERS'
 334          WORKLOAD WITHOUT THEM TAKING ANY ACTION.
 335          """
 336          # Reroute warnings encountered during the patch function implementation to an MLflow event
 337          # logger, and enforce silent mode if applicable (i.e. if the corresponding autologging
 338          # integration was called with `silent=True`), hiding MLflow event logging statements and
 339          # hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding
 340          # the user's original / underlying ML function). Non-MLflow warnings are enabled during the
 341          # execution of the original / underlying ML function
 342          #
 343          # Note that we've opted *not* to apply this context manager as a decorator on
 344          # `safe_patch_function` because the context-manager-as-decorator pattern uses
 345          # `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled
 346          # during model serialization by ML frameworks such as scikit-learn
 347          is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
 348          with (
 349              MlflowEventsAndWarningsBehaviorGlobally(
 350                  # MLflow warnings emitted during autologging training sessions are likely not
 351                  # actionable and result from the autologging implementation invoking another MLflow
 352                  # API. Accordingly, we reroute these warnings to the MLflow event logger with level
 353                  # WARNING For reference, see recommended warning and event logging behaviors from
 354                  # https://docs.python.org/3/howto/logging.html#when-to-use-logging
 355                  reroute_warnings=True,
 356                  disable_event_logs=is_silent_mode,
 357                  disable_warnings=is_silent_mode,
 358              ),
 359              NonMlflowWarningsBehaviorForCurrentThread(
 360                  # non-MLflow Warnings emitted during the autologging preamble (before the original /
 361                  # underlying ML function is called) and postamble (after the original / underlying
 362                  # ML function is called) are likely not actionable and result from the autologging
 363                  # implementation invoking an API from a dependent library. Accordingly, we reroute
 364                  # these warnings to the MLflow event logger with level WARNING. For reference, see
 365                  # recommended warning and event logging behaviors from
 366                  # https://docs.python.org/3/howto/logging.html#when-to-use-logging
 367                  reroute_warnings=True,
 368                  disable_warnings=is_silent_mode,
 369              ),
 370          ):
 371              if is_testing():
 372                  preexisting_run_for_testing = active_run()
 373  
 374              # Whether or not to exclude autologged content from user-created fluent runs
 375              # (i.e. runs created manually via `mlflow.start_run()`)
 376              exclusive = get_autologging_config(autologging_integration, "exclusive", False)
 377              user_created_fluent_run_is_active = (
 378                  active_run() and not _AutologgingSessionManager.active_session()
 379              )
 380              active_session_failed = (
 381                  _AutologgingSessionManager.active_session() is not None
 382                  and _AutologgingSessionManager.active_session().state == "failed"
 383              )
 384  
 385              if (
 386                  active_session_failed
 387                  or autologging_is_disabled(autologging_integration)
 388                  or (user_created_fluent_run_is_active and exclusive)
 389                  or (
 390                      mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
 391                      and autologging_integration
 392                  )
 393              ):
 394                  # If the autologging integration associated with this patch is disabled,
 395                  # or if the current autologging integration is in exclusive mode and a user-created
 396                  # fluent run is active, call the original function and return. Restore the original
 397                  # warning behavior during original function execution, since autologging is being
 398                  # skipped
 399                  with NonMlflowWarningsBehaviorForCurrentThread(
 400                      disable_warnings=False,
 401                      reroute_warnings=False,
 402                  ):
 403                      return original(*args, **kwargs)
 404  
 405              # Whether or not the original / underlying function has been called during the
 406              # execution of patched code
 407              original_has_been_called = False
 408              # The value returned by the call to the original / underlying function during
 409              # the execution of patched code
 410              original_result = None
 411              # Whether or not an exception was raised from within the original / underlying function
 412              # during the execution of patched code
 413              failed_during_original = False
 414              # The active MLflow run (if any) associated with patch code execution
 415              patch_function_run_for_testing = None
 416              # The exception raised during executing patching function
 417              patch_error = None
 418  
 419              with _AutologgingSessionManager.start_session(autologging_integration) as session:
 420                  event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
 421  
 422                  def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
 423                      try:
 424                          event_logger.log_original_function_start(og_args, og_kwargs)
 425  
 426                          original_fn_result = original_fn(*og_args, **og_kwargs)
 427  
 428                          event_logger.log_original_function_success(og_args, og_kwargs)
 429                          return original_fn_result
 430                      except Exception as e:
 431                          event_logger.log_original_function_error(og_args, og_kwargs, e)
 432  
 433                          nonlocal failed_during_original
 434                          failed_during_original = True
 435                          raise
 436  
 437                  try:
 438  
 439                      def call_original(*og_args, **og_kwargs):
 440                          def _original_fn(*_og_args, **_og_kwargs):
 441                              if is_testing():
 442                                  _validate_args(
 443                                      autologging_integration,
 444                                      function_name,
 445                                      args,
 446                                      kwargs,
 447                                      og_args,
 448                                      og_kwargs,
 449                                  )
 450                                  # By the time `original` is called by the patch implementation, we
 451                                  # assume that either: 1. the patch implementation has already
 452                                  # created an MLflow run or 2. the patch code will not create an
 453                                  # MLflow run during the current execution. Here, we capture a
 454                                  # reference to the active run, which we will use later on to
 455                                  # determine whether or not the patch implementation created
 456                                  # a run and perform validation if necessary
 457                                  nonlocal patch_function_run_for_testing
 458                                  patch_function_run_for_testing = active_run()
 459  
 460                              nonlocal original_has_been_called
 461                              original_has_been_called = True
 462  
 463                              nonlocal original_result
 464                              # Show all non-MLflow warnings as normal (i.e. not as event logs)
 465                              # during original function execution, even if silent mode is enabled
 466                              # (`silent=True`), since these warnings originate from the ML framework
 467                              # or one of its dependencies and are likely relevant to the caller
 468                              with NonMlflowWarningsBehaviorForCurrentThread(
 469                                  disable_warnings=False,
 470                                  reroute_warnings=False,
 471                              ):
 472                                  original_result = original(*_og_args, **_og_kwargs)
 473                                  return original_result
 474  
 475                          return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
 476  
 477                      # Apply the name, docstring, and signature of `original` to `call_original`.
 478                      # This is important because several autologging patch implementations inspect
 479                      # the signature of the `original` argument during execution
 480                      call_original = update_wrapper_extended(call_original, original)
 481  
 482                      event_logger.log_patch_function_start(args, kwargs)
 483  
 484                      patch_function(call_original, *args, **kwargs)
 485  
 486                      session.state = "succeeded"
 487                      event_logger.log_patch_function_success(args, kwargs)
 488  
 489                  except Exception as e:
 490                      session.state = "failed"
 491                      patch_error = e
 492                      # Exceptions thrown during execution of the original function should be
 493                      # propagated to the caller. Additionally, exceptions encountered during test
 494                      # mode should be reraised to detect bugs in autologging implementations
 495                      if failed_during_original or is_testing():
 496                          raise
 497  
 498                  if is_testing() and not preexisting_run_for_testing:
 499                      # If an MLflow run was created during the execution of patch code, verify that
 500                      # it is no longer active and that it contains expected autologging tags
 501                      assert not active_run(), (
 502                          f"Autologging integration {autologging_integration} leaked an active run"
 503                      )
 504                      if patch_function_run_for_testing:
 505                          _validate_autologging_run(
 506                              autologging_integration, patch_function_run_for_testing.info.run_id
 507                          )
 508                  try:
 509                      if original_has_been_called:
 510                          return original_result
 511                      else:
 512                          return call_original_fn_with_event_logging(original, args, kwargs)
 513                  finally:
 514                      # If original function succeeds, but `patch_function_exception` exists,
 515                      # it represent patching code unexpected failure, so we call
 516                      # `log_patch_function_error` in this case.
 517                      # If original function failed, we don't call `log_patch_function_error`
 518                      # even if `patch_function_exception` exists, because original function failure
 519                      # means there's some error in user code (e.g. user provide wrong arguments)
 520                      if patch_error is not None and not failed_during_original:
 521                          event_logger.log_patch_function_error(args, kwargs, patch_error)
 522                          _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
 523  
 524      async def async_safe_patch_function(*args, **kwargs):
 525          """
 526          Async version of safe_patch_function.
 527  
 528          This code brainlessly copies the synchronous version of the function, but with async
 529          context managers and async functions. This is done to avoid the risk of introducing
 530          any bugs or regressions in the async version of the function. Note that we need to
 531          be really careful here, because autologging is enabled by-default in DBR/MLR, hence
 532          any bug here can break users' workload without them taking any action.
 533  
 534          That said, some long comments are omitted in this version to avoid redundancy. If
 535          you want to understand the context of the code better, please refer to the
 536          synchronous version as well.
 537          """
 538          is_silent_mode = get_autologging_config(autologging_integration, "silent", False)
 539          async with (
 540              MlflowEventsAndWarningsBehaviorGlobally(
 541                  reroute_warnings=True,
 542                  disable_event_logs=is_silent_mode,
 543                  disable_warnings=is_silent_mode,
 544              ),
 545              NonMlflowWarningsBehaviorForCurrentThread(
 546                  disable_warnings=is_silent_mode,
 547                  reroute_warnings=True,
 548              ),
 549          ):
 550              if is_testing():
 551                  preexisting_run_for_testing = active_run()
 552  
 553              # Whether or not to exclude autologged content from user-created fluent runs
 554              # (i.e. runs created manually via `mlflow.start_run()`)
 555              exclusive = get_autologging_config(autologging_integration, "exclusive", False)
 556              user_created_fluent_run_is_active = (
 557                  active_run() and not _AutologgingSessionManager.active_session()
 558              )
 559              active_session_failed = (
 560                  _AutologgingSessionManager.active_session() is not None
 561                  and _AutologgingSessionManager.active_session().state == "failed"
 562              )
 563  
 564              if (
 565                  active_session_failed
 566                  or autologging_is_disabled(autologging_integration)
 567                  or (user_created_fluent_run_is_active and exclusive)
 568                  or (
 569                      mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED
 570                      and autologging_integration
 571                  )
 572              ):
 573                  async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
 574                      return await original(*args, **kwargs)
 575  
 576              original_has_been_called = False
 577              original_result = None
 578              failed_during_original = False
 579              patch_function_run_for_testing = None
 580              patch_error = None
 581  
 582              async with _AutologgingSessionManager.astart_session(
 583                  autologging_integration
 584              ) as session:
 585                  event_logger = AutologgingEventLoggerWrapper(session, destination, function_name)
 586  
 587                  async def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs):
 588                      try:
 589                          event_logger.log_original_function_start(og_args, og_kwargs)
 590                          original_fn_result = await original_fn(*og_args, **og_kwargs)
 591                          event_logger.log_original_function_success(og_args, og_kwargs)
 592                          return original_fn_result
 593                      except Exception as e:
 594                          event_logger.log_original_function_error(og_args, og_kwargs, e)
 595                          nonlocal failed_during_original
 596                          failed_during_original = True
 597                          raise
 598  
 599                  try:
 600  
 601                      async def call_original(*og_args, **og_kwargs):
 602                          async def _original_fn(*_og_args, **_og_kwargs):
 603                              if is_testing():
 604                                  _validate_args(
 605                                      autologging_integration,
 606                                      function_name,
 607                                      args,
 608                                      kwargs,
 609                                      og_args,
 610                                      og_kwargs,
 611                                  )
 612                                  nonlocal patch_function_run_for_testing
 613                                  patch_function_run_for_testing = active_run()
 614  
 615                              nonlocal original_has_been_called
 616                              original_has_been_called = True
 617  
 618                              nonlocal original_result
 619                              async with NonMlflowWarningsBehaviorForCurrentThread(False, False):
 620                                  original_result = await original(*_og_args, **_og_kwargs)
 621                                  return original_result
 622  
 623                          return await call_original_fn_with_event_logging(
 624                              _original_fn, og_args, og_kwargs
 625                          )
 626  
 627                      # Apply the name, docstring, and signature of `original` to `call_original`.
 628                      # This is important because several autologging patch implementations inspect
 629                      # the signature of the `original` argument during execution
 630                      call_original = update_wrapper_extended(call_original, original)
 631  
 632                      event_logger.log_patch_function_start(args, kwargs)
 633  
 634                      await patch_function(call_original, *args, **kwargs)
 635  
 636                      session.state = "succeeded"
 637                      event_logger.log_patch_function_success(args, kwargs)
 638  
 639                  except Exception as e:
 640                      session.state = "failed"
 641                      patch_error = e
 642                      # Exceptions thrown during execution of the original function should be
 643                      # propagated to the caller. Additionally, exceptions encountered during test
 644                      # mode should be reraised to detect bugs in autologging implementations
 645                      if failed_during_original or is_testing():
 646                          raise
 647  
 648                  if is_testing() and not preexisting_run_for_testing:
 649                      # If an MLflow run was created during the execution of patch code, verify that
 650                      # it is no longer active and that it contains expected autologging tags
 651                      assert not active_run(), (
 652                          f"Autologging integration {autologging_integration} leaked an active run"
 653                      )
 654                      if patch_function_run_for_testing:
 655                          _validate_autologging_run(
 656                              autologging_integration, patch_function_run_for_testing.info.run_id
 657                          )
 658                  try:
 659                      if original_has_been_called:
 660                          return original_result
 661                      else:
 662                          return await call_original_fn_with_event_logging(original, args, kwargs)
 663                  finally:
 664                      if patch_error is not None and not failed_during_original:
 665                          event_logger.log_patch_function_error(args, kwargs, patch_error)
 666                          _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error))
 667  
 668      if is_property_method:
 669          # Create a patched function (also property decorated)
 670          # like:
 671          #
 672          # class A:
 673          # @property
 674          # def get_bound_safe_patch_fn(self):
 675          #   original_fn.fget(self) # do availability check
 676          #   return bound_safe_patch_fn
 677          #
 678          # Suppose `a1` is instance of class A,
 679          # then `a1.get_bound_safe_patch_fn(*args, **kwargs)` will be equivalent to
 680          # `bound_safe_patch_fn(*args, **kwargs)`
 681          def get_bound_safe_patch_fn(self):
 682              # This `original_fn.fget` call is for availability check, if it raise error
 683              # then `hasattr(obj, {func_name})` will return False
 684              # so it mimic the original property behavior.
 685              original_fn.fget(self)
 686  
 687              def bound_safe_patch_fn(*args, **kwargs):
 688                  return safe_patch_function(self, *args, **kwargs)
 689  
 690              # Make bound method `instance.target_method` keep the same doc and signature.
 691              # Here return the bound safe patch function because user call property decorated
 692              # method will like `instance.property_decorated_method(...)`, and internally it will
 693              # call the `bound_safe_patch_fn`, the argument list don't include the `self` argument,
 694              # so return bound function here.
 695              return update_wrapper_extended(bound_safe_patch_fn, original_fn.fget)
 696  
 697          # Make unbound method `class.target_method` keep the same doc and signature
 698          get_bound_safe_patch_fn = update_wrapper_extended(get_bound_safe_patch_fn, original_fn.fget)
 699          safe_patch_obj = property(get_bound_safe_patch_fn)
 700      elif is_async_function:
 701          safe_patch_obj = update_wrapper_extended(async_safe_patch_function, original)
 702      else:
 703          safe_patch_obj = update_wrapper_extended(safe_patch_function, original)
 704  
 705      new_patch = _wrap_patch(destination, function_name, safe_patch_obj)
 706      _store_patch(autologging_integration, new_patch)
 707  
 708  
 709  def revert_patches(autologging_integration):
 710      """Reverts all patches on the specified destination class for autologging disablement purposes.
 711  
 712      Args:
 713          autologging_integration: The name of the autologging integration associated with the
 714              patch. Note: If called via fluent api (`autologging_integration="mlflow"`), then revert
 715              all patches for all active autologging integrations.
 716  
 717      """
 718      for patch in _AUTOLOGGING_PATCHES.get(autologging_integration, []):
 719          gorilla.revert(patch)
 720  
 721      _AUTOLOGGING_PATCHES.pop(autologging_integration, None)
 722  
 723      # Call any registered cleanup callbacks (e.g., for OTel uninstrumentation)
 724      for callback in _AUTOLOGGING_CLEANUP_CALLBACKS.get(autologging_integration, []):
 725          try:
 726              callback()
 727          except Exception as e:
 728              _logger.warning(f"Error calling cleanup callback for {autologging_integration}: {e}")
 729  
 730      _AUTOLOGGING_CLEANUP_CALLBACKS.pop(autologging_integration, None)
 731  
 732  
 733  # Represents an active autologging session using two fields:
 734  # - integration: the name of the autologging integration corresponding to the session
 735  # - id: a unique session identifier (e.g., a UUID)
 736  # - state: the state of AutologgingSession, will be one of running/succeeded/failed
 737  class AutologgingSession:
 738      def __init__(self, integration, id_):
 739          self.integration = integration
 740          self.id = id_
 741          self.state = "running"
 742  
 743  
 744  class _AutologgingSessionManager:
 745      _session = None
 746  
 747      @classmethod
 748      @contextmanager
 749      def start_session(cls, integration):
 750          try:
 751              prev_session = cls._session
 752              if prev_session is None:
 753                  session_id = uuid.uuid4().hex
 754                  cls._session = AutologgingSession(integration, session_id)
 755              yield cls._session
 756          finally:
 757              # Only end the session upon termination of the context if we created
 758              # the session; otherwise, leave the session open for later termination
 759              # by its creator
 760              if prev_session is None:
 761                  cls._end_session()
 762  
 763      @classmethod
 764      @asynccontextmanager
 765      async def astart_session(cls, integration):
 766          try:
 767              prev_session = cls._session
 768              if prev_session is None:
 769                  session_id = uuid.uuid4().hex
 770                  cls._session = AutologgingSession(integration, session_id)
 771              yield cls._session
 772          finally:
 773              if prev_session is None:
 774                  cls._end_session()
 775  
 776      @classmethod
 777      def active_session(cls):
 778          return cls._session
 779  
 780      @classmethod
 781      def _end_session(cls):
 782          cls._session = None
 783  
 784  
 785  def update_wrapper_extended(wrapper, wrapped):
 786      """Update a `wrapper` function to look like the `wrapped` function. This is an extension of
 787      `functools.update_wrapper` that applies the docstring *and* signature of `wrapped` to
 788      `wrapper`, producing a new function.
 789  
 790      Returns:
 791          A new function with the same implementation as `wrapper` and the same docstring
 792          & signature as `wrapped`.
 793      """
 794      updated_wrapper = functools.update_wrapper(wrapper, wrapped)
 795      # Assign the signature of the `wrapped` function to the updated wrapper function.
 796      # Certain frameworks may disallow signature inspection, causing `inspect.signature()` to throw.
 797      # One such example is the `tensorflow.estimator.Estimator.export_savedmodel()` function
 798      try:
 799          updated_wrapper.__signature__ = inspect.signature(wrapped)
 800      except Exception:
 801          _logger.debug("Failed to restore original signature for wrapper around %s", wrapped)
 802      return updated_wrapper
 803  
 804  
 805  def _wrap_patch(destination, name, patch_obj, settings=None):
 806      """Apply a patch.
 807  
 808      Args:
 809          destination: Patch destination.
 810          name: Name of the attribute at the destination.
 811          patch_obj: Patch object, it should be a function or a property decorated function
 812              to be assigned to the patch point {destination}.{name}.
 813          settings: Settings for gorilla.Patch.
 814  
 815      """
 816      if settings is None:
 817          settings = gorilla.Settings(allow_hit=True, store_hit=True)
 818  
 819      patch = gorilla.Patch(destination, name, patch_obj, settings=settings)
 820      gorilla.apply(patch)
 821      return patch
 822  
 823  
 824  def _store_patch(autologging_integration, patch):
 825      """
 826      Stores a patch for a specified autologging_integration class. Later to be used for being able
 827      to revert the patch when disabling autologging.
 828  
 829      Args:
 830          autologging_integration: The name of the autologging integration associated with the
 831              patch.
 832          patch: The patch to be stored.
 833      """
 834      if autologging_integration in _AUTOLOGGING_PATCHES:
 835          _AUTOLOGGING_PATCHES[autologging_integration].add(patch)
 836      else:
 837          _AUTOLOGGING_PATCHES[autologging_integration] = {patch}
 838  
 839  
 840  def _validate_autologging_run(autologging_integration, run_id):
 841      """
 842      For testing purposes, verifies that an MLflow run produced by an `autologging_integration`
 843      satisfies the following properties:
 844  
 845          - The run has an autologging tag whose value is the name of the autologging integration
 846          - The run has a terminal status (e.g., KILLED, FAILED, FINISHED)
 847      """
 848      from mlflow.tracking.client import MlflowClient
 849  
 850      client = MlflowClient()
 851      run = client.get_run(run_id)
 852      autologging_tag_value = run.data.tags.get(MLFLOW_AUTOLOGGING)
 853      assert autologging_tag_value == autologging_integration, (
 854          f"Autologging run with id {run_id} failed to set autologging tag with expected value. "
 855          f"Expected: '{autologging_integration}', Actual: '{autologging_tag_value}'"
 856      )
 857      assert RunStatus.is_terminated(RunStatus.from_string(run.info.status)), (
 858          f"Autologging run with id {run_id} has a non-terminal status '{run.info.status}'"
 859      )
 860  
 861  
 862  class ValidationExemptArgument(NamedTuple):
 863      """
 864      A NamedTuple representing the properties of an argument that is exempt from validation
 865  
 866      autologging_integration: The name of the autologging integration.
 867      function_name: The name of the function that is being validated.
 868      type_function: A Callable that accepts an object and returns True if the given object matches
 869                     the argument type. Returns False otherwise.
 870      positional_argument_index: The index of the argument in the function signature.
 871      keyword_argument_name: The name of the argument in the function signature.
 872      """
 873  
 874      autologging_integration: str
 875      function_name: str
 876      type_function: Callable[..., Any]
 877      positional_argument_index: int | None = None
 878      keyword_argument_name: str | None = None
 879  
 880      def matches(
 881          self,
 882          autologging_integration,
 883          function_name,
 884          value,
 885          argument_index=None,
 886          argument_name=None,
 887      ):
 888          """
 889          This method checks if the properties provided through the function arguments matches the
 890          properties defined in the NamedTuple.
 891  
 892          Args:
 893              autologging_integration: The name of an autologging integration.
 894              function_name: The name of the function that is being matched.
 895              value: The value of the argument.
 896              argument_index: The index of the argument, if it is passed as a positional
 897                  argument. Otherwise it is None.
 898              argument_name: The name of the argument, if it is passed as a keyword
 899                  argument. Otherwise it is None.
 900  
 901          Returns:
 902              Returns True if the given function properties matches the exempt argument's
 903              properties. Returns False otherwise.
 904          """
 905          return (
 906              self.autologging_integration == autologging_integration
 907              and self.function_name == function_name
 908              and (
 909                  self.positional_argument_index == argument_index
 910                  or self.keyword_argument_name == argument_name
 911              )
 912              and self.type_function(value)
 913          )
 914  
 915  
 916  # WARNING: Exemptions should NOT be introduced unless absolutely necessary. If deemed necessary,
 917  #          clear reasons must be provided as comment in addition to thorough integration tests.
 918  _VALIDATION_EXEMPT_ARGUMENTS = [
 919      # When extracting implicitly defined `batch_size` in the case that `x` is a generator or a
 920      # generator class, we need to consume and restore the first element back to the generator to
 921      # calculate the `batch_size`. This means that:
 922      # 1. The type of `x` will become 'generator' regardless if user provided `x` as a generator or a
 923      #    custom generator class.
 924      # 2. The instance of `x` will be different, since we reconstructed the generator after consuming
 925      #    the first element.
 926      ValidationExemptArgument("tensorflow", "fit", is_iterator, 1, "x"),
 927      ValidationExemptArgument("keras", "fit", is_iterator, 1, "x"),
 928      ValidationExemptArgument("dspy", "__call__", lambda x: isinstance(x, Callable), 2, "metric"),
 929      # Autologging injects tracing context headers as `extra_headers` to enable distributed
 930      # tracing between client spans and gateway spans. The user may or may not have passed
 931      # `extra_headers` originally, so the argument value will differ from the user's input.
 932      ValidationExemptArgument(
 933          "openai", "create", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers"
 934      ),
 935      ValidationExemptArgument(
 936          "openai", "parse", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers"
 937      ),
 938      ValidationExemptArgument(
 939          "anthropic", "create", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers"
 940      ),
 941      # Gemini header injection goes through config.http_options.headers. Config can be
 942      # None, a dict, or a Pydantic-style object with an http_options attribute.
 943      ValidationExemptArgument(
 944          "gemini",
 945          "_generate_content",
 946          lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"),
 947          None,
 948          "config",
 949      ),
 950      ValidationExemptArgument(
 951          "gemini",
 952          "send_message",
 953          lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"),
 954          None,
 955          "config",
 956      ),
 957      ValidationExemptArgument(
 958          "gemini",
 959          "count_tokens",
 960          lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"),
 961          None,
 962          "config",
 963      ),
 964      ValidationExemptArgument(
 965          "gemini",
 966          "embed_content",
 967          lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"),
 968          None,
 969          "config",
 970      ),
 971  ]
 972  
 973  
 974  def _is_arg_exempt_from_validation(
 975      autologging_integration,
 976      function_name,
 977      argument,
 978      argument_index=None,
 979      argument_name=None,
 980  ):
 981      """This function is responsible for determining whether or not an argument is exempt from
 982      autolog safety validations. This includes both type checking and immutable checking.
 983  
 984      Args:
 985          autologging_integration: The name of the autologging integration.
 986          function_name: The name of the function that is being validated.
 987          argument: The actual argument.
 988          argument_index: The index of the argument, if it is passed as a positional
 989              argument. Otherwise it is None.
 990          argument_name: The name of the argument, if it is passed as a keyword argument.
 991              Otherwise it is None.
 992  
 993      Returns:
 994          True or False
 995      """
 996      return any(
 997          exemption.matches(
 998              autologging_integration,
 999              function_name,
1000              argument,
1001              argument_index,
1002              argument_name,
1003          )
1004          for exemption in _VALIDATION_EXEMPT_ARGUMENTS
1005      )
1006  
1007  
1008  def _validate_args(
1009      autologging_integration,
1010      function_name,
1011      user_call_args,
1012      user_call_kwargs,
1013      autologging_call_args,
1014      autologging_call_kwargs,
1015  ):
1016      """
1017      Used for testing purposes to verify that, when a patched ML function calls its underlying
1018      / original ML function, the following properties are satisfied:
1019  
1020          - All arguments supplied to the patched ML function are forwarded to the
1021            original ML function
1022          - Any additional arguments supplied to the original function are exception safe (i.e.
1023            they are either functions decorated with the `@exception_safe_function_for_class` or
1024            `@pickalable_exception_safe_function` decorators, or classes / instances of classes with
1025            type `ExceptionSafeClass`
1026      """
1027  
1028      def _validate_new_input(inp):
1029          """
1030          Validates a new input (arg or kwarg) introduced to the underlying / original ML function
1031          call during the execution of a patched ML function. The new input is valid if:
1032  
1033              - The new input is a function that has been decorated with
1034                `exception_safe_function_for_class` or `pickalable_exception_safe_function`
1035              - OR the new input is a class with the `ExceptionSafeClass` metaclass
1036              - OR the new input is a list and each of its elements is valid according to the
1037                these criteria
1038          """
1039          if type(inp) == list:
1040              for item in inp:
1041                  _validate_new_input(item)
1042          elif isinstance(inp, dict) and "callbacks" in inp:
1043              _validate_new_input(inp["callbacks"])
1044          elif callable(inp):
1045              assert getattr(inp, _ATTRIBUTE_EXCEPTION_SAFE, False), (
1046                  f"New function argument '{inp}' passed to original function is not exception-safe."
1047                  " Please decorate the function with `exception_safe_function` or "
1048                  "`pickalable_exception_safe_function`"
1049              )
1050          else:
1051              assert hasattr(inp, "__class__") and type(inp.__class__) in [
1052                  ExceptionSafeClass,
1053                  ExceptionSafeAbstractClass,
1054              ], (
1055                  f"Invalid new input '{inp}'. New args / kwargs introduced to `original` function "
1056                  "calls by patched code must either be functions decorated with "
1057                  "`exception_safe_function_for_class`, instances of classes with the "
1058                  "`ExceptionSafeClass` or `ExceptionSafeAbstractClass` metaclass safe or lists of "
1059                  "such exception safe functions / classes."
1060              )
1061  
1062      def _assert_autologging_input_positional_args_are_superset(
1063          autologging_call_input, user_call_input
1064      ):
1065          length_diff = len(autologging_call_input) - len(user_call_input)
1066          assert length_diff >= 0, (
1067              f"{length_diff} expected inputs are missing from the call to the original function."
1068          )
1069  
1070      def _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input):
1071          assert set(user_call_input.keys()).issubset(set(autologging_call_input.keys())), (
1072              "Keyword or dictionary arguments to original function omit"
1073              " one or more expected keys: '{}'".format(
1074                  set(user_call_input.keys()) - set(autologging_call_input.keys())
1075              )
1076          )
1077  
1078      def _validate(autologging_call_input, user_call_input=None):
1079          """
1080          Validates that the specified `autologging_call_input` and `user_call_input`
1081          are compatible. If `user_call_input` is `None`, then `autologging_call_input`
1082          is regarded as a new input added by autologging and is validated using
1083          `_validate_new_input`. Otherwise, the following properties must hold:
1084  
1085              - `autologging_call_input` and `user_call_input` must have the same type
1086                (referred to as "input type")
1087              - if the input type is a tuple, list or dictionary, then `autologging_call_input` must
1088                be equivalent to `user_call_input` or be a superset of `user_call_input`
1089              - for all other input types, `autologging_call_input` and `user_call_input`
1090                must be equivalent by reference equality or by object equality
1091  
1092          Args:
1093              autologging_call_input: call input from autologging.
1094              user_call_input: call input from user.
1095          """
1096  
1097          if user_call_input is None and autologging_call_input is not None:
1098              _validate_new_input(autologging_call_input)
1099              return
1100  
1101          assert type(autologging_call_input) == type(user_call_input), (
1102              "Type of input to original function '{}' does not match expected type '{}'".format(
1103                  type(autologging_call_input), type(user_call_input)
1104              )
1105          )
1106  
1107          if type(autologging_call_input) in [list, tuple]:
1108              _assert_autologging_input_positional_args_are_superset(
1109                  autologging_call_input, user_call_input
1110              )
1111              # If the autologging call input is longer than the user call input, we `zip_longest`
1112              # will pad the user call input with `None` values to ensure that the subsequent calls
1113              # to `_validate` identify new inputs added by the autologging call
1114              for a, u in itertools.zip_longest(autologging_call_input, user_call_input):
1115                  _validate(a, u)
1116          elif type(autologging_call_input) == dict:
1117              _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input)
1118              for key in autologging_call_input.keys():
1119                  _validate(autologging_call_input[key], user_call_input.get(key, None))
1120  
1121          else:
1122              assert (
1123                  autologging_call_input is user_call_input
1124                  or autologging_call_input == user_call_input
1125              ), (
1126                  "Input to original function does not match expected input."
1127                  f" Original: '{autologging_call_input}'. Expected: '{user_call_input}'"
1128              )
1129  
1130      # Similar validation logic found in _validate, unraveling the list of arguments to exclude
1131      # checks for any validation exempt positional arguments.
1132      _assert_autologging_input_positional_args_are_superset(autologging_call_args, user_call_args)
1133      for index, autologging_call_arg, user_call_arg in itertools.zip_longest(
1134          range(len(user_call_args)), autologging_call_args, user_call_args
1135      ):
1136          if not _is_arg_exempt_from_validation(
1137              autologging_integration,
1138              function_name,
1139              user_call_arg,
1140              argument_index=index,
1141          ):
1142              _validate(autologging_call_arg, user_call_arg)
1143  
1144      # Similar validation logic found in _validate, unraveling the dictionary of arguments to exclude
1145      # checks for any validation exempt keyword arguments.
1146      _assert_autologging_input_kwargs_are_superset(autologging_call_kwargs, user_call_kwargs)
1147      for key in autologging_call_kwargs.keys():
1148          if not _is_arg_exempt_from_validation(
1149              autologging_integration,
1150              function_name,
1151              user_call_kwargs.get(key, None),
1152              argument_name=key,
1153          ):
1154              _validate(
1155                  autologging_call_kwargs[key],
1156                  user_call_kwargs.get(key, None),
1157              )