/ tests / unit / gateway / base / test_auth_interface.py
test_auth_interface.py
  1  """
  2  Unit tests for gateway/base/auth_interface.py
  3  Tests the abstract AuthHandler interface.
  4  """
  5  
  6  import pytest
  7  from abc import ABC
  8  from unittest.mock import AsyncMock, MagicMock
  9  
 10  from solace_agent_mesh.gateway.base.auth_interface import AuthHandler
 11  
 12  
 13  class ConcreteAuthHandler(AuthHandler):
 14      """Concrete implementation for testing the abstract interface."""
 15  
 16      def __init__(self):
 17          self.is_auth = False
 18          self.auth_headers = {}
 19          self.authorize_called = False
 20          self.callback_called = False
 21  
 22      async def handle_authorize(self, request):
 23          self.authorize_called = True
 24          return {"redirect_url": "https://auth.example.com/authorize", "status_code": 302}
 25  
 26      async def handle_callback(self, request):
 27          self.callback_called = True
 28          self.is_auth = True
 29          return {"success": True, "message": "Authentication successful"}
 30  
 31      async def get_auth_headers(self):
 32          if self.is_auth:
 33              return self.auth_headers
 34          return {}
 35  
 36      async def is_authenticated(self):
 37          return self.is_auth
 38  
 39  
 40  class TestAuthHandlerInterface:
 41      """Test the AuthHandler abstract interface."""
 42  
 43      def test_auth_handler_is_abstract(self):
 44          """Test that AuthHandler is an abstract base class."""
 45          assert issubclass(AuthHandler, ABC)
 46  
 47          # Attempting to instantiate should raise TypeError
 48          with pytest.raises(TypeError):
 49              AuthHandler()
 50  
 51      def test_auth_handler_has_required_methods(self):
 52          """Test that AuthHandler defines all required abstract methods."""
 53          required_methods = [
 54              'handle_authorize',
 55              'handle_callback',
 56              'get_auth_headers',
 57              'is_authenticated'
 58          ]
 59  
 60          for method_name in required_methods:
 61              assert hasattr(AuthHandler, method_name)
 62  
 63      @pytest.mark.asyncio
 64      async def test_concrete_implementation_handle_authorize(self):
 65          """Test concrete implementation of handle_authorize."""
 66          handler = ConcreteAuthHandler()
 67  
 68          result = await handler.handle_authorize(MagicMock())
 69  
 70          assert handler.authorize_called is True
 71          assert "redirect_url" in result
 72          assert result["redirect_url"] == "https://auth.example.com/authorize"
 73          assert result["status_code"] == 302
 74  
 75      @pytest.mark.asyncio
 76      async def test_concrete_implementation_handle_callback(self):
 77          """Test concrete implementation of handle_callback."""
 78          handler = ConcreteAuthHandler()
 79  
 80          result = await handler.handle_callback(MagicMock())
 81  
 82          assert handler.callback_called is True
 83          assert result["success"] is True
 84          assert "message" in result
 85  
 86      @pytest.mark.asyncio
 87      async def test_concrete_implementation_get_auth_headers_authenticated(self):
 88          """Test get_auth_headers when authenticated."""
 89          handler = ConcreteAuthHandler()
 90          handler.is_auth = True
 91          handler.auth_headers = {"Authorization": "Bearer test-token"}
 92  
 93          headers = await handler.get_auth_headers()
 94  
 95          assert headers == {"Authorization": "Bearer test-token"}
 96  
 97      @pytest.mark.asyncio
 98      async def test_concrete_implementation_get_auth_headers_not_authenticated(self):
 99          """Test get_auth_headers when not authenticated."""
