test_span_auto_extract_attachments.py
1 import base64 2 3 from pydantic import BaseModel 4 5 from mlflow.entities.span import LiveSpan 6 from mlflow.tracing.attachments import Attachment 7 8 9 def _make_live_span(trace_id="tr-test123"): 10 from opentelemetry.sdk.trace import TracerProvider 11 12 tracer = TracerProvider().get_tracer("test") 13 otel_span = tracer.start_span("test_span") 14 return LiveSpan(otel_span, trace_id=trace_id) 15 16 17 PNG_BYTES = base64.b64decode( 18 "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwADhQGAWjR9awAAAABJRU5ErkJggg==" 19 ) 20 PNG_DATA_URI = f"data:image/png;base64,{base64.b64encode(PNG_BYTES).decode()}" 21 22 WAV_B64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode() 23 24 25 # --- Data URI extraction --- 26 27 28 def test_extracts_image_data_uri(): 29 span = _make_live_span() 30 span.set_inputs({"image": PNG_DATA_URI}) 31 32 inputs = span.inputs 33 assert inputs["image"].startswith("mlflow-attachment://") 34 assert len(span._attachments) == 1 35 att = next(iter(span._attachments.values())) 36 assert att.content_type == "image/png" 37 assert att.content_bytes == PNG_BYTES 38 39 40 def test_extracts_nested_data_uri(): 41 span = _make_live_span() 42 span.set_inputs({ 43 "messages": [ 44 { 45 "role": "user", 46 "content": [ 47 {"type": "text", "text": "What is this?"}, 48 { 49 "type": "image_url", 50 "image_url": {"url": PNG_DATA_URI}, 51 }, 52 ], 53 } 54 ] 55 }) 56 57 inputs = span.inputs 58 url = inputs["messages"][0]["content"][1]["image_url"]["url"] 59 assert url.startswith("mlflow-attachment://") 60 assert len(span._attachments) == 1 61 62 63 def test_leaves_http_urls_alone(): 64 span = _make_live_span() 65 url = "https://example.com/photo.png" 66 span.set_inputs({"image_url": {"url": url}}) 67 assert span.inputs["image_url"]["url"] == url 68 assert len(span._attachments) == 0 69 70 71 def test_leaves_plain_strings_alone(): 72 span = _make_live_span() 73 span.set_inputs({"text": "hello world"}) 74 assert span.inputs["text"] == "hello world" 75 assert len(span._attachments) == 0 76 77 78 def test_handles_invalid_base64_gracefully(): 79 span = _make_live_span() 80 bad_uri = "data:image/png;base64,!!!not-valid-base64!!!" 81 span.set_inputs({"image": bad_uri}) 82 assert span.inputs["image"] == bad_uri 83 assert len(span._attachments) == 0 84 85 86 def test_rejects_base64_with_trailing_garbage(): 87 # "Zg==!!!" is silently accepted by b64decode without validate=True 88 span = _make_live_span() 89 bad_uri = "data:image/png;base64,Zg==!!!" 90 span.set_inputs({"image": bad_uri}) 91 assert span.inputs["image"] == bad_uri 92 assert len(span._attachments) == 0 93 94 95 def test_handles_empty_mime_type(): 96 span = _make_live_span() 97 bad_uri = "data:;base64,dGVzdA==" 98 span.set_inputs({"val": bad_uri}) 99 assert span.inputs["val"] == bad_uri 100 assert len(span._attachments) == 0 101 102 103 def test_multiple_data_uris(): 104 span = _make_live_span() 105 span.set_inputs({"img1": PNG_DATA_URI, "img2": PNG_DATA_URI}) 106 assert span.inputs["img1"].startswith("mlflow-attachment://") 107 assert span.inputs["img2"].startswith("mlflow-attachment://") 108 assert len(span._attachments) == 2 109 110 111 # --- Structured content extraction --- 112 113 114 def test_extracts_input_audio(): 115 span = _make_live_span() 116 span.set_inputs({ 117 "messages": [ 118 { 119 "role": "user", 120 "content": [ 121 {"type": "text", "text": "What does this say?"}, 122 { 123 "type": "input_audio", 124 "input_audio": {"data": WAV_B64, "format": "wav"}, 125 }, 126 ], 127 } 128 ] 129 }) 130 131 audio_part = span.inputs["messages"][0]["content"][1] 132 assert audio_part["type"] == "input_audio" 133 assert audio_part["input_audio"]["data"].startswith("mlflow-attachment://") 134 assert audio_part["input_audio"]["format"] == "wav" 135 assert len(span._attachments) == 1 136 att = next(iter(span._attachments.values())) 137 assert att.content_type == "audio/wav" 138 139 140 def test_extracts_b64_json(): 141 span = _make_live_span() 142 img_b64 = base64.b64encode(PNG_BYTES).decode() 143 span.set_outputs({"data": [{"b64_json": img_b64, "revised_prompt": "a sunset"}]}) 144 145 output = span.outputs 146 item = output["data"][0] 147 assert item["b64_json"].startswith("mlflow-attachment://") 148 assert item["revised_prompt"] == "a sunset" 149 assert len(span._attachments) == 1 150 att = next(iter(span._attachments.values())) 151 assert att.content_type == "image/png" 152 assert att.content_bytes == PNG_BYTES 153 154 155 def test_input_audio_with_invalid_base64(): 156 span = _make_live_span() 157 span.set_inputs({ 158 "content": [ 159 { 160 "type": "input_audio", 161 "input_audio": {"data": "!!!bad!!!", "format": "wav"}, 162 } 163 ] 164 }) 165 audio_part = span.inputs["content"][0] 166 assert audio_part["input_audio"]["data"] == "!!!bad!!!" 167 assert len(span._attachments) == 0 168 169 170 def test_structured_content_with_sibling_data_uri(): 171 span = _make_live_span() 172 span.set_inputs({ 173 "content": [ 174 { 175 "type": "input_audio", 176 "input_audio": {"data": WAV_B64, "format": "wav"}, 177 "extra_image": PNG_DATA_URI, 178 } 179 ] 180 }) 181 part = span.inputs["content"][0] 182 assert part["input_audio"]["data"].startswith("mlflow-attachment://") 183 assert part["extra_image"].startswith("mlflow-attachment://") 184 assert len(span._attachments) == 2 185 186 187 def test_mixed_content_parts(): 188 span = _make_live_span() 189 span.set_inputs({ 190 "messages": [ 191 { 192 "role": "user", 193 "content": [ 194 {"type": "text", "text": "Describe both"}, 195 { 196 "type": "image_url", 197 "image_url": {"url": PNG_DATA_URI}, 198 }, 199 { 200 "type": "input_audio", 201 "input_audio": {"data": WAV_B64, "format": "mp3"}, 202 }, 203 ], 204 } 205 ] 206 }) 207 content = span.inputs["messages"][0]["content"] 208 assert content[0] == {"type": "text", "text": "Describe both"} 209 assert content[1]["image_url"]["url"].startswith("mlflow-attachment://") 210 assert content[2]["input_audio"]["data"].startswith("mlflow-attachment://") 211 assert len(span._attachments) == 2 212 213 214 # --- Anthropic image pattern --- 215 216 217 def test_extracts_anthropic_image(): 218 span = _make_live_span() 219 span.set_inputs({ 220 "messages": [ 221 { 222 "role": "user", 223 "content": [ 224 {"type": "text", "text": "What is this?"}, 225 { 226 "type": "image", 227 "source": { 228 "type": "base64", 229 "media_type": "image/png", 230 "data": base64.b64encode(PNG_BYTES).decode(), 231 }, 232 }, 233 ], 234 } 235 ] 236 }) 237 238 content = span.inputs["messages"][0]["content"] 239 assert content[1]["type"] == "image" 240 assert content[1]["source"]["data"].startswith("mlflow-attachment://") 241 assert content[1]["source"]["type"] == "base64" 242 assert content[1]["source"]["media_type"] == "image/png" 243 assert len(span._attachments) == 1 244 att = next(iter(span._attachments.values())) 245 assert att.content_type == "image/png" 246 assert att.content_bytes == PNG_BYTES 247 248 249 def test_extracts_multiple_anthropic_images(): 250 span = _make_live_span() 251 img2_bytes = b"fake jpeg bytes" 252 span.set_inputs({ 253 "messages": [ 254 { 255 "role": "user", 256 "content": [ 257 { 258 "type": "image", 259 "source": { 260 "type": "base64", 261 "media_type": "image/png", 262 "data": base64.b64encode(PNG_BYTES).decode(), 263 }, 264 }, 265 { 266 "type": "image", 267 "source": { 268 "type": "base64", 269 "media_type": "image/jpeg", 270 "data": base64.b64encode(img2_bytes).decode(), 271 }, 272 }, 273 ], 274 } 275 ] 276 }) 277 278 content = span.inputs["messages"][0]["content"] 279 assert content[0]["source"]["data"].startswith("mlflow-attachment://") 280 assert content[1]["source"]["data"].startswith("mlflow-attachment://") 281 assert len(span._attachments) == 2 282 283 284 # --- Audio output pattern --- 285 286 287 def test_extracts_audio_output(): 288 span = _make_live_span() 289 audio_b64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode() 290 span.set_outputs({ 291 "choices": [ 292 { 293 "message": { 294 "role": "assistant", 295 "content": None, 296 "audio": { 297 "id": "audio_123", 298 "data": audio_b64, 299 "transcript": "Hello world", 300 }, 301 } 302 } 303 ] 304 }) 305 306 outputs = span.outputs 307 audio = outputs["choices"][0]["message"]["audio"] 308 assert audio["data"].startswith("mlflow-attachment://") 309 assert audio["transcript"] == "Hello world" 310 assert audio["id"] == "audio_123" 311 assert len(span._attachments) == 1 312 att = next(iter(span._attachments.values())) 313 assert att.content_type == "audio/wav" 314 315 316 def test_extracts_b64_json_multiple(): 317 span = _make_live_span() 318 img_b64 = base64.b64encode(PNG_BYTES).decode() 319 span.set_outputs({ 320 "data": [ 321 {"b64_json": img_b64, "revised_prompt": "a circle"}, 322 {"b64_json": img_b64, "revised_prompt": "a triangle"}, 323 ] 324 }) 325 326 output = span.outputs 327 assert output["data"][0]["b64_json"].startswith("mlflow-attachment://") 328 assert output["data"][1]["b64_json"].startswith("mlflow-attachment://") 329 assert output["data"][0]["revised_prompt"] == "a circle" 330 assert output["data"][1]["revised_prompt"] == "a triangle" 331 assert len(span._attachments) == 2 332 333 334 # --- Bedrock image pattern --- 335 336 337 def test_extracts_bedrock_image(): 338 span = _make_live_span() 339 img_b64 = base64.b64encode(PNG_BYTES).decode() 340 span.set_outputs({ 341 "output": { 342 "message": { 343 "content": [ 344 {"text": "Here is the image."}, 345 { 346 "image": { 347 "format": "png", 348 "source": {"bytes": img_b64}, 349 } 350 }, 351 ] 352 } 353 } 354 }) 355 356 content = span.outputs["output"]["message"]["content"] 357 assert content[0] == {"text": "Here is the image."} 358 img_block = content[1] 359 assert img_block["image"]["format"] == "png" 360 assert img_block["image"]["source"]["bytes"].startswith("mlflow-attachment://") 361 assert len(span._attachments) == 1 362 att = next(iter(span._attachments.values())) 363 assert att.content_type == "image/png" 364 assert att.content_bytes == PNG_BYTES 365 366 367 def test_bedrock_image_with_invalid_base64(): 368 span = _make_live_span() 369 span.set_outputs({ 370 "content": [ 371 { 372 "image": { 373 "format": "png", 374 "source": {"bytes": "!!!bad!!!"}, 375 } 376 } 377 ] 378 }) 379 img_block = span.outputs["content"][0] 380 assert img_block["image"]["source"]["bytes"] == "!!!bad!!!" 381 assert len(span._attachments) == 0 382 383 384 # --- Gemini inline_data pattern --- 385 386 387 def test_extracts_gemini_inline_data(): 388 span = _make_live_span() 389 img_b64 = base64.b64encode(PNG_BYTES).decode() 390 span.set_outputs({ 391 "candidates": [ 392 { 393 "content": { 394 "parts": [ 395 {"text": "Here is what I see."}, 396 { 397 "inline_data": { 398 "mime_type": "image/png", 399 "data": img_b64, 400 } 401 }, 402 ] 403 } 404 } 405 ] 406 }) 407 408 parts = span.outputs["candidates"][0]["content"]["parts"] 409 assert parts[0] == {"text": "Here is what I see."} 410 inline = parts[1] 411 assert inline["inline_data"]["mime_type"] == "image/png" 412 assert inline["inline_data"]["data"].startswith("mlflow-attachment://") 413 assert len(span._attachments) == 1 414 att = next(iter(span._attachments.values())) 415 assert att.content_type == "image/png" 416 assert att.content_bytes == PNG_BYTES 417 418 419 def test_extracts_gemini_inline_data_bytes_repr(): 420 # Gemini SDK Pydantic serialization produces repr(bytes) instead of base64 421 span = _make_live_span() 422 bytes_repr = repr(PNG_BYTES) 423 span.set_outputs({ 424 "candidates": [ 425 { 426 "content": { 427 "parts": [ 428 {"text": "A small image."}, 429 { 430 "inline_data": { 431 "mime_type": "image/png", 432 "data": bytes_repr, 433 } 434 }, 435 ] 436 } 437 } 438 ] 439 }) 440 441 parts = span.outputs["candidates"][0]["content"]["parts"] 442 assert parts[0] == {"text": "A small image."} 443 inline = parts[1] 444 assert inline["inline_data"]["data"].startswith("mlflow-attachment://") 445 assert len(span._attachments) == 1 446 att = next(iter(span._attachments.values())) 447 assert att.content_type == "image/png" 448 assert att.content_bytes == PNG_BYTES 449 450 451 def test_gemini_inline_data_with_invalid_base64(): 452 span = _make_live_span() 453 span.set_outputs({ 454 "parts": [ 455 { 456 "inline_data": { 457 "mime_type": "image/jpeg", 458 "data": "!!!bad!!!", 459 } 460 } 461 ] 462 }) 463 inline = span.outputs["parts"][0] 464 assert inline["inline_data"]["data"] == "!!!bad!!!" 465 assert len(span._attachments) == 0 466 467 468 # --- Responses API image_generation_call pattern --- 469 470 471 def test_extracts_responses_api_image_generation(): 472 span = _make_live_span() 473 img_b64 = base64.b64encode(PNG_BYTES).decode() 474 span.set_outputs({ 475 "output": [ 476 { 477 "type": "image_generation_call", 478 "result": img_b64, 479 "output_format": "png", 480 "revised_prompt": "a blue square", 481 }, 482 { 483 "type": "message", 484 "content": [{"type": "output_text", "text": "Here is the image."}], 485 }, 486 ] 487 }) 488 489 outputs = span.outputs 490 img_call = outputs["output"][0] 491 assert img_call["type"] == "image_generation_call" 492 assert img_call["result"].startswith("mlflow-attachment://") 493 assert img_call["revised_prompt"] == "a blue square" 494 assert img_call["output_format"] == "png" 495 msg = outputs["output"][1] 496 assert msg["type"] == "message" 497 assert len(span._attachments) == 1 498 att = next(iter(span._attachments.values())) 499 assert att.content_type == "image/png" 500 assert att.content_bytes == PNG_BYTES 501 502 503 def test_responses_api_image_generation_with_invalid_base64(): 504 span = _make_live_span() 505 span.set_outputs({ 506 "output": [ 507 { 508 "type": "image_generation_call", 509 "result": "!!!bad!!!", 510 "output_format": "png", 511 } 512 ] 513 }) 514 img_call = span.outputs["output"][0] 515 assert img_call["result"] == "!!!bad!!!" 516 assert len(span._attachments) == 0 517 518 519 # --- Two-pass serialization extraction --- 520 521 522 def test_extracts_base64_from_pydantic_model(): 523 """Pydantic models aren't traversable in the first pass but become 524 plain dicts after JSON serialization. The second pass should extract 525 the base64 data from the serialized form. 526 """ 527 528 class AudioOutput(BaseModel): 529 transcript: str 530 audio: dict[str, str] 531 532 audio_b64 = base64.b64encode(b"RIFF\x00\x00\x00\x00WAVEfmt ").decode() 533 output = AudioOutput( 534 transcript="Hello", 535 audio={"data": audio_b64, "id": "audio_123"}, 536 ) 537 538 span = _make_live_span() 539 span.set_outputs({"result": output}) 540 541 outputs = span.outputs 542 assert outputs["result"]["audio"]["data"].startswith("mlflow-attachment://") 543 assert outputs["result"]["transcript"] == "Hello" 544 assert len(span._attachments) == 1 545 att = next(iter(span._attachments.values())) 546 assert att.content_type == "audio/wav" 547 548 549 def test_two_pass_with_explicit_attachment_and_pydantic(): 550 """When a span has both an explicit Attachment (first pass) AND a Pydantic 551 model with base64 (second pass), both should be extracted. 552 """ 553 554 class ImageResult(BaseModel): 555 b64_json: str 556 revised_prompt: str 557 558 img_b64 = base64.b64encode(PNG_BYTES).decode() 559 pydantic_output = ImageResult(b64_json=img_b64, revised_prompt="a sunset") 560 explicit_att = Attachment(content_type="image/png", content_bytes=PNG_BYTES) 561 562 span = _make_live_span() 563 span.set_outputs({"image": pydantic_output, "thumbnail": explicit_att}) 564 565 outputs = span.outputs 566 assert outputs["thumbnail"].startswith("mlflow-attachment://") 567 assert outputs["image"]["b64_json"].startswith("mlflow-attachment://") 568 assert len(span._attachments) == 2 569 570 571 # --- Opt-out --- 572 573 574 def test_opt_out_via_env_var(monkeypatch): 575 monkeypatch.setenv("MLFLOW_TRACE_EXTRACT_ATTACHMENTS", "false") 576 span = _make_live_span() 577 span.set_inputs({"image": PNG_DATA_URI}) 578 assert span.inputs["image"] == PNG_DATA_URI 579 assert len(span._attachments) == 0 580 581 582 def test_explicit_attachment_still_works_when_opted_out(monkeypatch): 583 monkeypatch.setenv("MLFLOW_TRACE_EXTRACT_ATTACHMENTS", "false") 584 span = _make_live_span() 585 att = Attachment(content_type="image/png", content_bytes=PNG_BYTES) 586 span.set_inputs({"image": att}) 587 assert span.inputs["image"].startswith("mlflow-attachment://") 588 assert len(span._attachments) == 1 589 590 591 # --- Attachment size limit --- 592 593 594 def test_attachment_under_size_limit_is_extracted(monkeypatch): 595 monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", str(len(PNG_BYTES) + 1)) 596 span = _make_live_span() 597 span.set_inputs({"image": PNG_DATA_URI}) 598 599 assert span.inputs["image"].startswith("mlflow-attachment://") 600 assert len(span._attachments) == 1 601 602 603 def test_attachment_over_size_limit_is_discarded(monkeypatch): 604 monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", str(len(PNG_BYTES) - 1)) 605 span = _make_live_span() 606 span.set_inputs({"image": PNG_DATA_URI}) 607 608 assert "[Attachment too large:" in span.inputs["image"] 609 assert len(span._attachments) == 0 610 611 612 def test_attachment_size_limit_unset_allows_all(): 613 # Default is None (unset) — no limit enforced 614 span = _make_live_span() 615 span.set_inputs({"image": PNG_DATA_URI}) 616 617 assert span.inputs["image"].startswith("mlflow-attachment://") 618 assert len(span._attachments) == 1 619 620 621 def test_structured_content_over_size_limit_is_discarded(monkeypatch): 622 monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "1") 623 span = _make_live_span() 624 span.set_inputs({ 625 "messages": [ 626 { 627 "role": "user", 628 "content": [ 629 { 630 "type": "input_audio", 631 "input_audio": {"data": WAV_B64, "format": "wav"}, 632 } 633 ], 634 } 635 ] 636 }) 637 638 audio_part = span.inputs["messages"][0]["content"][0] 639 assert "[Attachment too large:" in audio_part["input_audio"]["data"] 640 assert len(span._attachments) == 0 641 642 643 def test_explicit_attachment_over_size_limit_is_discarded(monkeypatch): 644 monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "1") 645 span = _make_live_span() 646 att = Attachment(content_type="image/png", content_bytes=PNG_BYTES) 647 span.set_inputs({"image": att}) 648 649 assert "[Attachment too large:" in span.inputs["image"] 650 assert len(span._attachments) == 0 651 652 653 def test_attachment_size_limit_negative_treated_as_disabled(monkeypatch): 654 monkeypatch.setenv("MLFLOW_TRACE_MAX_ATTACHMENT_SIZE", "-1") 655 span = _make_live_span() 656 span.set_inputs({"image": PNG_DATA_URI}) 657 658 assert span.inputs["image"].startswith("mlflow-attachment://") 659 assert len(span._attachments) == 1