/ restai / routers / auth.py
auth.py
  1  from datetime import timedelta, datetime, timezone
  2  from typing import Optional
  3  import logging
  4  
  5  import jwt
  6  import pyotp
  7  from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
  8  from fastapi.responses import RedirectResponse
  9  from restai import config
 10  from restai.auth import create_access_token, get_current_username, get_current_username_admin
 11  from restai.config import RESTAI_AUTH_SECRET
 12  from restai.database import DBWrapper, get_db_wrapper
 13  from restai.models.models import User, TOTPVerifyRequest
 14  from restai.utils.crypto import decrypt_totp_secret, hash_recovery_code, verify_recovery_code
 15  import json
 16  
 17  
 18  logging.basicConfig(level=config.LOG_LEVEL)
 19  
 20  router = APIRouter()
 21  
 22  # --- DB-backed rate limiter for auth endpoints ---
 23  _LOGIN_MAX_ATTEMPTS = 10
 24  _LOGIN_WINDOW_SECONDS = 300  # 5 minutes
 25  
 26  
 27  def _check_login_rate_limit(request: Request, db_wrapper: DBWrapper):
 28      from restai.models.databasemodels import LoginAttemptDatabase
 29  
 30      ip = request.client.host if request.client else "unknown"
 31      now = datetime.now(timezone.utc)
 32      cutoff = now - timedelta(seconds=_LOGIN_WINDOW_SECONDS)
 33  
 34      # Count recent attempts for this IP
 35      count = (
 36          db_wrapper.db.query(LoginAttemptDatabase)
 37          .filter(
 38              LoginAttemptDatabase.ip == ip,
 39              LoginAttemptDatabase.attempted_at > cutoff,
 40          )
 41          .count()
 42      )
 43      if count >= _LOGIN_MAX_ATTEMPTS:
 44          raise HTTPException(status_code=429, detail="Too many login attempts. Try again later.")
 45  
 46      # Record this attempt
 47      db_wrapper.db.add(LoginAttemptDatabase(ip=ip, attempted_at=now))
 48      db_wrapper.db.commit()
 49  
 50      # Periodic cleanup: delete old entries (runs ~1% of requests to avoid overhead)
 51      import random
 52      if random.random() < 0.01:
 53          db_wrapper.db.query(LoginAttemptDatabase).filter(
 54              LoginAttemptDatabase.attempted_at < cutoff
 55          ).delete()
 56          db_wrapper.db.commit()
 57  
 58  
 59  def _rate_limit_dependency(request: Request, db_wrapper: DBWrapper = Depends(get_db_wrapper)):
 60      """FastAPI dependency that checks the login rate limit BEFORE authentication."""
 61      _check_login_rate_limit(request, db_wrapper)
 62  
 63  
 64  def _password_age_warning(user_db, db_wrapper) -> Optional[dict]:
 65      """Return a soft warning dict when the user's password is older than
 66      the admin-configured `password_max_age_days` setting. Returns None
 67      when the feature is disabled (max_age=0), the user has no password
 68      (SSO/LDAP), or the timestamp is missing (legacy row pre-migration).
 69  
 70      Soft-only by design: we never block login on stale credentials —
 71      forcing rotation creates worse outcomes than nudging it. The UI can
 72      show a banner when this field appears in the response."""
 73      try:
 74          max_days = int(db_wrapper.get_setting_value("password_max_age_days", "0") or "0")
 75      except (TypeError, ValueError):
 76          return None
 77      if max_days <= 0 or user_db is None or user_db.password_updated_at is None:
 78          return None
 79      from datetime import datetime, timezone
 80      last = user_db.password_updated_at
 81      if last.tzinfo is None:
 82          last = last.replace(tzinfo=timezone.utc)
 83      age_days = (datetime.now(timezone.utc) - last).days
 84      if age_days < max_days:
 85          return None
 86      return {
 87          "password_age_days": age_days,
 88          "password_max_age_days": max_days,
 89          "message": (
 90              f"Your password is {age_days} days old (limit: {max_days}). "
 91              "Please change it from the Account page."
 92          ),
 93      }
 94  
 95  
 96  @router.post("/auth/login")
 97  async def login(
 98      request: Request,
 99      response: Response,
100      _rl=Depends(_rate_limit_dependency),
101      user: User = Depends(get_current_username),
102      db_wrapper: DBWrapper = Depends(get_db_wrapper),
103  ):
104      """Authenticate and receive a session cookie. If 2FA is enabled, returns a temporary token instead."""
105      # Check if user has TOTP enabled
106      user_db = db_wrapper.get_user_by_username(user.username)
107      if user_db and user_db.totp_enabled:
108          # Return temp token for TOTP verification (5 min expiry)
109          totp_token = create_access_token(
110              data={"username": user.username, "purpose": "totp_verify"},
111              expires_delta=timedelta(minutes=5),
112          )
113          return {"requires_totp": True, "totp_token": totp_token}
114  
115      # Normal login — no 2FA
116      jwt_token = create_access_token(
117          data={"username": user.username}, expires_delta=timedelta(minutes=1440)
118      )
119  
120      response.set_cookie(
121          key="restai_token",
122          value=jwt_token,
123          samesite="strict",
124          expires=86400,
125          httponly=True,
126      )
127  
128      out = {"message": "Logged in successfully."}
129      warning = _password_age_warning(user_db, db_wrapper)
130      if warning:
131          out["password_warning"] = warning
132      return out
133  
134  
135  @router.post("/auth/verify-totp")
136  async def verify_totp(
137      request: Request,
138      body: TOTPVerifyRequest,
139      response: Response,
140      db_wrapper: DBWrapper = Depends(get_db_wrapper),
141  ):
142      """Complete 2FA login by verifying a TOTP code or recovery code."""
143      _check_login_rate_limit(request, db_wrapper)
144      # Decode temp token
145      try:
146          data = jwt.decode(body.token, RESTAI_AUTH_SECRET, algorithms=["HS512"])
147          if data.get("purpose") != "totp_verify":
148              raise ValueError("Invalid token purpose")
149      except Exception:
150          raise HTTPException(status_code=401, detail="Invalid or expired TOTP token")
151  
152      username = data.get("username")
153      user_db = db_wrapper.get_user_by_username(username)
154      if user_db is None or not user_db.totp_enabled or not user_db.totp_secret:
155          raise HTTPException(status_code=401, detail="Invalid TOTP configuration")
156  
157      # Try TOTP code first
158      try:
159          secret = decrypt_totp_secret(user_db.totp_secret)
160          totp = pyotp.TOTP(secret)
161          if totp.verify(body.code, valid_window=1):
162              # Valid TOTP code — create session
163              jwt_token = create_access_token(
164                  data={"username": username}, expires_delta=timedelta(minutes=1440)
165              )
166              response.set_cookie(
167                  key="restai_token", value=jwt_token,
168                  samesite="strict", expires=86400, httponly=True,
169              )
170              out = {"message": "Logged in successfully."}
171              warning = _password_age_warning(user_db, db_wrapper)
172              if warning:
173                  out["password_warning"] = warning
174              return out
175      except Exception:
176          pass
177  
178      # Try recovery code
179      if user_db.totp_recovery_codes:
180          try:
181              codes = json.loads(user_db.totp_recovery_codes)
182              matched_code = None
183              for stored_hash in codes:
184                  if verify_recovery_code(body.code, stored_hash):
185                      matched_code = stored_hash
186                      break
187              if matched_code is not None:
188                  # Consume the recovery code
189                  codes.remove(matched_code)
190                  user_db.totp_recovery_codes = json.dumps(codes)
191                  db_wrapper.db.commit()
192  
193                  jwt_token = create_access_token(
194                      data={"username": username}, expires_delta=timedelta(minutes=1440)
195                  )
196                  response.set_cookie(
197                      key="restai_token", value=jwt_token,
198                      samesite="strict", expires=86400, httponly=True,
199                  )
200                  out = {"message": "Logged in successfully. Recovery code consumed."}
201                  warning = _password_age_warning(user_db, db_wrapper)
202                  if warning:
203                      out["password_warning"] = warning
204                  return out
205          except Exception:
206              pass
207  
208      raise HTTPException(status_code=401, detail="Invalid TOTP code")
209  
210  
211  
212  @router.get("/auth/whoami")
213  async def get_whoami(
214      request: Request,
215      user: User = Depends(get_current_username),
216      db_wrapper: DBWrapper = Depends(get_db_wrapper),
217  ):
218      """Get the currently authenticated user's profile."""
219      user_model = User.model_validate(db_wrapper.get_user_by_username(user.username))
220      result = user_model.model_dump()
221      result["impersonating"] = request.cookies.get("restai_token_admin") is not None
222      return result
223  
224  
225  @router.post("/auth/impersonate/{username}")
226  async def impersonate_user(
227      request: Request,
228      response: Response,
229      username: str = Path(description="Username to impersonate"),
230      user: User = Depends(get_current_username_admin),
231      db_wrapper: DBWrapper = Depends(get_db_wrapper),
232  ):
233      """Impersonate another user (admin only). Saves admin session for restoration."""
234      target = db_wrapper.get_user_by_username(username)
235      if target is None:
236          raise HTTPException(status_code=404, detail="User not found")
237  
238      # Save admin's current token
239      admin_token = request.cookies.get("restai_token")
240      if admin_token:
241          response.set_cookie(
242              key="restai_token_admin",
243              value=admin_token,
244              samesite="strict",
245              max_age=1800,
246              httponly=True,
247          )
248  
249      # Create token for target user
250      jwt_token = create_access_token(
251          data={"username": username},
252          expires_delta=timedelta(minutes=1440),
253      )
254      response.set_cookie(
255          key="restai_token",
256          value=jwt_token,
257          samesite="strict",
258          expires=86400,
259          httponly=True,
260      )
261  
262      from restai.audit import _log_to_db
263      _log_to_db(user.username, "IMPERSONATE_START", username, 200)
264  
265      return {"message": f"Impersonating {username}", "impersonating": True}
266  
267  
268  @router.post("/auth/exit-impersonation")
269  async def exit_impersonation(
270      request: Request,
271      response: Response,
272      _: User = Depends(get_current_username),
273      db_wrapper: DBWrapper = Depends(get_db_wrapper),
274  ):
275      """Exit impersonation and restore the admin session."""
276      admin_token = request.cookies.get("restai_token_admin")
277      if not admin_token:
278          raise HTTPException(status_code=400, detail="Not currently impersonating")
279  
280      # Validate the admin token is a valid JWT
281      try:
282          data = jwt.decode(admin_token, RESTAI_AUTH_SECRET, algorithms=["HS512"])
283      except jwt.PyJWTError:
284          response.delete_cookie(key="restai_token_admin")
285          raise HTTPException(status_code=400, detail="Invalid admin token")
286  
287      # Verify the token belongs to an admin user
288      admin_user = db_wrapper.get_user_by_username(data.get("username", ""))
289      if admin_user is None or not admin_user.is_admin:
290          response.delete_cookie(key="restai_token_admin")
291          raise HTTPException(status_code=400, detail="Invalid admin token")
292  
293      # Restore admin token
294      response.set_cookie(
295          key="restai_token",
296          value=admin_token,
297          samesite="strict",
298          expires=86400,
299          httponly=True,
300      )
301      response.delete_cookie(key="restai_token_admin")
302  
303      from restai.audit import _log_to_db
304      _log_to_db(admin_user.username, "IMPERSONATE_END", "", 200)
305  
306      return {"message": "Impersonation ended", "impersonating": False}
307  
308  
309  @router.post("/auth/logout")
310  async def logout(
311      request: Request, response: Response, user: User = Depends(get_current_username)
312  ):
313      """Clear the session cookie and log out."""
314      response.delete_cookie(key="restai_token")
315  
316      return {"message": "Logged out successfully."}