database.py
1 import logging 2 from sqlalchemy import create_engine, func, or_ 3 from restai import config 4 from datetime import datetime, timezone 5 from restai.models.databasemodels import ( 6 ApiKeyDatabase, 7 LLMDatabase, 8 EmbeddingDatabase, 9 OutputDatabase, 10 ProjectDatabase, 11 ProjectToolDatabase, 12 ProjectRoutineDatabase, 13 CronLogDatabase, 14 SettingDatabase, 15 UserDatabase, 16 TeamDatabase, 17 TeamImageGeneratorDatabase, 18 TeamAudioGeneratorDatabase, 19 WidgetDatabase, 20 ImageGeneratorDatabase, 21 SpeechToTextDatabase, 22 ProjectSecretDatabase, 23 ) 24 from restai.models.models import ( 25 LLMModel, 26 LLMUpdate, 27 ProjectModelUpdate, 28 User, 29 UserUpdate, 30 EmbeddingModel, 31 EmbeddingUpdate, 32 TeamModel, 33 TeamModelUpdate, 34 TeamModelCreate, 35 ) 36 from sqlalchemy.orm import sessionmaker, Session 37 import bcrypt 38 from typing import Optional, List 39 from restai.config import MYSQL_HOST, MYSQL_URL, POSTGRES_HOST, POSTGRES_URL 40 import json 41 from restai.utils.crypto import decrypt_api_key, hash_api_key, verify_api_key_hash 42 43 import logging as _logging 44 _db_logger = _logging.getLogger(__name__) 45 46 if MYSQL_HOST: 47 _db_logger.info("Using MySQL database.") 48 engine = create_engine( 49 MYSQL_URL, 50 pool_size=config.DB_POOL_SIZE, 51 max_overflow=config.DB_MAX_OVERFLOW, 52 pool_recycle=config.DB_POOL_RECYCLE, 53 ) 54 elif POSTGRES_HOST: 55 _db_logger.info("Using PostgreSQL database.") 56 engine = create_engine( 57 POSTGRES_URL, 58 pool_size=config.DB_POOL_SIZE, 59 max_overflow=config.DB_MAX_OVERFLOW, 60 pool_recycle=config.DB_POOL_RECYCLE, 61 ) 62 else: 63 _db_logger.info("Using sqlite database.") 64 engine = create_engine( 65 "sqlite:///./restai.db", 66 connect_args={"check_same_thread": False}, 67 pool_size=config.DB_POOL_SIZE, 68 max_overflow=config.DB_POOL_RECYCLE, 69 pool_recycle=config.DB_POOL_RECYCLE, 70 ) 71 72 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 73 def hash_password(password: str) -> str: 74 return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') 75 76 def verify_password(password: str, hashed: str) -> bool: 77 return bcrypt.checkpw(password.encode('utf-8'), hashed.encode('utf-8')) 78 79 80 class DBWrapper: 81 __slots__ = ("db",) 82 83 def __init__(self): 84 self.db: Session = SessionLocal() 85 86 def create_user( 87 self, 88 username: str, 89 password: Optional[str], 90 admin: bool = False, 91 private: bool = False, 92 restricted: bool = False, 93 ) -> UserDatabase: 94 from datetime import datetime, timezone 95 password_hash: Optional[str] 96 if password: 97 password_hash = hash_password(password) 98 password_updated_at = datetime.now(timezone.utc) 99 else: 100 password_hash = None 101 password_updated_at = None 102 db_user: UserDatabase = UserDatabase( 103 username=username, 104 hashed_password=password_hash, 105 password_updated_at=password_updated_at, 106 is_admin=admin, 107 is_private=private, 108 is_restricted=restricted, 109 options='{"credit": -1.0}', 110 ) 111 self.db.add(db_user) 112 self.db.commit() 113 self.db.refresh(db_user) 114 return db_user 115 116 def create_llm( 117 self, 118 name: str, 119 class_name: str, 120 options: str, 121 privacy: str, 122 description: str, 123 context_window: int = 4096, 124 input_cost: float = 0.0, 125 output_cost: float = 0.0, 126 ) -> LLMDatabase: 127 # Encrypt sensitive fields (api_key) in the options JSON 128 from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS 129 import json as _json 130 try: 131 opts_dict = _json.loads(options) if isinstance(options, str) else options 132 opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS) 133 options = _json.dumps(opts_dict) 134 except Exception as e: 135 logging.warning("Failed to encrypt LLM options: %s", e) 136 137 db_llm: LLMDatabase = LLMDatabase( 138 name=name, 139 class_name=class_name, 140 options=options, 141 privacy=privacy, 142 description=description, 143 context_window=context_window, 144 input_cost=input_cost, 145 output_cost=output_cost, 146 ) 147 self.db.add(db_llm) 148 self.db.commit() 149 self.db.refresh(db_llm) 150 return db_llm 151 152 def create_embedding( 153 self, 154 name: str, 155 class_name: str, 156 options: str, 157 privacy: str, 158 description: str, 159 dimension: int, 160 ) -> EmbeddingDatabase: 161 db_embedding: EmbeddingDatabase = EmbeddingDatabase( 162 name=name, 163 class_name=class_name, 164 options=options, 165 privacy=privacy, 166 description=description, 167 dimension=dimension, 168 ) 169 self.db.add(db_embedding) 170 self.db.commit() 171 self.db.refresh(db_embedding) 172 return db_embedding 173 174 def get_users(self) -> list[UserDatabase]: 175 users: list[UserDatabase] = self.db.query(UserDatabase).all() 176 return users 177 178 def get_llms(self) -> list[LLMDatabase]: 179 llms: list[LLMDatabase] = self.db.query(LLMDatabase).all() 180 return llms 181 182 def get_embeddings(self) -> list[EmbeddingDatabase]: 183 embeddings: list[EmbeddingDatabase] = self.db.query(EmbeddingDatabase).all() 184 return embeddings 185 186 def get_llm_by_name(self, name: str) -> Optional[LLMDatabase]: 187 llm: Optional[LLMDatabase] = ( 188 self.db.query(LLMDatabase).filter(LLMDatabase.name == name).first() 189 ) 190 return llm 191 192 def get_llm_by_id(self, id: int) -> Optional[LLMDatabase]: 193 return self.db.query(LLMDatabase).filter(LLMDatabase.id == id).first() 194 195 def get_embedding_by_name(self, name: str) -> Optional[EmbeddingDatabase]: 196 llm: Optional[EmbeddingDatabase] = ( 197 self.db.query(EmbeddingDatabase) 198 .filter(EmbeddingDatabase.name == name) 199 .first() 200 ) 201 return llm 202 203 def get_embedding_by_id(self, id: int) -> Optional[EmbeddingDatabase]: 204 return self.db.query(EmbeddingDatabase).filter(EmbeddingDatabase.id == id).first() 205 206 def get_user_by_apikey(self, apikey: str): 207 """Returns (UserDatabase, ApiKeyDatabase) or (UserDatabase, None) for legacy, or (None, None).""" 208 # Lookup by key_prefix, then verify the salted hash 209 prefix = apikey[:8] 210 candidates = ( 211 self.db.query(ApiKeyDatabase) 212 .filter(ApiKeyDatabase.key_prefix == prefix) 213 .all() 214 ) 215 for api_key_row in candidates: 216 if verify_api_key_hash(apikey, api_key_row.key_hash): 217 return api_key_row.user, api_key_row 218 # Fallback: check legacy api_key column for migration period 219 for user in self.db.query(UserDatabase).filter(UserDatabase.api_key.isnot(None)): 220 try: 221 if decrypt_api_key(user.api_key) == apikey: 222 return user, None 223 except Exception: 224 continue 225 return None, None 226 227 def get_user_by_username(self, username: str) -> Optional[UserDatabase]: 228 user: Optional[UserDatabase] = ( 229 self.db.query(UserDatabase) 230 .filter(UserDatabase.username == username) 231 .first() 232 ) 233 return user 234 235 def create_api_key(self, user_id: int, encrypted_key: str, key_hash: str, key_prefix: str, description: str, allowed_projects: str = None, read_only: bool = False) -> ApiKeyDatabase: 236 api_key = ApiKeyDatabase( 237 user_id=user_id, 238 encrypted_key=encrypted_key, 239 key_hash=key_hash, 240 key_prefix=key_prefix, 241 description=description, 242 created_at=datetime.now(timezone.utc), 243 allowed_projects=allowed_projects, 244 read_only=read_only, 245 ) 246 self.db.add(api_key) 247 self.db.commit() 248 self.db.refresh(api_key) 249 return api_key 250 251 def get_api_keys_for_user(self, user_id: int) -> list[ApiKeyDatabase]: 252 return ( 253 self.db.query(ApiKeyDatabase) 254 .filter(ApiKeyDatabase.user_id == user_id) 255 .order_by(ApiKeyDatabase.created_at.desc()) 256 .all() 257 ) 258 259 def delete_api_key(self, api_key_id: int, user_id: int) -> bool: 260 api_key = ( 261 self.db.query(ApiKeyDatabase) 262 .filter(ApiKeyDatabase.id == api_key_id, ApiKeyDatabase.user_id == user_id) 263 .first() 264 ) 265 if api_key is None: 266 return False 267 self.db.delete(api_key) 268 self.db.commit() 269 return True 270 271 # ── Widget methods ───────────────────────────────────────────────── 272 273 def create_widget(self, project_id, creator_id, encrypted_key, key_hash, key_prefix, name, config_json, allowed_domains_json): 274 now = datetime.now(timezone.utc) 275 widget = WidgetDatabase( 276 project_id=project_id, 277 creator_id=creator_id, 278 encrypted_key=encrypted_key, 279 key_hash=key_hash, 280 key_prefix=key_prefix, 281 name=name, 282 config=config_json, 283 allowed_domains=allowed_domains_json, 284 enabled=True, 285 created_at=now, 286 updated_at=now, 287 ) 288 self.db.add(widget) 289 self.db.commit() 290 self.db.refresh(widget) 291 return widget 292 293 def get_widget_by_id(self, widget_id): 294 return self.db.query(WidgetDatabase).filter(WidgetDatabase.id == widget_id).first() 295 296 def get_widget_by_key_hash(self, key_hash): 297 return self.db.query(WidgetDatabase).filter(WidgetDatabase.key_hash == key_hash).first() 298 299 def get_widget_by_key(self, plaintext_key): 300 """Look up a widget by plaintext key using prefix-then-verify (salted hash).""" 301 prefix = plaintext_key[:11] 302 candidates = self.db.query(WidgetDatabase).filter(WidgetDatabase.key_prefix == prefix).all() 303 for w in candidates: 304 if verify_api_key_hash(plaintext_key, w.key_hash): 305 return w 306 return None 307 308 def get_widgets_for_project(self, project_id): 309 return ( 310 self.db.query(WidgetDatabase) 311 .filter(WidgetDatabase.project_id == project_id) 312 .order_by(WidgetDatabase.created_at.desc()) 313 .all() 314 ) 315 316 def delete_widget(self, widget): 317 self.db.delete(widget) 318 self.db.commit() 319 return True 320 321 def update_user(self, user: User, user_update: UserUpdate) -> bool: 322 if user_update.password is not None: 323 from datetime import datetime, timezone 324 user.hashed_password = hash_password(user_update.password) 325 user.password_updated_at = datetime.now(timezone.utc) 326 327 if user_update.is_admin is not None: 328 user.is_admin = user_update.is_admin 329 330 if user_update.is_private is not None: 331 user.is_private = user_update.is_private 332 333 if user_update.is_restricted is not None: 334 user.is_restricted = user_update.is_restricted 335 336 if hasattr(user_update, "options") and user_update.options is not None: 337 try: 338 current_options = json.loads(user.options) if user.options else {} 339 new_options = user_update.options.model_dump() 340 if current_options != new_options: 341 user.options = json.dumps(new_options) 342 except json.JSONDecodeError: 343 user.options = json.dumps(user_update.options.model_dump()) 344 345 self.db.commit() 346 return True 347 348 def update_llm(self, llm: LLMModel, llmUpdate: LLMUpdate) -> bool: 349 if llmUpdate.class_name is not None and llm.class_name != llmUpdate.class_name: 350 llm.class_name = llmUpdate.class_name 351 352 if llmUpdate.options is not None and llm.options != llmUpdate.options: 353 # Encrypt sensitive fields (api_key) before persisting 354 from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS 355 import json as _json 356 try: 357 opts_dict = _json.loads(llmUpdate.options) if isinstance(llmUpdate.options, str) else llmUpdate.options 358 # If api_key is the masked value, preserve the existing one 359 if opts_dict.get("api_key") == "********": 360 existing = _json.loads(llm.options) if isinstance(llm.options, str) else (llm.options or {}) 361 if "api_key" in existing: 362 opts_dict["api_key"] = existing["api_key"] 363 else: 364 del opts_dict["api_key"] 365 opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS) 366 llm.options = _json.dumps(opts_dict) if isinstance(llmUpdate.options, str) else opts_dict 367 except Exception as e: 368 logging.warning("Failed to encrypt LLM options on update: %s", e) 369 llm.options = llmUpdate.options 370 371 if llmUpdate.privacy is not None and llm.privacy != llmUpdate.privacy: 372 llm.privacy = llmUpdate.privacy 373 374 if ( 375 llmUpdate.description is not None 376 and llm.description != llmUpdate.description 377 ): 378 llm.description = llmUpdate.description 379 380 if llmUpdate.input_cost is not None and llm.input_cost != llmUpdate.input_cost: 381 llm.input_cost = llmUpdate.input_cost 382 383 if ( 384 llmUpdate.output_cost is not None 385 and llm.output_cost != llmUpdate.output_cost 386 ): 387 llm.output_cost = llmUpdate.output_cost 388 389 if ( 390 llmUpdate.context_window is not None 391 and llm.context_window != llmUpdate.context_window 392 ): 393 llm.context_window = llmUpdate.context_window 394 395 self.db.commit() 396 return True 397 398 def update_embedding( 399 self, embedding: EmbeddingModel, embeddingUpdate: EmbeddingUpdate 400 ) -> bool: 401 if ( 402 embeddingUpdate.class_name is not None 403 and embedding.class_name != embeddingUpdate.class_name 404 ): 405 embedding.class_name = embeddingUpdate.class_name 406 407 if ( 408 embeddingUpdate.options is not None 409 and embedding.options != embeddingUpdate.options 410 ): 411 # If api_key is the masked value, preserve the existing one 412 import json as _json 413 try: 414 new_opts = _json.loads(embeddingUpdate.options) if isinstance(embeddingUpdate.options, str) else (embeddingUpdate.options or {}) 415 if new_opts.get("api_key") == "********": 416 existing = _json.loads(embedding.options) if isinstance(embedding.options, str) else (embedding.options or {}) 417 if "api_key" in existing: 418 new_opts["api_key"] = existing["api_key"] 419 else: 420 del new_opts["api_key"] 421 embeddingUpdate.options = _json.dumps(new_opts) 422 except Exception: 423 pass 424 embedding.options = embeddingUpdate.options 425 426 if ( 427 embeddingUpdate.privacy is not None 428 and embedding.privacy != embeddingUpdate.privacy 429 ): 430 embedding.privacy = embeddingUpdate.privacy 431 432 if ( 433 embeddingUpdate.description is not None 434 and embedding.description != embeddingUpdate.description 435 ): 436 embedding.description = embeddingUpdate.description 437 438 if ( 439 embeddingUpdate.dimension is not None 440 and embedding.dimension != embeddingUpdate.dimension 441 ): 442 embedding.dimension = embeddingUpdate.dimension 443 444 self.db.commit() 445 return True 446 447 def delete_llm(self, llm: LLMDatabase) -> bool: 448 self.db.delete(llm) 449 self.db.commit() 450 return True 451 452 def delete_embedding(self, embedding: EmbeddingDatabase) -> bool: 453 self.db.delete(embedding) 454 self.db.commit() 455 return True 456 457 # ─── Image generators ───────────────────────────────────────────── 458 459 def get_image_generators(self) -> list[ImageGeneratorDatabase]: 460 return self.db.query(ImageGeneratorDatabase).order_by(ImageGeneratorDatabase.name).all() 461 462 def get_image_generator_by_id(self, gen_id: int) -> Optional[ImageGeneratorDatabase]: 463 return self.db.query(ImageGeneratorDatabase).filter(ImageGeneratorDatabase.id == gen_id).first() 464 465 def get_image_generator_by_name(self, name: str) -> Optional[ImageGeneratorDatabase]: 466 return self.db.query(ImageGeneratorDatabase).filter(ImageGeneratorDatabase.name == name).first() 467 468 def create_image_generator( 469 self, 470 name: str, 471 class_name: str, 472 options, 473 privacy: str = "public", 474 description: Optional[str] = None, 475 enabled: bool = True, 476 ) -> ImageGeneratorDatabase: 477 from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS 478 import json as _json 479 480 try: 481 opts_dict = _json.loads(options) if isinstance(options, str) else (options or {}) 482 opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS) 483 options_str = _json.dumps(opts_dict) 484 except Exception as e: 485 logging.warning("Failed to encrypt image generator options: %s", e) 486 options_str = options if isinstance(options, str) else _json.dumps(options or {}) 487 488 now = datetime.now(timezone.utc) 489 row = ImageGeneratorDatabase( 490 name=name, 491 class_name=class_name, 492 options=options_str, 493 privacy=privacy, 494 description=description, 495 enabled=enabled, 496 created_at=now, 497 updated_at=now, 498 ) 499 self.db.add(row) 500 self.db.commit() 501 self.db.refresh(row) 502 return row 503 504 def edit_image_generator(self, gen: ImageGeneratorDatabase, update) -> bool: 505 """Patch an image generator. `update` is an ImageGeneratorModelUpdate. 506 api_key in `options` is preserved when the submitted value is the 507 masked sentinel `"********"` (matches the LLM edit pattern).""" 508 from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS 509 import json as _json 510 511 changed = False 512 if update.class_name is not None and gen.class_name != update.class_name: 513 gen.class_name = update.class_name 514 changed = True 515 516 if update.options is not None: 517 try: 518 new_opts = _json.loads(update.options) if isinstance(update.options, str) else (update.options or {}) 519 # Preserve masked sensitive fields 520 existing = _json.loads(gen.options) if gen.options else {} 521 # The existing values in DB are already encrypted; decrypt 522 # them so the comparison is plaintext-vs-plaintext. 523 from restai.utils.crypto import decrypt_sensitive_options 524 existing_plain = decrypt_sensitive_options(dict(existing), LLM_SENSITIVE_KEYS) 525 for k in LLM_SENSITIVE_KEYS: 526 val = new_opts.get(k) 527 if isinstance(val, str) and val == "********": 528 if k in existing_plain: 529 new_opts[k] = existing_plain[k] 530 else: 531 new_opts.pop(k, None) 532 new_opts_enc = encrypt_sensitive_options(new_opts, LLM_SENSITIVE_KEYS) 533 gen.options = _json.dumps(new_opts_enc) 534 changed = True 535 except Exception as e: 536 logging.warning("Failed to update image generator options: %s", e) 537 538 if update.privacy is not None and gen.privacy != update.privacy: 539 gen.privacy = update.privacy 540 changed = True 541 542 if update.description is not None and gen.description != update.description: 543 gen.description = update.description 544 changed = True 545 546 if update.enabled is not None and gen.enabled != update.enabled: 547 gen.enabled = update.enabled 548 changed = True 549 550 if changed: 551 gen.updated_at = datetime.now(timezone.utc) 552 self.db.commit() 553 return changed 554 555 def delete_image_generator(self, gen: ImageGeneratorDatabase) -> bool: 556 # Also drop any team grants pointing at this name so we don't leave 557 # dangling rows in the legacy string-keyed teams_image_generators. 558 try: 559 self.db.query(TeamImageGeneratorDatabase).filter( 560 TeamImageGeneratorDatabase.generator_name == gen.name 561 ).delete(synchronize_session=False) 562 except Exception: 563 pass 564 self.db.delete(gen) 565 self.db.commit() 566 return True 567 568 # ─── Speech-to-text models ──────────────────────────────────────── 569 570 def get_speech_to_text(self) -> list[SpeechToTextDatabase]: 571 return self.db.query(SpeechToTextDatabase).order_by(SpeechToTextDatabase.name).all() 572 573 def get_speech_to_text_by_id(self, model_id: int) -> Optional[SpeechToTextDatabase]: 574 return self.db.query(SpeechToTextDatabase).filter(SpeechToTextDatabase.id == model_id).first() 575 576 def get_speech_to_text_by_name(self, name: str) -> Optional[SpeechToTextDatabase]: 577 return self.db.query(SpeechToTextDatabase).filter(SpeechToTextDatabase.name == name).first() 578 579 def create_speech_to_text( 580 self, 581 name: str, 582 class_name: str, 583 options, 584 privacy: str = "public", 585 description: Optional[str] = None, 586 enabled: bool = True, 587 ) -> SpeechToTextDatabase: 588 from restai.utils.crypto import encrypt_sensitive_options, LLM_SENSITIVE_KEYS 589 import json as _json 590 591 try: 592 opts_dict = _json.loads(options) if isinstance(options, str) else (options or {}) 593 opts_dict = encrypt_sensitive_options(opts_dict, LLM_SENSITIVE_KEYS) 594 options_str = _json.dumps(opts_dict) 595 except Exception as e: 596 logging.warning("Failed to encrypt speech-to-text options: %s", e) 597 options_str = options if isinstance(options, str) else _json.dumps(options or {}) 598 599 now = datetime.now(timezone.utc) 600 row = SpeechToTextDatabase( 601 name=name, 602 class_name=class_name, 603 options=options_str, 604 privacy=privacy, 605 description=description, 606 enabled=enabled, 607 created_at=now, 608 updated_at=now, 609 ) 610 self.db.add(row) 611 self.db.commit() 612 self.db.refresh(row) 613 return row 614 615 def edit_speech_to_text(self, model: SpeechToTextDatabase, update) -> bool: 616 """Patch a speech-to-text row. `"********"` in `options.api_key` 617 preserves the existing value (matches the LLM/image-gen pattern).""" 618 from restai.utils.crypto import ( 619 decrypt_sensitive_options, 620 encrypt_sensitive_options, 621 LLM_SENSITIVE_KEYS, 622 ) 623 import json as _json 624 625 changed = False 626 if update.class_name is not None and model.class_name != update.class_name: 627 model.class_name = update.class_name 628 changed = True 629 630 if update.options is not None: 631 try: 632 new_opts = _json.loads(update.options) if isinstance(update.options, str) else (update.options or {}) 633 existing = _json.loads(model.options) if model.options else {} 634 existing_plain = decrypt_sensitive_options(dict(existing), LLM_SENSITIVE_KEYS) 635 for k in LLM_SENSITIVE_KEYS: 636 val = new_opts.get(k) 637 if isinstance(val, str) and val == "********": 638 if k in existing_plain: 639 new_opts[k] = existing_plain[k] 640 else: 641 new_opts.pop(k, None) 642 new_opts_enc = encrypt_sensitive_options(new_opts, LLM_SENSITIVE_KEYS) 643 model.options = _json.dumps(new_opts_enc) 644 changed = True 645 except Exception as e: 646 logging.warning("Failed to update speech-to-text options: %s", e) 647 648 if update.privacy is not None and model.privacy != update.privacy: 649 model.privacy = update.privacy 650 changed = True 651 if update.description is not None and model.description != update.description: 652 model.description = update.description 653 changed = True 654 if update.enabled is not None and model.enabled != update.enabled: 655 model.enabled = update.enabled 656 changed = True 657 658 if changed: 659 model.updated_at = datetime.now(timezone.utc) 660 self.db.commit() 661 return changed 662 663 def delete_speech_to_text(self, model: SpeechToTextDatabase) -> bool: 664 # Drop team grants pointing at this name (string-keyed table). 665 try: 666 self.db.query(TeamAudioGeneratorDatabase).filter( 667 TeamAudioGeneratorDatabase.generator_name == model.name 668 ).delete(synchronize_session=False) 669 except Exception: 670 pass 671 self.db.delete(model) 672 self.db.commit() 673 return True 674 675 # ─── Project secrets (agentic browser vault) ───────────────────── 676 677 def get_project_secrets(self, project_id: int) -> list[ProjectSecretDatabase]: 678 return ( 679 self.db.query(ProjectSecretDatabase) 680 .filter(ProjectSecretDatabase.project_id == project_id) 681 .order_by(ProjectSecretDatabase.name) 682 .all() 683 ) 684 685 def get_project_secret_by_id(self, secret_id: int) -> Optional[ProjectSecretDatabase]: 686 return self.db.query(ProjectSecretDatabase).filter(ProjectSecretDatabase.id == secret_id).first() 687 688 def get_project_secret_by_name(self, project_id: int, name: str) -> Optional[ProjectSecretDatabase]: 689 return ( 690 self.db.query(ProjectSecretDatabase) 691 .filter(ProjectSecretDatabase.project_id == project_id, ProjectSecretDatabase.name == name) 692 .first() 693 ) 694 695 def create_project_secret( 696 self, 697 project_id: int, 698 name: str, 699 value: str, 700 description: Optional[str] = None, 701 ) -> ProjectSecretDatabase: 702 from restai.utils.crypto import encrypt_field 703 now = datetime.now(timezone.utc) 704 row = ProjectSecretDatabase( 705 project_id=project_id, 706 name=name, 707 value=encrypt_field(value), 708 description=description, 709 created_at=now, 710 updated_at=now, 711 ) 712 self.db.add(row) 713 self.db.commit() 714 self.db.refresh(row) 715 return row 716 717 def edit_project_secret(self, secret: ProjectSecretDatabase, update) -> bool: 718 """Patch a project secret. Value `"********"` preserves the 719 existing stored value (same mask-round-trip as LLMs).""" 720 from restai.utils.crypto import encrypt_field 721 722 changed = False 723 if update.value is not None and update.value != "********": 724 secret.value = encrypt_field(update.value) 725 changed = True 726 if update.description is not None and secret.description != update.description: 727 secret.description = update.description 728 changed = True 729 if changed: 730 secret.updated_at = datetime.now(timezone.utc) 731 self.db.commit() 732 return changed 733 734 def delete_project_secret(self, secret: ProjectSecretDatabase) -> bool: 735 self.db.delete(secret) 736 self.db.commit() 737 return True 738 739 def resolve_project_secret(self, project_id: int, name: str) -> Optional[str]: 740 """Server-side plaintext resolution — only called from inside a tool 741 (e.g. `browser_fill`). Returns None when the secret doesn't exist. 742 The plaintext never crosses back into LLM context because callers 743 pass it directly to the micro-server, not back to the agent.""" 744 from restai.utils.crypto import decrypt_field 745 row = self.get_project_secret_by_name(project_id, name) 746 if row is None or not row.value: 747 return None 748 try: 749 return decrypt_field(row.value) 750 except Exception: 751 return None 752 753 # ───────────────────────────────────────────────────────────────────── 754 755 def get_user_by_id(self, user_id: int) -> Optional[UserDatabase]: 756 user: Optional[UserDatabase] = ( 757 self.db.query(UserDatabase).filter(UserDatabase.id == user_id).first() 758 ) 759 return user 760 761 def delete_user(self, user: UserDatabase) -> bool: 762 self.db.delete(user) 763 self.db.commit() 764 return True 765 766 def get_project_by_name(self, name: str) -> Optional[ProjectDatabase]: 767 project: Optional[ProjectDatabase] = ( 768 self.db.query(ProjectDatabase).filter(ProjectDatabase.name == name).first() 769 ) 770 return project 771 772 def get_project_by_id(self, id: int) -> Optional[ProjectDatabase]: 773 project: Optional[ProjectDatabase] = ( 774 self.db.query(ProjectDatabase).filter(ProjectDatabase.id == id).first() 775 ) 776 return project 777 778 def create_project( 779 self, 780 name: str, 781 embeddings: str, 782 llm: str, 783 vectorstore: str, 784 human_name: str, 785 human_description: str, 786 project_type: str, 787 creator: int, 788 team_id: int, 789 ) -> Optional[ProjectDatabase]: 790 # Validate that the team exists 791 team = self.get_team_by_id(team_id) 792 if team is None: 793 return None 794 795 # Validate that the team has access to the specified LLM (block projects don't need one) 796 if llm: 797 llm_db = self.get_llm_by_name(llm) 798 if llm_db is None or llm_db not in team.llms: 799 return None 800 801 # If embeddings are specified, validate that the team has access to it 802 if embeddings: 803 embedding_db = self.get_embedding_by_name(embeddings) 804 if embedding_db is None or embedding_db not in team.embeddings: 805 return None 806 807 # Look up the creator so we can associate them in the same transaction 808 creator_user = self.get_user_by_id(creator) if creator else None 809 810 db_project: ProjectDatabase = ProjectDatabase( 811 name=name, 812 embeddings=embeddings, 813 llm=llm, 814 vectorstore=vectorstore, 815 human_name=human_name, 816 human_description=human_description, 817 type=project_type, 818 creator=creator, 819 options='{"logging": true}', 820 ) 821 self.db.add(db_project) 822 823 # Associate with team and creator in ONE transaction so we never 824 # end up with a project row that has no users. 825 if db_project not in team.projects: 826 team.projects.append(db_project) 827 if creator_user and db_project not in creator_user.projects: 828 creator_user.projects.append(db_project) 829 830 self.db.commit() 831 self.db.refresh(db_project) 832 833 return db_project 834 835 def delete_project(self, project: ProjectDatabase) -> bool: 836 self.db.delete(project) 837 self.db.commit() 838 return True 839 840 def edit_project(self, id: int, projectModel: ProjectModelUpdate) -> bool: 841 proj_db: Optional[ProjectDatabase] = self.get_project_by_id(id) 842 if proj_db is None: 843 return False 844 845 changed = False 846 847 # Get all teams that have this project to validate LLM/embedding access 848 teams_with_project = [team for team in self.get_teams() if proj_db in team.projects] 849 if not teams_with_project: 850 return False # Project should belong to at least one team 851 852 if projectModel.users is not None: 853 new_users = [] 854 rejected = [] 855 for username in projectModel.users: 856 user_db = self.get_user_by_username(username) 857 if user_db is None: 858 rejected.append(f"{username} (not found)") 859 continue 860 # Platform admins bypass the team membership check 861 if user_db.is_admin: 862 new_users.append(user_db) 863 continue 864 # Otherwise the user must belong to one of the project's teams 865 in_team = any( 866 user_db in team.users or user_db in team.admins 867 for team in teams_with_project 868 ) 869 if in_team: 870 new_users.append(user_db) 871 else: 872 rejected.append(f"{username} (not in project's team)") 873 874 if rejected: 875 from fastapi import HTTPException 876 raise HTTPException( 877 status_code=400, 878 detail=f"Cannot assign users: {', '.join(rejected)}", 879 ) 880 881 proj_db.users = new_users 882 changed = True 883 884 if projectModel.name is not None and proj_db.name != projectModel.name: 885 if ( 886 self.db.query(ProjectDatabase) 887 .filter( 888 ProjectDatabase.creator == proj_db.creator, 889 ProjectDatabase.name == projectModel.name, 890 ProjectDatabase.id != proj_db.id, 891 ) 892 .first() 893 is not None 894 ): 895 return False 896 proj_db.name = projectModel.name 897 changed = True 898 899 if projectModel.llm is not None and proj_db.llm != projectModel.llm: 900 # Validate that at least one team has access to this LLM 901 llm_db = self.get_llm_by_name(projectModel.llm) 902 if llm_db is None: 903 return False 904 905 llm_access = False 906 for team in teams_with_project: 907 if llm_db in team.llms: 908 llm_access = True 909 break 910 911 if not llm_access: 912 return False # No team has access to this LLM 913 914 proj_db.llm = projectModel.llm 915 changed = True 916 917 if projectModel.embeddings is not None and proj_db.embeddings != projectModel.embeddings: 918 # Validate that at least one team has access to this embedding model 919 if projectModel.embeddings: # Only check if embeddings is not empty 920 embedding_db = self.get_embedding_by_name(projectModel.embeddings) 921 if embedding_db is None: 922 return False 923 924 embedding_access = False 925 for team in teams_with_project: 926 if embedding_db in team.embeddings: 927 embedding_access = True 928 break 929 930 if not embedding_access: 931 return False # No team has access to this embedding model 932 933 proj_db.embeddings = projectModel.embeddings 934 changed = True 935 936 if projectModel.system is not None and proj_db.system != projectModel.system: 937 proj_db.system = projectModel.system 938 changed = True 939 # Auto-create prompt version 940 self._create_prompt_version(proj_db.id, projectModel.system, user_id=getattr(projectModel, '_user_id', None)) 941 942 if ( 943 projectModel.censorship is not None 944 and proj_db.censorship != projectModel.censorship 945 ): 946 proj_db.censorship = projectModel.censorship 947 changed = True 948 949 if projectModel.guard is not None and proj_db.guard != projectModel.guard: 950 proj_db.guard = projectModel.guard 951 changed = True 952 953 if ( 954 projectModel.human_name is not None 955 and proj_db.human_name != projectModel.human_name 956 ): 957 proj_db.human_name = projectModel.human_name 958 changed = True 959 960 if ( 961 projectModel.human_description is not None 962 and proj_db.human_description != projectModel.human_description 963 ): 964 proj_db.human_description = projectModel.human_description 965 changed = True 966 967 if projectModel.public is not None and proj_db.public != projectModel.public: 968 proj_db.public = projectModel.public 969 changed = True 970 971 if ( 972 projectModel.default_prompt is not None 973 and proj_db.default_prompt != projectModel.default_prompt 974 ): 975 proj_db.default_prompt = projectModel.default_prompt 976 changed = True 977 978 if hasattr(projectModel, "options") and projectModel.options is not None: 979 from restai.utils.crypto import encrypt_sensitive_options, PROJECT_SENSITIVE_KEYS 980 # Keys that aren't surfaced in the project-edit form — they are 981 # owned by dedicated endpoints (e.g. the Mobile pairing flow) and 982 # must not be wiped when the edit form POSTs a ProjectOptions dump. 983 PRESERVED_KEYS = ("mobile_enabled", "mobile_api_key_id") 984 try: 985 current_options = json.loads(proj_db.options) if proj_db.options else {} 986 except json.JSONDecodeError: 987 current_options = {} 988 new_options = projectModel.options.model_dump() 989 new_options = encrypt_sensitive_options(new_options, PROJECT_SENSITIVE_KEYS) 990 for k in PRESERVED_KEYS: 991 if k in current_options: 992 new_options[k] = current_options[k] 993 if current_options != new_options: 994 proj_db.options = json.dumps(new_options) 995 changed = True 996 997 if changed: 998 self.db.commit() 999 return True 1000 1001 def create_team(self, team_create: TeamModelCreate) -> TeamDatabase: 1002 db_team: TeamDatabase = TeamDatabase( 1003 name=team_create.name, 1004 description=team_create.description, 1005 creator_id=team_create.creator_id, 1006 budget=team_create.budget, 1007 created_at=datetime.now(timezone.utc), 1008 updated_at=datetime.now(timezone.utc), 1009 ) 1010 self.db.add(db_team) 1011 self.db.commit() 1012 self.db.refresh(db_team) 1013 return db_team 1014 1015 def get_team_by_id(self, team_id: int) -> Optional[TeamDatabase]: 1016 team: Optional[TeamDatabase] = ( 1017 self.db.query(TeamDatabase).filter(TeamDatabase.id == team_id).first() 1018 ) 1019 return team 1020 1021 def get_team_by_name(self, name: str) -> Optional[TeamDatabase]: 1022 team: Optional[TeamDatabase] = ( 1023 self.db.query(TeamDatabase).filter(TeamDatabase.name == name).first() 1024 ) 1025 return team 1026 1027 def get_teams(self) -> List[TeamDatabase]: 1028 teams: List[TeamDatabase] = self.db.query(TeamDatabase).all() 1029 return teams 1030 1031 def update_team(self, team: TeamDatabase, team_update: TeamModelUpdate) -> bool: 1032 changed = False 1033 1034 if team_update.name is not None and team.name != team_update.name: 1035 team.name = team_update.name 1036 changed = True 1037 1038 if team_update.description is not None and team.description != team_update.description: 1039 team.description = team_update.description 1040 changed = True 1041 1042 if team_update.budget is not None and team.budget != team_update.budget: 1043 team.budget = team_update.budget 1044 changed = True 1045 1046 if team_update.branding is not None: 1047 import json 1048 team.branding = json.dumps(team_update.branding.model_dump()) 1049 changed = True 1050 1051 if changed: 1052 team.updated_at = datetime.now(timezone.utc) 1053 self.db.commit() 1054 return changed 1055 1056 def delete_team(self, team: TeamDatabase) -> bool: 1057 self.db.delete(team) 1058 self.db.commit() 1059 return True 1060 1061 def add_user_to_team(self, team: TeamDatabase, user: UserDatabase) -> bool: 1062 if user not in team.users: 1063 team.users.append(user) 1064 self.db.commit() 1065 return True 1066 1067 def remove_user_from_team(self, team: TeamDatabase, user: UserDatabase) -> bool: 1068 if user in team.users: 1069 team.users.remove(user) 1070 self.db.commit() 1071 return True 1072 1073 def add_admin_to_team(self, team: TeamDatabase, user: UserDatabase) -> bool: 1074 if user not in team.admins: 1075 team.admins.append(user) 1076 self.db.commit() 1077 return True 1078 1079 def remove_admin_from_team(self, team: TeamDatabase, user: UserDatabase) -> bool: 1080 if user in team.admins: 1081 team.admins.remove(user) 1082 self.db.commit() 1083 return True 1084 1085 def add_project_to_team(self, team: TeamDatabase, project: ProjectDatabase) -> bool: 1086 if project not in team.projects: 1087 team.projects.append(project) 1088 self.db.commit() 1089 return True 1090 1091 def remove_project_from_team(self, team: TeamDatabase, project: ProjectDatabase) -> bool: 1092 if project in team.projects: 1093 team.projects.remove(project) 1094 self.db.commit() 1095 return True 1096 1097 def add_llm_to_team(self, team: TeamDatabase, llm: LLMDatabase) -> bool: 1098 if llm not in team.llms: 1099 team.llms.append(llm) 1100 self.db.commit() 1101 return True 1102 1103 def remove_llm_from_team(self, team: TeamDatabase, llm: LLMDatabase) -> bool: 1104 if llm in team.llms: 1105 team.llms.remove(llm) 1106 self.db.commit() 1107 return True 1108 1109 def add_embedding_to_team(self, team: TeamDatabase, embedding: EmbeddingDatabase) -> bool: 1110 if embedding not in team.embeddings: 1111 team.embeddings.append(embedding) 1112 self.db.commit() 1113 return True 1114 1115 def remove_embedding_from_team(self, team: TeamDatabase, embedding: EmbeddingDatabase) -> bool: 1116 if embedding in team.embeddings: 1117 team.embeddings.remove(embedding) 1118 self.db.commit() 1119 return True 1120 1121 def get_teams_for_user(self, user_id: int) -> List[TeamDatabase]: 1122 """Get all teams where the user is a member or admin""" 1123 user = self.get_user_by_id(user_id) 1124 if user is None: 1125 return [] 1126 return list(set(user.teams + user.admin_teams)) 1127 1128 def get_settings(self) -> list[SettingDatabase]: 1129 from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, decrypt_field 1130 rows = self.db.query(SettingDatabase).all() 1131 for r in rows: 1132 if r.key in SETTINGS_ENCRYPTED_KEYS and r.value: 1133 self.db.expunge(r) 1134 r.value = decrypt_field(r.value) 1135 return rows 1136 1137 def get_setting(self, key: str) -> Optional[SettingDatabase]: 1138 from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, decrypt_field 1139 row = self.db.query(SettingDatabase).filter(SettingDatabase.key == key).first() 1140 if row and key in SETTINGS_ENCRYPTED_KEYS and row.value: 1141 self.db.expunge(row) 1142 row.value = decrypt_field(row.value) 1143 return row 1144 1145 def get_setting_value(self, key: str, default: str = "") -> str: 1146 """Get a setting value by key, returning default if not found or empty.""" 1147 row = self.get_setting(key) 1148 return row.value if row and row.value else default 1149 1150 def upsert_setting(self, key: str, value: str) -> None: 1151 from restai.utils.crypto import SETTINGS_ENCRYPTED_KEYS, encrypt_field 1152 stored_value = encrypt_field(value) if (key in SETTINGS_ENCRYPTED_KEYS and value) else value 1153 existing = self.db.query(SettingDatabase).filter(SettingDatabase.key == key).first() 1154 if existing: 1155 existing.value = stored_value 1156 else: 1157 self.db.add(SettingDatabase(key=key, value=stored_value)) 1158 try: 1159 self.db.commit() 1160 except Exception: 1161 self.db.rollback() 1162 1163 def get_team_spending(self, team_id: int) -> float: 1164 now = datetime.now(timezone.utc) 1165 month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) 1166 result = self.db.query( 1167 func.coalesce(func.sum(OutputDatabase.input_cost + OutputDatabase.output_cost), 0.0) 1168 ).filter( 1169 or_( 1170 OutputDatabase.project_id.in_( 1171 self.db.query(ProjectDatabase.id).filter(ProjectDatabase.team_id == team_id) 1172 ), 1173 OutputDatabase.team_id == team_id 1174 ), 1175 OutputDatabase.date >= month_start 1176 ).scalar() 1177 return float(result) 1178 1179 def update_team_members(self, team: TeamDatabase, team_update: TeamModelUpdate) -> bool: 1180 changed = False 1181 1182 # Update users 1183 if team_update.users is not None: 1184 team.users = [] 1185 for username in team_update.users: 1186 user_db = self.get_user_by_username(username) 1187 if user_db is not None: 1188 team.users.append(user_db) 1189 changed = True 1190 1191 # Update admins 1192 if team_update.admins is not None: 1193 team.admins = [] 1194 for username in team_update.admins: 1195 user_db = self.get_user_by_username(username) 1196 if user_db is not None: 1197 team.admins.append(user_db) 1198 changed = True 1199 1200 # Update projects 1201 if team_update.projects is not None: 1202 team.projects = [] 1203 for project_name in team_update.projects: 1204 project_db = self.get_project_by_name(project_name) 1205 if project_db is not None: 1206 team.projects.append(project_db) 1207 changed = True 1208 1209 # Update LLMs 1210 if team_update.llms is not None: 1211 team.llms = [] 1212 for llm_name in team_update.llms: 1213 llm_db = self.get_llm_by_name(llm_name) 1214 if llm_db is not None: 1215 team.llms.append(llm_db) 1216 changed = True 1217 1218 # Update embeddings 1219 if team_update.embeddings is not None: 1220 team.embeddings = [] 1221 for embedding_name in team_update.embeddings: 1222 embedding_db = self.get_embedding_by_name(embedding_name) 1223 if embedding_db is not None: 1224 team.embeddings.append(embedding_db) 1225 changed = True 1226 1227 # Update image generators 1228 if team_update.image_generators is not None: 1229 team.image_generators = [] 1230 self.db.flush() 1231 for gen_name in team_update.image_generators: 1232 team.image_generators.append( 1233 TeamImageGeneratorDatabase(team_id=team.id, generator_name=gen_name) 1234 ) 1235 changed = True 1236 1237 # Update audio generators 1238 if team_update.audio_generators is not None: 1239 team.audio_generators = [] 1240 self.db.flush() 1241 for gen_name in team_update.audio_generators: 1242 team.audio_generators.append( 1243 TeamAudioGeneratorDatabase(team_id=team.id, generator_name=gen_name) 1244 ) 1245 changed = True 1246 1247 if changed: 1248 team.updated_at = datetime.now(timezone.utc) 1249 self.db.commit() 1250 1251 return changed 1252 1253 1254 def add_image_generator_to_team(self, team: TeamDatabase, generator_name: str) -> bool: 1255 existing = self.db.query(TeamImageGeneratorDatabase).filter( 1256 TeamImageGeneratorDatabase.team_id == team.id, 1257 TeamImageGeneratorDatabase.generator_name == generator_name 1258 ).first() 1259 if existing is None: 1260 team.image_generators.append( 1261 TeamImageGeneratorDatabase(team_id=team.id, generator_name=generator_name) 1262 ) 1263 self.db.commit() 1264 return True 1265 1266 def remove_image_generator_from_team(self, team: TeamDatabase, generator_name: str) -> bool: 1267 item = self.db.query(TeamImageGeneratorDatabase).filter( 1268 TeamImageGeneratorDatabase.team_id == team.id, 1269 TeamImageGeneratorDatabase.generator_name == generator_name 1270 ).first() 1271 if item is not None: 1272 self.db.delete(item) 1273 self.db.commit() 1274 return True 1275 1276 def add_audio_generator_to_team(self, team: TeamDatabase, generator_name: str) -> bool: 1277 existing = self.db.query(TeamAudioGeneratorDatabase).filter( 1278 TeamAudioGeneratorDatabase.team_id == team.id, 1279 TeamAudioGeneratorDatabase.generator_name == generator_name 1280 ).first() 1281 if existing is None: 1282 team.audio_generators.append( 1283 TeamAudioGeneratorDatabase(team_id=team.id, generator_name=generator_name) 1284 ) 1285 self.db.commit() 1286 return True 1287 1288 def remove_audio_generator_from_team(self, team: TeamDatabase, generator_name: str) -> bool: 1289 item = self.db.query(TeamAudioGeneratorDatabase).filter( 1290 TeamAudioGeneratorDatabase.team_id == team.id, 1291 TeamAudioGeneratorDatabase.generator_name == generator_name 1292 ).first() 1293 if item is not None: 1294 self.db.delete(item) 1295 self.db.commit() 1296 return True 1297 1298 1299 def _create_prompt_version(self, project_id: int, system_prompt: str, user_id: int = None): 1300 """Create a new prompt version record, marking it as active.""" 1301 from restai.models.databasemodels import PromptVersionDatabase 1302 1303 # Deactivate current active version 1304 self.db.query(PromptVersionDatabase).filter( 1305 PromptVersionDatabase.project_id == project_id, 1306 PromptVersionDatabase.is_active == True, 1307 ).update({"is_active": False}) 1308 1309 # Get next version number 1310 max_version = ( 1311 self.db.query(func.max(PromptVersionDatabase.version)) 1312 .filter(PromptVersionDatabase.project_id == project_id) 1313 .scalar() 1314 ) or 0 1315 1316 version = PromptVersionDatabase( 1317 project_id=project_id, 1318 version=max_version + 1, 1319 system_prompt=system_prompt or "", 1320 created_by=user_id, 1321 created_at=datetime.now(timezone.utc), 1322 is_active=True, 1323 ) 1324 self.db.add(version) 1325 1326 def get_prompt_versions(self, project_id: int): 1327 """Get all prompt versions for a project, newest first.""" 1328 from restai.models.databasemodels import PromptVersionDatabase 1329 return ( 1330 self.db.query(PromptVersionDatabase) 1331 .filter(PromptVersionDatabase.project_id == project_id) 1332 .order_by(PromptVersionDatabase.version.desc()) 1333 .all() 1334 ) 1335 1336 def get_prompt_version(self, version_id: int): 1337 """Get a specific prompt version by ID.""" 1338 from restai.models.databasemodels import PromptVersionDatabase 1339 return self.db.query(PromptVersionDatabase).filter(PromptVersionDatabase.id == version_id).first() 1340 1341 def get_active_prompt_version(self, project_id: int): 1342 """Get the active prompt version for a project.""" 1343 from restai.models.databasemodels import PromptVersionDatabase 1344 return ( 1345 self.db.query(PromptVersionDatabase) 1346 .filter(PromptVersionDatabase.project_id == project_id, PromptVersionDatabase.is_active == True) 1347 .first() 1348 ) 1349 1350 # ── Project Tools (agent-created) ──────────────────────────────────── 1351 1352 def get_project_tools(self, project_id: int) -> list[ProjectToolDatabase]: 1353 return ( 1354 self.db.query(ProjectToolDatabase) 1355 .filter(ProjectToolDatabase.project_id == project_id) 1356 .order_by(ProjectToolDatabase.name) 1357 .all() 1358 ) 1359 1360 def get_project_tool_by_name(self, project_id: int, name: str) -> Optional[ProjectToolDatabase]: 1361 return ( 1362 self.db.query(ProjectToolDatabase) 1363 .filter(ProjectToolDatabase.project_id == project_id, ProjectToolDatabase.name == name) 1364 .first() 1365 ) 1366 1367 def upsert_project_tool(self, project_id: int, name: str, description: str, parameters: str, code: str) -> ProjectToolDatabase: 1368 from datetime import datetime, timezone 1369 now = datetime.now(timezone.utc) 1370 existing = self.get_project_tool_by_name(project_id, name) 1371 if existing: 1372 existing.description = description 1373 existing.parameters = parameters 1374 existing.code = code 1375 existing.updated_at = now 1376 self.db.commit() 1377 return existing 1378 tool = ProjectToolDatabase( 1379 project_id=project_id, 1380 name=name, 1381 description=description, 1382 parameters=parameters, 1383 code=code, 1384 created_at=now, 1385 updated_at=now, 1386 ) 1387 self.db.add(tool) 1388 self.db.commit() 1389 return tool 1390 1391 def delete_project_tool(self, project_id: int, name: str) -> bool: 1392 tool = self.get_project_tool_by_name(project_id, name) 1393 if tool: 1394 self.db.delete(tool) 1395 self.db.commit() 1396 return True 1397 return False 1398 1399 # ── Project Routines (scheduled messages) ──────────────────────────── 1400 1401 def get_project_routines(self, project_id: int) -> list[ProjectRoutineDatabase]: 1402 return ( 1403 self.db.query(ProjectRoutineDatabase) 1404 .filter(ProjectRoutineDatabase.project_id == project_id) 1405 .order_by(ProjectRoutineDatabase.name) 1406 .all() 1407 ) 1408 1409 def get_all_enabled_routines(self) -> list[ProjectRoutineDatabase]: 1410 return ( 1411 self.db.query(ProjectRoutineDatabase) 1412 .filter(ProjectRoutineDatabase.enabled == True) 1413 .all() 1414 ) 1415 1416 def get_project_routine_by_id(self, routine_id: int) -> Optional[ProjectRoutineDatabase]: 1417 return self.db.query(ProjectRoutineDatabase).filter(ProjectRoutineDatabase.id == routine_id).first() 1418 1419 def create_project_routine(self, project_id: int, name: str, message: str, schedule_minutes: int, enabled: bool = True) -> ProjectRoutineDatabase: 1420 from datetime import datetime, timezone 1421 now = datetime.now(timezone.utc) 1422 routine = ProjectRoutineDatabase( 1423 project_id=project_id, 1424 name=name, 1425 message=message, 1426 schedule_minutes=schedule_minutes, 1427 enabled=enabled, 1428 created_at=now, 1429 updated_at=now, 1430 ) 1431 self.db.add(routine) 1432 self.db.commit() 1433 self.db.refresh(routine) 1434 return routine 1435 1436 def delete_project_routine(self, routine_id: int) -> bool: 1437 routine = self.get_project_routine_by_id(routine_id) 1438 if routine: 1439 self.db.delete(routine) 1440 self.db.commit() 1441 return True 1442 return False 1443 1444 # ── Cron Logs ──────────────────────────────────────────────────────── 1445 1446 def create_cron_log(self, job, status, message, details=None, items_processed=0, duration_ms=None): 1447 from datetime import datetime, timezone 1448 entry = CronLogDatabase( 1449 job=job, 1450 status=status, 1451 message=message, 1452 details=details, 1453 items_processed=items_processed, 1454 duration_ms=duration_ms, 1455 date=datetime.now(timezone.utc), 1456 ) 1457 self.db.add(entry) 1458 self.db.commit() 1459 return entry 1460 1461 def get_cron_logs(self, job=None, status=None, start=0, end=50): 1462 query = self.db.query(CronLogDatabase).order_by(CronLogDatabase.date.desc()) 1463 if job: 1464 query = query.filter(CronLogDatabase.job == job) 1465 if status: 1466 query = query.filter(CronLogDatabase.status == status) 1467 return query.offset(start).limit(end - start).all() 1468 1469 1470 def get_db_wrapper() -> DBWrapper: 1471 wrapper: DBWrapper = DBWrapper() 1472 try: 1473 return wrapper 1474 finally: 1475 wrapper.db.close()