/ tests / gemini / test_gemini_genai_semconv_converter.py
test_gemini_genai_semconv_converter.py
  1  import json
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  from google import genai
  6  
  7  import mlflow
  8  from mlflow.gemini.genai_semconv_converter import _convert_part
  9  from mlflow.tracing.constant import GenAiSemconvKey
 10  
 11  from tests.gemini.test_gemini_autolog import (
 12      _dummy_generate_content,
 13      _generate_content_response,
 14      multiply,
 15  )
 16  from tests.tracing.helper import capture_otel_export, reset_autolog_state  # noqa: F401
 17  
 18  MODEL = "gemini-1.5-flash"
 19  
 20  
 21  @pytest.fixture(autouse=True)
 22  def enable_genai_semconv(monkeypatch):
 23      monkeypatch.setenv("MLFLOW_ENABLE_OTEL_GENAI_SEMCONV", "true")
 24      return
 25  
 26  
 27  def _get_llm_span(exporter, processor):
 28      processor.force_flush(timeout_millis=5000)
 29      spans = exporter.get_finished_spans()
 30      return next(
 31          s for s in spans if s.attributes.get(GenAiSemconvKey.OPERATION_NAME) == "generate_content"
 32      )
 33  
 34  
 35  @pytest.mark.usefixtures("reset_autolog_state")
 36  def test_autolog_basic(capture_otel_export):
 37      exporter, processor = capture_otel_export
 38  
 39      mlflow.gemini.autolog()
 40      with patch(
 41          "google.genai.models.Models._generate_content",
 42          new=_dummy_generate_content(is_async=False),
 43      ):
 44          client = genai.Client(api_key="dummy")
 45          client.models.generate_content(model=MODEL, contents="test content")
 46  
 47      llm_span = _get_llm_span(exporter, processor)
 48      assert llm_span.attributes[GenAiSemconvKey.OPERATION_NAME] == "generate_content"
 49      assert llm_span.attributes[GenAiSemconvKey.REQUEST_MODEL] == MODEL
 50  
 51      input_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.INPUT_MESSAGES])
 52      assert input_msgs[0]["role"] == "user"
 53      assert input_msgs[0]["parts"][0]["type"] == "text"
 54      assert input_msgs[0]["parts"][0]["content"] == "test content"
 55  
 56      output_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.OUTPUT_MESSAGES])
 57      assert len(output_msgs) == 1
 58      assert output_msgs[0]["role"] == "assistant"
 59      assert output_msgs[0]["parts"][0]["content"] == "test answer"
 60      assert not any(k.startswith("mlflow.") for k in llm_span.attributes)
 61  
 62  
 63  @pytest.mark.usefixtures("reset_autolog_state")
 64  def test_autolog_with_tool_calls(capture_otel_export):
 65      exporter, processor = capture_otel_export
 66  
 67      tool_call_content = {
 68          "parts": [
 69              {
 70                  "function_call": {
 71                      "name": "multiply",
 72                      "args": {"a": 57.0, "b": 44.0},
 73                  }
 74              }
 75          ],
 76          "role": "model",
 77      }
 78      response = _generate_content_response(tool_call_content)
 79  
 80      def _generate_content(self, model, contents, config):
 81          return response
 82  
 83      mlflow.gemini.autolog()
 84      with patch("google.genai.models.Models._generate_content", new=_generate_content):
 85          client = genai.Client(api_key="dummy")
 86          client.models.generate_content(
 87              model=MODEL,
 88              contents="How much is 57 * 44?",
 89              config=genai.types.GenerateContentConfig(
 90                  tools=[multiply],
 91                  automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True),
 92              ),
 93          )
 94  
 95      llm_span = _get_llm_span(exporter, processor)
 96      assert llm_span.attributes[GenAiSemconvKey.OPERATION_NAME] == "generate_content"
 97      assert llm_span.attributes[GenAiSemconvKey.REQUEST_MODEL] == MODEL
 98  
 99      input_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.INPUT_MESSAGES])
100      assert input_msgs[0]["role"] == "user"
101      assert input_msgs[0]["parts"][0]["content"] == "How much is 57 * 44?"
102  
103      output_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.OUTPUT_MESSAGES])
104      assert len(output_msgs) == 1
105      assert output_msgs[0]["role"] == "assistant"
106      tool_part = output_msgs[0]["parts"][0]
107      assert tool_part["type"] == "tool_call"
108      assert tool_part["name"] == "multiply"
109      assert tool_part["arguments"] == {"a": 57.0, "b": 44.0}
110  
111      assert GenAiSemconvKey.TOOL_DEFINITIONS in llm_span.attributes
112      assert "multiply" in llm_span.attributes[GenAiSemconvKey.TOOL_DEFINITIONS]
113      assert not any(k.startswith("mlflow.") for k in llm_span.attributes)
114  
115  
116  @pytest.mark.parametrize(
117      ("part", "expected"),
118      [
119          # inline_data (image)
120          (
121              {"inline_data": {"data": "iVBOR...", "mime_type": "image/png"}},
122              {
123                  "type": "blob",
124                  "modality": "image",
125                  "mime_type": "image/png",
126                  "content": "iVBOR...",
127              },
128          ),
129          # file_data (image)
130          (
131              {"file_data": {"file_uri": "gs://bucket/img.jpg", "mime_type": "image/jpeg"}},
132              {
133                  "type": "uri",
134                  "modality": "image",
135                  "mime_type": "image/jpeg",
136                  "uri": "gs://bucket/img.jpg",
137              },
138          ),
139          # inline_data (audio)
140          (
141              {"inline_data": {"data": "audiodata", "mime_type": "audio/mp3"}},
142              {
143                  "type": "blob",
144                  "modality": "audio",
145                  "mime_type": "audio/mp3",
146                  "content": "audiodata",
147              },
148          ),
149          # file_data (video)
150          (
151              {"file_data": {"file_uri": "gs://bucket/vid.mp4", "mime_type": "video/mp4"}},
152              {
153                  "type": "uri",
154                  "modality": "video",
155                  "mime_type": "video/mp4",
156                  "uri": "gs://bucket/vid.mp4",
157              },
158          ),
159      ],
160  )
161  def test_convert_part_multimodal(part, expected):
162      assert _convert_part(part) == expected