/ tests / tracing / export / test_mlflow_v3_exporter.py
test_mlflow_v3_exporter.py
  1  import json
  2  import os
  3  import threading
  4  import time
  5  from concurrent.futures import ThreadPoolExecutor
  6  from unittest import mock
  7  
  8  import pytest
  9  from google.protobuf.json_format import ParseDict
 10  
 11  import mlflow
 12  from mlflow.entities import LiveSpan
 13  from mlflow.entities.model_registry import PromptVersion
 14  from mlflow.entities.span_event import SpanEvent
 15  from mlflow.entities.trace import Trace
 16  from mlflow.entities.trace_info import TraceInfo
 17  from mlflow.entities.trace_location import MlflowExperimentLocation
 18  from mlflow.protos import service_pb2 as pb
 19  from mlflow.tracing.constant import SpansLocation, TraceMetadataKey, TraceSizeStatsKey, TraceTagKey
 20  from mlflow.tracing.export.mlflow_v3 import MlflowV3SpanExporter
 21  from mlflow.tracing.provider import _get_trace_exporter
 22  from mlflow.tracing.trace_manager import InMemoryTraceManager
 23  from mlflow.tracing.utils import generate_trace_id_v3
 24  
 25  from tests.tracing.helper import create_mock_otel_span, create_test_trace_info
 26  
 27  _EXPERIMENT_ID = "dummy-experiment-id"
 28  
 29  
 30  def join_thread_by_name_prefix(prefix: str, timeout: float = 5.0):
 31      """Join thread by name prefix to avoid time.sleep in tests."""
 32      for thread in threading.enumerate():
 33          if thread != threading.main_thread() and thread.name.startswith(prefix):
 34              thread.join(timeout=timeout)
 35  
 36  
 37  @mlflow.trace
 38  def _predict(x: str) -> str:
 39      with mlflow.start_span(name="child") as child_span:
 40          child_span.set_inputs("dummy")
 41          child_span.add_event(SpanEvent(name="child_event", attributes={"attr1": "val1"}))
 42      mlflow.update_current_trace(tags={"foo": "bar"})
 43      return x + "!"
 44  
 45  
 46  def _flush_async_logging():
 47      exporter = _get_trace_exporter()
 48      assert hasattr(exporter, "_async_queue"), "Async queue is not initialized"
 49      exporter._async_queue.flush(terminate=True)
 50  
 51  
 52  # Set a test timeout of 20 seconds to catch excessive delays due to request retry loops,
 53  # e.g. when checking the MLflow server version
 54  @pytest.mark.timeout(20)
 55  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
 56  def test_export(is_async, monkeypatch):
 57      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
 58      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
 59      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async))
 60      # Disable batch span processor — this test verifies exporter-level async logging
 61      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false")
 62  
 63      mlflow.set_tracking_uri("databricks")
 64      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
 65  
 66      trace_info = None
 67  
 68      def mock_response(credentials, path, method, trace_json, *args, **kwargs):
 69          nonlocal trace_info
 70          trace_dict = json.loads(trace_json)
 71          trace_proto = ParseDict(trace_dict["trace"], pb.Trace())
 72          trace_info_proto = ParseDict(trace_dict["trace"]["trace_info"], pb.TraceInfoV3())
 73          trace_info = TraceInfo.from_proto(trace_info_proto)
 74          return pb.StartTraceV3.Response(trace=trace_proto)
 75  
 76      with (
 77          mock.patch(
 78              "mlflow.store.tracking.rest_store.call_endpoint", side_effect=mock_response
 79          ) as mock_call_endpoint,
 80          mock.patch(
 81              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
 82          ) as mock_upload_trace_data,
 83          mock.patch("mlflow.tracing.client.TracingClient._upload_attachments", return_value=None),
 84      ):
 85          _predict("hello")
 86  
 87          if is_async:
 88              _flush_async_logging()
 89  
 90      # Verify client methods were called correctly
 91      mock_call_endpoint.assert_called_once()
 92      mock_upload_trace_data.assert_called_once()
 93  
 94      # Access the trace that was passed to _start_trace
 95      endpoint = mock_call_endpoint.call_args.args[1]
 96      assert endpoint == "/api/3.0/mlflow/traces"
 97      trace_data = mock_upload_trace_data.call_args.args[1]
 98  
 99      # Basic validation of the trace object
