/ src / solace_agent_mesh / shared / database / base_repository.py
base_repository.py
  1  """
  2  Base repository classes with proper transaction management.
  3  
  4  This module provides base classes for repositories that follow FastAPI best practices
  5  for database session management and transaction handling.
  6  """
  7  
  8  from abc import ABC, abstractmethod
  9  from typing import Any, Generic, TypeVar
 10  
 11  from sqlalchemy import inspect
 12  from sqlalchemy.orm import Session
 13  from solace_ai_connector.common.observability import DBMonitor, MonitorLatency
 14  
 15  from ..exceptions.exceptions import EntityNotFoundError
 16  
 17  T = TypeVar("T")
 18  ModelType = TypeVar("ModelType")
 19  EntityType = TypeVar("EntityType")
 20  
 21  
 22  class BaseRepository(ABC, Generic[ModelType, EntityType]):
 23      """
 24      Abstract base class for repositories with common database operations.
 25  
 26      This base class provides common patterns for database operations
 27      without manual transaction management, following the principle that
 28      transactions should be handled at the service/API layer.
 29      """
 30  
 31      def __init__(self, model_class: type[ModelType], entity_class: type[EntityType]):
 32          """
 33          Initialize repository with model and entity classes.
 34  
 35          Args:
 36              model_class: SQLAlchemy model class
 37              entity_class: Pydantic entity class
 38          """
 39          self.model_class = model_class
 40          self.entity_class = entity_class
 41  
 42      @property
 43      @abstractmethod
 44      def entity_name(self) -> str:
 45          """Return the entity name for error messages."""
 46          pass
 47  
 48      @property
 49      def table_name(self) -> str:
 50          """Return the database table name for observability."""
 51          return self.model_class.__tablename__
 52  
 53      def create(self, session: Session, create_data: dict[str, Any]) -> EntityType:
 54          """
 55          Create a new entity.
 56  
 57          Args:
 58              session: Database session (managed externally)
 59              create_data: Data for creating the entity
 60  
 61          Returns:
 62              Created entity
 63  
 64          Note:
 65              This method does NOT commit the transaction.
 66              Commit/rollback is handled by the service layer.
 67          """
 68          with MonitorLatency(DBMonitor.insert(self.table_name)):
 69              model_instance = self.model_class(**create_data)
 70              session.add(model_instance)
 71              session.flush()  # Flush to get generated IDs
 72              session.refresh(model_instance)
 73  
 74          entity = self.entity_class.model_validate(model_instance)
 75  
 76          return entity
 77  
 78      def get_by_id(self, session: Session, entity_id: Any) -> EntityType:
 79          """
 80          Get entity by ID.
 81  
 82          Args:
 83              session: Database session
 84              entity_id: Entity identifier
 85  
 86          Returns:
 87              Entity instance
 88  
 89          Raises:
 90              EntityNotFoundError: If entity not found
 91          """
 92          with MonitorLatency(DBMonitor.query(self.table_name)):
 93              model_instance = (
 94                  session.query(self.model_class)
 95                  .filter(self.model_class.id == str(entity_id))
 96                  .first()
 97              )
 98  
 99          if not model_instance:
