safety.py
1 import abc 2 import functools 3 import inspect 4 import itertools 5 import uuid 6 from contextlib import asynccontextmanager, contextmanager 7 from typing import Any, Callable, NamedTuple 8 9 import mlflow 10 import mlflow.utils.autologging_utils 11 from mlflow.entities.run_status import RunStatus 12 from mlflow.environment_variables import _MLFLOW_AUTOLOGGING_TESTING 13 from mlflow.exceptions import MlflowException 14 from mlflow.utils import gorilla, is_iterator 15 from mlflow.utils.autologging_utils import _logger 16 from mlflow.utils.autologging_utils.events import AutologgingEventLoggerWrapper 17 from mlflow.utils.autologging_utils.logging_and_warnings import ( 18 MlflowEventsAndWarningsBehaviorGlobally, 19 NonMlflowWarningsBehaviorForCurrentThread, 20 ) 21 from mlflow.utils.mlflow_tags import MLFLOW_AUTOLOGGING 22 23 _AUTOLOGGING_PATCHES = {} 24 _AUTOLOGGING_CLEANUP_CALLBACKS = {} 25 26 27 # Function attribute used for testing purposes to verify that a given function 28 # has been wrapped with the `exception_safe_function_for_class` and 29 # `picklable_exception_safe_function` decorators 30 _ATTRIBUTE_EXCEPTION_SAFE = "exception_safe" 31 32 33 _ERROR_MSG = "Encountered unexpected error during {} autologging: {}" 34 35 36 def exception_safe_function_for_class(function): 37 """ 38 Wraps the specified function with broad exception handling to guard 39 against unexpected errors during autologging. 40 Note this function creates an unpicklable function as `safe_function` is locally defined, 41 but a class instance containing methods decorated by this function should be pickalable, 42 because pickle only saves instance attributes, not methods. 43 See https://docs.python.org/3/library/pickle.html#pickling-class-instances for more details. 44 """ 45 if is_testing(): 46 setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True) 47 48 def safe_function(*args, **kwargs): 49 try: 50 return function(*args, **kwargs) 51 except Exception as e: 52 if is_testing(): 53 raise 54 else: 55 _logger.warning("Encountered unexpected error during autologging: %s", e) 56 57 return update_wrapper_extended(safe_function, function) 58 59 60 def _safe_function(function, *args, **kwargs): 61 try: 62 return function(*args, **kwargs) 63 except Exception as e: 64 if is_testing(): 65 raise 66 else: 67 _logger.warning("Encountered unexpected error during autologging: %s", e) 68 69 70 def picklable_exception_safe_function(function): 71 """ 72 Wraps the specified function with broad exception handling to guard 73 against unexpected errors during autologging while preserving picklability. 74 """ 75 if is_testing(): 76 setattr(function, _ATTRIBUTE_EXCEPTION_SAFE, True) 77 78 return update_wrapper_extended(functools.partial(_safe_function, function), function) 79 80 81 def _exception_safe_class_factory(base_class): 82 """ 83 Creates an exception safe metaclass that inherits from `base_class`. 84 """ 85 86 class _ExceptionSafeClass(base_class): 87 """ 88 Metaclass that wraps all functions defined on the specified class with broad error handling 89 logic to guard against unexpected errors during autlogging. 90 91 Rationale: Patched autologging functions commonly pass additional class instances as 92 arguments to their underlying original training routines; for example, Keras autologging 93 constructs a subclass of `keras.callbacks.Callback` and forwards it to `Model.fit()`. 94 To prevent errors encountered during method execution within such classes from disrupting 95 model training, this metaclass wraps all class functions in a broad try / catch statement. 96 97 Note: `ExceptionSafeClass` does not handle exceptions in class methods or static methods, 98 as these are not always Python callables and are difficult to wrap 99 """ 100 101 def __new__(cls, name, bases, dct): 102 for m in dct: 103 # class methods or static methods are not callable. 104 if callable(dct[m]): 105 dct[m] = exception_safe_function_for_class(dct[m]) 106 return base_class.__new__(cls, name, bases, dct) 107 108 return _ExceptionSafeClass 109 110 111 ExceptionSafeClass = _exception_safe_class_factory(type) 112 113 # `ExceptionSafeClass` causes an error when used with an abstract class. 114 # 115 # ``` 116 # class AbstractClass(abc.ABC): 117 # ... 118 # 119 # class DerivedClass(AbstractClass, metaclass=ExceptionSafeClass): 120 # ... 121 # ``` 122 # 123 # This raises: 124 # 125 # ``` 126 # TypeError: metaclass conflict: the metaclass of a derived class must be 127 # a (non-strict) subclass of the metaclasses of all its bases. 128 # ``` 129 # 130 # To avoid this error, create `ExceptionSafeAbstractClass` that is based on `abc.ABCMeta`. 131 ExceptionSafeAbstractClass = _exception_safe_class_factory(abc.ABCMeta) 132 133 134 def with_managed_run(autologging_integration, patch_function, tags=None): 135 """Given a `patch_function`, returns an `augmented_patch_function` that wraps the execution of 136 `patch_function` with an active MLflow run. The following properties apply: 137 138 - An MLflow run is only created if there is no active run present when the 139 patch function is executed 140 141 - If an active run is created by the `augmented_patch_function`, it is terminated 142 with the `FINISHED` state at the end of function execution 143 144 - If an active run is created by the `augmented_patch_function`, it is terminated 145 with the `FAILED` if an unhandled exception is thrown during function execution 146 147 Note that, if nested runs or non-fluent runs are created by `patch_function`, `patch_function` 148 is responsible for terminating them by the time it terminates 149 (or in the event of an exception). 150 151 Args: 152 autologging_integration: The autologging integration associated 153 with the `patch_function`. 154 patch_function: A function object compatible with `safe_patch`. 155 tags: A dictionary of string tags to set on each managed run created during the 156 execution of `patch_function`. 157 """ 158 from mlflow.tracking.fluent import active_run 159 from mlflow.utils.autologging_utils import _has_active_training_session 160 161 def create_managed_run(): 162 managed_run = mlflow.start_run(tags=tags) 163 _logger.info( 164 "Created MLflow autologging run with ID '%s', which will track hyperparameters," 165 " performance metrics, model artifacts, and lineage information for the" 166 " current %s workflow", 167 managed_run.info.run_id, 168 autologging_integration, 169 ) 170 return managed_run 171 172 def patch_with_managed_run(original, *args, **kwargs): 173 managed_run = None 174 # If there is an active training session but there is no active run 175 # in current thread, it means the thread is spawned by `estimator.fit` 176 # as a worker thread, we should disable autologging in 177 # these worker threads, so skip creating managed run. 178 if not active_run() and not _has_active_training_session(): 179 managed_run = create_managed_run() 180 181 try: 182 result = patch_function(original, *args, **kwargs) 183 except (Exception, KeyboardInterrupt): 184 # In addition to standard Python exceptions, handle keyboard interrupts to ensure 185 # that runs are terminated if a user prematurely interrupts training execution 186 # (e.g. via sigint / ctrl-c) 187 if managed_run: 188 mlflow.end_run(RunStatus.to_string(RunStatus.FAILED)) 189 raise 190 else: 191 if managed_run: 192 mlflow.end_run(RunStatus.to_string(RunStatus.FINISHED)) 193 return result 194 195 return patch_with_managed_run 196 197 198 def is_testing(): 199 """ 200 Indicates whether or not autologging functionality is running in test mode (as determined 201 by the `MLFLOW_AUTOLOGGING_TESTING` environment variable). Test mode performs additional 202 validation during autologging, including: 203 204 - Checks for the exception safety of arguments passed to model training functions 205 (i.e. all additional arguments should be "exception safe" functions or classes) 206 - Disables exception handling for patched function logic, ensuring that patch code 207 executes without errors during testing 208 """ 209 return _MLFLOW_AUTOLOGGING_TESTING.get() 210 211 212 def _resolve_extra_tags(autologging_integration, extra_tags): 213 tags = {MLFLOW_AUTOLOGGING: autologging_integration} 214 if extra_tags: 215 if isinstance(extra_tags, dict): 216 if MLFLOW_AUTOLOGGING in extra_tags: 217 extra_tags.pop(MLFLOW_AUTOLOGGING) 218 _logger.warning( 219 f"Tag `{MLFLOW_AUTOLOGGING}` is ignored as it is a reserved tag by MLflow " 220 f"autologging." 221 ) 222 tags.update(extra_tags) 223 else: 224 raise mlflow.exceptions.MlflowException.invalid_parameter_value( 225 f"Invalid `extra_tags` type: expecting dictionary, " 226 f"received `{type(extra_tags).__name__}`" 227 ) 228 return tags 229 230 231 def safe_patch( 232 autologging_integration, 233 destination, 234 function_name, 235 patch_function, 236 manage_run=False, 237 extra_tags=None, 238 ): 239 """Patches the specified `function_name` on the specified `destination` class for autologging 240 purposes, preceding its implementation with an error-safe copy of the specified patch 241 `patch_function` with the following error handling behavior: 242 - Exceptions thrown from the underlying / original function 243 (`<destination>.<function_name>`) are propagated to the caller. 244 - Exceptions thrown from other parts of the patched implementation (`patch_function`) 245 are caught and logged as warnings. 246 247 Args: 248 autologging_integration: The name of the autologging integration associated with the 249 patch. 250 destination: The Python class on which the patch is being defined. 251 function_name: The name of the function to patch on the specified `destination` class. 252 patch_function: The patched function code to apply. The first argument should be reserved 253 for an `original` argument representing the underlying / original function. Subsequent 254 arguments should be identical to those of the original function being patched. 255 manage_run: If `True`, applies the `with_managed_run` wrapper to the specified 256 `patch_function`, which automatically creates & terminates an MLflow 257 active run during patch code execution if necessary. If `False`, 258 does not apply the `with_managed_run` wrapper to the specified 259 `patch_function`. 260 extra_tags: A dictionary of extra tags to set on each managed run created by autologging. 261 """ 262 from mlflow.tracking.fluent import active_run 263 from mlflow.utils.autologging_utils import autologging_is_disabled, get_autologging_config 264 265 # NB: Checking the signature of the patch function rather than original, so that we don't 266 # accidentally change the behavior of existing patches that may use sync patch function 267 # for async original functions (e.g. LangChain). 268 is_async_function = inspect.iscoroutinefunction(patch_function) 269 270 if manage_run: 271 if is_async_function: 272 raise MlflowException("manage_run parameter is not supported for async functions.") 273 274 tags = _resolve_extra_tags(autologging_integration, extra_tags) 275 patch_function = with_managed_run( 276 autologging_integration, 277 patch_function, 278 tags=tags, 279 ) 280 281 original_fn = gorilla.get_original_attribute( 282 destination, function_name, bypass_descriptor_protocol=False 283 ) 284 # Retrieve raw attribute while bypassing the descriptor protocol 285 raw_original_obj = gorilla.get_original_attribute( 286 destination, function_name, bypass_descriptor_protocol=True 287 ) 288 if original_fn != raw_original_obj: 289 raise RuntimeError(f"Unsupported patch on {destination}.{function_name}") 290 elif isinstance(original_fn, property): 291 if is_async_function: 292 raise MlflowException("Patching async property methods is not supported.") 293 294 is_property_method = True 295 296 # For property decorated methods (a kind of method delegation), e.g. 297 # class A: 298 # @property 299 # def f1(self): 300 # ... 301 # return delegated_f1 302 # 303 # suppose `a1` is an instance of class `A`, 304 # `A.f1.fget` will get the original `def f1(self)` method, 305 # and `A.f1.fget(a1)` will be equivalent to `a1.f1()` and 306 # its return value will be the `delegated_f1` function. 307 # So using the `property.fget` we can construct the (delegated) "original_fn" 308 def original(self, *args, **kwargs): 309 # the `original_fn.fget` will get the original method decorated by `property` 310 # the `original_fn.fget(self)` will get the delegated function returned by the 311 # property decorated method. 312 bound_delegate_method = original_fn.fget(self) 313 return bound_delegate_method(*args, **kwargs) 314 315 else: 316 original = original_fn 317 is_property_method = False 318 319 def safe_patch_function(*args, **kwargs): 320 """ 321 A safe wrapper around the specified `patch_function` implementation designed to 322 handle exceptions thrown during the execution of `patch_function`. This wrapper 323 distinguishes exceptions thrown from the underlying / original function 324 (`<destination>.<function_name>`) from exceptions thrown from other parts of 325 `patch_function`. This distinction is made by passing an augmented version of the 326 underlying / original function to `patch_function` that uses nonlocal state to track 327 whether or not it has been executed and whether or not it threw an exception. 328 Exceptions thrown from the underlying / original function are propagated to the caller, 329 while exceptions thrown from other parts of `patch_function` are caught and logged as 330 warnings. 331 332 NB: PLEASE BE SUPER CAREFUL WHEN MODIFYING THIS FUNCTION. IT IS USED IN A WIDE VARIETY 333 OF CONTEXTX AND CRITICAL PATH IN DBR/MLR BY DEFAULT. ANY BUG HERE CAN BREAK USERS' 334 WORKLOAD WITHOUT THEM TAKING ANY ACTION. 335 """ 336 # Reroute warnings encountered during the patch function implementation to an MLflow event 337 # logger, and enforce silent mode if applicable (i.e. if the corresponding autologging 338 # integration was called with `silent=True`), hiding MLflow event logging statements and 339 # hiding all warnings in the autologging preamble and postamble (i.e. the code surrounding 340 # the user's original / underlying ML function). Non-MLflow warnings are enabled during the 341 # execution of the original / underlying ML function 342 # 343 # Note that we've opted *not* to apply this context manager as a decorator on 344 # `safe_patch_function` because the context-manager-as-decorator pattern uses 345 # `contextlib.ContextDecorator`, which creates generator expressions that cannot be pickled 346 # during model serialization by ML frameworks such as scikit-learn 347 is_silent_mode = get_autologging_config(autologging_integration, "silent", False) 348 with ( 349 MlflowEventsAndWarningsBehaviorGlobally( 350 # MLflow warnings emitted during autologging training sessions are likely not 351 # actionable and result from the autologging implementation invoking another MLflow 352 # API. Accordingly, we reroute these warnings to the MLflow event logger with level 353 # WARNING For reference, see recommended warning and event logging behaviors from 354 # https://docs.python.org/3/howto/logging.html#when-to-use-logging 355 reroute_warnings=True, 356 disable_event_logs=is_silent_mode, 357 disable_warnings=is_silent_mode, 358 ), 359 NonMlflowWarningsBehaviorForCurrentThread( 360 # non-MLflow Warnings emitted during the autologging preamble (before the original / 361 # underlying ML function is called) and postamble (after the original / underlying 362 # ML function is called) are likely not actionable and result from the autologging 363 # implementation invoking an API from a dependent library. Accordingly, we reroute 364 # these warnings to the MLflow event logger with level WARNING. For reference, see 365 # recommended warning and event logging behaviors from 366 # https://docs.python.org/3/howto/logging.html#when-to-use-logging 367 reroute_warnings=True, 368 disable_warnings=is_silent_mode, 369 ), 370 ): 371 if is_testing(): 372 preexisting_run_for_testing = active_run() 373 374 # Whether or not to exclude autologged content from user-created fluent runs 375 # (i.e. runs created manually via `mlflow.start_run()`) 376 exclusive = get_autologging_config(autologging_integration, "exclusive", False) 377 user_created_fluent_run_is_active = ( 378 active_run() and not _AutologgingSessionManager.active_session() 379 ) 380 active_session_failed = ( 381 _AutologgingSessionManager.active_session() is not None 382 and _AutologgingSessionManager.active_session().state == "failed" 383 ) 384 385 if ( 386 active_session_failed 387 or autologging_is_disabled(autologging_integration) 388 or (user_created_fluent_run_is_active and exclusive) 389 or ( 390 mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED 391 and autologging_integration 392 ) 393 ): 394 # If the autologging integration associated with this patch is disabled, 395 # or if the current autologging integration is in exclusive mode and a user-created 396 # fluent run is active, call the original function and return. Restore the original 397 # warning behavior during original function execution, since autologging is being 398 # skipped 399 with NonMlflowWarningsBehaviorForCurrentThread( 400 disable_warnings=False, 401 reroute_warnings=False, 402 ): 403 return original(*args, **kwargs) 404 405 # Whether or not the original / underlying function has been called during the 406 # execution of patched code 407 original_has_been_called = False 408 # The value returned by the call to the original / underlying function during 409 # the execution of patched code 410 original_result = None 411 # Whether or not an exception was raised from within the original / underlying function 412 # during the execution of patched code 413 failed_during_original = False 414 # The active MLflow run (if any) associated with patch code execution 415 patch_function_run_for_testing = None 416 # The exception raised during executing patching function 417 patch_error = None 418 419 with _AutologgingSessionManager.start_session(autologging_integration) as session: 420 event_logger = AutologgingEventLoggerWrapper(session, destination, function_name) 421 422 def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs): 423 try: 424 event_logger.log_original_function_start(og_args, og_kwargs) 425 426 original_fn_result = original_fn(*og_args, **og_kwargs) 427 428 event_logger.log_original_function_success(og_args, og_kwargs) 429 return original_fn_result 430 except Exception as e: 431 event_logger.log_original_function_error(og_args, og_kwargs, e) 432 433 nonlocal failed_during_original 434 failed_during_original = True 435 raise 436 437 try: 438 439 def call_original(*og_args, **og_kwargs): 440 def _original_fn(*_og_args, **_og_kwargs): 441 if is_testing(): 442 _validate_args( 443 autologging_integration, 444 function_name, 445 args, 446 kwargs, 447 og_args, 448 og_kwargs, 449 ) 450 # By the time `original` is called by the patch implementation, we 451 # assume that either: 1. the patch implementation has already 452 # created an MLflow run or 2. the patch code will not create an 453 # MLflow run during the current execution. Here, we capture a 454 # reference to the active run, which we will use later on to 455 # determine whether or not the patch implementation created 456 # a run and perform validation if necessary 457 nonlocal patch_function_run_for_testing 458 patch_function_run_for_testing = active_run() 459 460 nonlocal original_has_been_called 461 original_has_been_called = True 462 463 nonlocal original_result 464 # Show all non-MLflow warnings as normal (i.e. not as event logs) 465 # during original function execution, even if silent mode is enabled 466 # (`silent=True`), since these warnings originate from the ML framework 467 # or one of its dependencies and are likely relevant to the caller 468 with NonMlflowWarningsBehaviorForCurrentThread( 469 disable_warnings=False, 470 reroute_warnings=False, 471 ): 472 original_result = original(*_og_args, **_og_kwargs) 473 return original_result 474 475 return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs) 476 477 # Apply the name, docstring, and signature of `original` to `call_original`. 478 # This is important because several autologging patch implementations inspect 479 # the signature of the `original` argument during execution 480 call_original = update_wrapper_extended(call_original, original) 481 482 event_logger.log_patch_function_start(args, kwargs) 483 484 patch_function(call_original, *args, **kwargs) 485 486 session.state = "succeeded" 487 event_logger.log_patch_function_success(args, kwargs) 488 489 except Exception as e: 490 session.state = "failed" 491 patch_error = e 492 # Exceptions thrown during execution of the original function should be 493 # propagated to the caller. Additionally, exceptions encountered during test 494 # mode should be reraised to detect bugs in autologging implementations 495 if failed_during_original or is_testing(): 496 raise 497 498 if is_testing() and not preexisting_run_for_testing: 499 # If an MLflow run was created during the execution of patch code, verify that 500 # it is no longer active and that it contains expected autologging tags 501 assert not active_run(), ( 502 f"Autologging integration {autologging_integration} leaked an active run" 503 ) 504 if patch_function_run_for_testing: 505 _validate_autologging_run( 506 autologging_integration, patch_function_run_for_testing.info.run_id 507 ) 508 try: 509 if original_has_been_called: 510 return original_result 511 else: 512 return call_original_fn_with_event_logging(original, args, kwargs) 513 finally: 514 # If original function succeeds, but `patch_function_exception` exists, 515 # it represent patching code unexpected failure, so we call 516 # `log_patch_function_error` in this case. 517 # If original function failed, we don't call `log_patch_function_error` 518 # even if `patch_function_exception` exists, because original function failure 519 # means there's some error in user code (e.g. user provide wrong arguments) 520 if patch_error is not None and not failed_during_original: 521 event_logger.log_patch_function_error(args, kwargs, patch_error) 522 _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error)) 523 524 async def async_safe_patch_function(*args, **kwargs): 525 """ 526 Async version of safe_patch_function. 527 528 This code brainlessly copies the synchronous version of the function, but with async 529 context managers and async functions. This is done to avoid the risk of introducing 530 any bugs or regressions in the async version of the function. Note that we need to 531 be really careful here, because autologging is enabled by-default in DBR/MLR, hence 532 any bug here can break users' workload without them taking any action. 533 534 That said, some long comments are omitted in this version to avoid redundancy. If 535 you want to understand the context of the code better, please refer to the 536 synchronous version as well. 537 """ 538 is_silent_mode = get_autologging_config(autologging_integration, "silent", False) 539 async with ( 540 MlflowEventsAndWarningsBehaviorGlobally( 541 reroute_warnings=True, 542 disable_event_logs=is_silent_mode, 543 disable_warnings=is_silent_mode, 544 ), 545 NonMlflowWarningsBehaviorForCurrentThread( 546 disable_warnings=is_silent_mode, 547 reroute_warnings=True, 548 ), 549 ): 550 if is_testing(): 551 preexisting_run_for_testing = active_run() 552 553 # Whether or not to exclude autologged content from user-created fluent runs 554 # (i.e. runs created manually via `mlflow.start_run()`) 555 exclusive = get_autologging_config(autologging_integration, "exclusive", False) 556 user_created_fluent_run_is_active = ( 557 active_run() and not _AutologgingSessionManager.active_session() 558 ) 559 active_session_failed = ( 560 _AutologgingSessionManager.active_session() is not None 561 and _AutologgingSessionManager.active_session().state == "failed" 562 ) 563 564 if ( 565 active_session_failed 566 or autologging_is_disabled(autologging_integration) 567 or (user_created_fluent_run_is_active and exclusive) 568 or ( 569 mlflow.utils.autologging_utils._AUTOLOGGING_GLOBALLY_DISABLED 570 and autologging_integration 571 ) 572 ): 573 async with NonMlflowWarningsBehaviorForCurrentThread(False, False): 574 return await original(*args, **kwargs) 575 576 original_has_been_called = False 577 original_result = None 578 failed_during_original = False 579 patch_function_run_for_testing = None 580 patch_error = None 581 582 async with _AutologgingSessionManager.astart_session( 583 autologging_integration 584 ) as session: 585 event_logger = AutologgingEventLoggerWrapper(session, destination, function_name) 586 587 async def call_original_fn_with_event_logging(original_fn, og_args, og_kwargs): 588 try: 589 event_logger.log_original_function_start(og_args, og_kwargs) 590 original_fn_result = await original_fn(*og_args, **og_kwargs) 591 event_logger.log_original_function_success(og_args, og_kwargs) 592 return original_fn_result 593 except Exception as e: 594 event_logger.log_original_function_error(og_args, og_kwargs, e) 595 nonlocal failed_during_original 596 failed_during_original = True 597 raise 598 599 try: 600 601 async def call_original(*og_args, **og_kwargs): 602 async def _original_fn(*_og_args, **_og_kwargs): 603 if is_testing(): 604 _validate_args( 605 autologging_integration, 606 function_name, 607 args, 608 kwargs, 609 og_args, 610 og_kwargs, 611 ) 612 nonlocal patch_function_run_for_testing 613 patch_function_run_for_testing = active_run() 614 615 nonlocal original_has_been_called 616 original_has_been_called = True 617 618 nonlocal original_result 619 async with NonMlflowWarningsBehaviorForCurrentThread(False, False): 620 original_result = await original(*_og_args, **_og_kwargs) 621 return original_result 622 623 return await call_original_fn_with_event_logging( 624 _original_fn, og_args, og_kwargs 625 ) 626 627 # Apply the name, docstring, and signature of `original` to `call_original`. 628 # This is important because several autologging patch implementations inspect 629 # the signature of the `original` argument during execution 630 call_original = update_wrapper_extended(call_original, original) 631 632 event_logger.log_patch_function_start(args, kwargs) 633 634 await patch_function(call_original, *args, **kwargs) 635 636 session.state = "succeeded" 637 event_logger.log_patch_function_success(args, kwargs) 638 639 except Exception as e: 640 session.state = "failed" 641 patch_error = e 642 # Exceptions thrown during execution of the original function should be 643 # propagated to the caller. Additionally, exceptions encountered during test 644 # mode should be reraised to detect bugs in autologging implementations 645 if failed_during_original or is_testing(): 646 raise 647 648 if is_testing() and not preexisting_run_for_testing: 649 # If an MLflow run was created during the execution of patch code, verify that 650 # it is no longer active and that it contains expected autologging tags 651 assert not active_run(), ( 652 f"Autologging integration {autologging_integration} leaked an active run" 653 ) 654 if patch_function_run_for_testing: 655 _validate_autologging_run( 656 autologging_integration, patch_function_run_for_testing.info.run_id 657 ) 658 try: 659 if original_has_been_called: 660 return original_result 661 else: 662 return await call_original_fn_with_event_logging(original, args, kwargs) 663 finally: 664 if patch_error is not None and not failed_during_original: 665 event_logger.log_patch_function_error(args, kwargs, patch_error) 666 _logger.warning(_ERROR_MSG.format(autologging_integration, patch_error)) 667 668 if is_property_method: 669 # Create a patched function (also property decorated) 670 # like: 671 # 672 # class A: 673 # @property 674 # def get_bound_safe_patch_fn(self): 675 # original_fn.fget(self) # do availability check 676 # return bound_safe_patch_fn 677 # 678 # Suppose `a1` is instance of class A, 679 # then `a1.get_bound_safe_patch_fn(*args, **kwargs)` will be equivalent to 680 # `bound_safe_patch_fn(*args, **kwargs)` 681 def get_bound_safe_patch_fn(self): 682 # This `original_fn.fget` call is for availability check, if it raise error 683 # then `hasattr(obj, {func_name})` will return False 684 # so it mimic the original property behavior. 685 original_fn.fget(self) 686 687 def bound_safe_patch_fn(*args, **kwargs): 688 return safe_patch_function(self, *args, **kwargs) 689 690 # Make bound method `instance.target_method` keep the same doc and signature. 691 # Here return the bound safe patch function because user call property decorated 692 # method will like `instance.property_decorated_method(...)`, and internally it will 693 # call the `bound_safe_patch_fn`, the argument list don't include the `self` argument, 694 # so return bound function here. 695 return update_wrapper_extended(bound_safe_patch_fn, original_fn.fget) 696 697 # Make unbound method `class.target_method` keep the same doc and signature 698 get_bound_safe_patch_fn = update_wrapper_extended(get_bound_safe_patch_fn, original_fn.fget) 699 safe_patch_obj = property(get_bound_safe_patch_fn) 700 elif is_async_function: 701 safe_patch_obj = update_wrapper_extended(async_safe_patch_function, original) 702 else: 703 safe_patch_obj = update_wrapper_extended(safe_patch_function, original) 704 705 new_patch = _wrap_patch(destination, function_name, safe_patch_obj) 706 _store_patch(autologging_integration, new_patch) 707 708 709 def revert_patches(autologging_integration): 710 """Reverts all patches on the specified destination class for autologging disablement purposes. 711 712 Args: 713 autologging_integration: The name of the autologging integration associated with the 714 patch. Note: If called via fluent api (`autologging_integration="mlflow"`), then revert 715 all patches for all active autologging integrations. 716 717 """ 718 for patch in _AUTOLOGGING_PATCHES.get(autologging_integration, []): 719 gorilla.revert(patch) 720 721 _AUTOLOGGING_PATCHES.pop(autologging_integration, None) 722 723 # Call any registered cleanup callbacks (e.g., for OTel uninstrumentation) 724 for callback in _AUTOLOGGING_CLEANUP_CALLBACKS.get(autologging_integration, []): 725 try: 726 callback() 727 except Exception as e: 728 _logger.warning(f"Error calling cleanup callback for {autologging_integration}: {e}") 729 730 _AUTOLOGGING_CLEANUP_CALLBACKS.pop(autologging_integration, None) 731 732 733 # Represents an active autologging session using two fields: 734 # - integration: the name of the autologging integration corresponding to the session 735 # - id: a unique session identifier (e.g., a UUID) 736 # - state: the state of AutologgingSession, will be one of running/succeeded/failed 737 class AutologgingSession: 738 def __init__(self, integration, id_): 739 self.integration = integration 740 self.id = id_ 741 self.state = "running" 742 743 744 class _AutologgingSessionManager: 745 _session = None 746 747 @classmethod 748 @contextmanager 749 def start_session(cls, integration): 750 try: 751 prev_session = cls._session 752 if prev_session is None: 753 session_id = uuid.uuid4().hex 754 cls._session = AutologgingSession(integration, session_id) 755 yield cls._session 756 finally: 757 # Only end the session upon termination of the context if we created 758 # the session; otherwise, leave the session open for later termination 759 # by its creator 760 if prev_session is None: 761 cls._end_session() 762 763 @classmethod 764 @asynccontextmanager 765 async def astart_session(cls, integration): 766 try: 767 prev_session = cls._session 768 if prev_session is None: 769 session_id = uuid.uuid4().hex 770 cls._session = AutologgingSession(integration, session_id) 771 yield cls._session 772 finally: 773 if prev_session is None: 774 cls._end_session() 775 776 @classmethod 777 def active_session(cls): 778 return cls._session 779 780 @classmethod 781 def _end_session(cls): 782 cls._session = None 783 784 785 def update_wrapper_extended(wrapper, wrapped): 786 """Update a `wrapper` function to look like the `wrapped` function. This is an extension of 787 `functools.update_wrapper` that applies the docstring *and* signature of `wrapped` to 788 `wrapper`, producing a new function. 789 790 Returns: 791 A new function with the same implementation as `wrapper` and the same docstring 792 & signature as `wrapped`. 793 """ 794 updated_wrapper = functools.update_wrapper(wrapper, wrapped) 795 # Assign the signature of the `wrapped` function to the updated wrapper function. 796 # Certain frameworks may disallow signature inspection, causing `inspect.signature()` to throw. 797 # One such example is the `tensorflow.estimator.Estimator.export_savedmodel()` function 798 try: 799 updated_wrapper.__signature__ = inspect.signature(wrapped) 800 except Exception: 801 _logger.debug("Failed to restore original signature for wrapper around %s", wrapped) 802 return updated_wrapper 803 804 805 def _wrap_patch(destination, name, patch_obj, settings=None): 806 """Apply a patch. 807 808 Args: 809 destination: Patch destination. 810 name: Name of the attribute at the destination. 811 patch_obj: Patch object, it should be a function or a property decorated function 812 to be assigned to the patch point {destination}.{name}. 813 settings: Settings for gorilla.Patch. 814 815 """ 816 if settings is None: 817 settings = gorilla.Settings(allow_hit=True, store_hit=True) 818 819 patch = gorilla.Patch(destination, name, patch_obj, settings=settings) 820 gorilla.apply(patch) 821 return patch 822 823 824 def _store_patch(autologging_integration, patch): 825 """ 826 Stores a patch for a specified autologging_integration class. Later to be used for being able 827 to revert the patch when disabling autologging. 828 829 Args: 830 autologging_integration: The name of the autologging integration associated with the 831 patch. 832 patch: The patch to be stored. 833 """ 834 if autologging_integration in _AUTOLOGGING_PATCHES: 835 _AUTOLOGGING_PATCHES[autologging_integration].add(patch) 836 else: 837 _AUTOLOGGING_PATCHES[autologging_integration] = {patch} 838 839 840 def _validate_autologging_run(autologging_integration, run_id): 841 """ 842 For testing purposes, verifies that an MLflow run produced by an `autologging_integration` 843 satisfies the following properties: 844 845 - The run has an autologging tag whose value is the name of the autologging integration 846 - The run has a terminal status (e.g., KILLED, FAILED, FINISHED) 847 """ 848 from mlflow.tracking.client import MlflowClient 849 850 client = MlflowClient() 851 run = client.get_run(run_id) 852 autologging_tag_value = run.data.tags.get(MLFLOW_AUTOLOGGING) 853 assert autologging_tag_value == autologging_integration, ( 854 f"Autologging run with id {run_id} failed to set autologging tag with expected value. " 855 f"Expected: '{autologging_integration}', Actual: '{autologging_tag_value}'" 856 ) 857 assert RunStatus.is_terminated(RunStatus.from_string(run.info.status)), ( 858 f"Autologging run with id {run_id} has a non-terminal status '{run.info.status}'" 859 ) 860 861 862 class ValidationExemptArgument(NamedTuple): 863 """ 864 A NamedTuple representing the properties of an argument that is exempt from validation 865 866 autologging_integration: The name of the autologging integration. 867 function_name: The name of the function that is being validated. 868 type_function: A Callable that accepts an object and returns True if the given object matches 869 the argument type. Returns False otherwise. 870 positional_argument_index: The index of the argument in the function signature. 871 keyword_argument_name: The name of the argument in the function signature. 872 """ 873 874 autologging_integration: str 875 function_name: str 876 type_function: Callable[..., Any] 877 positional_argument_index: int | None = None 878 keyword_argument_name: str | None = None 879 880 def matches( 881 self, 882 autologging_integration, 883 function_name, 884 value, 885 argument_index=None, 886 argument_name=None, 887 ): 888 """ 889 This method checks if the properties provided through the function arguments matches the 890 properties defined in the NamedTuple. 891 892 Args: 893 autologging_integration: The name of an autologging integration. 894 function_name: The name of the function that is being matched. 895 value: The value of the argument. 896 argument_index: The index of the argument, if it is passed as a positional 897 argument. Otherwise it is None. 898 argument_name: The name of the argument, if it is passed as a keyword 899 argument. Otherwise it is None. 900 901 Returns: 902 Returns True if the given function properties matches the exempt argument's 903 properties. Returns False otherwise. 904 """ 905 return ( 906 self.autologging_integration == autologging_integration 907 and self.function_name == function_name 908 and ( 909 self.positional_argument_index == argument_index 910 or self.keyword_argument_name == argument_name 911 ) 912 and self.type_function(value) 913 ) 914 915 916 # WARNING: Exemptions should NOT be introduced unless absolutely necessary. If deemed necessary, 917 # clear reasons must be provided as comment in addition to thorough integration tests. 918 _VALIDATION_EXEMPT_ARGUMENTS = [ 919 # When extracting implicitly defined `batch_size` in the case that `x` is a generator or a 920 # generator class, we need to consume and restore the first element back to the generator to 921 # calculate the `batch_size`. This means that: 922 # 1. The type of `x` will become 'generator' regardless if user provided `x` as a generator or a 923 # custom generator class. 924 # 2. The instance of `x` will be different, since we reconstructed the generator after consuming 925 # the first element. 926 ValidationExemptArgument("tensorflow", "fit", is_iterator, 1, "x"), 927 ValidationExemptArgument("keras", "fit", is_iterator, 1, "x"), 928 ValidationExemptArgument("dspy", "__call__", lambda x: isinstance(x, Callable), 2, "metric"), 929 # Autologging injects tracing context headers as `extra_headers` to enable distributed 930 # tracing between client spans and gateway spans. The user may or may not have passed 931 # `extra_headers` originally, so the argument value will differ from the user's input. 932 ValidationExemptArgument( 933 "openai", "create", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers" 934 ), 935 ValidationExemptArgument( 936 "openai", "parse", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers" 937 ), 938 ValidationExemptArgument( 939 "anthropic", "create", lambda x: isinstance(x, (dict, type(None))), None, "extra_headers" 940 ), 941 # Gemini header injection goes through config.http_options.headers. Config can be 942 # None, a dict, or a Pydantic-style object with an http_options attribute. 943 ValidationExemptArgument( 944 "gemini", 945 "_generate_content", 946 lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"), 947 None, 948 "config", 949 ), 950 ValidationExemptArgument( 951 "gemini", 952 "send_message", 953 lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"), 954 None, 955 "config", 956 ), 957 ValidationExemptArgument( 958 "gemini", 959 "count_tokens", 960 lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"), 961 None, 962 "config", 963 ), 964 ValidationExemptArgument( 965 "gemini", 966 "embed_content", 967 lambda x: x is None or isinstance(x, dict) or hasattr(x, "http_options"), 968 None, 969 "config", 970 ), 971 ] 972 973 974 def _is_arg_exempt_from_validation( 975 autologging_integration, 976 function_name, 977 argument, 978 argument_index=None, 979 argument_name=None, 980 ): 981 """This function is responsible for determining whether or not an argument is exempt from 982 autolog safety validations. This includes both type checking and immutable checking. 983 984 Args: 985 autologging_integration: The name of the autologging integration. 986 function_name: The name of the function that is being validated. 987 argument: The actual argument. 988 argument_index: The index of the argument, if it is passed as a positional 989 argument. Otherwise it is None. 990 argument_name: The name of the argument, if it is passed as a keyword argument. 991 Otherwise it is None. 992 993 Returns: 994 True or False 995 """ 996 return any( 997 exemption.matches( 998 autologging_integration, 999 function_name, 1000 argument, 1001 argument_index, 1002 argument_name, 1003 ) 1004 for exemption in _VALIDATION_EXEMPT_ARGUMENTS 1005 ) 1006 1007 1008 def _validate_args( 1009 autologging_integration, 1010 function_name, 1011 user_call_args, 1012 user_call_kwargs, 1013 autologging_call_args, 1014 autologging_call_kwargs, 1015 ): 1016 """ 1017 Used for testing purposes to verify that, when a patched ML function calls its underlying 1018 / original ML function, the following properties are satisfied: 1019 1020 - All arguments supplied to the patched ML function are forwarded to the 1021 original ML function 1022 - Any additional arguments supplied to the original function are exception safe (i.e. 1023 they are either functions decorated with the `@exception_safe_function_for_class` or 1024 `@pickalable_exception_safe_function` decorators, or classes / instances of classes with 1025 type `ExceptionSafeClass` 1026 """ 1027 1028 def _validate_new_input(inp): 1029 """ 1030 Validates a new input (arg or kwarg) introduced to the underlying / original ML function 1031 call during the execution of a patched ML function. The new input is valid if: 1032 1033 - The new input is a function that has been decorated with 1034 `exception_safe_function_for_class` or `pickalable_exception_safe_function` 1035 - OR the new input is a class with the `ExceptionSafeClass` metaclass 1036 - OR the new input is a list and each of its elements is valid according to the 1037 these criteria 1038 """ 1039 if type(inp) == list: 1040 for item in inp: 1041 _validate_new_input(item) 1042 elif isinstance(inp, dict) and "callbacks" in inp: 1043 _validate_new_input(inp["callbacks"]) 1044 elif callable(inp): 1045 assert getattr(inp, _ATTRIBUTE_EXCEPTION_SAFE, False), ( 1046 f"New function argument '{inp}' passed to original function is not exception-safe." 1047 " Please decorate the function with `exception_safe_function` or " 1048 "`pickalable_exception_safe_function`" 1049 ) 1050 else: 1051 assert hasattr(inp, "__class__") and type(inp.__class__) in [ 1052 ExceptionSafeClass, 1053 ExceptionSafeAbstractClass, 1054 ], ( 1055 f"Invalid new input '{inp}'. New args / kwargs introduced to `original` function " 1056 "calls by patched code must either be functions decorated with " 1057 "`exception_safe_function_for_class`, instances of classes with the " 1058 "`ExceptionSafeClass` or `ExceptionSafeAbstractClass` metaclass safe or lists of " 1059 "such exception safe functions / classes." 1060 ) 1061 1062 def _assert_autologging_input_positional_args_are_superset( 1063 autologging_call_input, user_call_input 1064 ): 1065 length_diff = len(autologging_call_input) - len(user_call_input) 1066 assert length_diff >= 0, ( 1067 f"{length_diff} expected inputs are missing from the call to the original function." 1068 ) 1069 1070 def _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input): 1071 assert set(user_call_input.keys()).issubset(set(autologging_call_input.keys())), ( 1072 "Keyword or dictionary arguments to original function omit" 1073 " one or more expected keys: '{}'".format( 1074 set(user_call_input.keys()) - set(autologging_call_input.keys()) 1075 ) 1076 ) 1077 1078 def _validate(autologging_call_input, user_call_input=None): 1079 """ 1080 Validates that the specified `autologging_call_input` and `user_call_input` 1081 are compatible. If `user_call_input` is `None`, then `autologging_call_input` 1082 is regarded as a new input added by autologging and is validated using 1083 `_validate_new_input`. Otherwise, the following properties must hold: 1084 1085 - `autologging_call_input` and `user_call_input` must have the same type 1086 (referred to as "input type") 1087 - if the input type is a tuple, list or dictionary, then `autologging_call_input` must 1088 be equivalent to `user_call_input` or be a superset of `user_call_input` 1089 - for all other input types, `autologging_call_input` and `user_call_input` 1090 must be equivalent by reference equality or by object equality 1091 1092 Args: 1093 autologging_call_input: call input from autologging. 1094 user_call_input: call input from user. 1095 """ 1096 1097 if user_call_input is None and autologging_call_input is not None: 1098 _validate_new_input(autologging_call_input) 1099 return 1100 1101 assert type(autologging_call_input) == type(user_call_input), ( 1102 "Type of input to original function '{}' does not match expected type '{}'".format( 1103 type(autologging_call_input), type(user_call_input) 1104 ) 1105 ) 1106 1107 if type(autologging_call_input) in [list, tuple]: 1108 _assert_autologging_input_positional_args_are_superset( 1109 autologging_call_input, user_call_input 1110 ) 1111 # If the autologging call input is longer than the user call input, we `zip_longest` 1112 # will pad the user call input with `None` values to ensure that the subsequent calls 1113 # to `_validate` identify new inputs added by the autologging call 1114 for a, u in itertools.zip_longest(autologging_call_input, user_call_input): 1115 _validate(a, u) 1116 elif type(autologging_call_input) == dict: 1117 _assert_autologging_input_kwargs_are_superset(autologging_call_input, user_call_input) 1118 for key in autologging_call_input.keys(): 1119 _validate(autologging_call_input[key], user_call_input.get(key, None)) 1120 1121 else: 1122 assert ( 1123 autologging_call_input is user_call_input 1124 or autologging_call_input == user_call_input 1125 ), ( 1126 "Input to original function does not match expected input." 1127 f" Original: '{autologging_call_input}'. Expected: '{user_call_input}'" 1128 ) 1129 1130 # Similar validation logic found in _validate, unraveling the list of arguments to exclude 1131 # checks for any validation exempt positional arguments. 1132 _assert_autologging_input_positional_args_are_superset(autologging_call_args, user_call_args) 1133 for index, autologging_call_arg, user_call_arg in itertools.zip_longest( 1134 range(len(user_call_args)), autologging_call_args, user_call_args 1135 ): 1136 if not _is_arg_exempt_from_validation( 1137 autologging_integration, 1138 function_name, 1139 user_call_arg, 1140 argument_index=index, 1141 ): 1142 _validate(autologging_call_arg, user_call_arg) 1143 1144 # Similar validation logic found in _validate, unraveling the dictionary of arguments to exclude 1145 # checks for any validation exempt keyword arguments. 1146 _assert_autologging_input_kwargs_are_superset(autologging_call_kwargs, user_call_kwargs) 1147 for key in autologging_call_kwargs.keys(): 1148 if not _is_arg_exempt_from_validation( 1149 autologging_integration, 1150 function_name, 1151 user_call_kwargs.get(key, None), 1152 argument_name=key, 1153 ): 1154 _validate( 1155 autologging_call_kwargs[key], 1156 user_call_kwargs.get(key, None), 1157 )