helper.py
1 import os 2 import time 3 import uuid 4 from concurrent.futures import ThreadPoolExecutor 5 from dataclasses import dataclass, field 6 from typing import Any 7 from unittest import mock 8 9 import opentelemetry.trace as trace_api 10 import pytest 11 from opentelemetry.sdk.trace import Event, ReadableSpan 12 from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter 13 14 import mlflow 15 from mlflow.entities import Trace, TraceData, TraceInfo 16 from mlflow.entities.trace_location import TraceLocation 17 from mlflow.entities.trace_state import TraceState 18 from mlflow.ml_package_versions import FLAVOR_TO_MODULE_NAME 19 from mlflow.tracing.client import TracingClient 20 from mlflow.tracing.constant import TRACE_SCHEMA_VERSION, TRACE_SCHEMA_VERSION_KEY 21 from mlflow.tracing.export.inference_table import pop_trace 22 from mlflow.tracing.processor.mlflow_v3 import MlflowV3SpanProcessor 23 from mlflow.tracing.processor.otel import OtelSpanProcessor 24 from mlflow.tracing.provider import _get_tracer 25 from mlflow.tracking.fluent import _get_experiment_id 26 from mlflow.utils.autologging_utils import AUTOLOGGING_INTEGRATIONS, get_autolog_function 27 from mlflow.utils.autologging_utils.safety import revert_patches 28 from mlflow.version import IS_TRACING_SDK_ONLY 29 30 31 def create_mock_otel_span( 32 trace_id: int, 33 span_id: int, 34 name: str = "test_span", 35 parent_id: int | None = None, 36 start_time: int | None = None, 37 end_time: int | None = None, 38 ): 39 """ 40 Create a mock OpenTelemetry span for testing purposes. 41 42 OpenTelemetry doesn't allow creating a span outside of a tracer. So here we create a mock span 43 that extends ReadableSpan (data object) and exposes the necessary attributes for testing. 44 """ 45 46 @dataclass 47 class _MockSpanContext: 48 trace_id: str 49 span_id: str 50 trace_flags: trace_api.TraceFlags = trace_api.TraceFlags(1) 51 trace_state: trace_api.TraceState = field(default_factory=trace_api.TraceState) 52 53 class _MockOTelSpan(trace_api.Span, ReadableSpan): 54 def __init__( 55 self, 56 name, 57 context, 58 parent, 59 start_time=None, 60 end_time=None, 61 status=trace_api.Status(trace_api.StatusCode.UNSET), 62 ): 63 self._name = name 64 self._parent = parent 65 self._context = context 66 self._start_time = start_time if start_time is not None else int(time.time() * 1e9) 67 self._end_time = end_time 68 self._status = status 69 self._attributes = {} 70 self._events = [] 71 72 # NB: The following methods are defined as abstract method in the Span class. 73 def set_attributes(self, attributes): 74 self._attributes.update(attributes) 75 76 def set_attribute(self, key, value): 77 self._attributes[key] = value 78 79 def set_status(self, status): 80 self._status = status 81 82 def add_event(self, name, attributes=None, timestamp=None): 83 self._events.append(Event(name, attributes, timestamp)) 84 85 def get_span_context(self): 86 return self._context 87 88 def is_recording(self): 89 return self._end_time is None 90 91 def update_name(self, name): 92 self.name = name 93 94 def end(self, end_time_ns=None): 95 pass 96 97 def record_exception(): 98 pass 99 100 return _MockOTelSpan( 101 name=name, 102 context=_MockSpanContext(trace_id, span_id), 103 parent=_MockSpanContext(trace_id, parent_id) if parent_id else None, 104 start_time=start_time, 105 end_time=end_time, 106 ) 107 108 109 def create_trace(request_id) -> Trace: 110 return Trace(info=create_test_trace_info(request_id), data=TraceData()) 111 112 113 def create_test_trace_info( 114 trace_id, 115 experiment_id="test", 116 request_time=0, 117 execution_duration=1, 118 state=TraceState.OK, 119 trace_metadata=None, 120 tags=None, 121 ): 122 # Add schema version to metadata if not provided, to match real trace creation behavior 123 final_metadata = trace_metadata or {} 124 if TRACE_SCHEMA_VERSION_KEY not in final_metadata: 125 final_metadata = final_metadata.copy() 126 final_metadata[TRACE_SCHEMA_VERSION_KEY] = str(TRACE_SCHEMA_VERSION) 127 128 return TraceInfo( 129 trace_id=trace_id, 130 trace_location=TraceLocation.from_experiment_id(experiment_id), 131 request_time=request_time, 132 execution_duration=execution_duration, 133 state=state, 134 trace_metadata=final_metadata, 135 tags=tags or {}, 136 ) 137 138 139 def create_test_trace_info_with_uc_table( 140 trace_id: str, catalog_name: str, schema_name: str 141 ) -> TraceInfo: 142 return TraceInfo( 143 trace_id=trace_id, 144 trace_location=TraceLocation.from_databricks_uc_schema(catalog_name, schema_name), 145 request_time=0, 146 execution_duration=1, 147 state=TraceState.OK, 148 trace_metadata={TRACE_SCHEMA_VERSION_KEY: str(TRACE_SCHEMA_VERSION)}, 149 tags={}, 150 ) 151 152 153 def get_traces(experiment_id=None) -> list[Trace]: 154 # Flush any pending async trace writes before querying so tests see complete results. 155 mlflow.flush_trace_async_logging() 156 # Get all traces from the backend 157 return TracingClient().search_traces( 158 locations=[experiment_id or _get_experiment_id()], 159 ) 160 161 162 def purge_traces(experiment_id=None): 163 if len(get_traces(experiment_id)) == 0: 164 return 165 166 # Delete all traces from the backend 167 TracingClient().delete_traces( 168 experiment_id=experiment_id or _get_experiment_id(), 169 max_traces=1000, 170 max_timestamp_millis=int(time.time() * 1000), 171 ) 172 173 174 def get_tracer_tracking_uri() -> str | None: 175 """Get current tracking URI configured as the trace export destination.""" 176 from opentelemetry import trace 177 178 tracer = _get_tracer(__name__) 179 if isinstance(tracer, trace.ProxyTracer): 180 tracer = tracer._tracer 181 span_processor = tracer.span_processor._span_processors[0] 182 183 if isinstance(span_processor, MlflowV3SpanProcessor): 184 return span_processor.span_exporter._client.tracking_uri 185 186 187 @pytest.fixture 188 def reset_autolog_state(): 189 """Reset autologging state to avoid interference between tests""" 190 yield 191 192 for flavor in FLAVOR_TO_MODULE_NAME: 193 # 1. Remove post-import hooks (registered by global mlflow.autolog() function) 194 mlflow.utils.import_hooks._post_import_hooks.pop(flavor, None) 195 196 for flavor in AUTOLOGGING_INTEGRATIONS.keys(): 197 # 2. Disable autologging for the flavor. This is necessary because some autologging 198 # update global settings (e.g. callbacks) and we need to revert them. 199 try: 200 if autolog := get_autolog_function(flavor): 201 autolog(disable=True) 202 except ImportError: 203 pass 204 205 # 3. Revert any patches applied by autologging 206 revert_patches(flavor) 207 208 AUTOLOGGING_INTEGRATIONS.clear() 209 210 211 def score_in_model_serving(model_uri: str, model_input: dict[str, Any]): 212 """ 213 A helper function to emulate model prediction inside a Databricks model serving environment. 214 215 This is highly simplified version, but captures important aspects for testing tracing: 216 1. Setting env vars that users set for enable tracing in model serving 217 2. Load the model in a background thread 218 """ 219 from mlflow.pyfunc.context import Context, set_prediction_context 220 221 with mock.patch.dict( 222 "os.environ", 223 os.environ | {"IS_IN_DB_MODEL_SERVING_ENV": "true", "ENABLE_MLFLOW_TRACING": "true"}, 224 clear=True, 225 ): 226 # Reset tracing setup to start fresh w/ model serving environment 227 mlflow.tracing.reset() 228 229 def _load_model(): 230 return mlflow.pyfunc.load_model(model_uri) 231 232 with ThreadPoolExecutor( 233 max_workers=1, thread_name_prefix="test-tracing-helper" 234 ) as executor: 235 model = executor.submit(_load_model).result() 236 237 # Score the model 238 request_id = uuid.uuid4().hex 239 with set_prediction_context(Context(request_id=request_id)): 240 predictions = model.predict(model_input) 241 242 trace = pop_trace(request_id) 243 return (request_id, predictions, trace) 244 245 246 def skip_when_testing_trace_sdk(f): 247 # Decorator to Skip the test if only mlflow-tracing package is installed and 248 # not the full mlflow package. 249 msg = "Skipping test because it requires mlflow or mlflow-skinny to be installed." 250 skip_decorator = pytest.mark.skipif(IS_TRACING_SDK_ONLY, reason=msg) 251 return skip_decorator(f) 252 253 254 def skip_module_when_testing_trace_sdk(): 255 """Skip the entire module if only mlflow-tracing package is installed""" 256 if IS_TRACING_SDK_ONLY: 257 pytest.skip( 258 "Skipping test because it requires mlflow or mlflow-skinny to be installed.", 259 allow_module_level=True, 260 ) 261 262 263 @pytest.fixture 264 def capture_otel_export(): 265 """Capture traces in memory for testing otel export.""" 266 from mlflow.tracing.provider import provider 267 268 exporter = InMemorySpanExporter() 269 provider.get_or_init_tracer("test") 270 tp = provider.get() 271 processor = OtelSpanProcessor(span_exporter=exporter, export_metrics=False) 272 processor._should_register_traces = False 273 tp.add_span_processor(processor) 274 yield exporter, processor 275 processor.force_flush(timeout_millis=5000) 276 processor.shutdown() 277 278 279 V2_TRACE_DICT = { 280 "info": { 281 "request_id": "58f4e27101304034b15c512b603bf1b2", 282 "experiment_id": "0", 283 "timestamp_ms": 100, 284 "execution_time_ms": 200, 285 "status": "OK", 286 "request_metadata": { 287 "mlflow.trace_schema.version": "2", 288 "mlflow.traceInputs": '{"x": 2, "y": 5}', 289 "mlflow.traceOutputs": "8", 290 }, 291 "tags": { 292 "mlflow.source.name": "test", 293 "mlflow.source.type": "LOCAL", 294 "mlflow.traceName": "predict", 295 "mlflow.artifactLocation": "/path/to/artifact", 296 }, 297 "assessments": [], 298 }, 299 "data": { 300 "spans": [ 301 { 302 "name": "predict", 303 "context": { 304 "span_id": "0d48a6670588966b", 305 "trace_id": "63076d0c1b90f1df0970f897dc428bd6", 306 }, 307 "parent_id": None, 308 "start_time": 100, 309 "end_time": 200, 310 "status_code": "OK", 311 "status_message": "", 312 "attributes": { 313 "mlflow.traceRequestId": '"58f4e27101304034b15c512b603bf1b2"', 314 "mlflow.spanType": '"UNKNOWN"', 315 "mlflow.spanFunctionName": '"predict"', 316 "mlflow.spanInputs": '{"x": 2, "y": 5}', 317 "mlflow.spanOutputs": "8", 318 }, 319 "events": [], 320 }, 321 { 322 "name": "add_one_with_custom_name", 323 "context": { 324 "span_id": "6fc32f36ef591f60", 325 "trace_id": "63076d0c1b90f1df0970f897dc428bd6", 326 }, 327 "parent_id": "0d48a6670588966b", 328 "start_time": 300, 329 "end_time": 400, 330 "status_code": "OK", 331 "status_message": "", 332 "attributes": { 333 "mlflow.traceRequestId": '"58f4e27101304034b15c512b603bf1b2"', 334 "mlflow.spanType": '"LLM"', 335 "delta": "1", 336 "metadata": '{"foo": "bar"}', 337 "datetime": '"2025-04-29 08:37:06.772253"', 338 "mlflow.spanFunctionName": '"add_one"', 339 "mlflow.spanInputs": '{"z": 7}', 340 "mlflow.spanOutputs": "8", 341 }, 342 "events": [], 343 }, 344 ], 345 "request": '{"x": 2, "y": 5}', 346 "response": "8", 347 }, 348 }