/ tests / entities / test_span_auto_extract_attachments.py
test_span_auto_extract_attachments.py
  1  import base64
  2  
  3  from pydantic import BaseModel
  4  
  5  from mlflow.entities.span import LiveSpan
  6  from mlflow.tracing.attachments import Attachment
  7  
  8  
  9  def _make_live_span(trace_id="tr-test123"):
 10      from opentelemetry.sdk.trace import TracerProvider
 11  
 12      tracer = TracerProvider().get_tracer("test")
 13      otel_span = tracer.start_span("test_span")
 14      return LiveSpan(otel_span, trace_id=trace_id)
 15  
 16  
 17  PNG_BYTES = base64.b64decode(
 18      "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwADhQGAWjR9awAAAABJRU5ErkJggg=="
 19  )
 20  PNG_DATA_URI = f"data:image/png;base64,{base64.b64encode(PNG_BYTES).decode()}"
 21  
 22  WAV_B64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode()
 23  
 24  
 25  # --- Data URI extraction ---
 26  
 27  
 28  def test_extracts_image_data_uri():
 29      span = _make_live_span()
 30      span.set_inputs({"image": PNG_DATA_URI})
 31  
 32      inputs = span.inputs
 33      assert inputs["image"].startswith("mlflow-attachment://")
 34      assert len(span._attachments) == 1
 35      att = next(iter(span._attachments.values()))
 36      assert att.content_type == "image/png"
 37      assert att.content_bytes == PNG_BYTES
 38  
 39  
 40  def test_extracts_nested_data_uri():
 41      span = _make_live_span()
 42      span.set_inputs({
 43          "messages": [
 44              {
 45                  "role": "user",
 46                  "content": [
 47                      {"type": "text", "text": "What is this?"},
 48                      {
 49                          "type": "image_url",
 50                          "image_url": {"url": PNG_DATA_URI},
 51                      },
 52                  ],
 53              }
 54          ]
 55      })
 56  
 57      inputs = span.inputs
 58      url = inputs["messages"][0]["content"][1]["image_url"]["url"]
 59      assert url.startswith("mlflow-attachment://")
 60      assert len(span._attachments) == 1
 61  
 62  
 63  def test_leaves_http_urls_alone():
 64      span = _make_live_span()
 65      url = "https://example.com/photo.png"
 66      span.set_inputs({"image_url": {"url": url}})
 67      assert span.inputs["image_url"]["url"] == url
 68      assert len(span._attachments) == 0
 69  
 70  
 71  def test_leaves_plain_strings_alone():
 72      span = _make_live_span()
 73      span.set_inputs({"text": "hello world"})
 74      assert span.inputs["text"] == "hello world"
 75      assert len(span._attachments) == 0
 76  
 77  
 78  def test_handles_invalid_base64_gracefully():
 79      span = _make_live_span()
 80      bad_uri = "data:image/png;base64,!!!not-valid-base64!!!"
 81      span.set_inputs({"image": bad_uri})
 82      assert span.inputs["image"] == bad_uri
 83      assert len(span._attachments) == 0
 84  
 85  
 86  def test_rejects_base64_with_trailing_garbage():
 87      # "Zg==!!!" is silently accepted by b64decode without validate=True
 88      span = _make_live_span()
 89      bad_uri = "data:image/png;base64,Zg==!!!"
 90      span.set_inputs({"image": bad_uri})
 91      assert span.inputs["image"] == bad_uri
 92      assert len(span._attachments) == 0
 93  
 94  
 95  def test_handles_empty_mime_type():
 96      span = _make_live_span()
 97      bad_uri = "data:;base64,dGVzdA=="
 98      span.set_inputs({"val": bad_uri})
 99      assert span.inputs["val"] == bad_uri