100      assert trace_info.trace_id is not None
101  
102      # Validate the size stats metadata
103      # Using pop() to exclude the size of these fields when computing the expected size
104      size_stats = json.loads(trace_info.trace_metadata.pop(TraceMetadataKey.SIZE_STATS))
105      size_bytes = int(trace_info.trace_metadata.pop(TraceMetadataKey.SIZE_BYTES))
106  
107      # The total size of the trace should much with the size of the trace object
108      expected_size_bytes = len(Trace(info=trace_info, data=trace_data).to_json().encode("utf-8"))
109  
110      assert size_bytes == expected_size_bytes
111      assert size_stats[TraceSizeStatsKey.TOTAL_SIZE_BYTES] == expected_size_bytes
112      assert size_stats[TraceSizeStatsKey.NUM_SPANS] == 2
113      assert size_stats[TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES] > 0
114  
115      # Verify percentile stats are included
116      assert TraceSizeStatsKey.P25_SPAN_SIZE_BYTES in size_stats
117      assert TraceSizeStatsKey.P50_SPAN_SIZE_BYTES in size_stats
118      assert TraceSizeStatsKey.P75_SPAN_SIZE_BYTES in size_stats
119  
120      # Verify percentiles are valid integers
121      assert isinstance(size_stats[TraceSizeStatsKey.P25_SPAN_SIZE_BYTES], int)
122      assert isinstance(size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES], int)
123      assert isinstance(size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES], int)
124  
125      # Verify percentile ordering: P25 <= P50 <= P75 <= max
126      assert (
127          size_stats[TraceSizeStatsKey.P25_SPAN_SIZE_BYTES]
128          <= size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES]
129      )
130      assert (
131          size_stats[TraceSizeStatsKey.P50_SPAN_SIZE_BYTES]
132          <= size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES]
133      )
134      assert (
135          size_stats[TraceSizeStatsKey.P75_SPAN_SIZE_BYTES]
136          <= size_stats[TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES]
137      )
138  
139      # Validate the data was passed to upload_trace_data
140      call_args = mock_upload_trace_data.call_args
141      assert isinstance(call_args.args[0], TraceInfo)
142      assert call_args.args[0].trace_id == trace_info.trace_id
143  
144      # We don't need to validate the exact JSON structure anymore since
145      # we're testing the client methods directly, not the HTTP request
146  
147      # Last active trace ID should be set
148      assert mlflow.get_last_active_trace_id() is not None
149  
150  
151  @pytest.mark.timeout(20)
152  def test_export_with_batch_span_processor(monkeypatch):
153      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
154      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
155      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "true")
156      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true")
157  
158      mlflow.set_tracking_uri("databricks")
159      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
160  
161      trace_info = None
162  
163      def mock_response(credentials, path, method, trace_json, *args, **kwargs):
164          nonlocal trace_info
165          trace_dict = json.loads(trace_json)
166          trace_proto = ParseDict(trace_dict["trace"], pb.Trace())
167          trace_info_proto = ParseDict(trace_dict["trace"]["trace_info"], pb.TraceInfoV3())
168          trace_info = TraceInfo.from_proto(trace_info_proto)
169          return pb.StartTraceV3.Response(trace=trace_proto)
170  
171      with (
172          mock.patch(
173              "mlflow.store.tracking.rest_store.call_endpoint", side_effect=mock_response
174          ) as mock_call_endpoint,
175          mock.patch(
176              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
177          ) as mock_upload_trace_data,
178          mock.patch("mlflow.tracing.client.TracingClient._upload_attachments", return_value=None),
179      ):
180          _predict("hello")
181  
182          # Flush the batch processor and async queue to ensure spans are exported
183          mlflow.flush_trace_async_logging(terminate=True)
184  
185      # Verify the trace was exported through the batch processor pipeline
186      mock_call_endpoint.assert_called_once()
187      mock_upload_trace_data.assert_called_once()
188  
189      assert trace_info is not None
190      assert trace_info.trace_id is not None
191      assert mlflow.get_last_active_trace_id() is not None
192  
193  
194  def test_async_logging_disabled_in_databricks_notebook(monkeypatch):
195      with mock.patch("mlflow.tracing.export.mlflow_v3.is_in_databricks_notebook", return_value=True):
196          monkeypatch.delenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", raising=False)
197          exporter = MlflowV3SpanExporter()
198          assert not exporter._is_async_enabled
199  
200          # If the env var is set explicitly, we should respect that
201          monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True")
202          exporter = MlflowV3SpanExporter()
203          assert exporter._is_async_enabled
204  
205  
206  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
207  def test_export_catch_failure(is_async, monkeypatch):
208      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
209      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
210      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async))
211      # Disable batch span processor — this test verifies exporter-level async logging
212      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false")
213  
214      mlflow.set_tracking_uri("databricks")
215      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
216  
217      response = mock.MagicMock()
218      response.status_code = 500
219      response.text = "Failed to export trace"
220  
221      with (
222          mock.patch(
223              "mlflow.tracing.client.TracingClient.start_trace",
224              side_effect=Exception("Failed to start trace"),
225          ),
226          mock.patch("mlflow.tracing.export.mlflow_v3._logger") as mock_logger,
227      ):
228          _predict("hello")
229  
230          if is_async:
231              _flush_async_logging()
232  
233      mock_logger.warning.assert_called()
234      warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
235      assert any("Failed to start trace" in msg for msg in warning_calls)
236  
237  
238  def test_export_catch_failure_with_batch_span_processor(monkeypatch):
239      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
240      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
241      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "true")
242      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true")
243  
244      mlflow.set_tracking_uri("databricks")
245      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
246  
247      with (
248          mock.patch(
249              "mlflow.tracing.client.TracingClient.start_trace",
250              side_effect=Exception("Failed to start trace"),
251          ),
252          mock.patch("mlflow.tracing.export.mlflow_v3._logger") as mock_logger,
253      ):
254          _predict("hello")
255  
256          # Flush batch processor to ensure the export (and failure) is processed
257          mlflow.flush_trace_async_logging(terminate=True)
258  
259      # Verify the failure was logged, not raised
260      mock_logger.warning.assert_called()
261      warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
262      assert any("Failed to start trace" in msg for msg in warning_calls)
263  
264  
265  @pytest.mark.skipif(os.name == "nt", reason="Flaky on Windows")
266  def test_async_bulk_export(monkeypatch):
267      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
268      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
269      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True")
270      monkeypatch.setenv("MLFLOW_ASYNC_TRACE_LOGGING_MAX_QUEUE_SIZE", "1000")
271      # Disable batch span processor — this test verifies exporter-level async logging
272      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "false")
273  
274      mlflow.set_tracking_uri("databricks")
275      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=0))
276  
277      # Create a mock function that simulates delay
278      def _mock_client_method(*args, **kwargs):
279          # Simulate a slow response
280          time.sleep(0.1)
281          mock_trace = mock.MagicMock()
282          mock_trace.info = mock.MagicMock()
283          return mock_trace
284  
285      with (
286          mock.patch(
287              "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_client_method
288          ) as mock_start_trace,
289          mock.patch(
290              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
291          ) as mock_upload_trace_data,
292      ):
293          # Log many traces
294          start_time = time.time()
295          with ThreadPoolExecutor(
296              max_workers=10, thread_name_prefix="test-mlflow-v3-exporter"
297          ) as executor:
298              for _ in range(100):
299                  executor.submit(_predict, "hello")
300  
301          # Trace logging should not block the main thread
302          assert time.time() - start_time < 5
303  
304          _flush_async_logging()
305  
306      # Verify the client methods were called the expected number of times
307      assert mock_start_trace.call_count == 100
308      assert mock_upload_trace_data.call_count == 100
309  
310  
311  @pytest.mark.skipif(os.name == "nt", reason="Flaky on Windows")
312  def test_async_bulk_export_with_batch_span_processor(monkeypatch):
313      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
314      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
315      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "True")
316      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", "true")
317  
318      mlflow.set_tracking_uri("databricks")
319      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=0))
320  
321      def _mock_client_method(*args, **kwargs):
322          time.sleep(0.1)
323          mock_trace = mock.MagicMock()
324          mock_trace.info = mock.MagicMock()
325          return mock_trace
326  
327      with (
328          mock.patch(
329              "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_client_method
330          ) as mock_start_trace,
331          mock.patch(
332              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
333          ) as mock_upload_trace_data,
334      ):
335          # Log many traces concurrently
336          start_time = time.time()
337          with ThreadPoolExecutor(
338              max_workers=10, thread_name_prefix="test-mlflow-v3-exporter-batch"
339          ) as executor:
340              for _ in range(100):
341                  executor.submit(_predict, "hello")
342  
343          # Trace logging should not block the main thread
344          assert time.time() - start_time < 5
345  
346          # Flush batch processor and async queue
347          mlflow.flush_trace_async_logging(terminate=True)
348  
349      # Verify all traces were exported
350      assert mock_start_trace.call_count == 100
351      assert mock_upload_trace_data.call_count == 100
352  
353  
354  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
355  def test_prompt_linking_in_mlflow_v3_exporter(is_async, monkeypatch):
356      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
357      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
358      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async))
359  
360      mlflow.set_tracking_uri("databricks")
361      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
362  
363      # Capture prompt linking calls
364      captured_prompts = None
365      captured_trace_id = None
366  
367      def mock_link_prompt_versions_to_trace(trace_id, prompts):
368          nonlocal captured_prompts, captured_trace_id
369          captured_prompts = prompts
370          captured_trace_id = trace_id
371  
372      # Mock the prompt linking method and other client methods
373      with (
374          mock.patch(
375              "mlflow.tracing.client.TracingClient.start_trace",
376          ) as mock_start_trace,
377          mock.patch(
378              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
379          ) as mock_upload_trace_data,
380          mock.patch(
381              "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace",
382              side_effect=mock_link_prompt_versions_to_trace,
383          ) as mock_link_prompts,
384      ):
385          # Create test prompt versions
386          prompt1 = PromptVersion(
387              name="test_prompt_1",
388              version=1,
389              template="Hello, {{name}}!",
390              commit_message="Test prompt 1",
391              creation_timestamp=123456789,
392          )
393          prompt2 = PromptVersion(
394              name="test_prompt_2",
395              version=2,
396              template="Goodbye, {{name}}!",
397              commit_message="Test prompt 2",
398              creation_timestamp=123456790,
399          )
400  
401          # Create a mock OTEL span and trace
402          otel_span = create_mock_otel_span(
403              name="root",
404              trace_id=12345,
405              span_id=1,
406              parent_id=None,
407          )
408          trace_id = generate_trace_id_v3(otel_span)
409          span = LiveSpan(otel_span, trace_id)
410  
411          # Register the trace and spans
412          trace_manager = InMemoryTraceManager.get_instance()
413          trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID)
414          trace_manager.register_trace(otel_span.context.trace_id, trace_info)
415          trace_manager.register_span(span)
416  
417          # Register prompts to the trace
418          trace_manager.register_prompt(trace_id, prompt1)
419          trace_manager.register_prompt(trace_id, prompt2)
420  
421          # Create and use the exporter
422          exporter = MlflowV3SpanExporter()
423          exporter.export([otel_span])
424  
425          if is_async:
426              # For async tests, we need to flush the specific exporter's queue
427              exporter._async_queue.flush(terminate=True)
428  
429          # Wait for any prompt linking threads to complete
430          join_thread_by_name_prefix("link_prompts_from_exporter")
431  
432      # Verify that trace info contains the linked prompts tags
433      tag_value = trace_info.tags.get(TraceTagKey.LINKED_PROMPTS)
434      assert tag_value is not None
435      tag_value = json.loads(tag_value)
436      assert len(tag_value) == 2
437      assert tag_value[0]["name"] == "test_prompt_1"
438      assert tag_value[0]["version"] == "1"
439      assert tag_value[1]["name"] == "test_prompt_2"
440      assert tag_value[1]["version"] == "2"
441  
442      # Verify that prompt linking was called
443      mock_link_prompts.assert_called_once()
444      assert captured_prompts is not None, "Prompts were not passed to link method"
445      assert len(captured_prompts) == 2, f"Expected 2 prompts, got {len(captured_prompts)}"
446  
447      # Verify prompt details
448      prompt_names = {p.name for p in captured_prompts}
449      assert prompt_names == {"test_prompt_1", "test_prompt_2"}
450  
451      # Verify the trace ID matches
452      assert captured_trace_id == trace_id
453  
454      # Verify other client methods were also called
455      mock_start_trace.assert_called_once()
456      mock_upload_trace_data.assert_called_once()
457  
458  
459  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
460  def test_prompt_linking_with_empty_prompts_mlflow_v3(is_async, monkeypatch):
461      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
462      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
463      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", str(is_async))
464  
465      mlflow.set_tracking_uri("databricks")
466      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
467  
468      # Capture prompt linking calls
469      captured_prompts = None
470      captured_trace_id = None
471  
472      def mock_link_prompt_versions_to_trace(trace_id, prompts):
473          nonlocal captured_prompts, captured_trace_id
474          captured_prompts = prompts
475          captured_trace_id = trace_id
476  
477      # Mock the client methods
478      with (
479          mock.patch(
480              "mlflow.tracing.client.TracingClient.start_trace",
481              return_value=mock.MagicMock(trace_id="test-trace-id"),
482          ) as mock_start_trace,
483          mock.patch(
484              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
485          ) as mock_upload_trace_data,
486          mock.patch(
487              "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace",
488              side_effect=mock_link_prompt_versions_to_trace,
489          ) as mock_link_prompts,
490      ):
491          # Create a mock OTEL span and trace (no prompts added)
492          otel_span = create_mock_otel_span(
493              name="root",
494              trace_id=12345,
495              span_id=1,
496              parent_id=None,
497          )
498          trace_id = generate_trace_id_v3(otel_span)
499          span = LiveSpan(otel_span, trace_id)
500  
501          # Register the trace and spans (but no prompts)
502          trace_manager = InMemoryTraceManager.get_instance()
503          trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID)
504          trace_manager.register_trace(otel_span.context.trace_id, trace_info)
505          trace_manager.register_span(span)
506  
507          # Create and use the exporter
508          exporter = MlflowV3SpanExporter()
509          exporter.export([otel_span])
510  
511          if is_async:
512              # For async tests, we need to flush the specific exporter's queue
513              exporter._async_queue.flush(terminate=True)
514  
515          # Wait for any prompt linking threads to complete
516          join_thread_by_name_prefix("link_prompts_from_exporter")
517  
518      # Verify that prompt linking was NOT called for empty prompts (this is correct behavior)
519      mock_link_prompts.assert_not_called()
520      # Since no prompts were passed, no thread was started and no call was made
521      assert captured_trace_id is None  # No linking occurred, so trace_id was never captured
522  
523      # Verify other client methods were also called
524      mock_start_trace.assert_called_once()
525      mock_upload_trace_data.assert_called_once()
526  
527  
528  def test_prompt_linking_error_handling_mlflow_v3(monkeypatch):
529      monkeypatch.setenv("DATABRICKS_HOST", "dummy-host")
530      monkeypatch.setenv("DATABRICKS_TOKEN", "dummy-token")
531      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "False")  # Use sync for easier testing
532  
533      mlflow.set_tracking_uri("databricks")
534      mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id=_EXPERIMENT_ID))
535  
536      # Mock the client methods with prompt linking failing
537      with (
538          mock.patch(
539              "mlflow.tracing.client.TracingClient.start_trace",
540              return_value=mock.MagicMock(trace_id="test-trace-id"),
541          ) as mock_start_trace,
542          mock.patch(
543              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
544          ) as mock_upload_trace_data,
545          mock.patch(
546              "mlflow.tracing.client.TracingClient.link_prompt_versions_to_trace",
547              side_effect=Exception("Prompt linking failed"),
548          ) as mock_link_prompts,
549          mock.patch("mlflow.tracing.export.utils._logger") as mock_logger,
550      ):
551          # Create a mock OTEL span and trace with a prompt
552          otel_span = create_mock_otel_span(
553              name="root",
554              trace_id=12345,
555              span_id=1,
556              parent_id=None,
557          )
558          trace_id = generate_trace_id_v3(otel_span)
559          span = LiveSpan(otel_span, trace_id)
560  
561          # Create a test prompt
562          prompt = PromptVersion(
563              name="test_prompt",
564              version=1,
565              template="Hello, {{name}}!",
566              commit_message="Test prompt",
567              creation_timestamp=123456789,
568          )
569  
570          # Register the trace, span, and prompt
571          trace_manager = InMemoryTraceManager.get_instance()
572          trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID)
573          trace_manager.register_trace(otel_span.context.trace_id, trace_info)
574          trace_manager.register_span(span)
575          trace_manager.register_prompt(trace_id, prompt)
576  
577          # Create and use the exporter
578          exporter = MlflowV3SpanExporter()
579          exporter.export([otel_span])
580  
581          # Wait for any prompt linking threads to complete so the error can be caught
582          join_thread_by_name_prefix("link_prompts_from_exporter")
583  
584      # Verify that prompt linking was attempted but failed
585      mock_link_prompts.assert_called_once()
586  
587      # Verify other client methods were still called
588      # (trace export should succeed despite prompt linking failure)
589      mock_start_trace.assert_called_once()
590      mock_upload_trace_data.assert_called_once()
591  
592      # Verify that the error was logged but didn't crash the export
593      mock_logger.warning.assert_called()
594      warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list]
595      assert any("Prompt linking failed" in msg for msg in warning_calls)
596  
597  
598  def test_no_log_spans_to_artifacts_if_stored_in_tracking_store(monkeypatch):
599      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "false")
600      # Create a mock OTEL span and trace
601      otel_span = create_mock_otel_span(
602          name="root",
603          trace_id=12345,
604          span_id=1,
605          parent_id=None,
606      )
607      trace_id = generate_trace_id_v3(otel_span)
608      span = LiveSpan(otel_span, trace_id)
609  
610      # Register the trace and spans
611      trace_manager = InMemoryTraceManager.get_instance()
612      trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID)
613      trace_info.tags[TraceTagKey.SPANS_LOCATION] = SpansLocation.TRACKING_STORE.value
614      trace_manager.register_trace(otel_span.context.trace_id, trace_info)
615      trace_manager.register_span(span)
616  
617      mlflow.flush_trace_async_logging()
618  
619      with (
620          mock.patch(
621              "mlflow.tracing.client.TracingClient.start_trace",
622              return_value=trace_info,
623          ) as mock_start_trace,
624          mock.patch(
625              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
626          ) as mock_upload_trace_data,
627      ):
628          exporter = MlflowV3SpanExporter()
629          exporter.export([otel_span])
630          mock_upload_trace_data.assert_not_called()
631          mock_start_trace.assert_called_once()
632  
633  
634  def test_batch_write_skipped_when_store_unsupported(monkeypatch):
635      monkeypatch.setenv("MLFLOW_ENABLE_ASYNC_TRACE_LOGGING", "false")
636      otel_span = create_mock_otel_span(name="root", trace_id=66666, span_id=1, parent_id=None)
637      trace_id = generate_trace_id_v3(otel_span)
638      span = LiveSpan(otel_span, trace_id)
639  
640      trace_manager = InMemoryTraceManager.get_instance()
641      trace_info = create_test_trace_info(trace_id, _EXPERIMENT_ID)
642      trace_manager.register_trace(otel_span.context.trace_id, trace_info)
643      trace_manager.register_span(span)
644  
645      with (
646          mock.patch(
647              "mlflow.tracing.client.TracingClient.start_trace",
648              return_value=trace_info,
649          ) as mock_start_trace,
650          mock.patch(
651              "mlflow.tracing.client.TracingClient._upload_trace_data", return_value=None
652          ) as mock_upload_trace_data,
653          mock.patch("mlflow.tracing.client.TracingClient.log_spans") as mock_log_spans,
654      ):
655          exporter = MlflowV3SpanExporter()
656          exporter._store_supports_log_spans = False
657          exporter.export([otel_span])
658  
659          mock_start_trace.assert_called_once()
660          # log_spans should NOT be called when store doesn't support it
661          mock_log_spans.assert_not_called()
662          # Artifact upload should still happen as fallback
663          mock_upload_trace_data.assert_called_once()