/ database.py
database.py
  1  import os
  2  from typing import Optional
  3  from datetime import datetime
  4  from sqlalchemy import create_engine
  5  from sqlalchemy.orm import sessionmaker, Session
  6  from db.models import User, ApiKey, UsageRecord
  7  from config import load_config
  8  
  9  # Load configuration and define connection URL
 10  config = load_config()
 11  DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:password@localhost:5432/db_name")
 12  
 13  # Creating the SQLAlchemy engine and local session
 14  engine = create_engine(DATABASE_URL)
 15  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 16  
 17  def get_db() -> Session:
 18      """
 19      FastAPI dependency to get a database session.
 20      """
 21      db = SessionLocal()
 22      try:
 23          yield db
 24      finally:
 25          db.close()
 26  
 27  # Add this alias for compatibility with auth_router.py
 28  get_db_connection = get_db
 29  
 30  def get_user_by_username(db: Session, username: str) -> Optional[User]:
 31      """Retrieves a user by username."""
 32      return db.query(User).filter(User.username == username).first()
 33  
 34  def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
 35      """Retrieves a user by ID."""
 36      return db.query(User).filter(User.id == user_id).first()
 37  
 38  def get_api_key(db: Session, key: str) -> Optional[ApiKey]:
 39      """Retrieves information for an API key."""
 40      return db.query(ApiKey).filter(ApiKey.key == key).first()
 41  
 42  def get_api_keys_for_user(db: Session, user_id: str):
 43      """Retrieves all API keys for a user."""
 44      return db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
 45  
 46  def create_user(db: Session, user: User) -> User:
 47      """Creates a new user in the database."""
 48      db.add(user)
 49      db.commit()
 50      db.refresh(user)
 51      return user
 52  
 53  def create_api_key(db: Session, api_key: ApiKey) -> ApiKey:
 54      """Creates a new API key."""
 55      db.add(api_key)
 56      db.commit()
 57      db.refresh(api_key)
 58      return api_key
 59  
 60  def update_api_key(db: Session, api_key: ApiKey) -> ApiKey:
 61      """Updates an existing API key."""
 62      db.merge(api_key)
 63      db.commit()
 64      db.refresh(api_key)
 65      return api_key
 66  
 67  def delete_api_key(db: Session, key: str) -> bool:
 68      """Deletes an API key."""
 69      api_key = get_api_key(db, key)
 70      if api_key:
 71          db.delete(api_key)
 72          db.commit()
 73          return True
 74      return False
 75  
 76  def record_api_usage(db: Session, record: UsageRecord) -> None:
 77      """Records API usage."""
 78      db.add(record)
 79      db.commit()
 80  
 81  def get_user_usage_stats(db: Session, user_id: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None) -> dict:
 82      """
 83      Retrieves usage statistics for a user.
 84      If start_date and end_date are provided, filters records accordingly.
 85      """
 86      query = db.query(UsageRecord).filter(UsageRecord.user_id == user_id)
 87      if start_date:
 88          query = query.filter(UsageRecord.timestamp >= start_date)
 89      if end_date:
 90          query = query.filter(UsageRecord.timestamp <= end_date)
 91      records = query.all()
 92      
 93      total_requests = len(records)
 94      total_tokens_input = sum(record.tokens_input for record in records)
 95      total_tokens_output = sum(record.tokens_output for record in records)
 96      avg_processing_time = sum(record.processing_time for record in records) / total_requests if total_requests else 0
 97  
 98      daily_stats = {}
 99      for record in records:
