/ restai / direct_access.py
direct_access.py
  1  from datetime import datetime, timezone
  2  from typing import Optional
  3  
  4  from fastapi import HTTPException
  5  
  6  from restai.database import DBWrapper
  7  from restai.models.databasemodels import OutputDatabase
  8  from restai.models.models import User
  9  
 10  
 11  def resolve_team_for_llm(user: User, llm_name: str, db: DBWrapper) -> Optional[int]:
 12      """Resolve which team grants the user access to an LLM.
 13  
 14      Returns:
 15          None  - admin bypass (no team_id needed)
 16          int   - team_id that grants access
 17      Raises:
 18          HTTPException 403 if no access
 19      """
 20      if user.is_admin:
 21          return None
 22  
 23      teams = db.get_teams_for_user(user.id)
 24      for team in teams:
 25          for llm in team.llms:
 26              if llm.name == llm_name:
 27                  # Check team budget
 28                  if team.budget >= 0:
 29                      spending = db.get_team_spending(team.id)
 30                      if team.budget - spending <= 0:
 31                          continue
 32                  return team.id
 33  
 34      raise HTTPException(status_code=403, detail="You do not have access to this model")
 35  
 36  
 37  def resolve_team_for_image_generator(user: User, generator_name: str, db: DBWrapper) -> Optional[int]:
 38      """Resolve which team grants the user access to an image generator."""
 39      if user.is_admin:
 40          return None
 41  
 42      teams = db.get_teams_for_user(user.id)
 43      for team in teams:
 44          gen_names = [g.generator_name for g in team.image_generators]
 45          if generator_name in gen_names:
 46              if team.budget >= 0:
 47                  spending = db.get_team_spending(team.id)
 48                  if team.budget - spending <= 0:
 49                      continue
 50              return team.id
 51  
 52      raise HTTPException(status_code=403, detail="You do not have access to this image generator")
 53  
 54  
 55  def resolve_team_for_audio_generator(user: User, generator_name: str, db: DBWrapper) -> Optional[int]:
 56      """Resolve which team grants the user access to an audio generator."""
 57      if user.is_admin:
 58          return None
 59  
 60      teams = db.get_teams_for_user(user.id)
 61      for team in teams:
 62          gen_names = [g.generator_name for g in team.audio_generators]
 63          if generator_name in gen_names:
 64              if team.budget >= 0:
 65                  spending = db.get_team_spending(team.id)
 66                  if team.budget - spending <= 0:
 67                      continue
 68              return team.id
 69  
 70      raise HTTPException(status_code=403, detail="You do not have access to this audio generator")
 71  
 72  
 73  def resolve_team_for_embedding(user: User, embedding_name: str, db: DBWrapper) -> Optional[int]:
 74      """Resolve which team grants the user access to an embedding model."""
 75      if user.is_admin:
 76          return None
 77  
 78      teams = db.get_teams_for_user(user.id)
 79      for team in teams:
 80          for emb in team.embeddings:
 81              if emb.name == embedding_name:
 82                  if team.budget >= 0:
 83                      spending = db.get_team_spending(team.id)
 84                      if team.budget - spending <= 0:
 85                          continue
 86                  return team.id
 87  
 88      raise HTTPException(status_code=403, detail="You do not have access to this embedding model")
 89  
 90  
 91  def log_direct_usage(
 92      db: DBWrapper,
 93      user_id: int,
 94      team_id: Optional[int],
 95      llm_name: str,
 96      question: str,
 97      answer: str,
 98      input_tokens: int,
 99      output_tokens: int,
100      input_cost: float,
101      output_cost: float,
102  ):
103      """Log direct access usage to OutputDatabase with project_id=NULL."""
104      entry = OutputDatabase(
105          user_id=user_id,
106          project_id=None,
107          team_id=team_id,
108          llm=llm_name,
109          question=question,
110          answer=answer,
111          date=datetime.now(timezone.utc),
112          input_tokens=input_tokens,
113          output_tokens=output_tokens,
114          input_cost=input_cost,
115          output_cost=output_cost,
116      )
117      db.db.add(entry)
118      db.db.commit()