test_auth_interface.py
1 """ 2 Unit tests for gateway/base/auth_interface.py 3 Tests the abstract AuthHandler interface. 4 """ 5 6 import pytest 7 from abc import ABC 8 from unittest.mock import AsyncMock, MagicMock 9 10 from solace_agent_mesh.gateway.base.auth_interface import AuthHandler 11 12 13 class ConcreteAuthHandler(AuthHandler): 14 """Concrete implementation for testing the abstract interface.""" 15 16 def __init__(self): 17 self.is_auth = False 18 self.auth_headers = {} 19 self.authorize_called = False 20 self.callback_called = False 21 22 async def handle_authorize(self, request): 23 self.authorize_called = True 24 return {"redirect_url": "https://auth.example.com/authorize", "status_code": 302} 25 26 async def handle_callback(self, request): 27 self.callback_called = True 28 self.is_auth = True 29 return {"success": True, "message": "Authentication successful"} 30 31 async def get_auth_headers(self): 32 if self.is_auth: 33 return self.auth_headers 34 return {} 35 36 async def is_authenticated(self): 37 return self.is_auth 38 39 40 class TestAuthHandlerInterface: 41 """Test the AuthHandler abstract interface.""" 42 43 def test_auth_handler_is_abstract(self): 44 """Test that AuthHandler is an abstract base class.""" 45 assert issubclass(AuthHandler, ABC) 46 47 # Attempting to instantiate should raise TypeError 48 with pytest.raises(TypeError): 49 AuthHandler() 50 51 def test_auth_handler_has_required_methods(self): 52 """Test that AuthHandler defines all required abstract methods.""" 53 required_methods = [ 54 'handle_authorize', 55 'handle_callback', 56 'get_auth_headers', 57 'is_authenticated' 58 ] 59 60 for method_name in required_methods: 61 assert hasattr(AuthHandler, method_name) 62 63 @pytest.mark.asyncio 64 async def test_concrete_implementation_handle_authorize(self): 65 """Test concrete implementation of handle_authorize.""" 66 handler = ConcreteAuthHandler() 67 68 result = await handler.handle_authorize(MagicMock()) 69 70 assert handler.authorize_called is True 71 assert "redirect_url" in result 72 assert result["redirect_url"] == "https://auth.example.com/authorize" 73 assert result["status_code"] == 302 74 75 @pytest.mark.asyncio 76 async def test_concrete_implementation_handle_callback(self): 77 """Test concrete implementation of handle_callback.""" 78 handler = ConcreteAuthHandler() 79 80 result = await handler.handle_callback(MagicMock()) 81 82 assert handler.callback_called is True 83 assert result["success"] is True 84 assert "message" in result 85 86 @pytest.mark.asyncio 87 async def test_concrete_implementation_get_auth_headers_authenticated(self): 88 """Test get_auth_headers when authenticated.""" 89 handler = ConcreteAuthHandler() 90 handler.is_auth = True 91 handler.auth_headers = {"Authorization": "Bearer test-token"} 92 93 headers = await handler.get_auth_headers() 94 95 assert headers == {"Authorization": "Bearer test-token"} 96 97 @pytest.mark.asyncio 98 async def test_concrete_implementation_get_auth_headers_not_authenticated(self): 99 """Test get_auth_headers when not authenticated.""" 100 handler = ConcreteAuthHandler() 101 handler.is_auth = False 102 103 headers = await handler.get_auth_headers() 104 105 assert headers == {} 106 107 @pytest.mark.asyncio 108 async def test_concrete_implementation_is_authenticated(self): 109 """Test is_authenticated method.""" 110 handler = ConcreteAuthHandler() 111 112 # Initially not authenticated 113 assert await handler.is_authenticated() is False 114 115 # After callback 116 await handler.handle_callback(MagicMock()) 117 assert await handler.is_authenticated() is True 118 119 120 class TestAuthHandlerWorkflow: 121 """Test typical authentication workflow.""" 122 123 @pytest.mark.asyncio 124 async def test_oauth_workflow(self): 125 """Test a typical OAuth2 workflow.""" 126 handler = ConcreteAuthHandler() 127 128 # Step 1: User initiates authorization 129 authorize_result = await handler.handle_authorize(MagicMock()) 130 assert authorize_result["redirect_url"] is not None 131 132 # At this point, user is not authenticated yet 133 assert await handler.is_authenticated() is False 134 assert await handler.get_auth_headers() == {} 135 136 # Step 2: OAuth callback with code 137 callback_result = await handler.handle_callback(MagicMock()) 138 assert callback_result["success"] is True 139 140 # Step 3: Now authenticated 141 assert await handler.is_authenticated() is True 142 143 # Step 4: Can get auth headers 144 handler.auth_headers = {"Authorization": "Bearer token123"} 145 headers = await handler.get_auth_headers() 146 assert "Authorization" in headers 147 148 149 class TestAuthHandlerEdgeCases: 150 """Test edge cases and error handling.""" 151 152 @pytest.mark.asyncio 153 async def test_get_auth_headers_empty_when_not_authenticated(self): 154 """Test that get_auth_headers returns empty dict when not authenticated.""" 155 handler = ConcreteAuthHandler() 156 handler.is_auth = False 157 handler.auth_headers = {"Authorization": "Bearer should-not-be-returned"} 158 159 headers = await handler.get_auth_headers() 160 161 # Should return empty, not the stored headers 162 assert headers == {} 163 164 @pytest.mark.asyncio 165 async def test_multiple_authorize_calls(self): 166 """Test that handle_authorize can be called multiple times.""" 167 handler = ConcreteAuthHandler() 168 169 result1 = await handler.handle_authorize(MagicMock()) 170 result2 = await handler.handle_authorize(MagicMock()) 171 172 # Should work both times 173 assert "redirect_url" in result1 174 assert "redirect_url" in result2 175 176 @pytest.mark.asyncio 177 async def test_callback_before_authorize(self): 178 """Test calling callback without authorize (e.g., replay attack).""" 179 handler = ConcreteAuthHandler() 180 181 # In this simple implementation, callback works independently 182 # Real implementations might validate state/nonce 183 result = await handler.handle_callback(MagicMock()) 184 185 assert result["success"] is True 186 187 188 class FailingAuthHandler(AuthHandler): 189 """Auth handler that raises exceptions for testing error handling.""" 190 191 async def handle_authorize(self, request): 192 raise ValueError("Authorization service unavailable") 193 194 async def handle_callback(self, request): 195 raise ValueError("Invalid authorization code") 196 197 async def get_auth_headers(self): 198 raise RuntimeError("Token expired") 199 200 async def is_authenticated(self): 201 raise ConnectionError("Cannot reach auth service") 202 203 204 class TestAuthHandlerErrorHandling: 205 """Test error handling in auth handlers.""" 206 207 @pytest.mark.asyncio 208 async def test_handle_authorize_exception(self): 209 """Test that handle_authorize can raise exceptions.""" 210 handler = FailingAuthHandler() 211 212 with pytest.raises(ValueError, match="Authorization service unavailable"): 213 await handler.handle_authorize(MagicMock()) 214 215 @pytest.mark.asyncio 216 async def test_handle_callback_exception(self): 217 """Test that handle_callback can raise exceptions.""" 218 handler = FailingAuthHandler() 219 220 with pytest.raises(ValueError, match="Invalid authorization code"): 221 await handler.handle_callback(MagicMock()) 222 223 @pytest.mark.asyncio 224 async def test_get_auth_headers_exception(self): 225 """Test that get_auth_headers can raise exceptions.""" 226 handler = FailingAuthHandler() 227 228 with pytest.raises(RuntimeError, match="Token expired"): 229 await handler.get_auth_headers() 230 231 @pytest.mark.asyncio 232 async def test_is_authenticated_exception(self): 233 """Test that is_authenticated can raise exceptions.""" 234 handler = FailingAuthHandler() 235 236 with pytest.raises(ConnectionError, match="Cannot reach auth service"): 237 await handler.is_authenticated() 238 239 240 class MockFrameworkRequest: 241 """Mock request object representing different frameworks.""" 242 243 def __init__(self, params=None, headers=None): 244 self.params = params or {} 245 self.headers = headers or {} 246 247 248 class TestAuthHandlerFrameworkAgnostic: 249 """Test that AuthHandler interface works with different request objects.""" 250 251 @pytest.mark.asyncio 252 async def test_handle_authorize_with_custom_request(self): 253 """Test handle_authorize with custom request object.""" 254 handler = ConcreteAuthHandler() 255 request = MockFrameworkRequest(params={"client_id": "test123"}) 256 257 result = await handler.handle_authorize(request) 258 259 assert result is not None 260 assert "redirect_url" in result 261 262 @pytest.mark.asyncio 263 async def test_handle_callback_with_custom_request(self): 264 """Test handle_callback with custom request object.""" 265 handler = ConcreteAuthHandler() 266 request = MockFrameworkRequest( 267 params={"code": "auth_code_123", "state": "random_state"} 268 ) 269 270 result = await handler.handle_callback(request) 271 272 assert result["success"] is True