/ tests / test_login_rate_limit.py
test_login_rate_limit.py
 1  """Tests for login endpoint rate limiting."""
 2  import pytest
 3  from fastapi.testclient import TestClient
 4  
 5  from restai.config import RESTAI_DEFAULT_PASSWORD
 6  from restai.main import app
 7  
 8  
 9  def _clear_login_attempts():
10      from restai.database import DBWrapper
11      from restai.models.databasemodels import LoginAttemptDatabase
12      db = DBWrapper()
13      try:
14          db.db.query(LoginAttemptDatabase).delete()
15          db.db.commit()
16      except Exception:
17          db.db.rollback()
18      finally:
19          db.db.close()
20  
21  
22  def test_login_success():
23      """Valid login works."""
24      _clear_login_attempts()
25      with TestClient(app) as client:
26          resp = client.post("/auth/login", auth=("admin", RESTAI_DEFAULT_PASSWORD))
27          assert resp.status_code == 200
28  
29  
30  def test_login_wrong_password():
31      """Wrong password returns 401."""
32      _clear_login_attempts()
33      with TestClient(app) as client:
34          resp = client.post("/auth/login", auth=("admin", "wrongpassword"))
35          assert resp.status_code == 401
36  
37  
38  def test_login_rate_limit_triggers():
39      """After 10 attempts, subsequent requests get 429."""
40      _clear_login_attempts()
41      with TestClient(app) as client:
42          for i in range(10):
43              client.post("/auth/login", auth=("admin", "bad"))
44  
45          # 11th attempt should be rate limited
46          resp = client.post("/auth/login", auth=("admin", "bad"))
47          assert resp.status_code == 429
48          assert "Too many" in resp.json()["detail"]
49  
50      _clear_login_attempts()
51  
52  
53  def test_login_rate_limit_doesnt_block_valid_after_reset():
54      """After clearing state, valid login works again."""
55      _clear_login_attempts()
56      with TestClient(app) as client:
57          resp = client.post("/auth/login", auth=("admin", RESTAI_DEFAULT_PASSWORD))
58          assert resp.status_code == 200