deps.py
1 """ 2 FastAPI dependencies for Ag3ntum API. 3 4 Provides dependency injection for authentication, database sessions, etc. 5 """ 6 import logging 7 from dataclasses import dataclass, field 8 from typing import Optional 9 10 from fastapi import Depends, HTTPException, Query, Request, status 11 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer 12 from sqlalchemy.ext.asyncio import AsyncSession 13 14 from ..db.database import get_db 15 from ..services.auth_service import auth_service, UserEnvironmentError 16 from ..services.connection_token import validate_connection_token 17 from ..core.sandbox_path_resolver import ( 18 configure_sandbox_path_resolver, 19 has_sandbox_path_resolver, 20 ) 21 22 logger = logging.getLogger(__name__) 23 24 # HTTP Bearer authentication scheme 25 bearer_scheme = HTTPBearer(auto_error=True) 26 # Optional bearer for endpoints that also accept query param tokens 27 bearer_scheme_optional = HTTPBearer(auto_error=False) 28 29 30 async def get_current_user_id( 31 credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), 32 db: AsyncSession = Depends(get_db), 33 ) -> str: 34 """ 35 Dependency that extracts and validates the JWT token. 36 37 Returns the user_id from the token. 38 39 Raises: 40 HTTPException: 401 if token is invalid/expired, 403 if user environment misconfigured. 41 """ 42 token = credentials.credentials 43 44 try: 45 user_id = await auth_service.validate_token(token, db) 46 except UserEnvironmentError as e: 47 # User account exists but filesystem is misconfigured 48 # Return 403 Forbidden - user must be recreated 49 raise HTTPException( 50 status_code=status.HTTP_403_FORBIDDEN, 51 detail=str(e), 52 ) 53 54 if not user_id: 55 raise HTTPException( 56 status_code=status.HTTP_401_UNAUTHORIZED, 57 detail="Invalid or expired token", 58 headers={"WWW-Authenticate": "Bearer"}, 59 ) 60 61 return user_id 62 63 64 async def get_proxy_caller_id( 65 request: Request, 66 credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional), 67 db: AsyncSession = Depends(get_db), 68 ) -> str: 69 """ 70 Auth dependency for the LLM proxy endpoint. 71 72 Accepts two authentication methods: 73 1. Loopback requests (127.0.0.1) with x-api-key header → returns "internal-agent" 74 (This is how the Claude Agent SDK authenticates when ANTHROPIC_BASE_URL is set) 75 2. Standard JWT Bearer token → falls back to get_current_user_id logic 76 77 This is needed because the SDK sends x-api-key (Anthropic API auth), not 78 JWT Bearer tokens, when making requests to the proxy endpoint. 79 """ 80 client_host = request.client.host if request.client else None 81 x_api_key = request.headers.get("x-api-key") 82 83 # Path 1: Loopback traffic with x-api-key (internal SDK calls) 84 if client_host == "127.0.0.1" and x_api_key: 85 logger.info("LLM Proxy: loopback auth accepted from %s", client_host) 86 return "internal-agent" 87 88 # Path 2: Standard JWT Bearer auth 89 if credentials and credentials.credentials: 90 token = credentials.credentials 91 try: 92 user_id = await auth_service.validate_token(token, db) 93 except UserEnvironmentError as e: 94 raise HTTPException( 95 status_code=status.HTTP_403_FORBIDDEN, 96 detail=str(e), 97 ) 98 if user_id: 99 return user_id 100 101 raise HTTPException( 102 status_code=status.HTTP_401_UNAUTHORIZED, 103 detail="Invalid or expired token", 104 headers={"WWW-Authenticate": "Bearer"}, 105 ) 106 107 108 async def get_current_user( 109 credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), 110 db: AsyncSession = Depends(get_db), 111 ): 112 """ 113 Dependency that extracts, validates JWT token and returns the full User object. 114 115 Returns the User object from the database. 116 117 Raises: 118 HTTPException: 401 if token is invalid/expired, 403 if user environment misconfigured. 119 """ 120 token = credentials.credentials 121 122 try: 123 user_id = await auth_service.validate_token(token, db) 124 except UserEnvironmentError as e: 125 raise HTTPException( 126 status_code=status.HTTP_403_FORBIDDEN, 127 detail=str(e), 128 ) 129 130 if not user_id: 131 raise HTTPException( 132 status_code=status.HTTP_401_UNAUTHORIZED, 133 detail="Invalid or expired token", 134 headers={"WWW-Authenticate": "Bearer"}, 135 ) 136 137 user = await auth_service.get_user_by_id(db, user_id) 138 if not user: 139 raise HTTPException( 140 status_code=status.HTTP_401_UNAUTHORIZED, 141 detail="User not found", 142 headers={"WWW-Authenticate": "Bearer"}, 143 ) 144 145 return user 146 147 148 async def require_admin( 149 credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), 150 db: AsyncSession = Depends(get_db), 151 ): 152 """ 153 Dependency that requires admin role. 154 155 Returns the User object if user is an admin. 156 157 Raises: 158 HTTPException: 401 if not authenticated, 403 if not admin. 159 """ 160 user = await get_current_user(credentials, db) 161 162 if user.role != "admin": 163 raise HTTPException( 164 status_code=status.HTTP_403_FORBIDDEN, 165 detail="Admin access required", 166 ) 167 168 return user 169 170 171 async def get_current_user_id_from_query_or_header( 172 token: Optional[str] = Query(None, description="JWT token for authentication"), 173 credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional), 174 db: AsyncSession = Depends(get_db), 175 ) -> str: 176 """ 177 Dependency that accepts JWT token from either: 178 1. Query parameter 'token' (for file downloads via browser) 179 2. Authorization header (standard Bearer token) 180 181 This is needed for file download endpoints where window.open() cannot set headers. 182 183 Returns the user_id from the token. 184 185 Raises: 186 HTTPException: 401 if not authenticated/invalid, 403 if user environment misconfigured. 187 """ 188 # Prefer header token if available, fall back to query param 189 actual_token = None 190 if credentials and credentials.credentials: 191 actual_token = credentials.credentials 192 elif token: 193 actual_token = token 194 195 if not actual_token: 196 raise HTTPException( 197 status_code=status.HTTP_401_UNAUTHORIZED, 198 detail="Not authenticated", 199 headers={"WWW-Authenticate": "Bearer"}, 200 ) 201 202 try: 203 user_id = await auth_service.validate_token(actual_token, db) 204 except UserEnvironmentError as e: 205 # User account exists but filesystem is misconfigured 206 # Return 403 Forbidden - user must be recreated 207 raise HTTPException( 208 status_code=status.HTTP_403_FORBIDDEN, 209 detail=str(e), 210 ) 211 212 if not user_id: 213 raise HTTPException( 214 status_code=status.HTTP_401_UNAUTHORIZED, 215 detail="Invalid or expired token", 216 headers={"WWW-Authenticate": "Bearer"}, 217 ) 218 219 return user_id 220 221 222 async def validate_sse_token( 223 token: Optional[str], 224 authorization: Optional[str], 225 db: AsyncSession, 226 ) -> str: 227 """Validate an SSE connection token or JWT for SSE/polling endpoints. 228 229 Tries connection token first (preferred, single-use, short-lived), 230 then falls back to JWT validation for backward compatibility. 231 232 Args: 233 token: Query parameter token (connection token or JWT). 234 authorization: Authorization header value. 235 db: Database session for JWT validation. 236 237 Returns: 238 The authenticated user_id. 239 240 Raises: 241 HTTPException: 401 if neither token is valid. 242 """ 243 # Extract token from header if not provided as query param 244 actual_token = token 245 if not actual_token and authorization and authorization.lower().startswith("bearer "): 246 actual_token = authorization.split(" ", 1)[1] 247 248 if not actual_token: 249 raise HTTPException( 250 status_code=status.HTTP_401_UNAUTHORIZED, 251 detail="Missing access token", 252 ) 253 254 # Try connection token first (preferred for SSE) 255 user_id = await validate_connection_token(actual_token) 256 if user_id: 257 return user_id 258 259 # Fall back to JWT validation (backward compatibility) 260 user_id = await auth_service.validate_token(actual_token, db) 261 if user_id: 262 return user_id 263 264 raise HTTPException( 265 status_code=status.HTTP_401_UNAUTHORIZED, 266 detail="Invalid or expired token", 267 ) 268 269 270 @dataclass 271 class AuthContext: 272 """Unified authentication context for reseller/admin endpoints.""" 273 274 user_id: str 275 role: str # "admin", "reseller", "user" 276 reseller_id: Optional[str] = None 277 api_key_id: Optional[str] = None 278 api_key_scopes: list = field(default_factory=list) 279 280 @property 281 def is_admin(self) -> bool: 282 return self.role == "admin" 283 284 @property 285 def is_reseller(self) -> bool: 286 return self.role == "reseller" 287 288 def has_scope(self, scope: str) -> bool: 289 """Check if auth context has a specific API key scope. 290 291 JWT auth (no API key) has all scopes implicitly. 292 """ 293 if not self.api_key_id: 294 return True # JWT auth = all scopes 295 return scope in self.api_key_scopes 296 297 298 async def get_auth_context( 299 request: Request, 300 credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional), 301 db: AsyncSession = Depends(get_db), 302 ) -> AuthContext: 303 """Unified auth: accepts JWT Bearer OR API key in X-API-Key header. 304 305 For JWT: extracts user from token, determines role from User.role, 306 if role=reseller, looks up reseller_id. 307 308 For API key: validates via APIKeyService, returns context with 309 key's scopes and reseller_id. 310 """ 311 # Try API key first (X-API-Key header) 312 api_key_header = request.headers.get("x-api-key") 313 if api_key_header and ( 314 api_key_header.startswith("ag3_res_") or api_key_header.startswith("ag3_adm_") 315 ): 316 from ..services.api_key_service import api_key_service 317 import json 318 319 client_ip = request.client.host if request.client else "unknown" 320 321 key = await api_key_service.validate_key(db, api_key_header) 322 if not key: 323 await api_key_service.log_usage( 324 db, None, None, "auth_failed", 325 None, client_ip, 401, error="Invalid or expired API key", 326 ) 327 raise HTTPException( 328 status_code=status.HTTP_401_UNAUTHORIZED, 329 detail="Invalid or expired API key", 330 ) 331 332 # Check IP allowlist 333 if not api_key_service.check_ip_allowed(key, client_ip): 334 await api_key_service.log_usage( 335 db, key.id, key.reseller_id, "ip_denied", 336 key.user_id, client_ip, 403, error="IP not in allowlist", 337 ) 338 raise HTTPException( 339 status_code=status.HTTP_403_FORBIDDEN, 340 detail="IP address not in allowlist", 341 ) 342 343 # Check per-key rate limit 344 from ..services.api_key_rate_limiter import check_api_key_rate_limit 345 if not await check_api_key_rate_limit(key.id, key.rate_limit_per_minute): 346 await api_key_service.log_usage( 347 db, key.id, key.reseller_id, "rate_limited", 348 key.user_id, client_ip, 429, error="Rate limit exceeded", 349 ) 350 raise HTTPException( 351 status_code=status.HTTP_429_TOO_MANY_REQUESTS, 352 detail="API key rate limit exceeded", 353 ) 354 355 # Update last used 356 await api_key_service.update_last_used(db, key.id, client_ip) 357 358 scopes = json.loads(key.scopes) if key.scopes else [] 359 360 # Determine role from key prefix 361 role = "admin" if api_key_header.startswith("ag3_adm_") else "reseller" 362 363 return AuthContext( 364 user_id=key.user_id, 365 role=role, 366 reseller_id=key.reseller_id, 367 api_key_id=key.id, 368 api_key_scopes=scopes, 369 ) 370 371 # Fall back to JWT Bearer auth 372 if not credentials or not credentials.credentials: 373 raise HTTPException( 374 status_code=status.HTTP_401_UNAUTHORIZED, 375 detail="Not authenticated", 376 headers={"WWW-Authenticate": "Bearer"}, 377 ) 378 379 token = credentials.credentials 380 try: 381 user_id = await auth_service.validate_token(token, db) 382 except UserEnvironmentError as e: 383 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e)) 384 385 if not user_id: 386 raise HTTPException( 387 status_code=status.HTTP_401_UNAUTHORIZED, 388 detail="Invalid or expired token", 389 headers={"WWW-Authenticate": "Bearer"}, 390 ) 391 392 user = await auth_service.get_user_by_id(db, user_id) 393 if not user: 394 raise HTTPException( 395 status_code=status.HTTP_401_UNAUTHORIZED, 396 detail="User not found", 397 ) 398 399 # reseller_id is stored directly on the User row (set during reseller creation) 400 reseller_id = user.reseller_id if user.role == "reseller" else None 401 402 return AuthContext( 403 user_id=user.id, 404 role=user.role, 405 reseller_id=reseller_id, 406 ) 407 408 409 async def require_reseller( 410 auth: AuthContext = Depends(get_auth_context), 411 ) -> AuthContext: 412 """Require reseller role (or admin for override access).""" 413 if auth.role not in ("reseller", "admin"): 414 raise HTTPException( 415 status_code=status.HTTP_403_FORBIDDEN, 416 detail="Reseller access required", 417 ) 418 return auth 419 420 421 def require_scope(scope: str): 422 """Factory for scope-checking dependencies.""" 423 424 async def _check(auth: AuthContext = Depends(get_auth_context)) -> AuthContext: 425 if not auth.has_scope(scope): 426 raise HTTPException( 427 status_code=status.HTTP_403_FORBIDDEN, 428 detail=f"Missing required scope: {scope}", 429 ) 430 return auth 431 432 return _check 433 434 435 def configure_sandbox_path_resolver_if_needed( 436 session_id: str, 437 username: str, 438 workspace_docker: str, 439 ) -> None: 440 """ 441 Configure SandboxPathResolver for a session if not already configured. 442 443 This is used by the File Explorer API to configure the resolver on-demand 444 when accessing existing sessions after a server restart. 445 446 Args: 447 session_id: The session ID 448 username: The username for the session 449 workspace_docker: The Docker workspace path 450 """ 451 if has_sandbox_path_resolver(session_id): 452 return 453 454 try: 455 configure_sandbox_path_resolver( 456 session_id=session_id, 457 username=username, 458 workspace_docker=workspace_docker, 459 ) 460 logger.info( 461 f"On-demand SandboxPathResolver configured for session {session_id}" 462 ) 463 except Exception as e: 464 logger.warning(f"Failed to configure SandboxPathResolver on-demand: {e}")