test_inference_table_processor.py
1 import json 2 from unittest import mock 3 4 import pytest 5 6 from mlflow.entities.span import LiveSpan 7 from mlflow.entities.trace_state import TraceState 8 from mlflow.tracing.constant import ( 9 TRACE_SCHEMA_VERSION, 10 TRACE_SCHEMA_VERSION_KEY, 11 SpanAttributeKey, 12 TraceMetadataKey, 13 ) 14 from mlflow.tracing.processor.inference_table import ( 15 _HEADER_REQUEST_ID_KEY, 16 InferenceTableSpanProcessor, 17 ) 18 from mlflow.tracing.trace_manager import InMemoryTraceManager 19 from mlflow.tracing.utils import generate_trace_id_v3 20 from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_MODEL_SERVING_ENDPOINT_NAME 21 22 from tests.tracing.helper import ( 23 create_mock_otel_span, 24 create_test_trace_info, 25 skip_module_when_testing_trace_sdk, 26 ) 27 28 skip_module_when_testing_trace_sdk() 29 30 from mlflow.pyfunc.context import Context, set_prediction_context 31 from mlflow.tracking.fluent import set_active_model 32 33 _OTEL_TRACE_ID = 12345 34 _DATABRICKS_REQUEST_ID = "databricks-request-id" 35 36 37 @pytest.mark.parametrize("context_type", ["mlflow", "flask"]) 38 def test_on_start(context_type): 39 # Root span should create a new trace on start 40 span = create_mock_otel_span( 41 trace_id=_OTEL_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000 42 ) 43 trace_manager = InMemoryTraceManager.get_instance() 44 processor = InferenceTableSpanProcessor(span_exporter=mock.MagicMock()) 45 model = set_active_model(name="test-model") 46 47 if context_type == "mlflow": 48 with set_prediction_context( 49 Context(request_id=_DATABRICKS_REQUEST_ID, endpoint_name="test-endpoint") 50 ): 51 processor.on_start(span) 52 else: 53 with mock.patch( 54 "mlflow.tracing.processor.inference_table._get_flask_request" 55 ) as mock_get_flask_request: 56 request = mock_get_flask_request.return_value 57 request.headers = {_HEADER_REQUEST_ID_KEY: _DATABRICKS_REQUEST_ID} 58 59 processor.on_start(span) 60 61 expected_trace_id = generate_trace_id_v3(span) 62 63 with trace_manager.get_trace(expected_trace_id) as trace: 64 # Trace ID should be generated by MLflow 65 assert trace.info.trace_id == expected_trace_id 66 # Databricks request ID should be set to the client request ID 67 assert trace.info.client_request_id == _DATABRICKS_REQUEST_ID 68 assert trace.info.experiment_id is None 69 assert trace.info.timestamp_ms == 5 70 assert trace.info.execution_time_ms is None 71 assert trace.info.state == TraceState.IN_PROGRESS 72 73 if context_type == "mlflow": 74 assert trace.info.request_metadata == { 75 TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION), 76 MLFLOW_DATABRICKS_MODEL_SERVING_ENDPOINT_NAME: "test-endpoint", 77 TraceMetadataKey.MODEL_ID: model.model_id, 78 } 79 80 # Child span should not create a new trace 81 child_span = create_mock_otel_span( 82 trace_id=_OTEL_TRACE_ID, span_id=2, parent_id=1, start_time=8_000_000 83 ) 84 85 with set_prediction_context(Context(request_id=_DATABRICKS_REQUEST_ID)): 86 processor.on_start(child_span) 87 88 assert child_span.attributes.get(SpanAttributeKey.REQUEST_ID) == json.dumps(expected_trace_id) 89 90 # start time should not be overwritten 91 with trace_manager.get_trace(expected_trace_id) as trace: 92 assert trace.info.timestamp_ms == 5 93 94 95 def test_on_start_with_experiment_id_env_var(monkeypatch): 96 # When the MLFLOW_EXPERIMENT_ID env var is set, it should be populated into the trace info 97 monkeypatch.setenv("MLFLOW_EXPERIMENT_ID", "123") 98 99 span = create_mock_otel_span( 100 trace_id=_OTEL_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000 101 ) 102 trace_manager = InMemoryTraceManager.get_instance() 103 processor = InferenceTableSpanProcessor(span_exporter=mock.MagicMock()) 104 105 with set_prediction_context(Context(request_id=_DATABRICKS_REQUEST_ID)): 106 processor.on_start(span) 107 108 expected_trace_id = generate_trace_id_v3(span) 109 with trace_manager.get_trace(expected_trace_id) as trace: 110 assert trace.info.trace_id == expected_trace_id 111 assert trace.info.client_request_id == _DATABRICKS_REQUEST_ID 112 assert trace.info.experiment_id == "123" 113 114 115 def test_on_end(): 116 otel_span = create_mock_otel_span( 117 name="foo", 118 trace_id=_OTEL_TRACE_ID, 119 span_id=1, 120 parent_id=None, 121 start_time=5_000_000, 122 end_time=9_000_000, 123 ) 124 125 trace_id = generate_trace_id_v3(otel_span) 126 trace_info = create_test_trace_info(trace_id, 0) 127 trace_manager = InMemoryTraceManager.get_instance() 128 trace_manager.register_trace(_OTEL_TRACE_ID, trace_info) 129 130 span = LiveSpan(otel_span, trace_id) 131 span.set_status("OK") 132 span.set_inputs({"input1": "very long input" * 100}) 133 span.set_outputs({"output": "very long output" * 100}) 134 135 mock_exporter = mock.MagicMock() 136 processor = InferenceTableSpanProcessor(span_exporter=mock_exporter) 137 138 processor.on_end(otel_span) 139 140 mock_exporter.export.assert_called_once_with((otel_span,)) 141 # Trace info should be updated according to the span attributes 142 assert trace_info.state == TraceState.OK 143 assert trace_info.execution_duration == 4 144 145 # Non-root span should not be exported 146 mock_exporter.reset_mock() 147 child_span = create_mock_otel_span(trace_id=_OTEL_TRACE_ID, span_id=2, parent_id=1) 148 processor.on_end(child_span) 149 mock_exporter.export.assert_not_called() 150 151 152 def test_on_end_preserves_user_set_trace_state(): 153 otel_span = create_mock_otel_span( 154 name="foo", 155 trace_id=_OTEL_TRACE_ID, 156 span_id=1, 157 parent_id=None, 158 start_time=5_000_000, 159 end_time=9_000_000, 160 ) 161 162 trace_id = generate_trace_id_v3(otel_span) 163 trace_info = create_test_trace_info(trace_id, 0) 164 trace_manager = InMemoryTraceManager.get_instance() 165 trace_manager.register_trace(_OTEL_TRACE_ID, trace_info) 166 167 # Explicitly set trace state to ERROR (user action) 168 with trace_manager.get_trace(trace_id) as trace: 169 trace.info.state = TraceState.ERROR 170 171 span = LiveSpan(otel_span, trace_id) 172 span.set_status("OK") # Span status is OK 173 span.set_inputs({"input1": "test"}) 174 span.set_outputs({"output": "test"}) 175 176 mock_exporter = mock.MagicMock() 177 processor = InferenceTableSpanProcessor(span_exporter=mock_exporter) 178 179 processor.on_end(otel_span) 180 181 # Trace state should remain ERROR (user-set), not be overwritten by span status (OK) 182 with trace_manager.get_trace(trace_id) as trace: 183 assert trace.info.state == TraceState.ERROR 184 assert trace_info.execution_duration == 4 185 186 187 def test_on_end_updates_trace_state_when_in_progress(): 188 otel_span = create_mock_otel_span( 189 name="foo", 190 trace_id=_OTEL_TRACE_ID, 191 span_id=1, 192 parent_id=None, 193 start_time=5_000_000, 194 end_time=9_000_000, 195 ) 196 197 trace_id = generate_trace_id_v3(otel_span) 198 trace_info = create_test_trace_info(trace_id, 0, state=TraceState.IN_PROGRESS) 199 trace_manager = InMemoryTraceManager.get_instance() 200 trace_manager.register_trace(_OTEL_TRACE_ID, trace_info) 201 202 # Trace state remains IN_PROGRESS (not explicitly set by user) 203 with trace_manager.get_trace(trace_id) as trace: 204 assert trace.info.state == TraceState.IN_PROGRESS 205 206 span = LiveSpan(otel_span, trace_id) 207 span.set_status("ERROR") # Span status is ERROR 208 span.set_inputs({"input1": "test"}) 209 span.set_outputs({"output": "test"}) 210 211 mock_exporter = mock.MagicMock() 212 processor = InferenceTableSpanProcessor(span_exporter=mock_exporter) 213 214 processor.on_end(otel_span) 215 216 # Trace state should be updated to ERROR from span status 217 with trace_manager.get_trace(trace_id) as trace: 218 assert trace.info.state == TraceState.ERROR 219 assert trace_info.execution_duration == 4