/ tests / keras / test_callback.py
test_callback.py
  1  import math
  2  
  3  import keras
  4  import numpy as np
  5  
  6  import mlflow
  7  from mlflow.keras.callback import MlflowCallback
  8  from mlflow.tracking.fluent import flush_async_logging
  9  
 10  
 11  def test_keras_mlflow_callback_log_every_epoch():
 12      # Prepare data for a 2-class classification.
 13      data = np.random.uniform(size=(20, 28, 28, 3))
 14      label = np.random.randint(2, size=20)
 15  
 16      model = keras.Sequential([
 17          keras.Input([28, 28, 3]),
 18          keras.layers.Flatten(),
 19          keras.layers.Dense(2),
 20      ])
 21  
 22      model.compile(
 23          loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
 24          optimizer=keras.optimizers.Adam(0.001),
 25          metrics=[keras.metrics.SparseCategoricalAccuracy()],
 26      )
 27  
 28      num_epochs = 2
 29      with mlflow.start_run() as run:
 30          mlflow_callback = MlflowCallback(log_every_epoch=True)
 31          model.fit(
 32              data,
 33              label,
 34              validation_data=(data, label),
 35              batch_size=4,
 36              epochs=num_epochs,
 37              callbacks=[mlflow_callback],
 38          )
 39      flush_async_logging()
 40      client = mlflow.MlflowClient()
 41      mlflow_run = client.get_run(run.info.run_id)
 42      run_metrics = mlflow_run.data.metrics
 43      model_info = mlflow_run.data.params
 44  
 45      assert "sparse_categorical_accuracy" in run_metrics
 46      assert model_info["optimizer_name"] == "adam"
 47      assert math.isclose(float(model_info["optimizer_learning_rate"]), 0.001, rel_tol=1e-6)
 48      assert "loss" in run_metrics
 49      assert "validation_loss" in run_metrics
 50  
 51      loss_history = client.get_metric_history(run_id=run.info.run_id, key="loss")
 52      assert len(loss_history) == num_epochs
 53  
 54      validation_loss_history = client.get_metric_history(
 55          run_id=run.info.run_id,
 56          key="validation_loss",
 57      )
 58      assert len(validation_loss_history) == num_epochs
 59  
 60  
 61  def test_keras_mlflow_callback_log_every_n_steps():
 62      # Prepare data for a 2-class classification.
 63      data = np.random.uniform(size=(20, 28, 28, 3))
 64      label = np.random.randint(2, size=20)
 65  
 66      model = keras.Sequential([
 67          keras.Input([28, 28, 3]),
 68          keras.layers.Flatten(),
 69          keras.layers.Dense(2),
 70      ])
 71  
 72      model.compile(
 73          loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
 74          optimizer=keras.optimizers.Adam(0.001),
 75          metrics=[keras.metrics.SparseCategoricalAccuracy()],
 76      )
 77  
 78      log_every_n_steps = 1
 79      num_epochs = 2
 80      with mlflow.start_run() as run:
 81          mlflow_callback = MlflowCallback(log_every_epoch=False, log_every_n_steps=log_every_n_steps)
 82          model.fit(
 83              data,
 84              label,
 85              validation_data=(data, label),
 86              batch_size=4,
 87              epochs=num_epochs,
 88              callbacks=[mlflow_callback],
 89          )
 90      flush_async_logging()
 91      client = mlflow.MlflowClient()
 92      mlflow_run = client.get_run(run.info.run_id)
 93      run_metrics = mlflow_run.data.metrics
 94      model_info = mlflow_run.data.params
 95  
 96      assert "sparse_categorical_accuracy" in run_metrics
 97      assert model_info["optimizer_name"] == "adam"
 98      assert math.isclose(float(model_info["optimizer_learning_rate"]), 0.001, rel_tol=1e-6)
 99      assert "loss" in run_metrics
100      assert "validation_loss" in run_metrics
101  
102      loss_history = client.get_metric_history(run_id=run.info.run_id, key="loss")
103      assert len(loss_history) == model.optimizer.iterations.numpy() // log_every_n_steps
104  
105      validation_loss_history = client.get_metric_history(
106          run_id=run.info.run_id,
107          key="validation_loss",
108      )
109      assert len(validation_loss_history) == num_epochs
110  
111  
112  def test_old_callback_still_exists():
113      assert mlflow.keras.MLflowCallback is mlflow.keras.MlflowCallback