/ tests / gemini / test_legacy_gemini_autolog.py
test_legacy_gemini_autolog.py
  1  """
  2  This file contains unit tests for the legacy Gemini Python SDK
  3  https://github.com/google-gemini/generative-ai-python
  4  """
  5  
  6  import base64
  7  from unittest.mock import patch
  8  
  9  import google.generativeai as genai
 10  import pytest
 11  from packaging.version import Version
 12  
 13  import mlflow
 14  from mlflow.entities.span import SpanType
 15  
 16  from tests.tracing.helper import get_traces
 17  
 18  _CONTENT = {"parts": [{"text": "test answer"}], "role": "model"}
 19  
 20  _USER_METADATA = {
 21      "prompt_token_count": 6,
 22      "candidates_token_count": 6,
 23      "total_token_count": 6,
 24      "cached_content_token_count": 0,
 25  }
 26  
 27  
 28  def _get_candidate(content):
 29      candidate = {
 30          "content": content,
 31          "avg_logprobs": 0.0,
 32          "finish_reason": 0,
 33          "grounding_attributions": [],
 34          "safety_ratings": [],
 35          "token_count": 0,
 36      }
 37  
 38      if Version(genai.__version__) < Version("0.8.3"):
 39          candidate.pop("avg_logprobs")
 40  
 41      return candidate
 42  
 43  
 44  def _generate_content_response(content):
 45      res = {
 46          "candidates": [_get_candidate(content)],
 47          "usage_metadata": _USER_METADATA,
 48      }
 49  
 50      if hasattr(genai.types.GenerateContentResponse, "model_version"):
 51          res["model_version"] = "gemini-1.5-flash-002"
 52  
 53      return res
 54  
 55  
 56  _GENERATE_CONTENT_RESPONSE = _generate_content_response(_CONTENT)
 57  
 58  _DUMMY_GENERATE_CONTENT_RESPONSE = genai.types.GenerateContentResponse.from_response(
 59      genai.protos.GenerateContentResponse(_GENERATE_CONTENT_RESPONSE)
 60  )
 61  
 62  _DUMMY_COUNT_TOKENS_RESPONSE = {"total_count": 10}
 63  
 64  _DUMMY_EMBEDDING_RESPONSE = {"embedding": [1, 2, 3]}
 65  
 66  
 67  def generate_content(self, contents):
 68      return _DUMMY_GENERATE_CONTENT_RESPONSE
 69  
 70  
 71  def send_message(self, content):
 72      return _DUMMY_GENERATE_CONTENT_RESPONSE
 73  
 74  
 75  def count_tokens(self, contents):
 76      return _DUMMY_COUNT_TOKENS_RESPONSE
 77  
 78  
 79  def embed_content(model, content):
 80      return _DUMMY_EMBEDDING_RESPONSE
 81  
 82  
 83  def multiply(a: float, b: float):
 84      """returns a * b."""
 85      return a * b
 86  
 87  
 88  TOOL_ATTRIBUTE = [
 89      {
 90          "type": "function",
 91          "function": {
 92              "name": "multiply",
 93              "description": "returns a * b.",
 94              "parameters": {
 95                  "properties": {
 96                      "a": {"type": "number", "description": "", "enum": []},
 97                      "b": {"type": "number", "description": "", "enum": []},
 98                  },
 99                  "required": ["a", "b"],
100              },
101          },
102      },
103  ]
104  
105  
106  @pytest.fixture(autouse=True)
107  def cleanup():
108      yield
109      mlflow.gemini.autolog(disable=True)
110  
111  
112  def test_generate_content_enable_disable_autolog():
113      with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content):
114          mlflow.gemini.autolog()
115          model = genai.GenerativeModel("gemini-1.5-flash")
116          model.generate_content("test content")
117  
118          traces = get_traces()
119          assert len(traces) == 1
120          assert traces[0].info.status == "OK"
121          assert len(traces[0].data.spans) == 1
122          span = traces[0].data.spans[0]
123          assert span.name == "GenerativeModel.generate_content"
124          assert span.span_type == SpanType.LLM
125          assert span.inputs == {"contents": "test content"}
126          assert span.outputs == _GENERATE_CONTENT_RESPONSE
127  
128          mlflow.gemini.autolog(disable=True)
129          model = genai.GenerativeModel("gemini-1.5-flash")
130          model.generate_content("test content")
131  
132          # No new trace should be created
133          traces = get_traces()
134          assert len(traces) == 1
135  
136  
137  def test_generate_content_tracing_with_error():
138      with patch(
139          "google.generativeai.GenerativeModel.generate_content", side_effect=Exception("dummy error")
140      ):
141          mlflow.gemini.autolog()
142          model = genai.GenerativeModel("gemini-1.5-flash")
143  
144          with pytest.raises(Exception, match="dummy error"):
145              model.generate_content("test content")
146  
147      traces = get_traces()
148      assert len(traces) == 1
149      assert traces[0].info.status == "ERROR"
150      assert traces[0].data.spans[0].status.status_code == "ERROR"
151      assert traces[0].data.spans[0].status.description == "Exception: dummy error"
152  
153  
154  def test_generate_content_image_autolog():
155      image = base64.b64encode(b"image").decode("utf-8")
156      request = [{"mime_type": "image/jpeg", "data": image}, "Caption this image"]
157      with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content):
158          mlflow.gemini.autolog()
159          model = genai.GenerativeModel("gemini-1.5-flash")
160          model.generate_content(request)
161  
162      traces = get_traces()
163      assert len(traces) == 1
164      assert traces[0].info.status == "OK"
165      assert len(traces[0].data.spans) == 1
166      span = traces[0].data.spans[0]
167      assert span.name == "GenerativeModel.generate_content"
168      assert span.span_type == SpanType.LLM
169      assert span.inputs == {"contents": request}
170      assert span.outputs == _GENERATE_CONTENT_RESPONSE
171  
172  
173  def test_generate_content_tool_calling_autolog():
174      tool_call_content = {
175          "parts": [
176              {
177                  "function_call": {
178                      "name": "multiply",
179                      "args": {
180                          "a": 57.0,
181                          "b": 44.0,
182                      },
183                  }
184              }
185          ],
186          "role": "model",
187      }
188  
189      raw_response = _generate_content_response(tool_call_content)
190      response = genai.types.GenerateContentResponse.from_response(
191          genai.protos.GenerateContentResponse(raw_response)
192      )
193  
194      def generate_content(self, content):
195          return response
196  
197      with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content):
198          mlflow.gemini.autolog()
199          model = genai.GenerativeModel("gemini-1.5-flash", tools=[multiply])
200          model.generate_content(
201              "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
202          )
203  
204      traces = get_traces()
205      assert len(traces) == 1
206      assert traces[0].info.status == "OK"
207      assert len(traces[0].data.spans) == 1
208      span = traces[0].data.spans[0]
209      assert span.name == "GenerativeModel.generate_content"
210      assert span.span_type == SpanType.LLM
211      assert span.inputs == {
212          "content": "I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
213      }
214      assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE
215  
216  
217  def test_generate_content_tool_calling_chat_history_autolog():
218      question_content = genai.protos.Content({
219          "parts": [
220              {
221                  "text": "I have 57 cats, each owns 44 mittens, how many mittens in total?",
222              }
223          ],
224          "role": "user",
225      })
226  
227      tool_call_content = genai.protos.Content({
228          "parts": [
229              {
230                  "function_call": {
231                      "name": "multiply",
232                      "args": {
233                          "a": 57.0,
234                          "b": 44.0,
235                      },
236                  }
237              }
238          ],
239          "role": "model",
240      })
241  
242      tool_response_content = genai.protos.Content({
243          "parts": [{"function_response": {"name": "multiply", "response": {"result": 2508.0}}}],
244          "role": "user",
245      })
246  
247      raw_response = _generate_content_response(
248          genai.protos.Content({
249              "parts": [
250                  {
251                      "text": "57 cats * 44 mittens/cat = 2508 mittens in total.",
252                  }
253              ],
254              "role": "model",
255          })
256      )
257  
258      response = genai.types.GenerateContentResponse.from_response(
259          genai.protos.GenerateContentResponse(raw_response)
260      )
261  
262      def generate_content(self, content):
263          return response
264  
265      with patch("google.generativeai.GenerativeModel.generate_content", new=generate_content):
266          mlflow.gemini.autolog()
267          model = genai.GenerativeModel("gemini-1.5-flash", tools=[multiply])
268          model.generate_content([question_content, tool_call_content, tool_response_content])
269  
270      traces = get_traces()
271      assert len(traces) == 1
272      assert traces[0].info.status == "OK"
273      assert len(traces[0].data.spans) == 1
274      span = traces[0].data.spans[0]
275      assert span.name == "GenerativeModel.generate_content"
276      assert span.span_type == SpanType.LLM
277      assert span.inputs == {
278          "content": [str(question_content), str(tool_call_content), str(tool_response_content)]
279      }
280      assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTE
281  
282  
283  def test_chat_session_autolog():
284      with patch("google.generativeai.ChatSession.send_message", new=send_message):
285          mlflow.gemini.autolog()
286          model = genai.GenerativeModel("gemini-1.5-flash")
287          chat = model.start_chat(history=[])
288          chat.send_message("test content")
289  
290          traces = get_traces()
291          assert len(traces) == 1
292          assert traces[0].info.status == "OK"
293          assert len(traces[0].data.spans) == 1
294          span = traces[0].data.spans[0]
295          assert span.name == "ChatSession.send_message"
296          assert span.span_type == SpanType.CHAT_MODEL
297          assert span.inputs == {"content": "test content"}
298          assert span.outputs == _GENERATE_CONTENT_RESPONSE
299  
300          mlflow.gemini.autolog(disable=True)
301          model = genai.GenerativeModel("gemini-1.5-flash")
302          chat = model.start_chat(history=[])
303          chat.send_message("test content")
304  
305          # No new trace should be created
306          traces = get_traces()
307          assert len(traces) == 1
308  
309  
310  def test_count_tokens_autolog():
311      with patch("google.generativeai.GenerativeModel.count_tokens", new=count_tokens):
312          mlflow.gemini.autolog()
313          model = genai.GenerativeModel("gemini-1.5-flash")
314          model.count_tokens("test content")
315  
316          traces = get_traces()
317          assert len(traces) == 1
318          assert traces[0].info.status == "OK"
319          assert len(traces[0].data.spans) == 1
320          span = traces[0].data.spans[0]
321          assert span.name == "GenerativeModel.count_tokens"
322          assert span.span_type == SpanType.LLM
323          assert span.inputs == {"contents": "test content"}
324          assert span.outputs == _DUMMY_COUNT_TOKENS_RESPONSE
325  
326          mlflow.gemini.autolog(disable=True)
327          model = genai.GenerativeModel("gemini-1.5-flash")
328          model.count_tokens("test content")
329  
330          # No new trace should be created
331          traces = get_traces()
332          assert len(traces) == 1
333  
334  
335  def test_embed_content_autolog():
336      with patch("google.generativeai.embed_content", new=embed_content):
337          mlflow.gemini.autolog()
338          genai.embed_content(model="models/text-embedding-004", content="Hello World")
339  
340          traces = get_traces()
341          assert len(traces) == 1
342          assert traces[0].info.status == "OK"
343          assert len(traces[0].data.spans) == 1
344          span = traces[0].data.spans[0]
345          assert span.name == "embed_content"
346          assert span.span_type == SpanType.EMBEDDING
347          assert span.inputs == {"content": "Hello World", "model": "models/text-embedding-004"}
348          assert span.outputs == _DUMMY_EMBEDDING_RESPONSE
349  
350          mlflow.gemini.autolog(disable=True)
351          genai.embed_content(model="models/text-embedding-004", content="Hello World")
352  
353          # No new trace should be created
354          traces = get_traces()
355          assert len(traces) == 1