/ tests / tracing / processor / test_inference_table_processor.py
test_inference_table_processor.py
  1  import json
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  from mlflow.entities.span import LiveSpan
  7  from mlflow.entities.trace_state import TraceState
  8  from mlflow.tracing.constant import (
  9      TRACE_SCHEMA_VERSION,
 10      TRACE_SCHEMA_VERSION_KEY,
 11      SpanAttributeKey,
 12      TraceMetadataKey,
 13  )
 14  from mlflow.tracing.processor.inference_table import (
 15      _HEADER_REQUEST_ID_KEY,
 16      InferenceTableSpanProcessor,
 17  )
 18  from mlflow.tracing.trace_manager import InMemoryTraceManager
 19  from mlflow.tracing.utils import generate_trace_id_v3
 20  from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_MODEL_SERVING_ENDPOINT_NAME
 21  
 22  from tests.tracing.helper import (
 23      create_mock_otel_span,
 24      create_test_trace_info,
 25      skip_module_when_testing_trace_sdk,
 26  )
 27  
 28  skip_module_when_testing_trace_sdk()
 29  
 30  from mlflow.pyfunc.context import Context, set_prediction_context
 31  from mlflow.tracking.fluent import set_active_model
 32  
 33  _OTEL_TRACE_ID = 12345
 34  _DATABRICKS_REQUEST_ID = "databricks-request-id"
 35  
 36  
 37  @pytest.mark.parametrize("context_type", ["mlflow", "flask"])
 38  def test_on_start(context_type):
 39      # Root span should create a new trace on start
 40      span = create_mock_otel_span(
 41          trace_id=_OTEL_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000
 42      )
 43      trace_manager = InMemoryTraceManager.get_instance()
 44      processor = InferenceTableSpanProcessor(span_exporter=mock.MagicMock())
 45      model = set_active_model(name="test-model")
 46  
 47      if context_type == "mlflow":
 48          with set_prediction_context(
 49              Context(request_id=_DATABRICKS_REQUEST_ID, endpoint_name="test-endpoint")
 50          ):
 51              processor.on_start(span)
 52      else:
 53          with mock.patch(
 54              "mlflow.tracing.processor.inference_table._get_flask_request"
 55          ) as mock_get_flask_request:
 56              request = mock_get_flask_request.return_value
 57              request.headers = {_HEADER_REQUEST_ID_KEY: _DATABRICKS_REQUEST_ID}
 58  
 59              processor.on_start(span)
 60  
 61      expected_trace_id = generate_trace_id_v3(span)
 62  
 63      with trace_manager.get_trace(expected_trace_id) as trace:
 64          # Trace ID should be generated by MLflow
 65          assert trace.info.trace_id == expected_trace_id
 66          # Databricks request ID should be set to the client request ID
 67          assert trace.info.client_request_id == _DATABRICKS_REQUEST_ID
 68          assert trace.info.experiment_id is None
 69          assert trace.info.timestamp_ms == 5
 70          assert trace.info.execution_time_ms is None
 71          assert trace.info.state == TraceState.IN_PROGRESS
 72  
 73          if context_type == "mlflow":
 74              assert trace.info.request_metadata == {
 75                  TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION),
 76                  MLFLOW_DATABRICKS_MODEL_SERVING_ENDPOINT_NAME: "test-endpoint",
 77                  TraceMetadataKey.MODEL_ID: model.model_id,
 78              }
 79  
 80      # Child span should not create a new trace
 81      child_span = create_mock_otel_span(
 82          trace_id=_OTEL_TRACE_ID, span_id=2, parent_id=1, start_time=8_000_000
 83      )
 84  
 85      with set_prediction_context(Context(request_id=_DATABRICKS_REQUEST_ID)):
 86          processor.on_start(child_span)
 87  
 88      assert child_span.attributes.get(SpanAttributeKey.REQUEST_ID) == json.dumps(expected_trace_id)
 89  
 90      # start time should not be overwritten
 91      with trace_manager.get_trace(expected_trace_id) as trace:
 92          assert trace.info.timestamp_ms == 5
 93  
 94  
 95  def test_on_start_with_experiment_id_env_var(monkeypatch):
 96      # When the MLFLOW_EXPERIMENT_ID env var is set, it should be populated into the trace info
 97      monkeypatch.setenv("MLFLOW_EXPERIMENT_ID", "123")
 98  
 99      span = create_mock_otel_span(
100          trace_id=_OTEL_TRACE_ID, span_id=1, parent_id=None, start_time=5_000_000
101      )
102      trace_manager = InMemoryTraceManager.get_instance()
103      processor = InferenceTableSpanProcessor(span_exporter=mock.MagicMock())
104  
105      with set_prediction_context(Context(request_id=_DATABRICKS_REQUEST_ID)):
106          processor.on_start(span)
107  
108      expected_trace_id = generate_trace_id_v3(span)
109      with trace_manager.get_trace(expected_trace_id) as trace:
110          assert trace.info.trace_id == expected_trace_id
111          assert trace.info.client_request_id == _DATABRICKS_REQUEST_ID
112          assert trace.info.experiment_id == "123"
113  
114  
115  def test_on_end():
116      otel_span = create_mock_otel_span(
117          name="foo",
118          trace_id=_OTEL_TRACE_ID,
119          span_id=1,
120          parent_id=None,
121          start_time=5_000_000,
122          end_time=9_000_000,
123      )
124  
125      trace_id = generate_trace_id_v3(otel_span)
126      trace_info = create_test_trace_info(trace_id, 0)
127      trace_manager = InMemoryTraceManager.get_instance()
128      trace_manager.register_trace(_OTEL_TRACE_ID, trace_info)
129  
130      span = LiveSpan(otel_span, trace_id)
131      span.set_status("OK")
132      span.set_inputs({"input1": "very long input" * 100})
133      span.set_outputs({"output": "very long output" * 100})
134  
135      mock_exporter = mock.MagicMock()
136      processor = InferenceTableSpanProcessor(span_exporter=mock_exporter)
137  
138      processor.on_end(otel_span)
139  
140      mock_exporter.export.assert_called_once_with((otel_span,))
141      # Trace info should be updated according to the span attributes
142      assert trace_info.state == TraceState.OK
143      assert trace_info.execution_duration == 4
144  
145      # Non-root span should not be exported
146      mock_exporter.reset_mock()
147      child_span = create_mock_otel_span(trace_id=_OTEL_TRACE_ID, span_id=2, parent_id=1)
148      processor.on_end(child_span)
149      mock_exporter.export.assert_not_called()
150  
151  
152  def test_on_end_preserves_user_set_trace_state():
153      otel_span = create_mock_otel_span(
154          name="foo",
155          trace_id=_OTEL_TRACE_ID,
156          span_id=1,
157          parent_id=None,
158          start_time=5_000_000,
159          end_time=9_000_000,
160      )
161  
162      trace_id = generate_trace_id_v3(otel_span)
163      trace_info = create_test_trace_info(trace_id, 0)
164      trace_manager = InMemoryTraceManager.get_instance()
165      trace_manager.register_trace(_OTEL_TRACE_ID, trace_info)
166  
167      # Explicitly set trace state to ERROR (user action)
168      with trace_manager.get_trace(trace_id) as trace:
169          trace.info.state = TraceState.ERROR
170  
171      span = LiveSpan(otel_span, trace_id)
172      span.set_status("OK")  # Span status is OK
173      span.set_inputs({"input1": "test"})
174      span.set_outputs({"output": "test"})
175  
176      mock_exporter = mock.MagicMock()
177      processor = InferenceTableSpanProcessor(span_exporter=mock_exporter)
178  
179      processor.on_end(otel_span)
180  
181      # Trace state should remain ERROR (user-set), not be overwritten by span status (OK)
182      with trace_manager.get_trace(trace_id) as trace:
183          assert trace.info.state == TraceState.ERROR
184      assert trace_info.execution_duration == 4
185  
186  
187  def test_on_end_updates_trace_state_when_in_progress():
188      otel_span = create_mock_otel_span(
189          name="foo",
190          trace_id=_OTEL_TRACE_ID,
191          span_id=1,
192          parent_id=None,
193          start_time=5_000_000,
194          end_time=9_000_000,
195      )
196  
197      trace_id = generate_trace_id_v3(otel_span)
198      trace_info = create_test_trace_info(trace_id, 0, state=TraceState.IN_PROGRESS)
199      trace_manager = InMemoryTraceManager.get_instance()
200      trace_manager.register_trace(_OTEL_TRACE_ID, trace_info)
201  
202      # Trace state remains IN_PROGRESS (not explicitly set by user)
203      with trace_manager.get_trace(trace_id) as trace:
204          assert trace.info.state == TraceState.IN_PROGRESS
205  
206      span = LiveSpan(otel_span, trace_id)
207      span.set_status("ERROR")  # Span status is ERROR
208      span.set_inputs({"input1": "test"})
209      span.set_outputs({"output": "test"})
210  
211      mock_exporter = mock.MagicMock()
212      processor = InferenceTableSpanProcessor(span_exporter=mock_exporter)
213  
214      processor.on_end(otel_span)
215  
216      # Trace state should be updated to ERROR from span status
217      with trace_manager.get_trace(trace_id) as trace:
218          assert trace.info.state == TraceState.ERROR
219      assert trace_info.execution_duration == 4