/ tests / tracing / export / test_mlflow_v3_attachments.py
test_mlflow_v3_attachments.py
  1  from unittest.mock import MagicMock, patch
  2  
  3  from mlflow.tracing.attachments import Attachment
  4  from mlflow.tracing.constant import SpansLocation, TraceTagKey
  5  from mlflow.tracing.export.mlflow_v3 import MlflowV3SpanExporter
  6  
  7  
  8  def _make_trace_info_mock():
  9      info = MagicMock()
 10      info.trace_id = "tr-test123"
 11      info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.ARTIFACT_REPO.value}
 12      info.metadata = {}
 13      return info
 14  
 15  
 16  def _make_trace(attachments_map=None):
 17      span = MagicMock()
 18      span._attachments = attachments_map or {}
 19  
 20      trace = MagicMock()
 21      trace.info = _make_trace_info_mock()
 22      trace.info.trace_id = "tr-test123"
 23      trace.data.spans = [span]
 24      return trace
 25  
 26  
 27  def _make_exporter(mock_client):
 28      with patch.object(MlflowV3SpanExporter, "__init__", return_value=None):
 29          exporter = MlflowV3SpanExporter()
 30      exporter._client = mock_client
 31      exporter._store_supports_log_spans = True
 32      return exporter
 33  
 34  
 35  def test_log_trace_uploads_attachments():
 36      att = Attachment(content_type="image/png", content_bytes=b"img")
 37      trace = _make_trace({att.id: att})
 38  
 39      mock_client = MagicMock()
 40      returned_info = _make_trace_info_mock()
 41      mock_client.start_trace.return_value = returned_info
 42  
 43      exporter = _make_exporter(mock_client)
 44  
 45      with (
 46          patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"),
 47          patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"),
 48      ):
 49          exporter._log_trace(trace, prompts=[])
 50  
 51      mock_client._upload_trace_data.assert_called_once_with(returned_info, trace.data)
 52      mock_client._upload_attachments.assert_called_once()
 53      call_args = mock_client._upload_attachments.call_args
 54      assert call_args[0][0] is returned_info
 55      assert att.id in call_args[0][1]
 56  
 57  
 58  def test_log_trace_uploads_attachments_in_tracking_store_mode():
 59      att = Attachment(content_type="image/png", content_bytes=b"img")
 60      trace = _make_trace({att.id: att})
 61      # Override to TRACKING_STORE mode
 62      trace.info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE.value}
 63  
 64      mock_client = MagicMock()
 65      returned_info = MagicMock()
 66      returned_info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE.value}
 67      mock_client.start_trace.return_value = returned_info
 68  
 69      exporter = _make_exporter(mock_client)
 70  
 71      with (
 72          patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"),
 73          patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"),
 74      ):
 75          exporter._log_trace(trace, prompts=[])
 76  
 77      # traces.json should NOT be uploaded in TRACKING_STORE mode
 78      mock_client._upload_trace_data.assert_not_called()
 79      # But attachments MUST still be uploaded to the artifact repo
 80      mock_client._upload_attachments.assert_called_once()
 81      call_args = mock_client._upload_attachments.call_args
 82      assert call_args[0][0] is returned_info
 83      assert att.id in call_args[0][1]
 84  
 85  
 86  def test_trace_data_still_lands_when_attachment_upload_fails():
 87      att = Attachment(content_type="image/png", content_bytes=b"img")
 88      trace = _make_trace({att.id: att})
 89  
 90      mock_client = MagicMock()
 91      returned_info = _make_trace_info_mock()
 92      mock_client.start_trace.return_value = returned_info
 93      mock_client._upload_attachments.side_effect = Exception("S3 timeout")
 94  
 95      exporter = _make_exporter(mock_client)
 96  
 97      with (
 98          patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"),
 99          patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"),
100      ):
101          # Should not raise
102          exporter._log_trace(trace, prompts=[])
103  
104      # Trace data was still uploaded despite attachment failure
105      mock_client._upload_trace_data.assert_called_once_with(returned_info, trace.data)
106      mock_client._upload_attachments.assert_called_once()
107  
108  
109  def test_log_trace_skips_upload_when_no_attachments():
110      trace = _make_trace()
111  
112      mock_client = MagicMock()
113      returned_info = _make_trace_info_mock()
114      mock_client.start_trace.return_value = returned_info
115  
116      exporter = _make_exporter(mock_client)
117  
118      with (
119          patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"),
120          patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"),
121      ):
122          exporter._log_trace(trace, prompts=[])
123  
124      mock_client._upload_trace_data.assert_called_once()
125      mock_client._upload_attachments.assert_not_called()