push_notification_auth.py
1 import logging 2 from jwcrypto import jwk 3 import uuid 4 from starlette.responses import JSONResponse 5 from starlette.requests import Request 6 from typing import Any 7 8 import jwt 9 import time 10 import json 11 import hashlib 12 import httpx 13 14 from jwt import PyJWK, PyJWKClient 15 16 log = logging.getLogger(__name__) 17 18 AUTH_HEADER_PREFIX = "Bearer " 19 20 21 class PushNotificationAuth: 22 def _calculate_request_body_sha256(self, data: dict[str, Any]): 23 """Calculates the SHA256 hash of a request body. 24 25 This logic needs to be same for both the agent who signs the payload and the client verifier. 26 """ 27 body_str = json.dumps( 28 data, 29 ensure_ascii=False, 30 allow_nan=False, 31 indent=None, 32 separators=(",", ":"), 33 ) 34 return hashlib.sha256(body_str.encode()).hexdigest() 35 36 37 class PushNotificationSenderAuth(PushNotificationAuth): 38 def __init__(self): 39 self.public_keys = [] 40 self.private_key_jwk: PyJWK = None 41 42 @staticmethod 43 async def verify_push_notification_url(url: str) -> bool: 44 async with httpx.AsyncClient(timeout=10) as client: 45 try: 46 validation_token = str(uuid.uuid4()) 47 response = await client.get( 48 url, params={"validationToken": validation_token} 49 ) 50 response.raise_for_status() 51 is_verified = response.text == validation_token 52 53 log.info("Verified push-notification URL: %s => %s", url, is_verified) 54 return is_verified 55 except Exception as e: 56 log.warning( 57 "Error during sending push-notification for URL %s: %s", url, e 58 ) 59 60 return False 61 62 def generate_jwk(self): 63 key = jwk.JWK.generate(kty="RSA", size=2048, kid=str(uuid.uuid4()), use="sig") 64 self.public_keys.append(key.export_public(as_dict=True)) 65 self.private_key_jwk = PyJWK.from_json(key.export_private()) 66 67 def handle_jwks_endpoint(self, _request: Request): 68 """Allow clients to fetch public keys.""" 69 return JSONResponse({"keys": self.public_keys}) 70 71 def _generate_jwt(self, data: dict[str, Any]): 72 """JWT is generated by signing both the request payload SHA digest and time of token generation. 73 74 Payload is signed with private key and it ensures the integrity of payload for client. 75 Including iat prevents from replay attack. 76 """ 77 78 iat = int(time.time()) 79 80 return jwt.encode( 81 { 82 "iat": iat, 83 "request_body_sha256": self._calculate_request_body_sha256(data), 84 }, 85 key=self.private_key_jwk, 86 headers={"kid": self.private_key_jwk.key_id}, 87 algorithm="RS256", 88 ) 89 90 async def send_push_notification(self, url: str, data: dict[str, Any]): 91 jwt_token = self._generate_jwt(data) 92 headers = {"Authorization": f"Bearer {jwt_token}"} 93 async with httpx.AsyncClient(timeout=10) as client: 94 try: 95 response = await client.post(url, json=data, headers=headers) 96 response.raise_for_status() 97 log.info("Push-notification sent for URL: %s", url) 98 except Exception as e: 99 log.warning( 100 "Error during sending push-notification for URL %s: %s", url, e 101 ) 102 103 104 class PushNotificationReceiverAuth(PushNotificationAuth): 105 def __init__(self): 106 self.public_keys_jwks = [] 107 self.jwks_client = None 108 109 async def load_jwks(self, jwks_url: str): 110 self.jwks_client = PyJWKClient(jwks_url) 111 112 async def verify_push_notification(self, request: Request) -> bool: 113 auth_header = request.headers.get("Authorization") 114 if not auth_header or not auth_header.startswith(AUTH_HEADER_PREFIX): 115 log.warning("Invalid authorization header") 116 return False 117 118 token = auth_header[len(AUTH_HEADER_PREFIX) :] 119 signing_key = self.jwks_client.get_signing_key_from_jwt(token) 120 121 decode_token = jwt.decode( 122 token, 123 signing_key, 124 options={"require": ["iat", "request_body_sha256"]}, 125 algorithms=["RS256"], 126 ) 127 128 actual_body_sha256 = self._calculate_request_body_sha256(await request.json()) 129 if actual_body_sha256 != decode_token["request_body_sha256"]: 130 raise ValueError("Invalid request body") 131 132 if time.time() - decode_token["iat"] > 60 * 5: 133 raise ValueError("Token is expired") 134 135 return True