test_gemini_autolog.py
1 """ 2 This file contains unit tests for the new Gemini Python SDK 3 https://github.com/googleapis/python-genai 4 """ 5 6 import asyncio 7 import base64 8 import importlib.metadata 9 import re 10 from unittest.mock import patch 11 12 import pytest 13 from google import genai 14 from packaging.version import Version 15 16 import mlflow 17 from mlflow.entities.span import SpanType 18 from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey 19 from mlflow.version import IS_TRACING_SDK_ONLY 20 21 from tests.tracing.helper import get_traces 22 23 google_gemini_version = Version(importlib.metadata.version("google.genai")) 24 is_gemini_1_7_or_newer = google_gemini_version >= Version("1.7.0") 25 26 _CONTENT = {"parts": [{"text": "test answer"}], "role": "model"} 27 28 _USER_METADATA = { 29 "prompt_token_count": 6, 30 "candidates_token_count": 6, 31 "total_token_count": 12, 32 "cached_content_token_count": 0, 33 } 34 35 _USER_METADATA_WITH_CACHE = { 36 "prompt_token_count": 50, 37 "candidates_token_count": 20, 38 "total_token_count": 70, 39 "cached_content_token_count": 30, 40 } 41 42 43 def _get_candidate(content): 44 candidate = { 45 "content": content, 46 "avg_logprobs": 0.0, 47 "finish_reason": "STOP", 48 "safety_ratings": [], 49 "token_count": 0, 50 } 51 52 return genai.types.Candidate(**candidate) 53 54 55 def _generate_content_response(content, usage_metadata=None): 56 res = { 57 "candidates": [_get_candidate(content)], 58 "usage_metadata": usage_metadata or _USER_METADATA, 59 "automatic_function_calling_history": [], 60 } 61 62 return genai.types.GenerateContentResponse(**res) 63 64 65 _DUMMY_GENERATE_CONTENT_RESPONSE = _generate_content_response(_CONTENT) 66 67 _DUMMY_COUNT_TOKENS_RESPONSE = {"total_count": 10} 68 69 _DUMMY_EMBEDDING_RESPONSE = {"embedding": [1, 2, 3]} 70 71 72 def _dummy_generate_content(is_async: bool): 73 if is_async: 74 75 async def _generate_content(self, model, contents, config): 76 return _DUMMY_GENERATE_CONTENT_RESPONSE 77 78 else: 79 80 def _generate_content(self, model, contents, config): 81 return _DUMMY_GENERATE_CONTENT_RESPONSE 82 83 return _generate_content 84 85 86 def send_message(self, content): 87 return _DUMMY_GENERATE_CONTENT_RESPONSE 88 89 90 def count_tokens(self, model, contents): 91 return _DUMMY_COUNT_TOKENS_RESPONSE 92 93 94 def embed_content(self, model, content): 95 return _DUMMY_EMBEDDING_RESPONSE 96 97 98 def multiply(a: float, b: float): 99 """returns a * b.""" 100 return a * b 101 102 103 TOOL_ATTRIBUTE = [ 104 { 105 "type": "function", 106 "function": { 107 "name": "multiply", 108 "description": "returns a * b.", 109 "parameters": { 110 "properties": { 111 "a": {"type": "number", "description": None, "enum": None}, 112 "b": {"type": "number", "description": None, "enum": None}, 113 }, 114 "required": ["a", "b"] if is_gemini_1_7_or_newer else None, 115 }, 116 }, 117 }, 118 ] 119 120 121 @pytest.fixture(autouse=True) 122 def cleanup(): 123 yield 124 mlflow.gemini.autolog(disable=True) 125 126 127 @pytest.fixture(params=[True, False], ids=["async", "sync"]) 128 def is_async(request): 129 return request.param 130 131 132 def _call_generate_content( 133 is_async: bool, contents: str, model: str = "gemini-1.5-flash", config=None 134 ): 135 client = genai.Client(api_key="dummy") 136 if is_async: 137 return asyncio.run( 138 client.aio.models.generate_content(model=model, contents=contents, config=config) 139 ) 140 else: 141 return client.models.generate_content(model=model, contents=contents, config=config) 142 143 144 def _create_chat_and_send_message(is_async: bool, message: str): 145 client = genai.Client(api_key="dummy") 146 if is_async: 147 chat = client.aio.chats.create(model="gemini-1.5-flash") 148 return asyncio.run(chat.send_message(message)) 149 else: 150 chat = client.chats.create(model="gemini-1.5-flash") 151 return chat.send_message(message) 152 153 154 def test_generate_content_enable_disable_autolog(is_async, mock_litellm_cost): 155 cls = "AsyncModels" if is_async else "Models" 156 with ( 157 patch( 158 f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async) 159 ), 160 ): 161 mlflow.gemini.autolog() 162 _call_generate_content(is_async, "test content") 163 164 traces = get_traces() 165 assert len(traces) == 1 166 assert traces[0].info.status == "OK" 167 assert len(traces[0].data.spans) == 2 168 169 span = traces[0].data.spans[0] 170 assert span.name == f"{cls}.generate_content" 171 assert span.span_type == SpanType.LLM 172 assert span.inputs == { 173 "contents": "test content", 174 "model": "gemini-1.5-flash", 175 "config": None, 176 } 177 assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump() 178 assert span.model_name == "gemini-1.5-flash" 179 180 span1 = traces[0].data.spans[1] 181 assert span1.name == f"{cls}._generate_content" 182 assert span1.span_type == SpanType.LLM 183 assert span1.inputs == { 184 "contents": "test content", 185 "model": "gemini-1.5-flash", 186 "config": None, 187 } 188 assert span1.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump() 189 190 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 191 TokenUsageKey.INPUT_TOKENS: 6, 192 TokenUsageKey.OUTPUT_TOKENS: 6, 193 TokenUsageKey.TOTAL_TOKENS: 12, 194 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0, 195 } 196 197 if not IS_TRACING_SDK_ONLY: 198 # Verify cost is calculated (6 input tokens * 1.0 + 6 output tokens * 2.0) 199 assert span.llm_cost == { 200 "input_cost": 6.0, 201 "output_cost": 12.0, 202 "total_cost": 18.0, 203 } 204 205 assert traces[0].info.token_usage == { 206 "input_tokens": 6, 207 "output_tokens": 6, 208 "total_tokens": 12, 209 "cache_read_input_tokens": 0, 210 } 211 212 mlflow.gemini.autolog(disable=True) 213 _call_generate_content(is_async, "test content") 214 215 # No new trace should be created 216 traces = get_traces() 217 assert len(traces) == 1 218 219 220 def test_generate_content_tracing_with_error(is_async): 221 if is_async: 222 223 async def _generate_content(self, model, contents, config): 224 raise Exception("dummy error") 225 226 else: 227 228 def _generate_content(self, model, contents, config): 229 raise Exception("dummy error") 230 231 cls = "AsyncModels" if is_async else "Models" 232 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 233 mlflow.gemini.autolog() 234 235 with pytest.raises(Exception, match="dummy error"): 236 _call_generate_content(is_async, "test content") 237 238 traces = get_traces() 239 assert len(traces) == 1 240 assert len(traces[0].data.spans) == 2 241 242 assert traces[0].info.status == "ERROR" 243 assert traces[0].data.spans[0].status.status_code == "ERROR" 244 assert traces[0].data.spans[0].status.description == "Exception: dummy error" 245 assert traces[0].data.spans[1].status.status_code == "ERROR" 246 assert traces[0].data.spans[1].status.description == "Exception: dummy error" 247 248 249 def test_generate_content_image_autolog(mock_litellm_cost): 250 image = base64.b64encode(b"image").decode("utf-8") 251 request = [ 252 genai.types.Part.from_bytes(mime_type="image/jpeg", data=image), 253 "Caption this image", 254 ] 255 cls = "AsyncModels" if is_async else "Models" 256 with patch( 257 f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async) 258 ): 259 mlflow.gemini.autolog() 260 _call_generate_content(is_async, request) 261 262 traces = get_traces() 263 assert len(traces) == 1 264 assert traces[0].info.status == "OK" 265 assert len(traces[0].data.spans) == 2 266 267 span = traces[0].data.spans[0] 268 assert span.name == f"{cls}.generate_content" 269 assert span.span_type == SpanType.LLM 270 assert span.inputs["model"] == "gemini-1.5-flash" 271 extra = {"display_name": None} if google_gemini_version >= Version("1.15.0") else {} 272 inline_data = span.inputs["contents"][0]["inline_data"] 273 assert inline_data["mime_type"] == "image/jpeg" 274 # Auto-extraction replaces bytes repr with mlflow-attachment:// URI 275 assert inline_data["data"].startswith("mlflow-attachment://") 276 assert "content_type=image%2Fjpeg" in inline_data["data"] 277 if extra: 278 assert inline_data["display_name"] is None 279 assert span.inputs["contents"][1] == "Caption this image" 280 assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump() 281 assert span.model_name == "gemini-1.5-flash" 282 283 span1 = traces[0].data.spans[1] 284 assert span1.name == f"{cls}._generate_content" 285 assert span1.span_type == SpanType.LLM 286 assert span1.parent_id == span.span_id 287 assert span1.inputs["model"] == "gemini-1.5-flash" 288 inline_data1 = span1.inputs["contents"][0]["inline_data"] 289 assert inline_data1["mime_type"] == "image/jpeg" 290 assert inline_data1["data"].startswith("mlflow-attachment://") 291 assert "content_type=image%2Fjpeg" in inline_data1["data"] 292 if extra: 293 assert inline_data1["display_name"] is None 294 assert span1.inputs["contents"][1] == "Caption this image" 295 assert span1.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump() 296 297 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 298 TokenUsageKey.INPUT_TOKENS: 6, 299 TokenUsageKey.OUTPUT_TOKENS: 6, 300 TokenUsageKey.TOTAL_TOKENS: 12, 301 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0, 302 } 303 if not IS_TRACING_SDK_ONLY: 304 assert span.llm_cost == { 305 "input_cost": 6.0, 306 "output_cost": 12.0, 307 "total_cost": 18.0, 308 } 309 310 assert traces[0].info.token_usage == { 311 "input_tokens": 6, 312 "output_tokens": 6, 313 "total_tokens": 12, 314 "cache_read_input_tokens": 0, 315 } 316 317 318 def test_generate_content_tool_calling_autolog(is_async, mock_litellm_cost): 319 tool_call_content = { 320 "parts": [ 321 { 322 "function_call": { 323 "name": "multiply", 324 "args": { 325 "a": 57.0, 326 "b": 44.0, 327 }, 328 } 329 } 330 ], 331 "role": "model", 332 } 333 334 response = _generate_content_response(tool_call_content) 335 if is_async: 336 337 async def _generate_content(self, model, contents, config): 338 return response 339 340 else: 341 342 def _generate_content(self, model, contents, config): 343 return response 344 345 cls = "AsyncModels" if is_async else "Models" 346 347 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 348 mlflow.gemini.autolog() 349 _call_generate_content( 350 is_async, 351 model="gemini-1.5-flash", 352 contents="I have 57 cats, each owns 44 mittens, how many mittens is that in total?", 353 config=genai.types.GenerateContentConfig( 354 tools=[multiply], 355 automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True), 356 ), 357 ) 358 359 traces = get_traces() 360 assert len(traces) == 1 361 assert traces[0].info.status == "OK" 362 assert len(traces[0].data.spans) == 2 363 364 span = traces[0].data.spans[0] 365 assert span.name == f"{cls}.generate_content" 366 assert span.span_type == SpanType.LLM 367 assert ( 368 span.inputs["contents"] 369 == "I have 57 cats, each owns 44 mittens, how many mittens is that in total?" 370 ) 371 assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == TOOL_ATTRIBUTE 372 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini" 373 assert span.model_name == "gemini-1.5-flash" 374 375 span1 = traces[0].data.spans[1] 376 assert span1.name == f"{cls}._generate_content" 377 assert span1.span_type == SpanType.LLM 378 assert span1.parent_id == span.span_id 379 assert ( 380 span1.inputs["contents"] 381 == "I have 57 cats, each owns 44 mittens, how many mittens is that in total?" 382 ) 383 assert span1.get_attribute(SpanAttributeKey.CHAT_TOOLS) == TOOL_ATTRIBUTE 384 assert span1.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini" 385 386 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 387 TokenUsageKey.INPUT_TOKENS: 6, 388 TokenUsageKey.OUTPUT_TOKENS: 6, 389 TokenUsageKey.TOTAL_TOKENS: 12, 390 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0, 391 } 392 if not IS_TRACING_SDK_ONLY: 393 assert span.llm_cost == { 394 "input_cost": 6.0, 395 "output_cost": 12.0, 396 "total_cost": 18.0, 397 } 398 399 assert traces[0].info.token_usage == { 400 "input_tokens": 6, 401 "output_tokens": 6, 402 "total_tokens": 12, 403 "cache_read_input_tokens": 0, 404 } 405 406 407 def test_generate_content_tool_calling_chat_history_autolog(is_async, mock_litellm_cost): 408 question_content = genai.types.Content(**{ 409 "parts": [ 410 { 411 "text": "I have 57 cats, each owns 44 mittens, how many mittens in total?", 412 } 413 ], 414 "role": "user", 415 }) 416 417 tool_call_content = genai.types.Content(**{ 418 "parts": [ 419 { 420 "function_call": { 421 "name": "multiply", 422 "args": { 423 "a": 57.0, 424 "b": 44.0, 425 }, 426 } 427 } 428 ], 429 "role": "model", 430 }) 431 432 tool_response_content = genai.types.Content(**{ 433 "parts": [{"function_response": {"name": "multiply", "response": {"result": 2508.0}}}], 434 "role": "user", 435 }) 436 437 response = _generate_content_response( 438 genai.types.Content(**{ 439 "parts": [ 440 { 441 "text": "57 cats * 44 mittens/cat = 2508 mittens in total.", 442 } 443 ], 444 "role": "model", 445 }) 446 ) 447 448 cls = "AsyncModels" if is_async else "Models" 449 450 if is_async: 451 452 async def _generate_content(self, model, contents, config): 453 return response 454 455 else: 456 457 def _generate_content(self, model, contents, config): 458 return response 459 460 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 461 mlflow.gemini.autolog() 462 _call_generate_content( 463 is_async, 464 model="gemini-1.5-flash", 465 contents=[question_content, tool_call_content, tool_response_content], 466 config=genai.types.GenerateContentConfig( 467 tools=[multiply], 468 automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True), 469 ), 470 ) 471 472 traces = get_traces() 473 assert len(traces) == 1 474 assert traces[0].info.status == "OK" 475 assert len(traces[0].data.spans) == 2 476 477 span = traces[0].data.spans[0] 478 assert span.name == f"{cls}.generate_content" 479 assert span.span_type == SpanType.LLM 480 assert span.inputs["contents"] == [ 481 question_content.model_dump(), 482 tool_call_content.model_dump(), 483 tool_response_content.model_dump(), 484 ] 485 assert span.inputs["model"] == "gemini-1.5-flash" 486 assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE 487 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini" 488 assert span.model_name == "gemini-1.5-flash" 489 490 span1 = traces[0].data.spans[1] 491 assert span1.name == f"{cls}._generate_content" 492 assert span1.span_type == SpanType.LLM 493 assert span1.parent_id == span.span_id 494 assert span1.inputs["contents"] == [ 495 question_content.model_dump(), 496 tool_call_content.model_dump(), 497 tool_response_content.model_dump(), 498 ] 499 assert span1.inputs["model"] == "gemini-1.5-flash" 500 assert span1.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE 501 assert span1.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini" 502 503 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 504 TokenUsageKey.INPUT_TOKENS: 6, 505 TokenUsageKey.OUTPUT_TOKENS: 6, 506 TokenUsageKey.TOTAL_TOKENS: 12, 507 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0, 508 } 509 if not IS_TRACING_SDK_ONLY: 510 assert span.llm_cost == { 511 "input_cost": 6.0, 512 "output_cost": 12.0, 513 "total_cost": 18.0, 514 } 515 516 assert traces[0].info.token_usage == { 517 "input_tokens": 6, 518 "output_tokens": 6, 519 "total_tokens": 12, 520 "cache_read_input_tokens": 0, 521 } 522 523 524 def test_chat_session_autolog(is_async): 525 cls = "AsyncModels" if is_async else "Models" 526 with patch( 527 f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async) 528 ): 529 mlflow.gemini.autolog() 530 _create_chat_and_send_message(is_async, "test content") 531 532 traces = get_traces() 533 assert len(traces) == 1 534 assert traces[0].info.status == "OK" 535 assert len(traces[0].data.spans) == 3 536 span = traces[0].data.spans[0] 537 assert span.name == "AsyncChat.send_message" if is_async else "Chat.send_message" 538 assert span.span_type == SpanType.CHAT_MODEL 539 assert span.inputs == {"message": "test content"} 540 assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump() 541 assert span.model_name == "gemini-1.5-flash" 542 543 mlflow.gemini.autolog(disable=True) 544 _create_chat_and_send_message(is_async, "test content") 545 546 # No new trace should be created 547 traces = get_traces() 548 assert len(traces) == 1 549 550 551 def test_count_tokens_autolog(): 552 with patch("google.genai.models.Models.count_tokens", new=count_tokens): 553 mlflow.gemini.autolog() 554 client = genai.Client(api_key="dummy") 555 client.models.count_tokens(model="gemini-1.5-flash", contents="test content") 556 557 traces = get_traces() 558 assert len(traces) == 1 559 assert traces[0].info.status == "OK" 560 assert len(traces[0].data.spans) == 1 561 span = traces[0].data.spans[0] 562 assert span.name == "Models.count_tokens" 563 assert span.span_type == SpanType.LLM 564 assert span.inputs == {"contents": "test content", "model": "gemini-1.5-flash"} 565 assert span.outputs == _DUMMY_COUNT_TOKENS_RESPONSE 566 assert span.model_name == "gemini-1.5-flash" 567 568 mlflow.gemini.autolog(disable=True) 569 client = genai.Client(api_key="dummy") 570 client.models.count_tokens(model="gemini-1.5-flash", contents="test content") 571 572 # No new trace should be created 573 traces = get_traces() 574 assert len(traces) == 1 575 576 577 def test_embed_content_autolog(): 578 with patch("google.genai.models.Models.embed_content", new=embed_content): 579 mlflow.gemini.autolog() 580 client = genai.Client(api_key="dummy") 581 client.models.embed_content(model="text-embedding-004", content="Hello World") 582 583 traces = get_traces() 584 assert len(traces) == 1 585 assert traces[0].info.status == "OK" 586 assert len(traces[0].data.spans) == 1 587 span = traces[0].data.spans[0] 588 assert span.name == "Models.embed_content" 589 assert span.span_type == SpanType.EMBEDDING 590 assert span.inputs == {"content": "Hello World", "model": "text-embedding-004"} 591 assert span.outputs == _DUMMY_EMBEDDING_RESPONSE 592 assert span.model_name == "text-embedding-004" 593 594 mlflow.gemini.autolog(disable=True) 595 client.models.embed_content(model="text-embedding-004", content="Hello World") 596 597 # No new trace should be created 598 traces = get_traces() 599 assert len(traces) == 1 600 601 602 def test_generate_content_cached_tokens(is_async, mock_litellm_cost): 603 cached_response = _generate_content_response(_CONTENT, _USER_METADATA_WITH_CACHE) 604 605 if is_async: 606 607 async def _generate_content(self, model, contents, config): 608 return cached_response 609 610 else: 611 612 def _generate_content(self, model, contents, config): 613 return cached_response 614 615 cls = "AsyncModels" if is_async else "Models" 616 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 617 mlflow.gemini.autolog() 618 _call_generate_content(is_async, "test content") 619 620 traces = get_traces() 621 assert len(traces) == 1 622 span = traces[0].data.spans[0] 623 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 624 TokenUsageKey.INPUT_TOKENS: 50, 625 TokenUsageKey.OUTPUT_TOKENS: 20, 626 TokenUsageKey.TOTAL_TOKENS: 70, 627 TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30, 628 } 629 630 631 def test_tracing_headers_injected_in_config(is_async): 632 captured_config = {} 633 634 if is_async: 635 636 async def _generate_content(self, model, contents, config): 637 captured_config["config"] = config 638 return _DUMMY_GENERATE_CONTENT_RESPONSE 639 640 else: 641 642 def _generate_content(self, model, contents, config): 643 captured_config["config"] = config 644 return _DUMMY_GENERATE_CONTENT_RESPONSE 645 646 cls = "AsyncModels" if is_async else "Models" 647 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 648 mlflow.gemini.autolog() 649 _call_generate_content(is_async, "test content", config={"temperature": 0.5}) 650 651 traces = get_traces() 652 assert len(traces) == 1 653 654 # Verify traceparent was injected into config.http_options.headers 655 config = captured_config["config"] 656 # config passed to _generate_content may be a dict or object 657 if isinstance(config, dict): 658 headers = config.get("http_options", {}).get("headers", {}) 659 else: 660 headers = getattr(getattr(config, "http_options", None), "headers", {}) or {} 661 assert "traceparent" in headers 662 assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", headers["traceparent"]) 663 664 665 def test_tracing_headers_preserve_existing_config_headers(is_async): 666 captured_config = {} 667 668 if is_async: 669 670 async def _generate_content(self, model, contents, config): 671 captured_config["config"] = config 672 return _DUMMY_GENERATE_CONTENT_RESPONSE 673 674 else: 675 676 def _generate_content(self, model, contents, config): 677 captured_config["config"] = config 678 return _DUMMY_GENERATE_CONTENT_RESPONSE 679 680 cls = "AsyncModels" if is_async else "Models" 681 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 682 mlflow.gemini.autolog() 683 _call_generate_content( 684 is_async, 685 "test content", 686 config={ 687 "temperature": 0.5, 688 "http_options": {"headers": {"X-Custom": "my-value"}}, 689 }, 690 ) 691 692 config = captured_config["config"] 693 if isinstance(config, dict): 694 headers = config.get("http_options", {}).get("headers", {}) 695 else: 696 headers = getattr(getattr(config, "http_options", None), "headers", {}) or {} 697 698 # Both traceparent and user headers should be present 699 assert "traceparent" in headers 700 # User-provided headers take precedence 701 assert headers["X-Custom"] == "my-value" 702 703 704 def test_tracing_headers_injected_when_config_is_none(is_async): 705 captured_config = {} 706 707 if is_async: 708 709 async def _generate_content(self, model, contents, config): 710 captured_config["config"] = config 711 return _DUMMY_GENERATE_CONTENT_RESPONSE 712 713 else: 714 715 def _generate_content(self, model, contents, config): 716 captured_config["config"] = config 717 return _DUMMY_GENERATE_CONTENT_RESPONSE 718 719 cls = "AsyncModels" if is_async else "Models" 720 with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content): 721 mlflow.gemini.autolog() 722 # Call without config — headers should still be injected 723 _call_generate_content(is_async, "test content") 724 725 traces = get_traces() 726 assert len(traces) == 1 727 728 # Verify traceparent was injected via config even though original config was None 729 config = captured_config["config"] 730 if isinstance(config, dict): 731 headers = config.get("http_options", {}).get("headers", {}) 732 else: 733 headers = getattr(getattr(config, "http_options", None), "headers", {}) or {} 734 assert "traceparent" in headers 735 assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", headers["traceparent"]) 736 737 # Verify the traceparent is stripped from span inputs 738 for span in traces[0].data.spans: 739 config_input = span.inputs.get("config") 740 if config_input is None: 741 continue 742 if isinstance(config_input, dict): 743 http_headers = config_input.get("http_options", {}).get("headers", {}) 744 else: 745 http_headers = getattr(getattr(config_input, "http_options", None), "headers", {}) or {} 746 assert "traceparent" not in http_headers