/ tests / tracking / fluent / test_set_experiment_trace_location.py
test_set_experiment_trace_location.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  import mlflow
  6  from mlflow.entities import Experiment
  7  from mlflow.entities.experiment_tag import ExperimentTag
  8  from mlflow.entities.trace_location import UnityCatalog
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.tracking.fluent import (
 11      _resolve_experiment_to_trace_location,
 12      _sync_trace_destination_and_provider,
 13  )
 14  from mlflow.utils.mlflow_tags import (
 15      MLFLOW_EXPERIMENT_DATABRICKS_TRACE_ANNOTATIONS_TABLE,
 16      MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH,
 17      MLFLOW_EXPERIMENT_DATABRICKS_TRACE_LOG_STORAGE_TABLE,
 18      MLFLOW_EXPERIMENT_DATABRICKS_TRACE_SPAN_STORAGE_TABLE,
 19  )
 20  
 21  
 22  def _experiment(tags=None):
 23      tag_entities = [ExperimentTag(k, v) for k, v in (tags or {}).items()]
 24      return Experiment(
 25          experiment_id="123",
 26          name="test-experiment",
 27          artifact_location="file:/tmp",
 28          lifecycle_stage="active",
 29          tags=tag_entities,
 30      )
 31  
 32  
 33  def test_invalid_type_raises():
 34      with pytest.raises(MlflowException, match="UnityCatalog"):
 35          _resolve_experiment_to_trace_location(
 36              experiment=_experiment(),
 37              trace_location="not-a-location",
 38          )
 39  
 40  
 41  def test_uc_schema_location_is_rejected():
 42      from mlflow.entities.trace_location import UCSchemaLocation
 43  
 44      with pytest.raises(MlflowException, match="UnityCatalog"):
 45          _resolve_experiment_to_trace_location(
 46              experiment=_experiment(),
 47              trace_location=UCSchemaLocation("catalog", "schema"),
 48          )
 49  
 50  
 51  def test_no_trace_location_returns_none():
 52      result = _resolve_experiment_to_trace_location(
 53          experiment=_experiment(),
 54          trace_location=None,
 55      )
 56      assert result is None
 57  
 58  
 59  def test_non_databricks_backend_raises():
 60      with (
 61          mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="file:///tmp"),
 62          mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=False),
 63      ):
 64          with pytest.raises(MlflowException, match="only supported with a Databricks tracking URI"):
 65              _resolve_experiment_to_trace_location(
 66                  experiment=_experiment(),
 67                  trace_location=UnityCatalog("catalog", "schema", "prefix"),
 68              )
 69  
 70  
 71  def test_set_experiment_with_table_prefix_env_var_points_to_trace_location_param(monkeypatch):
 72      from mlflow.tracing.provider import _get_tracer
 73  
 74      monkeypatch.setenv("MLFLOW_TRACING_DESTINATION", "catalog.schema.prefix")
 75  
 76      mlflow.tracing.reset()
 77      mlflow.set_experiment("test-experiment")
 78  
 79      # The error surfaces lazily at trace creation time (provider init),
 80      # not eagerly at set_experiment time.
 81      with pytest.raises(
 82          MlflowException,
 83          match=r"Unity Catalog table-prefix destinations "
 84          r"\(<catalog_name>\.<schema_name>\.<table_prefix>\) are not supported in "
 85          r"MLFLOW_TRACING_DESTINATION.*Use `mlflow\.set_experiment",
 86      ):
 87          _get_tracer("test")
 88  
 89      mlflow.tracing.reset()
 90  
 91  
 92  def test_set_experiment_defaults_empty_prefix_to_experiment_id():
 93      resolved = UnityCatalog("catalog", "schema", table_prefix="123")
 94  
 95      with (
 96          mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls,
 97          mock.patch(
 98              "mlflow.tracking.fluent._resolve_experiment_to_trace_location",
 99              return_value=resolved,
100          ) as mock_resolve,
101          mock.patch("mlflow.tracking.fluent._sync_trace_destination_and_provider"),
102      ):
103          client = mock_client_cls.return_value
104          client.get_experiment_by_name.return_value = _experiment()  # experiment_id="123"
105  
106          original = UnityCatalog("catalog", "schema")  # no prefix
107          mlflow.set_experiment("test-experiment", trace_location=original)
108  
109          # Verify _resolve was called with a location that has the experiment ID as prefix
110          _, kwargs = mock_resolve.call_args
111          passed_location = kwargs["trace_location"]
112          assert passed_location.table_prefix == "123"
113          assert passed_location.catalog_name == "catalog"
114          assert passed_location.schema_name == "schema"
115  
116          # Original object should not be mutated
117          assert original.table_prefix is None
118  
119  
120  def test_creates_and_links_when_no_existing_location(monkeypatch):
121      monkeypatch.setenv("MLFLOW_TRACING_SQL_WAREHOUSE_ID", "warehouse-1")
122      requested = UnityCatalog("catalog", "schema", table_prefix="prefix")
123      resolved = UnityCatalog("catalog", "schema", table_prefix="prefix")
124  
125      with (
126          mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"),
127          mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True),
128          mock.patch("mlflow.tracing.client.TracingClient") as tc_cls,
129      ):
130          tc = tc_cls.return_value
131          tc._create_or_get_trace_location.return_value = resolved
132  
133          result = _resolve_experiment_to_trace_location(
134              experiment=_experiment(),
135              trace_location=requested,
136          )
137  
138          assert result is resolved
139          tc._create_or_get_trace_location.assert_called_once_with(requested, "warehouse-1")
140          tc._link_trace_location.assert_called_once_with(
141              experiment_id="123",
142              location=resolved,
143          )
144  
145  
146  def test_noop_when_existing_location_matches():
147      requested = UnityCatalog("catalog", "schema", table_prefix="prefix")
148      experiment = _experiment(
149          tags={
150              MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema.prefix",
151              MLFLOW_EXPERIMENT_DATABRICKS_TRACE_SPAN_STORAGE_TABLE: (
152                  "catalog.schema.prefix_otel_spans"
153              ),
154              MLFLOW_EXPERIMENT_DATABRICKS_TRACE_LOG_STORAGE_TABLE: (
155                  "catalog.schema.prefix_otel_logs"
156              ),
157              MLFLOW_EXPERIMENT_DATABRICKS_TRACE_ANNOTATIONS_TABLE: (
158                  "catalog.schema.prefix_annotations"
159              ),
160          }
161      )
162  
163      with (
164          mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"),
165          mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True),
166      ):
167          result = _resolve_experiment_to_trace_location(
168              experiment=experiment,
169              trace_location=requested,
170          )
171  
172          assert result == requested
173          assert result._otel_spans_table_name == "catalog.schema.prefix_otel_spans"
174          assert result._otel_logs_table_name == "catalog.schema.prefix_otel_logs"
175          assert result._annotations_table_name == "catalog.schema.prefix_annotations"
176  
177  
178  def test_errors_when_existing_location_differs():
179      requested = UnityCatalog("catalog", "schema", table_prefix="new_prefix")
180      experiment = _experiment(
181          tags={MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema.old_prefix"}
182      )
183  
184      with (
185          mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"),
186          mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True),
187      ):
188          with pytest.raises(MlflowException, match="already linked to a different"):
189              _resolve_experiment_to_trace_location(
190                  experiment=experiment,
191                  trace_location=requested,
192              )
193  
194  
195  def test_existing_uc_schema_destination_rejects_table_prefix():
196      requested = UnityCatalog("catalog", "schema", table_prefix="pfx")
197      experiment = _experiment(
198          tags={MLFLOW_EXPERIMENT_DATABRICKS_TRACE_DESTINATION_PATH: "catalog.schema"}
199      )
200  
201      with (
202          mock.patch("mlflow.tracking.fluent._resolve_tracking_uri", return_value="databricks"),
203          mock.patch("mlflow.tracking.fluent.is_databricks_uri", return_value=True),
204      ):
205          with pytest.raises(MlflowException, match="already linked to a different"):
206              _resolve_experiment_to_trace_location(
207                  experiment=experiment,
208                  trace_location=requested,
209              )
210  
211  
212  def test_link_failure_on_new_experiment_includes_retry_guidance():
213      with (
214          mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls,
215          mock.patch(
216              "mlflow.tracking.fluent._resolve_experiment_to_trace_location",
217              side_effect=MlflowException("backend error"),
218          ) as mock_resolve,
219      ):
220          client = mock_client_cls.return_value
221          # Simulate: experiment_name lookup returns None (not found) -> create -> get
222          client.get_experiment_by_name.return_value = None
223          client.create_experiment.return_value = "456"
224          new_exp = _experiment()
225          client.get_experiment.return_value = new_exp
226  
227          with pytest.raises(
228              MlflowException, match="fix the issue and call set_experiment again"
229          ) as exc_info:
230              mlflow.set_experiment(
231                  "new-exp",
232                  trace_location=UnityCatalog("cat", "sch", "pfx"),
233              )
234  
235          assert "backend error" in exc_info.value.message
236          mock_resolve.assert_called_once()
237  
238  
239  def test_link_failure_on_existing_experiment_reraises_original():
240      with (
241          mock.patch("mlflow.tracking.fluent.TrackingServiceClient") as mock_client_cls,
242          mock.patch(
243              "mlflow.tracking.fluent._resolve_experiment_to_trace_location",
244              side_effect=MlflowException("backend error"),
245          ) as mock_resolve,
246      ):
247          client = mock_client_cls.return_value
248          # Simulate: experiment already exists
249          client.get_experiment_by_name.return_value = _experiment()
250  
251          with pytest.raises(MlflowException, match="backend error"):
252              mlflow.set_experiment(
253                  "test-experiment",
254                  trace_location=UnityCatalog("cat", "sch", "pfx"),
255              )
256  
257          mock_resolve.assert_called_once()
258  
259  
260  def test_set_experiment_wires_trace_location_to_returned_experiment():
261      resolved = UnityCatalog("catalog", "schema", table_prefix="pfx")
262  
263      with (
264          mock.patch(
265              "mlflow.tracking.fluent._resolve_experiment_to_trace_location",
266              return_value=resolved,
267          ) as mock_register,
268          mock.patch(
269              "mlflow.tracking.fluent._sync_trace_destination_and_provider",
270          ) as mock_sync,
271      ):
272          experiment = mlflow.set_experiment("test-trace-loc-integration")
273  
274      mock_register.assert_called_once()
275      _, kwargs = mock_register.call_args
276      assert kwargs["experiment"].name == "test-trace-loc-integration"
277      mock_sync.assert_called_once_with(resolved)
278      assert experiment.trace_location is resolved
279  
280  
281  def test_set_experiment_with_trace_location_installs_uc_processor():
282      from mlflow.tracing.export.uc_table import DatabricksUCTableSpanExporter
283      from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor
284      from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, _get_tracer
285  
286      resolved = UnityCatalog("catalog", "schema", table_prefix="pfx")
287      mlflow.tracing.reset()
288      _MLFLOW_TRACE_USER_DESTINATION.reset()
289  
290      with (
291          mock.patch(
292              "mlflow.tracking.fluent._resolve_experiment_to_trace_location",
293              return_value=resolved,
294          ) as mock_register,
295      ):
296          experiment = mlflow.set_experiment("test-uc-processor")
297  
298      mock_register.assert_called_once()
299      assert experiment.trace_location is resolved
300  
301      tracer = _get_tracer("test")
302      processors = tracer.span_processor._span_processors
303      assert len(processors) == 1
304      assert isinstance(processors[0], DatabricksUCTableSpanProcessor)
305      assert isinstance(processors[0].span_exporter, DatabricksUCTableSpanExporter)
306  
307      _MLFLOW_TRACE_USER_DESTINATION.reset()
308      mlflow.tracing.reset()
309  
310  
311  def test_set_experiment_without_trace_location_does_not_install_uc_processor():
312      from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor
313      from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, _get_tracer
314  
315      mlflow.tracing.reset()
316      _MLFLOW_TRACE_USER_DESTINATION.reset()
317  
318      mlflow.set_experiment("test-no-uc-processor")
319  
320      tracer = _get_tracer("test")
321      processors = tracer.span_processor._span_processors
322      assert all(not isinstance(p, DatabricksUCTableSpanProcessor) for p in processors)
323  
324      _MLFLOW_TRACE_USER_DESTINATION.reset()
325      mlflow.tracing.reset()
326  
327  
328  @pytest.fixture
329  def _clean_tracing_state():
330      from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION, provider
331  
332      _MLFLOW_TRACE_USER_DESTINATION.reset()
333      provider.reset()
334      yield _MLFLOW_TRACE_USER_DESTINATION, provider
335      _MLFLOW_TRACE_USER_DESTINATION.reset()
336      provider.reset()
337  
338  
339  def test_sync_fresh_session_with_uc_location_sets_destination_only(_clean_tracing_state):
340      destination_registry, _ = _clean_tracing_state
341      location = UnityCatalog("catalog", "schema", table_prefix="pfx")
342  
343      _sync_trace_destination_and_provider(location)
344  
345      assert destination_registry.get() is location
346  
347  
348  def test_sync_experiment_switch_with_uc_location_resets_and_sets_new(_clean_tracing_state):
349      destination_registry, prov = _clean_tracing_state
350      destination_registry.set(UnityCatalog("catalog", "schema", table_prefix="old"))
351      prov.once._done = True
352  
353      new_location = UnityCatalog("catalog", "schema", table_prefix="new")
354      _sync_trace_destination_and_provider(new_location)
355  
356      assert destination_registry.get() is new_location
357      assert not prov.once._done
358  
359  
360  def test_sync_experiment_switch_without_location_clears_and_resets(_clean_tracing_state):
361      destination_registry, prov = _clean_tracing_state
362      destination_registry.set(UnityCatalog("catalog", "schema", table_prefix="old"))
363      prov.once._done = True
364  
365      _sync_trace_destination_and_provider(None)
366  
367      assert destination_registry.get() is None
368      assert not prov.once._done
369  
370  
371  def test_sync_fresh_session_without_location_is_noop(_clean_tracing_state):
372      destination_registry, prov = _clean_tracing_state
373  
374      _sync_trace_destination_and_provider(None)
375  
376      assert destination_registry.get() is None
377      assert not prov.once._done