__init__.py
1 """ 2 The ``mlflow.tensorflow`` module provides an API for logging and loading TensorFlow models. 3 This module exports TensorFlow models with the following flavors: 4 5 TensorFlow (native) format 6 This is the main flavor that can be loaded back into TensorFlow. 7 :py:mod:`mlflow.pyfunc` 8 Produced for use by generic pyfunc-based deployment tools and batch inference. 9 """ 10 11 import importlib 12 import logging 13 import os 14 import shutil 15 import tempfile 16 from typing import Any, NamedTuple 17 18 import numpy as np 19 import pandas 20 import yaml 21 from packaging.version import Version 22 23 import mlflow 24 from mlflow import pyfunc 25 from mlflow.data.code_dataset_source import CodeDatasetSource 26 from mlflow.data.numpy_dataset import from_numpy 27 from mlflow.data.tensorflow_dataset import from_tensorflow 28 from mlflow.entities import LoggedModelInput 29 from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION 30 from mlflow.exceptions import INVALID_PARAMETER_VALUE, MlflowException 31 from mlflow.models import Model, ModelInputExample, ModelSignature, infer_signature 32 from mlflow.models.model import MLMODEL_FILE_NAME 33 from mlflow.models.signature import _infer_signature_from_input_example 34 from mlflow.models.utils import _save_example 35 from mlflow.tensorflow.callback import MlflowCallback, MlflowModelCheckpointCallback # noqa: F401 36 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 37 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 38 from mlflow.tracking.context import registry as context_registry 39 from mlflow.tracking.fluent import _initialize_logged_model, _shut_down_async_logging 40 from mlflow.types.schema import TensorSpec 41 from mlflow.utils import is_iterator 42 from mlflow.utils.autologging_utils import ( 43 autologging_integration, 44 get_autologging_config, 45 log_fn_args_as_params, 46 picklable_exception_safe_function, 47 resolve_input_example_and_signature, 48 safe_patch, 49 ) 50 from mlflow.utils.checkpoint_utils import ( 51 _WEIGHT_ONLY_CHECKPOINT_SUFFIX, 52 download_checkpoint_artifact, 53 ) 54 from mlflow.utils.databricks_utils import ( 55 is_in_databricks_model_serving_environment, 56 is_in_databricks_runtime, 57 ) 58 from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring 59 from mlflow.utils.environment import ( 60 _CONDA_ENV_FILE_NAME, 61 _CONSTRAINTS_FILE_NAME, 62 _PYTHON_ENV_FILE_NAME, 63 _REQUIREMENTS_FILE_NAME, 64 _mlflow_conda_env, 65 _process_conda_env, 66 _process_pip_requirements, 67 _PythonEnv, 68 _validate_env_arguments, 69 ) 70 from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to 71 from mlflow.utils.model_utils import ( 72 _add_code_from_conf_to_system_path, 73 _copy_extra_files, 74 _get_flavor_configuration, 75 _validate_and_copy_code_paths, 76 _validate_and_prepare_target_save_path, 77 ) 78 from mlflow.utils.requirements_utils import _get_pinned_requirement 79 80 FLAVOR_NAME = "tensorflow" 81 82 _logger = logging.getLogger(__name__) 83 84 # For tracking if the run was started by autologging. 85 _AUTOLOG_RUN_ID = None 86 87 # File name to which custom objects cloudpickle is saved - used during save and load 88 _CUSTOM_OBJECTS_SAVE_PATH = "custom_objects.cloudpickle" 89 # File name to which custom objects stored in tensorflow _GLOBAL_CUSTOM_OBJECTS 90 # is saved - it is automatically detected and used during save and load 91 _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH = "global_custom_objects.cloudpickle" 92 _KERAS_MODULE_SPEC_PATH = "keras_module.txt" 93 _KERAS_SAVE_FORMAT_PATH = "save_format.txt" 94 # File name to which keras model is saved 95 _MODEL_SAVE_PATH = "model" 96 97 98 _MODEL_TYPE_KERAS = "keras" 99 _MODEL_TYPE_TF1_ESTIMATOR = "tf1-estimator" 100 _MODEL_TYPE_TF2_MODULE = "tf2-module" 101 102 103 _KERAS_MODEL_DATA_PATH = "data" 104 _TF2MODEL_SUBPATH = "tf2model" 105 106 107 MLflowCallback = MlflowCallback # for backwards compatibility 108 109 110 def get_default_pip_requirements(include_cloudpickle=False): 111 """ 112 Returns 113 A list of default pip requirements for MLflow Models produced by this flavor. 114 Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment 115 that, at minimum, contains these requirements. 116 """ 117 pip_deps = [_get_pinned_requirement("tensorflow")] 118 if include_cloudpickle: 119 pip_deps.append(_get_pinned_requirement("cloudpickle")) 120 121 return pip_deps 122 123 124 def get_default_conda_env(): 125 """ 126 Returns: 127 The default Conda environment for MLflow Models produced by calls to 128 :func:`save_model()` and :func:`log_model()`. 129 """ 130 return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) 131 132 133 def get_global_custom_objects(): 134 """ 135 Returns: 136 A live reference to the global dictionary of custom objects. 137 """ 138 try: 139 from tensorflow.keras.saving import get_custom_objects 140 141 return get_custom_objects() 142 except Exception: 143 pass 144 145 146 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) 147 def log_model( 148 model, 149 artifact_path: str | None = None, 150 custom_objects=None, 151 conda_env=None, 152 code_paths=None, 153 signature: ModelSignature = None, 154 input_example: ModelInputExample = None, 155 registered_model_name=None, 156 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 157 pip_requirements=None, 158 extra_pip_requirements=None, 159 saved_model_kwargs=None, 160 keras_model_kwargs=None, 161 metadata=None, 162 extra_files=None, 163 name: str | None = None, 164 params: dict[str, Any] | None = None, 165 tags: dict[str, Any] | None = None, 166 model_type: str | None = None, 167 step: int = 0, 168 model_id: str | None = None, 169 **kwargs, 170 ): 171 """ 172 Log a TF2 core model (inheriting tf.Module) or a Keras model in MLflow Model format. 173 174 .. note:: 175 176 If you log a Keras or TensorFlow model without a signature, inference with 177 :py:func:`mlflow.pyfunc.spark_udf()` will not work unless the model's pyfunc 178 representation accepts pandas DataFrames as inference inputs. 179 180 You can infer a model's signature by calling the :py:func:`mlflow.models.infer_signature()` 181 API on features from the model's test dataset. You can also manually create a model 182 signature, for example: 183 184 .. code-block:: python 185 :caption: Example of creating signature for saving TensorFlow and `tf.Keras` models 186 187 from mlflow.types.schema import Schema, TensorSpec 188 from mlflow.models import ModelSignature 189 import numpy as np 190 191 input_schema = Schema([ 192 TensorSpec(np.dtype(np.uint64), (-1, 5), "field1"), 193 TensorSpec(np.dtype(np.float32), (-1, 3, 2), "field2"), 194 ]) 195 # Create the signature for a model that requires 2 inputs: 196 # - Input with name "field1", shape (-1, 5), type "np.uint64" 197 # - Input with name "field2", shape (-1, 3, 2), type "np.float32" 198 signature = ModelSignature(inputs=input_schema) 199 200 Args: 201 model: The TF2 core model (inheriting tf.Module) or Keras model to be saved. 202 artifact_path: Deprecated. Use `name` instead. 203 custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to 204 custom classes or functions associated with the Keras model. MLflow saves 205 these custom layers using CloudPickle and restores them automatically 206 when the model is loaded with :py:func:`mlflow.tensorflow.load_model` and 207 :py:func:`mlflow.pyfunc.load_model`. 208 conda_env: {{ conda_env }} 209 code_paths: {{ code_paths }} 210 signature: {{ signature }} 211 input_example: {{ input_example }} 212 registered_model_name: If given, create a model version under 213 ``registered_model_name``, also creating a registered model if one 214 with the given name does not exist. 215 await_registration_for: Number of seconds to wait for the model version to finish 216 being created and is in ``READY`` status. By default, the function 217 waits for five minutes. Specify 0 or None to skip waiting. 218 pip_requirements: {{ pip_requirements }} 219 extra_pip_requirements: {{ extra_pip_requirements }} 220 saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method. 221 keras_model_kwargs: a dict of kwargs to pass to ``keras_model.save`` method. 222 metadata: {{ metadata }} 223 extra_files: {{ extra_files }} 224 name: {{ name }} 225 params: {{ params }} 226 tags: {{ tags }} 227 model_type: {{ model_type }} 228 step: {{ step }} 229 model_id: {{ model_id }} 230 kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`. 231 232 Returns 233 A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the 234 metadata of the logged model. 235 """ 236 237 return Model.log( 238 artifact_path=artifact_path, 239 name=name, 240 flavor=mlflow.tensorflow, 241 model=model, 242 conda_env=conda_env, 243 code_paths=code_paths, 244 custom_objects=custom_objects, 245 registered_model_name=registered_model_name, 246 signature=signature, 247 input_example=input_example, 248 await_registration_for=await_registration_for, 249 pip_requirements=pip_requirements, 250 extra_pip_requirements=extra_pip_requirements, 251 saved_model_kwargs=saved_model_kwargs, 252 keras_model_kwargs=keras_model_kwargs, 253 metadata=metadata, 254 extra_files=extra_files, 255 params=params, 256 tags=tags, 257 model_type=model_type, 258 step=step, 259 model_id=model_id, 260 **kwargs, 261 ) 262 263 264 def _save_keras_custom_objects(path, custom_objects, file_name): 265 """ 266 Save custom objects dictionary to a cloudpickle file so a model can be easily loaded later. 267 268 Args: 269 path: An absolute path that points to the data directory within /path/to/model. 270 custom_objects: Keras ``custom_objects`` is a dictionary mapping 271 names (strings) to custom classes or functions to be considered 272 during deserialization. MLflow saves these custom layers using 273 CloudPickle and restores them automatically when the model is 274 loaded with :py:func:`mlflow.keras.load_model` and 275 :py:func:`mlflow.pyfunc.load_model`. 276 file_name: The file name to save the custom objects to. 277 """ 278 import cloudpickle 279 280 custom_objects_path = os.path.join(path, file_name) 281 with open(custom_objects_path, "wb") as out_f: 282 cloudpickle.dump(custom_objects, out_f) 283 284 285 _NO_MODEL_SIGNATURE_WARNING = ( 286 "You are saving a TensorFlow Core model or Keras model " 287 "without a signature. Inference with mlflow.pyfunc.spark_udf() will not work " 288 "unless the model's pyfunc representation accepts pandas DataFrames as " 289 "inference inputs." 290 ) 291 292 293 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) 294 def save_model( 295 model, 296 path, 297 conda_env=None, 298 code_paths=None, 299 mlflow_model=None, 300 custom_objects=None, 301 signature: ModelSignature = None, 302 input_example: ModelInputExample = None, 303 pip_requirements=None, 304 extra_pip_requirements=None, 305 saved_model_kwargs=None, 306 keras_model_kwargs=None, 307 metadata=None, 308 extra_files=None, 309 ): 310 """ 311 Save a TF2 core model (inheriting tf.Module) or Keras model in MLflow Model format to a path on 312 the local file system. 313 314 .. note:: 315 If you save a Keras or TensorFlow model without a signature, inference with 316 :py:func:`mlflow.pyfunc.spark_udf()` will not work unless the model's pyfunc 317 representation accepts pandas DataFrames as inference inputs. 318 You can infer a model's signature by calling the :py:func:`mlflow.models.infer_signature()` 319 API on features from the model's test dataset. You can also manually create a model 320 signature, for example: 321 322 .. code-block:: python 323 :caption: Example of creating signature for saving TensorFlow and `tf.Keras` models 324 325 from mlflow.types.schema import Schema, TensorSpec 326 from mlflow.models import ModelSignature 327 import numpy as np 328 329 input_schema = Schema([ 330 TensorSpec(np.dtype(np.uint64), (-1, 5), "field1"), 331 TensorSpec(np.dtype(np.float32), (-1, 3, 2), "field2"), 332 ]) 333 # Create the signature for a model that requires 2 inputs: 334 # - Input with name "field1", shape (-1, 5), type "np.uint64" 335 # - Input with name "field2", shape (-1, 3, 2), type "np.float32" 336 signature = ModelSignature(inputs=input_schema) 337 338 Args: 339 model: The Keras model or Tensorflow module to be saved. 340 path: Local path where the MLflow model is to be saved. 341 conda_env: {{ conda_env }} 342 code_paths: {{ code_paths }} 343 mlflow_model: MLflow model configuration to which to add the ``tensorflow`` flavor. 344 custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to 345 custom classes or functions associated with the Keras model. MLflow saves 346 these custom layers using CloudPickle and restores them automatically 347 when the model is loaded with :py:func:`mlflow.tensorflow.load_model` and 348 :py:func:`mlflow.pyfunc.load_model`. 349 signature: {{ signature }} 350 input_example: {{ input_example }} 351 pip_requirements: {{ pip_requirements }} 352 extra_pip_requirements: {{ extra_pip_requirements }} 353 saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method 354 if the model to be saved is a Tensorflow module. 355 keras_model_kwargs: a dict of kwargs to pass to ``model.save`` method if the model 356 to be saved is a keras model. 357 metadata: {{ metadata }} 358 extra_files: {{ extra_files }} 359 """ 360 import tensorflow as tf 361 from tensorflow.keras.models import Model as KerasModel 362 363 # check if path exists 364 path = os.path.abspath(path) 365 _validate_and_prepare_target_save_path(path) 366 367 code_dir_subpath = _validate_and_copy_code_paths(code_paths, path) 368 369 if mlflow_model is None: 370 mlflow_model = Model() 371 saved_example = _save_example(mlflow_model, input_example, path) 372 373 if signature is None and saved_example is not None: 374 wrapped_model = None 375 if isinstance(model, KerasModel): 376 wrapped_model = _KerasModelWrapper(model, signature) 377 elif isinstance(model, tf.Module): 378 wrapped_model = _TF2ModuleWrapper(model, signature) 379 if wrapped_model is not None: 380 signature = _infer_signature_from_input_example(saved_example, wrapped_model) 381 elif signature is False: 382 signature = None 383 384 if signature is None: 385 _logger.warning(_NO_MODEL_SIGNATURE_WARNING) 386 else: 387 num_inputs = len(signature.inputs.inputs) 388 if num_inputs == 0: 389 raise MlflowException( 390 "The model signature's input schema must contain at least one field.", 391 error_code=INVALID_PARAMETER_VALUE, 392 ) 393 for field in signature.inputs.inputs: 394 if not isinstance(field, TensorSpec): 395 raise MlflowException( 396 "All fields in the model signature's input schema must be of type TensorSpec.", 397 error_code=INVALID_PARAMETER_VALUE, 398 ) 399 if field.shape[0] != -1: 400 raise MlflowException( 401 "All fields in the model signature's input schema must have a shape " 402 "in which the first dimension is a variable dimension.", 403 error_code=INVALID_PARAMETER_VALUE, 404 ) 405 406 _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) 407 408 if signature is not None: 409 mlflow_model.signature = signature 410 if metadata is not None: 411 mlflow_model.metadata = metadata 412 413 if isinstance(model, KerasModel): 414 keras_model_kwargs = keras_model_kwargs or {} 415 416 data_subpath = _KERAS_MODEL_DATA_PATH 417 # construct new data folder in existing path 418 data_path = os.path.join(path, data_subpath) 419 os.makedirs(data_path) 420 model_subpath = os.path.join(data_subpath, _MODEL_SAVE_PATH) 421 422 keras_module = importlib.import_module("tensorflow.keras") 423 # save custom objects if there are custom objects 424 if custom_objects is not None: 425 _save_keras_custom_objects(data_path, custom_objects, _CUSTOM_OBJECTS_SAVE_PATH) 426 # save custom objects stored within _GLOBAL_CUSTOM_OBJECTS 427 if global_custom_objects := get_global_custom_objects(): 428 _save_keras_custom_objects( 429 data_path, global_custom_objects, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH 430 ) 431 432 # save keras module spec to path/data/keras_module.txt 433 with open(os.path.join(data_path, _KERAS_MODULE_SPEC_PATH), "w") as f: 434 f.write(keras_module.__name__) 435 436 # Use the SavedModel format if `save_format` is unspecified 437 save_format = keras_model_kwargs.get("save_format", "tf") 438 439 # save keras save_format to path/data/save_format.txt 440 with open(os.path.join(data_path, _KERAS_SAVE_FORMAT_PATH), "w") as f: 441 f.write(save_format) 442 443 # save keras model 444 # To maintain prior behavior, when the format is HDF5, we save 445 # with the h5 file extension. Otherwise, model_path is a directory 446 # where the saved_model.pb will be stored (for SavedModel format) 447 # For tensorflow 2.16.0 (including dev version), 448 # it only supports saving model in .h5 or .keras format 449 if save_format == "h5": 450 file_extension = ".h5" 451 elif Version(tf.__version__).release >= (2, 16): 452 file_extension = ".keras" 453 else: 454 file_extension = "" 455 model_path = os.path.join(path, model_subpath) + file_extension 456 if path.startswith("/dbfs/"): 457 # The Databricks Filesystem uses a FUSE implementation that does not support 458 # random writes. It causes an error. 459 with tempfile.NamedTemporaryFile(suffix=".h5") as f: 460 model.save(f.name, **keras_model_kwargs) 461 f.flush() # force flush the data 462 shutil.copy2(src=f.name, dst=model_path) 463 else: 464 model.save(model_path, **keras_model_kwargs) 465 466 pyfunc_options = { 467 "data": data_subpath, 468 } 469 470 flavor_options = { 471 **pyfunc_options, 472 "model_type": _MODEL_TYPE_KERAS, 473 "keras_version": tf.__version__, 474 "save_format": save_format, 475 } 476 elif isinstance(model, tf.Module): 477 saved_model_kwargs = saved_model_kwargs or {} 478 model_dir_subpath = _TF2MODEL_SUBPATH 479 model_path = os.path.join(path, model_dir_subpath) 480 tf.saved_model.save(model, model_path, **saved_model_kwargs) 481 pyfunc_options = {} 482 flavor_options = { 483 "saved_model_dir": model_dir_subpath, 484 "model_type": _MODEL_TYPE_TF2_MODULE, 485 } 486 else: 487 raise MlflowException(f"Unknown model type: {type(model)}") 488 489 extra_files_config = _copy_extra_files(extra_files, path) 490 491 # update flavor info to mlflow_model 492 mlflow_model.add_flavor( 493 FLAVOR_NAME, code=code_dir_subpath, **flavor_options, **extra_files_config 494 ) 495 496 # append loader_module, data and env data to mlflow_model 497 pyfunc.add_to_model( 498 mlflow_model, 499 loader_module="mlflow.tensorflow", 500 conda_env=_CONDA_ENV_FILE_NAME, 501 python_env=_PYTHON_ENV_FILE_NAME, 502 code=code_dir_subpath, 503 **pyfunc_options, 504 ) 505 506 # add model file size to mlflow_model 507 if size := get_total_file_size(path): 508 mlflow_model.model_size_bytes = size 509 510 # save mlflow_model to path/MLmodel 511 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 512 513 include_cloudpickle = custom_objects is not None or get_global_custom_objects() is not None 514 if conda_env is None: 515 if pip_requirements is None: 516 default_reqs = get_default_pip_requirements(include_cloudpickle) 517 # To ensure `_load_pyfunc` can successfully load the model during the dependency 518 # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. 519 inferred_reqs = mlflow.models.infer_pip_requirements( 520 path, FLAVOR_NAME, fallback=default_reqs 521 ) 522 default_reqs = sorted(set(inferred_reqs).union(default_reqs)) 523 else: 524 default_reqs = None 525 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 526 default_reqs, 527 pip_requirements, 528 extra_pip_requirements, 529 ) 530 else: 531 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 532 533 with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: 534 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 535 536 # Save `constraints.txt` if necessary 537 if pip_constraints: 538 write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 539 540 # Save `requirements.txt` 541 write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 542 543 _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) 544 545 546 def _load_custom_objects(path, file_name): 547 custom_objects_path = None 548 if os.path.isdir(path): 549 if os.path.isfile(os.path.join(path, file_name)): 550 custom_objects_path = os.path.join(path, file_name) 551 if custom_objects_path is not None: 552 if ( 553 not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get() 554 and not is_in_databricks_runtime() 555 and not is_in_databricks_model_serving_environment() 556 ): 557 raise MlflowException( 558 "Deserializing custom objects using cloudpickle is disallowed, but this model " 559 "was saved with custom objects in pickle format. To address this issue, you need " 560 "to set environment variable 'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true'." 561 ) 562 import cloudpickle 563 564 with open(custom_objects_path, "rb") as f: 565 return cloudpickle.load(f) 566 567 568 def _load_keras_model(model_path, keras_module, save_format, **kwargs): 569 keras_models = importlib.import_module(keras_module.__name__ + ".models") 570 custom_objects = kwargs.pop("custom_objects", {}) 571 if saved_custom_objects := _load_custom_objects(model_path, _CUSTOM_OBJECTS_SAVE_PATH): 572 saved_custom_objects.update(custom_objects) 573 custom_objects = saved_custom_objects 574 575 if global_custom_objects := _load_custom_objects(model_path, _GLOBAL_CUSTOM_OBJECTS_SAVE_PATH): 576 global_custom_objects.update(custom_objects) 577 custom_objects = global_custom_objects 578 579 if os.path.isdir(model_path): 580 model_path = os.path.join(model_path, _MODEL_SAVE_PATH) 581 582 # If the save_format is HDF5, then we save with h5 file 583 # extension to align with prior behavior of mlflow logging 584 if save_format == "h5": 585 model_path += ".h5" 586 # Since TF 2.16.0, it only supports saving model in .h5 or .keras format. 587 # But for backwards compatibility, we still save model without suffix 588 # for older versions of TF. 589 elif os.path.exists(model_path + ".keras"): 590 model_path += ".keras" 591 592 import tensorflow as tf 593 594 # Using naive tuple-based comparison here rather than packaging.version.Version, because 595 # the latter consider dev version e.g. 2.16.0.dev2023010 as ahead of 2.16. While that is 596 # 'correct', we rather want to treat it is a part of 2.16 here. 597 if save_format == "h5" and (2, 2, 3) <= Version(tf.__version__).release < (2, 16): 598 # NOTE: TF 2.2.3 does not work with unicode paths in python2. Pass in h5py.File instead 599 # of string to avoid issues. 600 import h5py 601 602 with h5py.File(os.path.abspath(model_path), "r") as model_path: 603 return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs) 604 else: 605 # NOTE: Older versions of Keras only handle filepath. 606 return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs) 607 608 609 def _get_flavor_conf(model_conf): 610 if "keras" in model_conf.flavors: 611 return model_conf.flavors["keras"] 612 return model_conf.flavors[FLAVOR_NAME] 613 614 615 def _infer_model_type(model_conf): 616 model_type = _get_flavor_conf(model_conf).get("model_type") 617 if model_type is not None: 618 return model_type 619 # Loading model logged by old version mlflow, which deos not record model_type 620 # Inferring model type by checking whether model_conf contains "keras" flavor. 621 if "keras" in model_conf.flavors: 622 return _MODEL_TYPE_KERAS 623 return _MODEL_TYPE_TF1_ESTIMATOR 624 625 626 def load_model(model_uri, dst_path=None, saved_model_kwargs=None, keras_model_kwargs=None): 627 """ 628 Load an MLflow model that contains the TensorFlow flavor from the specified path. 629 630 Args: 631 model_uri: The location, in URI format, of the MLflow model. For example: 632 633 - ``/Users/me/path/to/local/model`` 634 - ``relative/path/to/local/model`` 635 - ``s3://my_bucket/path/to/model`` 636 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 637 - ``models:/<model_name>/<model_version>`` 638 - ``models:/<model_name>/<stage>`` 639 640 For more information about supported URI schemes, see 641 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 642 artifact-locations>`_. 643 dst_path: The local filesystem path to which to download the model artifact. 644 This directory must already exist. If unspecified, a local output 645 path will be created. 646 saved_model_kwargs: kwargs to pass to ``tensorflow.saved_model.load`` method. 647 Only available when you are loading a tensorflow2 core model. 648 keras_model_kwargs: kwargs to pass to ``keras.models.load_model`` method. 649 Only available when you are loading a Keras model. 650 651 Returns 652 A callable graph (tf.function) that takes inputs and returns inferences. 653 """ 654 import tensorflow as tf 655 656 local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) 657 658 model_configuration_path = os.path.join(local_model_path, MLMODEL_FILE_NAME) 659 model_conf = Model.load(model_configuration_path) 660 661 flavor_conf = _get_flavor_conf(model_conf) 662 663 _add_code_from_conf_to_system_path(local_model_path, flavor_conf) 664 665 model_type = _infer_model_type(model_conf) 666 if model_type == _MODEL_TYPE_KERAS: 667 keras_model_kwargs = keras_model_kwargs or {} 668 keras_module = importlib.import_module(flavor_conf.get("keras_module", "tensorflow.keras")) 669 # For backwards compatibility, we assume h5 when the save_format is absent 670 save_format = flavor_conf.get("save_format", "h5") 671 model_path = os.path.join(local_model_path, flavor_conf.get("data", _MODEL_SAVE_PATH)) 672 return _load_keras_model( 673 model_path=model_path, 674 keras_module=keras_module, 675 save_format=save_format, 676 **keras_model_kwargs, 677 ) 678 if model_type == _MODEL_TYPE_TF1_ESTIMATOR: 679 tf_saved_model_dir = os.path.join(local_model_path, flavor_conf["saved_model_dir"]) 680 tf_meta_graph_tags = flavor_conf["meta_graph_tags"] 681 tf_signature_def_key = flavor_conf["signature_def_key"] 682 return _load_tf1_estimator_saved_model( 683 tf_saved_model_dir=tf_saved_model_dir, 684 tf_meta_graph_tags=tf_meta_graph_tags, 685 tf_signature_def_key=tf_signature_def_key, 686 ) 687 if model_type == _MODEL_TYPE_TF2_MODULE: 688 saved_model_kwargs = saved_model_kwargs or {} 689 tf_saved_model_dir = os.path.join(local_model_path, flavor_conf["saved_model_dir"]) 690 return tf.saved_model.load(tf_saved_model_dir, **saved_model_kwargs) 691 692 raise MlflowException(f"Unknown model_type: {model_type}") 693 694 695 def _load_tf1_estimator_saved_model(tf_saved_model_dir, tf_meta_graph_tags, tf_signature_def_key): 696 """ 697 Load a specified TensorFlow model consisting of a TensorFlow metagraph and signature definition 698 from a serialized TensorFlow ``SavedModel`` collection. 699 700 Args: 701 tf_saved_model_dir: The local filesystem path or run-relative artifact path to the model. 702 tf_meta_graph_tags: A list of tags identifying the model's metagraph within the 703 serialized ``SavedModel`` object. For more information, see the 704 ``tags`` parameter of the `tf.saved_model.builder.SavedModelBuilder 705 method <https://www.tensorflow.org/api_docs/python/tf/saved_model/ 706 builder/SavedModelBuilder#add_meta_graph>`_. 707 tf_signature_def_key: A string identifying the input/output signature associated with the 708 model. This is a key within the serialized ``SavedModel``'s 709 signature definition mapping. For more information, see the 710 ``signature_def_map`` parameter of the 711 ``tf.saved_model.builder.SavedModelBuilder`` method. 712 713 Returns: 714 A callable graph (tensorflow.function) that takes inputs and returns inferences. 715 """ 716 import tensorflow as tf 717 718 loaded = tf.saved_model.load(tags=tf_meta_graph_tags, export_dir=tf_saved_model_dir) 719 loaded_sig = loaded.signatures 720 if tf_signature_def_key not in loaded_sig: 721 raise MlflowException( 722 f"Could not find signature def key {tf_signature_def_key}. " 723 f"Available keys are: {list(loaded_sig.keys())}" 724 ) 725 return loaded_sig[tf_signature_def_key] 726 727 728 def _load_pyfunc(path): 729 """ 730 Load PyFunc implementation. Called by ``pyfunc.load_model``. This function loads an MLflow 731 model with the TensorFlow flavor into a new TensorFlow graph and exposes it behind the 732 ``pyfunc.predict`` interface. 733 734 Args: 735 path: Local filesystem path to the MLflow Model with the ``tensorflow`` flavor. 736 """ 737 import tensorflow as tf 738 739 model_meta_path1 = os.path.join(path, MLMODEL_FILE_NAME) 740 model_meta_path2 = os.path.join(os.path.dirname(path), MLMODEL_FILE_NAME) 741 742 if os.path.isfile(model_meta_path1): 743 model_meta = Model.load(model_meta_path1) 744 elif os.path.isfile(model_meta_path2): 745 model_meta = Model.load(model_meta_path2) 746 else: 747 raise MlflowException(f"Cannot find file {MLMODEL_FILE_NAME} for the logged model.") 748 749 model_type = _infer_model_type(model_meta) 750 if model_type == _MODEL_TYPE_KERAS: 751 if os.path.isfile(os.path.join(path, _KERAS_MODULE_SPEC_PATH)): 752 with open(os.path.join(path, _KERAS_MODULE_SPEC_PATH)) as f: 753 keras_module = importlib.import_module(f.read()) 754 else: 755 from tensorflow import keras 756 757 keras_module = keras 758 759 # By default, we assume the save_format is h5 for backwards compatibility 760 save_format = "h5" 761 save_format_path = os.path.join(path, _KERAS_SAVE_FORMAT_PATH) 762 if os.path.isfile(save_format_path): 763 with open(save_format_path) as f: 764 save_format = f.read() 765 766 # In SavedModel format, loaded model should be compiled. 767 should_compile = save_format == "tf" 768 m = _load_keras_model( 769 path, keras_module=keras_module, save_format=save_format, compile=should_compile 770 ) 771 return _KerasModelWrapper(m, model_meta.signature) 772 if model_type == _MODEL_TYPE_TF1_ESTIMATOR: 773 flavor_conf = _get_flavor_configuration(path, FLAVOR_NAME) 774 775 tf_saved_model_dir = os.path.join(path, flavor_conf["saved_model_dir"]) 776 tf_meta_graph_tags = flavor_conf["meta_graph_tags"] 777 tf_signature_def_key = flavor_conf["signature_def_key"] 778 779 loaded_model = tf.saved_model.load(export_dir=tf_saved_model_dir, tags=tf_meta_graph_tags) 780 return _TF2Wrapper(model=loaded_model, infer=loaded_model.signatures[tf_signature_def_key]) 781 if model_type == _MODEL_TYPE_TF2_MODULE: 782 flavor_conf = _get_flavor_configuration(path, FLAVOR_NAME) 783 tf_saved_model_dir = os.path.join(path, flavor_conf["saved_model_dir"]) 784 loaded_model = tf.saved_model.load(tf_saved_model_dir) 785 return _TF2ModuleWrapper(model=loaded_model, signature=model_meta.signature) 786 787 raise MlflowException("Unknown model_type.") 788 789 790 class _TF2Wrapper: 791 """ 792 Wrapper class that exposes a TensorFlow model for inference via a ``predict`` function such that 793 ``predict(data: pandas.DataFrame) -> pandas.DataFrame``. For TensorFlow versions >= 2.0.0. 794 """ 795 796 def __init__(self, model, infer): 797 """ 798 Args: 799 model: A Tensorflow SavedModel. 800 infer: Tensorflow function returned by a saved model that is used for inference. 801 """ 802 # Note: we need to retain the model reference in TF2Wrapper object, because the infer 803 # function in tensorflow will be `ConcreteFunction` which only retains WeakRefs to the 804 # variables they close over. 805 # See https://www.tensorflow.org/guide/function#deleting_tfvariables_between_function_calls 806 self.model = model 807 self.infer = infer 808 809 def get_raw_model(self): 810 """ 811 Returns the underlying model. 812 """ 813 return self.model 814 815 def predict( 816 self, 817 data, 818 params: dict[str, Any] | None = None, 819 ): 820 """ 821 Args: 822 data: Model input data. 823 params: Additional parameters to pass to the model for inference. 824 825 Returns: 826 Model predictions. 827 """ 828 import tensorflow as tf 829 830 feed_dict = {} 831 if isinstance(data, dict): 832 feed_dict = {k: tf.constant(v) for k, v in data.items()} 833 elif isinstance(data, pandas.DataFrame): 834 for df_col_name in list(data): 835 # If there are multiple columns with the same name, selecting the shared name 836 # from the DataFrame will result in another DataFrame containing the columns 837 # with the shared name. TensorFlow cannot make eager tensors out of pandas 838 # DataFrames, so we convert the DataFrame to a numpy array here. 839 val = data[df_col_name] 840 val = val.values if isinstance(val, pandas.DataFrame) else np.array(val.to_list()) 841 feed_dict[df_col_name] = tf.constant(val) 842 else: 843 raise TypeError("Only dict and DataFrame input types are supported") 844 845 raw_preds = self.infer(**feed_dict) 846 pred_dict = {col_name: raw_preds[col_name].numpy() for col_name in raw_preds.keys()} 847 for col in pred_dict.keys(): 848 # If the output tensor is not 1-dimensional 849 # AND all elements have length of 1, flatten the array with `ravel()` 850 if len(pred_dict[col].shape) != 1 and all( 851 len(element) == 1 for element in pred_dict[col] 852 ): 853 pred_dict[col] = pred_dict[col].ravel() 854 else: 855 pred_dict[col] = pred_dict[col].tolist() 856 857 if isinstance(data, dict): 858 return pred_dict 859 else: 860 return pandas.DataFrame.from_dict(data=pred_dict) 861 862 863 class _TF2ModuleWrapper: 864 def __init__(self, model, signature): 865 self.model = model 866 self.signature = signature 867 868 def get_raw_model(self): 869 """ 870 Returns the underlying model. 871 """ 872 return self.model 873 874 def predict( 875 self, 876 data, 877 params: dict[str, Any] | None = None, 878 ): 879 """ 880 Args: 881 data: Model input data. 882 params: Additional parameters to pass to the model for inference. 883 884 Returns: 885 Model predictions. 886 """ 887 import tensorflow as tf 888 889 if isinstance(data, (np.ndarray, list)): 890 data = tf.convert_to_tensor(data) 891 else: 892 raise MlflowException( 893 f"Unsupported input data type: {type(data)}, the input data must be " 894 "numpy array or a list." 895 ) 896 result = self.model(data) 897 if isinstance(result, tf.Tensor): 898 return result.numpy() 899 return result 900 901 902 class _KerasModelWrapper: 903 def __init__(self, keras_model, signature): 904 self.keras_model = keras_model 905 self.signature = signature 906 907 def get_raw_model(self): 908 """ 909 Returns the underlying model. 910 """ 911 return self.keras_model 912 913 def predict( 914 self, 915 data, 916 params: dict[str, Any] | None = None, 917 ): 918 """ 919 Args: 920 data: Model input data. 921 params: Additional parameters to pass to the model for inference. 922 923 Returns 924 Model predictions. 925 """ 926 if isinstance(data, pandas.DataFrame): 927 # This line is for backwards compatibility: 928 # If model signature is not None, when calling 929 # `keras_pyfunc_model.predict(pandas_dataframe)`, `_enforce_schema` will convert 930 # dataframe input into dict input, so in the case `_KerasModelWrapper.predict` 931 # will receive a dict type input. 932 # If model signature is None, `_enforce_schema` can do nothing, and if the input 933 # is dataframe, `_KerasModelWrapper.predict` will receive a dataframe input, 934 # we need to handle this case, to keep backwards compatibility. 935 return pandas.DataFrame(self.keras_model.predict(data.values), index=data.index) 936 937 supported_input_types = (np.ndarray, list, tuple, dict) 938 if not isinstance(data, supported_input_types): 939 raise MlflowException( 940 f"Unsupported input data type: {type(data)}. " 941 f"Must be one of: {[x.__name__ for x in supported_input_types]}", 942 INVALID_PARAMETER_VALUE, 943 ) 944 return self.keras_model.predict(data) 945 946 947 def _assoc_list_to_map(lst): 948 """ 949 Convert an association list to a dictionary. 950 """ 951 d = {} 952 for run_id, metric in lst: 953 d[run_id] = d[run_id] + [metric] if run_id in d else [metric] 954 return d 955 956 957 @picklable_exception_safe_function 958 def _get_tensorboard_callback(lst): 959 import tensorflow as tf 960 961 for x in lst: 962 if isinstance(x, tf.keras.callbacks.TensorBoard): 963 return x 964 return None 965 966 967 # A representation of a TensorBoard event logging directory with two attributes: 968 # :location - string: The filesystem location of the logging directory 969 # :is_temp - boolean: `True` if the logging directory was created for temporary use by MLflow, 970 # `False` otherwise 971 class _TensorBoardLogDir(NamedTuple): 972 location: str 973 is_temp: bool 974 975 976 def _setup_callbacks(callbacks, log_every_epoch, log_every_n_steps): 977 """ 978 Adds TensorBoard and MlfLowTfKeras callbacks to the 979 input list, and returns the new list and appropriate log directory. 980 """ 981 from mlflow.tensorflow.autologging import _TensorBoard 982 from mlflow.tensorflow.callback import MlflowCallback, MlflowModelCheckpointCallback 983 984 tb = _get_tensorboard_callback(callbacks) 985 for callback in callbacks: 986 if isinstance(callback, MlflowCallback): 987 raise MlflowException( 988 "MLflow autologging must be turned off if an `MlflowCallback` is explicitly added " 989 "to the callback list. You are creating an `MlflowCallback` while having " 990 "autologging enabled. Please either call `mlflow.tensorflow.autolog(disable=True)` " 991 "to disable autologging or remove `MlflowCallback` from the callback list. " 992 ) 993 if tb is None: 994 log_dir = _TensorBoardLogDir(location=tempfile.mkdtemp(), is_temp=True) 995 callbacks.append(_TensorBoard(log_dir.location)) 996 else: 997 log_dir = _TensorBoardLogDir(location=tb.log_dir, is_temp=False) 998 999 callbacks.append( 1000 MlflowCallback( 1001 log_every_epoch=log_every_epoch, 1002 log_every_n_steps=log_every_n_steps, 1003 ) 1004 ) 1005 1006 model_checkpoint = get_autologging_config(mlflow.tensorflow.FLAVOR_NAME, "checkpoint", True) 1007 if model_checkpoint: 1008 checkpoint_monitor = get_autologging_config( 1009 mlflow.tensorflow.FLAVOR_NAME, "checkpoint_monitor", "val_loss" 1010 ) 1011 checkpoint_mode = get_autologging_config( 1012 mlflow.tensorflow.FLAVOR_NAME, "checkpoint_mode", "min" 1013 ) 1014 checkpoint_save_best_only = get_autologging_config( 1015 mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_best_only", True 1016 ) 1017 checkpoint_save_weights_only = get_autologging_config( 1018 mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_weights_only", False 1019 ) 1020 checkpoint_save_freq = get_autologging_config( 1021 mlflow.tensorflow.FLAVOR_NAME, "checkpoint_save_freq", "epoch" 1022 ) 1023 1024 if not any(isinstance(callback, MlflowModelCheckpointCallback) for callback in callbacks): 1025 callbacks.append( 1026 MlflowModelCheckpointCallback( 1027 monitor=checkpoint_monitor, 1028 mode=checkpoint_mode, 1029 save_best_only=checkpoint_save_best_only, 1030 save_weights_only=checkpoint_save_weights_only, 1031 save_freq=checkpoint_save_freq, 1032 ) 1033 ) 1034 1035 return callbacks, log_dir 1036 1037 1038 @autologging_integration(FLAVOR_NAME) 1039 def autolog( 1040 log_models=True, 1041 log_datasets=True, 1042 disable=False, 1043 exclusive=False, 1044 disable_for_unsupported_versions=False, 1045 silent=False, 1046 registered_model_name=None, 1047 log_input_examples=False, 1048 log_model_signatures=True, 1049 saved_model_kwargs=None, 1050 keras_model_kwargs=None, 1051 extra_tags=None, 1052 log_every_epoch=True, 1053 log_every_n_steps=None, 1054 checkpoint=True, 1055 checkpoint_monitor="val_loss", 1056 checkpoint_mode="min", 1057 checkpoint_save_best_only=True, 1058 checkpoint_save_weights_only=False, 1059 checkpoint_save_freq="epoch", 1060 ): 1061 """ 1062 Enables autologging for ``tf.keras``. 1063 Note that only ``tensorflow>=2.3`` are supported. 1064 As an example, try running the 1065 `Keras/TensorFlow example <https://github.com/mlflow/mlflow/blob/master/examples/keras/train.py>`_. 1066 1067 For each TensorFlow module, autologging captures the following information: 1068 1069 **tf.keras** 1070 - **Metrics** and **Parameters** 1071 1072 - Training and validation loss. 1073 - User-specified metrics. 1074 - Optimizer config, e.g., learning_rate, momentum, etc. 1075 - Training configs, e.g., epochs, batch_size, etc. 1076 1077 - **Artifacts** 1078 1079 - Model summary on training start. 1080 - Saved Keras model in `MLflow Model <https://mlflow.org/docs/latest/models.html>`_ format. 1081 - TensorBoard logs on training end. 1082 1083 **tf.keras.callbacks.EarlyStopping** 1084 - **Metrics** and **Parameters** 1085 1086 - Metrics from the ``EarlyStopping`` callbacks: ``stopped_epoch``, ``restored_epoch``, 1087 ``restore_best_weight``, etc 1088 - ``fit()`` or ``fit_generator()`` parameters associated with ``EarlyStopping``: 1089 ``min_delta``, ``patience``, ``baseline``, ``restore_best_weights``, etc 1090 1091 Refer to the autologging tracking documentation for more 1092 information on `TensorFlow workflows 1093 <https://www.mlflow.org/docs/latest/tracking.html#tensorflow-and-keras-experimental>`_. 1094 1095 Note that autologging cannot be used together with explicit MLflow callback, i.e., 1096 `mlflow.tensorflow.MlflowCallback`, because it will cause the same metrics to be logged twice. 1097 If you want to include `mlflow.tensorflow.MlflowCallback` in the callback list, please turn off 1098 autologging by calling `mlflow.tensorflow.autolog(disable=True)`. 1099 1100 Args: 1101 log_models: If ``True``, trained models are logged as MLflow model artifacts. 1102 If ``False``, trained models are not logged. 1103 log_datasets: If ``True``, dataset information is logged to MLflow Tracking. 1104 If ``False``, dataset information is not logged. 1105 disable: If ``True``, disables the TensorFlow autologging integration. If ``False``, 1106 enables the TensorFlow integration autologging integration. 1107 exclusive: If ``True``, autologged content is not logged to user-created fluent runs. 1108 If ``False``, autologged content is logged to the active fluent run, 1109 which may be user-created. 1110 disable_for_unsupported_versions: If ``True``, disable autologging for versions of 1111 tensorflow that have not been tested against this version of the MLflow 1112 client or are incompatible. 1113 silent: If ``True``, suppress all event logs and warnings from MLflow during TensorFlow 1114 autologging. If ``False``, show all events and warnings during TensorFlow 1115 autologging. 1116 registered_model_name: If given, each time a model is trained, it is registered as a 1117 new model version of the registered model with this name. 1118 The registered model is created if it does not already exist. 1119 log_input_examples: If ``True``, input examples from training datasets are collected and 1120 logged along with tf/keras model artifacts during training. If 1121 ``False``, input examples are not logged. 1122 log_model_signatures: If ``True``, 1123 :py:class:`ModelSignatures <mlflow.models.ModelSignature>` 1124 describing model inputs and outputs are collected and logged along 1125 with tf/keras model artifacts during training. If ``False``, 1126 signatures are not logged. Note that logging TensorFlow models 1127 with signatures changes their pyfunc inference behavior when 1128 Pandas DataFrames are passed to ``predict()``. 1129 When a signature is present, an ``np.ndarray`` 1130 (for single-output models) or a mapping from 1131 ``str`` -> ``np.ndarray`` (for multi-output models) is returned; 1132 when a signature is not present, a Pandas DataFrame is returned. 1133 saved_model_kwargs: a dict of kwargs to pass to ``tensorflow.saved_model.save`` method. 1134 keras_model_kwargs: a dict of kwargs to pass to ``keras_model.save`` method. 1135 extra_tags: A dictionary of extra tags to set on each managed run created by autologging. 1136 log_every_epoch: If True, training metrics will be logged at the end of each epoch. 1137 log_every_n_steps: If set, training metrics will be logged every `n` training steps. 1138 `log_every_n_steps` must be `None` when `log_every_epoch=True`. 1139 checkpoint: Enable automatic model checkpointing. 1140 checkpoint_monitor: In automatic model checkpointing, the metric name to monitor if 1141 you set `model_checkpoint_save_best_only` to True. 1142 checkpoint_mode: one of {"min", "max"}. In automatic model checkpointing, 1143 if save_best_only=True, the decision to overwrite the current save file is made based on 1144 either the maximization or the minimization of the monitored quantity. 1145 checkpoint_save_best_only: If True, automatic model checkpointing only saves when 1146 the model is considered the "best" model according to the quantity 1147 monitored and previous checkpoint model is overwritten. 1148 checkpoint_save_weights_only: In automatic model checkpointing, if True, then 1149 only the model's weights will be saved. Otherwise, the optimizer states, 1150 lr-scheduler states, etc are added in the checkpoint too. 1151 checkpoint_save_freq: `"epoch"` or integer. When using `"epoch"`, the callback 1152 saves the model after each epoch. When using integer, the callback 1153 saves the model at end of this many batches. Note that if the saving isn't aligned to 1154 epochs, the monitored metric may potentially be less reliable (it 1155 could reflect as little as 1 batch, since the metrics get reset 1156 every epoch). Defaults to `"epoch"`. 1157 """ 1158 import tensorflow as tf 1159 1160 if Version(tf.__version__) < Version("2.3"): 1161 _logger.error( 1162 "Could not log to MLflow because your Tensorflow version is below 2.3, detected " 1163 f"version: {tf.__version__}." 1164 ) 1165 return 1166 1167 @picklable_exception_safe_function 1168 def _get_early_stop_callback(callbacks): 1169 for callback in callbacks: 1170 if isinstance(callback, tf.keras.callbacks.EarlyStopping): 1171 return callback 1172 return None 1173 1174 def _log_early_stop_callback_params(callback): 1175 if callback: 1176 try: 1177 earlystopping_params = { 1178 "monitor": callback.monitor, 1179 "min_delta": callback.min_delta, 1180 "patience": callback.patience, 1181 "baseline": callback.baseline, 1182 "restore_best_weights": callback.restore_best_weights, 1183 } 1184 mlflow.log_params(earlystopping_params) 1185 except Exception: 1186 return 1187 1188 def _get_early_stop_callback_attrs(callback): 1189 try: 1190 return callback.stopped_epoch, callback.restore_best_weights, callback.patience 1191 except Exception: 1192 return None 1193 1194 def _log_early_stop_callback_metrics(callback, history, model_id=None): 1195 from mlflow import log_metrics 1196 1197 if callback is None or not callback.model.stop_training: 1198 return 1199 1200 callback_attrs = _get_early_stop_callback_attrs(callback) 1201 if callback_attrs is None: 1202 return 1203 1204 stopped_epoch, restore_best_weights, _ = callback_attrs 1205 log_metrics({"stopped_epoch": stopped_epoch}, synchronous=False, model_id=model_id) 1206 1207 if not restore_best_weights or callback.best_weights is None: 1208 return 1209 1210 monitored_metric = history.history.get(callback.monitor) 1211 if not monitored_metric: 1212 return 1213 1214 initial_epoch = history.epoch[0] 1215 # If `monitored_metric` contains multiple best values (e.g. [0.1, 0.1, 0.2] where 0.1 is 1216 # the minimum loss), the epoch corresponding to the first occurrence of the best value is 1217 # the best epoch. In keras > 2.6.0, the best epoch can be obtained via the `best_epoch` 1218 # attribute of an `EarlyStopping` instance: https://github.com/keras-team/keras/pull/15197 1219 restored_epoch = initial_epoch + monitored_metric.index(callback.best) 1220 log_metrics({"restored_epoch": restored_epoch}, synchronous=False, model_id=model_id) 1221 restored_index = history.epoch.index(restored_epoch) 1222 restored_metrics = { 1223 key: metrics[restored_index] for key, metrics in history.history.items() 1224 } 1225 # Checking that a metric history exists 1226 metric_key = next(iter(history.history), None) 1227 if metric_key is not None: 1228 log_metrics(restored_metrics, stopped_epoch + 1, synchronous=False, model_id=model_id) 1229 1230 def _log_keras_model(history, args, model_id=None): 1231 def _infer_model_signature(input_data_slice): 1232 # In certain TensorFlow versions, calling `predict()` on model may modify 1233 # the `stop_training` attribute, so we save and restore it accordingly 1234 original_stop_training = history.model.stop_training 1235 model_output = history.model.predict(input_data_slice) 1236 history.model.stop_training = original_stop_training 1237 return infer_signature(input_data_slice, model_output) 1238 1239 from mlflow.tensorflow.autologging import extract_tf_keras_input_example 1240 1241 def _get_tf_keras_input_example_slice(): 1242 input_training_data = args[0] 1243 keras_input_example_slice = extract_tf_keras_input_example(input_training_data) 1244 if keras_input_example_slice is None: 1245 raise MlflowException( 1246 "Cannot log input example or model signature for input with type" 1247 f" {type(input_training_data)}. TensorFlow Keras autologging can" 1248 " only log input examples and model signatures for the following" 1249 " input types: numpy.ndarray, dict[string -> numpy.ndarray]," 1250 " tensorflow.keras.utils.Sequence, and" 1251 " tensorflow.data.Dataset (TensorFlow >= 2.1.0 required)", 1252 INVALID_PARAMETER_VALUE, 1253 ) 1254 return keras_input_example_slice 1255 1256 input_example, signature = resolve_input_example_and_signature( 1257 _get_tf_keras_input_example_slice, 1258 _infer_model_signature, 1259 log_input_examples, 1260 log_model_signatures, 1261 _logger, 1262 ) 1263 1264 log_model( 1265 history.model, 1266 "model", 1267 input_example=input_example, 1268 signature=signature, 1269 registered_model_name=get_autologging_config( 1270 FLAVOR_NAME, "registered_model_name", None 1271 ), 1272 saved_model_kwargs=saved_model_kwargs, 1273 keras_model_kwargs=keras_model_kwargs, 1274 model_id=model_id, 1275 ) 1276 1277 def _patched_inference(original, inst, *args, **kwargs): 1278 log_dir = None 1279 try: 1280 unlogged_params = ["self", "x", "y", "callbacks", "validation_data", "verbose"] 1281 1282 batch_size = None 1283 try: 1284 is_single_input_model = isinstance(inst.input_shape, tuple) 1285 training_data = kwargs["x"] if "x" in kwargs else args[0] 1286 if isinstance(training_data, tf.data.Dataset) and hasattr( 1287 training_data, "_batch_size" 1288 ): 1289 batch_size = training_data._batch_size.numpy() 1290 elif isinstance(training_data, tf.keras.utils.Sequence): 1291 first_batch_inputs, *_ = training_data[0] 1292 if is_single_input_model: 1293 batch_size = len(first_batch_inputs) 1294 else: 1295 batch_size = len(first_batch_inputs[0]) 1296 elif is_iterator(training_data): 1297 peek = next(training_data) 1298 batch_size = len(peek[0]) if is_single_input_model else len(peek[0][0]) 1299 1300 def __restore_generator(prev_generator): 1301 yield peek 1302 yield from prev_generator 1303 1304 restored_generator = __restore_generator(training_data) 1305 if "x" in kwargs: 1306 kwargs["x"] = restored_generator 1307 else: 1308 args = (restored_generator,) + args[1:] 1309 except Exception as e: 1310 _logger.warning( 1311 "Encountered unexpected error while inferring batch size from training" 1312 " dataset: %s", 1313 e, 1314 ) 1315 1316 if batch_size is not None: 1317 mlflow.log_param("batch_size", batch_size) 1318 unlogged_params.append("batch_size") 1319 1320 log_fn_args_as_params(original, args, kwargs, unlogged_params) 1321 1322 # Check if the 'callback' argument of fit() is set positionally 1323 if len(args) >= 6: 1324 # Convert the positional training function arguments to a list in order to 1325 # mutate the contents 1326 args = list(args) 1327 # Make a shallow copy of the preexisting callbacks to avoid permanently 1328 # modifying their contents for future training invocations. Introduce 1329 # TensorBoard & tf.keras callbacks if necessary 1330 callbacks = list(args[5]) 1331 callbacks, log_dir = _setup_callbacks( 1332 callbacks, 1333 log_every_epoch=log_every_epoch, 1334 log_every_n_steps=log_every_n_steps, 1335 ) 1336 # Replace the callbacks positional entry in the copied arguments and convert 1337 # the arguments back to tuple form for usage in the training function 1338 args[5] = callbacks 1339 args = tuple(args) 1340 else: 1341 # Make a shallow copy of the preexisting callbacks and introduce TensorBoard 1342 # & tf.keras callbacks if necessary 1343 callbacks = list(kwargs.get("callbacks") or []) 1344 kwargs["callbacks"], log_dir = _setup_callbacks( 1345 callbacks, 1346 log_every_epoch=log_every_epoch, 1347 log_every_n_steps=log_every_n_steps, 1348 ) 1349 1350 early_stop_callback = _get_early_stop_callback(callbacks) 1351 _log_early_stop_callback_params(early_stop_callback) 1352 1353 model_id = None 1354 if log_models: 1355 model_id = _initialize_logged_model("model", flavor=FLAVOR_NAME).model_id 1356 1357 if log_datasets: 1358 try: 1359 context_tags = context_registry.resolve_tags() 1360 source = CodeDatasetSource(tags=context_tags) 1361 1362 x = kwargs["x"] if "x" in kwargs else args[0] 1363 if "y" in kwargs: 1364 y = kwargs["y"] 1365 elif len(args) >= 2: 1366 y = args[1] 1367 else: 1368 y = None 1369 1370 if "validation_data" in kwargs: 1371 validation_data = kwargs["validation_data"] 1372 elif len(args) >= 8: 1373 validation_data = args[7] 1374 else: 1375 validation_data = None 1376 _log_tensorflow_dataset(x, source, "train", targets=y, model_id=model_id) 1377 if validation_data is not None: 1378 _log_tensorflow_dataset(validation_data, source, "eval", model_id=model_id) 1379 1380 except Exception as e: 1381 _logger.warning( 1382 "Failed to log training dataset information to MLflow Tracking. Reason: %s", 1383 e, 1384 ) 1385 1386 history = original(inst, *args, **kwargs) 1387 1388 if log_models: 1389 _log_keras_model(history, args, model_id=model_id) 1390 1391 _log_early_stop_callback_metrics( 1392 callback=early_stop_callback, 1393 history=history, 1394 model_id=model_id, 1395 ) 1396 # Ensure all data are logged. 1397 # Shut down the async logging (instead of flushing) 1398 # to avoid leaving zombie threads between patchings. 1399 _shut_down_async_logging() 1400 1401 mlflow.log_artifacts( 1402 local_dir=log_dir.location, 1403 artifact_path="tensorboard_logs", 1404 ) 1405 if log_dir.is_temp: 1406 shutil.rmtree(log_dir.location) 1407 return history 1408 1409 except (Exception, KeyboardInterrupt) as e: 1410 try: 1411 if log_dir is not None and log_dir.is_temp and os.path.exists(log_dir.location): 1412 shutil.rmtree(log_dir.location) 1413 finally: 1414 # Regardless of what happens during the `_on_exception` callback, reraise 1415 # the original implementation exception once the callback completes 1416 raise e 1417 1418 safe_patch( 1419 FLAVOR_NAME, 1420 tf.keras.Model, 1421 "fit", 1422 _patched_inference, 1423 manage_run=True, 1424 extra_tags=extra_tags, 1425 ) 1426 1427 1428 def _log_tensorflow_dataset( 1429 tensorflow_dataset, source, context, name=None, targets=None, model_id=None 1430 ): 1431 import tensorflow as tf 1432 1433 # create a dataset 1434 if isinstance(tensorflow_dataset, np.ndarray): 1435 dataset = from_numpy(features=tensorflow_dataset, targets=targets, source=source, name=name) 1436 elif isinstance(tensorflow_dataset, tf.Tensor): 1437 dataset = from_tensorflow( 1438 features=tensorflow_dataset, targets=targets, source=source, name=name 1439 ) 1440 elif isinstance(tensorflow_dataset, tf.data.Dataset): 1441 dataset = from_tensorflow(features=tensorflow_dataset, source=source, name=name) 1442 elif isinstance(tensorflow_dataset, tuple): 1443 x = tensorflow_dataset[0] 1444 y = tensorflow_dataset[1] 1445 # check if x and y are tensors 1446 if isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor): 1447 dataset = from_tensorflow(features=x, source=source, targets=y, name=name) 1448 else: 1449 dataset = from_numpy(features=x, targets=y, source=source, name=name) 1450 else: 1451 _logger.warning( 1452 "Unrecognized dataset type %s. Dataset logging skipped.", type(tensorflow_dataset) 1453 ) 1454 return 1455 1456 model = None if model_id is None else LoggedModelInput(model_id=model_id) 1457 mlflow.log_input(dataset, context, model=model) 1458 1459 1460 def load_checkpoint(model=None, run_id=None, epoch=None, global_step=None): 1461 """ 1462 If you enable "checkpoint" in autologging, during Keras model 1463 training execution, checkpointed models are logged as MLflow artifacts. 1464 Using this API, you can load the checkpointed model. 1465 1466 If you want to load the latest checkpoint, set both `epoch` and `global_step` to None. 1467 If "checkpoint_save_freq" is set to "epoch" in autologging, 1468 you can set `epoch` param to the epoch of the checkpoint to load specific epoch checkpoint. 1469 If "checkpoint_save_freq" is set to an integer in autologging, 1470 you can set `global_step` param to the global step of the checkpoint to load specific 1471 global step checkpoint. 1472 `epoch` param and `global_step` can't be set together. 1473 1474 Args: 1475 model: A Keras model, this argument is required 1476 only when the saved checkpoint is "weight-only". 1477 run_id: The id of the run which model is logged to. If not provided, 1478 current active run is used. 1479 epoch: The epoch of the checkpoint to be loaded, if you set 1480 "checkpoint_save_freq" to "epoch". 1481 global_step: The global step of the checkpoint to be loaded, if 1482 you set "checkpoint_save_freq" to an integer. 1483 1484 Returns: 1485 The instance of a Keras model restored from the specified checkpoint. 1486 1487 .. code-block:: python 1488 :caption: Example 1489 1490 import mlflow 1491 1492 mlflow.tensorflow.autolog(checkpoint=True, checkpoint_save_best_only=False) 1493 1494 model = create_tf_keras_model() # Create a Keras model 1495 with mlflow.start_run() as run: 1496 model.fit(data, label, epoch=10) 1497 1498 run_id = run.info.run_id 1499 1500 # load latest checkpoint model 1501 latest_checkpoint_model = mlflow.tensorflow.load_checkpoint(run_id=run_id) 1502 1503 # load history checkpoint model logged in second epoch 1504 checkpoint_model = mlflow.tensorflow.load_checkpoint(run_id=run_id, epoch=2) 1505 """ 1506 import tensorflow as tf 1507 1508 with TempDir() as tmp_dir: 1509 downloaded_checkpoint_filepath = download_checkpoint_artifact( 1510 run_id=run_id, epoch=epoch, global_step=global_step, dst_path=tmp_dir.path() 1511 ) 1512 1513 fname = os.path.splitext(downloaded_checkpoint_filepath)[0] 1514 if fname.endswith(_WEIGHT_ONLY_CHECKPOINT_SUFFIX): 1515 # the model is saved as weights only 1516 if model is None: 1517 raise MlflowException( 1518 "The latest checkpoint is weights-only, 'model' argument must be provided" 1519 ) 1520 model.load_weights(downloaded_checkpoint_filepath) 1521 return model 1522 return tf.keras.models.load_model(downloaded_checkpoint_filepath)