/ tests / tracing / processor / test_uc_table_processor.py
test_uc_table_processor.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  import mlflow
  6  import mlflow.tracking.context.default_context
  7  from mlflow.entities.span import LiveSpan
  8  from mlflow.entities.trace_location import TraceLocationType, UCSchemaLocation
  9  from mlflow.entities.trace_state import TraceState
 10  from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.tracing.constant import TraceMetadataKey
 13  from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor
 14  from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION
 15  from mlflow.tracing.trace_manager import InMemoryTraceManager
 16  
 17  from tests.tracing.helper import (
 18      create_mock_otel_span,
 19      create_test_trace_info,
 20  )
 21  
 22  
 23  @pytest.fixture
 24  def active_uc_schema_destination():
 25      destination = UCSchemaLocation(catalog_name="catalog1", schema_name="schema1")
 26      destination._otel_spans_table_name = "spans_table"
 27      _MLFLOW_TRACE_USER_DESTINATION.set(destination)
 28      try:
 29          yield
 30      finally:
 31          _MLFLOW_TRACE_USER_DESTINATION.reset()
 32  
 33  
 34  def test_on_start_with_uc_table_name(monkeypatch, active_uc_schema_destination):
 35      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
 36      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "alice")
 37  
 38      # Root span should create a new trace on start
 39      trace_id = 12345
 40      span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000)
 41      processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
 42      processor.on_start(span)
 43  
 44      # Check that trace was created in trace manager
 45      trace_manager = InMemoryTraceManager.get_instance()
 46      traces = trace_manager._traces
 47      assert len(traces) == 1
 48  
 49      # Get the created trace
 50      created_trace = list(traces.values())[0]
 51      trace_info = created_trace.info
 52  
 53      # Verify trace location is UC_SCHEMA type
 54      assert trace_info.trace_location.type == TraceLocationType.UC_SCHEMA
 55      uc_schema = trace_info.trace_location.uc_schema
 56      assert uc_schema.catalog_name == "catalog1"
 57      assert uc_schema.schema_name == "schema1"
 58  
 59      # Verify trace state and timing
 60      assert trace_info.state == TraceState.IN_PROGRESS
 61      assert trace_info.request_time == 5  # 5_000_000 nanoseconds -> 5 milliseconds
 62      assert trace_info.execution_duration is None
 63  
 64  
 65  def test_on_start_without_uc_table_name(monkeypatch):
 66      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
 67      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "alice")
 68  
 69      # Root span should create a new trace on start
 70      trace_id = 12345
 71      span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000)
 72  
 73      _MLFLOW_TRACE_USER_DESTINATION.reset()
 74      processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
 75      with pytest.raises(MlflowException, match="Unity Catalog spans table name is not set"):
 76          processor.on_start(span)
 77  
 78      # Check that trace was still created in trace manager
 79      trace_manager = InMemoryTraceManager.get_instance()
 80      traces = trace_manager._traces
 81      assert len(traces) == 0
 82  
 83  
 84  def test_constructor_disables_metrics_export():
 85      mock_exporter = mock.MagicMock()
 86      processor = DatabricksUCTableSpanProcessor(span_exporter=mock_exporter)
 87  
 88      # The export_metrics should be False
 89      assert not processor._export_metrics
 90  
 91  
 92  def test_trace_id_generation_with_uc_schema(active_uc_schema_destination):
 93      trace_id = 12345
 94      span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000)
 95  
 96      with mock.patch(
 97          "mlflow.tracing.processor.uc_table.generate_trace_id_v4",
 98          return_value="trace:/catalog1.schema1/12345",
 99      ) as mock_generate_trace_id:
100          processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
101          processor.on_start(span)
102  
103          # Verify generate_trace_id_v4 was called with correct arguments
104          mock_generate_trace_id.assert_called_once_with(span, "catalog1.schema1")
105  
106  
107  def test_on_end():
108      trace_info = create_test_trace_info("request_id", 0)
109      trace_manager = InMemoryTraceManager.get_instance()
110      trace_manager.register_trace("trace_id", trace_info)
111  
112      otel_span = create_mock_otel_span(
113          name="foo",
114          trace_id="trace_id",
115          span_id=1,
116          parent_id=None,
117          start_time=5_000_000,
118          end_time=9_000_000,
119      )
120      span = LiveSpan(otel_span, "request_id")
121      span.set_status("OK")
122      span.set_inputs({"input1": "test input"})
123      span.set_outputs({"output": "test output"})
124  
125      mock_exporter = mock.MagicMock()
126      processor = DatabricksUCTableSpanProcessor(span_exporter=mock_exporter)
127  
128      processor.on_end(otel_span)
129  
130      # Verify span was exported
131      mock_exporter.export.assert_called_once_with((otel_span,))
132  
133  
134  def test_on_end_sets_user_session_span_attributes():
135      trace_manager = InMemoryTraceManager.get_instance()
136      with mock.patch.object(trace_manager, "pop_trace", return_value=None):
137          with mlflow.start_span("foo") as live_span:
138              mlflow.update_current_trace(
139                  metadata={
140                      TraceMetadataKey.TRACE_USER: "alice",
141                      TraceMetadataKey.TRACE_SESSION: "sess-123",
142                  }
143              )
144              otel_span = live_span._span
145  
146      processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
147      processor.on_end(otel_span)
148  
149      assert otel_span.attributes["user.id"] == "alice"
150      assert otel_span.attributes["session.id"] == "sess-123"
151  
152  
153  def test_on_end_does_not_set_user_session_attributes_when_missing():
154      trace_manager = InMemoryTraceManager.get_instance()
155      with mock.patch.object(trace_manager, "pop_trace", return_value=None):
156          with mlflow.start_span("foo") as live_span:
157              otel_span = live_span._span
158  
159      processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
160      processor.on_end(otel_span)
161  
162      assert "user.id" not in otel_span.attributes
163      assert "session.id" not in otel_span.attributes
164  
165  
166  def test_trace_metadata_and_tags(active_uc_schema_destination):
167      trace_id = 12345
168      span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000)
169      processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock())
170      processor.on_start(span)
171  
172      # Get the created trace
173      trace_manager = InMemoryTraceManager.get_instance()
174      traces = trace_manager._traces
175      created_trace = list(traces.values())[0]
176      trace_info = created_trace.info
177  
178      # Check that metadata and tags are present
179      assert trace_info.trace_metadata is not None
180      assert trace_info.tags is not None