/ tests / gemini / test_gemini_autolog.py
test_gemini_autolog.py
  1  """
  2  This file contains unit tests for the new Gemini Python SDK
  3  https://github.com/googleapis/python-genai
  4  """
  5  
  6  import asyncio
  7  import base64
  8  import importlib.metadata
  9  import re
 10  from unittest.mock import patch
 11  
 12  import pytest
 13  from google import genai
 14  from packaging.version import Version
 15  
 16  import mlflow
 17  from mlflow.entities.span import SpanType
 18  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
 19  from mlflow.version import IS_TRACING_SDK_ONLY
 20  
 21  from tests.tracing.helper import get_traces
 22  
 23  google_gemini_version = Version(importlib.metadata.version("google.genai"))
 24  is_gemini_1_7_or_newer = google_gemini_version >= Version("1.7.0")
 25  
 26  _CONTENT = {"parts": [{"text": "test answer"}], "role": "model"}
 27  
 28  _USER_METADATA = {
 29      "prompt_token_count": 6,
 30      "candidates_token_count": 6,
 31      "total_token_count": 12,
 32      "cached_content_token_count": 0,
 33  }
 34  
 35  _USER_METADATA_WITH_CACHE = {
 36      "prompt_token_count": 50,
 37      "candidates_token_count": 20,
 38      "total_token_count": 70,
 39      "cached_content_token_count": 30,
 40  }
 41  
 42  
 43  def _get_candidate(content):
 44      candidate = {
 45          "content": content,
 46          "avg_logprobs": 0.0,
 47          "finish_reason": "STOP",
 48          "safety_ratings": [],
 49          "token_count": 0,
 50      }
 51  
 52      return genai.types.Candidate(**candidate)
 53  
 54  
 55  def _generate_content_response(content, usage_metadata=None):
 56      res = {
 57          "candidates": [_get_candidate(content)],
 58          "usage_metadata": usage_metadata or _USER_METADATA,
 59          "automatic_function_calling_history": [],
 60      }
 61  
 62      return genai.types.GenerateContentResponse(**res)
 63  
 64  
 65  _DUMMY_GENERATE_CONTENT_RESPONSE = _generate_content_response(_CONTENT)
 66  
 67  _DUMMY_COUNT_TOKENS_RESPONSE = {"total_count": 10}
 68  
 69  _DUMMY_EMBEDDING_RESPONSE = {"embedding": [1, 2, 3]}
 70  
 71  
 72  def _dummy_generate_content(is_async: bool):
 73      if is_async:
 74  
 75          async def _generate_content(self, model, contents, config):
 76              return _DUMMY_GENERATE_CONTENT_RESPONSE
 77  
 78      else:
 79  
 80          def _generate_content(self, model, contents, config):
 81              return _DUMMY_GENERATE_CONTENT_RESPONSE
 82  
 83      return _generate_content
 84  
 85  
 86  def send_message(self, content):
 87      return _DUMMY_GENERATE_CONTENT_RESPONSE
 88  
 89  
 90  def count_tokens(self, model, contents):
 91      return _DUMMY_COUNT_TOKENS_RESPONSE
 92  
 93  
 94  def embed_content(self, model, content):
 95      return _DUMMY_EMBEDDING_RESPONSE
 96  
 97  
 98  def multiply(a: float, b: float):
 99      """returns a * b."""
