test_set_experiment_trace_location.py
1 from unittest import mock 2 3 import pytest 4 5 import mlflow 6 from mlflow.entities import Experiment 7 from mlflow.entities.experiment_tag import ExperimentTag 8 from mlflow.entities.trace_location import UnityCatalog 9 from mlflow.exceptions import MlflowException 10 from mlflow.tracking.fluent import ( 11 _resolve_experiment_to_trace_location, 12 _sync_trace_destination_and_provider, 13 ) 14 from mlflow.utils.mlflow_tags import ( 15 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_ANNOTATIONS_TABLE, 16 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH, 17 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_LOG_STORAGE_TABLE, 18 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_SPAN_STORAGE_TABLE, 19 ) 20 21 22 def _experiment(tags=None): 23 tag_entities = [ExperimentTag(k, v) for k, v in (tags or {}).items()] 24 return Experiment( 25 experiment_id="123", 26 name="test-experiment", 27 artifact_location="file:/tmp", 28 lifecycle_stage="active", 29 tags=tag_entities, 30 ) 31 32 33 def test_invalid_type_raises(): 34 with pytest.raises(MlflowException, match="UnityCatalog"): 35 _resolve_experiment_to_trace_location( 36 experiment=_experiment(), 37 trace_location="not-a-location", 38 ) 39 40 41 def test_uc_schema_location_is_rejected(): 42 from mlflow.entities.trace_location import UCSchemaLocation 43 44 with pytest.raises(MlflowException, match="UnityCatalog"): 45 _resolve_experiment_to_trace_location( 46 experiment=_experiment(), 47 trace_location=UCSchemaLocation("catalog", "schema"), 48 ) 49 50 51 def test_no_trace_location_returns_none(): 52 result = _resolve_experiment_to_trace_location( 53 experiment=_experiment(), 54 trace_location=None, 55 ) 56 assert result is None 57 58 59 def test_non_databricks_backend_raises(): 60 with ( 61 mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="file:///tmp"), 62 mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=False), 63 ): 64 with pytest.raises(MlflowException, match="only supported with a Databricks tracking URI"): 65 _resolve_experiment_to_trace_location( 66 experiment=_experiment(), 67 trace_location=UnityCatalog("catalog", "schema", "prefix"), 68 ) 69 70 71 def test_set_experiment_with_table_prefix_env_var_points_to_trace_location_param(monkeypatch): 72 from mlflow.tracing.provider import _get_tracer 73 74 monkeypatch.setenv("MLFLOW_TRACING_DESTINATION", "catalog.schema.prefix") 75 76 mlflow.tracing.reset() 77 mlflow.set_experiment("test-experiment") 78 79 # The error surfaces lazily at trace creation time (provider init), 80 # not eagerly at set_experiment time. 81 with pytest.raises( 82 MlflowException, 83 match=r"Unity Catalog table-prefix destinations " 84 r"\(<catalog_name>\.<schema_name>\.<table_prefix>\) are not supported in " 85 r"MLFLOW_TRACING_DESTINATION.*Use `mlflow\.set_experiment", 86 ): 87 _get_tracer("test") 88 89 mlflow.tracing.reset() 90 91 92 def test_set_experiment_defaults_empty_prefix_to_experiment_id(): 93 resolved = UnityCatalog("catalog", "schema", table_prefix="123") 94 95 with ( 96 mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls, 97 mock.patch( 98 "mlflow.tracking.fluent._resolve_experiment_to_trace_location", 99 return_value=resolved, 100 ) as mock_resolve, 101 mock.patch("mlflow.tracking.fluent._sync_trace_destination_and_provider"), 102 ): 103 client = mock_client_cls.return_value 104 client.get_experiment_by_name.return_value = _experiment() # experiment_id="123" 105 106 original = UnityCatalog("catalog", "schema") # no prefix 107 mlflow.set_experiment("test-experiment", trace_location=original) 108 109 # Verify _resolve was called with a location that has the experiment ID as prefix 110 _, kwargs = mock_resolve.call_args 111 passed_location = kwargs["trace_location"] 112 assert passed_location.table_prefix == "123" 113 assert passed_location.catalog_name == "catalog" 114 assert passed_location.schema_name == "schema" 115 116 # Original object should not be mutated 117 assert original.table_prefix is None 118 119 120 def test_creates_and_links_when_no_existing_location(monkeypatch): 121 monkeypatch.setenv("MLFLOW_TRACING_SQL_WAREHOUSE_ID", "warehouse-1") 122 requested = UnityCatalog("catalog", "schema", table_prefix="prefix") 123 resolved = UnityCatalog("catalog", "schema", table_prefix="prefix") 124 125 with ( 126 mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"), 127 mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True), 128 mock.patch("mlflow.tracing.client.TracingClient") as tc_cls, 129 ): 130 tc = tc_cls.return_value 131 tc._create_or_get_trace_location.return_value = resolved 132 133 result = _resolve_experiment_to_trace_location( 134 experiment=_experiment(), 135 trace_location=requested, 136 ) 137 138 assert result is resolved 139 tc._create_or_get_trace_location.assert_called_once_with(requested, "warehouse-1") 140 tc._link_trace_location.assert_called_once_with( 141 experiment_id="123", 142 location=resolved, 143 ) 144 145 146 def test_noop_when_existing_location_matches(): 147 requested = UnityCatalog("catalog", "schema", table_prefix="prefix") 148 experiment = _experiment( 149 tags={ 150 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema.prefix", 151 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_SPAN_STORAGE_TABLE: ( 152 "catalog.schema.prefix_otel_spans" 153 ), 154 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_LOG_STORAGE_TABLE: ( 155 "catalog.schema.prefix_otel_logs" 156 ), 157 MLFLOW_EXPERIMENT_DATABRICKS_TRACE_ANNOTATIONS_TABLE: ( 158 "catalog.schema.prefix_annotations" 159 ), 160 } 161 ) 162 163 with ( 164 mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"), 165 mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True), 166 ): 167 result = _resolve_experiment_to_trace_location( 168 experiment=experiment, 169 trace_location=requested, 170 ) 171 172 assert result == requested 173 assert result._otel_spans_table_name == "catalog.schema.prefix_otel_spans" 174 assert result._otel_logs_table_name == "catalog.schema.prefix_otel_logs" 175 assert result._annotations_table_name == "catalog.schema.prefix_annotations" 176 177 178 def test_errors_when_existing_location_differs(): 179 requested = UnityCatalog("catalog", "schema", table_prefix="new_prefix") 180 experiment = _experiment( 181 tags={MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema.old_prefix"} 182 ) 183 184 with ( 185 mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"), 186 mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True), 187 ): 188 with pytest.raises(MlflowException, match="already linked to a different"): 189 _resolve_experiment_to_trace_location( 190 experiment=experiment, 191 trace_location=requested, 192 ) 193 194 195 def test_existing_uc_schema_destination_rejects_table_prefix(): 196 requested = UnityCatalog("catalog", "schema", table_prefix="pfx") 197 experiment = _experiment( 198 tags={MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema"} 199 ) 200 201 with ( 202 mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"), 203 mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True), 204 ): 205 with pytest.raises(MlflowException, match="already linked to a different"): 206 _resolve_experiment_to_trace_location( 207 experiment=experiment, 208 trace_location=requested, 209 ) 210 211 212 def test_link_failure_on_new_experiment_includes_retry_guidance(): 213 with ( 214 mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls, 215 mock.patch( 216 "mlflow.tracking.fluent._resolve_experiment_to_trace_location", 217 side_effect=MlflowException("backend error"), 218 ) as mock_resolve, 219 ): 220 client = mock_client_cls.return_value 221 # Simulate: experiment_name lookup returns None (not found) -> create -> get 222 client.get_experiment_by_name.return_value = None 223 client.create_experiment.return_value = "456" 224 new_exp = _experiment() 225 client.get_experiment.return_value = new_exp 226 227 with pytest.raises( 228 MlflowException, match="fix the issue and call set_experiment again" 229 ) as exc_info: 230 mlflow.set_experiment( 231 "new-exp", 232 trace_location=UnityCatalog("cat", "sch", "pfx"), 233 ) 234 235 assert "backend error" in exc_info.value.message 236 mock_resolve.assert_called_once() 237 238 239 def test_link_failure_on_existing_experiment_reraises_original(): 240 with ( 241 mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls, 242 mock.patch( 243 "mlflow.tracking.fluent._resolve_experiment_to_trace_location", 244 side_effect=MlflowException("backend error"), 245 ) as mock_resolve, 246 ): 247 client = mock_client_cls.return_value 248 # Simulate: experiment already exists 249 client.get_experiment_by_name.return_value = _experiment() 250 251 with pytest.raises(MlflowException, match="backend error"): 252 mlflow.set_experiment( 253 "test-experiment", 254 trace_location=UnityCatalog("cat", "sch", "pfx"), 255 ) 256 257 mock_resolve.assert_called_once() 258 259 260 def test_set_experiment_wires_trace_location_to_returned_experiment(): 261 resolved = UnityCatalog("catalog", "schema", table_prefix="pfx") 262 263 with ( 264 mock.patch( 265 "mlflow.tracking.fluent._resolve_experiment_to_trace_location", 266 return_value=resolved, 267 ) as mock_register, 268 mock.patch( 269 "mlflow.tracking.fluent._sync_trace_destination_and_provider", 270 ) as mock_sync, 271 ): 272 experiment = mlflow.set_experiment("test-trace-loc-integration") 273 274 mock_register.assert_called_once() 275 _, kwargs = mock_register.call_args 276 assert kwargs["experiment"].name == "test-trace-loc-integration" 277 mock_sync.assert_called_once_with(resolved) 278 assert experiment.trace_location is resolved 279 280 281 def test_set_experiment_with_trace_location_installs_uc_processor(): 282 from mlflow.tracing.export.uc_table import DatabricksUCTableSpanExporter 283 from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor 284 from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, _get_tracer 285 286 resolved = UnityCatalog("catalog", "schema", table_prefix="pfx") 287 mlflow.tracing.reset() 288 _MLFLOW_TRACE_USER_DESTINATION.reset() 289 290 with ( 291 mock.patch( 292 "mlflow.tracking.fluent._resolve_experiment_to_trace_location", 293 return_value=resolved, 294 ) as mock_register, 295 ): 296 experiment = mlflow.set_experiment("test-uc-processor") 297 298 mock_register.assert_called_once() 299 assert experiment.trace_location is resolved 300 301 tracer = _get_tracer("test") 302 processors = tracer.span_processor._span_processors 303 assert len(processors) == 1 304 assert isinstance(processors[0], DatabricksUCTableSpanProcessor) 305 assert isinstance(processors[0].span_exporter, DatabricksUCTableSpanExporter) 306 307 _MLFLOW_TRACE_USER_DESTINATION.reset() 308 mlflow.tracing.reset() 309 310 311 def test_set_experiment_without_trace_location_does_not_install_uc_processor(): 312 from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor 313 from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, _get_tracer 314 315 mlflow.tracing.reset() 316 _MLFLOW_TRACE_USER_DESTINATION.reset() 317 318 mlflow.set_experiment("test-no-uc-processor") 319 320 tracer = _get_tracer("test") 321 processors = tracer.span_processor._span_processors 322 assert all(not isinstance(p, DatabricksUCTableSpanProcessor) for p in processors) 323 324 _MLFLOW_TRACE_USER_DESTINATION.reset() 325 mlflow.tracing.reset() 326 327 328 @pytest.fixture 329 def _clean_tracing_state(): 330 from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, provider 331 332 _MLFLOW_TRACE_USER_DESTINATION.reset() 333 provider.reset() 334 yield _MLFLOW_TRACE_USER_DESTINATION, provider 335 _MLFLOW_TRACE_USER_DESTINATION.reset() 336 provider.reset() 337 338 339 def test_sync_fresh_session_with_uc_location_sets_destination_only(_clean_tracing_state): 340 destination_registry, _ = _clean_tracing_state 341 location = UnityCatalog("catalog", "schema", table_prefix="pfx") 342 343 _sync_trace_destination_and_provider(location) 344 345 assert destination_registry.get() is location 346 347 348 def test_sync_experiment_switch_with_uc_location_resets_and_sets_new(_clean_tracing_state): 349 destination_registry, prov = _clean_tracing_state 350 destination_registry.set(UnityCatalog("catalog", "schema", table_prefix="old")) 351 prov.once._done = True 352 353 new_location = UnityCatalog("catalog", "schema", table_prefix="new") 354 _sync_trace_destination_and_provider(new_location) 355 356 assert destination_registry.get() is new_location 357 assert not prov.once._done 358 359 360 def test_sync_experiment_switch_without_location_clears_and_resets(_clean_tracing_state): 361 destination_registry, prov = _clean_tracing_state 362 destination_registry.set(UnityCatalog("catalog", "schema", table_prefix="old")) 363 prov.once._done = True 364 365 _sync_trace_destination_and_provider(None) 366 367 assert destination_registry.get() is None 368 assert not prov.once._done 369 370 371 def test_sync_fresh_session_without_location_is_noop(_clean_tracing_state): 372 destination_registry, prov = _clean_tracing_state 373 374 _sync_trace_destination_and_provider(None) 375 376 assert destination_registry.get() is None 377 assert not prov.once._done