/ tests / test_totp.py
test_totp.py
  1  import random
  2  import pyotp
  3  import pytest
  4  from fastapi.testclient import TestClient
  5  
  6  from restai.config import RESTAI_DEFAULT_PASSWORD
  7  from restai.main import app
  8  
  9  
 10  @pytest.fixture(autouse=True)
 11  def clear_rate_limiter():
 12      """Clear DB-backed login rate limiter before each test."""
 13      from restai.database import DBWrapper
 14      from restai.models.databasemodels import LoginAttemptDatabase
 15      db = DBWrapper()
 16      try:
 17          db.db.query(LoginAttemptDatabase).delete()
 18          db.db.commit()
 19      except Exception:
 20          db.db.rollback()
 21      finally:
 22          db.db.close()
 23  
 24  
 25  @pytest.fixture(scope="module")
 26  def client():
 27      with TestClient(app) as c:
 28          yield c
 29  
 30  
 31  test_username = "test_totp_user_" + str(random.randint(0, 1000000))
 32  test_password = "totp_test_pass_123"
 33  totp_secret = None
 34  recovery_codes = []
 35  
 36  
 37  def test_setup_totp_user(client):
 38      """Create a test user for TOTP tests."""
 39      response = client.post(
 40          "/users",
 41          json={"username": test_username, "password": test_password, "admin": False, "private": False},
 42          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 43      )
 44      assert response.status_code == 201
 45  
 46  
 47  def test_totp_status_initially_disabled(client):
 48      response = client.get(
 49          f"/users/{test_username}/totp/status",
 50          auth=(test_username, test_password),
 51      )
 52      assert response.status_code == 200
 53      data = response.json()
 54      assert data["enabled"] is False
 55  
 56  
 57  def test_totp_setup(client):
 58      global totp_secret, recovery_codes
 59      response = client.post(
 60          f"/users/{test_username}/totp/setup",
 61          json={},
 62          auth=(test_username, test_password),
 63      )
 64      assert response.status_code == 200
 65      data = response.json()
 66      assert "secret" in data
 67      assert "provisioning_uri" in data
 68      assert "recovery_codes" in data
 69      assert len(data["recovery_codes"]) == 8
 70      assert data["provisioning_uri"].startswith("otpauth://totp/")
 71      totp_secret = data["secret"]
 72      recovery_codes = data["recovery_codes"]
 73  
 74  
 75  def test_totp_enable_with_invalid_code(client):
 76      response = client.post(
 77          f"/users/{test_username}/totp/enable",
 78          json={"code": "000000", "password": test_password},
 79          auth=(test_username, test_password),
 80      )
 81      assert response.status_code == 400
 82      assert "Invalid TOTP code" in response.json()["detail"]
 83  
 84  
 85  def test_totp_enable_with_wrong_password(client):
 86      code = pyotp.TOTP(totp_secret).now()
 87      response = client.post(
 88          f"/users/{test_username}/totp/enable",
 89          json={"code": code, "password": "wrong_password"},
 90          auth=(test_username, test_password),
 91      )
 92      assert response.status_code == 403
 93      assert "Invalid password" in response.json()["detail"]
 94  
 95  
 96  def test_totp_enable_with_valid_code(client):
 97      code = pyotp.TOTP(totp_secret).now()
 98      response = client.post(
 99          f"/users/{test_username}/totp/enable",
100          json={"code": code, "password": test_password},
101          auth=(test_username, test_password),
102      )
103      assert response.status_code == 200
104      assert "enabled" in response.json()["message"].lower()
105  
106  
107  def test_totp_status_after_enable(client):
108      response = client.get(
109          f"/users/{test_username}/totp/status",
110          auth=(test_username, test_password),
111      )
112      assert response.status_code == 200
113      assert response.json()["enabled"] is True
114  
115  
116  def test_login_requires_totp_when_enabled():
117      with TestClient(app) as c:
118          response = c.post(
119              "/auth/login",
120              auth=(test_username, test_password),
121          )
122          assert response.status_code == 200
123          data = response.json()
124          assert data["requires_totp"] is True
125          assert "totp_token" in data
126  
127  
128  def test_verify_totp_invalid_code():
129      with TestClient(app) as c:
130          login_resp = c.post("/auth/login", auth=(test_username, test_password))
131          totp_token = login_resp.json()["totp_token"]
132  
133          response = c.post(
134              "/auth/verify-totp",
135              json={"token": totp_token, "code": "000000"},
136          )
137          assert response.status_code == 401
138  
139  
140  def test_verify_totp_valid_code():
141      with TestClient(app) as c:
142          login_resp = c.post("/auth/login", auth=(test_username, test_password))
143          totp_token = login_resp.json()["totp_token"]
144  
145          code = pyotp.TOTP(totp_secret).now()
146          response = c.post(
147              "/auth/verify-totp",
148              json={"token": totp_token, "code": code},
149          )
150          assert response.status_code == 200
151          assert "Logged in" in response.json()["message"]
152  
153  
154  def test_verify_totp_expired_token():
155      with TestClient(app) as c:
156          response = c.post(
157              "/auth/verify-totp",
158              json={"token": "invalid.jwt.token", "code": "123456"},
159          )
160          assert response.status_code == 401
161  
162  
163  def test_verify_totp_recovery_code():
164      with TestClient(app) as c:
165          login_resp = c.post("/auth/login", auth=(test_username, test_password))
166          totp_token = login_resp.json()["totp_token"]
167  
168          response = c.post(
169              "/auth/verify-totp",
170              json={"token": totp_token, "code": recovery_codes[0]},
171          )
172          assert response.status_code == 200
173          assert "Recovery code consumed" in response.json()["message"]
174  
175  
176  def test_recovery_code_single_use():
177      with TestClient(app) as c:
178          login_resp = c.post("/auth/login", auth=(test_username, test_password))
179          totp_token = login_resp.json()["totp_token"]
180  
181          # Same recovery code should fail now
182          response = c.post(
183              "/auth/verify-totp",
184              json={"token": totp_token, "code": recovery_codes[0]},
185          )
186      assert response.status_code == 401
187  
188  
189  def test_totp_disable_wrong_password(client):
190      response = client.post(
191          f"/users/{test_username}/totp/disable",
192          json={"password": "wrong_password"},
193          auth=(test_username, test_password),
194      )
195      assert response.status_code == 403
196      assert "Invalid password" in response.json()["detail"]
197  
198  
199  def test_totp_disable_with_password(client):
200      response = client.post(
201          f"/users/{test_username}/totp/disable",
202          json={"password": test_password},
203          auth=(test_username, test_password),
204      )
205      assert response.status_code == 200
206      assert "disabled" in response.json()["message"].lower()
207  
208  
209  def test_login_normal_after_disable():
210      """Uses separate client to avoid cookie pollution from login."""
211      with TestClient(app) as c:
212          response = c.post(
213              "/auth/login",
214              auth=(test_username, test_password),
215          )
216          assert response.status_code == 200
217          data = response.json()
218          assert "requires_totp" not in data or data.get("requires_totp") is not True
219          assert "Logged in" in data.get("message", "")
220  
221  
222  def test_totp_setup_overwrites_previous(client):
223      global totp_secret
224      # Setup twice
225      resp1 = client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password))
226      secret1 = resp1.json()["secret"]
227      resp2 = client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password))
228      secret2 = resp2.json()["secret"]
229      assert secret1 != secret2
230      totp_secret = secret2
231  
232  
233  def test_non_admin_cannot_setup_other_user(client):
234      response = client.post(
235          "/users/admin/totp/setup",
236          json={},
237          auth=(test_username, test_password),
238      )
239      assert response.status_code == 403
240  
241  
242  def test_enforce_only_local_users(client):
243      """API key auth should work regardless of 2FA enforcement."""
244      # Create an API key for the test user
245      response = client.post(
246          f"/users/{test_username}/apikeys",
247          json={"description": "totp_test_key"},
248          auth=(test_username, test_password),
249      )
250      assert response.status_code == 201
251      api_key = response.json()["api_key"]
252  
253      # Use API key auth — should work without 2FA
254      response = client.get(
255          "/auth/whoami",
256          headers={"Authorization": f"Bearer {api_key}"},
257      )
258      assert response.status_code == 200
259      assert response.json()["username"] == test_username
260  
261  
262  def test_enforce_2fa_setting(client):
263      response = client.patch(
264          "/settings",
265          json={"enforce_2fa": True},
266          auth=("admin", RESTAI_DEFAULT_PASSWORD),
267      )
268      assert response.status_code == 200
269  
270  
271  def test_cannot_disable_when_enforced(client):
272      # First enable 2FA for the user
273      client.post(f"/users/{test_username}/totp/setup", json={}, auth=(test_username, test_password))
274      code = pyotp.TOTP(totp_secret).now()
275      client.post(f"/users/{test_username}/totp/enable", json={"code": code, "password": test_password}, auth=(test_username, test_password))
276  
277      # Try to disable — should fail
278      response = client.post(
279          f"/users/{test_username}/totp/disable",
280          json={"password": test_password},
281          auth=(test_username, test_password),
282      )
283      assert response.status_code == 403
284      assert "enforced" in response.json()["detail"].lower()
285  
286  
287  def test_cleanup(client):
288      """Reset enforce_2fa and delete test user."""
289      client.patch("/settings", json={"enforce_2fa": False}, auth=("admin", RESTAI_DEFAULT_PASSWORD))
290      client.delete(f"/users/{test_username}", auth=("admin", RESTAI_DEFAULT_PASSWORD))