/ backend / src / main.rs
main.rs
  1  #![forbid(unsafe_code)]
  2  // `rocket` macros generate a bunch of clippy warnings.
  3  // TODO: remove these when updating rocket past 0.5.0.
  4  #![allow(clippy::blocks_in_conditions)]
  5  #![allow(clippy::to_string_in_format_args)]
  6  
  7  #[macro_use]
  8  extern crate rocket;
  9  
 10  use std::str::FromStr;
 11  
 12  use api::common::response::{ErrMsg, ResponseStatus};
 13  use api::core::core_routes;
 14  use api::frontend::site_routes;
 15  use api::oauth::oauth_routes;
 16  use api::openid::{self, openid_routes};
 17  use eyre::{eyre, WrapErr};
 18  use rocket::http::Status;
 19  use rocket::response::Responder;
 20  use rocket::{fs::FileServer, Request};
 21  use rocket_dyn_templates::Template;
 22  use services::email_service::EmailProvider;
 23  use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
 24  use sqlx::ConnectOptions;
 25  use tokio::task;
 26  
 27  use crate::util::config::Config;
 28  
 29  mod api;
 30  mod background_task;
 31  mod db;
 32  pub mod models;
 33  pub mod services;
 34  pub mod util;
 35  
 36  use mobc::Pool;
 37  use mobc_redis::redis;
 38  use mobc_redis::RedisConnectionManager;
 39  
 40  const MAX_REDIS_CONNECTIONS: u64 = 20;
 41  
 42  #[rocket::main]
 43  async fn main() -> eyre::Result<()> {
 44      color_eyre::install()?;
 45  
 46      // Load
 47      let config = Config::new().wrap_err("Failed to load config")?;
 48  
 49      // Setup DB
 50      let mut pg_options = PgConnectOptions::from_str(&config.database_url).wrap_err(eyre!(
 51          "Invalid database url provided {:?}",
 52          config.database_url
 53      ))?;
 54  
 55      if !config.log_db_statements {
 56          pg_options = pg_options.disable_statement_logging();
 57      }
 58  
 59      let db_pool = PgPoolOptions::new()
 60          .max_connections(5)
 61          .connect_with(pg_options)
 62          .await
 63          .wrap_err("Failed to connect to DB")?;
 64  
 65      sqlx::migrate!("./migrations")
 66          .run(&db_pool)
 67          .await
 68          .wrap_err("Failed to run migrations")?;
 69  
 70      db::init(&db_pool)
 71          .await
 72          .wrap_err("Failed to initialize db")?;
 73  
 74      // Setup Redis cache
 75      let redis_client = redis::Client::open(config.redis_url.clone()).wrap_err(eyre!(
 76          "Failed to connect to redis on URL {:?}",
 77          config.redis_url
 78      ))?;
 79      let redis_manager = RedisConnectionManager::new(redis_client);
 80      let redis_pool = Pool::builder()
 81          .max_open(MAX_REDIS_CONNECTIONS)
 82          .build(redis_manager);
 83  
 84      // Test redis connection
 85      redis_pool
 86          .get()
 87          .await
 88          .wrap_err("Test connection to redis pool failed")?;
 89  
 90      // Setup background tasks
 91      let pool_clone = db_pool.clone();
 92      task::spawn(background_task::run_background_tasks(pool_clone));
 93  
 94      let email_provider = EmailProvider::from(&config.email);
 95  
 96      let rocket = rocket::build()
 97          .mount("/api/core", core_routes())
 98          .mount("/api/site", site_routes())
 99          .mount("/api/oauth", oauth_routes())
100          .mount("/api/openid", openid_routes())
101          .mount("/api/public", FileServer::from("static/public"))
102          .mount(
103              "/",
104              routes![openid::configuration::get_openid_configuration],
105          )
106          .register("/", catchers![unauthorized, forbidden])
107          .manage(db_pool.clone())
108          .manage(redis_pool)
109          .manage(config)
110          .manage(email_provider)
111          .attach(Template::fairing());
112  
113      rocket.launch().await?;
114  
115      Ok(())
116  }
117  
118  struct UnauthorizedResponse(ResponseStatus<()>);
119  
120  impl<'r> Responder<'r, 'r> for UnauthorizedResponse {
121      fn respond_to(self, req: &'r Request<'_>) -> rocket::response::Result<'r> {
122          rocket::Response::build_from(self.0.respond_to(req)?)
123              .raw_header("location", "/api/core/login")
124              .ok()
125      }
126  }
127  
128  #[catch(401)]
129  fn unauthorized() -> UnauthorizedResponse {
130      UnauthorizedResponse(ResponseStatus::err(
131          Status::Unauthorized,
132          ErrMsg::Unauthorized,
133      ))
134  }
135  
136  const FORBIDDEN_TEMPLATE_NAME: &str = "forbidden-handler";
137  
138  #[catch(403)]
139  fn forbidden(_req: &Request) -> Template {
140      Template::render(FORBIDDEN_TEMPLATE_NAME, ())
141  }