test_totp.py
1 import random 2 import pyotp 3 import pytest 4 from fastapi.testclient import TestClient 5 6 from restai.config import RESTAI_DEFAULT_PASSWORD 7 from restai.main import app 8 9 10 @pytest.fixture(autouse=True) 11 def clear_rate_limiter(): 12 """Clear DB-backed login rate limiter before each test.""" 13 from restai.database import DBWrapper 14 from restai.models.databasemodels import LoginAttemptDatabase 15 db = DBWrapper() 16 try: 17 db.db.query(LoginAttemptDatabase).delete() 18 db.db.commit() 19 except Exception: 20 db.db.rollback() 21 finally: 22 db.db.close() 23 24 25 @pytest.fixture(scope="module") 26 def client(): 27 with TestClient(app) as c: 28 yield c 29 30 31 test_username = "test_totp_user_" + str(random.randint(0, 1000000)) 32 test_password = "totp_test_pass_123" 33 totp_secret = None 34 recovery_codes = [] 35 36 37 def test_setup_totp_user(client): 38 """Create a test user for TOTP tests.""" 39 response = client.post( 40 "/users", 41 json={"username": test_username, "password": test_password, "admin": False, "private": False}, 42 auth=("admin", RESTAI_DEFAULT_PASSWORD), 43 ) 44 assert response.status_code == 201 45 46 47 def test_totp_status_initially_disabled(client): 48 response = client.get( 49 f"/users/{test_username}/totp/status", 50 auth=(test_username, test_password), 51 ) 52 assert response.status_code == 200 53 data = response.json() 54 assert data["enabled"] is False 55 56 57 def test_totp_setup(client): 58 global totp_secret, recovery_codes 59 response = client.post( 60 f"/users/{test_username}/totp/setup", 61 json={}, 62 auth=(test_username, test_password), 63 ) 64 assert response.status_code == 200 65 data = response.json() 66 assert "secret" in data 67 assert "provisioning_uri" in data 68 assert "recovery_codes" in data 69 assert len(data["recovery_codes"]) == 8 70 assert data["provisioning_uri"].startswith("otpauth://totp/") 71 totp_secret = data["secret"] 72 recovery_codes = data["recovery_codes"] 73 74 75 def test_totp_enable_with_invalid_code(client): 76 response = client.post( 77 f"/users/{test_username}/totp/enable", 78 json={"code": "000000", "password": test_password}, 79 auth=(test_username, test_password), 80 ) 81 assert response.status_code == 400 82 assert "Invalid TOTP code" in response.json()["detail"] 83 84 85 def test_totp_enable_with_wrong_password(client): 86 code = pyotp.TOTP(totp_secret).now() 87 response = client.post( 88 f"/users/{test_username}/totp/enable", 89 json={"code": code, "password": "wrong_password"}, 90 auth=(test_username, test_password), 91 ) 92 assert response.status_code == 403 93 assert "Invalid password" in response.json()["detail"] 94 95 96 def test_totp_enable_with_valid_code(client): 97 code = pyotp.TOTP(totp_secret).now() 98 response = client.post( 99 f"/users/{test_username}/totp/enable", 100 json={"code": code, "password": test_password}, 101 auth=(test_username, test_password), 102 ) 103 assert response.status_code == 200 104 assert "enabled" in response.json()["message"].lower() 105 106 107 def test_totp_status_after_enable(client): 108 response = client.get( 109 f"/users/{test_username}/totp/status", 110 auth=(test_username, test_password), 111 ) 112 assert response.status_code == 200 113 assert response.json()["enabled"] is True 114 115 116 def test_login_requires_totp_when_enabled(): 117 with TestClient(app) as c: 118 response = c.post( 119 "/auth/login", 120 auth=(test_username, test_password), 121 ) 122 assert response.status_code == 200 123 data = response.json() 124 assert data["requires_totp"] is True 125 assert "totp_token" in data 126 127 128 def test_verify_totp_invalid_code(): 129 with TestClient(app) as c: 130 login_resp = c.post("/auth/login", auth=(test_username, test_password)) 131 totp_token = login_resp.json()["totp_token"] 132 133 response = c.post( 134 "/auth/verify-totp", 135 json={"token": totp_token, "code": "000000"}, 136 ) 137 assert response.status_code == 401 138 139 140 def test_verify_totp_valid_code(): 141 with TestClient(app) as c: 142 login_resp = c.post("/auth/login", auth=(test_username, test_password)) 143 totp_token = login_resp.json()["totp_token"] 144 145 code = pyotp.TOTP(totp_secret).now() 146 response = c.post( 147 "/auth/verify-totp", 148 json={"token": totp_token, "code": code}, 149 ) 150 assert response.status_code == 200 151 assert "Logged in" in response.json()["message"] 152 153 154 def test_verify_totp_expired_token(): 155 with TestClient(app) as c: 156 response = c.post( 157 "/auth/verify-totp", 158 json={"token": "invalid.jwt.token", "code": "123456"}, 159 ) 160 assert response.status_code == 401 161 162 163 def test_verify_totp_recovery_code(): 164 with TestClient(app) as c: 165 login_resp = c.post("/auth/login", auth=(test_username, test_password)) 166 totp_token = login_resp.json()["totp_token"] 167 168 response = c.post( 169 "/auth/verify-totp", 170 json={"token": totp_token, "code": recovery_codes[0]}, 171 ) 172 assert response.status_code == 200 173 assert "Recovery code consumed" in response.json()["message"] 174 175 176 def test_recovery_code_single_use(): 177 with TestClient(app) as c: 178 login_resp = c.post("/auth/login", auth=(test_username, test_password)) 179 totp_token = login_resp.json()["totp_token"] 180 181 # Same recovery code should fail now 182 response = c.post( 183 "/auth/verify-totp", 184 json={"token": totp_token, "code": recovery_codes[0]}, 185 ) 186 assert response.status_code == 401 187 188 189 def test_totp_disable_wrong_password(client): 190 response = client.post( 191 f"/users/{test_username}/totp/disable", 192 json={"password": "wrong_password"}, 193 auth=(test_username, test_password), 194 ) 195 assert response.status_code == 403 196 assert "Invalid password" in response.json()["detail"] 197 198 199 def test_totp_disable_with_password(client): 200 response = client.post( 201 f"/users/{test_username}/totp/disable", 202 json={"password": test_password}, 203 auth=(test_username, test_password), 204 ) 205 assert response.status_code == 200 206 assert "disabled" in response.json()["message"].lower() 207 208 209 def test_login_normal_after_disable(): 210 """Uses separate client to avoid cookie pollution from login.""" 211 with TestClient(app) as c: 212 response = c.post( 213 "/auth/login", 214 auth=(test_username, test_password), 215 ) 216 assert response.status_code == 200 217 data = response.json() 218 assert "requires_totp" not in data or data.get("requires_totp") is not True 219 assert "Logged in" in data.get("message", "") 220 221 222 def test_totp_setup_overwrites_previous(client): 223 global totp_secret 224 # Setup twice 225 resp1 = client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password)) 226 secret1 = resp1.json()["secret"] 227 resp2 = client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password)) 228 secret2 = resp2.json()["secret"] 229 assert secret1 != secret2 230 totp_secret = secret2 231 232 233 def test_non_admin_cannot_setup_other_user(client): 234 response = client.post( 235 "/users/admin/totp/setup", 236 json={}, 237 auth=(test_username, test_password), 238 ) 239 assert response.status_code == 403 240 241 242 def test_enforce_only_local_users(client): 243 """API key auth should work regardless of 2FA enforcement.""" 244 # Create an API key for the test user 245 response = client.post( 246 f"/users/{test_username}/apikeys", 247 json={"description": "totp_test_key"}, 248 auth=(test_username, test_password), 249 ) 250 assert response.status_code == 201 251 api_key = response.json()["api_key"] 252 253 # Use API key auth — should work without 2FA 254 response = client.get( 255 "/auth/whoami", 256 headers={"Authorization": f"Bearer {api_key}"}, 257 ) 258 assert response.status_code == 200 259 assert response.json()["username"] == test_username 260 261 262 def test_enforce_2fa_setting(client): 263 response = client.patch( 264 "/settings", 265 json={"enforce_2fa": True}, 266 auth=("admin", RESTAI_DEFAULT_PASSWORD), 267 ) 268 assert response.status_code == 200 269 270 271 def test_cannot_disable_when_enforced(client): 272 # First enable 2FA for the user 273 client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password)) 274 code = pyotp.TOTP(totp_secret).now() 275 client.post(f"/users/{test_username}/totp/enable", json={"code": code, "password": test_password}, auth=(test_username, test_password)) 276 277 # Try to disable — should fail 278 response = client.post( 279 f"/users/{test_username}/totp/disable", 280 json={"password": test_password}, 281 auth=(test_username, test_password), 282 ) 283 assert response.status_code == 403 284 assert "enforced" in response.json()["detail"].lower() 285 286 287 def test_cleanup(client): 288 """Reset enforce_2fa and delete test user.""" 289 client.patch("/settings", json={"enforce_2fa": False}, auth=("admin", RESTAI_DEFAULT_PASSWORD)) 290 client.delete(f"/users/{test_username}", auth=("admin", RESTAI_DEFAULT_PASSWORD))