test_tensorflow2_metric_value_conversion_utils.py
1 import pytest 2 import tensorflow as tf 3 4 import mlflow 5 from mlflow import tracking 6 from mlflow.exceptions import INVALID_PARAMETER_VALUE, ErrorCode, MlflowException 7 from mlflow.tracking.fluent import start_run 8 from mlflow.tracking.metric_value_conversion_utils import convert_metric_value_to_float_if_possible 9 10 11 def test_reraised_value_errors(): 12 multi_item_tf_tensor = tf.random.uniform([2, 2], dtype=tf.float32) 13 14 with pytest.raises(MlflowException, match=r"Failed to convert metric value to float") as e: 15 convert_metric_value_to_float_if_possible(multi_item_tf_tensor) 16 17 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 18 19 20 def test_convert_metric_value_to_float(): 21 tf_tensor_val = tf.random.uniform([], dtype=tf.float32) 22 assert convert_metric_value_to_float_if_possible(tf_tensor_val) == float(tf_tensor_val.numpy()) 23 24 25 def test_log_tf_tensor_as_metric(): 26 tf_tensor_val = tf.random.uniform([], dtype=tf.float32) 27 tf_tensor_float_val = float(tf_tensor_val.numpy()) 28 29 with start_run() as run: 30 mlflow.log_metric("name_tf", tf_tensor_val) 31 32 finished_run = tracking.MlflowClient().get_run(run.info.run_id) 33 assert finished_run.data.metrics == {"name_tf": tf_tensor_float_val}