test_gemini_client.py
1 """Unit tests for GeminiClient.complete() — happy path, error wrapping, and retries. 2 3 google-genai does not use httpx, so pytest-httpx cannot intercept its calls. 4 We mock self._client.models.generate_content at the SDK level instead. 5 """ 6 7 import unittest.mock 8 from unittest.mock import MagicMock 9 10 import pytest 11 from google.genai.errors import ClientError as _GenAIClientError 12 from google.genai.errors import ServerError as _GenAIServerError 13 14 from exceptions import APIError 15 from integrations.gemini import GeminiClient 16 17 _MODEL = "gemini-2.0-flash-lite" 18 19 20 def _make_client() -> GeminiClient: 21 return GeminiClient(api_key="fake-key", model=_MODEL) 22 23 24 class TestGeminiClient: 25 def test_text_response_returned(self) -> None: 26 client = _make_client() 27 mock_response = MagicMock() 28 mock_response.text = "Hello world" 29 with unittest.mock.patch.object( 30 client._client.models, "generate_content", return_value=mock_response 31 ): 32 result = client.complete(system="sys", user="usr") 33 assert result == "Hello world" 34 35 def test_4xx_wrapped_as_api_error(self) -> None: 36 # Non-retriable 4xx from the genai SDK must be caught and wrapped into APIError. 37 client = _make_client() 38 with unittest.mock.patch.object( 39 client._client.models, 40 "generate_content", 41 side_effect=_GenAIClientError(400, {"message": "bad request", "status": "BAD_REQUEST"}), 42 ): 43 with pytest.raises(APIError) as exc_info: 44 client.complete(system="sys", user="usr") 45 assert exc_info.value.status_code == 400 46 47 def test_429_rate_limit_retried_then_reraised(self) -> None: 48 # 429 rate-limit errors surface as ClientError in the genai SDK 49 # and must trigger tenacity retries. 50 client = _make_client() 51 mock_fn = MagicMock( 52 side_effect=_GenAIClientError( 53 429, {"message": "rate limit exceeded", "status": "RESOURCE_EXHAUSTED"} 54 ) 55 ) 56 with ( 57 unittest.mock.patch.object(client._client.models, "generate_content", mock_fn), 58 unittest.mock.patch.object(GeminiClient.complete.retry, "sleep"), # type: ignore[attr-defined] 59 pytest.raises(_GenAIClientError), 60 ): 61 client.complete(system="sys", user="usr") 62 assert mock_fn.call_count > 1 63 64 def test_5xx_retried_then_reraised(self) -> None: 65 # Transient 5xx errors must trigger tenacity retries before reraising. 66 client = _make_client() 67 mock_fn = MagicMock( 68 side_effect=_GenAIServerError( 69 503, {"message": "service unavailable", "status": "UNAVAILABLE"} 70 ) 71 ) 72 with ( 73 unittest.mock.patch.object(client._client.models, "generate_content", mock_fn), 74 unittest.mock.patch.object(GeminiClient.complete.retry, "sleep"), # type: ignore[attr-defined] 75 pytest.raises(_GenAIServerError), 76 ): 77 client.complete(system="sys", user="usr") 78 assert mock_fn.call_count > 1 79 80 def test_attachments_builds_parts(self) -> None: 81 # Attachments must produce Part.from_bytes() parts before the text. 82 from unittest.mock import MagicMock, patch 83 84 from models.llm import Attachment 85 86 client = _make_client() 87 attachment = Attachment(data=b"fake-pdf", media_type="application/pdf") 88 mock_response = MagicMock() 89 mock_response.text = "ok" 90 mock_fn = MagicMock(return_value=mock_response) 91 92 with patch.object(client._client.models, "generate_content", mock_fn): 93 client.complete(system="sys", user="usr", attachments=[attachment]) 94 95 call_kwargs = mock_fn.call_args 96 contents = call_kwargs.kwargs.get("contents") or call_kwargs[1].get("contents") 97 assert isinstance(contents, list) 98 assert len(contents) == 2 # Part + text string 99 assert contents[1] == "usr" 100 101 def test_attachments_none_uses_string_content(self) -> None: 102 # No attachments: contents is a plain string. 103 from unittest.mock import MagicMock, patch 104 105 client = _make_client() 106 mock_response = MagicMock() 107 mock_response.text = "ok" 108 mock_fn = MagicMock(return_value=mock_response) 109 110 with patch.object(client._client.models, "generate_content", mock_fn): 111 client.complete(system="sys", user="usr", attachments=None) 112 113 call_kwargs = mock_fn.call_args 114 contents = call_kwargs.kwargs.get("contents") or call_kwargs[1].get("contents") 115 assert contents == "usr" 116 117 118 class TestGeminiLastUsage: 119 def test_last_usage_populated_after_complete(self) -> None: 120 client = _make_client() 121 mock_response = MagicMock() 122 mock_response.text = "ok" 123 mock_response.usage_metadata.prompt_token_count = 100 124 mock_response.usage_metadata.candidates_token_count = 50 125 with unittest.mock.patch.object( 126 client._client.models, "generate_content", return_value=mock_response 127 ): 128 client.complete(system="sys", user="usr") 129 130 assert client.last_usage == (100, 50) 131 132 def test_last_usage_unchanged_after_failed_complete(self) -> None: 133 # Successful call first, then a 4xx failure — last_usage retains the successful values. 134 client = _make_client() 135 mock_response = MagicMock() 136 mock_response.text = "ok" 137 mock_response.usage_metadata.prompt_token_count = 100 138 mock_response.usage_metadata.candidates_token_count = 50 139 with unittest.mock.patch.object( 140 client._client.models, "generate_content", return_value=mock_response 141 ): 142 client.complete(system="sys", user="usr") 143 assert client.last_usage == (100, 50) 144 145 with unittest.mock.patch.object( 146 client._client.models, 147 "generate_content", 148 side_effect=_GenAIClientError(400, {"message": "bad", "status": "BAD"}), 149 ): 150 with pytest.raises(APIError): 151 client.complete(system="sys", user="usr") 152 153 assert client.last_usage == (100, 50) 154 155 def test_last_usage_returns_zeros_when_usage_metadata_is_none(self) -> None: 156 # Safety-blocked responses may have usage_metadata=None. 157 client = _make_client() 158 mock_response = MagicMock() 159 mock_response.text = "ok" 160 mock_response.usage_metadata = None 161 with unittest.mock.patch.object( 162 client._client.models, "generate_content", return_value=mock_response 163 ): 164 client.complete(system="sys", user="usr") 165 166 assert client.last_usage == (0, 0)