/ database.py
database.py
1 import os 2 from typing import Optional 3 from datetime import datetime 4 from sqlalchemy import create_engine 5 from sqlalchemy.orm import sessionmaker, Session 6 from db.models import User, ApiKey, UsageRecord 7 from config import load_config 8 9 # Load configuration and define connection URL 10 config = load_config() 11 DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost:5432/db_name") 12 13 # Creating the SQLAlchemy engine and local session 14 engine = create_engine(DATABASE_URL) 15 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 16 17 def get_db() -> Session: 18 """ 19 FastAPI dependency to get a database session. 20 """ 21 db = SessionLocal() 22 try: 23 yield db 24 finally: 25 db.close() 26 27 # Add this alias for compatibility with auth_router.py 28 get_db_connection = get_db 29 30 def get_user_by_username(db: Session, username: str) -> Optional[User]: 31 """Retrieves a user by username.""" 32 return db.query(User).filter(User.username == username).first() 33 34 def get_user_by_id(db: Session, user_id: str) -> Optional[User]: 35 """Retrieves a user by ID.""" 36 return db.query(User).filter(User.id == user_id).first() 37 38 def get_api_key(db: Session, key: str) -> Optional[ApiKey]: 39 """Retrieves information for an API key.""" 40 return db.query(ApiKey).filter(ApiKey.key == key).first() 41 42 def get_api_keys_for_user(db: Session, user_id: str): 43 """Retrieves all API keys for a user.""" 44 return db.query(ApiKey).filter(ApiKey.user_id == user_id).all() 45 46 def create_user(db: Session, user: User) -> User: 47 """Creates a new user in the database.""" 48 db.add(user) 49 db.commit() 50 db.refresh(user) 51 return user 52 53 def create_api_key(db: Session, api_key: ApiKey) -> ApiKey: 54 """Creates a new API key.""" 55 db.add(api_key) 56 db.commit() 57 db.refresh(api_key) 58 return api_key 59 60 def update_api_key(db: Session, api_key: ApiKey) -> ApiKey: 61 """Updates an existing API key.""" 62 db.merge(api_key) 63 db.commit() 64 db.refresh(api_key) 65 return api_key 66 67 def delete_api_key(db: Session, key: str) -> bool: 68 """Deletes an API key.""" 69 api_key = get_api_key(db, key) 70 if api_key: 71 db.delete(api_key) 72 db.commit() 73 return True 74 return False 75 76 def record_api_usage(db: Session, record: UsageRecord) -> None: 77 """Records API usage.""" 78 db.add(record) 79 db.commit() 80 81 def get_user_usage_stats(db: Session, user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None) -> dict: 82 """ 83 Retrieves usage statistics for a user. 84 If start_date and end_date are provided, filters records accordingly. 85 """ 86 query = db.query(UsageRecord).filter(UsageRecord.user_id == user_id) 87 if start_date: 88 query = query.filter(UsageRecord.timestamp >= start_date) 89 if end_date: 90 query = query.filter(UsageRecord.timestamp <= end_date) 91 records = query.all() 92 93 total_requests = len(records) 94 total_tokens_input = sum(record.tokens_input for record in records) 95 total_tokens_output = sum(record.tokens_output for record in records) 96 avg_processing_time = sum(record.processing_time for record in records) / total_requests if total_requests else 0 97 98 daily_stats = {} 99 for record in records: 100 day = record.timestamp.strftime("%Y-%m-%d") 101 if day not in daily_stats: 102 daily_stats[day] = { 103 "requests": 0, 104 "tokens_input": 0, 105 "tokens_output": 0, 106 "processing_time": 0 107 } 108 daily_stats[day]["requests"] += 1 109 daily_stats[day]["tokens_input"] += record.tokens_input 110 daily_stats[day]["tokens_output"] += record.tokens_output 111 daily_stats[day]["processing_time"] += record.processing_time 112 113 return { 114 "total_requests": total_requests, 115 "total_tokens_input": total_tokens_input, 116 "total_tokens_output": total_tokens_output, 117 "avg_processing_time": avg_processing_time, 118 "daily_stats": daily_stats 119 } 120 def get_all_users(db: Session): 121 """Retrieves all users.""" 122 return db.query(User).all() 123 def get_all_api_keys(db: Session): 124 """Retrieves all API keys.""" 125 return db.query(ApiKey).all() 126 def get_all_usage_records(db: Session): 127 """Retrieves all usage records.""" 128 return db.query(UsageRecord).all() 129 def delete_user(db: Session, user_id: str) -> bool: 130 """Deletes a user.""" 131 user = get_user_by_id(db, user_id) 132 if user: 133 db.delete(user) 134 db.commit() 135 return True 136 return False 137 def delete_all_api_keys(db: Session): 138 """Deletes all API keys.""" 139 db.query(ApiKey).delete() 140 db.commit() 141 def delete_all_users(db: Session): 142 """Deletes all users.""" 143 db.query(User).delete() 144 db.commit() 145 def delete_all_usage_records(db: Session): 146 """Deletes all usage records.""" 147 db.query(UsageRecord).delete() 148 db.commit() 149 def delete_all_data(db: Session): 150 """Deletes all data from the database.""" 151 delete_all_api_keys(db) 152 delete_all_users(db) 153 delete_all_usage_records(db) 154 def delete_all_data_with_confirmation(db: Session, confirmation_code: str) -> bool: 155 """Deletes all data from the database with a confirmation code.""" 156 if confirmation_code == config.confirmation_code: 157 delete_all_data(db) 158 return True 159 return False 160 def get_user_by_email(db: Session, email: str) -> Optional[User]: 161 """Retrieves a user by email address.""" 162 return db.query(User).filter(User.email == email).first() 163 def get_api_key_by_id(db: Session, key_id: str) -> Optional[ApiKey]: 164 """Retrieves an API key by ID.""" 165 return db.query(ApiKey).filter(ApiKey.id == key_id).first() 166 def get_usage_records_by_user(db: Session, user_id: str): 167 """Retrieves usage records for a user.""" 168 return db.query(UsageRecord).filter(UsageRecord.user_id == user_id).all() 169 def get_usage_records_by_api_key(db: Session, api_key_id: str): 170 """Retrieves usage records for an API key.""" 171 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id).all() 172 def get_usage_records_by_date(db: Session, date: datetime): 173 """Retrieves usage records by date.""" 174 return db.query(UsageRecord).filter(UsageRecord.timestamp >= date).all() 175 def get_usage_records_by_date_range(db: Session, start_date: datetime, end_date: datetime): 176 """Retrieves usage records by date range.""" 177 return db.query(UsageRecord).filter(UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all() 178 def get_usage_records_by_api_key_and_date(db: Session, api_key_id: str, date: datetime): 179 """Retrieves usage records for an API key by date.""" 180 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= date).all() 181 def get_usage_records_by_api_key_and_date_range(db: Session, api_key_id: str, start_date: datetime, end_date: datetime): 182 """Retrieves usage records for an API key by date range.""" 183 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all() 184 def get_usage_records_by_user_and_date(db: Session, user_id: str, date: datetime): 185 """Retrieves usage records for a user by date.""" 186 return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.timestamp >= date).all() 187 def get_usage_records_by_user_and_date_range(db: Session, user_id: str, start_date: datetime, end_date: datetime): 188 """Retrieves usage records for a user by date range.""" 189 return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all() 190 def get_usage_records_by_api_key_and_user(db: Session, api_key_id: str, user_id: str): 191 """Retrieves usage records for an API key by user.""" 192 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id).all() 193 def get_usage_records_by_api_key_and_user_and_date(db: Session, api_key_id: str, user_id: str, date: datetime): 194 """Retrieves usage records for an API key by user and date.""" 195 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id, UsageRecord.timestamp >= date).all() 196 def get_usage_records_by_api_key_and_user_and_date_range(db: Session, api_key_id: str, user_id: str, start_date: datetime, end_date: datetime): 197 """Retrieves usage records for an API key by user and date range.""" 198 return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all() 199 def get_usage_records_by_user_and_api_key(db: Session, user_id: str, api_key_id: str): 200 """Retrieves usage records for a user by API key.""" 201 return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.api_key_id == api_key_id).all() 202 def get_usage_records_by_user_and_api_key_and_date(db: Session, user_id: str, api_key_id: str, date: datetime): 203 """Retrieves usage records for a user by API key and date.""" 204 return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= date).all()