/ mlflow / tensorflow / __init__.py
__init__.py
   1  """
   2  The ``mlflow.tensorflow`` module provides an API for logging and loading TensorFlow models.
   3  This module exports TensorFlow models with the following flavors:
   4  
   5  TensorFlow (native) format
   6      This is the main flavor that can be loaded back into TensorFlow.
   7  :py:mod:`mlflow.pyfunc`
   8      Produced for use by generic pyfunc-based deployment tools and batch inference.
   9  """
  10  
  11  import importlib
  12  import logging
  13  import os
  14  import shutil
  15  import tempfile
  16  from typing import Any, NamedTuple
  17  
  18  import numpy as np
  19  import pandas
  20  import yaml
  21  from packaging.version import Version
  22  
  23  import mlflow
  24  from mlflow import pyfunc
  25  from mlflow.data.code_dataset_source import CodeDatasetSource
  26  from mlflow.data.numpy_dataset import from_numpy
  27  from mlflow.data.tensorflow_dataset import from_tensorflow
  28  from mlflow.entities import LoggedModelInput
  29  from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION
  30  from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException
  31  from mlflow.models import Model, ModelInputExample, ModelSignature, infer_signature
  32  from mlflow.models.model import MLMODEL_FILE_NAME
  33  from mlflow.models.signature import _infer_signature_from_input_example
  34  from mlflow.models.utils import _save_example
  35  from mlflow.tensorflow.callback import MlflowCallback, MlflowModelCheckpointCallback  # noqa: F401
  36  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
  37  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  38  from mlflow.tracking.context import registry as context_registry
  39  from mlflow.tracking.fluent import _initialize_logged_model, _shut_down_async_logging
  40  from mlflow.types.schema import TensorSpec
  41  from mlflow.utils import is_iterator
  42  from mlflow.utils.autologging_utils import (
  43      autologging_integration,
  44      get_autologging_config,
  45      log_fn_args_as_params,
  46      picklable_exception_safe_function,
  47      resolve_input_example_and_signature,
  48      safe_patch,
  49  )
  50  from mlflow.utils.checkpoint_utils import (
  51      _WEIGHT_ONLY_CHECKPOINT_SUFFIX,
  52      download_checkpoint_artifact,
  53  )
  54  from mlflow.utils.databricks_utils import (
  55      is_in_databricks_model_serving_environment,
  56      is_in_databricks_runtime,
  57  )
  58  from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
  59  from mlflow.utils.environment import (
  60      _CONDA_ENV_FILE_NAME,
  61      _CONSTRAINTS_FILE_NAME,
  62      _PYTHON_ENV_FILE_NAME,
  63      _REQUIREMENTS_FILE_NAME,
  64      _mlflow_conda_env,
  65      _process_conda_env,
  66      _process_pip_requirements,
  67      _PythonEnv,
  68      _validate_env_arguments,
  69  )
  70  from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to
  71  from mlflow.utils.model_utils import (
  72      _add_code_from_conf_to_system_path,
  73      _copy_extra_files,
  74      _get_flavor_configuration,
  75      _validate_and_copy_code_paths,
  76      _validate_and_prepare_target_save_path,
  77  )
  78  from mlflow.utils.requirements_utils import _get_pinned_requirement
  79  
  80  FLAVOR_NAME = "tensorflow"
  81  
  82  _logger = logging.getLogger(__name__)
  83  
  84  # For tracking if the run was started by autologging.
  85  _AUTOLOG_RUN_ID = None
  86  
  87  # File name to which custom objects cloudpickle is saved - used during save and load
  88  _CUSTOM_OBJECTS_SAVE_PATH = "custom_objects.cloudpickle"
  89  # File name to which custom objects stored in tensorflow _GLOBAL_CUSTOM_OBJECTS
  90  # is saved - it is automatically detected and used during save and load
  91  _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH = "global_custom_objects.cloudpickle"
  92  _KERAS_MODULE_SPEC_PATH = "keras_module.txt"
  93  _KERAS_SAVE_FORMAT_PATH = "save_format.txt"
  94  # File name to which keras model is saved
  95  _MODEL_SAVE_PATH = "model"
  96  
  97  
  98  _MODEL_TYPE_KERAS = "keras"
  99  _MODEL_TYPE_TF1_ESTIMATOR = "tf1-estimator"
 100  _MODEL_TYPE_TF2_MODULE = "tf2-module"
 101  
 102  
 103  _KERAS_MODEL_DATA_PATH = "data"
 104  _TF2MODEL_SUBPATH = "tf2model"
 105  
 106  
 107  MLflowCallback = MlflowCallback  # for backwards compatibility
 108  
 109  
 110  def get_default_pip_requirements(include_cloudpickle=False):
 111      """
 112      Returns
 113          A list of default pip requirements for MLflow Models produced by this flavor.
 114          Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
 115          that, at minimum, contains these requirements.
 116      """
 117      pip_deps = [_get_pinned_requirement("tensorflow")]
 118      if include_cloudpickle:
 119          pip_deps.append(_get_pinned_requirement("cloudpickle"))
 120  
 121      return pip_deps
 122  
 123  
 124  def get_default_conda_env():
 125      """
 126      Returns:
 127          The default Conda environment for MLflow Models produced by calls to
 128          :func:`save_model()` and :func:`log_model()`.
 129      """
 130      return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
 131  
 132  
 133  def get_global_custom_objects():
 134      """
 135      Returns:
 136          A live reference to the global dictionary of custom objects.
 137      """
 138      try:
 139          from tensorflow.keras.saving import get_custom_objects
 140  
 141          return get_custom_objects()
 142      except Exception:
 143          pass
 144  
 145  
 146  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
 147  def log_model(
 148      model,
 149      artifact_path: str | None = None,
 150      custom_objects=None,
 151      conda_env=None,
 152      code_paths=None,
 153      signature: ModelSignature = None,
 154      input_example: ModelInputExample = None,
 155      registered_model_name=None,
 156      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
 157      pip_requirements=None,
 158      extra_pip_requirements=None,
 159      saved_model_kwargs=None,
 160      keras_model_kwargs=None,
 161      metadata=None,
 162      extra_files=None,
 163      name: str | None = None,
 164      params: dict[str, Any] | None = None,
 165      tags: dict[str, Any] | None = None,
 166      model_type: str | None = None,
 167      step: int = 0,
 168      model_id: str | None = None,
 169      **kwargs,
 170  ):
 171      """
 172      Log a TF2 core model (inheriting tf.Module) or a Keras model in MLflow Model format.
 173  
 174      .. note::
 175  
 176          If you log a Keras or TensorFlow model without a signature, inference with
 177          :py:func:`mlflow.pyfunc.spark_udf()` will not work unless the model's pyfunc
 178          representation accepts pandas DataFrames as inference inputs.
 179  
 180          You can infer a model's signature by calling the :py:func:`mlflow.models.infer_signature()`
 181          API on features from the model's test dataset. You can also manually create a model
 182          signature, for example:
 183  
 184          .. code-block:: python
 185              :caption: Example of creating signature for saving TensorFlow and `tf.Keras` models
 186  
 187              from mlflow.types.schema import Schema, TensorSpec
 188              from mlflow.models import ModelSignature
 189              import numpy as np
 190  
 191              input_schema = Schema([
 192                  TensorSpec(np.dtype(np.uint64), (-1, 5), "field1"),
 193                  TensorSpec(np.dtype(np.float32), (-1, 3, 2), "field2"),
 194              ])
 195              # Create the signature for a model that requires 2 inputs:
 196              #  - Input with name "field1", shape (-1, 5), type "np.uint64"
 197              #  - Input with name "field2", shape (-1, 3, 2), type "np.float32"
 198              signature = ModelSignature(inputs=input_schema)
 199  
 200      Args:
 201          model: The TF2 core model (inheriting tf.Module) or Keras model to be saved.
 202          artifact_path: Deprecated. Use `name` instead.
 203          custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to
 204              custom classes or functions associated with the Keras model. MLflow saves
 205              these custom layers using CloudPickle and restores them automatically
 206              when the model is loaded with :py:func:`mlflow.tensorflow.load_model` and
 207              :py:func:`mlflow.pyfunc.load_model`.
 208          conda_env: {{ conda_env }}
 209          code_paths: {{ code_paths }}
 210          signature: {{ signature }}
 211          input_example: {{ input_example }}
 212          registered_model_name: If given, create a model version under
 213              ``registered_model_name``, also creating a registered model if one
 214              with the given name does not exist.
 215          await_registration_for: Number of seconds to wait for the model version to finish
 216              being created and is in ``READY`` status. By default, the function
 217              waits for five minutes. Specify 0 or None to skip waiting.
 218          pip_requirements: {{ pip_requirements }}
 219          extra_pip_requirements: {{ extra_pip_requirements }}
 220          saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method.
 221          keras_model_kwargs: a dict of kwargs to pass to ``keras_model.save`` method.
 222          metadata: {{ metadata }}
 223          extra_files: {{ extra_files }}
 224          name: {{ name }}
 225          params: {{ params }}
 226          tags: {{ tags }}
 227          model_type: {{ model_type }}
 228          step: {{ step }}
 229          model_id: {{ model_id }}
 230          kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`.
 231  
 232      Returns
 233          A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
 234          metadata of the logged model.
 235      """
 236  
 237      return Model.log(
 238          artifact_path=artifact_path,
 239          name=name,
 240          flavor=mlflow.tensorflow,
 241          model=model,
 242          conda_env=conda_env,
 243          code_paths=code_paths,
 244          custom_objects=custom_objects,
 245          registered_model_name=registered_model_name,
 246          signature=signature,
 247          input_example=input_example,
 248          await_registration_for=await_registration_for,
 249          pip_requirements=pip_requirements,
 250          extra_pip_requirements=extra_pip_requirements,
 251          saved_model_kwargs=saved_model_kwargs,
 252          keras_model_kwargs=keras_model_kwargs,
 253          metadata=metadata,
 254          extra_files=extra_files,
 255          params=params,
 256          tags=tags,
 257          model_type=model_type,
 258          step=step,
 259          model_id=model_id,
 260          **kwargs,
 261      )
 262  
 263  
 264  def _save_keras_custom_objects(path, custom_objects, file_name):
 265      """
 266      Save custom objects dictionary to a cloudpickle file so a model can be easily loaded later.
 267  
 268      Args:
 269          path: An absolute path that points to the data directory within /path/to/model.
 270          custom_objects: Keras ``custom_objects`` is a dictionary mapping
 271              names (strings) to custom classes or functions to be considered
 272              during deserialization. MLflow saves these custom layers using
 273              CloudPickle and restores them automatically when the model is
 274              loaded with :py:func:`mlflow.keras.load_model` and
 275              :py:func:`mlflow.pyfunc.load_model`.
 276          file_name: The file name to save the custom objects to.
 277      """
 278      import cloudpickle
 279  
 280      custom_objects_path = os.path.join(path, file_name)
 281      with open(custom_objects_path, "wb") as out_f:
 282          cloudpickle.dump(custom_objects, out_f)
 283  
 284  
 285  _NO_MODEL_SIGNATURE_WARNING = (
 286      "You are saving a TensorFlow Core model or Keras model "
 287      "without a signature. Inference with mlflow.pyfunc.spark_udf() will not work "
 288      "unless the model's pyfunc representation accepts pandas DataFrames as "
 289      "inference inputs."
 290  )
 291  
 292  
 293  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
 294  def save_model(
 295      model,
 296      path,
 297      conda_env=None,
 298      code_paths=None,
 299      mlflow_model=None,
 300      custom_objects=None,
 301      signature: ModelSignature = None,
 302      input_example: ModelInputExample = None,
 303      pip_requirements=None,
 304      extra_pip_requirements=None,
 305      saved_model_kwargs=None,
 306      keras_model_kwargs=None,
 307      metadata=None,
 308      extra_files=None,
 309  ):
 310      """
 311      Save a TF2 core model (inheriting tf.Module) or Keras model in MLflow Model format to a path on
 312      the local file system.
 313  
 314      .. note::
 315          If you save a Keras or TensorFlow model without a signature, inference with
 316          :py:func:`mlflow.pyfunc.spark_udf()` will not work unless the model's pyfunc
 317          representation accepts pandas DataFrames as inference inputs.
 318          You can infer a model's signature by calling the :py:func:`mlflow.models.infer_signature()`
 319          API on features from the model's test dataset. You can also manually create a model
 320          signature, for example:
 321  
 322          .. code-block:: python
 323              :caption: Example of creating signature for saving TensorFlow and `tf.Keras` models
 324  
 325              from mlflow.types.schema import Schema, TensorSpec
 326              from mlflow.models import ModelSignature
 327              import numpy as np
 328  
 329              input_schema = Schema([
 330                  TensorSpec(np.dtype(np.uint64), (-1, 5), "field1"),
 331                  TensorSpec(np.dtype(np.float32), (-1, 3, 2), "field2"),
 332              ])
 333              # Create the signature for a model that requires 2 inputs:
 334              #  - Input with name "field1", shape (-1, 5), type "np.uint64"
 335              #  - Input with name "field2", shape (-1, 3, 2), type "np.float32"
 336              signature = ModelSignature(inputs=input_schema)
 337  
 338      Args:
 339          model: The Keras model or Tensorflow module to be saved.
 340          path: Local path where the MLflow model is to be saved.
 341          conda_env: {{ conda_env }}
 342          code_paths: {{ code_paths }}
 343          mlflow_model: MLflow model configuration to which to add the ``tensorflow`` flavor.
 344          custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to
 345              custom classes or functions associated with the Keras model. MLflow saves
 346              these custom layers using CloudPickle and restores them automatically
 347              when the model is loaded with :py:func:`mlflow.tensorflow.load_model` and
 348              :py:func:`mlflow.pyfunc.load_model`.
 349          signature: {{ signature }}
 350          input_example: {{ input_example }}
 351          pip_requirements: {{ pip_requirements }}
 352          extra_pip_requirements: {{ extra_pip_requirements }}
 353          saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method
 354              if the model to be saved is a Tensorflow module.
 355          keras_model_kwargs: a dict of kwargs to pass to ``model.save`` method if the model
 356              to be saved is a keras model.
 357          metadata: {{ metadata }}
 358          extra_files: {{ extra_files }}
 359      """
 360      import tensorflow as tf
 361      from tensorflow.keras.models import Model as KerasModel
 362  
 363      # check if path exists
 364      path = os.path.abspath(path)
 365      _validate_and_prepare_target_save_path(path)
 366  
 367      code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
 368  
 369      if mlflow_model is None:
 370          mlflow_model = Model()
 371      saved_example = _save_example(mlflow_model, input_example, path)
 372  
 373      if signature is None and saved_example is not None:
 374          wrapped_model = None
 375          if isinstance(model, KerasModel):
 376              wrapped_model = _KerasModelWrapper(model, signature)
 377          elif isinstance(model, tf.Module):
 378              wrapped_model = _TF2ModuleWrapper(model, signature)
 379          if wrapped_model is not None:
 380              signature = _infer_signature_from_input_example(saved_example, wrapped_model)
 381      elif signature is False:
 382          signature = None
 383  
 384      if signature is None:
 385          _logger.warning(_NO_MODEL_SIGNATURE_WARNING)
 386      else:
 387          num_inputs = len(signature.inputs.inputs)
 388          if num_inputs == 0:
 389              raise MlflowException(
 390                  "The model signature's input schema must contain at least one field.",
 391                  error_code=INVALID_PARAMETER_VALUE,
 392              )
 393          for field in signature.inputs.inputs:
 394              if not isinstance(field, TensorSpec):
 395                  raise MlflowException(
 396                      "All fields in the model signature's input schema must be of type TensorSpec.",
 397                      error_code=INVALID_PARAMETER_VALUE,
 398                  )
 399              if field.shape[0] != -1:
 400                  raise MlflowException(
 401                      "All fields in the model signature's input schema must have a shape "
 402                      "in which the first dimension is a variable dimension.",
 403                      error_code=INVALID_PARAMETER_VALUE,
 404                  )
 405  
 406      _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
 407  
 408      if signature is not None:
 409          mlflow_model.signature = signature
 410      if metadata is not None:
 411          mlflow_model.metadata = metadata
 412  
 413      if isinstance(model, KerasModel):
 414          keras_model_kwargs = keras_model_kwargs or {}
 415  
 416          data_subpath = _KERAS_MODEL_DATA_PATH
 417          # construct new data folder in existing path
 418          data_path = os.path.join(path, data_subpath)
 419          os.makedirs(data_path)
 420          model_subpath = os.path.join(data_subpath, _MODEL_SAVE_PATH)
 421  
 422          keras_module = importlib.import_module("tensorflow.keras")
 423          # save custom objects if there are custom objects
 424          if custom_objects is not None:
 425              _save_keras_custom_objects(data_path, custom_objects, _CUSTOM_OBJECTS_SAVE_PATH)
 426          # save custom objects stored within _GLOBAL_CUSTOM_OBJECTS
 427          if global_custom_objects := get_global_custom_objects():
 428              _save_keras_custom_objects(
 429                  data_path, global_custom_objects, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH
 430              )
 431  
 432          # save keras module spec to path/data/keras_module.txt
 433          with open(os.path.join(data_path, _KERAS_MODULE_SPEC_PATH), "w") as f:
 434              f.write(keras_module.__name__)
 435  
 436          # Use the SavedModel format if `save_format` is unspecified
 437          save_format = keras_model_kwargs.get("save_format", "tf")
 438  
 439          # save keras save_format to path/data/save_format.txt
 440          with open(os.path.join(data_path, _KERAS_SAVE_FORMAT_PATH), "w") as f:
 441              f.write(save_format)
 442  
 443          # save keras model
 444          # To maintain prior behavior, when the format is HDF5, we save
 445          # with the h5 file extension. Otherwise, model_path is a directory
 446          # where the saved_model.pb will be stored (for SavedModel format)
 447          # For tensorflow 2.16.0 (including dev version),
 448          # it only supports saving model in .h5 or .keras format
 449          if save_format == "h5":
 450              file_extension = ".h5"
 451          elif Version(tf.__version__).release >= (2, 16):
 452              file_extension = ".keras"
 453          else:
 454              file_extension = ""
 455          model_path = os.path.join(path, model_subpath) + file_extension
 456          if path.startswith("/dbfs/"):
 457              # The Databricks Filesystem uses a FUSE implementation that does not support
 458              # random writes. It causes an error.
 459              with tempfile.NamedTemporaryFile(suffix=".h5") as f:
 460                  model.save(f.name, **keras_model_kwargs)
 461                  f.flush()  # force flush the data
 462                  shutil.copy2(src=f.name, dst=model_path)
 463          else:
 464              model.save(model_path, **keras_model_kwargs)
 465  
 466          pyfunc_options = {
 467              "data": data_subpath,
 468          }
 469  
 470          flavor_options = {
 471              **pyfunc_options,
 472              "model_type": _MODEL_TYPE_KERAS,
 473              "keras_version": tf.__version__,
 474              "save_format": save_format,
 475          }
 476      elif isinstance(model, tf.Module):
 477          saved_model_kwargs = saved_model_kwargs or {}
 478          model_dir_subpath = _TF2MODEL_SUBPATH
 479          model_path = os.path.join(path, model_dir_subpath)
 480          tf.saved_model.save(model, model_path, **saved_model_kwargs)
 481          pyfunc_options = {}
 482          flavor_options = {
 483              "saved_model_dir": model_dir_subpath,
 484              "model_type": _MODEL_TYPE_TF2_MODULE,
 485          }
 486      else:
 487          raise MlflowException(f"Unknown model type: {type(model)}")
 488  
 489      extra_files_config = _copy_extra_files(extra_files, path)
 490  
 491      # update flavor info to mlflow_model
 492      mlflow_model.add_flavor(
 493          FLAVOR_NAME, code=code_dir_subpath, **flavor_options, **extra_files_config
 494      )
 495  
 496      # append loader_module, data and env data to mlflow_model
 497      pyfunc.add_to_model(
 498          mlflow_model,
 499          loader_module="mlflow.tensorflow",
 500          conda_env=_CONDA_ENV_FILE_NAME,
 501          python_env=_PYTHON_ENV_FILE_NAME,
 502          code=code_dir_subpath,
 503          **pyfunc_options,
 504      )
 505  
 506      # add model file size to mlflow_model
 507      if size := get_total_file_size(path):
 508          mlflow_model.model_size_bytes = size
 509  
 510      # save mlflow_model to path/MLmodel
 511      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
 512  
 513      include_cloudpickle = custom_objects is not None or get_global_custom_objects() is not None
 514      if conda_env is None:
 515          if pip_requirements is None:
 516              default_reqs = get_default_pip_requirements(include_cloudpickle)
 517              # To ensure `_load_pyfunc` can successfully load the model during the dependency
 518              # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
 519              inferred_reqs = mlflow.models.infer_pip_requirements(
 520                  path, FLAVOR_NAME, fallback=default_reqs
 521              )
 522              default_reqs = sorted(set(inferred_reqs).union(default_reqs))
 523          else:
 524              default_reqs = None
 525          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
 526              default_reqs,
 527              pip_requirements,
 528              extra_pip_requirements,
 529          )
 530      else:
 531          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
 532  
 533      with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
 534          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
 535  
 536      # Save `constraints.txt` if necessary
 537      if pip_constraints:
 538          write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
 539  
 540      # Save `requirements.txt`
 541      write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
 542  
 543      _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
 544  
 545  
 546  def _load_custom_objects(path, file_name):
 547      custom_objects_path = None
 548      if os.path.isdir(path):
 549          if os.path.isfile(os.path.join(path, file_name)):
 550              custom_objects_path = os.path.join(path, file_name)
 551      if custom_objects_path is not None:
 552          if (
 553              not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get()
 554              and not is_in_databricks_runtime()
 555              and not is_in_databricks_model_serving_environment()
 556          ):
 557              raise MlflowException(
 558                  "Deserializing custom objects using cloudpickle is disallowed, but this model "
 559                  "was saved with custom objects in pickle format. To address this issue, you need "
 560                  "to set environment variable 'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true'."
 561              )
 562          import cloudpickle
 563  
 564          with open(custom_objects_path, "rb") as f:
 565              return cloudpickle.load(f)
 566  
 567  
 568  def _load_keras_model(model_path, keras_module, save_format, **kwargs):
 569      keras_models = importlib.import_module(keras_module.__name__ + ".models")
 570      custom_objects = kwargs.pop("custom_objects", {})
 571      if saved_custom_objects := _load_custom_objects(model_path, _CUSTOM_OBJECTS_SAVE_PATH):
 572          saved_custom_objects.update(custom_objects)
 573          custom_objects = saved_custom_objects
 574  
 575      if global_custom_objects := _load_custom_objects(model_path, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH):
 576          global_custom_objects.update(custom_objects)
 577          custom_objects = global_custom_objects
 578  
 579      if os.path.isdir(model_path):
 580          model_path = os.path.join(model_path, _MODEL_SAVE_PATH)
 581  
 582      # If the save_format is HDF5, then we save with h5 file
 583      # extension to align with prior behavior of mlflow logging
 584      if save_format == "h5":
 585          model_path += ".h5"
 586      # Since TF 2.16.0, it only supports saving model in .h5 or .keras format.
 587      # But for backwards compatibility, we still save model without suffix
 588      # for older versions of TF.
 589      elif os.path.exists(model_path + ".keras"):
 590          model_path += ".keras"
 591  
 592      import tensorflow as tf
 593  
 594      # Using naive tuple-based comparison here rather than packaging.version.Version, because
 595      # the latter consider dev version e.g. 2.16.0.dev2023010 as ahead of 2.16. While that is
 596      # 'correct', we rather want to treat it is a part of 2.16 here.
 597      if save_format == "h5" and (2, 2, 3) <= Version(tf.__version__).release < (2, 16):
 598          # NOTE: TF 2.2.3 does not work with unicode paths in python2. Pass in h5py.File instead
 599          # of string to avoid issues.
 600          import h5py
 601  
 602          with h5py.File(os.path.abspath(model_path), "r") as model_path:
 603              return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs)
 604      else:
 605          # NOTE: Older versions of Keras only handle filepath.
 606          return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs)
 607  
 608  
 609  def _get_flavor_conf(model_conf):
 610      if "keras" in model_conf.flavors:
 611          return model_conf.flavors["keras"]
 612      return model_conf.flavors[FLAVOR_NAME]
 613  
 614  
 615  def _infer_model_type(model_conf):
 616      model_type = _get_flavor_conf(model_conf).get("model_type")
 617      if model_type is not None:
 618          return model_type
 619      # Loading model logged by old version mlflow, which deos not record model_type
 620      # Inferring model type by checking whether model_conf contains "keras" flavor.
 621      if "keras" in model_conf.flavors:
 622          return _MODEL_TYPE_KERAS
 623      return _MODEL_TYPE_TF1_ESTIMATOR
 624  
 625  
 626  def load_model(model_uri, dst_path=None, saved_model_kwargs=None, keras_model_kwargs=None):
 627      """
 628      Load an MLflow model that contains the TensorFlow flavor from the specified path.
 629  
 630      Args:
 631          model_uri: The location, in URI format, of the MLflow model. For example:
 632  
 633              - ``/Users/me/path/to/local/model``
 634              - ``relative/path/to/local/model``
 635              - ``s3://my_bucket/path/to/model``
 636              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
 637              - ``models:/<model_name>/<model_version>``
 638              - ``models:/<model_name>/<stage>``
 639  
 640              For more information about supported URI schemes, see
 641              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
 642              artifact-locations>`_.
 643          dst_path: The local filesystem path to which to download the model artifact.
 644              This directory must already exist. If unspecified, a local output
 645              path will be created.
 646          saved_model_kwargs: kwargs to pass to ``tensorflow.saved_model.load`` method.
 647              Only available when you are loading a tensorflow2 core model.
 648          keras_model_kwargs: kwargs to pass to ``keras.models.load_model`` method.
 649              Only available when you are loading a Keras model.
 650  
 651      Returns
 652          A callable graph (tf.function) that takes inputs and returns inferences.
 653      """
 654      import tensorflow as tf
 655  
 656      local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
 657  
 658      model_configuration_path = os.path.join(local_model_path, MLMODEL_FILE_NAME)
 659      model_conf = Model.load(model_configuration_path)
 660  
 661      flavor_conf = _get_flavor_conf(model_conf)
 662  
 663      _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
 664  
 665      model_type = _infer_model_type(model_conf)
 666      if model_type == _MODEL_TYPE_KERAS:
 667          keras_model_kwargs = keras_model_kwargs or {}
 668          keras_module = importlib.import_module(flavor_conf.get("keras_module", "tensorflow.keras"))
 669          # For backwards compatibility, we assume h5 when the save_format is absent
 670          save_format = flavor_conf.get("save_format", "h5")
 671          model_path = os.path.join(local_model_path, flavor_conf.get("data", _MODEL_SAVE_PATH))
 672          return _load_keras_model(
 673              model_path=model_path,
 674              keras_module=keras_module,
 675              save_format=save_format,
 676              **keras_model_kwargs,
 677          )
 678      if model_type == _MODEL_TYPE_TF1_ESTIMATOR:
 679          tf_saved_model_dir = os.path.join(local_model_path, flavor_conf["saved_model_dir"])
 680          tf_meta_graph_tags = flavor_conf["meta_graph_tags"]
 681          tf_signature_def_key = flavor_conf["signature_def_key"]
 682          return _load_tf1_estimator_saved_model(
 683              tf_saved_model_dir=tf_saved_model_dir,
 684              tf_meta_graph_tags=tf_meta_graph_tags,
 685              tf_signature_def_key=tf_signature_def_key,
 686          )
 687      if model_type == _MODEL_TYPE_TF2_MODULE:
 688          saved_model_kwargs = saved_model_kwargs or {}
 689          tf_saved_model_dir = os.path.join(local_model_path, flavor_conf["saved_model_dir"])
 690          return tf.saved_model.load(tf_saved_model_dir, **saved_model_kwargs)
 691  
 692      raise MlflowException(f"Unknown model_type: {model_type}")
 693  
 694  
 695  def _load_tf1_estimator_saved_model(tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key):
 696      """
 697      Load a specified TensorFlow model consisting of a TensorFlow metagraph and signature definition
 698      from a serialized TensorFlow ``SavedModel`` collection.
 699  
 700      Args:
 701          tf_saved_model_dir: The local filesystem path or run-relative artifact path to the model.
 702          tf_meta_graph_tags: A list of tags identifying the model's metagraph within the
 703              serialized ``SavedModel`` object. For more information, see the
 704              ``tags`` parameter of the `tf.saved_model.builder.SavedModelBuilder
 705              method <https://www.tensorflow.org/api_docs/python/tf/saved_model/
 706              builder/SavedModelBuilder#add_meta_graph>`_.
 707          tf_signature_def_key: A string identifying the input/output signature associated with the
 708              model. This is a key within the serialized ``SavedModel``'s
 709              signature definition mapping. For more information, see the
 710              ``signature_def_map`` parameter of the
 711              ``tf.saved_model.builder.SavedModelBuilder`` method.
 712  
 713      Returns:
 714          A callable graph (tensorflow.function) that takes inputs and returns inferences.
 715      """
 716      import tensorflow as tf
 717  
 718      loaded = tf.saved_model.load(tags=tf_meta_graph_tags, export_dir=tf_saved_model_dir)
 719      loaded_sig = loaded.signatures
 720      if tf_signature_def_key not in loaded_sig:
 721          raise MlflowException(
 722              f"Could not find signature def key {tf_signature_def_key}. "
 723              f"Available keys are: {list(loaded_sig.keys())}"
 724          )
 725      return loaded_sig[tf_signature_def_key]
 726  
 727  
 728  def _load_pyfunc(path):
 729      """
 730      Load PyFunc implementation. Called by ``pyfunc.load_model``. This function loads an MLflow
 731      model with the TensorFlow flavor into a new TensorFlow graph and exposes it behind the
 732      ``pyfunc.predict`` interface.
 733  
 734      Args:
 735          path: Local filesystem path to the MLflow Model with the ``tensorflow`` flavor.
 736      """
 737      import tensorflow as tf
 738  
 739      model_meta_path1 = os.path.join(path, MLMODEL_FILE_NAME)
 740      model_meta_path2 = os.path.join(os.path.dirname(path), MLMODEL_FILE_NAME)
 741  
 742      if os.path.isfile(model_meta_path1):
 743          model_meta = Model.load(model_meta_path1)
 744      elif os.path.isfile(model_meta_path2):
 745          model_meta = Model.load(model_meta_path2)
 746      else:
 747          raise MlflowException(f"Cannot find file {MLMODEL_FILE_NAME} for the logged model.")
 748  
 749      model_type = _infer_model_type(model_meta)
 750      if model_type == _MODEL_TYPE_KERAS:
 751          if os.path.isfile(os.path.join(path, _KERAS_MODULE_SPEC_PATH)):
 752              with open(os.path.join(path, _KERAS_MODULE_SPEC_PATH)) as f:
 753                  keras_module = importlib.import_module(f.read())
 754          else:
 755              from tensorflow import keras
 756  
 757              keras_module = keras
 758  
 759          # By default, we assume the save_format is h5 for backwards compatibility
 760          save_format = "h5"
 761          save_format_path = os.path.join(path, _KERAS_SAVE_FORMAT_PATH)
 762          if os.path.isfile(save_format_path):
 763              with open(save_format_path) as f:
 764                  save_format = f.read()
 765  
 766          # In SavedModel format, loaded model should be compiled.
 767          should_compile = save_format == "tf"
 768          m = _load_keras_model(
 769              path, keras_module=keras_module, save_format=save_format, compile=should_compile
 770          )
 771          return _KerasModelWrapper(m, model_meta.signature)
 772      if model_type == _MODEL_TYPE_TF1_ESTIMATOR:
 773          flavor_conf = _get_flavor_configuration(path, FLAVOR_NAME)
 774  
 775          tf_saved_model_dir = os.path.join(path, flavor_conf["saved_model_dir"])
 776          tf_meta_graph_tags = flavor_conf["meta_graph_tags"]
 777          tf_signature_def_key = flavor_conf["signature_def_key"]
 778  
 779          loaded_model = tf.saved_model.load(export_dir=tf_saved_model_dir, tags=tf_meta_graph_tags)
 780          return _TF2Wrapper(model=loaded_model, infer=loaded_model.signatures[tf_signature_def_key])
 781      if model_type == _MODEL_TYPE_TF2_MODULE:
 782          flavor_conf = _get_flavor_configuration(path, FLAVOR_NAME)
 783          tf_saved_model_dir = os.path.join(path, flavor_conf["saved_model_dir"])
 784          loaded_model = tf.saved_model.load(tf_saved_model_dir)
 785          return _TF2ModuleWrapper(model=loaded_model, signature=model_meta.signature)
 786  
 787      raise MlflowException("Unknown model_type.")
 788  
 789  
 790  class _TF2Wrapper:
 791      """
 792      Wrapper class that exposes a TensorFlow model for inference via a ``predict`` function such that
 793      ``predict(data: pandas.DataFrame) -> pandas.DataFrame``. For TensorFlow versions >= 2.0.0.
 794      """
 795  
 796      def __init__(self, model, infer):
 797          """
 798          Args:
 799              model: A Tensorflow SavedModel.
 800              infer: Tensorflow function returned by a saved model that is used for inference.
 801          """
 802          # Note: we need to retain the model reference in TF2Wrapper object, because the infer
 803          #  function in tensorflow will be `ConcreteFunction` which only retains WeakRefs to the
 804          #  variables they close over.
 805          #  See https://www.tensorflow.org/guide/function#deleting_tfvariables_between_function_calls
 806          self.model = model
 807          self.infer = infer
 808  
 809      def get_raw_model(self):
 810          """
 811          Returns the underlying model.
 812          """
 813          return self.model
 814  
 815      def predict(
 816          self,
 817          data,
 818          params: dict[str, Any] | None = None,
 819      ):
 820          """
 821          Args:
 822              data: Model input data.
 823              params: Additional parameters to pass to the model for inference.
 824  
 825          Returns:
 826              Model predictions.
 827          """
 828          import tensorflow as tf
 829  
 830          feed_dict = {}
 831          if isinstance(data, dict):
 832              feed_dict = {k: tf.constant(v) for k, v in data.items()}
 833          elif isinstance(data, pandas.DataFrame):
 834              for df_col_name in list(data):
 835                  # If there are multiple columns with the same name, selecting the shared name
 836                  # from the DataFrame will result in another DataFrame containing the columns
 837                  # with the shared name. TensorFlow cannot make eager tensors out of pandas
 838                  # DataFrames, so we convert the DataFrame to a numpy array here.
 839                  val = data[df_col_name]
 840                  val = val.values if isinstance(val, pandas.DataFrame) else np.array(val.to_list())
 841                  feed_dict[df_col_name] = tf.constant(val)
 842          else:
 843              raise TypeError("Only dict and DataFrame input types are supported")
 844  
 845          raw_preds = self.infer(**feed_dict)
 846          pred_dict = {col_name: raw_preds[col_name].numpy() for col_name in raw_preds.keys()}
 847          for col in pred_dict.keys():
 848              # If the output tensor is not 1-dimensional
 849              # AND all elements have length of 1, flatten the array with `ravel()`
 850              if len(pred_dict[col].shape) != 1 and all(
 851                  len(element) == 1 for element in pred_dict[col]
 852              ):
 853                  pred_dict[col] = pred_dict[col].ravel()
 854              else:
 855                  pred_dict[col] = pred_dict[col].tolist()
 856  
 857          if isinstance(data, dict):
 858              return pred_dict
 859          else:
 860              return pandas.DataFrame.from_dict(data=pred_dict)
 861  
 862  
 863  class _TF2ModuleWrapper:
 864      def __init__(self, model, signature):
 865          self.model = model
 866          self.signature = signature
 867  
 868      def get_raw_model(self):
 869          """
 870          Returns the underlying model.
 871          """
 872          return self.model
 873  
 874      def predict(
 875          self,
 876          data,
 877          params: dict[str, Any] | None = None,
 878      ):
 879          """
 880          Args:
 881              data: Model input data.
 882              params: Additional parameters to pass to the model for inference.
 883  
 884          Returns:
 885              Model predictions.
 886          """
 887          import tensorflow as tf
 888  
 889          if isinstance(data, (np.ndarray, list)):
 890              data = tf.convert_to_tensor(data)
 891          else:
 892              raise MlflowException(
 893                  f"Unsupported input data type: {type(data)}, the input data must be "
 894                  "numpy array or a list."
 895              )
 896          result = self.model(data)
 897          if isinstance(result, tf.Tensor):
 898              return result.numpy()
 899          return result
 900  
 901  
 902  class _KerasModelWrapper:
 903      def __init__(self, keras_model, signature):
 904          self.keras_model = keras_model
 905          self.signature = signature
 906  
 907      def get_raw_model(self):
 908          """
 909          Returns the underlying model.
 910          """
 911          return self.keras_model
 912  
 913      def predict(
 914          self,
 915          data,
 916          params: dict[str, Any] | None = None,
 917      ):
 918          """
 919          Args:
 920              data: Model input data.
 921              params: Additional parameters to pass to the model for inference.
 922  
 923          Returns
 924              Model predictions.
 925          """
 926          if isinstance(data, pandas.DataFrame):
 927              # This line is for backwards compatibility:
 928              # If model signature is not None, when calling
 929              # `keras_pyfunc_model.predict(pandas_dataframe)`, `_enforce_schema` will convert
 930              # dataframe input into dict input, so in the case `_KerasModelWrapper.predict`
 931              # will receive a dict type input.
 932              # If model signature is None, `_enforce_schema` can do nothing, and if the input
 933              # is dataframe, `_KerasModelWrapper.predict` will receive a dataframe input,
 934              # we need to handle this case, to keep backwards compatibility.
 935              return pandas.DataFrame(self.keras_model.predict(data.values), index=data.index)
 936  
 937          supported_input_types = (np.ndarray, list, tuple, dict)
 938          if not isinstance(data, supported_input_types):
 939              raise MlflowException(
 940                  f"Unsupported input data type: {type(data)}. "
 941                  f"Must be one of: {[x.__name__ for x in supported_input_types]}",
 942                  INVALID_PARAMETER_VALUE,
 943              )
 944          return self.keras_model.predict(data)
 945  
 946  
 947  def _assoc_list_to_map(lst):
 948      """
 949      Convert an association list to a dictionary.
 950      """
 951      d = {}
 952      for run_id, metric in lst:
 953          d[run_id] = d[run_id] + [metric] if run_id in d else [metric]
 954      return d
 955  
 956  
 957  @picklable_exception_safe_function
 958  def _get_tensorboard_callback(lst):
 959      import tensorflow as tf
 960  
 961      for x in lst:
 962          if isinstance(x, tf.keras.callbacks.TensorBoard):
 963              return x
 964      return None
 965  
 966  
 967  # A representation of a TensorBoard event logging directory with two attributes:
 968  # :location - string: The filesystem location of the logging directory
 969  # :is_temp - boolean: `True` if the logging directory was created for temporary use by MLflow,
 970  #                     `False` otherwise
 971  class _TensorBoardLogDir(NamedTuple):
 972      location: str
 973      is_temp: bool
 974  
 975  
 976  def _setup_callbacks(callbacks, log_every_epoch, log_every_n_steps):
 977      """
 978      Adds TensorBoard and MlfLowTfKeras callbacks to the
 979      input list, and returns the new list and appropriate log directory.
 980      """
 981      from mlflow.tensorflow.autologging import _TensorBoard
 982      from mlflow.tensorflow.callback import MlflowCallback, MlflowModelCheckpointCallback
 983  
 984      tb = _get_tensorboard_callback(callbacks)
 985      for callback in callbacks:
 986          if isinstance(callback, MlflowCallback):
 987              raise MlflowException(
 988                  "MLflow autologging must be turned off if an `MlflowCallback` is explicitly added "
 989                  "to the callback list. You are creating an `MlflowCallback` while having "
 990                  "autologging enabled. Please either call `mlflow.tensorflow.autolog(disable=True)` "
 991                  "to disable autologging or remove `MlflowCallback` from the callback list. "
 992              )
 993      if tb is None:
 994          log_dir = _TensorBoardLogDir(location=tempfile.mkdtemp(), is_temp=True)
 995          callbacks.append(_TensorBoard(log_dir.location))
 996      else:
 997          log_dir = _TensorBoardLogDir(location=tb.log_dir, is_temp=False)
 998  
 999      callbacks.append(
1000          MlflowCallback(
1001              log_every_epoch=log_every_epoch,
1002              log_every_n_steps=log_every_n_steps,
1003          )
1004      )
1005  
1006      model_checkpoint = get_autologging_config(mlflow.tensorflow.FLAVOR_NAME, "checkpoint", True)
1007      if model_checkpoint:
1008          checkpoint_monitor = get_autologging_config(
1009              mlflow.tensorflow.FLAVOR_NAME, "checkpoint_monitor", "val_loss"
1010          )
1011          checkpoint_mode = get_autologging_config(
1012              mlflow.tensorflow.FLAVOR_NAME, "checkpoint_mode", "min"
1013          )
1014          checkpoint_save_best_only = get_autologging_config(
1015              mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_best_only", True
1016          )
1017          checkpoint_save_weights_only = get_autologging_config(
1018              mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_weights_only", False
1019          )
1020          checkpoint_save_freq = get_autologging_config(
1021              mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_freq", "epoch"
1022          )
1023  
1024          if not any(isinstance(callback, MlflowModelCheckpointCallback) for callback in callbacks):
1025              callbacks.append(
1026                  MlflowModelCheckpointCallback(
1027                      monitor=checkpoint_monitor,
1028                      mode=checkpoint_mode,
1029                      save_best_only=checkpoint_save_best_only,
1030                      save_weights_only=checkpoint_save_weights_only,
1031                      save_freq=checkpoint_save_freq,
1032                  )
1033              )
1034  
1035      return callbacks, log_dir
1036  
1037  
1038  @autologging_integration(FLAVOR_NAME)
1039  def autolog(
1040      log_models=True,
1041      log_datasets=True,
1042      disable=False,
1043      exclusive=False,
1044      disable_for_unsupported_versions=False,
1045      silent=False,
1046      registered_model_name=None,
1047      log_input_examples=False,
1048      log_model_signatures=True,
1049      saved_model_kwargs=None,
1050      keras_model_kwargs=None,
1051      extra_tags=None,
1052      log_every_epoch=True,
1053      log_every_n_steps=None,
1054      checkpoint=True,
1055      checkpoint_monitor="val_loss",
1056      checkpoint_mode="min",
1057      checkpoint_save_best_only=True,
1058      checkpoint_save_weights_only=False,
1059      checkpoint_save_freq="epoch",
1060  ):
1061      """
1062      Enables autologging for ``tf.keras``.
1063      Note that only ``tensorflow>=2.3`` are supported.
1064      As an example, try running the
1065      `Keras/TensorFlow example <https://github.com/mlflow/mlflow/blob/master/examples/keras/train.py>`_.
1066  
1067      For each TensorFlow module, autologging captures the following information:
1068  
1069      **tf.keras**
1070       - **Metrics** and **Parameters**
1071  
1072        - Training and validation loss.
1073        - User-specified metrics.
1074        - Optimizer config, e.g., learning_rate, momentum, etc.
1075        - Training configs, e.g., epochs, batch_size, etc.
1076  
1077       - **Artifacts**
1078  
1079        - Model summary on training start.
1080        - Saved Keras model in `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ format.
1081        - TensorBoard logs on training end.
1082  
1083      **tf.keras.callbacks.EarlyStopping**
1084       - **Metrics** and **Parameters**
1085  
1086        - Metrics from the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``,
1087          ``restore_best_weight``, etc
1088        - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``:
1089          ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc
1090  
1091      Refer to the autologging tracking documentation for more
1092      information on `TensorFlow workflows
1093      <https://www.mlflow.org/docs/latest/tracking.html#tensorflow-and-keras-experimental>`_.
1094  
1095      Note that autologging cannot be used together with explicit MLflow callback, i.e.,
1096      `mlflow.tensorflow.MlflowCallback`, because it will cause the same metrics to be logged twice.
1097      If you want to include `mlflow.tensorflow.MlflowCallback` in the callback list, please turn off
1098      autologging by calling `mlflow.tensorflow.autolog(disable=True)`.
1099  
1100      Args:
1101          log_models: If ``True``, trained models are logged as MLflow model artifacts.
1102              If ``False``, trained models are not logged.
1103          log_datasets: If ``True``, dataset information is logged to MLflow Tracking.
1104              If ``False``, dataset information is not logged.
1105          disable: If ``True``, disables the TensorFlow autologging integration. If ``False``,
1106              enables the TensorFlow integration autologging integration.
1107          exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
1108              If ``False``, autologged content is logged to the active fluent run,
1109              which may be user-created.
1110          disable_for_unsupported_versions: If ``True``, disable autologging for versions of
1111              tensorflow that have not been tested against this version of the MLflow
1112              client or are incompatible.
1113          silent: If ``True``, suppress all event logs and warnings from MLflow during TensorFlow
1114              autologging. If ``False``, show all events and warnings during TensorFlow
1115              autologging.
1116          registered_model_name: If given, each time a model is trained, it is registered as a
1117              new model version of the registered model with this name.
1118              The registered model is created if it does not already exist.
1119          log_input_examples: If ``True``, input examples from training datasets are collected and
1120              logged along with tf/keras model artifacts during training. If
1121              ``False``, input examples are not logged.
1122          log_model_signatures: If ``True``,
1123              :py:class:`ModelSignatures <mlflow.models.ModelSignature>`
1124              describing model inputs and outputs are collected and logged along
1125              with tf/keras model artifacts during training. If ``False``,
1126              signatures are not logged. Note that logging TensorFlow models
1127              with signatures changes their pyfunc inference behavior when
1128              Pandas DataFrames are passed to ``predict()``.
1129              When a signature is present, an ``np.ndarray``
1130              (for single-output models) or a mapping from
1131              ``str`` -> ``np.ndarray`` (for multi-output models) is returned;
1132              when a signature is not present, a Pandas DataFrame is returned.
1133          saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method.
1134          keras_model_kwargs: a dict of kwargs to pass to ``keras_model.save`` method.
1135          extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
1136          log_every_epoch: If True, training metrics will be logged at the end of each epoch.
1137          log_every_n_steps: If set, training metrics will be logged every `n` training steps.
1138              `log_every_n_steps` must be `None` when `log_every_epoch=True`.
1139          checkpoint: Enable automatic model checkpointing.
1140          checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if
1141              you set `model_checkpoint_save_best_only` to True.
1142          checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing,
1143              if save_best_only=True, the decision to overwrite the current save file is made based on
1144              either the maximization or the minimization of the monitored quantity.
1145          checkpoint_save_best_only: If True, automatic model checkpointing only saves when
1146              the model is considered the "best" model according to the quantity
1147              monitored and previous checkpoint model is overwritten.
1148          checkpoint_save_weights_only: In automatic model checkpointing, if True, then
1149              only the model's weights will be saved. Otherwise, the optimizer states,
1150              lr-scheduler states, etc are added in the checkpoint too.
1151          checkpoint_save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
1152              saves the model after each epoch. When using integer, the callback
1153              saves the model at end of this many batches. Note that if the saving isn't aligned to
1154              epochs, the monitored metric may potentially be less reliable (it
1155              could reflect as little as 1 batch, since the metrics get reset
1156              every epoch). Defaults to `"epoch"`.
1157      """
1158      import tensorflow as tf
1159  
1160      if Version(tf.__version__) < Version("2.3"):
1161          _logger.error(
1162              "Could not log to MLflow because your Tensorflow version is below 2.3, detected "
1163              f"version: {tf.__version__}."
1164          )
1165          return
1166  
1167      @picklable_exception_safe_function
1168      def _get_early_stop_callback(callbacks):
1169          for callback in callbacks:
1170              if isinstance(callback, tf.keras.callbacks.EarlyStopping):
1171                  return callback
1172          return None
1173  
1174      def _log_early_stop_callback_params(callback):
1175          if callback:
1176              try:
1177                  earlystopping_params = {
1178                      "monitor": callback.monitor,
1179                      "min_delta": callback.min_delta,
1180                      "patience": callback.patience,
1181                      "baseline": callback.baseline,
1182                      "restore_best_weights": callback.restore_best_weights,
1183                  }
1184                  mlflow.log_params(earlystopping_params)
1185              except Exception:
1186                  return
1187  
1188      def _get_early_stop_callback_attrs(callback):
1189          try:
1190              return callback.stopped_epoch, callback.restore_best_weights, callback.patience
1191          except Exception:
1192              return None
1193  
1194      def _log_early_stop_callback_metrics(callback, history, model_id=None):
1195          from mlflow import log_metrics
1196  
1197          if callback is None or not callback.model.stop_training:
1198              return
1199  
1200          callback_attrs = _get_early_stop_callback_attrs(callback)
1201          if callback_attrs is None:
1202              return
1203  
1204          stopped_epoch, restore_best_weights, _ = callback_attrs
1205          log_metrics({"stopped_epoch": stopped_epoch}, synchronous=False, model_id=model_id)
1206  
1207          if not restore_best_weights or callback.best_weights is None:
1208              return
1209  
1210          monitored_metric = history.history.get(callback.monitor)
1211          if not monitored_metric:
1212              return
1213  
1214          initial_epoch = history.epoch[0]
1215          # If `monitored_metric` contains multiple best values (e.g. [0.1, 0.1, 0.2] where 0.1 is
1216          # the minimum loss), the epoch corresponding to the first occurrence of the best value is
1217          # the best epoch. In keras > 2.6.0, the best epoch can be obtained via the `best_epoch`
1218          # attribute of an `EarlyStopping` instance: https://github.com/keras-team/keras/pull/15197
1219          restored_epoch = initial_epoch + monitored_metric.index(callback.best)
1220          log_metrics({"restored_epoch": restored_epoch}, synchronous=False, model_id=model_id)
1221          restored_index = history.epoch.index(restored_epoch)
1222          restored_metrics = {
1223              key: metrics[restored_index] for key, metrics in history.history.items()
1224          }
1225          # Checking that a metric history exists
1226          metric_key = next(iter(history.history), None)
1227          if metric_key is not None:
1228              log_metrics(restored_metrics, stopped_epoch + 1, synchronous=False, model_id=model_id)
1229  
1230      def _log_keras_model(history, args, model_id=None):
1231          def _infer_model_signature(input_data_slice):
1232              # In certain TensorFlow versions, calling `predict()` on model may modify
1233              # the `stop_training` attribute, so we save and restore it accordingly
1234              original_stop_training = history.model.stop_training
1235              model_output = history.model.predict(input_data_slice)
1236              history.model.stop_training = original_stop_training
1237              return infer_signature(input_data_slice, model_output)
1238  
1239          from mlflow.tensorflow.autologging import extract_tf_keras_input_example
1240  
1241          def _get_tf_keras_input_example_slice():
1242              input_training_data = args[0]
1243              keras_input_example_slice = extract_tf_keras_input_example(input_training_data)
1244              if keras_input_example_slice is None:
1245                  raise MlflowException(
1246                      "Cannot log input example or model signature for input with type"
1247                      f" {type(input_training_data)}. TensorFlow Keras autologging can"
1248                      " only log input examples and model signatures for the following"
1249                      " input types: numpy.ndarray, dict[string -> numpy.ndarray],"
1250                      " tensorflow.keras.utils.Sequence, and"
1251                      " tensorflow.data.Dataset (TensorFlow >= 2.1.0 required)",
1252                      INVALID_PARAMETER_VALUE,
1253                  )
1254              return keras_input_example_slice
1255  
1256          input_example, signature = resolve_input_example_and_signature(
1257              _get_tf_keras_input_example_slice,
1258              _infer_model_signature,
1259              log_input_examples,
1260              log_model_signatures,
1261              _logger,
1262          )
1263  
1264          log_model(
1265              history.model,
1266              "model",
1267              input_example=input_example,
1268              signature=signature,
1269              registered_model_name=get_autologging_config(
1270                  FLAVOR_NAME, "registered_model_name", None
1271              ),
1272              saved_model_kwargs=saved_model_kwargs,
1273              keras_model_kwargs=keras_model_kwargs,
1274              model_id=model_id,
1275          )
1276  
1277      def _patched_inference(original, inst, *args, **kwargs):
1278          log_dir = None
1279          try:
1280              unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"]
1281  
1282              batch_size = None
1283              try:
1284                  is_single_input_model = isinstance(inst.input_shape, tuple)
1285                  training_data = kwargs["x"] if "x" in kwargs else args[0]
1286                  if isinstance(training_data, tf.data.Dataset) and hasattr(
1287                      training_data, "_batch_size"
1288                  ):
1289                      batch_size = training_data._batch_size.numpy()
1290                  elif isinstance(training_data, tf.keras.utils.Sequence):
1291                      first_batch_inputs, *_ = training_data[0]
1292                      if is_single_input_model:
1293                          batch_size = len(first_batch_inputs)
1294                      else:
1295                          batch_size = len(first_batch_inputs[0])
1296                  elif is_iterator(training_data):
1297                      peek = next(training_data)
1298                      batch_size = len(peek[0]) if is_single_input_model else len(peek[0][0])
1299  
1300                      def __restore_generator(prev_generator):
1301                          yield peek
1302                          yield from prev_generator
1303  
1304                      restored_generator = __restore_generator(training_data)
1305                      if "x" in kwargs:
1306                          kwargs["x"] = restored_generator
1307                      else:
1308                          args = (restored_generator,) + args[1:]
1309              except Exception as e:
1310                  _logger.warning(
1311                      "Encountered unexpected error while inferring batch size from training"
1312                      " dataset: %s",
1313                      e,
1314                  )
1315  
1316              if batch_size is not None:
1317                  mlflow.log_param("batch_size", batch_size)
1318                  unlogged_params.append("batch_size")
1319  
1320              log_fn_args_as_params(original, args, kwargs, unlogged_params)
1321  
1322              # Check if the 'callback' argument of fit() is set positionally
1323              if len(args) >= 6:
1324                  # Convert the positional training function arguments to a list in order to
1325                  # mutate the contents
1326                  args = list(args)
1327                  # Make a shallow copy of the preexisting callbacks to avoid permanently
1328                  # modifying their contents for future training invocations. Introduce
1329                  # TensorBoard & tf.keras callbacks if necessary
1330                  callbacks = list(args[5])
1331                  callbacks, log_dir = _setup_callbacks(
1332                      callbacks,
1333                      log_every_epoch=log_every_epoch,
1334                      log_every_n_steps=log_every_n_steps,
1335                  )
1336                  # Replace the callbacks positional entry in the copied arguments and convert
1337                  # the arguments back to tuple form for usage in the training function
1338                  args[5] = callbacks
1339                  args = tuple(args)
1340              else:
1341                  # Make a shallow copy of the preexisting callbacks and introduce TensorBoard
1342                  # & tf.keras callbacks if necessary
1343                  callbacks = list(kwargs.get("callbacks") or [])
1344                  kwargs["callbacks"], log_dir = _setup_callbacks(
1345                      callbacks,
1346                      log_every_epoch=log_every_epoch,
1347                      log_every_n_steps=log_every_n_steps,
1348                  )
1349  
1350              early_stop_callback = _get_early_stop_callback(callbacks)
1351              _log_early_stop_callback_params(early_stop_callback)
1352  
1353              model_id = None
1354              if log_models:
1355                  model_id = _initialize_logged_model("model", flavor=FLAVOR_NAME).model_id
1356  
1357              if log_datasets:
1358                  try:
1359                      context_tags = context_registry.resolve_tags()
1360                      source = CodeDatasetSource(tags=context_tags)
1361  
1362                      x = kwargs["x"] if "x" in kwargs else args[0]
1363                      if "y" in kwargs:
1364                          y = kwargs["y"]
1365                      elif len(args) >= 2:
1366                          y = args[1]
1367                      else:
1368                          y = None
1369  
1370                      if "validation_data" in kwargs:
1371                          validation_data = kwargs["validation_data"]
1372                      elif len(args) >= 8:
1373                          validation_data = args[7]
1374                      else:
1375                          validation_data = None
1376                      _log_tensorflow_dataset(x, source, "train", targets=y, model_id=model_id)
1377                      if validation_data is not None:
1378                          _log_tensorflow_dataset(validation_data, source, "eval", model_id=model_id)
1379  
1380                  except Exception as e:
1381                      _logger.warning(
1382                          "Failed to log training dataset information to MLflow Tracking. Reason: %s",
1383                          e,
1384                      )
1385  
1386              history = original(inst, *args, **kwargs)
1387  
1388              if log_models:
1389                  _log_keras_model(history, args, model_id=model_id)
1390  
1391              _log_early_stop_callback_metrics(
1392                  callback=early_stop_callback,
1393                  history=history,
1394                  model_id=model_id,
1395              )
1396              # Ensure all data are logged.
1397              # Shut down the async logging (instead of flushing)
1398              # to avoid leaving zombie threads between patchings.
1399              _shut_down_async_logging()
1400  
1401              mlflow.log_artifacts(
1402                  local_dir=log_dir.location,
1403                  artifact_path="tensorboard_logs",
1404              )
1405              if log_dir.is_temp:
1406                  shutil.rmtree(log_dir.location)
1407              return history
1408  
1409          except (Exception, KeyboardInterrupt) as e:
1410              try:
1411                  if log_dir is not None and log_dir.is_temp and os.path.exists(log_dir.location):
1412                      shutil.rmtree(log_dir.location)
1413              finally:
1414                  # Regardless of what happens during the `_on_exception` callback, reraise
1415                  # the original implementation exception once the callback completes
1416                  raise e
1417  
1418      safe_patch(
1419          FLAVOR_NAME,
1420          tf.keras.Model,
1421          "fit",
1422          _patched_inference,
1423          manage_run=True,
1424          extra_tags=extra_tags,
1425      )
1426  
1427  
1428  def _log_tensorflow_dataset(
1429      tensorflow_dataset, source, context, name=None, targets=None, model_id=None
1430  ):
1431      import tensorflow as tf
1432  
1433      # create a dataset
1434      if isinstance(tensorflow_dataset, np.ndarray):
1435          dataset = from_numpy(features=tensorflow_dataset, targets=targets, source=source, name=name)
1436      elif isinstance(tensorflow_dataset, tf.Tensor):
1437          dataset = from_tensorflow(
1438              features=tensorflow_dataset, targets=targets, source=source, name=name
1439          )
1440      elif isinstance(tensorflow_dataset, tf.data.Dataset):
1441          dataset = from_tensorflow(features=tensorflow_dataset, source=source, name=name)
1442      elif isinstance(tensorflow_dataset, tuple):
1443          x = tensorflow_dataset[0]
1444          y = tensorflow_dataset[1]
1445          # check if x and y are tensors
1446          if isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor):
1447              dataset = from_tensorflow(features=x, source=source, targets=y, name=name)
1448          else:
1449              dataset = from_numpy(features=x, targets=y, source=source, name=name)
1450      else:
1451          _logger.warning(
1452              "Unrecognized dataset type %s. Dataset logging skipped.", type(tensorflow_dataset)
1453          )
1454          return
1455  
1456      model = None if model_id is None else LoggedModelInput(model_id=model_id)
1457      mlflow.log_input(dataset, context, model=model)
1458  
1459  
1460  def load_checkpoint(model=None, run_id=None, epoch=None, global_step=None):
1461      """
1462      If you enable "checkpoint" in autologging, during Keras model
1463      training execution, checkpointed models are logged as MLflow artifacts.
1464      Using this API, you can load the checkpointed model.
1465  
1466      If you want to load the latest checkpoint, set both `epoch` and `global_step` to None.
1467      If "checkpoint_save_freq" is set to "epoch" in autologging,
1468      you can set `epoch` param to the epoch of the checkpoint to load specific epoch checkpoint.
1469      If "checkpoint_save_freq" is set to an integer in autologging,
1470      you can set `global_step` param to the global step of the checkpoint to load specific
1471      global step checkpoint.
1472      `epoch` param and `global_step` can't be set together.
1473  
1474      Args:
1475          model: A Keras model, this argument is required
1476              only when the saved checkpoint is "weight-only".
1477          run_id: The id of the run which model is logged to. If not provided,
1478              current active run is used.
1479          epoch: The epoch of the checkpoint to be loaded, if you set
1480              "checkpoint_save_freq" to "epoch".
1481          global_step: The global step of the checkpoint to be loaded, if
1482              you set "checkpoint_save_freq" to an integer.
1483  
1484      Returns:
1485          The instance of a Keras model restored from the specified checkpoint.
1486  
1487      .. code-block:: python
1488          :caption: Example
1489  
1490          import mlflow
1491  
1492          mlflow.tensorflow.autolog(checkpoint=True, checkpoint_save_best_only=False)
1493  
1494          model = create_tf_keras_model()  # Create a Keras model
1495          with mlflow.start_run() as run:
1496              model.fit(data, label, epoch=10)
1497  
1498          run_id = run.info.run_id
1499  
1500          # load latest checkpoint model
1501          latest_checkpoint_model = mlflow.tensorflow.load_checkpoint(run_id=run_id)
1502  
1503          # load history checkpoint model logged in second epoch
1504          checkpoint_model = mlflow.tensorflow.load_checkpoint(run_id=run_id, epoch=2)
1505      """
1506      import tensorflow as tf
1507  
1508      with TempDir() as tmp_dir:
1509          downloaded_checkpoint_filepath = download_checkpoint_artifact(
1510              run_id=run_id, epoch=epoch, global_step=global_step, dst_path=tmp_dir.path()
1511          )
1512  
1513          fname = os.path.splitext(downloaded_checkpoint_filepath)[0]
1514          if fname.endswith(_WEIGHT_ONLY_CHECKPOINT_SUFFIX):
1515              # the model is saved as weights only
1516              if model is None:
1517                  raise MlflowException(
1518                      "The latest checkpoint is weights-only, 'model' argument must be provided"
1519                  )
1520              model.load_weights(downloaded_checkpoint_filepath)
1521              return model
1522          return tf.keras.models.load_model(downloaded_checkpoint_filepath)