100              raise EntityNotFoundError(self.entity_name, entity_id)
101  
102          return self.entity_class.model_validate(model_instance)
103  
104      def get_all(
105          self, session: Session, limit: int | None = None, offset: int | None = None
106      ) -> list[EntityType]:
107          """
108          Get all entities with optional pagination.
109  
110          Args:
111              session: Database session
112              limit: Maximum number of results
113              offset: Number of results to skip
114  
115          Returns:
116              List of entities
117          """
118          with MonitorLatency(DBMonitor.query(self.table_name)):
119              query = session.query(self.model_class)
120  
121              if offset:
122                  query = query.offset(offset)
123              if limit:
124                  query = query.limit(limit)
125  
126              model_instances = query.all()
127  
128          return [
129              self.entity_class.model_validate(instance) for instance in model_instances
130          ]
131  
132      def update(
133          self, session: Session, entity_id: Any, update_data: dict[str, Any]
134      ) -> EntityType:
135          """
136          Update an entity.
137  
138          Args:
139              session: Database session
140              entity_id: Entity identifier
141              update_data: Data to update
142  
143          Returns:
144              Updated entity
145  
146          Raises:
147              EntityNotFoundError: If entity not found
148          """
149          with MonitorLatency(DBMonitor.update(self.table_name)):
150              model_instance = (
151                  session.query(self.model_class)
152                  .filter(self.model_class.id == str(entity_id))
153                  .first()
154              )
155  
156              if not model_instance:
157                  raise EntityNotFoundError(self.entity_name, entity_id)
158  
159              for key, value in update_data.items():
160                  if value is not None and hasattr(model_instance, key):
161                      setattr(model_instance, key, value)
162  
163              session.flush()  # Flush to validate constraints
164              session.refresh(model_instance)
165  
166          entity = self.entity_class.model_validate(model_instance)
167  
168          return entity
169  
170      def delete(self, session: Session, entity_id: Any) -> None:
171          """
172          Delete an entity.
173  
174          Args:
175              session: Database session
176              entity_id: Entity identifier
177  
178          Raises:
179              EntityNotFoundError: If entity not found
180          """
181          with MonitorLatency(DBMonitor.delete(self.table_name)):
182              # with_for_update() acquires a row lock before the DELETE:
183              #   - Ensures the selected row is locked against concurrent conflicting writes.
184              #   - Helps prevent concurrent sessions from inserting child rows (e.g. deployments)
185              #     that reference this entity until the transaction completes, reducing the risk of
186              #     orphaned FK rows that would block the parent DELETE on MySQL.
187              # Note: with_for_update() does not by itself bypass the identity map or force a state
188              # refresh for already-loaded instances; relationship collections are explicitly expired
189              # below before the cascade is walked. On SQLite, with_for_update() is effectively a
190              # no-op (file-level locking only), but it is safe to call on all dialects.
191              model_instance = (
192                  session.query(self.model_class)
193                  .filter(self.model_class.id == str(entity_id))
194                  .with_for_update()
195                  .first()
196              )
197  
198              if not model_instance:
199                  raise EntityNotFoundError(self.entity_name, entity_id)
200  
201              # Expire all relationship collections (uselist=True) so SQLAlchemy
202              # re-fetches them before walking the cascade. Covers same-session
203              # staleness: if a child row was added later in the same session, it may
204              # only exist in the session's pending state and not in the cached
205              # collection, causing the cascade to miss it and the FK to block the
206              # parent DELETE.
207              for rel in inspect(type(model_instance)).relationships:
208                  if rel.uselist:
209                      session.expire(model_instance, [rel.key])
210  
211              session.delete(model_instance)
212              session.flush()  # Flush to validate constraints
213  
214      def exists(self, session: Session, entity_id: Any) -> bool:
215          """
216          Check if an entity exists.
217  
218          Args:
219              session: Database session
220              entity_id: Entity identifier
221  
222          Returns:
223              True if entity exists, False otherwise
224          """
225          with MonitorLatency(DBMonitor.query(self.table_name)):
226              count = (
227                  session.query(self.model_class)
228                  .filter(self.model_class.id == str(entity_id))
229                  .count()
230              )
231  
232          return count > 0
233  
234      def count(self, session: Session) -> int:
235          """
236          Get total count of entities.
237  
238          Args:
239              session: Database session
240  
241          Returns:
242              Total number of entities
243          """
244          # Note: Cannot use decorator here since we need self.table_name dynamically
245          with MonitorLatency(DBMonitor.query(self.table_name)):
246              return session.query(self.model_class).count()
247  
248  
249  class PaginatedRepository(BaseRepository[ModelType, EntityType]):
250      """
251      Base repository with pagination support.
252  
253      Concrete repositories should implement their own pagination methods
254      that apply specific filters and ordering before pagination.
255      """
256  
257      pass
258  
259  
260  class ValidationMixin:
261      """
262      Mixin for repositories that need validation logic.
263      """
264  
265      def validate_create_data(self, create_data: dict[str, Any]) -> None:
266          """
267          Validate data before creation.
268  
269          Args:
270              create_data: Data to validate
271  
272          Raises:
273              ValidationError: If validation fails
274          """
275          pass
276  
277      def validate_update_data(self, update_data: dict[str, Any]) -> None:
278          """
279          Validate data before update.
280  
281          Args:
282              update_data: Data to validate
283  
284          Raises:
285              ValidationError: If validation fails
286          """
287          pass