/ tests / unit / common / test_auth_headers.py
test_auth_headers.py
  1  """Tests for common authentication header building utilities."""
  2  
  3  import pytest
  4  from unittest.mock import AsyncMock, patch
  5  from solace_agent_mesh.common.auth_headers import (
  6      build_static_auth_headers,
  7      build_full_auth_headers,
  8  )
  9  
 10  
 11  class TestBuildStaticAuthHeaders:
 12      """Test suite for build_static_auth_headers()."""
 13  
 14      def test_static_bearer_token(self):
 15          """Test static bearer token authentication."""
 16          agent_config = {
 17              "authentication": {"type": "static_bearer", "token": "test_token_123"}
 18          }
 19  
 20          headers = build_static_auth_headers(
 21              agent_name="test-agent",
 22              agent_config=agent_config,
 23              custom_headers_key="agent_card_headers",
 24              use_auth=True,
 25          )
 26  
 27          assert headers == {"Authorization": "Bearer test_token_123"}
 28  
 29      def test_static_apikey(self):
 30          """Test static API key authentication."""
 31          agent_config = {
 32              "authentication": {"type": "static_apikey", "token": "api_key_456"}
 33          }
 34  
 35          headers = build_static_auth_headers(
 36              agent_name="test-agent",
 37              agent_config=agent_config,
 38              custom_headers_key="agent_card_headers",
 39              use_auth=True,
 40          )
 41  
 42          assert headers == {"X-API-Key": "api_key_456"}
 43  
 44      def test_legacy_bearer_scheme(self):
 45          """Test backward compatibility with legacy 'scheme' field."""
 46          agent_config = {
 47              "authentication": {
 48                  "scheme": "bearer",  # Legacy format
 49                  "token": "legacy_token",
 50              }
 51          }
 52  
 53          headers = build_static_auth_headers(
 54              agent_name="test-agent",
 55              agent_config=agent_config,
 56              custom_headers_key="agent_card_headers",
 57              use_auth=True,
 58          )
 59  
 60          assert headers == {"Authorization": "Bearer legacy_token"}
 61  
 62      def test_legacy_apikey_scheme(self):
 63          """Test backward compatibility with legacy 'apikey' scheme."""
 64          agent_config = {
 65              "authentication": {
 66                  "scheme": "apikey",  # Legacy format
 67                  "token": "legacy_key",
 68              }
 69          }
 70  
 71          headers = build_static_auth_headers(
 72              agent_name="test-agent",
 73              agent_config=agent_config,
 74              custom_headers_key="agent_card_headers",
 75              use_auth=True,
 76          )
 77  
 78          assert headers == {"X-API-Key": "legacy_key"}
 79  
 80      def test_custom_headers_override_auth(self):
 81          """Test that custom headers override auth headers."""
 82          agent_config = {
 83              "authentication": {"type": "static_bearer", "token": "original_token"},
 84              "agent_card_headers": [
 85                  {"name": "Authorization", "value": "Custom override"},
 86                  {"name": "X-Custom", "value": "custom_value"},
 87              ],
 88          }
 89  
 90          headers = build_static_auth_headers(
 91              agent_name="test-agent",
 92              agent_config=agent_config,
 93              custom_headers_key="agent_card_headers",
 94              use_auth=True,
 95          )
 96  
 97          assert headers == {
 98              "Authorization": "Custom override",  # Overridden
 99              "X-Custom": "custom_value",
100          }
101  
102      @patch("solace_agent_mesh.common.auth_headers.log")
103      def test_oauth2_skipped_with_warning(self, mock_log):
104          """Test that OAuth2 is skipped in sync context with warning."""
105          agent_config = {
106              "authentication": {
107                  "type": "oauth2_client_credentials",
108                  "token_url": "https://auth.example.com/token",
109                  "client_id": "client123",
110                  "client_secret": "secret456",
111              }
112          }
113  
114          headers = build_static_auth_headers(
115              agent_name="test-agent",
116              agent_config=agent_config,
117              custom_headers_key="agent_card_headers",
118              use_auth=True,
119              log_identifier="[Test]",
120          )
121  
122          # No auth header should be added
123          assert headers == {}
124          # Warning should be logged
125          mock_log.warning.assert_called_once()
126          call_args = " ".join(str(arg) for arg in mock_log.warning.call_args[0])
127          assert "OAuth2 authentication" in call_args
128          assert "not supported in synchronous context" in call_args
129  
130      def test_use_auth_false(self):
131          """Test that use_auth=False skips authentication."""
132          agent_config = {
133              "authentication": {"type": "static_bearer", "token": "test_token"}
134          }
135  
136          headers = build_static_auth_headers(
137              agent_name="test-agent",
138              agent_config=agent_config,
139              custom_headers_key="agent_card_headers",
140              use_auth=False,
141          )
142  
143          assert headers == {}
144  
145      def test_no_authentication_config(self):
146          """Test behavior when no authentication is configured."""
147          agent_config = {}
148  
149          headers = build_static_auth_headers(
150              agent_name="test-agent",
151              agent_config=agent_config,
152              custom_headers_key="agent_card_headers",
153              use_auth=True,
154          )
155  
156          assert headers == {}
157  
158      def test_missing_token(self):
159          """Test behavior when token is missing."""
160          agent_config = {
161              "authentication": {
162                  "type": "static_bearer"
163                  # No token field
164              }
165          }
166  
167          headers = build_static_auth_headers(
168              agent_name="test-agent",
169              agent_config=agent_config,
170              custom_headers_key="agent_card_headers",
171              use_auth=True,
172          )
173  
174          assert headers == {}
175  
176      def test_custom_headers_only(self):
177          """Test custom headers without authentication."""
178          agent_config = {
179              "task_headers": [
180                  {"name": "X-Request-ID", "value": "req-123"},
181                  {"name": "X-Tenant", "value": "tenant-456"},
182              ]
183          }
184  
185          headers = build_static_auth_headers(
186              agent_name="test-agent",
187              agent_config=agent_config,
188              custom_headers_key="task_headers",
189              use_auth=False,
190          )
191  
192          assert headers == {"X-Request-ID": "req-123", "X-Tenant": "tenant-456"}
193  
194      def test_custom_headers_with_missing_fields(self):
195          """Test that custom headers with missing name or value are skipped."""
196          agent_config = {
197              "agent_card_headers": [
198                  {"name": "X-Valid", "value": "valid_value"},
199                  {"name": "X-No-Value"},  # Missing value
200                  {"value": "no_name"},  # Missing name
201                  {},  # Both missing
202              ]
203          }
204  
205          headers = build_static_auth_headers(
206              agent_name="test-agent",
207              agent_config=agent_config,
208              custom_headers_key="agent_card_headers",
209              use_auth=False,
210          )
211  
212          assert headers == {"X-Valid": "valid_value"}
213  
214  
215  class TestBuildAuthHeadersAsync:
216      """Test suite for build_full_auth_headers()."""
217  
218      @pytest.mark.asyncio
219      async def test_static_bearer_async(self):
220          """Test static bearer token in async context."""
221          agent_config = {
222              "authentication": {"type": "static_bearer", "token": "async_token"}
223          }
224  
225          headers = await build_full_auth_headers(
226              agent_name="test-agent",
227              agent_config=agent_config,
228              custom_headers_key="task_headers",
229              use_auth=True,
230          )
231  
232          assert headers == {"Authorization": "Bearer async_token"}
233  
234      @pytest.mark.asyncio
235      async def test_oauth2_with_token_fetcher(self):
236          """Test OAuth2 with token fetcher."""
237          agent_config = {
238              "authentication": {
239                  "type": "oauth2_client_credentials",
240                  "token_url": "https://auth.example.com/token",
241                  "client_id": "client123",
242                  "client_secret": "secret456",
243              }
244          }
245  
246          # Mock token fetcher
247          async def mock_fetcher(agent_name, auth_config):
248              assert agent_name == "test-agent"
249              assert auth_config["client_id"] == "client123"
250              return "oauth_access_token_xyz"
251  
252          headers = await build_full_auth_headers(
253              agent_name="test-agent",
254              agent_config=agent_config,
255              custom_headers_key="task_headers",
256              use_auth=True,
257              oauth_token_fetcher=mock_fetcher,
258          )
259  
260          assert headers == {"Authorization": "Bearer oauth_access_token_xyz"}
261  
262      @pytest.mark.asyncio
263      async def test_oauth2_without_token_fetcher_raises(self):
264          """Test that OAuth2 without token fetcher raises ValueError."""
265          agent_config = {
266              "authentication": {
267                  "type": "oauth2_client_credentials",
268                  "token_url": "https://auth.example.com/token",
269                  "client_id": "client123",
270                  "client_secret": "secret456",
271              }
272          }
273  
274          with pytest.raises(ValueError, match="no oauth_token_fetcher provided"):
275              await build_full_auth_headers(
276                  agent_name="test-agent",
277                  agent_config=agent_config,
278                  custom_headers_key="task_headers",
279                  use_auth=True,
280                  oauth_token_fetcher=None,  # Missing!
281              )
282  
283      @pytest.mark.asyncio
284      @patch("solace_agent_mesh.common.auth_headers.log")
285      async def test_oauth2_token_fetch_failure(self, mock_log):
286          """Test that OAuth2 token fetch failure is logged but non-fatal."""
287          agent_config = {
288              "authentication": {
289                  "type": "oauth2_client_credentials",
290                  "token_url": "https://auth.example.com/token",
291                  "client_id": "client123",
292                  "client_secret": "secret456",
293              }
294          }
295  
296          # Mock token fetcher that raises
297          async def failing_fetcher(agent_name, auth_config):
298              raise RuntimeError("Token service unavailable")
299  
300          headers = await build_full_auth_headers(
301              agent_name="test-agent",
302              agent_config=agent_config,
303              custom_headers_key="task_headers",
304              use_auth=True,
305              log_identifier="[Test]",
306              oauth_token_fetcher=failing_fetcher,
307          )
308  
309          # Should return headers without auth (matches existing behavior)
310          assert "Authorization" not in headers
311          # Error should be logged
312          mock_log.error.assert_called_once()
313          call_args = " ".join(str(arg) for arg in mock_log.error.call_args[0])
314          assert "Failed to obtain OAuth 2.0 token" in call_args
315  
316      @pytest.mark.asyncio
317      async def test_oauth2_with_custom_headers(self):
318          """Test OAuth2 + custom headers combination."""
319          agent_config = {
320              "authentication": {
321                  "type": "oauth2_client_credentials",
322                  "token_url": "https://auth.example.com/token",
323                  "client_id": "client123",
324                  "client_secret": "secret456",
325              },
326              "task_headers": [{"name": "X-Custom", "value": "custom_value"}],
327          }
328  
329          async def mock_fetcher(agent_name, auth_config):
330              return "oauth_token"
331  
332          headers = await build_full_auth_headers(
333              agent_name="test-agent",
334              agent_config=agent_config,
335              custom_headers_key="task_headers",
336              use_auth=True,
337              oauth_token_fetcher=mock_fetcher,
338          )
339  
340          assert headers == {
341              "Authorization": "Bearer oauth_token",
342              "X-Custom": "custom_value",
343          }
344  
345      @pytest.mark.asyncio
346      async def test_static_auth_with_custom_headers_override(self):
347          """Test that custom headers override OAuth2 auth headers."""
348          agent_config = {
349              "authentication": {
350                  "type": "oauth2_client_credentials",
351                  "token_url": "https://auth.example.com/token",
352                  "client_id": "client123",
353                  "client_secret": "secret456",
354              },
355              "task_headers": [
356                  {"name": "Authorization", "value": "Custom Bearer override"}
357              ],
358          }
359  
360          async def mock_fetcher(agent_name, auth_config):
361              return "oauth_token"
362  
363          headers = await build_full_auth_headers(
364              agent_name="test-agent",
365              agent_config=agent_config,
366              custom_headers_key="task_headers",
367              use_auth=True,
368              oauth_token_fetcher=mock_fetcher,
369          )
370  
371          # Custom header should override OAuth2 token
372          assert headers == {"Authorization": "Custom Bearer override"}
373  
374      @pytest.mark.asyncio
375      async def test_use_auth_false_in_async(self):
376          """Test that use_auth=False works in async context."""
377          agent_config = {
378              "authentication": {
379                  "type": "oauth2_client_credentials",
380                  "token_url": "https://auth.example.com/token",
381                  "client_id": "client123",
382                  "client_secret": "secret456",
383              }
384          }
385  
386          async def mock_fetcher(agent_name, auth_config):
387              pytest.fail("Token fetcher should not be called when use_auth=False")
388  
389          headers = await build_full_auth_headers(
390              agent_name="test-agent",
391              agent_config=agent_config,
392              custom_headers_key="task_headers",
393              use_auth=False,
394              oauth_token_fetcher=mock_fetcher,
395          )
396  
397          assert headers == {}