base_repository.py
1 """ 2 Base repository classes with proper transaction management. 3 4 This module provides base classes for repositories that follow FastAPI best practices 5 for database session management and transaction handling. 6 """ 7 8 from abc import ABC, abstractmethod 9 from typing import Any, Generic, TypeVar 10 11 from sqlalchemy import inspect 12 from sqlalchemy.orm import Session 13 from solace_ai_connector.common.observability import DBMonitor, MonitorLatency 14 15 from ..exceptions.exceptions import EntityNotFoundError 16 17 T = TypeVar("T") 18 ModelType = TypeVar("ModelType") 19 EntityType = TypeVar("EntityType") 20 21 22 class BaseRepository(ABC, Generic[ModelType, EntityType]): 23 """ 24 Abstract base class for repositories with common database operations. 25 26 This base class provides common patterns for database operations 27 without manual transaction management, following the principle that 28 transactions should be handled at the service/API layer. 29 """ 30 31 def __init__(self, model_class: type[ModelType], entity_class: type[EntityType]): 32 """ 33 Initialize repository with model and entity classes. 34 35 Args: 36 model_class: SQLAlchemy model class 37 entity_class: Pydantic entity class 38 """ 39 self.model_class = model_class 40 self.entity_class = entity_class 41 42 @property 43 @abstractmethod 44 def entity_name(self) -> str: 45 """Return the entity name for error messages.""" 46 pass 47 48 @property 49 def table_name(self) -> str: 50 """Return the database table name for observability.""" 51 return self.model_class.__tablename__ 52 53 def create(self, session: Session, create_data: dict[str, Any]) -> EntityType: 54 """ 55 Create a new entity. 56 57 Args: 58 session: Database session (managed externally) 59 create_data: Data for creating the entity 60 61 Returns: 62 Created entity 63 64 Note: 65 This method does NOT commit the transaction. 66 Commit/rollback is handled by the service layer. 67 """ 68 with MonitorLatency(DBMonitor.insert(self.table_name)): 69 model_instance = self.model_class(**create_data) 70 session.add(model_instance) 71 session.flush() # Flush to get generated IDs 72 session.refresh(model_instance) 73 74 entity = self.entity_class.model_validate(model_instance) 75 76 return entity 77 78 def get_by_id(self, session: Session, entity_id: Any) -> EntityType: 79 """ 80 Get entity by ID. 81 82 Args: 83 session: Database session 84 entity_id: Entity identifier 85 86 Returns: 87 Entity instance 88 89 Raises: 90 EntityNotFoundError: If entity not found 91 """ 92 with MonitorLatency(DBMonitor.query(self.table_name)): 93 model_instance = ( 94 session.query(self.model_class) 95 .filter(self.model_class.id == str(entity_id)) 96 .first() 97 ) 98 99 if not model_instance: 100 raise EntityNotFoundError(self.entity_name, entity_id) 101 102 return self.entity_class.model_validate(model_instance) 103 104 def get_all( 105 self, session: Session, limit: int | None = None, offset: int | None = None 106 ) -> list[EntityType]: 107 """ 108 Get all entities with optional pagination. 109 110 Args: 111 session: Database session 112 limit: Maximum number of results 113 offset: Number of results to skip 114 115 Returns: 116 List of entities 117 """ 118 with MonitorLatency(DBMonitor.query(self.table_name)): 119 query = session.query(self.model_class) 120 121 if offset: 122 query = query.offset(offset) 123 if limit: 124 query = query.limit(limit) 125 126 model_instances = query.all() 127 128 return [ 129 self.entity_class.model_validate(instance) for instance in model_instances 130 ] 131 132 def update( 133 self, session: Session, entity_id: Any, update_data: dict[str, Any] 134 ) -> EntityType: 135 """ 136 Update an entity. 137 138 Args: 139 session: Database session 140 entity_id: Entity identifier 141 update_data: Data to update 142 143 Returns: 144 Updated entity 145 146 Raises: 147 EntityNotFoundError: If entity not found 148 """ 149 with MonitorLatency(DBMonitor.update(self.table_name)): 150 model_instance = ( 151 session.query(self.model_class) 152 .filter(self.model_class.id == str(entity_id)) 153 .first() 154 ) 155 156 if not model_instance: 157 raise EntityNotFoundError(self.entity_name, entity_id) 158 159 for key, value in update_data.items(): 160 if value is not None and hasattr(model_instance, key): 161 setattr(model_instance, key, value) 162 163 session.flush() # Flush to validate constraints 164 session.refresh(model_instance) 165 166 entity = self.entity_class.model_validate(model_instance) 167 168 return entity 169 170 def delete(self, session: Session, entity_id: Any) -> None: 171 """ 172 Delete an entity. 173 174 Args: 175 session: Database session 176 entity_id: Entity identifier 177 178 Raises: 179 EntityNotFoundError: If entity not found 180 """ 181 with MonitorLatency(DBMonitor.delete(self.table_name)): 182 # with_for_update() acquires a row lock before the DELETE: 183 # - Ensures the selected row is locked against concurrent conflicting writes. 184 # - Helps prevent concurrent sessions from inserting child rows (e.g. deployments) 185 # that reference this entity until the transaction completes, reducing the risk of 186 # orphaned FK rows that would block the parent DELETE on MySQL. 187 # Note: with_for_update() does not by itself bypass the identity map or force a state 188 # refresh for already-loaded instances; relationship collections are explicitly expired 189 # below before the cascade is walked. On SQLite, with_for_update() is effectively a 190 # no-op (file-level locking only), but it is safe to call on all dialects. 191 model_instance = ( 192 session.query(self.model_class) 193 .filter(self.model_class.id == str(entity_id)) 194 .with_for_update() 195 .first() 196 ) 197 198 if not model_instance: 199 raise EntityNotFoundError(self.entity_name, entity_id) 200 201 # Expire all relationship collections (uselist=True) so SQLAlchemy 202 # re-fetches them before walking the cascade. Covers same-session 203 # staleness: if a child row was added later in the same session, it may 204 # only exist in the session's pending state and not in the cached 205 # collection, causing the cascade to miss it and the FK to block the 206 # parent DELETE. 207 for rel in inspect(type(model_instance)).relationships: 208 if rel.uselist: 209 session.expire(model_instance, [rel.key]) 210 211 session.delete(model_instance) 212 session.flush() # Flush to validate constraints 213 214 def exists(self, session: Session, entity_id: Any) -> bool: 215 """ 216 Check if an entity exists. 217 218 Args: 219 session: Database session 220 entity_id: Entity identifier 221 222 Returns: 223 True if entity exists, False otherwise 224 """ 225 with MonitorLatency(DBMonitor.query(self.table_name)): 226 count = ( 227 session.query(self.model_class) 228 .filter(self.model_class.id == str(entity_id)) 229 .count() 230 ) 231 232 return count > 0 233 234 def count(self, session: Session) -> int: 235 """ 236 Get total count of entities. 237 238 Args: 239 session: Database session 240 241 Returns: 242 Total number of entities 243 """ 244 # Note: Cannot use decorator here since we need self.table_name dynamically 245 with MonitorLatency(DBMonitor.query(self.table_name)): 246 return session.query(self.model_class).count() 247 248 249 class PaginatedRepository(BaseRepository[ModelType, EntityType]): 250 """ 251 Base repository with pagination support. 252 253 Concrete repositories should implement their own pagination methods 254 that apply specific filters and ordering before pagination. 255 """ 256 257 pass 258 259 260 class ValidationMixin: 261 """ 262 Mixin for repositories that need validation logic. 263 """ 264 265 def validate_create_data(self, create_data: dict[str, Any]) -> None: 266 """ 267 Validate data before creation. 268 269 Args: 270 create_data: Data to validate 271 272 Raises: 273 ValidationError: If validation fails 274 """ 275 pass 276 277 def validate_update_data(self, update_data: dict[str, Any]) -> None: 278 """ 279 Validate data before update. 280 281 Args: 282 update_data: Data to validate 283 284 Raises: 285 ValidationError: If validation fails 286 """ 287 pass