test_legacy_gemini_autolog.py
1 """ 2 This file contains unit tests for the legacy Gemini Python SDK 3 https://github.com/google-gemini/generative-ai-python 4 """ 5 6 import base64 7 from unittest.mock import patch 8 9 import google.generativeai as genai 10 import pytest 11 from packaging.version import Version 12 13 import mlflow 14 from mlflow.entities.span import SpanType 15 16 from tests.tracing.helper import get_traces 17 18 _CONTENT = {"parts": [{"text": "test answer"}], "role": "model"} 19 20 _USER_METADATA = { 21 "prompt_token_count": 6, 22 "candidates_token_count": 6, 23 "total_token_count": 6, 24 "cached_content_token_count": 0, 25 } 26 27 28 def _get_candidate(content): 29 candidate = { 30 "content": content, 31 "avg_logprobs": 0.0, 32 "finish_reason": 0, 33 "grounding_attributions": [], 34 "safety_ratings": [], 35 "token_count": 0, 36 } 37 38 if Version(genai.__version__) < Version("0.8.3"): 39 candidate.pop("avg_logprobs") 40 41 return candidate 42 43 44 def _generate_content_response(content): 45 res = { 46 "candidates": [_get_candidate(content)], 47 "usage_metadata": _USER_METADATA, 48 } 49 50 if hasattr(genai.types.GenerateContentResponse, "model_version"): 51 res["model_version"] = "gemini-1.5-flash-002" 52 53 return res 54 55 56 _GENERATE_CONTENT_RESPONSE = _generate_content_response(_CONTENT) 57 58 _DUMMY_GENERATE_CONTENT_RESPONSE = genai.types.GenerateContentResponse.from_response( 59 genai.protos.GenerateContentResponse(_GENERATE_CONTENT_RESPONSE) 60 ) 61 62 _DUMMY_COUNT_TOKENS_RESPONSE = {"total_count": 10} 63 64 _DUMMY_EMBEDDING_RESPONSE = {"embedding": [1, 2, 3]} 65 66 67 def generate_content(self, contents): 68 return _DUMMY_GENERATE_CONTENT_RESPONSE 69 70 71 def send_message(self, content): 72 return _DUMMY_GENERATE_CONTENT_RESPONSE 73 74 75 def count_tokens(self, contents): 76 return _DUMMY_COUNT_TOKENS_RESPONSE 77 78 79 def embed_content(model, content): 80 return _DUMMY_EMBEDDING_RESPONSE 81 82 83 def multiply(a: float, b: float): 84 """returns a * b.""" 85 return a * b 86 87 88 TOOL_ATTRIBUTE = [ 89 { 90 "type": "function", 91 "function": { 92 "name": "multiply", 93 "description": "returns a * b.", 94 "parameters": { 95 "properties": { 96 "a": {"type": "number", "description": "", "enum": []}, 97 "b": {"type": "number", "description": "", "enum": []}, 98 }, 99 "required": ["a", "b"], 100 }, 101 }, 102 }, 103 ] 104 105 106 @pytest.fixture(autouse=True) 107 def cleanup(): 108 yield 109 mlflow.gemini.autolog(disable=True) 110 111 112 def test_generate_content_enable_disable_autolog(): 113 with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content): 114 mlflow.gemini.autolog() 115 model = genai.GenerativeModel("gemini-1.5-flash") 116 model.generate_content("test content") 117 118 traces = get_traces() 119 assert len(traces) == 1 120 assert traces[0].info.status == "OK" 121 assert len(traces[0].data.spans) == 1 122 span = traces[0].data.spans[0] 123 assert span.name == "GenerativeModel.generate_content" 124 assert span.span_type == SpanType.LLM 125 assert span.inputs == {"contents": "test content"} 126 assert span.outputs == _GENERATE_CONTENT_RESPONSE 127 128 mlflow.gemini.autolog(disable=True) 129 model = genai.GenerativeModel("gemini-1.5-flash") 130 model.generate_content("test content") 131 132 # No new trace should be created 133 traces = get_traces() 134 assert len(traces) == 1 135 136 137 def test_generate_content_tracing_with_error(): 138 with patch( 139 "google.generativeai.GenerativeModel.generate_content", side_effect=Exception("dummy error") 140 ): 141 mlflow.gemini.autolog() 142 model = genai.GenerativeModel("gemini-1.5-flash") 143 144 with pytest.raises(Exception, match="dummy error"): 145 model.generate_content("test content") 146 147 traces = get_traces() 148 assert len(traces) == 1 149 assert traces[0].info.status == "ERROR" 150 assert traces[0].data.spans[0].status.status_code == "ERROR" 151 assert traces[0].data.spans[0].status.description == "Exception: dummy error" 152 153 154 def test_generate_content_image_autolog(): 155 image = base64.b64encode(b"image").decode("utf-8") 156 request = [{"mime_type": "image/jpeg", "data": image}, "Caption this image"] 157 with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content): 158 mlflow.gemini.autolog() 159 model = genai.GenerativeModel("gemini-1.5-flash") 160 model.generate_content(request) 161 162 traces = get_traces() 163 assert len(traces) == 1 164 assert traces[0].info.status == "OK" 165 assert len(traces[0].data.spans) == 1 166 span = traces[0].data.spans[0] 167 assert span.name == "GenerativeModel.generate_content" 168 assert span.span_type == SpanType.LLM 169 assert span.inputs == {"contents": request} 170 assert span.outputs == _GENERATE_CONTENT_RESPONSE 171 172 173 def test_generate_content_tool_calling_autolog(): 174 tool_call_content = { 175 "parts": [ 176 { 177 "function_call": { 178 "name": "multiply", 179 "args": { 180 "a": 57.0, 181 "b": 44.0, 182 }, 183 } 184 } 185 ], 186 "role": "model", 187 } 188 189 raw_response = _generate_content_response(tool_call_content) 190 response = genai.types.GenerateContentResponse.from_response( 191 genai.protos.GenerateContentResponse(raw_response) 192 ) 193 194 def generate_content(self, content): 195 return response 196 197 with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content): 198 mlflow.gemini.autolog() 199 model = genai.GenerativeModel("gemini-1.5-flash", tools=[multiply]) 200 model.generate_content( 201 "I have 57 cats, each owns 44 mittens, how many mittens is that in total?" 202 ) 203 204 traces = get_traces() 205 assert len(traces) == 1 206 assert traces[0].info.status == "OK" 207 assert len(traces[0].data.spans) == 1 208 span = traces[0].data.spans[0] 209 assert span.name == "GenerativeModel.generate_content" 210 assert span.span_type == SpanType.LLM 211 assert span.inputs == { 212 "content": "I have 57 cats, each owns 44 mittens, how many mittens is that in total?" 213 } 214 assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE 215 216 217 def test_generate_content_tool_calling_chat_history_autolog(): 218 question_content = genai.protos.Content({ 219 "parts": [ 220 { 221 "text": "I have 57 cats, each owns 44 mittens, how many mittens in total?", 222 } 223 ], 224 "role": "user", 225 }) 226 227 tool_call_content = genai.protos.Content({ 228 "parts": [ 229 { 230 "function_call": { 231 "name": "multiply", 232 "args": { 233 "a": 57.0, 234 "b": 44.0, 235 }, 236 } 237 } 238 ], 239 "role": "model", 240 }) 241 242 tool_response_content = genai.protos.Content({ 243 "parts": [{"function_response": {"name": "multiply", "response": {"result": 2508.0}}}], 244 "role": "user", 245 }) 246 247 raw_response = _generate_content_response( 248 genai.protos.Content({ 249 "parts": [ 250 { 251 "text": "57 cats * 44 mittens/cat = 2508 mittens in total.", 252 } 253 ], 254 "role": "model", 255 }) 256 ) 257 258 response = genai.types.GenerateContentResponse.from_response( 259 genai.protos.GenerateContentResponse(raw_response) 260 ) 261 262 def generate_content(self, content): 263 return response 264 265 with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content): 266 mlflow.gemini.autolog() 267 model = genai.GenerativeModel("gemini-1.5-flash", tools=[multiply]) 268 model.generate_content([question_content, tool_call_content, tool_response_content]) 269 270 traces = get_traces() 271 assert len(traces) == 1 272 assert traces[0].info.status == "OK" 273 assert len(traces[0].data.spans) == 1 274 span = traces[0].data.spans[0] 275 assert span.name == "GenerativeModel.generate_content" 276 assert span.span_type == SpanType.LLM 277 assert span.inputs == { 278 "content": [str(question_content), str(tool_call_content), str(tool_response_content)] 279 } 280 assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE 281 282 283 def test_chat_session_autolog(): 284 with patch("google.generativeai.ChatSession.send_message", new=send_message): 285 mlflow.gemini.autolog() 286 model = genai.GenerativeModel("gemini-1.5-flash") 287 chat = model.start_chat(history=[]) 288 chat.send_message("test content") 289 290 traces = get_traces() 291 assert len(traces) == 1 292 assert traces[0].info.status == "OK" 293 assert len(traces[0].data.spans) == 1 294 span = traces[0].data.spans[0] 295 assert span.name == "ChatSession.send_message" 296 assert span.span_type == SpanType.CHAT_MODEL 297 assert span.inputs == {"content": "test content"} 298 assert span.outputs == _GENERATE_CONTENT_RESPONSE 299 300 mlflow.gemini.autolog(disable=True) 301 model = genai.GenerativeModel("gemini-1.5-flash") 302 chat = model.start_chat(history=[]) 303 chat.send_message("test content") 304 305 # No new trace should be created 306 traces = get_traces() 307 assert len(traces) == 1 308 309 310 def test_count_tokens_autolog(): 311 with patch("google.generativeai.GenerativeModel.count_tokens", new=count_tokens): 312 mlflow.gemini.autolog() 313 model = genai.GenerativeModel("gemini-1.5-flash") 314 model.count_tokens("test content") 315 316 traces = get_traces() 317 assert len(traces) == 1 318 assert traces[0].info.status == "OK" 319 assert len(traces[0].data.spans) == 1 320 span = traces[0].data.spans[0] 321 assert span.name == "GenerativeModel.count_tokens" 322 assert span.span_type == SpanType.LLM 323 assert span.inputs == {"contents": "test content"} 324 assert span.outputs == _DUMMY_COUNT_TOKENS_RESPONSE 325 326 mlflow.gemini.autolog(disable=True) 327 model = genai.GenerativeModel("gemini-1.5-flash") 328 model.count_tokens("test content") 329 330 # No new trace should be created 331 traces = get_traces() 332 assert len(traces) == 1 333 334 335 def test_embed_content_autolog(): 336 with patch("google.generativeai.embed_content", new=embed_content): 337 mlflow.gemini.autolog() 338 genai.embed_content(model="models/text-embedding-004", content="Hello World") 339 340 traces = get_traces() 341 assert len(traces) == 1 342 assert traces[0].info.status == "OK" 343 assert len(traces[0].data.spans) == 1 344 span = traces[0].data.spans[0] 345 assert span.name == "embed_content" 346 assert span.span_type == SpanType.EMBEDDING 347 assert span.inputs == {"content": "Hello World", "model": "models/text-embedding-004"} 348 assert span.outputs == _DUMMY_EMBEDDING_RESPONSE 349 350 mlflow.gemini.autolog(disable=True) 351 genai.embed_content(model="models/text-embedding-004", content="Hello World") 352 353 # No new trace should be created 354 traces = get_traces() 355 assert len(traces) == 1