test_jwt.py
 1  import re
 2  from datetime import datetime, timedelta
 3  from typing import Any, cast
 4  from unittest.mock import MagicMock
 5  
 6  import jwt as _jwt
 7  import pytest
 8  from _pytest.monkeypatch import MonkeyPatch
 9  from pytest_mock import MockerFixture
10  
11  from api.settings import settings
12  from api.utils import jwt
13  
14  
15  @pytest.mark.parametrize(
16      "data,now,ttl,expected",
17      [
18          ({}, 42, 10, {"exp": 52}),
19          ({"foo": "bar", "x": 42, "y": 1337}, 123, 456, {"foo": "bar", "x": 42, "y": 1337, "exp": 579}),
20          ({"foo": {"bar": [{"x": 42}, {"y": 1337}]}}, 123, 456, {"foo": {"bar": [{"x": 42}, {"y": 1337}]}, "exp": 579}),
21      ],
22  )
23  async def test__jwt_encode(
24      data: dict[str, Any], now: int, ttl: int, expected: dict[str, Any], mocker: MockerFixture, monkeypatch: MonkeyPatch
25  ) -> None:
26      mocker.patch("api.utils.jwt.datetime", MagicMock(utcnow=lambda: datetime.utcfromtimestamp(now)))
27      monkeypatch.setattr(settings, "jwt_secret", "My JWT secret")
28  
29      token = jwt.encode_jwt(data, timedelta(seconds=ttl))
30  
31      match = re.match(r"^([a-zA-Z\d\-_]+)\.([a-zA-Z\d\-_]+)\.[a-zA-Z\d\-_]+$", token)
32      assert match, "Invalid JWT format"
33  
34      assert _jwt.get_unverified_header(token) == {"typ": "JWT", "alg": "HS256"}  # type: ignore
35      assert _jwt.decode(token, "My JWT secret", ["HS256"], {"verify_exp": False}) == expected
36  
37  
38  @pytest.mark.parametrize(
39      "data,ttl,require,expected",
40      [
41          ({}, 1, [], True),
42          ({}, -1, [], False),
43          ({}, 1, ["foo"], False),
44          ({"foo": "bar"}, 1, [], True),
45          ({"foo": "bar"}, 1, ["foo"], True),
46          ({"foo": "bar"}, 1, ["foo", "bar"], False),
47          ({"foo": "bar"}, -1, ["foo"], False),
48      ],
49  )
50  async def test__jwt_decode(
51      data: dict[str, Any], ttl: int, require: list[str], expected: bool, monkeypatch: MonkeyPatch
52  ) -> None:
53      exp = (datetime.utcnow() + timedelta(seconds=ttl)).replace(microsecond=0)
54      token = _jwt.encode(data | {"exp": exp}, "My JWT secret", "HS256")
55      monkeypatch.setattr(settings, "jwt_secret", "My JWT secret")
56      result = jwt.decode_jwt(token, require=require)
57  
58      if expected:
59          assert isinstance(result, dict)
60          exp_ = cast(dict[str, Any], result).pop("exp")
61          assert result == data
62          assert exp == datetime.utcfromtimestamp(exp_)
63      else:
64          assert result is None