test_gemini_genai_semconv_converter.py
1 import json 2 from unittest.mock import patch 3 4 import pytest 5 from google import genai 6 7 import mlflow 8 from mlflow.gemini.genai_semconv_converter import _convert_part 9 from mlflow.tracing.constant import GenAiSemconvKey 10 11 from tests.gemini.test_gemini_autolog import ( 12 _dummy_generate_content, 13 _generate_content_response, 14 multiply, 15 ) 16 from tests.tracing.helper import capture_otel_export, reset_autolog_state # noqa: F401 17 18 MODEL = "gemini-1.5-flash" 19 20 21 @pytest.fixture(autouse=True) 22 def enable_genai_semconv(monkeypatch): 23 monkeypatch.setenv("MLFLOW_ENABLE_OTEL_GENAI_SEMCONV", "true") 24 return 25 26 27 def _get_llm_span(exporter, processor): 28 processor.force_flush(timeout_millis=5000) 29 spans = exporter.get_finished_spans() 30 return next( 31 s for s in spans if s.attributes.get(GenAiSemconvKey.OPERATION_NAME) == "generate_content" 32 ) 33 34 35 @pytest.mark.usefixtures("reset_autolog_state") 36 def test_autolog_basic(capture_otel_export): 37 exporter, processor = capture_otel_export 38 39 mlflow.gemini.autolog() 40 with patch( 41 "google.genai.models.Models._generate_content", 42 new=_dummy_generate_content(is_async=False), 43 ): 44 client = genai.Client(api_key="dummy") 45 client.models.generate_content(model=MODEL, contents="test content") 46 47 llm_span = _get_llm_span(exporter, processor) 48 assert llm_span.attributes[GenAiSemconvKey.OPERATION_NAME] == "generate_content" 49 assert llm_span.attributes[GenAiSemconvKey.REQUEST_MODEL] == MODEL 50 51 input_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.INPUT_MESSAGES]) 52 assert input_msgs[0]["role"] == "user" 53 assert input_msgs[0]["parts"][0]["type"] == "text" 54 assert input_msgs[0]["parts"][0]["content"] == "test content" 55 56 output_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.OUTPUT_MESSAGES]) 57 assert len(output_msgs) == 1 58 assert output_msgs[0]["role"] == "assistant" 59 assert output_msgs[0]["parts"][0]["content"] == "test answer" 60 assert not any(k.startswith("mlflow.") for k in llm_span.attributes) 61 62 63 @pytest.mark.usefixtures("reset_autolog_state") 64 def test_autolog_with_tool_calls(capture_otel_export): 65 exporter, processor = capture_otel_export 66 67 tool_call_content = { 68 "parts": [ 69 { 70 "function_call": { 71 "name": "multiply", 72 "args": {"a": 57.0, "b": 44.0}, 73 } 74 } 75 ], 76 "role": "model", 77 } 78 response = _generate_content_response(tool_call_content) 79 80 def _generate_content(self, model, contents, config): 81 return response 82 83 mlflow.gemini.autolog() 84 with patch("google.genai.models.Models._generate_content", new=_generate_content): 85 client = genai.Client(api_key="dummy") 86 client.models.generate_content( 87 model=MODEL, 88 contents="How much is 57 * 44?", 89 config=genai.types.GenerateContentConfig( 90 tools=[multiply], 91 automatic_function_calling=genai.types.AutomaticFunctionCallingConfig(disable=True), 92 ), 93 ) 94 95 llm_span = _get_llm_span(exporter, processor) 96 assert llm_span.attributes[GenAiSemconvKey.OPERATION_NAME] == "generate_content" 97 assert llm_span.attributes[GenAiSemconvKey.REQUEST_MODEL] == MODEL 98 99 input_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.INPUT_MESSAGES]) 100 assert input_msgs[0]["role"] == "user" 101 assert input_msgs[0]["parts"][0]["content"] == "How much is 57 * 44?" 102 103 output_msgs = json.loads(llm_span.attributes[GenAiSemconvKey.OUTPUT_MESSAGES]) 104 assert len(output_msgs) == 1 105 assert output_msgs[0]["role"] == "assistant" 106 tool_part = output_msgs[0]["parts"][0] 107 assert tool_part["type"] == "tool_call" 108 assert tool_part["name"] == "multiply" 109 assert tool_part["arguments"] == {"a": 57.0, "b": 44.0} 110 111 assert GenAiSemconvKey.TOOL_DEFINITIONS in llm_span.attributes 112 assert "multiply" in llm_span.attributes[GenAiSemconvKey.TOOL_DEFINITIONS] 113 assert not any(k.startswith("mlflow.") for k in llm_span.attributes) 114 115 116 @pytest.mark.parametrize( 117 ("part", "expected"), 118 [ 119 # inline_data (image) 120 ( 121 {"inline_data": {"data": "iVBOR...", "mime_type": "image/png"}}, 122 { 123 "type": "blob", 124 "modality": "image", 125 "mime_type": "image/png", 126 "content": "iVBOR...", 127 }, 128 ), 129 # file_data (image) 130 ( 131 {"file_data": {"file_uri": "gs://bucket/img.jpg", "mime_type": "image/jpeg"}}, 132 { 133 "type": "uri", 134 "modality": "image", 135 "mime_type": "image/jpeg", 136 "uri": "gs://bucket/img.jpg", 137 }, 138 ), 139 # inline_data (audio) 140 ( 141 {"inline_data": {"data": "audiodata", "mime_type": "audio/mp3"}}, 142 { 143 "type": "blob", 144 "modality": "audio", 145 "mime_type": "audio/mp3", 146 "content": "audiodata", 147 }, 148 ), 149 # file_data (video) 150 ( 151 {"file_data": {"file_uri": "gs://bucket/vid.mp4", "mime_type": "video/mp4"}}, 152 { 153 "type": "uri", 154 "modality": "video", 155 "mime_type": "video/mp4", 156 "uri": "gs://bucket/vid.mp4", 157 }, 158 ), 159 ], 160 ) 161 def test_convert_part_multimodal(part, expected): 162 assert _convert_part(part) == expected