/ database.py
database.py
1 import json 2 import os 3 from datetime import datetime 4 5 import bcrypt 6 from sqlalchemy import create_engine, inspect 7 from sqlalchemy.orm import sessionmaker 8 9 from restai.config import ( 10 MYSQL_HOST, 11 MYSQL_URL, 12 POSTGRES_HOST, 13 POSTGRES_URL, 14 RESTAI_DEFAULT_PASSWORD, 15 ) 16 from restai.models.databasemodels import ( 17 ApiKeyDatabase, 18 Base, 19 LLMDatabase, 20 ProjectDatabase, 21 SettingDatabase, 22 UserDatabase, 23 EmbeddingDatabase, 24 TeamDatabase 25 ) 26 from restai.tools import DEFAULT_LLMS, DEFAULT_EMBEDDINGS 27 28 if MYSQL_HOST: 29 print("Using MySQL database") 30 engine = create_engine(MYSQL_URL, 31 pool_size=30, 32 max_overflow=100, 33 pool_recycle=900) 34 elif POSTGRES_HOST: 35 print("Using PostgreSQL database") 36 engine = create_engine(POSTGRES_URL, 37 pool_size=30, 38 max_overflow=100, 39 pool_recycle=900) 40 else: 41 print("Using sqlite database.") 42 engine = create_engine( 43 "sqlite:///./restai.db", 44 connect_args={ 45 "check_same_thread": False}, 46 pool_size=30, 47 max_overflow=100, 48 pool_recycle=300) 49 50 # Forcefully raise on failed connection 51 try: 52 with engine.connect() as conn: 53 pass 54 except Exception: 55 raise 56 57 SessionLocal = sessionmaker( 58 autocommit=False, autoflush=False, bind=engine) 59 60 def hash_password(password: str) -> str: 61 return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') 62 63 if os.getenv("RESTAI_DB_SCHEMA"): 64 Base.metadata.create_all(bind=engine) 65 else: 66 if "users" not in inspect(engine).get_table_names(): 67 print("Initializing database...") 68 default_password = RESTAI_DEFAULT_PASSWORD 69 Base.metadata.create_all(bind=engine) 70 dbi = SessionLocal() 71 db_user = UserDatabase( 72 username="admin", 73 hashed_password=hash_password(default_password), 74 is_admin=True) 75 dbi.add(db_user) 76 77 dbi.commit() 78 79 # Create a default team and add the admin user to it 80 default_team = TeamDatabase( 81 name="Default Team", 82 description="Default team created during initialization", 83 created_at=datetime.now() 84 ) 85 dbi.add(default_team) 86 dbi.commit() 87 88 # Add admin as both a team member and a team admin 89 default_team.users.append(db_user) 90 default_team.admins.append(db_user) 91 dbi.commit() 92 93 for llm in DEFAULT_LLMS: 94 llm_class, llm_args, privacy, description, input_cost, output_cost = DEFAULT_LLMS[llm] 95 db_llm = LLMDatabase( 96 name=llm, 97 class_name=llm_class, 98 options=json.dumps(llm_args), 99 privacy=privacy, 100 description=description, 101 input_cost=input_cost, 102 output_cost=output_cost 103 ) 104 dbi.add(db_llm) 105 106 # Add this LLM to the default team 107 default_team.llms.append(db_llm) 108 109 dbi.commit() 110 111 for embedding in DEFAULT_EMBEDDINGS: 112 embedding_class, embedding_args, privacy, description, dimension = DEFAULT_EMBEDDINGS[embedding] 113 db_embedding = EmbeddingDatabase( 114 name=embedding, 115 class_name=embedding_class, 116 options=json.dumps(embedding_args), 117 privacy=privacy, 118 description=description, 119 dimension=dimension 120 ) 121 dbi.add(db_embedding) 122 123 # Add this embedding model to the default team 124 default_team.embeddings.append(db_embedding) 125 126 dbi.commit() 127 128 dbi.commit() 129 dbi.close() 130 print("Database initialized.") 131 print("Default LLMs initialized.") 132 print("Default admin user created (admin:" + default_password + ").") 133 else: 134 # Ensure new tables are created on existing databases 135 Base.metadata.create_all(bind=engine, checkfirst=True) 136 print("Database already initialized.")