/ tests / tracing / test_enablement.py
test_enablement.py
  1  """
  2  Tests for mlflow.tracing.enablement module
  3  """
  4  
  5  from unittest import mock
  6  
  7  import pytest
  8  
  9  import mlflow
 10  from mlflow.entities.trace_location import UCSchemaLocation
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.tracing.enablement import (
 13      set_experiment_trace_location,
 14      unset_experiment_trace_location,
 15  )
 16  
 17  from tests.tracing.helper import skip_when_testing_trace_sdk
 18  
 19  
 20  @pytest.fixture
 21  def mock_databricks_tracking_uri():
 22      with mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"):
 23          yield
 24  
 25  
 26  @skip_when_testing_trace_sdk
 27  def test_set_experiment_trace_location(mock_databricks_tracking_uri):
 28      experiment_id = mlflow.create_experiment("test_experiment")
 29      location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 30      sql_warehouse_id = "test-warehouse-id"
 31  
 32      with mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class:
 33          mock_client = mock.MagicMock()
 34          mock_client_class.return_value = mock_client
 35  
 36          expected_location = UCSchemaLocation(
 37              catalog_name="test_catalog",
 38              schema_name="test_schema",
 39          )
 40          expected_location._otel_logs_table_name = "logs_table"
 41          expected_location._otel_spans_table_name = "spans_table"
 42          mock_client._set_experiment_trace_location.return_value = expected_location
 43  
 44          result = set_experiment_trace_location(
 45              location=location,
 46              experiment_id=experiment_id,
 47              sql_warehouse_id=sql_warehouse_id,
 48          )
 49  
 50          mock_client._set_experiment_trace_location.assert_called_once_with(
 51              location=location,
 52              experiment_id=experiment_id,
 53              sql_warehouse_id=sql_warehouse_id,
 54          )
 55          assert result == expected_location
 56  
 57  
 58  def test_set_experiment_trace_location_with_default_experiment(mock_databricks_tracking_uri):
 59      location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 60      default_experiment_id = mlflow.set_experiment("test_experiment").experiment_id
 61  
 62      with (
 63          mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class,
 64          mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=default_experiment_id),
 65      ):
 66          mock_client = mock.MagicMock()
 67          mock_client_class.return_value = mock_client
 68  
 69          expected_location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 70          mock_client._set_experiment_trace_location.return_value = expected_location
 71  
 72          result = set_experiment_trace_location(location=location)
 73          mock_client._set_experiment_trace_location.assert_called_once_with(
 74              location=location,
 75              experiment_id=default_experiment_id,
 76              sql_warehouse_id=None,
 77          )
 78  
 79          assert result == expected_location
 80  
 81  
 82  def test_set_experiment_trace_location_no_experiment(mock_databricks_tracking_uri):
 83      location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 84      with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None):
 85          with pytest.raises(MlflowException, match="Experiment ID is required"):
 86              set_experiment_trace_location(location=location)
 87  
 88  
 89  @skip_when_testing_trace_sdk
 90  def test_set_experiment_trace_location_non_existent_experiment(mock_databricks_tracking_uri):
 91      location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
 92  
 93      experiment_id = "12345"
 94      with pytest.raises(MlflowException, match="Could not find experiment with ID"):
 95          set_experiment_trace_location(location=location, experiment_id=experiment_id)
 96  
 97  
 98  def test_unset_experiment_trace_location(mock_databricks_tracking_uri):
 99      experiment_id = "123"
100      location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
101  
102      with mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class:
103          mock_client = mock.MagicMock()
104          mock_client_class.return_value = mock_client
105          unset_experiment_trace_location(
106              location=location,
107              experiment_id=experiment_id,
108          )
109          mock_client._unset_experiment_trace_location.assert_called_once_with(
110              experiment_id,
111              location,
112          )
113  
114  
115  def test_unset_experiment_trace_location_errors(mock_databricks_tracking_uri):
116      with pytest.raises(MlflowException, match="must be an instance of"):
117          unset_experiment_trace_location(location="test_catalog.test_schema")
118  
119      with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None):
120          with pytest.raises(MlflowException, match="Experiment ID is required"):
121              unset_experiment_trace_location(
122                  location=UCSchemaLocation("test_catalog", "test_schema")
123              )
124  
125  
126  def test_unset_experiment_trace_location_with_default_experiment(mock_databricks_tracking_uri):
127      default_experiment_id = "456"
128  
129      with (
130          mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class,
131          mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=default_experiment_id),
132      ):
133          mock_client = mock.MagicMock()
134          mock_client_class.return_value = mock_client
135  
136          location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema")
137          unset_experiment_trace_location(location)
138  
139          mock_client._unset_experiment_trace_location.assert_called_once_with(
140              default_experiment_id,
141              location,
142          )
143  
144  
145  def test_non_databricks_tracking_uri_errors():
146      with pytest.raises(
147          MlflowException,
148          match="The `set_experiment_trace_location` API is only supported on Databricks.",
149      ):
150          set_experiment_trace_location(location=UCSchemaLocation("test_catalog", "test_schema"))
151  
152      with pytest.raises(
153          MlflowException,
154          match="The `unset_experiment_trace_location` API is only supported on Databricks.",
155      ):
156          unset_experiment_trace_location(location=UCSchemaLocation("test_catalog", "test_schema"))