/ tests / tensorflow / test_tensorflow2_metric_value_conversion_utils.py
test_tensorflow2_metric_value_conversion_utils.py
 1  import pytest
 2  import tensorflow as tf
 3  
 4  import mlflow
 5  from mlflow import tracking
 6  from mlflow.exceptions import INVALID_PARAMETER_VALUE, ErrorCode, MlflowException
 7  from mlflow.tracking.fluent import start_run
 8  from mlflow.tracking.metric_value_conversion_utils import convert_metric_value_to_float_if_possible
 9  
10  
11  def test_reraised_value_errors():
12      multi_item_tf_tensor = tf.random.uniform([2, 2], dtype=tf.float32)
13  
14      with pytest.raises(MlflowException, match=r"Failed to convert metric value to float") as e:
15          convert_metric_value_to_float_if_possible(multi_item_tf_tensor)
16  
17      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
18  
19  
20  def test_convert_metric_value_to_float():
21      tf_tensor_val = tf.random.uniform([], dtype=tf.float32)
22      assert convert_metric_value_to_float_if_possible(tf_tensor_val) == float(tf_tensor_val.numpy())
23  
24  
25  def test_log_tf_tensor_as_metric():
26      tf_tensor_val = tf.random.uniform([], dtype=tf.float32)
27      tf_tensor_float_val = float(tf_tensor_val.numpy())
28  
29      with start_run() as run:
30          mlflow.log_metric("name_tf", tf_tensor_val)
31  
32      finished_run = tracking.MlflowClient().get_run(run.info.run_id)
33      assert finished_run.data.metrics == {"name_tf": tf_tensor_float_val}