100      assert len(span._attachments) == 0
101  
102  
103  def test_multiple_data_uris():
104      span = _make_live_span()
105      span.set_inputs({"img1": PNG_DATA_URI, "img2": PNG_DATA_URI})
106      assert span.inputs["img1"].startswith("mlflow-attachment://")
107      assert span.inputs["img2"].startswith("mlflow-attachment://")
108      assert len(span._attachments) == 2
109  
110  
111  # --- Structured content extraction ---
112  
113  
114  def test_extracts_input_audio():
115      span = _make_live_span()
116      span.set_inputs({
117          "messages": [
118              {
119                  "role": "user",
120                  "content": [
121                      {"type": "text", "text": "What does this say?"},
122                      {
123                          "type": "input_audio",
124                          "input_audio": {"data": WAV_B64, "format": "wav"},
125                      },
126                  ],
127              }
128          ]
129      })
130  
131      audio_part = span.inputs["messages"][0]["content"][1]
132      assert audio_part["type"] == "input_audio"
133      assert audio_part["input_audio"]["data"].startswith("mlflow-attachment://")
134      assert audio_part["input_audio"]["format"] == "wav"
135      assert len(span._attachments) == 1
136      att = next(iter(span._attachments.values()))
137      assert att.content_type == "audio/wav"
138  
139  
140  def test_extracts_b64_json():
141      span = _make_live_span()
142      img_b64 = base64.b64encode(PNG_BYTES).decode()
143      span.set_outputs({"data": [{"b64_json": img_b64, "revised_prompt": "a sunset"}]})
144  
145      output = span.outputs
146      item = output["data"][0]
147      assert item["b64_json"].startswith("mlflow-attachment://")
148      assert item["revised_prompt"] == "a sunset"
149      assert len(span._attachments) == 1
150      att = next(iter(span._attachments.values()))
151      assert att.content_type == "image/png"
152      assert att.content_bytes == PNG_BYTES
153  
154  
155  def test_input_audio_with_invalid_base64():
156      span = _make_live_span()
157      span.set_inputs({
158          "content": [
159              {
160                  "type": "input_audio",
161                  "input_audio": {"data": "!!!bad!!!", "format": "wav"},
162              }
163          ]
164      })
165      audio_part = span.inputs["content"][0]
166      assert audio_part["input_audio"]["data"] == "!!!bad!!!"
167      assert len(span._attachments) == 0
168  
169  
170  def test_structured_content_with_sibling_data_uri():
171      span = _make_live_span()
172      span.set_inputs({
173          "content": [
174              {
175                  "type": "input_audio",
176                  "input_audio": {"data": WAV_B64, "format": "wav"},
177                  "extra_image": PNG_DATA_URI,
178              }
179          ]
180      })
181      part = span.inputs["content"][0]
182      assert part["input_audio"]["data"].startswith("mlflow-attachment://")
183      assert part["extra_image"].startswith("mlflow-attachment://")
184      assert len(span._attachments) == 2
185  
186  
187  def test_mixed_content_parts():
188      span = _make_live_span()
189      span.set_inputs({
190          "messages": [
191              {
192                  "role": "user",
193                  "content": [
194                      {"type": "text", "text": "Describe both"},
195                      {
196                          "type": "image_url",
197                          "image_url": {"url": PNG_DATA_URI},
198                      },
199                      {
200                          "type": "input_audio",
201                          "input_audio": {"data": WAV_B64, "format": "mp3"},
202                      },
203                  ],
204              }
205          ]
206      })
207      content = span.inputs["messages"][0]["content"]
208      assert content[0] == {"type": "text", "text": "Describe both"}
209      assert content[1]["image_url"]["url"].startswith("mlflow-attachment://")
210      assert content[2]["input_audio"]["data"].startswith("mlflow-attachment://")
211      assert len(span._attachments) == 2
212  
213  
214  # --- Anthropic image pattern ---
215  
216  
217  def test_extracts_anthropic_image():
218      span = _make_live_span()
219      span.set_inputs({
220          "messages": [
221              {
222                  "role": "user",
223                  "content": [
224                      {"type": "text", "text": "What is this?"},
225                      {
226                          "type": "image",
227                          "source": {
228                              "type": "base64",
229                              "media_type": "image/png",
230                              "data": base64.b64encode(PNG_BYTES).decode(),
231                          },
232                      },
233                  ],
234              }
235          ]
236      })
237  
238      content = span.inputs["messages"][0]["content"]
239      assert content[1]["type"] == "image"
240      assert content[1]["source"]["data"].startswith("mlflow-attachment://")
241      assert content[1]["source"]["type"] == "base64"
242      assert content[1]["source"]["media_type"] == "image/png"
243      assert len(span._attachments) == 1
244      att = next(iter(span._attachments.values()))
245      assert att.content_type == "image/png"
246      assert att.content_bytes == PNG_BYTES
247  
248  
249  def test_extracts_multiple_anthropic_images():
250      span = _make_live_span()
251      img2_bytes = b"fake jpeg bytes"
252      span.set_inputs({
253          "messages": [
254              {
255                  "role": "user",
256                  "content": [
257                      {
258                          "type": "image",
259                          "source": {
260                              "type": "base64",
261                              "media_type": "image/png",
262                              "data": base64.b64encode(PNG_BYTES).decode(),
263                          },
264                      },
265                      {
266                          "type": "image",
267                          "source": {
268                              "type": "base64",
269                              "media_type": "image/jpeg",
270                              "data": base64.b64encode(img2_bytes).decode(),
271                          },
272                      },
273                  ],
274              }
275          ]
276      })
277  
278      content = span.inputs["messages"][0]["content"]
279      assert content[0]["source"]["data"].startswith("mlflow-attachment://")
280      assert content[1]["source"]["data"].startswith("mlflow-attachment://")
281      assert len(span._attachments) == 2
282  
283  
284  # --- Audio output pattern ---
285  
286  
287  def test_extracts_audio_output():
288      span = _make_live_span()
289      audio_b64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode()
290      span.set_outputs({
291          "choices": [
292              {
293                  "message": {
294                      "role": "assistant",
295                      "content": None,
296                      "audio": {
297                          "id": "audio_123",
298                          "data": audio_b64,
299                          "transcript": "Hello world",
300                      },
301                  }
302              }
303          ]
304      })
305  
306      outputs = span.outputs
307      audio = outputs["choices"][0]["message"]["audio"]
308      assert audio["data"].startswith("mlflow-attachment://")
309      assert audio["transcript"] == "Hello world"
310      assert audio["id"] == "audio_123"
311      assert len(span._attachments) == 1
312      att = next(iter(span._attachments.values()))
313      assert att.content_type == "audio/wav"
314  
315  
316  def test_extracts_b64_json_multiple():
317      span = _make_live_span()
318      img_b64 = base64.b64encode(PNG_BYTES).decode()
319      span.set_outputs({
320          "data": [
321              {"b64_json": img_b64, "revised_prompt": "a circle"},
322              {"b64_json": img_b64, "revised_prompt": "a triangle"},
323          ]
324      })
325  
326      output = span.outputs
327      assert output["data"][0]["b64_json"].startswith("mlflow-attachment://")
328      assert output["data"][1]["b64_json"].startswith("mlflow-attachment://")
329      assert output["data"][0]["revised_prompt"] == "a circle"
330      assert output["data"][1]["revised_prompt"] == "a triangle"
331      assert len(span._attachments) == 2
332  
333  
334  # --- Bedrock image pattern ---
335  
336  
337  def test_extracts_bedrock_image():
338      span = _make_live_span()
339      img_b64 = base64.b64encode(PNG_BYTES).decode()
340      span.set_outputs({
341          "output": {
342              "message": {
343                  "content": [
344                      {"text": "Here is the image."},
345                      {
346                          "image": {
347                              "format": "png",
348                              "source": {"bytes": img_b64},
349                          }
350                      },
351                  ]
352              }
353          }
354      })
355  
356      content = span.outputs["output"]["message"]["content"]
357      assert content[0] == {"text": "Here is the image."}
358      img_block = content[1]
359      assert img_block["image"]["format"] == "png"
360      assert img_block["image"]["source"]["bytes"].startswith("mlflow-attachment://")
361      assert len(span._attachments) == 1
362      att = next(iter(span._attachments.values()))
363      assert att.content_type == "image/png"
364      assert att.content_bytes == PNG_BYTES
365  
366  
367  def test_bedrock_image_with_invalid_base64():
368      span = _make_live_span()
369      span.set_outputs({
370          "content": [
371              {
372                  "image": {
373                      "format": "png",
374                      "source": {"bytes": "!!!bad!!!"},
375                  }
376              }
377          ]
378      })
379      img_block = span.outputs["content"][0]
380      assert img_block["image"]["source"]["bytes"] == "!!!bad!!!"
381      assert len(span._attachments) == 0
382  
383  
384  # --- Gemini inline_data pattern ---
385  
386  
387  def test_extracts_gemini_inline_data():
388      span = _make_live_span()
389      img_b64 = base64.b64encode(PNG_BYTES).decode()
390      span.set_outputs({
391          "candidates": [
392              {
393                  "content": {
394                      "parts": [
395                          {"text": "Here is what I see."},
396                          {
397                              "inline_data": {
398                                  "mime_type": "image/png",
399                                  "data": img_b64,
400                              }
401                          },
402                      ]
403                  }
404              }
405          ]
406      })
407  
408      parts = span.outputs["candidates"][0]["content"]["parts"]
409      assert parts[0] == {"text": "Here is what I see."}
410      inline = parts[1]
411      assert inline["inline_data"]["mime_type"] == "image/png"
412      assert inline["inline_data"]["data"].startswith("mlflow-attachment://")
413      assert len(span._attachments) == 1
414      att = next(iter(span._attachments.values()))
415      assert att.content_type == "image/png"
416      assert att.content_bytes == PNG_BYTES
417  
418  
419  def test_extracts_gemini_inline_data_bytes_repr():
420      # Gemini SDK Pydantic serialization produces repr(bytes) instead of base64
421      span = _make_live_span()
422      bytes_repr = repr(PNG_BYTES)
423      span.set_outputs({
424          "candidates": [
425              {
426                  "content": {
427                      "parts": [
428                          {"text": "A small image."},
429                          {
430                              "inline_data": {
431                                  "mime_type": "image/png",
432                                  "data": bytes_repr,
433                              }
434                          },
435                      ]
436                  }
437              }
438          ]
439      })
440  
441      parts = span.outputs["candidates"][0]["content"]["parts"]
442      assert parts[0] == {"text": "A small image."}
443      inline = parts[1]
444      assert inline["inline_data"]["data"].startswith("mlflow-attachment://")
445      assert len(span._attachments) == 1
446      att = next(iter(span._attachments.values()))
447      assert att.content_type == "image/png"
448      assert att.content_bytes == PNG_BYTES
449  
450  
451  def test_gemini_inline_data_with_invalid_base64():
452      span = _make_live_span()
453      span.set_outputs({
454          "parts": [
455              {
456                  "inline_data": {
457                      "mime_type": "image/jpeg",
458                      "data": "!!!bad!!!",
459                  }
460              }
461          ]
462      })
463      inline = span.outputs["parts"][0]
464      assert inline["inline_data"]["data"] == "!!!bad!!!"
465      assert len(span._attachments) == 0
466  
467  
468  # --- Responses API image_generation_call pattern ---
469  
470  
471  def test_extracts_responses_api_image_generation():
472      span = _make_live_span()
473      img_b64 = base64.b64encode(PNG_BYTES).decode()
474      span.set_outputs({
475          "output": [
476              {
477                  "type": "image_generation_call",
478                  "result": img_b64,
479                  "output_format": "png",
480                  "revised_prompt": "a blue square",
481              },
482              {
483                  "type": "message",
484                  "content": [{"type": "output_text", "text": "Here is the image."}],
485              },
486          ]
487      })
488  
489      outputs = span.outputs
490      img_call = outputs["output"][0]
491      assert img_call["type"] == "image_generation_call"
492      assert img_call["result"].startswith("mlflow-attachment://")
493      assert img_call["revised_prompt"] == "a blue square"
494      assert img_call["output_format"] == "png"
495      msg = outputs["output"][1]
496      assert msg["type"] == "message"
497      assert len(span._attachments) == 1
498      att = next(iter(span._attachments.values()))
499      assert att.content_type == "image/png"
500      assert att.content_bytes == PNG_BYTES
501  
502  
503  def test_responses_api_image_generation_with_invalid_base64():
504      span = _make_live_span()
505      span.set_outputs({
506          "output": [
507              {
508                  "type": "image_generation_call",
509                  "result": "!!!bad!!!",
510                  "output_format": "png",
511              }
512          ]
513      })
514      img_call = span.outputs["output"][0]
515      assert img_call["result"] == "!!!bad!!!"
516      assert len(span._attachments) == 0
517  
518  
519  # --- Two-pass serialization extraction ---
520  
521  
522  def test_extracts_base64_from_pydantic_model():
523      """Pydantic models aren't traversable in the first pass but become
524      plain dicts after JSON serialization. The second pass should extract
525      the base64 data from the serialized form.
526      """
527  
528      class AudioOutput(BaseModel):
529          transcript: str
530          audio: dict[str, str]
531  
532      audio_b64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode()
533      output = AudioOutput(
534          transcript="Hello",
535          audio={"data": audio_b64, "id": "audio_123"},
536      )
537  
538      span = _make_live_span()
539      span.set_outputs({"result": output})
540  
541      outputs = span.outputs
542      assert outputs["result"]["audio"]["data"].startswith("mlflow-attachment://")
543      assert outputs["result"]["transcript"] == "Hello"
544      assert len(span._attachments) == 1
545      att = next(iter(span._attachments.values()))
546      assert att.content_type == "audio/wav"
547  
548  
549  def test_two_pass_with_explicit_attachment_and_pydantic():
550      """When a span has both an explicit Attachment (first pass) AND a Pydantic
551      model with base64 (second pass), both should be extracted.
552      """
553  
554      class ImageResult(BaseModel):
555          b64_json: str
556          revised_prompt: str
557  
558      img_b64 = base64.b64encode(PNG_BYTES).decode()
559      pydantic_output = ImageResult(b64_json=img_b64, revised_prompt="a sunset")
560      explicit_att = Attachment(content_type="image/png", content_bytes=PNG_BYTES)
561  
562      span = _make_live_span()
563      span.set_outputs({"image": pydantic_output, "thumbnail": explicit_att})
564  
565      outputs = span.outputs
566      assert outputs["thumbnail"].startswith("mlflow-attachment://")
567      assert outputs["image"]["b64_json"].startswith("mlflow-attachment://")
568      assert len(span._attachments) == 2
569  
570  
571  # --- Opt-out ---
572  
573  
574  def test_opt_out_via_env_var(monkeypatch):
575      monkeypatch.setenv("MLFLOW_TRACE_EXTRACT_ATTACHMENTS", "false")
576      span = _make_live_span()
577      span.set_inputs({"image": PNG_DATA_URI})
578      assert span.inputs["image"] == PNG_DATA_URI
579      assert len(span._attachments) == 0
580  
581  
582  def test_explicit_attachment_still_works_when_opted_out(monkeypatch):
583      monkeypatch.setenv("MLFLOW_TRACE_EXTRACT_ATTACHMENTS", "false")
584      span = _make_live_span()
585      att = Attachment(content_type="image/png", content_bytes=PNG_BYTES)
586      span.set_inputs({"image": att})
587      assert span.inputs["image"].startswith("mlflow-attachment://")
588      assert len(span._attachments) == 1
589  
590  
591  # --- Attachment size limit ---
592  
593  
594  def test_attachment_under_size_limit_is_extracted(monkeypatch):
595      monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", str(len(PNG_BYTES) + 1))
596      span = _make_live_span()
597      span.set_inputs({"image": PNG_DATA_URI})
598  
599      assert span.inputs["image"].startswith("mlflow-attachment://")
600      assert len(span._attachments) == 1
601  
602  
603  def test_attachment_over_size_limit_is_discarded(monkeypatch):
604      monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", str(len(PNG_BYTES) - 1))
605      span = _make_live_span()
606      span.set_inputs({"image": PNG_DATA_URI})
607  
608      assert "[Attachment too large:" in span.inputs["image"]
609      assert len(span._attachments) == 0
610  
611  
612  def test_attachment_size_limit_unset_allows_all():
613      # Default is None (unset) — no limit enforced
614      span = _make_live_span()
615      span.set_inputs({"image": PNG_DATA_URI})
616  
617      assert span.inputs["image"].startswith("mlflow-attachment://")
618      assert len(span._attachments) == 1
619  
620  
621  def test_structured_content_over_size_limit_is_discarded(monkeypatch):
622      monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "1")
623      span = _make_live_span()
624      span.set_inputs({
625          "messages": [
626              {
627                  "role": "user",
628                  "content": [
629                      {
630                          "type": "input_audio",
631                          "input_audio": {"data": WAV_B64, "format": "wav"},
632                      }
633                  ],
634              }
635          ]
636      })
637  
638      audio_part = span.inputs["messages"][0]["content"][0]
639      assert "[Attachment too large:" in audio_part["input_audio"]["data"]
640      assert len(span._attachments) == 0
641  
642  
643  def test_explicit_attachment_over_size_limit_is_discarded(monkeypatch):
644      monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "1")
645      span = _make_live_span()
646      att = Attachment(content_type="image/png", content_bytes=PNG_BYTES)
647      span.set_inputs({"image": att})
648  
649      assert "[Attachment too large:" in span.inputs["image"]
650      assert len(span._attachments) == 0
651  
652  
653  def test_attachment_size_limit_negative_treated_as_disabled(monkeypatch):
654      monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "-1")
655      span = _make_live_span()
656      span.set_inputs({"image": PNG_DATA_URI})
657  
658      assert span.inputs["image"].startswith("mlflow-attachment://")
659      assert len(span._attachments) == 1