100          handler = ConcreteAuthHandler()
101          handler.is_auth = False
102  
103          headers = await handler.get_auth_headers()
104  
105          assert headers == {}
106  
107      @pytest.mark.asyncio
108      async def test_concrete_implementation_is_authenticated(self):
109          """Test is_authenticated method."""
110          handler = ConcreteAuthHandler()
111  
112          # Initially not authenticated
113          assert await handler.is_authenticated() is False
114  
115          # After callback
116          await handler.handle_callback(MagicMock())
117          assert await handler.is_authenticated() is True
118  
119  
120  class TestAuthHandlerWorkflow:
121      """Test typical authentication workflow."""
122  
123      @pytest.mark.asyncio
124      async def test_oauth_workflow(self):
125          """Test a typical OAuth2 workflow."""
126          handler = ConcreteAuthHandler()
127  
128          # Step 1: User initiates authorization
129          authorize_result = await handler.handle_authorize(MagicMock())
130          assert authorize_result["redirect_url"] is not None
131  
132          # At this point, user is not authenticated yet
133          assert await handler.is_authenticated() is False
134          assert await handler.get_auth_headers() == {}
135  
136          # Step 2: OAuth callback with code
137          callback_result = await handler.handle_callback(MagicMock())
138          assert callback_result["success"] is True
139  
140          # Step 3: Now authenticated
141          assert await handler.is_authenticated() is True
142  
143          # Step 4: Can get auth headers
144          handler.auth_headers = {"Authorization": "Bearer token123"}
145          headers = await handler.get_auth_headers()
146          assert "Authorization" in headers
147  
148  
149  class TestAuthHandlerEdgeCases:
150      """Test edge cases and error handling."""
151  
152      @pytest.mark.asyncio
153      async def test_get_auth_headers_empty_when_not_authenticated(self):
154          """Test that get_auth_headers returns empty dict when not authenticated."""
155          handler = ConcreteAuthHandler()
156          handler.is_auth = False
157          handler.auth_headers = {"Authorization": "Bearer should-not-be-returned"}
158  
159          headers = await handler.get_auth_headers()
160  
161          # Should return empty, not the stored headers
162          assert headers == {}
163  
164      @pytest.mark.asyncio
165      async def test_multiple_authorize_calls(self):
166          """Test that handle_authorize can be called multiple times."""
167          handler = ConcreteAuthHandler()
168  
169          result1 = await handler.handle_authorize(MagicMock())
170          result2 = await handler.handle_authorize(MagicMock())
171  
172          # Should work both times
173          assert "redirect_url" in result1
174          assert "redirect_url" in result2
175  
176      @pytest.mark.asyncio
177      async def test_callback_before_authorize(self):
178          """Test calling callback without authorize (e.g., replay attack)."""
179          handler = ConcreteAuthHandler()
180  
181          # In this simple implementation, callback works independently
182          # Real implementations might validate state/nonce
183          result = await handler.handle_callback(MagicMock())
184  
185          assert result["success"] is True
186  
187  
188  class FailingAuthHandler(AuthHandler):
189      """Auth handler that raises exceptions for testing error handling."""
190  
191      async def handle_authorize(self, request):
192          raise ValueError("Authorization service unavailable")
193  
194      async def handle_callback(self, request):
195          raise ValueError("Invalid authorization code")
196  
197      async def get_auth_headers(self):
198          raise RuntimeError("Token expired")
199  
200      async def is_authenticated(self):
201          raise ConnectionError("Cannot reach auth service")
202  
203  
204  class TestAuthHandlerErrorHandling:
205      """Test error handling in auth handlers."""
206  
207      @pytest.mark.asyncio
208      async def test_handle_authorize_exception(self):
209          """Test that handle_authorize can raise exceptions."""
210          handler = FailingAuthHandler()
211  
212          with pytest.raises(ValueError, match="Authorization service unavailable"):
213              await handler.handle_authorize(MagicMock())
214  
215      @pytest.mark.asyncio
216      async def test_handle_callback_exception(self):
217          """Test that handle_callback can raise exceptions."""
218          handler = FailingAuthHandler()
219  
220          with pytest.raises(ValueError, match="Invalid authorization code"):
221              await handler.handle_callback(MagicMock())
222  
223      @pytest.mark.asyncio
224      async def test_get_auth_headers_exception(self):
225          """Test that get_auth_headers can raise exceptions."""
226          handler = FailingAuthHandler()
227  
228          with pytest.raises(RuntimeError, match="Token expired"):
229              await handler.get_auth_headers()
230  
231      @pytest.mark.asyncio
232      async def test_is_authenticated_exception(self):
233          """Test that is_authenticated can raise exceptions."""
234          handler = FailingAuthHandler()
235  
236          with pytest.raises(ConnectionError, match="Cannot reach auth service"):
237              await handler.is_authenticated()
238  
239  
240  class MockFrameworkRequest:
241      """Mock request object representing different frameworks."""
242  
243      def __init__(self, params=None, headers=None):
244          self.params = params or {}
245          self.headers = headers or {}
246  
247  
248  class TestAuthHandlerFrameworkAgnostic:
249      """Test that AuthHandler interface works with different request objects."""
250  
251      @pytest.mark.asyncio
252      async def test_handle_authorize_with_custom_request(self):
253          """Test handle_authorize with custom request object."""
254          handler = ConcreteAuthHandler()
255          request = MockFrameworkRequest(params={"client_id": "test123"})
256  
257          result = await handler.handle_authorize(request)
258  
259          assert result is not None
260          assert "redirect_url" in result
261  
262      @pytest.mark.asyncio
263      async def test_handle_callback_with_custom_request(self):
264          """Test handle_callback with custom request object."""
265          handler = ConcreteAuthHandler()
266          request = MockFrameworkRequest(
267              params={"code": "auth_code_123", "state": "random_state"}
268          )
269  
270          result = await handler.handle_callback(request)
271  
272          assert result["success"] is True