/ tests / tracing / helper.py
helper.py
  1  import os
  2  import time
  3  import uuid
  4  from concurrent.futures import ThreadPoolExecutor
  5  from dataclasses import dataclass, field
  6  from typing import Any
  7  from unittest import mock
  8  
  9  import opentelemetry.trace as trace_api
 10  import pytest
 11  from opentelemetry.sdk.trace import Event, ReadableSpan
 12  from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
 13  
 14  import mlflow
 15  from mlflow.entities import Trace, TraceData, TraceInfo
 16  from mlflow.entities.trace_location import TraceLocation
 17  from mlflow.entities.trace_state import TraceState
 18  from mlflow.ml_package_versions import FLAVOR_TO_MODULE_NAME
 19  from mlflow.tracing.client import TracingClient
 20  from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY
 21  from mlflow.tracing.export.inference_table import pop_trace
 22  from mlflow.tracing.processor.mlflow_v3 import MlflowV3SpanProcessor
 23  from mlflow.tracing.processor.otel import OtelSpanProcessor
 24  from mlflow.tracing.provider import _get_tracer
 25  from mlflow.tracking.fluent import _get_experiment_id
 26  from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS, get_autolog_function
 27  from mlflow.utils.autologging_utils.safety import revert_patches
 28  from mlflow.version import IS_TRACING_SDK_ONLY
 29  
 30  
 31  def create_mock_otel_span(
 32      trace_id: int,
 33      span_id: int,
 34      name: str = "test_span",
 35      parent_id: int | None = None,
 36      start_time: int | None = None,
 37      end_time: int | None = None,
 38  ):
 39      """
 40      Create a mock OpenTelemetry span for testing purposes.
 41  
 42      OpenTelemetry doesn't allow creating a span outside of a tracer. So here we create a mock span
 43      that extends ReadableSpan (data object) and exposes the necessary attributes for testing.
 44      """
 45  
 46      @dataclass
 47      class _MockSpanContext:
 48          trace_id: str
 49          span_id: str
 50          trace_flags: trace_api.TraceFlags = trace_api.TraceFlags(1)
 51          trace_state: trace_api.TraceState = field(default_factory=trace_api.TraceState)
 52  
 53      class _MockOTelSpan(trace_api.Span, ReadableSpan):
 54          def __init__(
 55              self,
 56              name,
 57              context,
 58              parent,
 59              start_time=None,
 60              end_time=None,
 61              status=trace_api.Status(trace_api.StatusCode.UNSET),
 62          ):
 63              self._name = name
 64              self._parent = parent
 65              self._context = context
 66              self._start_time = start_time if start_time is not None else int(time.time() * 1e9)
 67              self._end_time = end_time
 68              self._status = status
 69              self._attributes = {}
 70              self._events = []
 71  
 72          # NB: The following methods are defined as abstract method in the Span class.
 73          def set_attributes(self, attributes):
 74              self._attributes.update(attributes)
 75  
 76          def set_attribute(self, key, value):
 77              self._attributes[key] = value
 78  
 79          def set_status(self, status):
 80              self._status = status
 81  
 82          def add_event(self, name, attributes=None, timestamp=None):
 83              self._events.append(Event(name, attributes, timestamp))
 84  
 85          def get_span_context(self):
 86              return self._context
 87  
 88          def is_recording(self):
 89              return self._end_time is None
 90  
 91          def update_name(self, name):
 92              self.name = name
 93  
 94          def end(self, end_time_ns=None):
 95              pass
 96  
 97          def record_exception():
 98              pass
 99  
