/ tests / test_gemini_client.py
test_gemini_client.py
  1  """Unit tests for GeminiClient.complete() — happy path, error wrapping, and retries.
  2  
  3  google-genai does not use httpx, so pytest-httpx cannot intercept its calls.
  4  We mock self._client.models.generate_content at the SDK level instead.
  5  """
  6  
  7  import unittest.mock
  8  from unittest.mock import MagicMock
  9  
 10  import pytest
 11  from google.genai.errors import ClientError as _GenAIClientError
 12  from google.genai.errors import ServerError as _GenAIServerError
 13  
 14  from exceptions import APIError
 15  from integrations.gemini import GeminiClient
 16  
 17  _MODEL = "gemini-2.0-flash-lite"
 18  
 19  
 20  def _make_client() -> GeminiClient:
 21      return GeminiClient(api_key="fake-key", model=_MODEL)
 22  
 23  
 24  class TestGeminiClient:
 25      def test_text_response_returned(self) -> None:
 26          client = _make_client()
 27          mock_response = MagicMock()
 28          mock_response.text = "Hello world"
 29          with unittest.mock.patch.object(
 30              client._client.models, "generate_content", return_value=mock_response
 31          ):
 32              result = client.complete(system="sys", user="usr")
 33          assert result == "Hello world"
 34  
 35      def test_4xx_wrapped_as_api_error(self) -> None:
 36          # Non-retriable 4xx from the genai SDK must be caught and wrapped into APIError.
 37          client = _make_client()
 38          with unittest.mock.patch.object(
 39              client._client.models,
 40              "generate_content",
 41              side_effect=_GenAIClientError(400, {"message": "bad request", "status": "BAD_REQUEST"}),
 42          ):
 43              with pytest.raises(APIError) as exc_info:
 44                  client.complete(system="sys", user="usr")
 45          assert exc_info.value.status_code == 400
 46  
 47      def test_429_rate_limit_retried_then_reraised(self) -> None:
 48          # 429 rate-limit errors surface as ClientError in the genai SDK
 49          # and must trigger tenacity retries.
 50          client = _make_client()
 51          mock_fn = MagicMock(
 52              side_effect=_GenAIClientError(
 53                  429, {"message": "rate limit exceeded", "status": "RESOURCE_EXHAUSTED"}
 54              )
 55          )
 56          with (
 57              unittest.mock.patch.object(client._client.models, "generate_content", mock_fn),
 58              unittest.mock.patch.object(GeminiClient.complete.retry, "sleep"),  # type: ignore[attr-defined]
 59              pytest.raises(_GenAIClientError),
 60          ):
 61              client.complete(system="sys", user="usr")
 62          assert mock_fn.call_count > 1
 63  
 64      def test_5xx_retried_then_reraised(self) -> None:
 65          # Transient 5xx errors must trigger tenacity retries before reraising.
 66          client = _make_client()
 67          mock_fn = MagicMock(
 68              side_effect=_GenAIServerError(
 69                  503, {"message": "service unavailable", "status": "UNAVAILABLE"}
 70              )
 71          )
 72          with (
 73              unittest.mock.patch.object(client._client.models, "generate_content", mock_fn),
 74              unittest.mock.patch.object(GeminiClient.complete.retry, "sleep"),  # type: ignore[attr-defined]
 75              pytest.raises(_GenAIServerError),
 76          ):
 77              client.complete(system="sys", user="usr")
 78          assert mock_fn.call_count > 1
 79  
 80      def test_attachments_builds_parts(self) -> None:
 81          # Attachments must produce Part.from_bytes() parts before the text.
 82          from unittest.mock import MagicMock, patch
 83  
 84          from models.llm import Attachment
 85  
 86          client = _make_client()
 87          attachment = Attachment(data=b"fake-pdf", media_type="application/pdf")
 88          mock_response = MagicMock()
 89          mock_response.text = "ok"
 90          mock_fn = MagicMock(return_value=mock_response)
 91  
 92          with patch.object(client._client.models, "generate_content", mock_fn):
 93              client.complete(system="sys", user="usr", attachments=[attachment])
 94  
 95          call_kwargs = mock_fn.call_args
 96          contents = call_kwargs.kwargs.get("contents") or call_kwargs[1].get("contents")
 97          assert isinstance(contents, list)
 98          assert len(contents) == 2  # Part + text string
 99          assert contents[1] == "usr"
100  
101      def test_attachments_none_uses_string_content(self) -> None:
102          # No attachments: contents is a plain string.
103          from unittest.mock import MagicMock, patch
104  
105          client = _make_client()
106          mock_response = MagicMock()
107          mock_response.text = "ok"
108          mock_fn = MagicMock(return_value=mock_response)
109  
110          with patch.object(client._client.models, "generate_content", mock_fn):
111              client.complete(system="sys", user="usr", attachments=None)
112  
113          call_kwargs = mock_fn.call_args
114          contents = call_kwargs.kwargs.get("contents") or call_kwargs[1].get("contents")
115          assert contents == "usr"
116  
117  
118  class TestGeminiLastUsage:
119      def test_last_usage_populated_after_complete(self) -> None:
120          client = _make_client()
121          mock_response = MagicMock()
122          mock_response.text = "ok"
123          mock_response.usage_metadata.prompt_token_count = 100
124          mock_response.usage_metadata.candidates_token_count = 50
125          with unittest.mock.patch.object(
126              client._client.models, "generate_content", return_value=mock_response
127          ):
128              client.complete(system="sys", user="usr")
129  
130          assert client.last_usage == (100, 50)
131  
132      def test_last_usage_unchanged_after_failed_complete(self) -> None:
133          # Successful call first, then a 4xx failure — last_usage retains the successful values.
134          client = _make_client()
135          mock_response = MagicMock()
136          mock_response.text = "ok"
137          mock_response.usage_metadata.prompt_token_count = 100
138          mock_response.usage_metadata.candidates_token_count = 50
139          with unittest.mock.patch.object(
140              client._client.models, "generate_content", return_value=mock_response
141          ):
142              client.complete(system="sys", user="usr")
143          assert client.last_usage == (100, 50)
144  
145          with unittest.mock.patch.object(
146              client._client.models,
147              "generate_content",
148              side_effect=_GenAIClientError(400, {"message": "bad", "status": "BAD"}),
149          ):
150              with pytest.raises(APIError):
151                  client.complete(system="sys", user="usr")
152  
153          assert client.last_usage == (100, 50)
154  
155      def test_last_usage_returns_zeros_when_usage_metadata_is_none(self) -> None:
156          # Safety-blocked responses may have usage_metadata=None.
157          client = _make_client()
158          mock_response = MagicMock()
159          mock_response.text = "ok"
160          mock_response.usage_metadata = None
161          with unittest.mock.patch.object(
162              client._client.models, "generate_content", return_value=mock_response
163          ):
164              client.complete(system="sys", user="usr")
165  
166          assert client.last_usage == (0, 0)