/ restai / cli.py
cli.py
  1  """RESTai CLI — run the platform from a pip install."""
  2  
  3  import argparse
  4  import os
  5  import sys
  6  
  7  
  8  def _load_env(env_file):
  9      """Load environment variables from a .env file."""
 10      if not env_file:
 11          return
 12      if not os.path.isfile(env_file):
 13          print(f"Error: env file '{env_file}' not found.", file=sys.stderr)
 14          sys.exit(1)
 15      try:
 16          from dotenv import load_dotenv
 17          load_dotenv(env_file, override=True)
 18      except ImportError:
 19          # Manual fallback
 20          with open(env_file) as f:
 21              for line in f:
 22                  line = line.strip()
 23                  if not line or line.startswith("#"):
 24                      continue
 25                  if "=" in line:
 26                      key, _, value = line.partition("=")
 27                      key = key.strip()
 28                      value = value.strip().strip("'\"")
 29                      os.environ[key] = value
 30  
 31  
 32  def cmd_serve(args):
 33      """Start the RESTai server."""
 34      _load_env(args.env_file)
 35  
 36      # Set port from args or env
 37      if args.port:
 38          os.environ["RESTAI_PORT"] = str(args.port)
 39  
 40      import uvicorn
 41      from restai.config import RESTAI_PORT
 42  
 43      port = int(args.port or RESTAI_PORT or 9000)
 44      workers = args.workers
 45  
 46      print(f"Starting RESTai on port {port} with {workers} worker(s)")
 47      uvicorn.run(
 48          "restai.main:app",
 49          host=args.host,
 50          port=port,
 51          workers=workers,
 52          reload=args.reload,
 53      )
 54  
 55  
 56  def cmd_migrate(args):
 57      """Run database migrations."""
 58      _load_env(args.env_file)
 59  
 60      from alembic.config import Config
 61      from alembic import command
 62      from restai.config import POSTGRES_URL, MYSQL_URL, POSTGRES_HOST, MYSQL_HOST
 63  
 64      if POSTGRES_HOST:
 65          db_url = POSTGRES_URL
 66      elif MYSQL_HOST:
 67          db_url = MYSQL_URL
 68      else:
 69          db_url = "sqlite:///./restai.db"
 70  
 71      # Find alembic.ini and migrations — check repo root first, then site-packages
 72      import sysconfig
 73      package_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 74      site_packages = sysconfig.get_path("purelib")
 75  
 76      alembic_ini = os.path.join(package_root, "alembic.ini")
 77      if not os.path.isfile(alembic_ini):
 78          alembic_ini = os.path.join(site_packages, "alembic.ini")
 79      if not os.path.isfile(alembic_ini):
 80          print("Error: alembic.ini not found", file=sys.stderr)
 81          sys.exit(1)
 82  
 83      migrations_dir = os.path.join(package_root, "migrations")
 84      if not os.path.isdir(migrations_dir):
 85          migrations_dir = os.path.join(site_packages, "migrations")
 86  
 87      alembic_cfg = Config(alembic_ini)
 88      alembic_cfg.set_main_option("sqlalchemy.url", db_url)
 89      alembic_cfg.set_main_option("script_location", migrations_dir)
 90  
 91      if args.direction == "upgrade":
 92          command.upgrade(alembic_cfg, "head")
 93          print("Database migrated successfully.")
 94      else:
 95          command.downgrade(alembic_cfg, "-1")
 96          print("Database downgraded.")
 97  
 98  
 99  def cmd_init(args):
100      """Initialize the database with tables, admin user, and default models."""
101      _load_env(args.env_file)
102  
103      # The root database.py script creates tables and seeds data on import
104      import importlib
105      import database  # noqa: F401 — side-effect import that creates tables
106      print("Database initialized.")
107  
108  
109  def _run_script(args, script_path):
110      """Run a standalone script."""
111      _load_env(args.env_file)
112      import importlib.util
113      spec = importlib.util.spec_from_file_location("script", script_path)
114      if spec is None:
115          # Try relative to package
116          package_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
117          script_path = os.path.join(package_root, script_path)
118          spec = importlib.util.spec_from_file_location("script", script_path)
119      if spec is None:
120          print(f"Error: script not found: {script_path}", file=sys.stderr)
121          sys.exit(1)
122      mod = importlib.util.module_from_spec(spec)
123      spec.loader.exec_module(mod)
124      if hasattr(mod, "main"):
125          mod.main()
126  
127  
128  def main():
129      parser = argparse.ArgumentParser(
130          prog="restai",
131          description="RESTai — AI as a Service Platform",
132      )
133      parser.add_argument("--env-file", "-e", default=None, help="Path to .env file")
134      subparsers = parser.add_subparsers(dest="command")
135  
136      # serve
137      serve_parser = subparsers.add_parser("serve", help="Start the RESTai server")
138      serve_parser.add_argument("--host", default="0.0.0.0", help="Bind host (default: 0.0.0.0)")
139      serve_parser.add_argument("--port", "-p", type=int, default=None, help="Port (default: 9000 or RESTAI_PORT)")
140      serve_parser.add_argument("--workers", "-w", type=int, default=4, help="Number of workers (default: 4)")
141      serve_parser.add_argument("--reload", action="store_true", help="Enable auto-reload (dev mode)")
142      serve_parser.set_defaults(func=cmd_serve)
143  
144      # migrate
145      migrate_parser = subparsers.add_parser("migrate", help="Run database migrations")
146      migrate_parser.add_argument("direction", nargs="?", default="upgrade", choices=["upgrade", "downgrade"])
147      migrate_parser.set_defaults(func=cmd_migrate)
148  
149      # init
150      init_parser = subparsers.add_parser("init", help="Initialize database schema and admin user")
151      init_parser.set_defaults(func=cmd_init)
152  
153      # crons (run all)
154      crons_parser = subparsers.add_parser("crons", help="Run all cron jobs (single entry point)")
155      crons_parser.set_defaults(func=lambda args: _run_script(args, "crons/runner.py"))
156  
157      # sync
158      sync_parser = subparsers.add_parser("sync", help="Run knowledge base sync (cron-friendly)")
159      sync_parser.set_defaults(func=lambda args: _run_script(args, "crons/sync.py"))
160  
161      # telegram
162      telegram_parser = subparsers.add_parser("telegram", help="Poll Telegram for updates (cron-friendly)")
163      telegram_parser.set_defaults(func=lambda args: _run_script(args, "crons/telegram.py"))
164  
165      # slack
166      slack_parser = subparsers.add_parser("slack", help="Poll Slack for messages (cron-friendly)")
167      slack_parser.set_defaults(func=lambda args: _run_script(args, "crons/slack.py"))
168  
169      # docker-cleanup
170      docker_parser = subparsers.add_parser("docker-cleanup", help="Remove idle Docker containers (cron-friendly)")
171      docker_parser.set_defaults(func=lambda args: _run_script(args, "crons/docker_cleanup.py"))
172  
173      # routines
174      routines_parser = subparsers.add_parser("routines", help="Run project routines (cron-friendly)")
175      routines_parser.set_defaults(func=lambda args: _run_script(args, "crons/routines.py"))
176  
177      args = parser.parse_args()
178      if not args.command:
179          # Default to serve
180          args.command = "serve"
181          args.host = "0.0.0.0"
182          args.port = None
183          args.workers = 4
184          args.reload = False
185          args.func = cmd_serve
186  
187      args.func(args)
188  
189  
190  if __name__ == "__main__":
191      main()