/ src / api / deps.py
deps.py
  1  """
  2  FastAPI dependencies for Ag3ntum API.
  3  
  4  Provides dependency injection for authentication, database sessions, etc.
  5  """
  6  import logging
  7  from dataclasses import dataclass, field
  8  from typing import Optional
  9  
 10  from fastapi import Depends, HTTPException, Query, Request, status
 11  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 12  from sqlalchemy.ext.asyncio import AsyncSession
 13  
 14  from ..db.database import get_db
 15  from ..services.auth_service import auth_service, UserEnvironmentError
 16  from ..services.connection_token import validate_connection_token
 17  from ..core.sandbox_path_resolver import (
 18      configure_sandbox_path_resolver,
 19      has_sandbox_path_resolver,
 20  )
 21  
 22  logger = logging.getLogger(__name__)
 23  
 24  # HTTP Bearer authentication scheme
 25  bearer_scheme = HTTPBearer(auto_error=True)
 26  # Optional bearer for endpoints that also accept query param tokens
 27  bearer_scheme_optional = HTTPBearer(auto_error=False)
 28  
 29  
 30  async def get_current_user_id(
 31      credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
 32      db: AsyncSession = Depends(get_db),
 33  ) -> str:
 34      """
 35      Dependency that extracts and validates the JWT token.
 36  
 37      Returns the user_id from the token.
 38  
 39      Raises:
 40          HTTPException: 401 if token is invalid/expired, 403 if user environment misconfigured.
 41      """
 42      token = credentials.credentials
 43  
 44      try:
 45          user_id = await auth_service.validate_token(token, db)
 46      except UserEnvironmentError as e:
 47          # User account exists but filesystem is misconfigured
 48          # Return 403 Forbidden - user must be recreated
 49          raise HTTPException(
 50              status_code=status.HTTP_403_FORBIDDEN,
 51              detail=str(e),
 52          )
 53  
 54      if not user_id:
 55          raise HTTPException(
 56              status_code=status.HTTP_401_UNAUTHORIZED,
 57              detail="Invalid or expired token",
 58              headers={"WWW-Authenticate": "Bearer"},
 59          )
 60  
 61      return user_id
 62  
 63  
 64  async def get_proxy_caller_id(
 65      request: Request,
 66      credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional),
 67      db: AsyncSession = Depends(get_db),
 68  ) -> str:
 69      """
 70      Auth dependency for the LLM proxy endpoint.
 71  
 72      Accepts two authentication methods:
 73      1. Loopback requests (127.0.0.1) with x-api-key header → returns "internal-agent"
 74         (This is how the Claude Agent SDK authenticates when ANTHROPIC_BASE_URL is set)
 75      2. Standard JWT Bearer token → falls back to get_current_user_id logic
 76  
 77      This is needed because the SDK sends x-api-key (Anthropic API auth), not
 78      JWT Bearer tokens, when making requests to the proxy endpoint.
 79      """
 80      client_host = request.client.host if request.client else None
 81      x_api_key = request.headers.get("x-api-key")
 82  
 83      # Path 1: Loopback traffic with x-api-key (internal SDK calls)
 84      if client_host == "127.0.0.1" and x_api_key:
 85          logger.info("LLM Proxy: loopback auth accepted from %s", client_host)
 86          return "internal-agent"
 87  
 88      # Path 2: Standard JWT Bearer auth
 89      if credentials and credentials.credentials:
 90          token = credentials.credentials
 91          try:
 92              user_id = await auth_service.validate_token(token, db)
 93          except UserEnvironmentError as e:
 94              raise HTTPException(
 95                  status_code=status.HTTP_403_FORBIDDEN,
 96                  detail=str(e),
 97              )
 98          if user_id:
 99              return user_id
