test_enablement.py
1 """ 2 Tests for mlflow.tracing.enablement module 3 """ 4 5 from unittest import mock 6 7 import pytest 8 9 import mlflow 10 from mlflow.entities.trace_location import UCSchemaLocation 11 from mlflow.exceptions import MlflowException 12 from mlflow.tracing.enablement import ( 13 set_experiment_trace_location, 14 unset_experiment_trace_location, 15 ) 16 17 from tests.tracing.helper import skip_when_testing_trace_sdk 18 19 20 @pytest.fixture 21 def mock_databricks_tracking_uri(): 22 with mock.patch("mlflow.tracking.get_tracking_uri", return_value="databricks"): 23 yield 24 25 26 @skip_when_testing_trace_sdk 27 def test_set_experiment_trace_location(mock_databricks_tracking_uri): 28 experiment_id = mlflow.create_experiment("test_experiment") 29 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 30 sql_warehouse_id = "test-warehouse-id" 31 32 with mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class: 33 mock_client = mock.MagicMock() 34 mock_client_class.return_value = mock_client 35 36 expected_location = UCSchemaLocation( 37 catalog_name="test_catalog", 38 schema_name="test_schema", 39 ) 40 expected_location._otel_logs_table_name = "logs_table" 41 expected_location._otel_spans_table_name = "spans_table" 42 mock_client._set_experiment_trace_location.return_value = expected_location 43 44 result = set_experiment_trace_location( 45 location=location, 46 experiment_id=experiment_id, 47 sql_warehouse_id=sql_warehouse_id, 48 ) 49 50 mock_client._set_experiment_trace_location.assert_called_once_with( 51 location=location, 52 experiment_id=experiment_id, 53 sql_warehouse_id=sql_warehouse_id, 54 ) 55 assert result == expected_location 56 57 58 def test_set_experiment_trace_location_with_default_experiment(mock_databricks_tracking_uri): 59 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 60 default_experiment_id = mlflow.set_experiment("test_experiment").experiment_id 61 62 with ( 63 mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class, 64 mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=default_experiment_id), 65 ): 66 mock_client = mock.MagicMock() 67 mock_client_class.return_value = mock_client 68 69 expected_location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 70 mock_client._set_experiment_trace_location.return_value = expected_location 71 72 result = set_experiment_trace_location(location=location) 73 mock_client._set_experiment_trace_location.assert_called_once_with( 74 location=location, 75 experiment_id=default_experiment_id, 76 sql_warehouse_id=None, 77 ) 78 79 assert result == expected_location 80 81 82 def test_set_experiment_trace_location_no_experiment(mock_databricks_tracking_uri): 83 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 84 with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None): 85 with pytest.raises(MlflowException, match="Experiment ID is required"): 86 set_experiment_trace_location(location=location) 87 88 89 @skip_when_testing_trace_sdk 90 def test_set_experiment_trace_location_non_existent_experiment(mock_databricks_tracking_uri): 91 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 92 93 experiment_id = "12345" 94 with pytest.raises(MlflowException, match="Could not find experiment with ID"): 95 set_experiment_trace_location(location=location, experiment_id=experiment_id) 96 97 98 def test_unset_experiment_trace_location(mock_databricks_tracking_uri): 99 experiment_id = "123" 100 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 101 102 with mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class: 103 mock_client = mock.MagicMock() 104 mock_client_class.return_value = mock_client 105 unset_experiment_trace_location( 106 location=location, 107 experiment_id=experiment_id, 108 ) 109 mock_client._unset_experiment_trace_location.assert_called_once_with( 110 experiment_id, 111 location, 112 ) 113 114 115 def test_unset_experiment_trace_location_errors(mock_databricks_tracking_uri): 116 with pytest.raises(MlflowException, match="must be an instance of"): 117 unset_experiment_trace_location(location="test_catalog.test_schema") 118 119 with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None): 120 with pytest.raises(MlflowException, match="Experiment ID is required"): 121 unset_experiment_trace_location( 122 location=UCSchemaLocation("test_catalog", "test_schema") 123 ) 124 125 126 def test_unset_experiment_trace_location_with_default_experiment(mock_databricks_tracking_uri): 127 default_experiment_id = "456" 128 129 with ( 130 mock.patch("mlflow.tracing.client.TracingClient") as mock_client_class, 131 mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=default_experiment_id), 132 ): 133 mock_client = mock.MagicMock() 134 mock_client_class.return_value = mock_client 135 136 location = UCSchemaLocation(catalog_name="test_catalog", schema_name="test_schema") 137 unset_experiment_trace_location(location) 138 139 mock_client._unset_experiment_trace_location.assert_called_once_with( 140 default_experiment_id, 141 location, 142 ) 143 144 145 def test_non_databricks_tracking_uri_errors(): 146 with pytest.raises( 147 MlflowException, 148 match="The `set_experiment_trace_location` API is only supported on Databricks.", 149 ): 150 set_experiment_trace_location(location=UCSchemaLocation("test_catalog", "test_schema")) 151 152 with pytest.raises( 153 MlflowException, 154 match="The `unset_experiment_trace_location` API is only supported on Databricks.", 155 ): 156 unset_experiment_trace_location(location=UCSchemaLocation("test_catalog", "test_schema"))