/ database.py
database.py
  1  import json
  2  import os
  3  from datetime import datetime
  4  
  5  import bcrypt
  6  from sqlalchemy import create_engine, inspect
  7  from sqlalchemy.orm import sessionmaker
  8  
  9  from restai.config import (
 10      MYSQL_HOST,
 11      MYSQL_URL,
 12      POSTGRES_HOST,
 13      POSTGRES_URL,
 14      RESTAI_DEFAULT_PASSWORD,
 15  )
 16  from restai.models.databasemodels import (
 17      ApiKeyDatabase,
 18      Base,
 19      LLMDatabase,
 20      ProjectDatabase,
 21      SettingDatabase,
 22      UserDatabase,
 23      EmbeddingDatabase,
 24      TeamDatabase
 25  )
 26  from restai.tools import DEFAULT_LLMS, DEFAULT_EMBEDDINGS
 27  
 28  if MYSQL_HOST:
 29      print("Using MySQL database")
 30      engine = create_engine(MYSQL_URL,
 31                             pool_size=30,
 32                             max_overflow=100,
 33                             pool_recycle=900)
 34  elif POSTGRES_HOST:
 35      print("Using PostgreSQL database")
 36      engine = create_engine(POSTGRES_URL,
 37                             pool_size=30,
 38                             max_overflow=100,
 39                             pool_recycle=900)
 40  else:
 41      print("Using sqlite database.")
 42      engine = create_engine(
 43          "sqlite:///./restai.db",
 44          connect_args={
 45              "check_same_thread": False},
 46          pool_size=30,
 47          max_overflow=100,
 48          pool_recycle=300)
 49  
 50  # Forcefully raise on failed connection
 51  try:
 52      with engine.connect() as conn:
 53          pass
 54  except Exception:
 55      raise
 56  
 57  SessionLocal = sessionmaker(
 58      autocommit=False, autoflush=False, bind=engine)
 59  
 60  def hash_password(password: str) -> str:
 61      return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
 62  
 63  if os.getenv("RESTAI_DB_SCHEMA"):
 64      Base.metadata.create_all(bind=engine)
 65  else:
 66      if "users" not in inspect(engine).get_table_names():
 67          print("Initializing database...")
 68          default_password = RESTAI_DEFAULT_PASSWORD
 69          Base.metadata.create_all(bind=engine)
 70          dbi = SessionLocal()
 71          db_user = UserDatabase(
 72              username="admin",
 73              hashed_password=hash_password(default_password),
 74              is_admin=True)
 75          dbi.add(db_user)
 76          
 77          dbi.commit()
 78          
 79          # Create a default team and add the admin user to it
 80          default_team = TeamDatabase(
 81              name="Default Team",
 82              description="Default team created during initialization",
 83              created_at=datetime.now()
 84          )
 85          dbi.add(default_team)
 86          dbi.commit()
 87          
 88          # Add admin as both a team member and a team admin
 89          default_team.users.append(db_user)
 90          default_team.admins.append(db_user)
 91          dbi.commit()
 92  
 93          for llm in DEFAULT_LLMS:
 94              llm_class, llm_args, privacy, description, input_cost, output_cost = DEFAULT_LLMS[llm]
 95              db_llm = LLMDatabase(
 96                  name=llm,
 97                  class_name=llm_class,
 98                  options=json.dumps(llm_args),
 99                  privacy=privacy,
100                  description=description,
101                  input_cost=input_cost,
102                  output_cost=output_cost
103              )
104              dbi.add(db_llm)
105              
106              # Add this LLM to the default team
107              default_team.llms.append(db_llm)
108              
109          dbi.commit()
110          
111          for embedding in DEFAULT_EMBEDDINGS:
112              embedding_class, embedding_args, privacy, description, dimension = DEFAULT_EMBEDDINGS[embedding]
113              db_embedding = EmbeddingDatabase(
114                  name=embedding,
115                  class_name=embedding_class,
116                  options=json.dumps(embedding_args),
117                  privacy=privacy,
118                  description=description,
119                  dimension=dimension
120              )
121              dbi.add(db_embedding)
122              
123              # Add this embedding model to the default team
124              default_team.embeddings.append(db_embedding)
125              
126          dbi.commit()
127          
128          dbi.commit()
129          dbi.close()
130          print("Database initialized.")
131          print("Default LLMs initialized.")
132          print("Default admin user created (admin:" + default_password + ").")
133      else:
134          # Ensure new tables are created on existing databases
135          Base.metadata.create_all(bind=engine, checkfirst=True)
136          print("Database already initialized.")