test_mlflow_v3_exporter.py
1 import json 2 import os 3 import threading 4 import time 5 from concurrent.futures import ThreadPoolExecutor 6 from unittest import mock 7 8 import pytest 9 from google.protobuf.json_format import ParseDict 10 11 import mlflow 12 from mlflow.entities import LiveSpan 13 from mlflow.entities.model_registry import PromptVersion 14 from mlflow.entities.span_event import SpanEvent 15 from mlflow.entities.trace import Trace 16 from mlflow.entities.trace_info import TraceInfo 17 from mlflow.entities.trace_location import MlflowExperimentLocation 18 from mlflow.protos import service_pb2 as pb 19 from mlflow.tracing.constant import SpansLocation, TraceMetadataKey, TraceSizeStatsKey, TraceTagKey 20 from mlflow.tracing.export.mlflow_v3 import MlflowV3SpanExporter 21 from mlflow.tracing.provider import _get_trace_exporter 22 from mlflow.tracing.trace_manager import InMemoryTraceManager 23 from mlflow.tracing.utils import generate_trace_id_v3 24 25 from tests.tracing.helper import create_mock_otel_span, create_test_trace_info 26 27 _EXPERIMENT_ID = "dummy-experiment-id" 28 29 30 def join_thread_by_name_prefix(prefix: str, timeout: float = 5.0): 31 """Join thread by name prefix to avoid time.sleep in tests.""" 32 for thread in threading.enumerate(): 33 if thread != threading.main_thread() and thread.name.startswith(prefix): 34 thread.join(timeout=timeout) 35 36 37 @mlflow.trace 38 def _predict(x: str) -> str: 39 with mlflow.start_span(name="child") as child_span: 40 child_span.set_inputs("dummy") 41 child_span.add_event(SpanEvent(name="child_event", attributes={"attr1": "val1"})) 42 mlflow.update_current_trace(tags={"foo": "bar"}) 43 return x + "!" 44 45 46 def _flush_async_logging(): 47 exporter = _get_trace_exporter() 48 assert hasattr(exporter, "_async_queue"), "Async queue is not initialized" 49 exporter._async_queue.flush(terminate=True) 50 51 52 # Set a test timeout of 20 seconds to catch excessive delays due to request retry loops, 53 # e.g. when checking the MLflow server version 54 @pytest.mark.timeout(20) 55 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 56 def test_export(is_async, monkeypatch): 57 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 58 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 59 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async)) 60 # Disable batch span processor — this test verifies exporter-level async logging 61 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false") 62 63 mlflow.set_tracking_uri("databricks") 64 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 65 66 trace_info = None 67 68 def mock_response(credentials, path, method, trace_json, *args, **kwargs): 69 nonlocal trace_info 70 trace_dict = json.loads(trace_json) 71 trace_proto = ParseDict(trace_dict["trace"], pb.Trace()) 72 trace_info_proto = ParseDict(trace_dict["trace"]["trace_info"], pb.TraceInfoV3()) 73 trace_info = TraceInfo.from_proto(trace_info_proto) 74 return pb.StartTraceV3.Response(trace=trace_proto) 75 76 with ( 77 mock.patch( 78 "mlflow.store.tracking.rest_store.call_endpoint", side_effect=mock_response 79 ) as mock_call_endpoint, 80 mock.patch( 81 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 82 ) as mock_upload_trace_data, 83 mock.patch("mlflow.tracing.client.TracingClient._upload_attachments", return_value=None), 84 ): 85 _predict("hello") 86 87 if is_async: 88 _flush_async_logging() 89 90 # Verify client methods were called correctly 91 mock_call_endpoint.assert_called_once() 92 mock_upload_trace_data.assert_called_once() 93 94 # Access the trace that was passed to _start_trace 95 endpoint = mock_call_endpoint.call_args.args[1] 96 assert endpoint == "/api/3.0/mlflow/traces" 97 trace_data = mock_upload_trace_data.call_args.args[1] 98 99 # Basic validation of the trace object 100 assert trace_info.trace_id is not None 101 102 # Validate the size stats metadata 103 # Using pop() to exclude the size of these fields when computing the expected size 104 size_stats = json.loads(trace_info.trace_metadata.pop(TraceMetadataKey.SIZE_STATS)) 105 size_bytes = int(trace_info.trace_metadata.pop(TraceMetadataKey.SIZE_BYTES)) 106 107 # The total size of the trace should much with the size of the trace object 108 expected_size_bytes = len(Trace(info=trace_info, data=trace_data).to_json().encode("utf-8")) 109 110 assert size_bytes == expected_size_bytes 111 assert size_stats[TraceSizeStatsKey.TOTAL_SIZE_BYTES] == expected_size_bytes 112 assert size_stats[TraceSizeStatsKey.NUM_SPANS] == 2 113 assert size_stats[TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES] > 0 114 115 # Verify percentile stats are included 116 assert TraceSizeStatsKey.P25_SPAN_SIZE_BYTES in size_stats 117 assert TraceSizeStatsKey.P50_SPAN_SIZE_BYTES in size_stats 118 assert TraceSizeStatsKey.P75_SPAN_SIZE_BYTES in size_stats 119 120 # Verify percentiles are valid integers 121 assert isinstance(size_stats[TraceSizeStatsKey.P25_SPAN_SIZE_BYTES], int) 122 assert isinstance(size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES], int) 123 assert isinstance(size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES], int) 124 125 # Verify percentile ordering: P25 <= P50 <= P75 <= max 126 assert ( 127 size_stats[TraceSizeStatsKey.P25_SPAN_SIZE_BYTES] 128 <= size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES] 129 ) 130 assert ( 131 size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES] 132 <= size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES] 133 ) 134 assert ( 135 size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES] 136 <= size_stats[TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES] 137 ) 138 139 # Validate the data was passed to upload_trace_data 140 call_args = mock_upload_trace_data.call_args 141 assert isinstance(call_args.args[0], TraceInfo) 142 assert call_args.args[0].trace_id == trace_info.trace_id 143 144 # We don't need to validate the exact JSON structure anymore since 145 # we're testing the client methods directly, not the HTTP request 146 147 # Last active trace ID should be set 148 assert mlflow.get_last_active_trace_id() is not None 149 150 151 @pytest.mark.timeout(20) 152 def test_export_with_batch_span_processor(monkeypatch): 153 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 154 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 155 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "true") 156 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true") 157 158 mlflow.set_tracking_uri("databricks") 159 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 160 161 trace_info = None 162 163 def mock_response(credentials, path, method, trace_json, *args, **kwargs): 164 nonlocal trace_info 165 trace_dict = json.loads(trace_json) 166 trace_proto = ParseDict(trace_dict["trace"], pb.Trace()) 167 trace_info_proto = ParseDict(trace_dict["trace"]["trace_info"], pb.TraceInfoV3()) 168 trace_info = TraceInfo.from_proto(trace_info_proto) 169 return pb.StartTraceV3.Response(trace=trace_proto) 170 171 with ( 172 mock.patch( 173 "mlflow.store.tracking.rest_store.call_endpoint", side_effect=mock_response 174 ) as mock_call_endpoint, 175 mock.patch( 176 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 177 ) as mock_upload_trace_data, 178 mock.patch("mlflow.tracing.client.TracingClient._upload_attachments", return_value=None), 179 ): 180 _predict("hello") 181 182 # Flush the batch processor and async queue to ensure spans are exported 183 mlflow.flush_trace_async_logging(terminate=True) 184 185 # Verify the trace was exported through the batch processor pipeline 186 mock_call_endpoint.assert_called_once() 187 mock_upload_trace_data.assert_called_once() 188 189 assert trace_info is not None 190 assert trace_info.trace_id is not None 191 assert mlflow.get_last_active_trace_id() is not None 192 193 194 def test_async_logging_disabled_in_databricks_notebook(monkeypatch): 195 with mock.patch("mlflow.tracing.export.mlflow_v3.is_in_databricks_notebook", return_value=True): 196 monkeypatch.delenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", raising=False) 197 exporter = MlflowV3SpanExporter() 198 assert not exporter._is_async_enabled 199 200 # If the env var is set explicitly, we should respect that 201 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True") 202 exporter = MlflowV3SpanExporter() 203 assert exporter._is_async_enabled 204 205 206 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 207 def test_export_catch_failure(is_async, monkeypatch): 208 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 209 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 210 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async)) 211 # Disable batch span processor — this test verifies exporter-level async logging 212 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false") 213 214 mlflow.set_tracking_uri("databricks") 215 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 216 217 response = mock.MagicMock() 218 response.status_code = 500 219 response.text = "Failed to export trace" 220 221 with ( 222 mock.patch( 223 "mlflow.tracing.client.TracingClient.start_trace", 224 side_effect=Exception("Failed to start trace"), 225 ), 226 mock.patch("mlflow.tracing.export.mlflow_v3._logger") as mock_logger, 227 ): 228 _predict("hello") 229 230 if is_async: 231 _flush_async_logging() 232 233 mock_logger.warning.assert_called() 234 warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] 235 assert any("Failed to start trace" in msg for msg in warning_calls) 236 237 238 def test_export_catch_failure_with_batch_span_processor(monkeypatch): 239 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 240 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 241 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "true") 242 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true") 243 244 mlflow.set_tracking_uri("databricks") 245 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 246 247 with ( 248 mock.patch( 249 "mlflow.tracing.client.TracingClient.start_trace", 250 side_effect=Exception("Failed to start trace"), 251 ), 252 mock.patch("mlflow.tracing.export.mlflow_v3._logger") as mock_logger, 253 ): 254 _predict("hello") 255 256 # Flush batch processor to ensure the export (and failure) is processed 257 mlflow.flush_trace_async_logging(terminate=True) 258 259 # Verify the failure was logged, not raised 260 mock_logger.warning.assert_called() 261 warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] 262 assert any("Failed to start trace" in msg for msg in warning_calls) 263 264 265 @pytest.mark.skipif(os.name == "nt", reason="Flaky on Windows") 266 def test_async_bulk_export(monkeypatch): 267 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 268 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 269 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True") 270 monkeypatch.setenv("MLFLOW_ASYNC_TRACE_LOGGING_MAX_QUEUE_SIZE", "1000") 271 # Disable batch span processor — this test verifies exporter-level async logging 272 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false") 273 274 mlflow.set_tracking_uri("databricks") 275 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=0)) 276 277 # Create a mock function that simulates delay 278 def _mock_client_method(*args, **kwargs): 279 # Simulate a slow response 280 time.sleep(0.1) 281 mock_trace = mock.MagicMock() 282 mock_trace.info = mock.MagicMock() 283 return mock_trace 284 285 with ( 286 mock.patch( 287 "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_client_method 288 ) as mock_start_trace, 289 mock.patch( 290 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 291 ) as mock_upload_trace_data, 292 ): 293 # Log many traces 294 start_time = time.time() 295 with ThreadPoolExecutor( 296 max_workers=10, thread_name_prefix="test-mlflow-v3-exporter" 297 ) as executor: 298 for _ in range(100): 299 executor.submit(_predict, "hello") 300 301 # Trace logging should not block the main thread 302 assert time.time() - start_time < 5 303 304 _flush_async_logging() 305 306 # Verify the client methods were called the expected number of times 307 assert mock_start_trace.call_count == 100 308 assert mock_upload_trace_data.call_count == 100 309 310 311 @pytest.mark.skipif(os.name == "nt", reason="Flaky on Windows") 312 def test_async_bulk_export_with_batch_span_processor(monkeypatch): 313 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 314 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 315 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True") 316 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true") 317 318 mlflow.set_tracking_uri("databricks") 319 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=0)) 320 321 def _mock_client_method(*args, **kwargs): 322 time.sleep(0.1) 323 mock_trace = mock.MagicMock() 324 mock_trace.info = mock.MagicMock() 325 return mock_trace 326 327 with ( 328 mock.patch( 329 "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_client_method 330 ) as mock_start_trace, 331 mock.patch( 332 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 333 ) as mock_upload_trace_data, 334 ): 335 # Log many traces concurrently 336 start_time = time.time() 337 with ThreadPoolExecutor( 338 max_workers=10, thread_name_prefix="test-mlflow-v3-exporter-batch" 339 ) as executor: 340 for _ in range(100): 341 executor.submit(_predict, "hello") 342 343 # Trace logging should not block the main thread 344 assert time.time() - start_time < 5 345 346 # Flush batch processor and async queue 347 mlflow.flush_trace_async_logging(terminate=True) 348 349 # Verify all traces were exported 350 assert mock_start_trace.call_count == 100 351 assert mock_upload_trace_data.call_count == 100 352 353 354 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 355 def test_prompt_linking_in_mlflow_v3_exporter(is_async, monkeypatch): 356 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 357 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 358 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async)) 359 360 mlflow.set_tracking_uri("databricks") 361 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 362 363 # Capture prompt linking calls 364 captured_prompts = None 365 captured_trace_id = None 366 367 def mock_link_prompt_versions_to_trace(trace_id, prompts): 368 nonlocal captured_prompts, captured_trace_id 369 captured_prompts = prompts 370 captured_trace_id = trace_id 371 372 # Mock the prompt linking method and other client methods 373 with ( 374 mock.patch( 375 "mlflow.tracing.client.TracingClient.start_trace", 376 ) as mock_start_trace, 377 mock.patch( 378 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 379 ) as mock_upload_trace_data, 380 mock.patch( 381 "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace", 382 side_effect=mock_link_prompt_versions_to_trace, 383 ) as mock_link_prompts, 384 ): 385 # Create test prompt versions 386 prompt1 = PromptVersion( 387 name="test_prompt_1", 388 version=1, 389 template="Hello, {{name}}!", 390 commit_message="Test prompt 1", 391 creation_timestamp=123456789, 392 ) 393 prompt2 = PromptVersion( 394 name="test_prompt_2", 395 version=2, 396 template="Goodbye, {{name}}!", 397 commit_message="Test prompt 2", 398 creation_timestamp=123456790, 399 ) 400 401 # Create a mock OTEL span and trace 402 otel_span = create_mock_otel_span( 403 name="root", 404 trace_id=12345, 405 span_id=1, 406 parent_id=None, 407 ) 408 trace_id = generate_trace_id_v3(otel_span) 409 span = LiveSpan(otel_span, trace_id) 410 411 # Register the trace and spans 412 trace_manager = InMemoryTraceManager.get_instance() 413 trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID) 414 trace_manager.register_trace(otel_span.context.trace_id, trace_info) 415 trace_manager.register_span(span) 416 417 # Register prompts to the trace 418 trace_manager.register_prompt(trace_id, prompt1) 419 trace_manager.register_prompt(trace_id, prompt2) 420 421 # Create and use the exporter 422 exporter = MlflowV3SpanExporter() 423 exporter.export([otel_span]) 424 425 if is_async: 426 # For async tests, we need to flush the specific exporter's queue 427 exporter._async_queue.flush(terminate=True) 428 429 # Wait for any prompt linking threads to complete 430 join_thread_by_name_prefix("link_prompts_from_exporter") 431 432 # Verify that trace info contains the linked prompts tags 433 tag_value = trace_info.tags.get(TraceTagKey.LINKED_PROMPTS) 434 assert tag_value is not None 435 tag_value = json.loads(tag_value) 436 assert len(tag_value) == 2 437 assert tag_value[0]["name"] == "test_prompt_1" 438 assert tag_value[0]["version"] == "1" 439 assert tag_value[1]["name"] == "test_prompt_2" 440 assert tag_value[1]["version"] == "2" 441 442 # Verify that prompt linking was called 443 mock_link_prompts.assert_called_once() 444 assert captured_prompts is not None, "Prompts were not passed to link method" 445 assert len(captured_prompts) == 2, f"Expected 2 prompts, got {len(captured_prompts)}" 446 447 # Verify prompt details 448 prompt_names = {p.name for p in captured_prompts} 449 assert prompt_names == {"test_prompt_1", "test_prompt_2"} 450 451 # Verify the trace ID matches 452 assert captured_trace_id == trace_id 453 454 # Verify other client methods were also called 455 mock_start_trace.assert_called_once() 456 mock_upload_trace_data.assert_called_once() 457 458 459 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 460 def test_prompt_linking_with_empty_prompts_mlflow_v3(is_async, monkeypatch): 461 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 462 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 463 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async)) 464 465 mlflow.set_tracking_uri("databricks") 466 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 467 468 # Capture prompt linking calls 469 captured_prompts = None 470 captured_trace_id = None 471 472 def mock_link_prompt_versions_to_trace(trace_id, prompts): 473 nonlocal captured_prompts, captured_trace_id 474 captured_prompts = prompts 475 captured_trace_id = trace_id 476 477 # Mock the client methods 478 with ( 479 mock.patch( 480 "mlflow.tracing.client.TracingClient.start_trace", 481 return_value=mock.MagicMock(trace_id="test-trace-id"), 482 ) as mock_start_trace, 483 mock.patch( 484 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 485 ) as mock_upload_trace_data, 486 mock.patch( 487 "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace", 488 side_effect=mock_link_prompt_versions_to_trace, 489 ) as mock_link_prompts, 490 ): 491 # Create a mock OTEL span and trace (no prompts added) 492 otel_span = create_mock_otel_span( 493 name="root", 494 trace_id=12345, 495 span_id=1, 496 parent_id=None, 497 ) 498 trace_id = generate_trace_id_v3(otel_span) 499 span = LiveSpan(otel_span, trace_id) 500 501 # Register the trace and spans (but no prompts) 502 trace_manager = InMemoryTraceManager.get_instance() 503 trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID) 504 trace_manager.register_trace(otel_span.context.trace_id, trace_info) 505 trace_manager.register_span(span) 506 507 # Create and use the exporter 508 exporter = MlflowV3SpanExporter() 509 exporter.export([otel_span]) 510 511 if is_async: 512 # For async tests, we need to flush the specific exporter's queue 513 exporter._async_queue.flush(terminate=True) 514 515 # Wait for any prompt linking threads to complete 516 join_thread_by_name_prefix("link_prompts_from_exporter") 517 518 # Verify that prompt linking was NOT called for empty prompts (this is correct behavior) 519 mock_link_prompts.assert_not_called() 520 # Since no prompts were passed, no thread was started and no call was made 521 assert captured_trace_id is None # No linking occurred, so trace_id was never captured 522 523 # Verify other client methods were also called 524 mock_start_trace.assert_called_once() 525 mock_upload_trace_data.assert_called_once() 526 527 528 def test_prompt_linking_error_handling_mlflow_v3(monkeypatch): 529 monkeypatch.setenv("DATABRICKS_HOST", "dummy-host") 530 monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token") 531 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "False") # Use sync for easier testing 532 533 mlflow.set_tracking_uri("databricks") 534 mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID)) 535 536 # Mock the client methods with prompt linking failing 537 with ( 538 mock.patch( 539 "mlflow.tracing.client.TracingClient.start_trace", 540 return_value=mock.MagicMock(trace_id="test-trace-id"), 541 ) as mock_start_trace, 542 mock.patch( 543 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 544 ) as mock_upload_trace_data, 545 mock.patch( 546 "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace", 547 side_effect=Exception("Prompt linking failed"), 548 ) as mock_link_prompts, 549 mock.patch("mlflow.tracing.export.utils._logger") as mock_logger, 550 ): 551 # Create a mock OTEL span and trace with a prompt 552 otel_span = create_mock_otel_span( 553 name="root", 554 trace_id=12345, 555 span_id=1, 556 parent_id=None, 557 ) 558 trace_id = generate_trace_id_v3(otel_span) 559 span = LiveSpan(otel_span, trace_id) 560 561 # Create a test prompt 562 prompt = PromptVersion( 563 name="test_prompt", 564 version=1, 565 template="Hello, {{name}}!", 566 commit_message="Test prompt", 567 creation_timestamp=123456789, 568 ) 569 570 # Register the trace, span, and prompt 571 trace_manager = InMemoryTraceManager.get_instance() 572 trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID) 573 trace_manager.register_trace(otel_span.context.trace_id, trace_info) 574 trace_manager.register_span(span) 575 trace_manager.register_prompt(trace_id, prompt) 576 577 # Create and use the exporter 578 exporter = MlflowV3SpanExporter() 579 exporter.export([otel_span]) 580 581 # Wait for any prompt linking threads to complete so the error can be caught 582 join_thread_by_name_prefix("link_prompts_from_exporter") 583 584 # Verify that prompt linking was attempted but failed 585 mock_link_prompts.assert_called_once() 586 587 # Verify other client methods were still called 588 # (trace export should succeed despite prompt linking failure) 589 mock_start_trace.assert_called_once() 590 mock_upload_trace_data.assert_called_once() 591 592 # Verify that the error was logged but didn't crash the export 593 mock_logger.warning.assert_called() 594 warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] 595 assert any("Prompt linking failed" in msg for msg in warning_calls) 596 597 598 def test_no_log_spans_to_artifacts_if_stored_in_tracking_store(monkeypatch): 599 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "false") 600 # Create a mock OTEL span and trace 601 otel_span = create_mock_otel_span( 602 name="root", 603 trace_id=12345, 604 span_id=1, 605 parent_id=None, 606 ) 607 trace_id = generate_trace_id_v3(otel_span) 608 span = LiveSpan(otel_span, trace_id) 609 610 # Register the trace and spans 611 trace_manager = InMemoryTraceManager.get_instance() 612 trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID) 613 trace_info.tags[TraceTagKey.SPANS_LOCATION] = SpansLocation.TRACKING_STORE.value 614 trace_manager.register_trace(otel_span.context.trace_id, trace_info) 615 trace_manager.register_span(span) 616 617 mlflow.flush_trace_async_logging() 618 619 with ( 620 mock.patch( 621 "mlflow.tracing.client.TracingClient.start_trace", 622 return_value=trace_info, 623 ) as mock_start_trace, 624 mock.patch( 625 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 626 ) as mock_upload_trace_data, 627 ): 628 exporter = MlflowV3SpanExporter() 629 exporter.export([otel_span]) 630 mock_upload_trace_data.assert_not_called() 631 mock_start_trace.assert_called_once() 632 633 634 def test_batch_write_skipped_when_store_unsupported(monkeypatch): 635 monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "false") 636 otel_span = create_mock_otel_span(name="root", trace_id=66666, span_id=1, parent_id=None) 637 trace_id = generate_trace_id_v3(otel_span) 638 span = LiveSpan(otel_span, trace_id) 639 640 trace_manager = InMemoryTraceManager.get_instance() 641 trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID) 642 trace_manager.register_trace(otel_span.context.trace_id, trace_info) 643 trace_manager.register_span(span) 644 645 with ( 646 mock.patch( 647 "mlflow.tracing.client.TracingClient.start_trace", 648 return_value=trace_info, 649 ) as mock_start_trace, 650 mock.patch( 651 "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None 652 ) as mock_upload_trace_data, 653 mock.patch("mlflow.tracing.client.TracingClient.log_spans") as mock_log_spans, 654 ): 655 exporter = MlflowV3SpanExporter() 656 exporter._store_supports_log_spans = False 657 exporter.export([otel_span]) 658 659 mock_start_trace.assert_called_once() 660 # log_spans should NOT be called when store doesn't support it 661 mock_log_spans.assert_not_called() 662 # Artifact upload should still happen as fallback 663 mock_upload_trace_data.assert_called_once()