test_jwt.py
1 import re 2 from datetime import datetime, timedelta 3 from typing import Any, cast 4 from unittest.mock import MagicMock 5 6 import jwt as _jwt 7 import pytest 8 from _pytest.monkeypatch import MonkeyPatch 9 from pytest_mock import MockerFixture 10 11 from api.settings import settings 12 from api.utils import jwt 13 14 15 @pytest.mark.parametrize( 16 "data,now,ttl,expected", 17 [ 18 ({}, 42, 10, {"exp": 52}), 19 ({"foo": "bar", "x": 42, "y": 1337}, 123, 456, {"foo": "bar", "x": 42, "y": 1337, "exp": 579}), 20 ({"foo": {"bar": [{"x": 42}, {"y": 1337}]}}, 123, 456, {"foo": {"bar": [{"x": 42}, {"y": 1337}]}, "exp": 579}), 21 ], 22 ) 23 async def test__jwt_encode( 24 data: dict[str, Any], now: int, ttl: int, expected: dict[str, Any], mocker: MockerFixture, monkeypatch: MonkeyPatch 25 ) -> None: 26 mocker.patch("api.utils.jwt.datetime", MagicMock(utcnow=lambda: datetime.utcfromtimestamp(now))) 27 monkeypatch.setattr(settings, "jwt_secret", "My JWT secret") 28 29 token = jwt.encode_jwt(data, timedelta(seconds=ttl)) 30 31 match = re.match(r"^([a-zA-Z\d\-_]+)\.([a-zA-Z\d\-_]+)\.[a-zA-Z\d\-_]+$", token) 32 assert match, "Invalid JWT format" 33 34 assert _jwt.get_unverified_header(token) == {"typ": "JWT", "alg": "HS256"} # type: ignore 35 assert _jwt.decode(token, "My JWT secret", ["HS256"], {"verify_exp": False}) == expected 36 37 38 @pytest.mark.parametrize( 39 "data,ttl,require,expected", 40 [ 41 ({}, 1, [], True), 42 ({}, -1, [], False), 43 ({}, 1, ["foo"], False), 44 ({"foo": "bar"}, 1, [], True), 45 ({"foo": "bar"}, 1, ["foo"], True), 46 ({"foo": "bar"}, 1, ["foo", "bar"], False), 47 ({"foo": "bar"}, -1, ["foo"], False), 48 ], 49 ) 50 async def test__jwt_decode( 51 data: dict[str, Any], ttl: int, require: list[str], expected: bool, monkeypatch: MonkeyPatch 52 ) -> None: 53 exp = (datetime.utcnow() + timedelta(seconds=ttl)).replace(microsecond=0) 54 token = _jwt.encode(data | {"exp": exp}, "My JWT secret", "HS256") 55 monkeypatch.setattr(settings, "jwt_secret", "My JWT secret") 56 result = jwt.decode_jwt(token, require=require) 57 58 if expected: 59 assert isinstance(result, dict) 60 exp_ = cast(dict[str, Any], result).pop("exp") 61 assert result == data 62 assert exp == datetime.utcfromtimestamp(exp_) 63 else: 64 assert result is None