direct_access.py
1 from datetime import datetime, timezone 2 from typing import Optional 3 4 from fastapi import HTTPException 5 6 from restai.database import DBWrapper 7 from restai.models.databasemodels import OutputDatabase 8 from restai.models.models import User 9 10 11 def resolve_team_for_llm(user: User, llm_name: str, db: DBWrapper) -> Optional[int]: 12 """Resolve which team grants the user access to an LLM. 13 14 Returns: 15 None - admin bypass (no team_id needed) 16 int - team_id that grants access 17 Raises: 18 HTTPException 403 if no access 19 """ 20 if user.is_admin: 21 return None 22 23 teams = db.get_teams_for_user(user.id) 24 for team in teams: 25 for llm in team.llms: 26 if llm.name == llm_name: 27 # Check team budget 28 if team.budget >= 0: 29 spending = db.get_team_spending(team.id) 30 if team.budget - spending <= 0: 31 continue 32 return team.id 33 34 raise HTTPException(status_code=403, detail="You do not have access to this model") 35 36 37 def resolve_team_for_image_generator(user: User, generator_name: str, db: DBWrapper) -> Optional[int]: 38 """Resolve which team grants the user access to an image generator.""" 39 if user.is_admin: 40 return None 41 42 teams = db.get_teams_for_user(user.id) 43 for team in teams: 44 gen_names = [g.generator_name for g in team.image_generators] 45 if generator_name in gen_names: 46 if team.budget >= 0: 47 spending = db.get_team_spending(team.id) 48 if team.budget - spending <= 0: 49 continue 50 return team.id 51 52 raise HTTPException(status_code=403, detail="You do not have access to this image generator") 53 54 55 def resolve_team_for_audio_generator(user: User, generator_name: str, db: DBWrapper) -> Optional[int]: 56 """Resolve which team grants the user access to an audio generator.""" 57 if user.is_admin: 58 return None 59 60 teams = db.get_teams_for_user(user.id) 61 for team in teams: 62 gen_names = [g.generator_name for g in team.audio_generators] 63 if generator_name in gen_names: 64 if team.budget >= 0: 65 spending = db.get_team_spending(team.id) 66 if team.budget - spending <= 0: 67 continue 68 return team.id 69 70 raise HTTPException(status_code=403, detail="You do not have access to this audio generator") 71 72 73 def resolve_team_for_embedding(user: User, embedding_name: str, db: DBWrapper) -> Optional[int]: 74 """Resolve which team grants the user access to an embedding model.""" 75 if user.is_admin: 76 return None 77 78 teams = db.get_teams_for_user(user.id) 79 for team in teams: 80 for emb in team.embeddings: 81 if emb.name == embedding_name: 82 if team.budget >= 0: 83 spending = db.get_team_spending(team.id) 84 if team.budget - spending <= 0: 85 continue 86 return team.id 87 88 raise HTTPException(status_code=403, detail="You do not have access to this embedding model") 89 90 91 def log_direct_usage( 92 db: DBWrapper, 93 user_id: int, 94 team_id: Optional[int], 95 llm_name: str, 96 question: str, 97 answer: str, 98 input_tokens: int, 99 output_tokens: int, 100 input_cost: float, 101 output_cost: float, 102 ): 103 """Log direct access usage to OutputDatabase with project_id=NULL.""" 104 entry = OutputDatabase( 105 user_id=user_id, 106 project_id=None, 107 team_id=team_id, 108 llm=llm_name, 109 question=question, 110 answer=answer, 111 date=datetime.now(timezone.utc), 112 input_tokens=input_tokens, 113 output_tokens=output_tokens, 114 input_cost=input_cost, 115 output_cost=output_cost, 116 ) 117 db.db.add(entry) 118 db.db.commit()