/ tests / tensorflow / test_tensorflow2_autolog.py
test_tensorflow2_autolog.py
   1  # pep8: disable=E501
   2  
   3  import functools
   4  import json
   5  import os
   6  import pickle
   7  import sys
   8  from pathlib import Path
   9  from unittest.mock import patch
  10  
  11  import numpy as np
  12  import pytest
  13  import tensorflow as tf
  14  import yaml
  15  from packaging.version import Version
  16  from tensorflow.keras import layers
  17  
  18  import mlflow
  19  from mlflow import MlflowClient
  20  from mlflow.exceptions import MlflowException
  21  from mlflow.models import Model
  22  from mlflow.models.utils import _read_example
  23  from mlflow.tensorflow import load_checkpoint
  24  from mlflow.tensorflow.autologging import _TensorBoard
  25  from mlflow.tensorflow.callback import MlflowCallback
  26  from mlflow.tracking.fluent import _shut_down_async_logging
  27  from mlflow.types.utils import _infer_schema
  28  from mlflow.utils.autologging_utils import (
  29      AUTOLOGGING_INTEGRATIONS,
  30      autologging_is_disabled,
  31  )
  32  from mlflow.utils.file_utils import local_file_uri_to_path
  33  from mlflow.utils.process import _exec_cmd
  34  
  35  np.random.seed(1337)
  36  
  37  
  38  @pytest.fixture(autouse=True)
  39  def clear_session():
  40      yield
  41      _shut_down_async_logging()
  42      tf.keras.backend.clear_session()
  43  
  44  
  45  @pytest.fixture
  46  def random_train_data():
  47      return np.random.random((150, 4))
  48  
  49  
  50  @pytest.fixture
  51  def random_one_hot_labels():
  52      n = 150
  53      n_class = 3
  54      classes = np.random.randint(0, n_class, n)
  55      labels = np.zeros((n, n_class))
  56      labels[np.arange(n), classes] = 1
  57      return labels
  58  
  59  
  60  @pytest.fixture
  61  def random_train_dict_mapping(random_train_data):
  62      def _generate_features(pos):
  63          return [v[pos] for v in random_train_data]
  64  
  65      return {
  66          "a": np.array(_generate_features(0)),
  67          "b": np.array(_generate_features(1)),
  68          "c": np.array(_generate_features(2)),
  69          "d": np.array(_generate_features(3)),
  70      }
  71  
  72  
  73  def _create_model_for_dict_mapping():
  74      inputs = {
  75          "a": tf.keras.Input(shape=(1,), name="a"),
  76          "b": tf.keras.Input(shape=(1,), name="b"),
  77          "c": tf.keras.Input(shape=(1,), name="c"),
  78          "d": tf.keras.Input(shape=(1,), name="d"),
  79      }
  80      concatenated = layers.Concatenate()(list(inputs.values()))
  81      x = layers.Dense(16, activation="relu", input_shape=(4,))(concatenated)
  82      outputs = layers.Dense(3, activation="softmax")(x)
  83      model = tf.keras.Model(inputs=inputs, outputs=outputs)
  84      model.compile(
  85          optimizer=tf.keras.optimizers.Adam(), loss="categorical_crossentropy", metrics=["accuracy"]
  86      )
  87      return model
  88  
  89  
  90  @pytest.fixture
  91  def fashion_mnist_tf_dataset():
  92      train, _ = tf.keras.datasets.fashion_mnist.load_data()
  93      images, labels = train
  94      images = images / 255.0
  95      labels = labels.astype(np.int32)
  96      fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
  97      return fmnist_train_ds.shuffle(5000).batch(32)
  98  
  99  
 100  @pytest.fixture
 101  def fashion_mnist_tf_dataset_eval():
 102      _, eval_dataset = tf.keras.datasets.fashion_mnist.load_data()
 103      images, labels = eval_dataset
 104      images = images / 255.0
 105      labels = labels.astype(np.int32)
 106      fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
 107      return fmnist_train_ds.shuffle(5000).batch(32)
 108  
 109  
 110  def _create_fashion_mnist_model():
 111      model = tf.keras.Sequential([
 112          tf.keras.Input((28, 28)),
 113          tf.keras.layers.Flatten(),
 114          tf.keras.layers.Dense(10),
 115      ])
 116      model.compile(
 117          optimizer=tf.keras.optimizers.Adam(),
 118          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
 119          metrics=["accuracy"],
 120      )
 121      return model
 122  
 123  
 124  @pytest.fixture
 125  def keras_data_gen_sequence(random_train_data, random_one_hot_labels):
 126      class DataGenerator(tf.keras.utils.Sequence):
 127          def __len__(self):
 128              return 128
 129  
 130          def __getitem__(self, index):
 131              x = random_train_data
 132              y = random_one_hot_labels
 133              return x, y
 134  
 135      return DataGenerator()
 136  
 137  
 138  @pytest.fixture(autouse=True)
 139  def clear_fluent_autologging_import_hooks():
 140      """
 141      Clears import hooks for MLflow fluent autologging (`mlflow.autolog()`) between tests
 142      to ensure that interactions between fluent autologging and TensorFlow / tf.keras can
 143      be tested successfully
 144      """
 145      mlflow.utils.import_hooks._post_import_hooks.pop("tensorflow", None)
 146      mlflow.utils.import_hooks._post_import_hooks.pop("keras", None)
 147  
 148  
 149  @pytest.fixture(autouse=True)
 150  def clear_autologging_config():
 151      """
 152      Clears TensorFlow autologging config, simulating a fresh state where autologging has not
 153      been previously enabled with any particular configuration
 154      """
 155      AUTOLOGGING_INTEGRATIONS.pop(mlflow.tensorflow.FLAVOR_NAME, None)
 156  
 157  
 158  def create_tf_keras_model():
 159      model = tf.keras.Sequential()
 160      model.add(tf.keras.Input(shape=(4,), dtype="float64"))
 161      model.add(layers.Dense(16, activation="relu"))
 162      model.add(layers.Dense(3, activation="softmax"))
 163  
 164      model.compile(
 165          optimizer=tf.keras.optimizers.Adam(), loss="categorical_crossentropy", metrics=["accuracy"]
 166      )
 167      return model
 168  
 169  
 170  def test_tf_keras_autolog_ends_auto_created_run(random_train_data, random_one_hot_labels):
 171      mlflow.tensorflow.autolog()
 172  
 173      data = random_train_data
 174      labels = random_one_hot_labels
 175  
 176      model = create_tf_keras_model()
 177      model.fit(data, labels, epochs=10)
 178  
 179      assert mlflow.active_run() is None
 180  
 181  
 182  def test_extra_tags_tensorflow_autolog(random_train_data, random_one_hot_labels):
 183      mlflow.tensorflow.autolog(extra_tags={"test_tag": "tf_autolog"})
 184  
 185      data = random_train_data
 186      labels = random_one_hot_labels
 187  
 188      model = create_tf_keras_model()
 189      model.fit(data, labels, epochs=10)
 190  
 191      run = mlflow.last_active_run()
 192      assert run.data.tags["test_tag"] == "tf_autolog"
 193      assert run.data.tags[mlflow.utils.mlflow_tags.MLFLOW_AUTOLOGGING] == "tensorflow"
 194  
 195  
 196  @pytest.mark.parametrize("log_models", [True, False])
 197  def test_tf_keras_autolog_log_models_configuration(
 198      random_train_data, random_one_hot_labels, log_models
 199  ):
 200      mlflow.tensorflow.autolog(log_models=log_models)
 201  
 202      data = random_train_data
 203      labels = random_one_hot_labels
 204  
 205      model = create_tf_keras_model()
 206  
 207      model.fit(data, labels, epochs=10)
 208  
 209      assert (mlflow.last_logged_model() is not None) == log_models
 210  
 211  
 212  @pytest.mark.parametrize("log_models", [True, False])
 213  @pytest.mark.parametrize("log_datasets", [True, False])
 214  def test_tf_keras_autolog_log_datasets_configuration_with_numpy(
 215      random_train_data, random_one_hot_labels, log_datasets, log_models
 216  ):
 217      mlflow.tensorflow.autolog(log_datasets=log_datasets, log_models=log_models)
 218  
 219      data = random_train_data
 220      labels = random_one_hot_labels
 221  
 222      model = create_tf_keras_model()
 223  
 224      model.fit(data, labels, epochs=10)
 225  
 226      client = MlflowClient()
 227      run_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs
 228      dataset_inputs = run_inputs.dataset_inputs
 229      if log_datasets:
 230          assert len(dataset_inputs) == 1
 231          feature_schema = _infer_schema(data)
 232          target_schema = _infer_schema(labels)
 233          assert dataset_inputs[0].dataset.schema == json.dumps({
 234              "mlflow_tensorspec": {
 235                  "features": feature_schema.to_json(),
 236                  "targets": target_schema.to_json(),
 237              }
 238          })
 239      else:
 240          assert len(dataset_inputs) == 0
 241      logged_model_inputs = run_inputs.model_inputs
 242      logged_model = mlflow.last_logged_model()
 243      if log_models:
 244          if log_datasets:
 245              assert len(logged_model_inputs) == 1
 246              assert logged_model_inputs[0].model_id == logged_model.model_id
 247          else:
 248              assert logged_model is not None
 249              assert logged_model.source_run_id == mlflow.last_active_run().info.run_id
 250      else:
 251          assert len(logged_model_inputs) == 0
 252          assert logged_model is None
 253  
 254  
 255  @pytest.mark.parametrize("log_datasets", [True, False])
 256  def test_tf_keras_autolog_log_datasets_configuration_with_tensor(
 257      random_train_data, random_one_hot_labels, log_datasets
 258  ):
 259      mlflow.tensorflow.autolog(log_datasets=log_datasets)
 260  
 261      data_as_tensor = tf.convert_to_tensor(random_train_data)
 262      labels_as_tensor = tf.convert_to_tensor(random_one_hot_labels)
 263  
 264      model = create_tf_keras_model()
 265  
 266      model.fit(data_as_tensor, labels_as_tensor, epochs=10)
 267  
 268      client = MlflowClient()
 269      dataset_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs.dataset_inputs
 270      if log_datasets:
 271          assert len(dataset_inputs) == 1
 272          feature_schema = _infer_schema(data_as_tensor.numpy())
 273          target_schema = _infer_schema(labels_as_tensor.numpy())
 274          assert dataset_inputs[0].dataset.schema == json.dumps({
 275              "mlflow_tensorspec": {
 276                  "features": feature_schema.to_json(),
 277                  "targets": target_schema.to_json(),
 278              }
 279          })
 280      else:
 281          assert len(dataset_inputs) == 0
 282  
 283  
 284  @pytest.mark.parametrize("log_datasets", [True, False])
 285  def test_tf_keras_autolog_log_datasets_configuration_with_tf_dataset(
 286      fashion_mnist_tf_dataset, log_datasets
 287  ):
 288      mlflow.tensorflow.autolog(log_datasets=log_datasets)
 289      fashion_mnist_model = _create_fashion_mnist_model()
 290      fashion_mnist_model.fit(fashion_mnist_tf_dataset)
 291  
 292      client = MlflowClient()
 293      dataset_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs.dataset_inputs
 294      if log_datasets:
 295          assert len(dataset_inputs) == 1
 296          numpy_data = next(fashion_mnist_tf_dataset.as_numpy_iterator())
 297          assert dataset_inputs[0].dataset.schema == json.dumps({
 298              "mlflow_tensorspec": {
 299                  "features": _infer_schema({
 300                      str(i): data_element for i, data_element in enumerate(numpy_data)
 301                  }).to_json(),
 302                  "targets": None,
 303              }
 304          })
 305  
 306      else:
 307          assert len(dataset_inputs) == 0
 308  
 309  
 310  def test_tf_keras_autolog_log_datasets_with_validation_data(
 311      fashion_mnist_tf_dataset, fashion_mnist_tf_dataset_eval
 312  ):
 313      mlflow.tensorflow.autolog(log_datasets=True)
 314      fashion_mnist_model = _create_fashion_mnist_model()
 315      fashion_mnist_model.fit(fashion_mnist_tf_dataset, validation_data=fashion_mnist_tf_dataset_eval)
 316  
 317      client = MlflowClient()
 318      dataset_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs.dataset_inputs
 319      assert len(dataset_inputs) == 2
 320      assert dataset_inputs[0].tags[0].value == "train"
 321      assert dataset_inputs[1].tags[0].value == "eval"
 322  
 323  
 324  def test_tf_keras_autolog_log_datasets_with_validation_data_as_numpy_tuple(
 325      fashion_mnist_tf_dataset, fashion_mnist_tf_dataset_eval
 326  ):
 327      mlflow.tensorflow.autolog(log_datasets=True)
 328      fashion_mnist_model = _create_fashion_mnist_model()
 329      X_eval, y_eval = next(fashion_mnist_tf_dataset_eval.as_numpy_iterator())
 330      fashion_mnist_model.fit(fashion_mnist_tf_dataset, validation_data=(X_eval, y_eval))
 331  
 332      client = MlflowClient()
 333      dataset_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs.dataset_inputs
 334      assert len(dataset_inputs) == 2
 335      assert dataset_inputs[0].tags[0].value == "train"
 336      assert dataset_inputs[1].tags[0].value == "eval"
 337  
 338  
 339  def test_tf_keras_autolog_log_datasets_with_validation_data_as_tf_tuple(
 340      fashion_mnist_tf_dataset, fashion_mnist_tf_dataset_eval
 341  ):
 342      mlflow.tensorflow.autolog(log_datasets=True)
 343      fashion_mnist_model = _create_fashion_mnist_model()
 344      # convert tensorflow dataset into tensors
 345      X_eval, y_eval = next(fashion_mnist_tf_dataset_eval.as_numpy_iterator())
 346      X_eval_tensor = tf.convert_to_tensor(X_eval)
 347      y_eval_tensor = tf.convert_to_tensor(y_eval)
 348      fashion_mnist_model.fit(
 349          fashion_mnist_tf_dataset, validation_data=(X_eval_tensor, y_eval_tensor)
 350      )
 351  
 352      client = MlflowClient()
 353      dataset_inputs = client.get_run(mlflow.last_active_run().info.run_id).inputs.dataset_inputs
 354      assert len(dataset_inputs) == 2
 355      assert dataset_inputs[0].tags[0].value == "train"
 356      assert dataset_inputs[1].tags[0].value == "eval"
 357  
 358  
 359  def test_tf_keras_autolog_persists_manually_created_run(random_train_data, random_one_hot_labels):
 360      mlflow.tensorflow.autolog()
 361      with mlflow.start_run() as run:
 362          data = random_train_data
 363          labels = random_one_hot_labels
 364  
 365          model = create_tf_keras_model()
 366          model.fit(data, labels, epochs=10)
 367  
 368          assert mlflow.active_run()
 369          assert mlflow.active_run().info.run_id == run.info.run_id
 370  
 371  
 372  @pytest.fixture
 373  def tf_keras_random_data_run(random_train_data, random_one_hot_labels, initial_epoch):
 374      mlflow.tensorflow.autolog()
 375  
 376      data = random_train_data
 377      labels = random_one_hot_labels
 378  
 379      model = create_tf_keras_model()
 380      history = model.fit(
 381          data, labels, epochs=initial_epoch + 10, steps_per_epoch=1, initial_epoch=initial_epoch
 382      )
 383  
 384      client = MlflowClient()
 385      return client.get_run(client.search_runs(["0"])[0].info.run_id), history
 386  
 387  
 388  @pytest.mark.parametrize("initial_epoch", [0, 10])
 389  def test_tf_keras_autolog_logs_expected_data(tf_keras_random_data_run):
 390      run, history = tf_keras_random_data_run
 391      data = run.data
 392      assert "accuracy" in data.metrics
 393      assert "loss" in data.metrics
 394      # Testing explicitly passed parameters are logged correctly
 395      assert "epochs" in data.params
 396      assert data.params["epochs"] == str(history.epoch[-1] + 1)
 397      assert "steps_per_epoch" in data.params
 398      assert data.params["steps_per_epoch"] == "1"
 399      # Testing default parameters are logged correctly
 400      assert "initial_epoch" in data.params
 401      assert data.params["initial_epoch"] == str(history.epoch[0])
 402      # Testing unwanted parameters are not logged
 403      assert "callbacks" not in data.params
 404      assert "validation_data" not in data.params
 405      # Testing optimizer parameters are logged
 406      assert "opt_name" in data.params
 407      assert data.params["opt_name"].lower() == "adam"
 408      assert "opt_learning_rate" in data.params
 409      assert "opt_beta_1" in data.params
 410      assert "opt_beta_2" in data.params
 411      assert "opt_epsilon" in data.params
 412      assert "opt_amsgrad" in data.params
 413      assert data.params["opt_amsgrad"] == "False"
 414      client = MlflowClient()
 415      all_epoch_acc = client.get_metric_history(run.info.run_id, "accuracy")
 416      num_of_epochs = len(history.history["loss"])
 417      assert len(all_epoch_acc) == num_of_epochs == 10
 418      artifacts = client.list_artifacts(run.info.run_id)
 419      artifacts = (x.path for x in artifacts)
 420      assert "model_summary.txt" in artifacts
 421  
 422  
 423  def __example_tf_dataset(batch_size):
 424      a = tf.data.Dataset.range(1)
 425      b = tf.data.Dataset.range(1)
 426      ds = tf.data.Dataset.zip((a, b))
 427      return ds.batch(batch_size)
 428  
 429  
 430  class __ExampleSequence(tf.keras.utils.Sequence):
 431      def __init__(self, batch_size, with_sample_weights=False):
 432          self.batch_size = batch_size
 433          self.with_sample_weights = with_sample_weights
 434  
 435      def __len__(self):
 436          return 10
 437  
 438      def __getitem__(self, idx):
 439          x = np.array([idx] * self.batch_size)
 440          y = np.array([-idx] * self.batch_size)
 441          if self.with_sample_weights:
 442              w = np.array([1] * self.batch_size)
 443              return x, y, w
 444          return x, y
 445  
 446  
 447  def __generator(data, target, batch_size):
 448      data_batches = np.split(data, data.shape[0] // batch_size)
 449      target_batches = np.split(target, target.shape[0] // batch_size)
 450      yield from zip(data_batches, target_batches)
 451  
 452  
 453  class __GeneratorClass:
 454      def __init__(self, data, target, batch_size):
 455          self.data = data
 456          self.target = target
 457          self.batch_size = batch_size
 458          self.ptr = 0
 459  
 460      def __next__(self):
 461          if self.ptr >= len(self.data):
 462              raise StopIteration
 463          idx = self.ptr % len(self.data)
 464          self.ptr += 1
 465          return self.data[idx : idx + self.batch_size], self.target[idx : idx + self.batch_size]
 466  
 467      def __iter__(self):
 468          return self
 469  
 470  
 471  @pytest.mark.parametrize(
 472      "generate_data",
 473      [
 474          __example_tf_dataset,
 475          __ExampleSequence,
 476          functools.partial(__ExampleSequence, with_sample_weights=True),
 477          functools.partial(__generator, np.array([[1]] * 10), np.array([[1]] * 10)),
 478          pytest.param(
 479              functools.partial(__GeneratorClass, np.array([[1]] * 10), np.array([[1]] * 10)),
 480              marks=pytest.mark.skipif(
 481                  Version(tf.__version__).release >= (2, 15)
 482                  and "TF_USE_LEGACY_KERAS" not in os.environ,
 483                  reason="does not support",
 484              ),
 485          ),
 486      ],
 487  )
 488  @pytest.mark.parametrize("batch_size", [5, 10])
 489  def test_tf_keras_autolog_implicit_batch_size_works(generate_data, batch_size):
 490      mlflow.autolog()
 491      model = tf.keras.Sequential()
 492      model.add(tf.keras.layers.Dense(1, input_shape=(1,)))
 493      model.compile(loss="mse")
 494  
 495      # 'x' passed as arg
 496      model.fit(generate_data(batch_size), verbose=0)
 497      assert mlflow.last_active_run().data.params["batch_size"] == str(batch_size)
 498  
 499      # 'x' passed as kwarg
 500      model.fit(x=generate_data(batch_size), verbose=0)
 501      assert mlflow.last_active_run().data.params["batch_size"] == str(batch_size)
 502  
 503  
 504  def __tf_dataset_multi_input(batch_size):
 505      a = tf.data.Dataset.range(1)
 506      b = tf.data.Dataset.range(1)
 507      c = tf.data.Dataset.range(1)
 508      ds = tf.data.Dataset.zip(((a, b), c))
 509      return ds.batch(batch_size)
 510  
 511  
 512  class __SequenceMultiInput(tf.keras.utils.Sequence):
 513      def __init__(self, batch_size):
 514          self.batch_size = batch_size
 515  
 516      def __len__(self):
 517          return 10
 518  
 519      def __getitem__(self, idx):
 520          return (np.random.rand(self.batch_size), np.random.rand(self.batch_size)), np.random.rand(
 521              self.batch_size
 522          )
 523  
 524  
 525  def __generator_multi_input(data, target, batch_size):
 526      data_batches = np.split(data, data.shape[1] // batch_size, axis=1)
 527      target_batches = np.split(target, target.shape[0] // batch_size)
 528      for inputs, output in zip(data_batches, target_batches):
 529          yield tuple(inputs), output
 530  
 531  
 532  class __GeneratorClassMultiInput:
 533      def __init__(self, data, target, batch_size):
 534          self.data = data
 535          self.target = target
 536          self.batch_size = batch_size
 537          self.ptr = 0
 538  
 539      def __next__(self):
 540          if self.ptr >= len(self.data):
 541              raise StopIteration
 542          idx = self.ptr % len(self.data)
 543          self.ptr += 1
 544          return (
 545              self.data[idx : idx + self.batch_size, 0],
 546              self.data[idx : idx + self.batch_size, 1],
 547          ), self.target[idx : idx + self.batch_size]
 548  
 549      def __iter__(self):
 550          return self
 551  
 552  
 553  @pytest.mark.parametrize(
 554      "generate_data",
 555      [
 556          __tf_dataset_multi_input,
 557          __SequenceMultiInput,
 558          functools.partial(__generator_multi_input, np.random.rand(2, 10), np.random.rand(10)),
 559          functools.partial(__GeneratorClassMultiInput, np.random.rand(10, 2), np.random.rand(10, 1)),
 560      ],
 561  )
 562  @pytest.mark.parametrize("batch_size", [5, 10])
 563  def test_tf_keras_autolog_implicit_batch_size_works_multi_input(generate_data, batch_size):
 564      mlflow.tensorflow.autolog()
 565  
 566      input1 = tf.keras.Input(shape=(1,))
 567      input2 = tf.keras.Input(shape=(1,))
 568      concat = tf.keras.layers.Concatenate()([input1, input2])
 569      output = tf.keras.layers.Dense(1, activation="sigmoid")(concat)
 570  
 571      model = tf.keras.models.Model(inputs=[input1, input2], outputs=output)
 572      model.compile(loss="mse")
 573  
 574      # 'x' passed as arg
 575      model.fit(generate_data(batch_size), verbose=0)
 576      assert mlflow.last_active_run().data.params["batch_size"] == str(batch_size)
 577  
 578      # 'x' passed as kwarg
 579      model.fit(x=generate_data(batch_size), verbose=0)
 580      assert mlflow.last_active_run().data.params["batch_size"] == str(batch_size)
 581  
 582  
 583  @pytest.mark.skipif(
 584      Version(tf.__version__) < Version("2.1.4"),
 585      reason="Does not support passing of generator classes as `x` in `fit`",
 586  )
 587  @pytest.mark.parametrize(
 588      "generator",
 589      [
 590          __generator,
 591          pytest.param(
 592              __GeneratorClass,
 593              marks=pytest.mark.skipif(
 594                  Version(tf.__version__).release >= (2, 15)
 595                  and "TF_USE_LEGACY_KERAS" not in os.environ,
 596                  reason="does not support",
 597              ),
 598          ),
 599      ],
 600  )
 601  @pytest.mark.parametrize("batch_size", [2, 3, 6])
 602  def test_tf_keras_autolog_implicit_batch_size_for_generator_dataset_without_side_effects(
 603      generator,
 604      batch_size,
 605  ):
 606      from tensorflow.keras.layers import Dense
 607      from tensorflow.keras.models import Sequential
 608  
 609      data = np.array([[1, 2, 3], [3, 2, 1], [2, 2, 2], [10, 20, 30], [30, 20, 10], [20, 20, 20]])
 610      target = np.array([[1], [3], [2], [11], [13], [12]])
 611  
 612      model = Sequential()
 613      model.add(
 614          Dense(
 615              5, input_dim=3, activation="relu", kernel_initializer="zeros", bias_initializer="zeros"
 616          )
 617      )
 618      model.add(Dense(1, kernel_initializer="zeros", bias_initializer="zeros"))
 619      model.compile(loss="mae", optimizer="adam", metrics=["mse"])
 620  
 621      mlflow.autolog()
 622      actual_mse = model.fit(generator(data, target, batch_size), verbose=0).history["mse"][-1]
 623  
 624      mlflow.autolog(disable=True)
 625      expected_mse = model.fit(generator(data, target, batch_size), verbose=0).history["mse"][-1]
 626  
 627      np.testing.assert_allclose(actual_mse, expected_mse, atol=1)
 628      assert mlflow.last_active_run().data.params["batch_size"] == str(batch_size)
 629  
 630  
 631  def test_tf_keras_autolog_succeeds_for_tf_datasets_lacking_batch_size_info():
 632      X_train = np.random.rand(100, 100)
 633      y_train = np.random.randint(0, 10, 100)
 634  
 635      train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
 636      train_ds = train_ds.batch(50)
 637      train_ds = train_ds.cache().prefetch(buffer_size=5)
 638      assert not hasattr(train_ds, "_batch_size")
 639  
 640      model = tf.keras.Sequential()
 641      model.add(tf.keras.Input((100,)))
 642      model.add(tf.keras.layers.Dense(256, activation="relu"))
 643      model.add(tf.keras.layers.Dropout(rate=0.4))
 644      model.add(tf.keras.layers.Dense(10, activation="sigmoid"))
 645      model.compile(
 646          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
 647          optimizer="Adam",
 648          metrics=["accuracy"],
 649      )
 650  
 651      mlflow.tensorflow.autolog()
 652      model.fit(train_ds, epochs=100)
 653  
 654      assert mlflow.last_active_run().data.params["batch_size"] == "None"
 655  
 656  
 657  def test_tf_keras_autolog_records_metrics_for_last_epoch(random_train_data, random_one_hot_labels):
 658      num_training_epochs = 17
 659      mlflow.tensorflow.autolog(log_every_epoch=True)
 660  
 661      model = create_tf_keras_model()
 662      with mlflow.start_run() as run:
 663          model.fit(
 664              random_train_data,
 665              random_one_hot_labels,
 666              epochs=num_training_epochs,
 667              initial_epoch=0,
 668          )
 669  
 670      client = MlflowClient()
 671      run_metrics = client.get_run(run.info.run_id).data.metrics
 672      assert "accuracy" in run_metrics
 673      all_epoch_acc = client.get_metric_history(run.info.run_id, "accuracy")
 674      assert len(all_epoch_acc) == num_training_epochs
 675  
 676  
 677  def test_tf_keras_autolog_logs_metrics_for_single_epoch_training(
 678      random_train_data, random_one_hot_labels
 679  ):
 680      """
 681      tf.Keras exhibits inconsistent epoch indexing behavior in comparison with other
 682      TF2 APIs (e.g., tf.Estimator). tf.Keras uses zero-indexing for epochs,
 683      while other APIs use one-indexing. Accordingly, this test verifies that metrics are
 684      produced in the boundary case where a model is trained for a single epoch, ensuring
 685      that we don't miss the zero index in the tf.Keras case.
 686      """
 687      mlflow.tensorflow.autolog()
 688  
 689      model = create_tf_keras_model()
 690      with mlflow.start_run() as run:
 691          model.fit(random_train_data, random_one_hot_labels, epochs=1)
 692  
 693      client = MlflowClient()
 694      run_metrics = client.get_run(run.info.run_id).data.metrics
 695      assert "accuracy" in run_metrics
 696      assert "loss" in run_metrics
 697  
 698  
 699  def test_tf_keras_autolog_names_positional_parameters_correctly(
 700      random_train_data, random_one_hot_labels
 701  ):
 702      mlflow.tensorflow.autolog()
 703  
 704      data = random_train_data
 705      labels = random_one_hot_labels
 706  
 707      model = create_tf_keras_model()
 708  
 709      with mlflow.start_run():
 710          # Pass `batch_size` as a positional argument for testing purposes
 711          model.fit(data, labels, 8, epochs=10, steps_per_epoch=1)
 712          run_id = mlflow.active_run().info.run_id
 713  
 714      client = MlflowClient()
 715      run_info = client.get_run(run_id)
 716      assert run_info.data.params.get("batch_size") == "8"
 717  
 718  
 719  @pytest.mark.parametrize("initial_epoch", [0, 10])
 720  def test_tf_keras_autolog_model_can_load_from_artifact(tf_keras_random_data_run, random_train_data):
 721      run, _ = tf_keras_random_data_run
 722  
 723      client = MlflowClient()
 724      artifacts = client.list_artifacts(run.info.run_id)
 725      artifacts = (x.path for x in artifacts)
 726      assert "tensorboard_logs" in artifacts
 727      model = mlflow.tensorflow.load_model("runs:/" + run.info.run_id + "/model")
 728      model.predict(random_train_data)
 729  
 730  
 731  def get_tf_keras_random_data_run_with_callback(
 732      random_train_data,
 733      random_one_hot_labels,
 734      callback,
 735      restore_weights,
 736      patience,
 737      initial_epoch,
 738      log_models,
 739  ):
 740      mlflow.tensorflow.autolog(log_models=log_models)
 741  
 742      data = random_train_data
 743      labels = random_one_hot_labels
 744  
 745      model = create_tf_keras_model()
 746      if callback == "early":
 747          # min_delta is set as such to guarantee early stopping
 748          callback = tf.keras.callbacks.EarlyStopping(
 749              monitor="loss",
 750              patience=patience,
 751              min_delta=99999999,
 752              restore_best_weights=restore_weights,
 753              verbose=1,
 754          )
 755      else:
 756  
 757          class CustomCallback(tf.keras.callbacks.Callback):
 758              def on_train_end(self, logs=None):
 759                  pass
 760  
 761          callback = CustomCallback()
 762  
 763      history = model.fit(
 764          data, labels, epochs=initial_epoch + 10, callbacks=[callback], initial_epoch=initial_epoch
 765      )
 766  
 767      client = MlflowClient()
 768      return client.get_run(client.search_runs(["0"])[0].info.run_id), history, callback
 769  
 770  
 771  @pytest.fixture
 772  def tf_keras_random_data_run_with_callback(
 773      random_train_data,
 774      random_one_hot_labels,
 775      callback,
 776      restore_weights,
 777      patience,
 778      initial_epoch,
 779      log_models,
 780  ):
 781      return get_tf_keras_random_data_run_with_callback(
 782          random_train_data,
 783          random_one_hot_labels,
 784          callback,
 785          restore_weights,
 786          patience,
 787          initial_epoch,
 788          log_models=log_models,
 789      )
 790  
 791  
 792  @pytest.mark.parametrize("log_models", [True, False])
 793  @pytest.mark.parametrize("restore_weights", [True])
 794  @pytest.mark.parametrize("callback", ["early"])
 795  @pytest.mark.parametrize("patience", [0, 1, 5])
 796  @pytest.mark.parametrize("initial_epoch", [0, 10])
 797  def test_tf_keras_autolog_early_stop_logs(
 798      tf_keras_random_data_run_with_callback, initial_epoch, log_models
 799  ):
 800      run, history, callback = tf_keras_random_data_run_with_callback
 801      metrics = run.data.metrics
 802      params = run.data.params
 803      assert "patience" in params
 804      assert params["patience"] == str(callback.patience)
 805      assert "monitor" in params
 806      assert params["monitor"] == "loss"
 807      assert "verbose" not in params
 808      assert "mode" not in params
 809      assert "stopped_epoch" in metrics
 810      assert "restored_epoch" in metrics
 811      restored_epoch = int(metrics["restored_epoch"])
 812      # In this test, the best epoch is always the first epoch because the early stopping callback
 813      # never observes a loss improvement due to an extremely large `min_delta` value
 814      assert restored_epoch == initial_epoch
 815      assert "loss" in history.history
 816      client = MlflowClient()
 817      metric_history = client.get_metric_history(run.info.run_id, "loss")
 818      # Check that MLflow has logged the metrics of the "best" model, in addition to per-epoch metrics
 819      loss = history.history["loss"]
 820      assert len(metric_history) == len(loss) + 1
 821      steps, values = map(list, zip(*[(m.step, m.value) for m in metric_history]))
 822      # Check that MLflow has logged the correct steps
 823      assert steps == [*history.epoch, callback.stopped_epoch + 1]
 824      # Check that MLflow has logged the correct metric values
 825      np.testing.assert_allclose(values, [*loss, callback.best])
 826  
 827      artifacts = [f.path for f in client.list_artifacts(run.info.run_id)]
 828      assert "tensorboard_logs" in artifacts
 829  
 830      # Check metrics are logged to the LoggedModel
 831      if log_models:
 832          logged_model = mlflow.last_logged_model()
 833          assert logged_model is not None
 834          assert {metric.key: metric.value for metric in logged_model.metrics} == metrics
 835  
 836  
 837  @pytest.mark.parametrize("log_models", [False])
 838  @pytest.mark.parametrize("restore_weights", [True])
 839  @pytest.mark.parametrize("callback", ["early"])
 840  @pytest.mark.parametrize("patience", [11])
 841  @pytest.mark.parametrize("initial_epoch", [0, 10])
 842  def test_tf_keras_autolog_early_stop_no_stop_does_not_log(tf_keras_random_data_run_with_callback):
 843      run, history, callback = tf_keras_random_data_run_with_callback
 844      metrics = run.data.metrics
 845      params = run.data.params
 846      assert "patience" in params
 847      assert params["patience"] == str(callback.patience)
 848      assert "monitor" in params
 849      assert params["monitor"] == "loss"
 850      assert "verbose" not in params
 851      assert "mode" not in params
 852      assert "stopped_epoch" not in metrics
 853      assert "restored_epoch" not in metrics
 854      assert "loss" in history.history
 855      num_of_epochs = len(history.history["loss"])
 856      client = MlflowClient()
 857      metric_history = client.get_metric_history(run.info.run_id, "loss")
 858      # Check the test epoch numbers are correct
 859      assert num_of_epochs == 10
 860      assert len(metric_history) == num_of_epochs
 861  
 862  
 863  @pytest.mark.parametrize("log_models", [False])
 864  @pytest.mark.parametrize("restore_weights", [False])
 865  @pytest.mark.parametrize("callback", ["early"])
 866  @pytest.mark.parametrize("patience", [5])
 867  @pytest.mark.parametrize("initial_epoch", [0, 10])
 868  def test_tf_keras_autolog_early_stop_no_restore_doesnt_log(tf_keras_random_data_run_with_callback):
 869      run, history, callback = tf_keras_random_data_run_with_callback
 870      metrics = run.data.metrics
 871      params = run.data.params
 872      assert "patience" in params
 873      assert params["patience"] == str(callback.patience)
 874      assert "monitor" in params
 875      assert params["monitor"] == "loss"
 876      assert "verbose" not in params
 877      assert "mode" not in params
 878      assert "stopped_epoch" in metrics
 879      assert "restored_epoch" not in metrics
 880      assert "loss" in history.history
 881      num_of_epochs = len(history.history["loss"])
 882      client = MlflowClient()
 883      metric_history = client.get_metric_history(run.info.run_id, "loss")
 884      # Check the test epoch numbers are correct
 885      assert num_of_epochs == callback.patience + 1
 886      assert len(metric_history) == num_of_epochs
 887  
 888  
 889  @pytest.mark.parametrize("log_models", [False])
 890  @pytest.mark.parametrize("restore_weights", [False])
 891  @pytest.mark.parametrize("callback", ["not-early"])
 892  @pytest.mark.parametrize("patience", [5])
 893  @pytest.mark.parametrize("initial_epoch", [0, 10])
 894  def test_tf_keras_autolog_non_early_stop_callback_no_log(tf_keras_random_data_run_with_callback):
 895      run, history = tf_keras_random_data_run_with_callback[:-1]
 896      metrics = run.data.metrics
 897      params = run.data.params
 898      assert "patience" not in params
 899      assert "monitor" not in params
 900      assert "verbose" not in params
 901      assert "mode" not in params
 902      assert "stopped_epoch" not in metrics
 903      assert "restored_epoch" not in metrics
 904      assert "loss" in history.history
 905      num_of_epochs = len(history.history["loss"])
 906      client = MlflowClient()
 907      metric_history = client.get_metric_history(run.info.run_id, "loss")
 908      # Check the test epoch numbers are correct
 909      assert num_of_epochs == 10
 910      assert len(metric_history) == num_of_epochs
 911  
 912  
 913  @pytest.mark.parametrize("positional", [True, False])
 914  def test_tf_keras_autolog_does_not_mutate_original_callbacks_list(
 915      tmp_path, random_train_data, random_one_hot_labels, positional
 916  ):
 917      """
 918      TensorFlow autologging passes new callbacks to the `fit()` / `fit_generator()` function. If
 919      preexisting user-defined callbacks already exist, these new callbacks are added to the
 920      user-specified ones. This test verifies that the new callbacks are added to the without
 921      permanently mutating the original list of callbacks.
 922      """
 923      mlflow.tensorflow.autolog()
 924  
 925      tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tmp_path)
 926      callbacks = [tensorboard_callback]
 927  
 928      model = create_tf_keras_model()
 929      data = random_train_data
 930      labels = random_one_hot_labels
 931  
 932      if positional:
 933          model.fit(data, labels, None, 10, 1, callbacks)
 934      else:
 935          model.fit(data, labels, epochs=10, callbacks=callbacks)
 936  
 937      assert len(callbacks) == 1
 938      assert callbacks == [tensorboard_callback]
 939  
 940  
 941  def test_tf_keras_autolog_does_not_delete_logging_directory_for_tensorboard_callback(
 942      tmp_path, random_train_data, random_one_hot_labels
 943  ):
 944      tensorboard_callback_logging_dir_path = str(tmp_path.joinpath("tb_logs"))
 945      tensorboard_callback = tf.keras.callbacks.TensorBoard(
 946          tensorboard_callback_logging_dir_path, histogram_freq=0
 947      )
 948  
 949      mlflow.tensorflow.autolog()
 950  
 951      data = random_train_data
 952      labels = random_one_hot_labels
 953  
 954      model = create_tf_keras_model()
 955      model.fit(data, labels, epochs=10, callbacks=[tensorboard_callback])
 956  
 957      assert os.path.exists(tensorboard_callback_logging_dir_path)
 958  
 959  
 960  def test_tf_keras_autolog_logs_to_and_deletes_temporary_directory_when_tensorboard_callback_absent(
 961      tmp_path, random_train_data, random_one_hot_labels
 962  ):
 963      from mlflow.tensorflow import _TensorBoardLogDir
 964  
 965      mlflow.tensorflow.autolog()
 966  
 967      mock_log_dir_inst = _TensorBoardLogDir(
 968          location=str(tmp_path.joinpath("tb_logging")), is_temp=True
 969      )
 970      with patch("mlflow.tensorflow._TensorBoardLogDir", autospec=True) as mock_log_dir_class:
 971          mock_log_dir_class.return_value = mock_log_dir_inst
 972  
 973          data = random_train_data
 974          labels = random_one_hot_labels
 975  
 976          model = create_tf_keras_model()
 977          model.fit(data, labels, epochs=10)
 978  
 979          assert not os.path.exists(mock_log_dir_inst.location)
 980  
 981  
 982  def get_text_vec_model(train_samples):
 983      # Taken from: https://github.com/mlflow/mlflow/issues/3910
 984  
 985      try:
 986          from tensorflow.keras.layers.experimental.preprocessing import TextVectorization
 987      except ModuleNotFoundError:
 988          from tensorflow.keras.layers import TextVectorization
 989  
 990      VOCAB_SIZE = 10
 991      SEQUENCE_LENGTH = 16
 992      EMBEDDING_DIM = 16
 993  
 994      vectorizer_layer = TextVectorization(
 995          max_tokens=VOCAB_SIZE,
 996          output_mode="int",
 997          output_sequence_length=SEQUENCE_LENGTH,
 998      )
 999      vectorizer_layer.adapt(train_samples)
1000      model = tf.keras.Sequential([
1001          vectorizer_layer,
1002          tf.keras.layers.Embedding(
1003              VOCAB_SIZE,
1004              EMBEDDING_DIM,
1005              name="embedding",
1006              mask_zero=True,
1007          ),
1008          tf.keras.layers.GlobalAveragePooling1D(),
1009          tf.keras.layers.Dense(16, activation="relu"),
1010          tf.keras.layers.Dense(1, activation="tanh"),
1011      ])
1012      model.compile(optimizer="adam", loss="mse", metrics=["mae"])
1013      return model
1014  
1015  
1016  @pytest.mark.skipif(
1017      Version(tf.__version__) < Version("2.3.0"),
1018      reason=(
1019          "Deserializing a model with `TextVectorization` and `Embedding` "
1020          "fails in tensorflow < 2.3.0. See this issue: "
1021          "https://github.com/tensorflow/tensorflow/issues/38250."
1022      ),
1023  )
1024  def test_autolog_text_vec_model(tmp_path):
1025      """
1026      Verifies autolog successfully saves a model that can't be saved in the H5 format
1027      """
1028      mlflow.tensorflow.autolog()
1029  
1030      train_samples = tf.convert_to_tensor(["this is an example", "another example"])
1031      train_labels = np.array([0.4, 0.2])
1032      model = get_text_vec_model(train_samples)
1033  
1034      with mlflow.start_run() as run:
1035          model.fit(train_samples, train_labels, epochs=1)
1036  
1037      loaded_model = mlflow.tensorflow.load_model("runs:/" + run.info.run_id + "/model")
1038      np.testing.assert_array_equal(loaded_model.predict(train_samples), model.predict(train_samples))
1039  
1040  
1041  def test_tf_keras_model_autolog_registering_model(random_train_data, random_one_hot_labels):
1042      registered_model_name = "test_autolog_registered_model"
1043      mlflow.tensorflow.autolog(registered_model_name=registered_model_name)
1044      with mlflow.start_run():
1045          model = create_tf_keras_model()
1046          model.fit(random_train_data, random_one_hot_labels, epochs=10)
1047  
1048          registered_model = MlflowClient().get_registered_model(registered_model_name)
1049          assert registered_model.name == registered_model_name
1050  
1051  
1052  def test_fluent_autolog_with_tf_keras_logs_expected_content(
1053      random_train_data, random_one_hot_labels
1054  ):
1055      """
1056      Guards against previously-exhibited issues where using the fluent `mlflow.autolog()` API with
1057      `tf.keras` Models did not work due to conflicting patches set by both the
1058      `mlflow.tensorflow.autolog()` and the `mlflow.keras.autolog()` APIs.
1059      """
1060      mlflow.autolog()
1061  
1062      model = create_tf_keras_model()
1063  
1064      with mlflow.start_run() as run:
1065          model.fit(random_train_data, random_one_hot_labels, epochs=10)
1066  
1067      client = MlflowClient()
1068      run_data = client.get_run(run.info.run_id).data
1069      assert "accuracy" in run_data.metrics
1070      assert "epochs" in run_data.params
1071  
1072  
1073  def test_callback_is_picklable():
1074      cb = MlflowCallback()
1075      pickle.dumps(cb)
1076  
1077      tb = _TensorBoard()
1078      pickle.dumps(tb)
1079  
1080  
1081  @pytest.mark.skipif(
1082      Version(tf.__version__) < Version("2.1.0"), reason="This test requires tensorflow >= 2.1.0"
1083  )
1084  def test_tf_keras_autolog_distributed_training(random_train_data, random_one_hot_labels):
1085      # Ref: https://www.tensorflow.org/tutorials/distribute/keras
1086      mlflow.tensorflow.autolog()
1087  
1088      with tf.distribute.MirroredStrategy().scope():
1089          model = create_tf_keras_model()
1090      fit_params = {"epochs": 10, "batch_size": 10}
1091      with mlflow.start_run() as run:
1092          model.fit(random_train_data, random_one_hot_labels, **fit_params)
1093      client = MlflowClient()
1094      assert client.get_run(run.info.run_id).data.params.keys() >= fit_params.keys()
1095  
1096  
1097  def test_import_tensorflow_with_fluent_autolog_enables_tensorflow_autologging():
1098      mlflow.autolog()
1099  
1100      import tensorflow  # noqa: F401
1101  
1102      assert not autologging_is_disabled(mlflow.tensorflow.FLAVOR_NAME)
1103  
1104  
1105  def _assert_autolog_infers_model_signature_correctly(input_sig_spec, output_sig_spec):
1106      logged_model = mlflow.last_logged_model()
1107      artifact_path = local_file_uri_to_path(logged_model.artifact_location)
1108      ml_model_path = os.path.join(artifact_path, "MLmodel")
1109      with open(ml_model_path) as f:
1110          data = yaml.safe_load(f)
1111          assert data is not None
1112          assert "signature" in data
1113          signature = data["signature"]
1114          assert signature is not None
1115          assert "inputs" in signature
1116          assert "outputs" in signature
1117          assert json.loads(signature["inputs"]) == input_sig_spec
1118          assert json.loads(signature["outputs"]) == output_sig_spec
1119  
1120  
1121  def _assert_keras_autolog_input_example_load_and_predict_with_nparray(random_train_data):
1122      logged_model = mlflow.last_logged_model()
1123      model_conf = Model.load(logged_model.model_uri)
1124      input_example = _read_example(model_conf, logged_model.model_uri)
1125      np.testing.assert_array_almost_equal(input_example, random_train_data[:5])
1126      pyfunc_model = mlflow.pyfunc.load_model(logged_model.model_uri)
1127      pyfunc_model.predict(input_example)
1128  
1129  
1130  def test_keras_autolog_input_example_load_and_predict_with_nparray(
1131      random_train_data, random_one_hot_labels
1132  ):
1133      mlflow.tensorflow.autolog(log_input_examples=True, log_model_signatures=True)
1134      initial_model = create_tf_keras_model()
1135      with mlflow.start_run():
1136          initial_model.fit(random_train_data, random_one_hot_labels)
1137          _assert_keras_autolog_input_example_load_and_predict_with_nparray(random_train_data)
1138  
1139  
1140  def test_keras_autolog_infers_model_signature_correctly_with_nparray(
1141      random_train_data, random_one_hot_labels
1142  ):
1143      mlflow.tensorflow.autolog(log_model_signatures=True)
1144      initial_model = create_tf_keras_model()
1145      with mlflow.start_run():
1146          initial_model.fit(random_train_data, random_one_hot_labels)
1147          _assert_autolog_infers_model_signature_correctly(
1148              [{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1, 4]}}],
1149              [{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 3]}}],
1150          )
1151  
1152  
1153  @pytest.mark.skipif(
1154      Version(tf.__version__) < Version("2.1.0"),
1155      reason="tf.data.Dataset inputs are unsupported for input example logging in TensorFlow < 2.1.0",
1156  )
1157  def test_keras_autolog_input_example_load_and_predict_with_tf_dataset(fashion_mnist_tf_dataset):
1158      mlflow.tensorflow.autolog(log_input_examples=True, log_model_signatures=True)
1159      fashion_mnist_model = _create_fashion_mnist_model()
1160      with mlflow.start_run():
1161          fashion_mnist_model.fit(fashion_mnist_tf_dataset)
1162          logged_model = mlflow.last_logged_model()
1163          model_conf = Model.load(logged_model.model_uri)
1164          input_example = _read_example(model_conf, logged_model.model_uri)
1165          pyfunc_model = mlflow.pyfunc.load_model(logged_model.model_uri)
1166          pyfunc_model.predict(input_example)
1167  
1168  
1169  @pytest.mark.skipif(
1170      Version(tf.__version__) < Version("2.1.0"),
1171      reason="tf.data.Dataset inputs are unsupported for signature logging in TensorFlow < 2.1.0",
1172  )
1173  def test_keras_autolog_infers_model_signature_correctly_with_tf_dataset(fashion_mnist_tf_dataset):
1174      mlflow.tensorflow.autolog(log_model_signatures=True)
1175      fashion_mnist_model = _create_fashion_mnist_model()
1176      with mlflow.start_run():
1177          fashion_mnist_model.fit(fashion_mnist_tf_dataset)
1178          _assert_autolog_infers_model_signature_correctly(
1179              [{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1, 28, 28]}}],
1180              [{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 10]}}],
1181          )
1182  
1183  
1184  def test_keras_autolog_input_example_load_and_predict_with_dict(
1185      random_train_dict_mapping, random_one_hot_labels
1186  ):
1187      mlflow.tensorflow.autolog(log_input_examples=True, log_model_signatures=True)
1188      model = _create_model_for_dict_mapping()
1189      with mlflow.start_run():
1190          model.fit(random_train_dict_mapping, random_one_hot_labels)
1191          logged_model = mlflow.last_logged_model()
1192          model_conf = Model.load(logged_model.model_uri)
1193          input_example = _read_example(model_conf, logged_model.model_uri)
1194          for k, v in random_train_dict_mapping.items():
1195              np.testing.assert_array_almost_equal(input_example[k], np.take(v, range(0, 5)))
1196          pyfunc_model = mlflow.pyfunc.load_model(logged_model.model_uri)
1197          pyfunc_model.predict(input_example)
1198  
1199  
1200  def test_keras_autolog_infers_model_signature_correctly_with_dict(
1201      random_train_dict_mapping, random_one_hot_labels
1202  ):
1203      mlflow.tensorflow.autolog(log_model_signatures=True)
1204      model = _create_model_for_dict_mapping()
1205      with mlflow.start_run():
1206          model.fit(random_train_dict_mapping, random_one_hot_labels)
1207          _assert_autolog_infers_model_signature_correctly(
1208              [
1209                  {"name": "a", "type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}},
1210                  {"name": "b", "type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}},
1211                  {"name": "c", "type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}},
1212                  {"name": "d", "type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}},
1213              ],
1214              [{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 3]}}],
1215          )
1216  
1217  
1218  def test_keras_autolog_input_example_load_and_predict_with_keras_sequence(keras_data_gen_sequence):
1219      mlflow.tensorflow.autolog(log_input_examples=True, log_model_signatures=True)
1220      model = create_tf_keras_model()
1221      with mlflow.start_run():
1222          model.fit(keras_data_gen_sequence)
1223          _assert_keras_autolog_input_example_load_and_predict_with_nparray(
1224              keras_data_gen_sequence[:][0][:5]
1225          )
1226  
1227  
1228  def test_keras_autolog_infers_model_signature_correctly_with_keras_sequence(
1229      keras_data_gen_sequence,
1230  ):
1231      mlflow.tensorflow.autolog(log_model_signatures=True)
1232      initial_model = create_tf_keras_model()
1233      with mlflow.start_run():
1234          initial_model.fit(keras_data_gen_sequence)
1235          _assert_autolog_infers_model_signature_correctly(
1236              [{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1, 4]}}],
1237              [{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 3]}}],
1238          )
1239  
1240  
1241  def test_keras_autolog_load_saved_hdf5_model(keras_data_gen_sequence):
1242      mlflow.tensorflow.autolog(keras_model_kwargs={"save_format": "h5"})
1243      model = create_tf_keras_model()
1244      with mlflow.start_run():
1245          model.fit(keras_data_gen_sequence)
1246          logged_model = mlflow.last_logged_model()
1247          artifact_path = local_file_uri_to_path(logged_model.artifact_location)
1248          assert Path(artifact_path, "data", "model.h5").exists()
1249  
1250  
1251  def test_keras_autolog_logs_model_signature_by_default(keras_data_gen_sequence):
1252      mlflow.autolog()
1253      initial_model = create_tf_keras_model()
1254      initial_model.fit(keras_data_gen_sequence)
1255  
1256      logged_model = mlflow.last_logged_model()
1257      artifact_path = local_file_uri_to_path(logged_model.artifact_location)
1258      mlmodel_path = os.path.join(artifact_path, "MLmodel")
1259      with open(mlmodel_path) as f:
1260          mlmodel_contents = yaml.safe_load(f)
1261      assert "signature" in mlmodel_contents.keys()
1262      signature = mlmodel_contents["signature"]
1263      assert signature is not None
1264      assert "inputs" in signature
1265      assert "outputs" in signature
1266      assert json.loads(signature["inputs"]) == [
1267          {"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1, 4]}}
1268      ]
1269      assert json.loads(signature["outputs"]) == [
1270          {"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 3]}}
1271      ]
1272  
1273  
1274  def test_extract_tf_keras_input_example_unsupported_type_returns_None():
1275      from mlflow.tensorflow.autologging import extract_tf_keras_input_example
1276  
1277      extracted_data = extract_tf_keras_input_example([1, 2, 4, 5])
1278      assert extracted_data is None, (
1279          "Keras input data extraction function should have "
1280          "returned None as input type is not supported."
1281      )
1282  
1283  
1284  def test_extract_input_example_from_tf_input_fn_unsupported_type_returns_None():
1285      from mlflow.tensorflow.autologging import extract_tf_keras_input_example
1286  
1287      extracted_data = extract_tf_keras_input_example(lambda: [1, 2, 4, 5])
1288      assert extracted_data is None, (
1289          "Tensorflow's input_fn training data extraction should have"
1290          " returned None as input type is not supported."
1291      )
1292  
1293  
1294  @pytest.mark.skipif(
1295      Version(tf.__version__) < Version("2.6.0"),
1296      reason=("TensorFlow only has a hard dependency on Keras in version >= 2.6.0"),
1297  )
1298  def test_import_keras_model_trigger_import_tensorflow():
1299      # This test is for guarding importing keras model will trigger importing tensorflow
1300      # Because in Keras>=2.6, the keras autologging patching is installed by
1301      # `mlflow.tensorflow.autolog`, suppose user enable autolog by `mlflow.autolog()`,
1302      # and then import keras, if keras does not trigger importing tensorflow,
1303      # then the keras autologging patching cannot be installed.
1304      py_executable = sys.executable
1305      _exec_cmd([
1306          py_executable,
1307          "-c",
1308          "from keras import Model; import sys; assert 'tensorflow' in sys.modules",
1309      ])
1310  
1311  
1312  def test_autolog_throw_error_on_explicit_mlflow_callback(keras_data_gen_sequence):
1313      mlflow.tensorflow.autolog()
1314  
1315      model = create_tf_keras_model()
1316      with mlflow.start_run() as run:
1317          with pytest.raises(MlflowException, match="MLflow autologging must be turned off*"):
1318              model.fit(keras_data_gen_sequence, callbacks=[MlflowCallback(run)])
1319  
1320  
1321  def test_autolog_correct_logging_frequency(random_train_data, random_one_hot_labels):
1322      logging_freq = 5
1323      num_epochs = 2
1324      batch_size = 10
1325      mlflow.tensorflow.autolog(log_every_epoch=False, log_every_n_steps=logging_freq)
1326      initial_model = create_tf_keras_model()
1327      with mlflow.start_run() as run:
1328          initial_model.fit(
1329              random_train_data,
1330              random_one_hot_labels,
1331              batch_size=batch_size,
1332              epochs=num_epochs,
1333          )
1334  
1335      client = MlflowClient()
1336      loss_history = client.get_metric_history(run.info.run_id, "loss")
1337      assert len(loss_history) == num_epochs * (len(random_train_data) // batch_size) // logging_freq
1338  
1339  
1340  def test_automatic_checkpoint_per_epoch_callback(random_train_data, random_one_hot_labels):
1341      mlflow.tensorflow.autolog(
1342          checkpoint=True,
1343          checkpoint_monitor=None,
1344          checkpoint_mode=None,
1345          checkpoint_save_best_only=False,
1346          checkpoint_save_weights_only=False,
1347          checkpoint_save_freq="epoch",
1348      )
1349  
1350      model = create_tf_keras_model()
1351  
1352      with mlflow.start_run() as run:
1353          model.fit(random_train_data, random_one_hot_labels, epochs=1)
1354      run_id = run.info.run_id
1355  
1356      logged_metrics = mlflow.artifacts.load_dict(
1357          f"runs:/{run_id}/checkpoints/epoch_0/checkpoint_metrics.json"
1358      )
1359      assert set(logged_metrics) == {"epoch", "loss", "accuracy", "global_step"}
1360      assert logged_metrics["epoch"] == 0
1361      assert logged_metrics["global_step"] == 5
1362  
1363      pred_result = model.predict(random_train_data)
1364      pred_result2 = load_checkpoint(run_id=run_id).predict(random_train_data)
1365      np.testing.assert_array_almost_equal(pred_result, pred_result2)
1366  
1367      pred_result3 = load_checkpoint(run_id=run_id, epoch=0).predict(random_train_data)
1368      np.testing.assert_array_almost_equal(pred_result, pred_result3)
1369  
1370  
1371  def test_automatic_checkpoint_per_epoch_save_weight_only_callback(
1372      random_train_data, random_one_hot_labels
1373  ):
1374      mlflow.tensorflow.autolog(
1375          checkpoint=True,
1376          checkpoint_monitor=None,
1377          checkpoint_mode=None,
1378          checkpoint_save_best_only=False,
1379          checkpoint_save_weights_only=True,
1380          checkpoint_save_freq="epoch",
1381      )
1382  
1383      model = create_tf_keras_model()
1384  
1385      with mlflow.start_run() as run:
1386          model.fit(random_train_data, random_one_hot_labels, epochs=1)
1387      run_id = run.info.run_id
1388  
1389      logged_metrics = mlflow.artifacts.load_dict(
1390          f"runs:/{run_id}/checkpoints/epoch_0/checkpoint_metrics.json"
1391      )
1392      assert set(logged_metrics) == {"epoch", "loss", "accuracy", "global_step"}
1393      assert logged_metrics["epoch"] == 0
1394      assert logged_metrics["global_step"] == 5
1395  
1396      model2 = create_tf_keras_model()
1397      pred_result = model.predict(random_train_data)
1398      pred_result2 = load_checkpoint(model=model2, run_id=run_id).predict(random_train_data)
1399      np.testing.assert_array_almost_equal(pred_result, pred_result2)
1400  
1401  
1402  def test_automatic_checkpoint_per_3_steps_callback(random_train_data, random_one_hot_labels):
1403      mlflow.tensorflow.autolog(
1404          checkpoint=True,
1405          checkpoint_monitor=None,
1406          checkpoint_mode=None,
1407          checkpoint_save_best_only=False,
1408          checkpoint_save_weights_only=False,
1409          checkpoint_save_freq=3,
1410      )
1411      model = create_tf_keras_model()
1412  
1413      with mlflow.start_run() as run:
1414          model.fit(random_train_data, random_one_hot_labels, epochs=1)
1415      run_id = run.info.run_id
1416      logged_metrics = mlflow.artifacts.load_dict(
1417          f"runs:/{run_id}/checkpoints/global_step_3/checkpoint_metrics.json"
1418      )
1419      assert set(logged_metrics) == {"epoch", "loss", "accuracy", "global_step"}
1420      assert logged_metrics["epoch"] == 0
1421      assert logged_metrics["global_step"] == 3
1422  
1423      assert isinstance(load_checkpoint(run_id=run_id), tf.keras.Sequential)
1424      assert isinstance(load_checkpoint(run_id=run_id, global_step=3), tf.keras.Sequential)
1425  
1426  
1427  def test_automatic_checkpoint_per_3_steps_save_best_only_callback(
1428      random_train_data, random_one_hot_labels
1429  ):
1430      mlflow.tensorflow.autolog(
1431          checkpoint=True,
1432          checkpoint_monitor="loss",
1433          checkpoint_mode="min",
1434          checkpoint_save_best_only=True,
1435          checkpoint_save_weights_only=False,
1436          checkpoint_save_freq=3,
1437      )
1438  
1439      model = create_tf_keras_model()
1440  
1441      with mlflow.start_run() as run:
1442          model.fit(
1443              random_train_data,
1444              random_one_hot_labels,
1445              epochs=1,
1446          )
1447      run_id = run.info.run_id
1448      logged_metrics = mlflow.artifacts.load_dict(
1449          f"runs:/{run_id}/checkpoints/latest_checkpoint_metrics.json"
1450      )
1451      assert set(logged_metrics) == {"epoch", "loss", "accuracy", "global_step"}
1452      assert logged_metrics["epoch"] == 0
1453      assert logged_metrics["global_step"] == 3
1454  
1455      assert isinstance(load_checkpoint(run_id=run_id), tf.keras.Sequential)