100      return a * b
101  
102  
103  TOOL_ATTRIBUTE = [
104      {
105          "type": "function",
106          "function": {
107              "name": "multiply",
108              "description": "returns a * b.",
109              "parameters": {
110                  "properties": {
111                      "a": {"type": "number", "description": None, "enum": None},
112                      "b": {"type": "number", "description": None, "enum": None},
113                  },
114                  "required": ["a", "b"] if is_gemini_1_7_or_newer else None,
115              },
116          },
117      },
118  ]
119  
120  
121  @pytest.fixture(autouse=True)
122  def cleanup():
123      yield
124      mlflow.gemini.autolog(disable=True)
125  
126  
127  @pytest.fixture(params=[True, False], ids=["async", "sync"])
128  def is_async(request):
129      return request.param
130  
131  
132  def _call_generate_content(
133      is_async: bool, contents: str, model: str = "gemini-1.5-flash", config=None
134  ):
135      client = genai.Client(api_key="dummy")
136      if is_async:
137          return asyncio.run(
138              client.aio.models.generate_content(model=model, contents=contents, config=config)
139          )
140      else:
141          return client.models.generate_content(model=model, contents=contents, config=config)
142  
143  
144  def _create_chat_and_send_message(is_async: bool, message: str):
145      client = genai.Client(api_key="dummy")
146      if is_async:
147          chat = client.aio.chats.create(model="gemini-1.5-flash")
148          return asyncio.run(chat.send_message(message))
149      else:
150          chat = client.chats.create(model="gemini-1.5-flash")
151          return chat.send_message(message)
152  
153  
154  def test_generate_content_enable_disable_autolog(is_async, mock_litellm_cost):
155      cls = "AsyncModels" if is_async else "Models"
156      with (
157          patch(
158              f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async)
159          ),
160      ):
161          mlflow.gemini.autolog()
162          _call_generate_content(is_async, "test content")
163  
164          traces = get_traces()
165          assert len(traces) == 1
166          assert traces[0].info.status == "OK"
167          assert len(traces[0].data.spans) == 2
168  
169          span = traces[0].data.spans[0]
170          assert span.name == f"{cls}.generate_content"
171          assert span.span_type == SpanType.LLM
172          assert span.inputs == {
173              "contents": "test content",
174              "model": "gemini-1.5-flash",
175              "config": None,
176          }
177          assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump()
178          assert span.model_name == "gemini-1.5-flash"
179  
180          span1 = traces[0].data.spans[1]
181          assert span1.name == f"{cls}._generate_content"
182          assert span1.span_type == SpanType.LLM
183          assert span1.inputs == {
184              "contents": "test content",
185              "model": "gemini-1.5-flash",
186              "config": None,
187          }
188          assert span1.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump()
189  
190          assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
191              TokenUsageKey.INPUT_TOKENS: 6,
192              TokenUsageKey.OUTPUT_TOKENS: 6,
193              TokenUsageKey.TOTAL_TOKENS: 12,
194              TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
195          }
196  
197          if not IS_TRACING_SDK_ONLY:
198              # Verify cost is calculated (6 input tokens * 1.0 + 6 output tokens * 2.0)
199              assert span.llm_cost == {
200                  "input_cost": 6.0,
201                  "output_cost": 12.0,
202                  "total_cost": 18.0,
203              }
204  
205          assert traces[0].info.token_usage == {
206              "input_tokens": 6,
207              "output_tokens": 6,
208              "total_tokens": 12,
209              "cache_read_input_tokens": 0,
210          }
211  
212          mlflow.gemini.autolog(disable=True)
213          _call_generate_content(is_async, "test content")
214  
215          # No new trace should be created
216          traces = get_traces()
217          assert len(traces) == 1
218  
219  
220  def test_generate_content_tracing_with_error(is_async):
221      if is_async:
222  
223          async def _generate_content(self, model, contents, config):
224              raise Exception("dummy error")
225  
226      else:
227  
228          def _generate_content(self, model, contents, config):
229              raise Exception("dummy error")
230  
231      cls = "AsyncModels" if is_async else "Models"
232      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
233          mlflow.gemini.autolog()
234  
235          with pytest.raises(Exception, match="dummy error"):
236              _call_generate_content(is_async, "test content")
237  
238      traces = get_traces()
239      assert len(traces) == 1
240      assert len(traces[0].data.spans) == 2
241  
242      assert traces[0].info.status == "ERROR"
243      assert traces[0].data.spans[0].status.status_code == "ERROR"
244      assert traces[0].data.spans[0].status.description == "Exception: dummy error"
245      assert traces[0].data.spans[1].status.status_code == "ERROR"
246      assert traces[0].data.spans[1].status.description == "Exception: dummy error"
247  
248  
249  def test_generate_content_image_autolog(mock_litellm_cost):
250      image = base64.b64encode(b"image").decode("utf-8")
251      request = [
252          genai.types.Part.from_bytes(mime_type="image/jpeg", data=image),
253          "Caption this image",
254      ]
255      cls = "AsyncModels" if is_async else "Models"
256      with patch(
257          f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async)
258      ):
259          mlflow.gemini.autolog()
260          _call_generate_content(is_async, request)
261  
262      traces = get_traces()
263      assert len(traces) == 1
264      assert traces[0].info.status == "OK"
265      assert len(traces[0].data.spans) == 2
266  
267      span = traces[0].data.spans[0]
268      assert span.name == f"{cls}.generate_content"
269      assert span.span_type == SpanType.LLM
270      assert span.inputs["model"] == "gemini-1.5-flash"
271      extra = {"display_name": None} if google_gemini_version >= Version("1.15.0") else {}
272      inline_data = span.inputs["contents"][0]["inline_data"]
273      assert inline_data["mime_type"] == "image/jpeg"
274      # Auto-extraction replaces bytes repr with mlflow-attachment:// URI
275      assert inline_data["data"].startswith("mlflow-attachment://")
276      assert "content_type=image%2Fjpeg" in inline_data["data"]
277      if extra:
278          assert inline_data["display_name"] is None
279      assert span.inputs["contents"][1] == "Caption this image"
280      assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump()
281      assert span.model_name == "gemini-1.5-flash"
282  
283      span1 = traces[0].data.spans[1]
284      assert span1.name == f"{cls}._generate_content"
285      assert span1.span_type == SpanType.LLM
286      assert span1.parent_id == span.span_id
287      assert span1.inputs["model"] == "gemini-1.5-flash"
288      inline_data1 = span1.inputs["contents"][0]["inline_data"]
289      assert inline_data1["mime_type"] == "image/jpeg"
290      assert inline_data1["data"].startswith("mlflow-attachment://")
291      assert "content_type=image%2Fjpeg" in inline_data1["data"]
292      if extra:
293          assert inline_data1["display_name"] is None
294      assert span1.inputs["contents"][1] == "Caption this image"
295      assert span1.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump()
296  
297      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
298          TokenUsageKey.INPUT_TOKENS: 6,
299          TokenUsageKey.OUTPUT_TOKENS: 6,
300          TokenUsageKey.TOTAL_TOKENS: 12,
301          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
302      }
303      if not IS_TRACING_SDK_ONLY:
304          assert span.llm_cost == {
305              "input_cost": 6.0,
306              "output_cost": 12.0,
307              "total_cost": 18.0,
308          }
309  
310      assert traces[0].info.token_usage == {
311          "input_tokens": 6,
312          "output_tokens": 6,
313          "total_tokens": 12,
314          "cache_read_input_tokens": 0,
315      }
316  
317  
318  def test_generate_content_tool_calling_autolog(is_async, mock_litellm_cost):
319      tool_call_content = {
320          "parts": [
321              {
322                  "function_call": {
323                      "name": "multiply",
324                      "args": {
325                          "a": 57.0,
326                          "b": 44.0,
327                      },
328                  }
329              }
330          ],
331          "role": "model",
332      }
333  
334      response = _generate_content_response(tool_call_content)
335      if is_async:
336  
337          async def _generate_content(self, model, contents, config):
338              return response
339  
340      else:
341  
342          def _generate_content(self, model, contents, config):
343              return response
344  
345      cls = "AsyncModels" if is_async else "Models"
346  
347      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
348          mlflow.gemini.autolog()
349          _call_generate_content(
350              is_async,
351              model="gemini-1.5-flash",
352              contents="I have 57 cats, each owns 44 mittens, how many mittens is that in total?",
353              config=genai.types.GenerateContentConfig(
354                  tools=[multiply],
355                  automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True),
356              ),
357          )
358  
359      traces = get_traces()
360      assert len(traces) == 1
361      assert traces[0].info.status == "OK"
362      assert len(traces[0].data.spans) == 2
363  
364      span = traces[0].data.spans[0]
365      assert span.name == f"{cls}.generate_content"
366      assert span.span_type == SpanType.LLM
367      assert (
368          span.inputs["contents"]
369          == "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
370      )
371      assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == TOOL_ATTRIBUTE
372      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini"
373      assert span.model_name == "gemini-1.5-flash"
374  
375      span1 = traces[0].data.spans[1]
376      assert span1.name == f"{cls}._generate_content"
377      assert span1.span_type == SpanType.LLM
378      assert span1.parent_id == span.span_id
379      assert (
380          span1.inputs["contents"]
381          == "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
382      )
383      assert span1.get_attribute(SpanAttributeKey.CHAT_TOOLS) == TOOL_ATTRIBUTE
384      assert span1.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini"
385  
386      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
387          TokenUsageKey.INPUT_TOKENS: 6,
388          TokenUsageKey.OUTPUT_TOKENS: 6,
389          TokenUsageKey.TOTAL_TOKENS: 12,
390          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
391      }
392      if not IS_TRACING_SDK_ONLY:
393          assert span.llm_cost == {
394              "input_cost": 6.0,
395              "output_cost": 12.0,
396              "total_cost": 18.0,
397          }
398  
399      assert traces[0].info.token_usage == {
400          "input_tokens": 6,
401          "output_tokens": 6,
402          "total_tokens": 12,
403          "cache_read_input_tokens": 0,
404      }
405  
406  
407  def test_generate_content_tool_calling_chat_history_autolog(is_async, mock_litellm_cost):
408      question_content = genai.types.Content(**{
409          "parts": [
410              {
411                  "text": "I have 57 cats, each owns 44 mittens, how many mittens in total?",
412              }
413          ],
414          "role": "user",
415      })
416  
417      tool_call_content = genai.types.Content(**{
418          "parts": [
419              {
420                  "function_call": {
421                      "name": "multiply",
422                      "args": {
423                          "a": 57.0,
424                          "b": 44.0,
425                      },
426                  }
427              }
428          ],
429          "role": "model",
430      })
431  
432      tool_response_content = genai.types.Content(**{
433          "parts": [{"function_response": {"name": "multiply", "response": {"result": 2508.0}}}],
434          "role": "user",
435      })
436  
437      response = _generate_content_response(
438          genai.types.Content(**{
439              "parts": [
440                  {
441                      "text": "57 cats * 44 mittens/cat = 2508 mittens in total.",
442                  }
443              ],
444              "role": "model",
445          })
446      )
447  
448      cls = "AsyncModels" if is_async else "Models"
449  
450      if is_async:
451  
452          async def _generate_content(self, model, contents, config):
453              return response
454  
455      else:
456  
457          def _generate_content(self, model, contents, config):
458              return response
459  
460      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
461          mlflow.gemini.autolog()
462          _call_generate_content(
463              is_async,
464              model="gemini-1.5-flash",
465              contents=[question_content, tool_call_content, tool_response_content],
466              config=genai.types.GenerateContentConfig(
467                  tools=[multiply],
468                  automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True),
469              ),
470          )
471  
472      traces = get_traces()
473      assert len(traces) == 1
474      assert traces[0].info.status == "OK"
475      assert len(traces[0].data.spans) == 2
476  
477      span = traces[0].data.spans[0]
478      assert span.name == f"{cls}.generate_content"
479      assert span.span_type == SpanType.LLM
480      assert span.inputs["contents"] == [
481          question_content.model_dump(),
482          tool_call_content.model_dump(),
483          tool_response_content.model_dump(),
484      ]
485      assert span.inputs["model"] == "gemini-1.5-flash"
486      assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE
487      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini"
488      assert span.model_name == "gemini-1.5-flash"
489  
490      span1 = traces[0].data.spans[1]
491      assert span1.name == f"{cls}._generate_content"
492      assert span1.span_type == SpanType.LLM
493      assert span1.parent_id == span.span_id
494      assert span1.inputs["contents"] == [
495          question_content.model_dump(),
496          tool_call_content.model_dump(),
497          tool_response_content.model_dump(),
498      ]
499      assert span1.inputs["model"] == "gemini-1.5-flash"
500      assert span1.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE
501      assert span1.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "gemini"
502  
503      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
504          TokenUsageKey.INPUT_TOKENS: 6,
505          TokenUsageKey.OUTPUT_TOKENS: 6,
506          TokenUsageKey.TOTAL_TOKENS: 12,
507          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 0,
508      }
509      if not IS_TRACING_SDK_ONLY:
510          assert span.llm_cost == {
511              "input_cost": 6.0,
512              "output_cost": 12.0,
513              "total_cost": 18.0,
514          }
515  
516      assert traces[0].info.token_usage == {
517          "input_tokens": 6,
518          "output_tokens": 6,
519          "total_tokens": 12,
520          "cache_read_input_tokens": 0,
521      }
522  
523  
524  def test_chat_session_autolog(is_async):
525      cls = "AsyncModels" if is_async else "Models"
526      with patch(
527          f"google.genai.models.{cls}._generate_content", new=_dummy_generate_content(is_async)
528      ):
529          mlflow.gemini.autolog()
530          _create_chat_and_send_message(is_async, "test content")
531  
532          traces = get_traces()
533          assert len(traces) == 1
534          assert traces[0].info.status == "OK"
535          assert len(traces[0].data.spans) == 3
536          span = traces[0].data.spans[0]
537          assert span.name == "AsyncChat.send_message" if is_async else "Chat.send_message"
538          assert span.span_type == SpanType.CHAT_MODEL
539          assert span.inputs == {"message": "test content"}
540          assert span.outputs == _DUMMY_GENERATE_CONTENT_RESPONSE.model_dump()
541          assert span.model_name == "gemini-1.5-flash"
542  
543          mlflow.gemini.autolog(disable=True)
544          _create_chat_and_send_message(is_async, "test content")
545  
546          # No new trace should be created
547          traces = get_traces()
548          assert len(traces) == 1
549  
550  
551  def test_count_tokens_autolog():
552      with patch("google.genai.models.Models.count_tokens", new=count_tokens):
553          mlflow.gemini.autolog()
554          client = genai.Client(api_key="dummy")
555          client.models.count_tokens(model="gemini-1.5-flash", contents="test content")
556  
557          traces = get_traces()
558          assert len(traces) == 1
559          assert traces[0].info.status == "OK"
560          assert len(traces[0].data.spans) == 1
561          span = traces[0].data.spans[0]
562          assert span.name == "Models.count_tokens"
563          assert span.span_type == SpanType.LLM
564          assert span.inputs == {"contents": "test content", "model": "gemini-1.5-flash"}
565          assert span.outputs == _DUMMY_COUNT_TOKENS_RESPONSE
566          assert span.model_name == "gemini-1.5-flash"
567  
568          mlflow.gemini.autolog(disable=True)
569          client = genai.Client(api_key="dummy")
570          client.models.count_tokens(model="gemini-1.5-flash", contents="test content")
571  
572          # No new trace should be created
573          traces = get_traces()
574          assert len(traces) == 1
575  
576  
577  def test_embed_content_autolog():
578      with patch("google.genai.models.Models.embed_content", new=embed_content):
579          mlflow.gemini.autolog()
580          client = genai.Client(api_key="dummy")
581          client.models.embed_content(model="text-embedding-004", content="Hello World")
582  
583          traces = get_traces()
584          assert len(traces) == 1
585          assert traces[0].info.status == "OK"
586          assert len(traces[0].data.spans) == 1
587          span = traces[0].data.spans[0]
588          assert span.name == "Models.embed_content"
589          assert span.span_type == SpanType.EMBEDDING
590          assert span.inputs == {"content": "Hello World", "model": "text-embedding-004"}
591          assert span.outputs == _DUMMY_EMBEDDING_RESPONSE
592          assert span.model_name == "text-embedding-004"
593  
594          mlflow.gemini.autolog(disable=True)
595          client.models.embed_content(model="text-embedding-004", content="Hello World")
596  
597          # No new trace should be created
598          traces = get_traces()
599          assert len(traces) == 1
600  
601  
602  def test_generate_content_cached_tokens(is_async, mock_litellm_cost):
603      cached_response = _generate_content_response(_CONTENT, _USER_METADATA_WITH_CACHE)
604  
605      if is_async:
606  
607          async def _generate_content(self, model, contents, config):
608              return cached_response
609  
610      else:
611  
612          def _generate_content(self, model, contents, config):
613              return cached_response
614  
615      cls = "AsyncModels" if is_async else "Models"
616      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
617          mlflow.gemini.autolog()
618          _call_generate_content(is_async, "test content")
619  
620      traces = get_traces()
621      assert len(traces) == 1
622      span = traces[0].data.spans[0]
623      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
624          TokenUsageKey.INPUT_TOKENS: 50,
625          TokenUsageKey.OUTPUT_TOKENS: 20,
626          TokenUsageKey.TOTAL_TOKENS: 70,
627          TokenUsageKey.CACHE_READ_INPUT_TOKENS: 30,
628      }
629  
630  
631  def test_tracing_headers_injected_in_config(is_async):
632      captured_config = {}
633  
634      if is_async:
635  
636          async def _generate_content(self, model, contents, config):
637              captured_config["config"] = config
638              return _DUMMY_GENERATE_CONTENT_RESPONSE
639  
640      else:
641  
642          def _generate_content(self, model, contents, config):
643              captured_config["config"] = config
644              return _DUMMY_GENERATE_CONTENT_RESPONSE
645  
646      cls = "AsyncModels" if is_async else "Models"
647      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
648          mlflow.gemini.autolog()
649          _call_generate_content(is_async, "test content", config={"temperature": 0.5})
650  
651      traces = get_traces()
652      assert len(traces) == 1
653  
654      # Verify traceparent was injected into config.http_options.headers
655      config = captured_config["config"]
656      # config passed to _generate_content may be a dict or object
657      if isinstance(config, dict):
658          headers = config.get("http_options", {}).get("headers", {})
659      else:
660          headers = getattr(getattr(config, "http_options", None), "headers", {}) or {}
661      assert "traceparent" in headers
662      assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", headers["traceparent"])
663  
664  
665  def test_tracing_headers_preserve_existing_config_headers(is_async):
666      captured_config = {}
667  
668      if is_async:
669  
670          async def _generate_content(self, model, contents, config):
671              captured_config["config"] = config
672              return _DUMMY_GENERATE_CONTENT_RESPONSE
673  
674      else:
675  
676          def _generate_content(self, model, contents, config):
677              captured_config["config"] = config
678              return _DUMMY_GENERATE_CONTENT_RESPONSE
679  
680      cls = "AsyncModels" if is_async else "Models"
681      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
682          mlflow.gemini.autolog()
683          _call_generate_content(
684              is_async,
685              "test content",
686              config={
687                  "temperature": 0.5,
688                  "http_options": {"headers": {"X-Custom": "my-value"}},
689              },
690          )
691  
692      config = captured_config["config"]
693      if isinstance(config, dict):
694          headers = config.get("http_options", {}).get("headers", {})
695      else:
696          headers = getattr(getattr(config, "http_options", None), "headers", {}) or {}
697  
698      # Both traceparent and user headers should be present
699      assert "traceparent" in headers
700      # User-provided headers take precedence
701      assert headers["X-Custom"] == "my-value"
702  
703  
704  def test_tracing_headers_injected_when_config_is_none(is_async):
705      captured_config = {}
706  
707      if is_async:
708  
709          async def _generate_content(self, model, contents, config):
710              captured_config["config"] = config
711              return _DUMMY_GENERATE_CONTENT_RESPONSE
712  
713      else:
714  
715          def _generate_content(self, model, contents, config):
716              captured_config["config"] = config
717              return _DUMMY_GENERATE_CONTENT_RESPONSE
718  
719      cls = "AsyncModels" if is_async else "Models"
720      with patch(f"google.genai.models.{cls}._generate_content", new=_generate_content):
721          mlflow.gemini.autolog()
722          # Call without config — headers should still be injected
723          _call_generate_content(is_async, "test content")
724  
725      traces = get_traces()
726      assert len(traces) == 1
727  
728      # Verify traceparent was injected via config even though original config was None
729      config = captured_config["config"]
730      if isinstance(config, dict):
731          headers = config.get("http_options", {}).get("headers", {})
732      else:
733          headers = getattr(getattr(config, "http_options", None), "headers", {}) or {}
734      assert "traceparent" in headers
735      assert re.fullmatch(r"00-[0-9a-f]{32}-[0-9a-f]{16}-[0-9a-f]{2}", headers["traceparent"])
736  
737      # Verify the traceparent is stripped from span inputs
738      for span in traces[0].data.spans:
739          config_input = span.inputs.get("config")
740          if config_input is None:
741              continue
742          if isinstance(config_input, dict):
743              http_headers = config_input.get("http_options", {}).get("headers", {})
744          else:
745              http_headers = getattr(getattr(config_input, "http_options", None), "headers", {}) or {}
746          assert "traceparent" not in http_headers