100  
101      raise HTTPException(
102          status_code=status.HTTP_401_UNAUTHORIZED,
103          detail="Invalid or expired token",
104          headers={"WWW-Authenticate": "Bearer"},
105      )
106  
107  
108  async def get_current_user(
109      credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
110      db: AsyncSession = Depends(get_db),
111  ):
112      """
113      Dependency that extracts, validates JWT token and returns the full User object.
114  
115      Returns the User object from the database.
116  
117      Raises:
118          HTTPException: 401 if token is invalid/expired, 403 if user environment misconfigured.
119      """
120      token = credentials.credentials
121  
122      try:
123          user_id = await auth_service.validate_token(token, db)
124      except UserEnvironmentError as e:
125          raise HTTPException(
126              status_code=status.HTTP_403_FORBIDDEN,
127              detail=str(e),
128          )
129  
130      if not user_id:
131          raise HTTPException(
132              status_code=status.HTTP_401_UNAUTHORIZED,
133              detail="Invalid or expired token",
134              headers={"WWW-Authenticate": "Bearer"},
135          )
136  
137      user = await auth_service.get_user_by_id(db, user_id)
138      if not user:
139          raise HTTPException(
140              status_code=status.HTTP_401_UNAUTHORIZED,
141              detail="User not found",
142              headers={"WWW-Authenticate": "Bearer"},
143          )
144  
145      return user
146  
147  
148  async def require_admin(
149      credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
150      db: AsyncSession = Depends(get_db),
151  ):
152      """
153      Dependency that requires admin role.
154  
155      Returns the User object if user is an admin.
156  
157      Raises:
158          HTTPException: 401 if not authenticated, 403 if not admin.
159      """
160      user = await get_current_user(credentials, db)
161  
162      if user.role != "admin":
163          raise HTTPException(
164              status_code=status.HTTP_403_FORBIDDEN,
165              detail="Admin access required",
166          )
167  
168      return user
169  
170  
171  async def get_current_user_id_from_query_or_header(
172      token: Optional[str] = Query(None, description="JWT token for authentication"),
173      credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional),
174      db: AsyncSession = Depends(get_db),
175  ) -> str:
176      """
177      Dependency that accepts JWT token from either:
178      1. Query parameter 'token' (for file downloads via browser)
179      2. Authorization header (standard Bearer token)
180  
181      This is needed for file download endpoints where window.open() cannot set headers.
182  
183      Returns the user_id from the token.
184  
185      Raises:
186          HTTPException: 401 if not authenticated/invalid, 403 if user environment misconfigured.
187      """
188      # Prefer header token if available, fall back to query param
189      actual_token = None
190      if credentials and credentials.credentials:
191          actual_token = credentials.credentials
192      elif token:
193          actual_token = token
194  
195      if not actual_token:
196          raise HTTPException(
197              status_code=status.HTTP_401_UNAUTHORIZED,
198              detail="Not authenticated",
199              headers={"WWW-Authenticate": "Bearer"},
200          )
201  
202      try:
203          user_id = await auth_service.validate_token(actual_token, db)
204      except UserEnvironmentError as e:
205          # User account exists but filesystem is misconfigured
206          # Return 403 Forbidden - user must be recreated
207          raise HTTPException(
208              status_code=status.HTTP_403_FORBIDDEN,
209              detail=str(e),
210          )
211  
212      if not user_id:
213          raise HTTPException(
214              status_code=status.HTTP_401_UNAUTHORIZED,
215              detail="Invalid or expired token",
216              headers={"WWW-Authenticate": "Bearer"},
217          )
218  
219      return user_id
220  
221  
222  async def validate_sse_token(
223      token: Optional[str],
224      authorization: Optional[str],
225      db: AsyncSession,
226  ) -> str:
227      """Validate an SSE connection token or JWT for SSE/polling endpoints.
228  
229      Tries connection token first (preferred, single-use, short-lived),
230      then falls back to JWT validation for backward compatibility.
231  
232      Args:
233          token: Query parameter token (connection token or JWT).
234          authorization: Authorization header value.
235          db: Database session for JWT validation.
236  
237      Returns:
238          The authenticated user_id.
239  
240      Raises:
241          HTTPException: 401 if neither token is valid.
242      """
243      # Extract token from header if not provided as query param
244      actual_token = token
245      if not actual_token and authorization and authorization.lower().startswith("bearer "):
246          actual_token = authorization.split(" ", 1)[1]
247  
248      if not actual_token:
249          raise HTTPException(
250              status_code=status.HTTP_401_UNAUTHORIZED,
251              detail="Missing access token",
252          )
253  
254      # Try connection token first (preferred for SSE)
255      user_id = await validate_connection_token(actual_token)
256      if user_id:
257          return user_id
258  
259      # Fall back to JWT validation (backward compatibility)
260      user_id = await auth_service.validate_token(actual_token, db)
261      if user_id:
262          return user_id
263  
264      raise HTTPException(
265          status_code=status.HTTP_401_UNAUTHORIZED,
266          detail="Invalid or expired token",
267      )
268  
269  
270  @dataclass
271  class AuthContext:
272      """Unified authentication context for reseller/admin endpoints."""
273  
274      user_id: str
275      role: str  # "admin", "reseller", "user"
276      reseller_id: Optional[str] = None
277      api_key_id: Optional[str] = None
278      api_key_scopes: list = field(default_factory=list)
279  
280      @property
281      def is_admin(self) -> bool:
282          return self.role == "admin"
283  
284      @property
285      def is_reseller(self) -> bool:
286          return self.role == "reseller"
287  
288      def has_scope(self, scope: str) -> bool:
289          """Check if auth context has a specific API key scope.
290  
291          JWT auth (no API key) has all scopes implicitly.
292          """
293          if not self.api_key_id:
294              return True  # JWT auth = all scopes
295          return scope in self.api_key_scopes
296  
297  
298  async def get_auth_context(
299      request: Request,
300      credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme_optional),
301      db: AsyncSession = Depends(get_db),
302  ) -> AuthContext:
303      """Unified auth: accepts JWT Bearer OR API key in X-API-Key header.
304  
305      For JWT: extracts user from token, determines role from User.role,
306      if role=reseller, looks up reseller_id.
307  
308      For API key: validates via APIKeyService, returns context with
309      key's scopes and reseller_id.
310      """
311      # Try API key first (X-API-Key header)
312      api_key_header = request.headers.get("x-api-key")
313      if api_key_header and (
314          api_key_header.startswith("ag3_res_") or api_key_header.startswith("ag3_adm_")
315      ):
316          from ..services.api_key_service import api_key_service
317          import json
318  
319          client_ip = request.client.host if request.client else "unknown"
320  
321          key = await api_key_service.validate_key(db, api_key_header)
322          if not key:
323              await api_key_service.log_usage(
324                  db, None, None, "auth_failed",
325                  None, client_ip, 401, error="Invalid or expired API key",
326              )
327              raise HTTPException(
328                  status_code=status.HTTP_401_UNAUTHORIZED,
329                  detail="Invalid or expired API key",
330              )
331  
332          # Check IP allowlist
333          if not api_key_service.check_ip_allowed(key, client_ip):
334              await api_key_service.log_usage(
335                  db, key.id, key.reseller_id, "ip_denied",
336                  key.user_id, client_ip, 403, error="IP not in allowlist",
337              )
338              raise HTTPException(
339                  status_code=status.HTTP_403_FORBIDDEN,
340                  detail="IP address not in allowlist",
341              )
342  
343          # Check per-key rate limit
344          from ..services.api_key_rate_limiter import check_api_key_rate_limit
345          if not await check_api_key_rate_limit(key.id, key.rate_limit_per_minute):
346              await api_key_service.log_usage(
347                  db, key.id, key.reseller_id, "rate_limited",
348                  key.user_id, client_ip, 429, error="Rate limit exceeded",
349              )
350              raise HTTPException(
351                  status_code=status.HTTP_429_TOO_MANY_REQUESTS,
352                  detail="API key rate limit exceeded",
353              )
354  
355          # Update last used
356          await api_key_service.update_last_used(db, key.id, client_ip)
357  
358          scopes = json.loads(key.scopes) if key.scopes else []
359  
360          # Determine role from key prefix
361          role = "admin" if api_key_header.startswith("ag3_adm_") else "reseller"
362  
363          return AuthContext(
364              user_id=key.user_id,
365              role=role,
366              reseller_id=key.reseller_id,
367              api_key_id=key.id,
368              api_key_scopes=scopes,
369          )
370  
371      # Fall back to JWT Bearer auth
372      if not credentials or not credentials.credentials:
373          raise HTTPException(
374              status_code=status.HTTP_401_UNAUTHORIZED,
375              detail="Not authenticated",
376              headers={"WWW-Authenticate": "Bearer"},
377          )
378  
379      token = credentials.credentials
380      try:
381          user_id = await auth_service.validate_token(token, db)
382      except UserEnvironmentError as e:
383          raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(e))
384  
385      if not user_id:
386          raise HTTPException(
387              status_code=status.HTTP_401_UNAUTHORIZED,
388              detail="Invalid or expired token",
389              headers={"WWW-Authenticate": "Bearer"},
390          )
391  
392      user = await auth_service.get_user_by_id(db, user_id)
393      if not user:
394          raise HTTPException(
395              status_code=status.HTTP_401_UNAUTHORIZED,
396              detail="User not found",
397          )
398  
399      # reseller_id is stored directly on the User row (set during reseller creation)
400      reseller_id = user.reseller_id if user.role == "reseller" else None
401  
402      return AuthContext(
403          user_id=user.id,
404          role=user.role,
405          reseller_id=reseller_id,
406      )
407  
408  
409  async def require_reseller(
410      auth: AuthContext = Depends(get_auth_context),
411  ) -> AuthContext:
412      """Require reseller role (or admin for override access)."""
413      if auth.role not in ("reseller", "admin"):
414          raise HTTPException(
415              status_code=status.HTTP_403_FORBIDDEN,
416              detail="Reseller access required",
417          )
418      return auth
419  
420  
421  def require_scope(scope: str):
422      """Factory for scope-checking dependencies."""
423  
424      async def _check(auth: AuthContext = Depends(get_auth_context)) -> AuthContext:
425          if not auth.has_scope(scope):
426              raise HTTPException(
427                  status_code=status.HTTP_403_FORBIDDEN,
428                  detail=f"Missing required scope: {scope}",
429              )
430          return auth
431  
432      return _check
433  
434  
435  def configure_sandbox_path_resolver_if_needed(
436      session_id: str,
437      username: str,
438      workspace_docker: str,
439  ) -> None:
440      """
441      Configure SandboxPathResolver for a session if not already configured.
442  
443      This is used by the File Explorer API to configure the resolver on-demand
444      when accessing existing sessions after a server restart.
445  
446      Args:
447          session_id: The session ID
448          username: The username for the session
449          workspace_docker: The Docker workspace path
450      """
451      if has_sandbox_path_resolver(session_id):
452          return
453  
454      try:
455          configure_sandbox_path_resolver(
456              session_id=session_id,
457              username=username,
458              workspace_docker=workspace_docker,
459          )
460          logger.info(
461              f"On-demand SandboxPathResolver configured for session {session_id}"
462          )
463      except Exception as e:
464          logger.warning(f"Failed to configure SandboxPathResolver on-demand: {e}")