test_uc_table_processor.py
1 from unittest import mock 2 3 import pytest 4 5 import mlflow 6 import mlflow.tracking.context.default_context 7 from mlflow.entities.span import LiveSpan 8 from mlflow.entities.trace_location import TraceLocationType, UCSchemaLocation 9 from mlflow.entities.trace_state import TraceState 10 from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME 11 from mlflow.exceptions import MlflowException 12 from mlflow.tracing.constant import TraceMetadataKey 13 from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor 14 from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION 15 from mlflow.tracing.trace_manager import InMemoryTraceManager 16 17 from tests.tracing.helper import ( 18 create_mock_otel_span, 19 create_test_trace_info, 20 ) 21 22 23 @pytest.fixture 24 def active_uc_schema_destination(): 25 destination = UCSchemaLocation(catalog_name="catalog1", schema_name="schema1") 26 destination._otel_spans_table_name = "spans_table" 27 _MLFLOW_TRACE_USER_DESTINATION.set(destination) 28 try: 29 yield 30 finally: 31 _MLFLOW_TRACE_USER_DESTINATION.reset() 32 33 34 def test_on_start_with_uc_table_name(monkeypatch, active_uc_schema_destination): 35 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 36 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "alice") 37 38 # Root span should create a new trace on start 39 trace_id = 12345 40 span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000) 41 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 42 processor.on_start(span) 43 44 # Check that trace was created in trace manager 45 trace_manager = InMemoryTraceManager.get_instance() 46 traces = trace_manager._traces 47 assert len(traces) == 1 48 49 # Get the created trace 50 created_trace = list(traces.values())[0] 51 trace_info = created_trace.info 52 53 # Verify trace location is UC_SCHEMA type 54 assert trace_info.trace_location.type == TraceLocationType.UC_SCHEMA 55 uc_schema = trace_info.trace_location.uc_schema 56 assert uc_schema.catalog_name == "catalog1" 57 assert uc_schema.schema_name == "schema1" 58 59 # Verify trace state and timing 60 assert trace_info.state == TraceState.IN_PROGRESS 61 assert trace_info.request_time == 5 # 5_000_000 nanoseconds -> 5 milliseconds 62 assert trace_info.execution_duration is None 63 64 65 def test_on_start_without_uc_table_name(monkeypatch): 66 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 67 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "alice") 68 69 # Root span should create a new trace on start 70 trace_id = 12345 71 span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000) 72 73 _MLFLOW_TRACE_USER_DESTINATION.reset() 74 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 75 with pytest.raises(MlflowException, match="Unity Catalog spans table name is not set"): 76 processor.on_start(span) 77 78 # Check that trace was still created in trace manager 79 trace_manager = InMemoryTraceManager.get_instance() 80 traces = trace_manager._traces 81 assert len(traces) == 0 82 83 84 def test_constructor_disables_metrics_export(): 85 mock_exporter = mock.MagicMock() 86 processor = DatabricksUCTableSpanProcessor(span_exporter=mock_exporter) 87 88 # The export_metrics should be False 89 assert not processor._export_metrics 90 91 92 def test_trace_id_generation_with_uc_schema(active_uc_schema_destination): 93 trace_id = 12345 94 span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000) 95 96 with mock.patch( 97 "mlflow.tracing.processor.uc_table.generate_trace_id_v4", 98 return_value="trace:/catalog1.schema1/12345", 99 ) as mock_generate_trace_id: 100 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 101 processor.on_start(span) 102 103 # Verify generate_trace_id_v4 was called with correct arguments 104 mock_generate_trace_id.assert_called_once_with(span, "catalog1.schema1") 105 106 107 def test_on_end(): 108 trace_info = create_test_trace_info("request_id", 0) 109 trace_manager = InMemoryTraceManager.get_instance() 110 trace_manager.register_trace("trace_id", trace_info) 111 112 otel_span = create_mock_otel_span( 113 name="foo", 114 trace_id="trace_id", 115 span_id=1, 116 parent_id=None, 117 start_time=5_000_000, 118 end_time=9_000_000, 119 ) 120 span = LiveSpan(otel_span, "request_id") 121 span.set_status("OK") 122 span.set_inputs({"input1": "test input"}) 123 span.set_outputs({"output": "test output"}) 124 125 mock_exporter = mock.MagicMock() 126 processor = DatabricksUCTableSpanProcessor(span_exporter=mock_exporter) 127 128 processor.on_end(otel_span) 129 130 # Verify span was exported 131 mock_exporter.export.assert_called_once_with((otel_span,)) 132 133 134 def test_on_end_sets_user_session_span_attributes(): 135 trace_manager = InMemoryTraceManager.get_instance() 136 with mock.patch.object(trace_manager, "pop_trace", return_value=None): 137 with mlflow.start_span("foo") as live_span: 138 mlflow.update_current_trace( 139 metadata={ 140 TraceMetadataKey.TRACE_USER: "alice", 141 TraceMetadataKey.TRACE_SESSION: "sess-123", 142 } 143 ) 144 otel_span = live_span._span 145 146 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 147 processor.on_end(otel_span) 148 149 assert otel_span.attributes["user.id"] == "alice" 150 assert otel_span.attributes["session.id"] == "sess-123" 151 152 153 def test_on_end_does_not_set_user_session_attributes_when_missing(): 154 trace_manager = InMemoryTraceManager.get_instance() 155 with mock.patch.object(trace_manager, "pop_trace", return_value=None): 156 with mlflow.start_span("foo") as live_span: 157 otel_span = live_span._span 158 159 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 160 processor.on_end(otel_span) 161 162 assert "user.id" not in otel_span.attributes 163 assert "session.id" not in otel_span.attributes 164 165 166 def test_trace_metadata_and_tags(active_uc_schema_destination): 167 trace_id = 12345 168 span = create_mock_otel_span(trace_id=trace_id, span_id=1, parent_id=None, start_time=5_000_000) 169 processor = DatabricksUCTableSpanProcessor(span_exporter=mock.MagicMock()) 170 processor.on_start(span) 171 172 # Get the created trace 173 trace_manager = InMemoryTraceManager.get_instance() 174 traces = trace_manager._traces 175 created_trace = list(traces.values())[0] 176 trace_info = created_trace.info 177 178 # Check that metadata and tags are present 179 assert trace_info.trace_metadata is not None 180 assert trace_info.tags is not None