test_mlflow_v3_attachments.py
1 from unittest.mock import MagicMock, patch 2 3 from mlflow.tracing.attachments import Attachment 4 from mlflow.tracing.constant import SpansLocation, TraceTagKey 5 from mlflow.tracing.export.mlflow_v3 import MlflowV3SpanExporter 6 7 8 def _make_trace_info_mock(): 9 info = MagicMock() 10 info.trace_id = "tr-test123" 11 info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.ARTIFACT_REPO.value} 12 info.metadata = {} 13 return info 14 15 16 def _make_trace(attachments_map=None): 17 span = MagicMock() 18 span._attachments = attachments_map or {} 19 20 trace = MagicMock() 21 trace.info = _make_trace_info_mock() 22 trace.info.trace_id = "tr-test123" 23 trace.data.spans = [span] 24 return trace 25 26 27 def _make_exporter(mock_client): 28 with patch.object(MlflowV3SpanExporter, "__init__", return_value=None): 29 exporter = MlflowV3SpanExporter() 30 exporter._client = mock_client 31 exporter._store_supports_log_spans = True 32 return exporter 33 34 35 def test_log_trace_uploads_attachments(): 36 att = Attachment(content_type="image/png", content_bytes=b"img") 37 trace = _make_trace({att.id: att}) 38 39 mock_client = MagicMock() 40 returned_info = _make_trace_info_mock() 41 mock_client.start_trace.return_value = returned_info 42 43 exporter = _make_exporter(mock_client) 44 45 with ( 46 patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"), 47 patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"), 48 ): 49 exporter._log_trace(trace, prompts=[]) 50 51 mock_client._upload_trace_data.assert_called_once_with(returned_info, trace.data) 52 mock_client._upload_attachments.assert_called_once() 53 call_args = mock_client._upload_attachments.call_args 54 assert call_args[0][0] is returned_info 55 assert att.id in call_args[0][1] 56 57 58 def test_log_trace_uploads_attachments_in_tracking_store_mode(): 59 att = Attachment(content_type="image/png", content_bytes=b"img") 60 trace = _make_trace({att.id: att}) 61 # Override to TRACKING_STORE mode 62 trace.info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE.value} 63 64 mock_client = MagicMock() 65 returned_info = MagicMock() 66 returned_info.tags = {TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE.value} 67 mock_client.start_trace.return_value = returned_info 68 69 exporter = _make_exporter(mock_client) 70 71 with ( 72 patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"), 73 patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"), 74 ): 75 exporter._log_trace(trace, prompts=[]) 76 77 # traces.json should NOT be uploaded in TRACKING_STORE mode 78 mock_client._upload_trace_data.assert_not_called() 79 # But attachments MUST still be uploaded to the artifact repo 80 mock_client._upload_attachments.assert_called_once() 81 call_args = mock_client._upload_attachments.call_args 82 assert call_args[0][0] is returned_info 83 assert att.id in call_args[0][1] 84 85 86 def test_trace_data_still_lands_when_attachment_upload_fails(): 87 att = Attachment(content_type="image/png", content_bytes=b"img") 88 trace = _make_trace({att.id: att}) 89 90 mock_client = MagicMock() 91 returned_info = _make_trace_info_mock() 92 mock_client.start_trace.return_value = returned_info 93 mock_client._upload_attachments.side_effect = Exception("S3 timeout") 94 95 exporter = _make_exporter(mock_client) 96 97 with ( 98 patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"), 99 patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"), 100 ): 101 # Should not raise 102 exporter._log_trace(trace, prompts=[]) 103 104 # Trace data was still uploaded despite attachment failure 105 mock_client._upload_trace_data.assert_called_once_with(returned_info, trace.data) 106 mock_client._upload_attachments.assert_called_once() 107 108 109 def test_log_trace_skips_upload_when_no_attachments(): 110 trace = _make_trace() 111 112 mock_client = MagicMock() 113 returned_info = _make_trace_info_mock() 114 mock_client.start_trace.return_value = returned_info 115 116 exporter = _make_exporter(mock_client) 117 118 with ( 119 patch("mlflow.tracing.export.mlflow_v3.try_link_prompts_to_trace"), 120 patch("mlflow.tracing.export.mlflow_v3.add_size_stats_to_trace_metadata"), 121 ): 122 exporter._log_trace(trace, prompts=[]) 123 124 mock_client._upload_trace_data.assert_called_once() 125 mock_client._upload_attachments.assert_not_called()