/ src / solace_agent_mesh / common / utils / push_notification_auth.py
push_notification_auth.py
  1  import logging
  2  from jwcrypto import jwk
  3  import uuid
  4  from starlette.responses import JSONResponse
  5  from starlette.requests import Request
  6  from typing import Any
  7  
  8  import jwt
  9  import time
 10  import json
 11  import hashlib
 12  import httpx
 13  
 14  from jwt import PyJWK, PyJWKClient
 15  
 16  log = logging.getLogger(__name__)
 17  
 18  AUTH_HEADER_PREFIX = "Bearer "
 19  
 20  
 21  class PushNotificationAuth:
 22      def _calculate_request_body_sha256(self, data: dict[str, Any]):
 23          """Calculates the SHA256 hash of a request body.
 24  
 25          This logic needs to be same for both the agent who signs the payload and the client verifier.
 26          """
 27          body_str = json.dumps(
 28              data,
 29              ensure_ascii=False,
 30              allow_nan=False,
 31              indent=None,
 32              separators=(",", ":"),
 33          )
 34          return hashlib.sha256(body_str.encode()).hexdigest()
 35  
 36  
 37  class PushNotificationSenderAuth(PushNotificationAuth):
 38      def __init__(self):
 39          self.public_keys = []
 40          self.private_key_jwk: PyJWK = None
 41  
 42      @staticmethod
 43      async def verify_push_notification_url(url: str) -> bool:
 44          async with httpx.AsyncClient(timeout=10) as client:
 45              try:
 46                  validation_token = str(uuid.uuid4())
 47                  response = await client.get(
 48                      url, params={"validationToken": validation_token}
 49                  )
 50                  response.raise_for_status()
 51                  is_verified = response.text == validation_token
 52  
 53                  log.info("Verified push-notification URL: %s => %s", url, is_verified)
 54                  return is_verified
 55              except Exception as e:
 56                  log.warning(
 57                      "Error during sending push-notification for URL %s: %s", url, e
 58                  )
 59  
 60          return False
 61  
 62      def generate_jwk(self):
 63          key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig")
 64          self.public_keys.append(key.export_public(as_dict=True))
 65          self.private_key_jwk = PyJWK.from_json(key.export_private())
 66  
 67      def handle_jwks_endpoint(self, _request: Request):
 68          """Allow clients to fetch public keys."""
 69          return JSONResponse({"keys": self.public_keys})
 70  
 71      def _generate_jwt(self, data: dict[str, Any]):
 72          """JWT is generated by signing both the request payload SHA digest and time of token generation.
 73  
 74          Payload is signed with private key and it ensures the integrity of payload for client.
 75          Including iat prevents from replay attack.
 76          """
 77  
 78          iat = int(time.time())
 79  
 80          return jwt.encode(
 81              {
 82                  "iat": iat,
 83                  "request_body_sha256": self._calculate_request_body_sha256(data),
 84              },
 85              key=self.private_key_jwk,
 86              headers={"kid": self.private_key_jwk.key_id},
 87              algorithm="RS256",
 88          )
 89  
 90      async def send_push_notification(self, url: str, data: dict[str, Any]):
 91          jwt_token = self._generate_jwt(data)
 92          headers = {"Authorization": f"Bearer {jwt_token}"}
 93          async with httpx.AsyncClient(timeout=10) as client:
 94              try:
 95                  response = await client.post(url, json=data, headers=headers)
 96                  response.raise_for_status()
 97                  log.info("Push-notification sent for URL: %s", url)
 98              except Exception as e:
 99                  log.warning(
100                      "Error during sending push-notification for URL %s: %s", url, e
101                  )
102  
103  
104  class PushNotificationReceiverAuth(PushNotificationAuth):
105      def __init__(self):
106          self.public_keys_jwks = []
107          self.jwks_client = None
108  
109      async def load_jwks(self, jwks_url: str):
110          self.jwks_client = PyJWKClient(jwks_url)
111  
112      async def verify_push_notification(self, request: Request) -> bool:
113          auth_header = request.headers.get("Authorization")
114          if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX):
115              log.warning("Invalid authorization header")
116              return False
117  
118          token = auth_header[len(AUTH_HEADER_PREFIX) :]
119          signing_key = self.jwks_client.get_signing_key_from_jwt(token)
120  
121          decode_token = jwt.decode(
122              token,
123              signing_key,
124              options={"require": ["iat", "request_body_sha256"]},
125              algorithms=["RS256"],
126          )
127  
128          actual_body_sha256 = self._calculate_request_body_sha256(await request.json())
129          if actual_body_sha256 != decode_token["request_body_sha256"]:
130              raise ValueError("Invalid request body")
131  
132          if time.time() - decode_token["iat"] > 60 * 5:
133              raise ValueError("Token is expired")
134  
135          return True