/ restai / database.py
database.py
   1  import logging
   2  from sqlalchemy import create_engine, func, or_
   3  from restai import config
   4  from datetime import datetime, timezone
   5  from restai.models.databasemodels import (
   6      ApiKeyDatabase,
   7      LLMDatabase,
   8      EmbeddingDatabase,
   9      OutputDatabase,
  10      ProjectDatabase,
  11      ProjectToolDatabase,
  12      ProjectRoutineDatabase,
  13      CronLogDatabase,
  14      SettingDatabase,
  15      UserDatabase,
  16      TeamDatabase,
  17      TeamImageGeneratorDatabase,
  18      TeamAudioGeneratorDatabase,
  19      WidgetDatabase,
  20      ImageGeneratorDatabase,
  21      SpeechToTextDatabase,
  22      ProjectSecretDatabase,
  23  )
  24  from restai.models.models import (
  25      LLMModel,
  26      LLMUpdate,
  27      ProjectModelUpdate,
  28      User,
  29      UserUpdate,
  30      EmbeddingModel,
  31      EmbeddingUpdate,
  32      TeamModel,
  33      TeamModelUpdate,
  34      TeamModelCreate,
  35  )
  36  from sqlalchemy.orm import sessionmaker, Session
  37  import bcrypt
  38  from typing import Optional, List
  39  from restai.config import MYSQL_HOST, MYSQL_URL, POSTGRES_HOST, POSTGRES_URL
  40  import json
  41  from restai.utils.crypto import decrypt_api_key, hash_api_key, verify_api_key_hash
  42  
  43  import logging as _logging
  44  _db_logger = _logging.getLogger(__name__)
  45  
  46  if MYSQL_HOST:
  47      _db_logger.info("Using MySQL database.")
  48      engine = create_engine(
  49          MYSQL_URL,
  50          pool_size=config.DB_POOL_SIZE,
  51          max_overflow=config.DB_MAX_OVERFLOW,
  52          pool_recycle=config.DB_POOL_RECYCLE,
  53      )
  54  elif POSTGRES_HOST:
  55      _db_logger.info("Using PostgreSQL database.")
  56      engine = create_engine(
  57          POSTGRES_URL,
  58          pool_size=config.DB_POOL_SIZE,
  59          max_overflow=config.DB_MAX_OVERFLOW,
  60          pool_recycle=config.DB_POOL_RECYCLE,
  61      )
  62  else:
  63      _db_logger.info("Using sqlite database.")
  64      engine = create_engine(
  65          "sqlite:///./restai.db",
  66          connect_args={"check_same_thread": False},
  67          pool_size=config.DB_POOL_SIZE,
  68          max_overflow=config.DB_POOL_RECYCLE,
  69          pool_recycle=config.DB_POOL_RECYCLE,
  70      )
  71  
  72  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
  73  def hash_password(password: str) -> str:
  74      return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
  75  
  76  def verify_password(password: str, hashed: str) -> bool:
  77      return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8'))
  78  
  79  
  80  class DBWrapper:
  81      __slots__ = ("db",)
  82  
  83      def __init__(self):
  84          self.db: Session = SessionLocal()
  85  
  86      def create_user(
  87          self,
  88          username: str,
  89          password: Optional[str],
  90          admin: bool = False,
  91          private: bool = False,
  92          restricted: bool = False,
  93      ) -> UserDatabase:
  94          from datetime import datetime, timezone
  95          password_hash: Optional[str]
  96          if password:
  97              password_hash = hash_password(password)
  98              password_updated_at = datetime.now(timezone.utc)
  99          else:
 100              password_hash = None
 101              password_updated_at = None
 102          db_user: UserDatabase = UserDatabase(
 103              username=username,
 104              hashed_password=password_hash,
 105              password_updated_at=password_updated_at,
 106              is_admin=admin,
 107              is_private=private,
 108              is_restricted=restricted,
 109              options='{"credit": -1.0}',
 110          )
 111          self.db.add(db_user)
 112          self.db.commit()
 113          self.db.refresh(db_user)
 114          return db_user
 115  
 116      def create_llm(
 117          self,
 118          name: str,
 119          class_name: str,
 120          options: str,
 121          privacy: str,
 122          description: str,
 123          context_window: int = 4096,
 124          input_cost: float = 0.0,
 125          output_cost: float = 0.0,
 126      ) -> LLMDatabase:
 127          # Encrypt sensitive fields (api_key) in the options JSON
 128          from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS
 129          import json as _json
 130          try:
 131              opts_dict = _json.loads(options) if isinstance(options, str) else options
 132              opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS)
 133              options = _json.dumps(opts_dict)
 134          except Exception as e:
 135              logging.warning("Failed to encrypt LLM options: %s", e)
 136  
 137          db_llm: LLMDatabase = LLMDatabase(
 138              name=name,
 139              class_name=class_name,
 140              options=options,
 141              privacy=privacy,
 142              description=description,
 143              context_window=context_window,
 144              input_cost=input_cost,
 145              output_cost=output_cost,
 146          )
 147          self.db.add(db_llm)
 148          self.db.commit()
 149          self.db.refresh(db_llm)
 150          return db_llm
 151  
 152      def create_embedding(
 153          self,
 154          name: str,
 155          class_name: str,
 156          options: str,
 157          privacy: str,
 158          description: str,
 159          dimension: int,
 160      ) -> EmbeddingDatabase:
 161          db_embedding: EmbeddingDatabase = EmbeddingDatabase(
 162              name=name,
 163              class_name=class_name,
 164              options=options,
 165              privacy=privacy,
 166              description=description,
 167              dimension=dimension,
 168          )
 169          self.db.add(db_embedding)
 170          self.db.commit()
 171          self.db.refresh(db_embedding)
 172          return db_embedding
 173  
 174      def get_users(self) -> list[UserDatabase]:
 175          users: list[UserDatabase] = self.db.query(UserDatabase).all()
 176          return users
 177  
 178      def get_llms(self) -> list[LLMDatabase]:
 179          llms: list[LLMDatabase] = self.db.query(LLMDatabase).all()
 180          return llms
 181  
 182      def get_embeddings(self) -> list[EmbeddingDatabase]:
 183          embeddings: list[EmbeddingDatabase] = self.db.query(EmbeddingDatabase).all()
 184          return embeddings
 185  
 186      def get_llm_by_name(self, name: str) -> Optional[LLMDatabase]:
 187          llm: Optional[LLMDatabase] = (
 188              self.db.query(LLMDatabase).filter(LLMDatabase.name == name).first()
 189          )
 190          return llm
 191  
 192      def get_llm_by_id(self, id: int) -> Optional[LLMDatabase]:
 193          return self.db.query(LLMDatabase).filter(LLMDatabase.id == id).first()
 194  
 195      def get_embedding_by_name(self, name: str) -> Optional[EmbeddingDatabase]:
 196          llm: Optional[EmbeddingDatabase] = (
 197              self.db.query(EmbeddingDatabase)
 198              .filter(EmbeddingDatabase.name == name)
 199              .first()
 200          )
 201          return llm
 202  
 203      def get_embedding_by_id(self, id: int) -> Optional[EmbeddingDatabase]:
 204          return self.db.query(EmbeddingDatabase).filter(EmbeddingDatabase.id == id).first()
 205  
 206      def get_user_by_apikey(self, apikey: str):
 207          """Returns (UserDatabase, ApiKeyDatabase) or (UserDatabase, None) for legacy, or (None, None)."""
 208          # Lookup by key_prefix, then verify the salted hash
 209          prefix = apikey[:8]
 210          candidates = (
 211              self.db.query(ApiKeyDatabase)
 212              .filter(ApiKeyDatabase.key_prefix == prefix)
 213              .all()
 214          )
 215          for api_key_row in candidates:
 216              if verify_api_key_hash(apikey, api_key_row.key_hash):
 217                  return api_key_row.user, api_key_row
 218          # Fallback: check legacy api_key column for migration period
 219          for user in self.db.query(UserDatabase).filter(UserDatabase.api_key.isnot(None)):
 220              try:
 221                  if decrypt_api_key(user.api_key) == apikey:
 222                      return user, None
 223              except Exception:
 224                  continue
 225          return None, None
 226  
 227      def get_user_by_username(self, username: str) -> Optional[UserDatabase]:
 228          user: Optional[UserDatabase] = (
 229              self.db.query(UserDatabase)
 230              .filter(UserDatabase.username == username)
 231              .first()
 232          )
 233          return user
 234  
 235      def create_api_key(self, user_id: int, encrypted_key: str, key_hash: str, key_prefix: str, description: str, allowed_projects: str = None, read_only: bool = False) -> ApiKeyDatabase:
 236          api_key = ApiKeyDatabase(
 237              user_id=user_id,
 238              encrypted_key=encrypted_key,
 239              key_hash=key_hash,
 240              key_prefix=key_prefix,
 241              description=description,
 242              created_at=datetime.now(timezone.utc),
 243              allowed_projects=allowed_projects,
 244              read_only=read_only,
 245          )
 246          self.db.add(api_key)
 247          self.db.commit()
 248          self.db.refresh(api_key)
 249          return api_key
 250  
 251      def get_api_keys_for_user(self, user_id: int) -> list[ApiKeyDatabase]:
 252          return (
 253              self.db.query(ApiKeyDatabase)
 254              .filter(ApiKeyDatabase.user_id == user_id)
 255              .order_by(ApiKeyDatabase.created_at.desc())
 256              .all()
 257          )
 258  
 259      def delete_api_key(self, api_key_id: int, user_id: int) -> bool:
 260          api_key = (
 261              self.db.query(ApiKeyDatabase)
 262              .filter(ApiKeyDatabase.id == api_key_id, ApiKeyDatabase.user_id == user_id)
 263              .first()
 264          )
 265          if api_key is None:
 266              return False
 267          self.db.delete(api_key)
 268          self.db.commit()
 269          return True
 270  
 271      # ── Widget methods ─────────────────────────────────────────────────
 272  
 273      def create_widget(self, project_id, creator_id, encrypted_key, key_hash, key_prefix, name, config_json, allowed_domains_json):
 274          now = datetime.now(timezone.utc)
 275          widget = WidgetDatabase(
 276              project_id=project_id,
 277              creator_id=creator_id,
 278              encrypted_key=encrypted_key,
 279              key_hash=key_hash,
 280              key_prefix=key_prefix,
 281              name=name,
 282              config=config_json,
 283              allowed_domains=allowed_domains_json,
 284              enabled=True,
 285              created_at=now,
 286              updated_at=now,
 287          )
 288          self.db.add(widget)
 289          self.db.commit()
 290          self.db.refresh(widget)
 291          return widget
 292  
 293      def get_widget_by_id(self, widget_id):
 294          return self.db.query(WidgetDatabase).filter(WidgetDatabase.id == widget_id).first()
 295  
 296      def get_widget_by_key_hash(self, key_hash):
 297          return self.db.query(WidgetDatabase).filter(WidgetDatabase.key_hash == key_hash).first()
 298  
 299      def get_widget_by_key(self, plaintext_key):
 300          """Look up a widget by plaintext key using prefix-then-verify (salted hash)."""
 301          prefix = plaintext_key[:11]
 302          candidates = self.db.query(WidgetDatabase).filter(WidgetDatabase.key_prefix == prefix).all()
 303          for w in candidates:
 304              if verify_api_key_hash(plaintext_key, w.key_hash):
 305                  return w
 306          return None
 307  
 308      def get_widgets_for_project(self, project_id):
 309          return (
 310              self.db.query(WidgetDatabase)
 311              .filter(WidgetDatabase.project_id == project_id)
 312              .order_by(WidgetDatabase.created_at.desc())
 313              .all()
 314          )
 315  
 316      def delete_widget(self, widget):
 317          self.db.delete(widget)
 318          self.db.commit()
 319          return True
 320  
 321      def update_user(self, user: User, user_update: UserUpdate) -> bool:
 322          if user_update.password is not None:
 323              from datetime import datetime, timezone
 324              user.hashed_password = hash_password(user_update.password)
 325              user.password_updated_at = datetime.now(timezone.utc)
 326  
 327          if user_update.is_admin is not None:
 328              user.is_admin = user_update.is_admin
 329  
 330          if user_update.is_private is not None:
 331              user.is_private = user_update.is_private
 332  
 333          if user_update.is_restricted is not None:
 334              user.is_restricted = user_update.is_restricted
 335  
 336          if hasattr(user_update, "options") and user_update.options is not None:
 337              try:
 338                  current_options = json.loads(user.options) if user.options else {}
 339                  new_options = user_update.options.model_dump()
 340                  if current_options != new_options:
 341                      user.options = json.dumps(new_options)
 342              except json.JSONDecodeError:
 343                  user.options = json.dumps(user_update.options.model_dump())
 344  
 345          self.db.commit()
 346          return True
 347  
 348      def update_llm(self, llm: LLMModel, llmUpdate: LLMUpdate) -> bool:
 349          if llmUpdate.class_name is not None and llm.class_name != llmUpdate.class_name:
 350              llm.class_name = llmUpdate.class_name
 351  
 352          if llmUpdate.options is not None and llm.options != llmUpdate.options:
 353              # Encrypt sensitive fields (api_key) before persisting
 354              from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS
 355              import json as _json
 356              try:
 357                  opts_dict = _json.loads(llmUpdate.options) if isinstance(llmUpdate.options, str) else llmUpdate.options
 358                  # If api_key is the masked value, preserve the existing one
 359                  if opts_dict.get("api_key") == "********":
 360                      existing = _json.loads(llm.options) if isinstance(llm.options, str) else (llm.options or {})
 361                      if "api_key" in existing:
 362                          opts_dict["api_key"] = existing["api_key"]
 363                      else:
 364                          del opts_dict["api_key"]
 365                  opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS)
 366                  llm.options = _json.dumps(opts_dict) if isinstance(llmUpdate.options, str) else opts_dict
 367              except Exception as e:
 368                  logging.warning("Failed to encrypt LLM options on update: %s", e)
 369                  llm.options = llmUpdate.options
 370  
 371          if llmUpdate.privacy is not None and llm.privacy != llmUpdate.privacy:
 372              llm.privacy = llmUpdate.privacy
 373  
 374          if (
 375              llmUpdate.description is not None
 376              and llm.description != llmUpdate.description
 377          ):
 378              llm.description = llmUpdate.description
 379  
 380          if llmUpdate.input_cost is not None and llm.input_cost != llmUpdate.input_cost:
 381              llm.input_cost = llmUpdate.input_cost
 382  
 383          if (
 384              llmUpdate.output_cost is not None
 385              and llm.output_cost != llmUpdate.output_cost
 386          ):
 387              llm.output_cost = llmUpdate.output_cost
 388  
 389          if (
 390              llmUpdate.context_window is not None
 391              and llm.context_window != llmUpdate.context_window
 392          ):
 393              llm.context_window = llmUpdate.context_window
 394  
 395          self.db.commit()
 396          return True
 397  
 398      def update_embedding(
 399          self, embedding: EmbeddingModel, embeddingUpdate: EmbeddingUpdate
 400      ) -> bool:
 401          if (
 402              embeddingUpdate.class_name is not None
 403              and embedding.class_name != embeddingUpdate.class_name
 404          ):
 405              embedding.class_name = embeddingUpdate.class_name
 406  
 407          if (
 408              embeddingUpdate.options is not None
 409              and embedding.options != embeddingUpdate.options
 410          ):
 411              # If api_key is the masked value, preserve the existing one
 412              import json as _json
 413              try:
 414                  new_opts = _json.loads(embeddingUpdate.options) if isinstance(embeddingUpdate.options, str) else (embeddingUpdate.options or {})
 415                  if new_opts.get("api_key") == "********":
 416                      existing = _json.loads(embedding.options) if isinstance(embedding.options, str) else (embedding.options or {})
 417                      if "api_key" in existing:
 418                          new_opts["api_key"] = existing["api_key"]
 419                      else:
 420                          del new_opts["api_key"]
 421                      embeddingUpdate.options = _json.dumps(new_opts)
 422              except Exception:
 423                  pass
 424              embedding.options = embeddingUpdate.options
 425  
 426          if (
 427              embeddingUpdate.privacy is not None
 428              and embedding.privacy != embeddingUpdate.privacy
 429          ):
 430              embedding.privacy = embeddingUpdate.privacy
 431  
 432          if (
 433              embeddingUpdate.description is not None
 434              and embedding.description != embeddingUpdate.description
 435          ):
 436              embedding.description = embeddingUpdate.description
 437  
 438          if (
 439              embeddingUpdate.dimension is not None
 440              and embedding.dimension != embeddingUpdate.dimension
 441          ):
 442              embedding.dimension = embeddingUpdate.dimension
 443  
 444          self.db.commit()
 445          return True
 446  
 447      def delete_llm(self, llm: LLMDatabase) -> bool:
 448          self.db.delete(llm)
 449          self.db.commit()
 450          return True
 451  
 452      def delete_embedding(self, embedding: EmbeddingDatabase) -> bool:
 453          self.db.delete(embedding)
 454          self.db.commit()
 455          return True
 456  
 457      # ─── Image generators ─────────────────────────────────────────────
 458  
 459      def get_image_generators(self) -> list[ImageGeneratorDatabase]:
 460          return self.db.query(ImageGeneratorDatabase).order_by(ImageGeneratorDatabase.name).all()
 461  
 462      def get_image_generator_by_id(self, gen_id: int) -> Optional[ImageGeneratorDatabase]:
 463          return self.db.query(ImageGeneratorDatabase).filter(ImageGeneratorDatabase.id == gen_id).first()
 464  
 465      def get_image_generator_by_name(self, name: str) -> Optional[ImageGeneratorDatabase]:
 466          return self.db.query(ImageGeneratorDatabase).filter(ImageGeneratorDatabase.name == name).first()
 467  
 468      def create_image_generator(
 469          self,
 470          name: str,
 471          class_name: str,
 472          options,
 473          privacy: str = "public",
 474          description: Optional[str] = None,
 475          enabled: bool = True,
 476      ) -> ImageGeneratorDatabase:
 477          from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS
 478          import json as _json
 479  
 480          try:
 481              opts_dict = _json.loads(options) if isinstance(options, str) else (options or {})
 482              opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS)
 483              options_str = _json.dumps(opts_dict)
 484          except Exception as e:
 485              logging.warning("Failed to encrypt image generator options: %s", e)
 486              options_str = options if isinstance(options, str) else _json.dumps(options or {})
 487  
 488          now = datetime.now(timezone.utc)
 489          row = ImageGeneratorDatabase(
 490              name=name,
 491              class_name=class_name,
 492              options=options_str,
 493              privacy=privacy,
 494              description=description,
 495              enabled=enabled,
 496              created_at=now,
 497              updated_at=now,
 498          )
 499          self.db.add(row)
 500          self.db.commit()
 501          self.db.refresh(row)
 502          return row
 503  
 504      def edit_image_generator(self, gen: ImageGeneratorDatabase, update) -> bool:
 505          """Patch an image generator. `update` is an ImageGeneratorModelUpdate.
 506          api_key in `options` is preserved when the submitted value is the
 507          masked sentinel `"********"` (matches the LLM edit pattern)."""
 508          from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS
 509          import json as _json
 510  
 511          changed = False
 512          if update.class_name is not None and gen.class_name != update.class_name:
 513              gen.class_name = update.class_name
 514              changed = True
 515  
 516          if update.options is not None:
 517              try:
 518                  new_opts = _json.loads(update.options) if isinstance(update.options, str) else (update.options or {})
 519                  # Preserve masked sensitive fields
 520                  existing = _json.loads(gen.options) if gen.options else {}
 521                  # The existing values in DB are already encrypted; decrypt
 522                  # them so the comparison is plaintext-vs-plaintext.
 523                  from restai.utils.crypto import decrypt_sensitive_options
 524                  existing_plain = decrypt_sensitive_options(dict(existing), LLM_SENSITIVE_KEYS)
 525                  for k in LLM_SENSITIVE_KEYS:
 526                      val = new_opts.get(k)
 527                      if isinstance(val, str) and val == "********":
 528                          if k in existing_plain:
 529                              new_opts[k] = existing_plain[k]
 530                          else:
 531                              new_opts.pop(k, None)
 532                  new_opts_enc = encrypt_sensitive_options(new_opts, LLM_SENSITIVE_KEYS)
 533                  gen.options = _json.dumps(new_opts_enc)
 534                  changed = True
 535              except Exception as e:
 536                  logging.warning("Failed to update image generator options: %s", e)
 537  
 538          if update.privacy is not None and gen.privacy != update.privacy:
 539              gen.privacy = update.privacy
 540              changed = True
 541  
 542          if update.description is not None and gen.description != update.description:
 543              gen.description = update.description
 544              changed = True
 545  
 546          if update.enabled is not None and gen.enabled != update.enabled:
 547              gen.enabled = update.enabled
 548              changed = True
 549  
 550          if changed:
 551              gen.updated_at = datetime.now(timezone.utc)
 552              self.db.commit()
 553          return changed
 554  
 555      def delete_image_generator(self, gen: ImageGeneratorDatabase) -> bool:
 556          # Also drop any team grants pointing at this name so we don't leave
 557          # dangling rows in the legacy string-keyed teams_image_generators.
 558          try:
 559              self.db.query(TeamImageGeneratorDatabase).filter(
 560                  TeamImageGeneratorDatabase.generator_name == gen.name
 561              ).delete(synchronize_session=False)
 562          except Exception:
 563              pass
 564          self.db.delete(gen)
 565          self.db.commit()
 566          return True
 567  
 568      # ─── Speech-to-text models ────────────────────────────────────────
 569  
 570      def get_speech_to_text(self) -> list[SpeechToTextDatabase]:
 571          return self.db.query(SpeechToTextDatabase).order_by(SpeechToTextDatabase.name).all()
 572  
 573      def get_speech_to_text_by_id(self, model_id: int) -> Optional[SpeechToTextDatabase]:
 574          return self.db.query(SpeechToTextDatabase).filter(SpeechToTextDatabase.id == model_id).first()
 575  
 576      def get_speech_to_text_by_name(self, name: str) -> Optional[SpeechToTextDatabase]:
 577          return self.db.query(SpeechToTextDatabase).filter(SpeechToTextDatabase.name == name).first()
 578  
 579      def create_speech_to_text(
 580          self,
 581          name: str,
 582          class_name: str,
 583          options,
 584          privacy: str = "public",
 585          description: Optional[str] = None,
 586          enabled: bool = True,
 587      ) -> SpeechToTextDatabase:
 588          from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS
 589          import json as _json
 590  
 591          try:
 592              opts_dict = _json.loads(options) if isinstance(options, str) else (options or {})
 593              opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS)
 594              options_str = _json.dumps(opts_dict)
 595          except Exception as e:
 596              logging.warning("Failed to encrypt speech-to-text options: %s", e)
 597              options_str = options if isinstance(options, str) else _json.dumps(options or {})
 598  
 599          now = datetime.now(timezone.utc)
 600          row = SpeechToTextDatabase(
 601              name=name,
 602              class_name=class_name,
 603              options=options_str,
 604              privacy=privacy,
 605              description=description,
 606              enabled=enabled,
 607              created_at=now,
 608              updated_at=now,
 609          )
 610          self.db.add(row)
 611          self.db.commit()
 612          self.db.refresh(row)
 613          return row
 614  
 615      def edit_speech_to_text(self, model: SpeechToTextDatabase, update) -> bool:
 616          """Patch a speech-to-text row. `"********"` in `options.api_key`
 617          preserves the existing value (matches the LLM/image-gen pattern)."""
 618          from restai.utils.crypto import (
 619              decrypt_sensitive_options,
 620              encrypt_sensitive_options,
 621              LLM_SENSITIVE_KEYS,
 622          )
 623          import json as _json
 624  
 625          changed = False
 626          if update.class_name is not None and model.class_name != update.class_name:
 627              model.class_name = update.class_name
 628              changed = True
 629  
 630          if update.options is not None:
 631              try:
 632                  new_opts = _json.loads(update.options) if isinstance(update.options, str) else (update.options or {})
 633                  existing = _json.loads(model.options) if model.options else {}
 634                  existing_plain = decrypt_sensitive_options(dict(existing), LLM_SENSITIVE_KEYS)
 635                  for k in LLM_SENSITIVE_KEYS:
 636                      val = new_opts.get(k)
 637                      if isinstance(val, str) and val == "********":
 638                          if k in existing_plain:
 639                              new_opts[k] = existing_plain[k]
 640                          else:
 641                              new_opts.pop(k, None)
 642                  new_opts_enc = encrypt_sensitive_options(new_opts, LLM_SENSITIVE_KEYS)
 643                  model.options = _json.dumps(new_opts_enc)
 644                  changed = True
 645              except Exception as e:
 646                  logging.warning("Failed to update speech-to-text options: %s", e)
 647  
 648          if update.privacy is not None and model.privacy != update.privacy:
 649              model.privacy = update.privacy
 650              changed = True
 651          if update.description is not None and model.description != update.description:
 652              model.description = update.description
 653              changed = True
 654          if update.enabled is not None and model.enabled != update.enabled:
 655              model.enabled = update.enabled
 656              changed = True
 657  
 658          if changed:
 659              model.updated_at = datetime.now(timezone.utc)
 660              self.db.commit()
 661          return changed
 662  
 663      def delete_speech_to_text(self, model: SpeechToTextDatabase) -> bool:
 664          # Drop team grants pointing at this name (string-keyed table).
 665          try:
 666              self.db.query(TeamAudioGeneratorDatabase).filter(
 667                  TeamAudioGeneratorDatabase.generator_name == model.name
 668              ).delete(synchronize_session=False)
 669          except Exception:
 670              pass
 671          self.db.delete(model)
 672          self.db.commit()
 673          return True
 674  
 675      # ─── Project secrets (agentic browser vault) ─────────────────────
 676  
 677      def get_project_secrets(self, project_id: int) -> list[ProjectSecretDatabase]:
 678          return (
 679              self.db.query(ProjectSecretDatabase)
 680              .filter(ProjectSecretDatabase.project_id == project_id)
 681              .order_by(ProjectSecretDatabase.name)
 682              .all()
 683          )
 684  
 685      def get_project_secret_by_id(self, secret_id: int) -> Optional[ProjectSecretDatabase]:
 686          return self.db.query(ProjectSecretDatabase).filter(ProjectSecretDatabase.id == secret_id).first()
 687  
 688      def get_project_secret_by_name(self, project_id: int, name: str) -> Optional[ProjectSecretDatabase]:
 689          return (
 690              self.db.query(ProjectSecretDatabase)
 691              .filter(ProjectSecretDatabase.project_id == project_id, ProjectSecretDatabase.name == name)
 692              .first()
 693          )
 694  
 695      def create_project_secret(
 696          self,
 697          project_id: int,
 698          name: str,
 699          value: str,
 700          description: Optional[str] = None,
 701      ) -> ProjectSecretDatabase:
 702          from restai.utils.crypto import encrypt_field
 703          now = datetime.now(timezone.utc)
 704          row = ProjectSecretDatabase(
 705              project_id=project_id,
 706              name=name,
 707              value=encrypt_field(value),
 708              description=description,
 709              created_at=now,
 710              updated_at=now,
 711          )
 712          self.db.add(row)
 713          self.db.commit()
 714          self.db.refresh(row)
 715          return row
 716  
 717      def edit_project_secret(self, secret: ProjectSecretDatabase, update) -> bool:
 718          """Patch a project secret. Value `"********"` preserves the
 719          existing stored value (same mask-round-trip as LLMs)."""
 720          from restai.utils.crypto import encrypt_field
 721  
 722          changed = False
 723          if update.value is not None and update.value != "********":
 724              secret.value = encrypt_field(update.value)
 725              changed = True
 726          if update.description is not None and secret.description != update.description:
 727              secret.description = update.description
 728              changed = True
 729          if changed:
 730              secret.updated_at = datetime.now(timezone.utc)
 731              self.db.commit()
 732          return changed
 733  
 734      def delete_project_secret(self, secret: ProjectSecretDatabase) -> bool:
 735          self.db.delete(secret)
 736          self.db.commit()
 737          return True
 738  
 739      def resolve_project_secret(self, project_id: int, name: str) -> Optional[str]:
 740          """Server-side plaintext resolution — only called from inside a tool
 741          (e.g. `browser_fill`). Returns None when the secret doesn't exist.
 742          The plaintext never crosses back into LLM context because callers
 743          pass it directly to the micro-server, not back to the agent."""
 744          from restai.utils.crypto import decrypt_field
 745          row = self.get_project_secret_by_name(project_id, name)
 746          if row is None or not row.value:
 747              return None
 748          try:
 749              return decrypt_field(row.value)
 750          except Exception:
 751              return None
 752  
 753      # ─────────────────────────────────────────────────────────────────────
 754  
 755      def get_user_by_id(self, user_id: int) -> Optional[UserDatabase]:
 756          user: Optional[UserDatabase] = (
 757              self.db.query(UserDatabase).filter(UserDatabase.id == user_id).first()
 758          )
 759          return user
 760  
 761      def delete_user(self, user: UserDatabase) -> bool:
 762          self.db.delete(user)
 763          self.db.commit()
 764          return True
 765  
 766      def get_project_by_name(self, name: str) -> Optional[ProjectDatabase]:
 767          project: Optional[ProjectDatabase] = (
 768              self.db.query(ProjectDatabase).filter(ProjectDatabase.name == name).first()
 769          )
 770          return project
 771  
 772      def get_project_by_id(self, id: int) -> Optional[ProjectDatabase]:
 773          project: Optional[ProjectDatabase] = (
 774              self.db.query(ProjectDatabase).filter(ProjectDatabase.id == id).first()
 775          )
 776          return project
 777  
 778      def create_project(
 779          self,
 780          name: str,
 781          embeddings: str,
 782          llm: str,
 783          vectorstore: str,
 784          human_name: str,
 785          human_description: str,
 786          project_type: str,
 787          creator: int,
 788          team_id: int,
 789      ) -> Optional[ProjectDatabase]:
 790          # Validate that the team exists
 791          team = self.get_team_by_id(team_id)
 792          if team is None:
 793              return None
 794              
 795          # Validate that the team has access to the specified LLM (block projects don't need one)
 796          if llm:
 797              llm_db = self.get_llm_by_name(llm)
 798              if llm_db is None or llm_db not in team.llms:
 799                  return None
 800              
 801          # If embeddings are specified, validate that the team has access to it
 802          if embeddings:
 803              embedding_db = self.get_embedding_by_name(embeddings)
 804              if embedding_db is None or embedding_db not in team.embeddings:
 805                  return None
 806                  
 807          # Look up the creator so we can associate them in the same transaction
 808          creator_user = self.get_user_by_id(creator) if creator else None
 809  
 810          db_project: ProjectDatabase = ProjectDatabase(
 811              name=name,
 812              embeddings=embeddings,
 813              llm=llm,
 814              vectorstore=vectorstore,
 815              human_name=human_name,
 816              human_description=human_description,
 817              type=project_type,
 818              creator=creator,
 819              options='{"logging": true}',
 820          )
 821          self.db.add(db_project)
 822  
 823          # Associate with team and creator in ONE transaction so we never
 824          # end up with a project row that has no users.
 825          if db_project not in team.projects:
 826              team.projects.append(db_project)
 827          if creator_user and db_project not in creator_user.projects:
 828              creator_user.projects.append(db_project)
 829  
 830          self.db.commit()
 831          self.db.refresh(db_project)
 832  
 833          return db_project
 834  
 835      def delete_project(self, project: ProjectDatabase) -> bool:
 836          self.db.delete(project)
 837          self.db.commit()
 838          return True
 839  
 840      def edit_project(self, id: int, projectModel: ProjectModelUpdate) -> bool:
 841          proj_db: Optional[ProjectDatabase] = self.get_project_by_id(id)
 842          if proj_db is None:
 843              return False
 844  
 845          changed = False
 846          
 847          # Get all teams that have this project to validate LLM/embedding access
 848          teams_with_project = [team for team in self.get_teams() if proj_db in team.projects]
 849          if not teams_with_project:
 850              return False  # Project should belong to at least one team
 851          
 852          if projectModel.users is not None:
 853              new_users = []
 854              rejected = []
 855              for username in projectModel.users:
 856                  user_db = self.get_user_by_username(username)
 857                  if user_db is None:
 858                      rejected.append(f"{username} (not found)")
 859                      continue
 860                  # Platform admins bypass the team membership check
 861                  if user_db.is_admin:
 862                      new_users.append(user_db)
 863                      continue
 864                  # Otherwise the user must belong to one of the project's teams
 865                  in_team = any(
 866                      user_db in team.users or user_db in team.admins
 867                      for team in teams_with_project
 868                  )
 869                  if in_team:
 870                      new_users.append(user_db)
 871                  else:
 872                      rejected.append(f"{username} (not in project's team)")
 873  
 874              if rejected:
 875                  from fastapi import HTTPException
 876                  raise HTTPException(
 877                      status_code=400,
 878                      detail=f"Cannot assign users: {', '.join(rejected)}",
 879                  )
 880  
 881              proj_db.users = new_users
 882              changed = True
 883  
 884          if projectModel.name is not None and proj_db.name != projectModel.name:
 885              if (
 886                  self.db.query(ProjectDatabase)
 887                  .filter(
 888                      ProjectDatabase.creator == proj_db.creator,
 889                      ProjectDatabase.name == projectModel.name,
 890                      ProjectDatabase.id != proj_db.id,
 891                  )
 892                  .first()
 893                  is not None
 894              ):
 895                  return False
 896              proj_db.name = projectModel.name
 897              changed = True
 898  
 899          if projectModel.llm is not None and proj_db.llm != projectModel.llm:
 900              # Validate that at least one team has access to this LLM
 901              llm_db = self.get_llm_by_name(projectModel.llm)
 902              if llm_db is None:
 903                  return False
 904                  
 905              llm_access = False
 906              for team in teams_with_project:
 907                  if llm_db in team.llms:
 908                      llm_access = True
 909                      break
 910                      
 911              if not llm_access:
 912                  return False  # No team has access to this LLM
 913                  
 914              proj_db.llm = projectModel.llm
 915              changed = True
 916  
 917          if projectModel.embeddings is not None and proj_db.embeddings != projectModel.embeddings:
 918              # Validate that at least one team has access to this embedding model
 919              if projectModel.embeddings:  # Only check if embeddings is not empty
 920                  embedding_db = self.get_embedding_by_name(projectModel.embeddings)
 921                  if embedding_db is None:
 922                      return False
 923                      
 924                  embedding_access = False
 925                  for team in teams_with_project:
 926                      if embedding_db in team.embeddings:
 927                          embedding_access = True
 928                          break
 929                          
 930                  if not embedding_access:
 931                      return False  # No team has access to this embedding model
 932              
 933              proj_db.embeddings = projectModel.embeddings
 934              changed = True
 935  
 936          if projectModel.system is not None and proj_db.system != projectModel.system:
 937              proj_db.system = projectModel.system
 938              changed = True
 939              # Auto-create prompt version
 940              self._create_prompt_version(proj_db.id, projectModel.system, user_id=getattr(projectModel, '_user_id', None))
 941  
 942          if (
 943              projectModel.censorship is not None
 944              and proj_db.censorship != projectModel.censorship
 945          ):
 946              proj_db.censorship = projectModel.censorship
 947              changed = True
 948  
 949          if projectModel.guard is not None and proj_db.guard != projectModel.guard:
 950              proj_db.guard = projectModel.guard
 951              changed = True
 952  
 953          if (
 954              projectModel.human_name is not None
 955              and proj_db.human_name != projectModel.human_name
 956          ):
 957              proj_db.human_name = projectModel.human_name
 958              changed = True
 959  
 960          if (
 961              projectModel.human_description is not None
 962              and proj_db.human_description != projectModel.human_description
 963          ):
 964              proj_db.human_description = projectModel.human_description
 965              changed = True
 966  
 967          if projectModel.public is not None and proj_db.public != projectModel.public:
 968              proj_db.public = projectModel.public
 969              changed = True
 970  
 971          if (
 972              projectModel.default_prompt is not None
 973              and proj_db.default_prompt != projectModel.default_prompt
 974          ):
 975              proj_db.default_prompt = projectModel.default_prompt
 976              changed = True
 977  
 978          if hasattr(projectModel, "options") and projectModel.options is not None:
 979              from restai.utils.crypto import encrypt_sensitive_options, PROJECT_SENSITIVE_KEYS
 980              # Keys that aren't surfaced in the project-edit form — they are
 981              # owned by dedicated endpoints (e.g. the Mobile pairing flow) and
 982              # must not be wiped when the edit form POSTs a ProjectOptions dump.
 983              PRESERVED_KEYS = ("mobile_enabled", "mobile_api_key_id")
 984              try:
 985                  current_options = json.loads(proj_db.options) if proj_db.options else {}
 986              except json.JSONDecodeError:
 987                  current_options = {}
 988              new_options = projectModel.options.model_dump()
 989              new_options = encrypt_sensitive_options(new_options, PROJECT_SENSITIVE_KEYS)
 990              for k in PRESERVED_KEYS:
 991                  if k in current_options:
 992                      new_options[k] = current_options[k]
 993              if current_options != new_options:
 994                  proj_db.options = json.dumps(new_options)
 995                  changed = True
 996  
 997          if changed:
 998              self.db.commit()
 999          return True