100          day = record.timestamp.strftime("%Y-%m-%d")
101          if day not in daily_stats:
102              daily_stats[day] = {
103                  "requests": 0,
104                  "tokens_input": 0,
105                  "tokens_output": 0,
106                  "processing_time": 0
107              }
108          daily_stats[day]["requests"] += 1
109          daily_stats[day]["tokens_input"] += record.tokens_input
110          daily_stats[day]["tokens_output"] += record.tokens_output
111          daily_stats[day]["processing_time"] += record.processing_time
112  
113      return {
114          "total_requests": total_requests,
115          "total_tokens_input": total_tokens_input,
116          "total_tokens_output": total_tokens_output,
117          "avg_processing_time": avg_processing_time,
118          "daily_stats": daily_stats
119      }
120  def get_all_users(db: Session):
121      """Retrieves all users."""
122      return db.query(User).all()
123  def get_all_api_keys(db: Session):
124      """Retrieves all API keys."""
125      return db.query(ApiKey).all()
126  def get_all_usage_records(db: Session):
127      """Retrieves all usage records."""
128      return db.query(UsageRecord).all()
129  def delete_user(db: Session, user_id: str) -> bool:
130      """Deletes a user."""
131      user = get_user_by_id(db, user_id)
132      if user:
133          db.delete(user)
134          db.commit()
135          return True
136      return False
137  def delete_all_api_keys(db: Session):
138      """Deletes all API keys."""
139      db.query(ApiKey).delete()
140      db.commit()
141  def delete_all_users(db: Session):
142      """Deletes all users."""
143      db.query(User).delete()
144      db.commit()
145  def delete_all_usage_records(db: Session):
146      """Deletes all usage records."""
147      db.query(UsageRecord).delete()
148      db.commit()
149  def delete_all_data(db: Session):
150      """Deletes all data from the database."""
151      delete_all_api_keys(db)
152      delete_all_users(db)
153      delete_all_usage_records(db)
154  def delete_all_data_with_confirmation(db: Session, confirmation_code: str) -> bool:
155      """Deletes all data from the database with a confirmation code."""
156      if confirmation_code == config.confirmation_code:
157          delete_all_data(db)
158          return True
159      return False
160  def get_user_by_email(db: Session, email: str) -> Optional[User]:
161      """Retrieves a user by email address."""
162      return db.query(User).filter(User.email == email).first()
163  def get_api_key_by_id(db: Session, key_id: str) -> Optional[ApiKey]:
164      """Retrieves an API key by ID."""
165      return db.query(ApiKey).filter(ApiKey.id == key_id).first()
166  def get_usage_records_by_user(db: Session, user_id: str):
167      """Retrieves usage records for a user."""
168      return db.query(UsageRecord).filter(UsageRecord.user_id == user_id).all()
169  def get_usage_records_by_api_key(db: Session, api_key_id: str):
170      """Retrieves usage records for an API key."""
171      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id).all()
172  def get_usage_records_by_date(db: Session, date: datetime):
173      """Retrieves usage records by date."""
174      return db.query(UsageRecord).filter(UsageRecord.timestamp >= date).all()
175  def get_usage_records_by_date_range(db: Session, start_date: datetime, end_date: datetime):
176      """Retrieves usage records by date range."""
177      return db.query(UsageRecord).filter(UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all()
178  def get_usage_records_by_api_key_and_date(db: Session, api_key_id: str, date: datetime):
179      """Retrieves usage records for an API key by date."""
180      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= date).all()
181  def get_usage_records_by_api_key_and_date_range(db: Session, api_key_id: str, start_date: datetime, end_date: datetime):
182      """Retrieves usage records for an API key by date range."""
183      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all()
184  def get_usage_records_by_user_and_date(db: Session, user_id: str, date: datetime):
185      """Retrieves usage records for a user by date."""
186      return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.timestamp >= date).all()
187  def get_usage_records_by_user_and_date_range(db: Session, user_id: str, start_date: datetime, end_date: datetime):
188      """Retrieves usage records for a user by date range."""
189      return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all()
190  def get_usage_records_by_api_key_and_user(db: Session, api_key_id: str, user_id: str):
191      """Retrieves usage records for an API key by user."""
192      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id).all()
193  def get_usage_records_by_api_key_and_user_and_date(db: Session, api_key_id: str, user_id: str, date: datetime):
194      """Retrieves usage records for an API key by user and date."""
195      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id, UsageRecord.timestamp >= date).all()
196  def get_usage_records_by_api_key_and_user_and_date_range(db: Session, api_key_id: str, user_id: str, start_date: datetime, end_date: datetime):
197      """Retrieves usage records for an API key by user and date range."""
198      return db.query(UsageRecord).filter(UsageRecord.api_key_id == api_key_id, UsageRecord.user_id == user_id, UsageRecord.timestamp >= start_date, UsageRecord.timestamp <= end_date).all()
199  def get_usage_records_by_user_and_api_key(db: Session, user_id: str, api_key_id: str):
200      """Retrieves usage records for a user by API key."""
201      return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.api_key_id == api_key_id).all()
202  def get_usage_records_by_user_and_api_key_and_date(db: Session, user_id: str, api_key_id: str, date: datetime):
203      """Retrieves usage records for a user by API key and date."""
204      return db.query(UsageRecord).filter(UsageRecord.user_id == user_id, UsageRecord.api_key_id == api_key_id, UsageRecord.timestamp >= date).all()