test_user.py
  1  from datetime import datetime
  2  from unittest.mock import AsyncMock, MagicMock
  3  
  4  import pytest
  5  from _pytest.monkeypatch import MonkeyPatch
  6  from pytest_mock import MockerFixture
  7  from sqlalchemy import func
  8  
  9  from api.database import db, db_wrapper, select
 10  from api.models import User
 11  from api.settings import settings
 12  from api.utils.passwords import verify_password
 13  
 14  
 15  @pytest.mark.parametrize(
 16      "enabled,admin,password,mfa", [(True, False, "asdf", False), (False, True, None, True), (True, True, None, False)]
 17  )
 18  async def test__serialize(enabled: bool, admin: bool, password: str | None, mfa: bool) -> None:
 19      obj = User(
 20          id="user_id",
 21          name="user_name",
 22          registration=datetime.fromtimestamp(123456),
 23          last_login=datetime.fromtimestamp(345678),
 24          enabled=enabled,
 25          admin=admin,
 26          password=password,
 27          mfa_enabled=mfa,
 28      )
 29  
 30      assert obj.serialize == {
 31          "id": "user_id",
 32          "name": "user_name",
 33          "registration": 123456,
 34          "last_login": 345678,
 35          "enabled": enabled,
 36          "admin": admin,
 37          "password": bool(password),
 38          "mfa_enabled": mfa,
 39      }
 40  
 41  
 42  @pytest.mark.parametrize("enabled,admin,password", [(True, False, "asdf"), (False, True, None), (True, True, None)])
 43  @db_wrapper
 44  async def test__create(enabled: bool, admin: bool, password: str | None) -> None:
 45      obj = await User.create("user_name", password, enabled, admin)
 46      users = await db.all(select(User))
 47      assert users == [obj]
 48  
 49      assert obj.name == "user_name"
 50  
 51      if password:
 52          assert await verify_password(password, obj.password)
 53      else:
 54          assert obj.password is None
 55  
 56      assert abs(datetime.utcnow() - obj.registration).total_seconds() < 10
 57      assert obj.last_login is None
 58      assert obj.enabled == enabled
 59      assert obj.admin == admin
 60      assert obj.mfa_secret is None
 61      assert obj.mfa_enabled is False
 62      assert obj.mfa_recovery_code is None
 63  
 64  
 65  async def test__filter_by_name() -> None:
 66      assert User.filter_by_name("UserName") == select(User).where(func.lower(User.name) == "username")
 67  
 68  
 69  @pytest.mark.parametrize("first_user", [True, False])
 70  @db_wrapper
 71  async def test__initialize(first_user: bool, monkeypatch: MonkeyPatch) -> None:
 72      monkeypatch.setattr(settings, "admin_username", "admin_username")
 73      monkeypatch.setattr(settings, "admin_password", "admin_password")
 74  
 75      if not first_user:
 76          await User.create("other_user", "other_password", True, True)
 77  
 78      await User.initialize()
 79  
 80      users = await db.all(select(User))
 81      assert len(users) == 1
 82      assert users[0].name == "admin_username" if first_user else "other_user"
 83  
 84  
 85  @pytest.mark.parametrize("arg,dbv,ok", [("foo", "foo", True), ("foo", "bar", False), ("foo", None, False)])
 86  async def test__check_password(arg: str, dbv: str | None, ok: bool, mocker: MockerFixture) -> None:
 87      mocker.patch("api.models.user.verify_password", AsyncMock(side_effect=str.__eq__))
 88  
 89      user = User(password=dbv)
 90      assert await user.check_password(arg) == ok
 91  
 92  
 93  @pytest.mark.parametrize("pw", ["asdf", None])
 94  async def test__change_password(pw: str | None, mocker: MockerFixture) -> None:
 95      hash_password = mocker.patch("api.models.user.hash_password", new_callable=AsyncMock)
 96  
 97      user = User(password="foobar")  # noqa: S106
 98      await user.change_password(pw)
 99  
100      if pw:
101          hash_password.assert_called_once_with(pw)
102          assert user.password == await hash_password()
103      else:
104          hash_password.assert_not_called()
105          assert user.password is None
106  
107  
108  async def test__create_session(mocker: MockerFixture) -> None:
109      create = mocker.patch("api.models.session.Session.create", new_callable=AsyncMock)
110  
111      user = User(id="my_user_id")
112      session = await user.create_session("my device name")
113  
114      create.assert_called_once_with("my_user_id", "my device name")
115      assert session == await create()
116      assert user.last_login is not None
117      assert abs(datetime.utcnow() - user.last_login).total_seconds() < 10
118  
119  
120  async def test__from_access_token__invalid_jwt(mocker: MockerFixture) -> None:
121      decode_jwt = mocker.patch("api.models.user.decode_jwt", MagicMock(return_value=None))
122  
123      assert await User.from_access_token("my_token") is None
124      decode_jwt.assert_called_once_with("my_token", require=["uid", "sid", "rt"])
125  
126  
127  async def test__from_access_token__logout(mocker: MockerFixture) -> None:
128      data = {"rt": "my_refresh_token"}
129      decode_jwt = mocker.patch("api.models.user.decode_jwt", MagicMock(return_value=data))
130      exists = mocker.patch("api.models.user.redis.exists", AsyncMock(return_value=True))
131  
132      assert await User.from_access_token("my_token") is None
133      decode_jwt.assert_called_once_with("my_token", require=["uid", "sid", "rt"])
134      exists.assert_called_once_with("session_logout:my_refresh_token")
135  
136  
137  @pytest.mark.parametrize("user_exists", [True, False])
138  @db_wrapper
139  async def test__from_access_token__valid(user_exists: bool, mocker: MockerFixture) -> None:
140      data = {"rt": "my_refresh_token", "uid": "my_uid"}
141      decode_jwt = mocker.patch("api.models.user.decode_jwt", MagicMock(return_value=data))
142      exists = mocker.patch("api.models.user.redis.exists", AsyncMock(return_value=False))
143  
144      (await User.create("other_user_name", "other_password", True, True)).id = "other_uid"
145      user: User | None = None
146      if user_exists:
147          user = await User.create("my_user_name", "my_password", True, True)
148          user.id = "my_uid"
149  
150      result = await User.from_access_token("my_token")
151  
152      decode_jwt.assert_called_once_with("my_token", require=["uid", "sid", "rt"])
153      exists.assert_called_once_with("session_logout:my_refresh_token")
154  
155      assert result is user
156  
157  
158  async def test__logout() -> None:
159      user = MagicMock()
160      user.sessions = [AsyncMock() for _ in range(5)]
161  
162      await User.logout(user)
163  
164      for session in user.sessions:
165          session.logout.assert_called_once_with()