1000  
1001      def create_team(self, team_create: TeamModelCreate) -> TeamDatabase:
1002          db_team: TeamDatabase = TeamDatabase(
1003              name=team_create.name,
1004              description=team_create.description,
1005              creator_id=team_create.creator_id,
1006              budget=team_create.budget,
1007              created_at=datetime.now(timezone.utc),
1008              updated_at=datetime.now(timezone.utc),
1009          )
1010          self.db.add(db_team)
1011          self.db.commit()
1012          self.db.refresh(db_team)
1013          return db_team
1014  
1015      def get_team_by_id(self, team_id: int) -> Optional[TeamDatabase]:
1016          team: Optional[TeamDatabase] = (
1017              self.db.query(TeamDatabase).filter(TeamDatabase.id == team_id).first()
1018          )
1019          return team
1020  
1021      def get_team_by_name(self, name: str) -> Optional[TeamDatabase]:
1022          team: Optional[TeamDatabase] = (
1023              self.db.query(TeamDatabase).filter(TeamDatabase.name == name).first()
1024          )
1025          return team
1026  
1027      def get_teams(self) -> List[TeamDatabase]:
1028          teams: List[TeamDatabase] = self.db.query(TeamDatabase).all()
1029          return teams
1030  
1031      def update_team(self, team: TeamDatabase, team_update: TeamModelUpdate) -> bool:
1032          changed = False
1033  
1034          if team_update.name is not None and team.name != team_update.name:
1035              team.name = team_update.name
1036              changed = True
1037  
1038          if team_update.description is not None and team.description != team_update.description:
1039              team.description = team_update.description
1040              changed = True
1041  
1042          if team_update.budget is not None and team.budget != team_update.budget:
1043              team.budget = team_update.budget
1044              changed = True
1045  
1046          if team_update.branding is not None:
1047              import json
1048              team.branding = json.dumps(team_update.branding.model_dump())
1049              changed = True
1050  
1051          if changed:
1052              team.updated_at = datetime.now(timezone.utc)
1053              self.db.commit()
1054          return changed
1055  
1056      def delete_team(self, team: TeamDatabase) -> bool:
1057          self.db.delete(team)
1058          self.db.commit()
1059          return True
1060  
1061      def add_user_to_team(self, team: TeamDatabase, user: UserDatabase) -> bool:
1062          if user not in team.users:
1063              team.users.append(user)
1064              self.db.commit()
1065          return True
1066      
1067      def remove_user_from_team(self, team: TeamDatabase, user: UserDatabase) -> bool:
1068          if user in team.users:
1069              team.users.remove(user)
1070              self.db.commit()
1071          return True
1072      
1073      def add_admin_to_team(self, team: TeamDatabase, user: UserDatabase) -> bool:
1074          if user not in team.admins:
1075              team.admins.append(user)
1076              self.db.commit()
1077          return True
1078      
1079      def remove_admin_from_team(self, team: TeamDatabase, user: UserDatabase) -> bool:
1080          if user in team.admins:
1081              team.admins.remove(user)
1082              self.db.commit()
1083          return True
1084      
1085      def add_project_to_team(self, team: TeamDatabase, project: ProjectDatabase) -> bool:
1086          if project not in team.projects:
1087              team.projects.append(project)
1088              self.db.commit()
1089          return True
1090      
1091      def remove_project_from_team(self, team: TeamDatabase, project: ProjectDatabase) -> bool:
1092          if project in team.projects:
1093              team.projects.remove(project)
1094              self.db.commit()
1095          return True
1096      
1097      def add_llm_to_team(self, team: TeamDatabase, llm: LLMDatabase) -> bool:
1098          if llm not in team.llms:
1099              team.llms.append(llm)
1100              self.db.commit()
1101          return True
1102      
1103      def remove_llm_from_team(self, team: TeamDatabase, llm: LLMDatabase) -> bool:
1104          if llm in team.llms:
1105              team.llms.remove(llm)
1106              self.db.commit()
1107          return True
1108      
1109      def add_embedding_to_team(self, team: TeamDatabase, embedding: EmbeddingDatabase) -> bool:
1110          if embedding not in team.embeddings:
1111              team.embeddings.append(embedding)
1112              self.db.commit()
1113          return True
1114      
1115      def remove_embedding_from_team(self, team: TeamDatabase, embedding: EmbeddingDatabase) -> bool:
1116          if embedding in team.embeddings:
1117              team.embeddings.remove(embedding)
1118              self.db.commit()
1119          return True
1120          
1121      def get_teams_for_user(self, user_id: int) -> List[TeamDatabase]:
1122          """Get all teams where the user is a member or admin"""
1123          user = self.get_user_by_id(user_id)
1124          if user is None:
1125              return []
1126          return list(set(user.teams + user.admin_teams))
1127          
1128      def get_settings(self) -> list[SettingDatabase]:
1129          from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, decrypt_field
1130          rows = self.db.query(SettingDatabase).all()
1131          for r in rows:
1132              if r.key in SETTINGS_ENCRYPTED_KEYS and r.value:
1133                  self.db.expunge(r)
1134                  r.value = decrypt_field(r.value)
1135          return rows
1136  
1137      def get_setting(self, key: str) -> Optional[SettingDatabase]:
1138          from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, decrypt_field
1139          row = self.db.query(SettingDatabase).filter(SettingDatabase.key == key).first()
1140          if row and key in SETTINGS_ENCRYPTED_KEYS and row.value:
1141              self.db.expunge(row)
1142              row.value = decrypt_field(row.value)
1143          return row
1144  
1145      def get_setting_value(self, key: str, default: str = "") -> str:
1146          """Get a setting value by key, returning default if not found or empty."""
1147          row = self.get_setting(key)
1148          return row.value if row and row.value else default
1149  
1150      def upsert_setting(self, key: str, value: str) -> None:
1151          from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, encrypt_field
1152          stored_value = encrypt_field(value) if (key in SETTINGS_ENCRYPTED_KEYS and value) else value
1153          existing = self.db.query(SettingDatabase).filter(SettingDatabase.key == key).first()
1154          if existing:
1155              existing.value = stored_value
1156          else:
1157              self.db.add(SettingDatabase(key=key, value=stored_value))
1158          try:
1159              self.db.commit()
1160          except Exception:
1161              self.db.rollback()
1162  
1163      def get_team_spending(self, team_id: int) -> float:
1164          now = datetime.now(timezone.utc)
1165          month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
1166          result = self.db.query(
1167              func.coalesce(func.sum(OutputDatabase.input_cost + OutputDatabase.output_cost), 0.0)
1168          ).filter(
1169              or_(
1170                  OutputDatabase.project_id.in_(
1171                      self.db.query(ProjectDatabase.id).filter(ProjectDatabase.team_id == team_id)
1172                  ),
1173                  OutputDatabase.team_id == team_id
1174              ),
1175              OutputDatabase.date >= month_start
1176          ).scalar()
1177          return float(result)
1178  
1179      def update_team_members(self, team: TeamDatabase, team_update: TeamModelUpdate) -> bool:
1180          changed = False
1181          
1182          # Update users
1183          if team_update.users is not None:
1184              team.users = []
1185              for username in team_update.users:
1186                  user_db = self.get_user_by_username(username)
1187                  if user_db is not None:
1188                      team.users.append(user_db)
1189              changed = True
1190              
1191          # Update admins
1192          if team_update.admins is not None:
1193              team.admins = []
1194              for username in team_update.admins:
1195                  user_db = self.get_user_by_username(username)
1196                  if user_db is not None:
1197                      team.admins.append(user_db)
1198              changed = True
1199              
1200          # Update projects
1201          if team_update.projects is not None:
1202              team.projects = []
1203              for project_name in team_update.projects:
1204                  project_db = self.get_project_by_name(project_name)
1205                  if project_db is not None:
1206                      team.projects.append(project_db)
1207              changed = True
1208              
1209          # Update LLMs
1210          if team_update.llms is not None:
1211              team.llms = []
1212              for llm_name in team_update.llms:
1213                  llm_db = self.get_llm_by_name(llm_name)
1214                  if llm_db is not None:
1215                      team.llms.append(llm_db)
1216              changed = True
1217              
1218          # Update embeddings
1219          if team_update.embeddings is not None:
1220              team.embeddings = []
1221              for embedding_name in team_update.embeddings:
1222                  embedding_db = self.get_embedding_by_name(embedding_name)
1223                  if embedding_db is not None:
1224                      team.embeddings.append(embedding_db)
1225              changed = True
1226  
1227          # Update image generators
1228          if team_update.image_generators is not None:
1229              team.image_generators = []
1230              self.db.flush()
1231              for gen_name in team_update.image_generators:
1232                  team.image_generators.append(
1233                      TeamImageGeneratorDatabase(team_id=team.id, generator_name=gen_name)
1234                  )
1235              changed = True
1236  
1237          # Update audio generators
1238          if team_update.audio_generators is not None:
1239              team.audio_generators = []
1240              self.db.flush()
1241              for gen_name in team_update.audio_generators:
1242                  team.audio_generators.append(
1243                      TeamAudioGeneratorDatabase(team_id=team.id, generator_name=gen_name)
1244                  )
1245              changed = True
1246  
1247          if changed:
1248              team.updated_at = datetime.now(timezone.utc)
1249              self.db.commit()
1250              
1251          return changed
1252  
1253  
1254      def add_image_generator_to_team(self, team: TeamDatabase, generator_name: str) -> bool:
1255          existing = self.db.query(TeamImageGeneratorDatabase).filter(
1256              TeamImageGeneratorDatabase.team_id == team.id,
1257              TeamImageGeneratorDatabase.generator_name == generator_name
1258          ).first()
1259          if existing is None:
1260              team.image_generators.append(
1261                  TeamImageGeneratorDatabase(team_id=team.id, generator_name=generator_name)
1262              )
1263              self.db.commit()
1264          return True
1265  
1266      def remove_image_generator_from_team(self, team: TeamDatabase, generator_name: str) -> bool:
1267          item = self.db.query(TeamImageGeneratorDatabase).filter(
1268              TeamImageGeneratorDatabase.team_id == team.id,
1269              TeamImageGeneratorDatabase.generator_name == generator_name
1270          ).first()
1271          if item is not None:
1272              self.db.delete(item)
1273              self.db.commit()
1274          return True
1275  
1276      def add_audio_generator_to_team(self, team: TeamDatabase, generator_name: str) -> bool:
1277          existing = self.db.query(TeamAudioGeneratorDatabase).filter(
1278              TeamAudioGeneratorDatabase.team_id == team.id,
1279              TeamAudioGeneratorDatabase.generator_name == generator_name
1280          ).first()
1281          if existing is None:
1282              team.audio_generators.append(
1283                  TeamAudioGeneratorDatabase(team_id=team.id, generator_name=generator_name)
1284              )
1285              self.db.commit()
1286          return True
1287  
1288      def remove_audio_generator_from_team(self, team: TeamDatabase, generator_name: str) -> bool:
1289          item = self.db.query(TeamAudioGeneratorDatabase).filter(
1290              TeamAudioGeneratorDatabase.team_id == team.id,
1291              TeamAudioGeneratorDatabase.generator_name == generator_name
1292          ).first()
1293          if item is not None:
1294              self.db.delete(item)
1295              self.db.commit()
1296          return True
1297  
1298  
1299      def _create_prompt_version(self, project_id: int, system_prompt: str, user_id: int = None):
1300          """Create a new prompt version record, marking it as active."""
1301          from restai.models.databasemodels import PromptVersionDatabase
1302  
1303          # Deactivate current active version
1304          self.db.query(PromptVersionDatabase).filter(
1305              PromptVersionDatabase.project_id == project_id,
1306              PromptVersionDatabase.is_active == True,
1307          ).update({"is_active": False})
1308  
1309          # Get next version number
1310          max_version = (
1311              self.db.query(func.max(PromptVersionDatabase.version))
1312              .filter(PromptVersionDatabase.project_id == project_id)
1313              .scalar()
1314          ) or 0
1315  
1316          version = PromptVersionDatabase(
1317              project_id=project_id,
1318              version=max_version + 1,
1319              system_prompt=system_prompt or "",
1320              created_by=user_id,
1321              created_at=datetime.now(timezone.utc),
1322              is_active=True,
1323          )
1324          self.db.add(version)
1325  
1326      def get_prompt_versions(self, project_id: int):
1327          """Get all prompt versions for a project, newest first."""
1328          from restai.models.databasemodels import PromptVersionDatabase
1329          return (
1330              self.db.query(PromptVersionDatabase)
1331              .filter(PromptVersionDatabase.project_id == project_id)
1332              .order_by(PromptVersionDatabase.version.desc())
1333              .all()
1334          )
1335  
1336      def get_prompt_version(self, version_id: int):
1337          """Get a specific prompt version by ID."""
1338          from restai.models.databasemodels import PromptVersionDatabase
1339          return self.db.query(PromptVersionDatabase).filter(PromptVersionDatabase.id == version_id).first()
1340  
1341      def get_active_prompt_version(self, project_id: int):
1342          """Get the active prompt version for a project."""
1343          from restai.models.databasemodels import PromptVersionDatabase
1344          return (
1345              self.db.query(PromptVersionDatabase)
1346              .filter(PromptVersionDatabase.project_id == project_id, PromptVersionDatabase.is_active == True)
1347              .first()
1348          )
1349  
1350      # ── Project Tools (agent-created) ────────────────────────────────────
1351  
1352      def get_project_tools(self, project_id: int) -> list[ProjectToolDatabase]:
1353          return (
1354              self.db.query(ProjectToolDatabase)
1355              .filter(ProjectToolDatabase.project_id == project_id)
1356              .order_by(ProjectToolDatabase.name)
1357              .all()
1358          )
1359  
1360      def get_project_tool_by_name(self, project_id: int, name: str) -> Optional[ProjectToolDatabase]:
1361          return (
1362              self.db.query(ProjectToolDatabase)
1363              .filter(ProjectToolDatabase.project_id == project_id, ProjectToolDatabase.name == name)
1364              .first()
1365          )
1366  
1367      def upsert_project_tool(self, project_id: int, name: str, description: str, parameters: str, code: str) -> ProjectToolDatabase:
1368          from datetime import datetime, timezone
1369          now = datetime.now(timezone.utc)
1370          existing = self.get_project_tool_by_name(project_id, name)
1371          if existing:
1372              existing.description = description
1373              existing.parameters = parameters
1374              existing.code = code
1375              existing.updated_at = now
1376              self.db.commit()
1377              return existing
1378          tool = ProjectToolDatabase(
1379              project_id=project_id,
1380              name=name,
1381              description=description,
1382              parameters=parameters,
1383              code=code,
1384              created_at=now,
1385              updated_at=now,
1386          )
1387          self.db.add(tool)
1388          self.db.commit()
1389          return tool
1390  
1391      def delete_project_tool(self, project_id: int, name: str) -> bool:
1392          tool = self.get_project_tool_by_name(project_id, name)
1393          if tool:
1394              self.db.delete(tool)
1395              self.db.commit()
1396              return True
1397          return False
1398  
1399      # ── Project Routines (scheduled messages) ────────────────────────────
1400  
1401      def get_project_routines(self, project_id: int) -> list[ProjectRoutineDatabase]:
1402          return (
1403              self.db.query(ProjectRoutineDatabase)
1404              .filter(ProjectRoutineDatabase.project_id == project_id)
1405              .order_by(ProjectRoutineDatabase.name)
1406              .all()
1407          )
1408  
1409      def get_all_enabled_routines(self) -> list[ProjectRoutineDatabase]:
1410          return (
1411              self.db.query(ProjectRoutineDatabase)
1412              .filter(ProjectRoutineDatabase.enabled == True)
1413              .all()
1414          )
1415  
1416      def get_project_routine_by_id(self, routine_id: int) -> Optional[ProjectRoutineDatabase]:
1417          return self.db.query(ProjectRoutineDatabase).filter(ProjectRoutineDatabase.id == routine_id).first()
1418  
1419      def create_project_routine(self, project_id: int, name: str, message: str, schedule_minutes: int, enabled: bool = True) -> ProjectRoutineDatabase:
1420          from datetime import datetime, timezone
1421          now = datetime.now(timezone.utc)
1422          routine = ProjectRoutineDatabase(
1423              project_id=project_id,
1424              name=name,
1425              message=message,
1426              schedule_minutes=schedule_minutes,
1427              enabled=enabled,
1428              created_at=now,
1429              updated_at=now,
1430          )
1431          self.db.add(routine)
1432          self.db.commit()
1433          self.db.refresh(routine)
1434          return routine
1435  
1436      def delete_project_routine(self, routine_id: int) -> bool:
1437          routine = self.get_project_routine_by_id(routine_id)
1438          if routine:
1439              self.db.delete(routine)
1440              self.db.commit()
1441              return True
1442          return False
1443  
1444      # ── Cron Logs ────────────────────────────────────────────────────────
1445  
1446      def create_cron_log(self, job, status, message, details=None, items_processed=0, duration_ms=None):
1447          from datetime import datetime, timezone
1448          entry = CronLogDatabase(
1449              job=job,
1450              status=status,
1451              message=message,
1452              details=details,
1453              items_processed=items_processed,
1454              duration_ms=duration_ms,
1455              date=datetime.now(timezone.utc),
1456          )
1457          self.db.add(entry)
1458          self.db.commit()
1459          return entry
1460  
1461      def get_cron_logs(self, job=None, status=None, start=0, end=50):
1462          query = self.db.query(CronLogDatabase).order_by(CronLogDatabase.date.desc())
1463          if job:
1464              query = query.filter(CronLogDatabase.job == job)
1465          if status:
1466              query = query.filter(CronLogDatabase.status == status)
1467          return query.offset(start).limit(end - start).all()
1468  
1469  
1470  def get_db_wrapper() -> DBWrapper:
1471      wrapper: DBWrapper = DBWrapper()
1472      try:
1473          return wrapper
1474      finally:
1475          wrapper.db.close()