100      return _MockOTelSpan(
101          name=name,
102          context=_MockSpanContext(trace_id, span_id),
103          parent=_MockSpanContext(trace_id, parent_id) if parent_id else None,
104          start_time=start_time,
105          end_time=end_time,
106      )
107  
108  
109  def create_trace(request_id) -> Trace:
110      return Trace(info=create_test_trace_info(request_id), data=TraceData())
111  
112  
113  def create_test_trace_info(
114      trace_id,
115      experiment_id="test",
116      request_time=0,
117      execution_duration=1,
118      state=TraceState.OK,
119      trace_metadata=None,
120      tags=None,
121  ):
122      # Add schema version to metadata if not provided, to match real trace creation behavior
123      final_metadata = trace_metadata or {}
124      if TRACE_SCHEMA_VERSION_KEY not in final_metadata:
125          final_metadata = final_metadata.copy()
126          final_metadata[TRACE_SCHEMA_VERSION_KEY] = str(TRACE_SCHEMA_VERSION)
127  
128      return TraceInfo(
129          trace_id=trace_id,
130          trace_location=TraceLocation.from_experiment_id(experiment_id),
131          request_time=request_time,
132          execution_duration=execution_duration,
133          state=state,
134          trace_metadata=final_metadata,
135          tags=tags or {},
136      )
137  
138  
139  def create_test_trace_info_with_uc_table(
140      trace_id: str, catalog_name: str, schema_name: str
141  ) -> TraceInfo:
142      return TraceInfo(
143          trace_id=trace_id,
144          trace_location=TraceLocation.from_databricks_uc_schema(catalog_name, schema_name),
145          request_time=0,
146          execution_duration=1,
147          state=TraceState.OK,
148          trace_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)},
149          tags={},
150      )
151  
152  
153  def get_traces(experiment_id=None) -> list[Trace]:
154      # Flush any pending async trace writes before querying so tests see complete results.
155      mlflow.flush_trace_async_logging()
156      # Get all traces from the backend
157      return TracingClient().search_traces(
158          locations=[experiment_id or _get_experiment_id()],
159      )
160  
161  
162  def purge_traces(experiment_id=None):
163      if len(get_traces(experiment_id)) == 0:
164          return
165  
166      # Delete all traces from the backend
167      TracingClient().delete_traces(
168          experiment_id=experiment_id or _get_experiment_id(),
169          max_traces=1000,
170          max_timestamp_millis=int(time.time() * 1000),
171      )
172  
173  
174  def get_tracer_tracking_uri() -> str | None:
175      """Get current tracking URI configured as the trace export destination."""
176      from opentelemetry import trace
177  
178      tracer = _get_tracer(__name__)
179      if isinstance(tracer, trace.ProxyTracer):
180          tracer = tracer._tracer
181      span_processor = tracer.span_processor._span_processors[0]
182  
183      if isinstance(span_processor, MlflowV3SpanProcessor):
184          return span_processor.span_exporter._client.tracking_uri
185  
186  
187  @pytest.fixture
188  def reset_autolog_state():
189      """Reset autologging state to avoid interference between tests"""
190      yield
191  
192      for flavor in FLAVOR_TO_MODULE_NAME:
193          # 1. Remove post-import hooks (registered by global mlflow.autolog() function)
194          mlflow.utils.import_hooks._post_import_hooks.pop(flavor, None)
195  
196      for flavor in AUTOLOGGING_INTEGRATIONS.keys():
197          # 2. Disable autologging for the flavor. This is necessary because some autologging
198          #    update global settings (e.g. callbacks) and we need to revert them.
199          try:
200              if autolog := get_autolog_function(flavor):
201                  autolog(disable=True)
202          except ImportError:
203              pass
204  
205          # 3. Revert any patches applied by autologging
206          revert_patches(flavor)
207  
208      AUTOLOGGING_INTEGRATIONS.clear()
209  
210  
211  def score_in_model_serving(model_uri: str, model_input: dict[str, Any]):
212      """
213      A helper function to emulate model prediction inside a Databricks model serving environment.
214  
215      This is highly simplified version, but captures important aspects for testing tracing:
216        1. Setting env vars that users set for enable tracing in model serving
217        2. Load the model in a background thread
218      """
219      from mlflow.pyfunc.context import Context, set_prediction_context
220  
221      with mock.patch.dict(
222          "os.environ",
223          os.environ | {"IS_IN_DB_MODEL_SERVING_ENV": "true", "ENABLE_MLFLOW_TRACING": "true"},
224          clear=True,
225      ):
226          # Reset tracing setup to start fresh w/ model serving environment
227          mlflow.tracing.reset()
228  
229          def _load_model():
230              return mlflow.pyfunc.load_model(model_uri)
231  
232          with ThreadPoolExecutor(
233              max_workers=1, thread_name_prefix="test-tracing-helper"
234          ) as executor:
235              model = executor.submit(_load_model).result()
236  
237          # Score the model
238          request_id = uuid.uuid4().hex
239          with set_prediction_context(Context(request_id=request_id)):
240              predictions = model.predict(model_input)
241  
242          trace = pop_trace(request_id)
243          return (request_id, predictions, trace)
244  
245  
246  def skip_when_testing_trace_sdk(f):
247      # Decorator to Skip the test if only mlflow-tracing package is installed and
248      # not the full mlflow package.
249      msg = "Skipping test because it requires mlflow or mlflow-skinny to be installed."
250      skip_decorator = pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason=msg)
251      return skip_decorator(f)
252  
253  
254  def skip_module_when_testing_trace_sdk():
255      """Skip the entire module if only mlflow-tracing package is installed"""
256      if IS_TRACING_SDK_ONLY:
257          pytest.skip(
258              "Skipping test because it requires mlflow or mlflow-skinny to be installed.",
259              allow_module_level=True,
260          )
261  
262  
263  @pytest.fixture
264  def capture_otel_export():
265      """Capture traces in memory for testing otel export."""
266      from mlflow.tracing.provider import provider
267  
268      exporter = InMemorySpanExporter()
269      provider.get_or_init_tracer("test")
270      tp = provider.get()
271      processor = OtelSpanProcessor(span_exporter=exporter, export_metrics=False)
272      processor._should_register_traces = False
273      tp.add_span_processor(processor)
274      yield exporter, processor
275      processor.force_flush(timeout_millis=5000)
276      processor.shutdown()
277  
278  
279  V2_TRACE_DICT = {
280      "info": {
281          "request_id": "58f4e27101304034b15c512b603bf1b2",
282          "experiment_id": "0",
283          "timestamp_ms": 100,
284          "execution_time_ms": 200,
285          "status": "OK",
286          "request_metadata": {
287              "mlflow.trace_schema.version": "2",
288              "mlflow.traceInputs": '{"x": 2, "y": 5}',
289              "mlflow.traceOutputs": "8",
290          },
291          "tags": {
292              "mlflow.source.name": "test",
293              "mlflow.source.type": "LOCAL",
294              "mlflow.traceName": "predict",
295              "mlflow.artifactLocation": "/path/to/artifact",
296          },
297          "assessments": [],
298      },
299      "data": {
300          "spans": [
301              {
302                  "name": "predict",
303                  "context": {
304                      "span_id": "0d48a6670588966b",
305                      "trace_id": "63076d0c1b90f1df0970f897dc428bd6",
306                  },
307                  "parent_id": None,
308                  "start_time": 100,
309                  "end_time": 200,
310                  "status_code": "OK",
311                  "status_message": "",
312                  "attributes": {
313                      "mlflow.traceRequestId": '"58f4e27101304034b15c512b603bf1b2"',
314                      "mlflow.spanType": '"UNKNOWN"',
315                      "mlflow.spanFunctionName": '"predict"',
316                      "mlflow.spanInputs": '{"x": 2, "y": 5}',
317                      "mlflow.spanOutputs": "8",
318                  },
319                  "events": [],
320              },
321              {
322                  "name": "add_one_with_custom_name",
323                  "context": {
324                      "span_id": "6fc32f36ef591f60",
325                      "trace_id": "63076d0c1b90f1df0970f897dc428bd6",
326                  },
327                  "parent_id": "0d48a6670588966b",
328                  "start_time": 300,
329                  "end_time": 400,
330                  "status_code": "OK",
331                  "status_message": "",
332                  "attributes": {
333                      "mlflow.traceRequestId": '"58f4e27101304034b15c512b603bf1b2"',
334                      "mlflow.spanType": '"LLM"',
335                      "delta": "1",
336                      "metadata": '{"foo": "bar"}',
337                      "datetime": '"2025-04-29 08:37:06.772253"',
338                      "mlflow.spanFunctionName": '"add_one"',
339                      "mlflow.spanInputs": '{"z": 7}',
340                      "mlflow.spanOutputs": "8",
341                  },
342                  "events": [],
343              },
344          ],
345          "request": '{"x": 2, "y": 5}',
346          "response": "8",
347      },
348  }