/ sussro_services / services / db_service.py
db_service.py
  1  """Database service for managing database connections and sessions."""
  2  from contextlib import contextmanager
  3  from typing import Generator
  4  
  5  from sqlalchemy.orm import Session, sessionmaker
  6  
  7  from ..db.session import SessionLocal, engine
  8  
  9  # Create a session factory
 10  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
 11  
 12  
 13  def get_db() -> Generator[Session, None, None]:
 14      """Get a database session.
 15      
 16      Yields:
 17          Session: A database session
 18          
 19      Example:
 20          ```python
 21          def some_endpoint(db: Session = Depends(get_db)):
 22              # Use the database session
 23              users = db.query(User).all()
 24              return users
 25          ```
 26      """
 27      db = SessionLocal()
 28      try:
 29          yield db
 30      finally:
 31          db.close()
 32  
 33  
 34  @contextmanager
 35  def get_db_context() -> Generator[Session, None, None]:
 36      """Context manager for database sessions.
 37      
 38      Yields:
 39          Session: A database session
 40          
 41      Example:
 42          ```python
 43          with get_db_context() as db:
 44              # Use the database session
 45              users = db.query(User).all()
 46          ```
 47      """
 48      db = SessionLocal()
 49      try:
 50          yield db
 51      finally:
 52          db.close()
 53  
 54  
 55  def get_db_session() -> Session:
 56      """Get a database session that needs to be closed manually.
 57      
 58      Returns:
 59          Session: A database session
 60          
 61      Note:
 62          It's recommended to use `get_db()` or `get_db_context()` instead,
 63          as they handle session closing automatically.
 64      """
 65      return SessionLocal()
 66  
 67  
 68  def commit_changes(db: Session) -> None:
 69      """Commit changes to the database.
 70      
 71      Args:
 72          db: Database session
 73          
 74      Raises:
 75          Exception: If there's an error committing changes
 76      """
 77      try:
 78          db.commit()
 79      except Exception as e:
 80          db.rollback()
 81          raise e
 82  
 83  
 84  def refresh_object(db: Session, obj: object) -> None:
 85      """Refresh an object from the database.
 86      
 87      Args:
 88          db: Database session
 89          obj: The object to refresh
 90      """
 91      db.refresh(obj)
 92  
 93  
 94  def close_session(db: Session) -> None:
 95      """Close a database session.
 96      
 97      Args:
 98          db: Database session to close
 99      """
100      db.close()