/ backend / src / services / oauth_authorization_service.rs
oauth_authorization_service.rs
  1  use std::collections::HashSet;
  2  
  3  use chrono::{DateTime, Duration, Utc};
  4  use mobc_redis::RedisConnectionManager;
  5  use rand::{distributions::Alphanumeric, thread_rng, Rng};
  6  use serde::{Deserialize, Serialize};
  7  use sqlx::Pool;
  8  use uuid::Uuid;
  9  
 10  use crate::{
 11      db::{
 12          client_scope_repository, new_transaction, oauth_client_repository,
 13          user_client_consent_repository, DB,
 14      },
 15      models::{client_scope::ClientScope, oauth_scope::OauthScope},
 16      util::accounts_error::AccountsError,
 17  };
 18  
 19  use super::redis_service::{self, RedisError};
 20  
 21  #[derive(Debug, thiserror::Error)]
 22  pub enum Oauth2Error {
 23      #[error("Client ID did not match the provided code")]
 24      InvalidClientId,
 25      #[error("There is no client with that client_id")]
 26      NoClientWithId,
 27      #[error("Redirect uri doesn't match client")]
 28      InvalidRedirectUri,
 29      #[error("Invalid client secret")]
 30      InvalidClientSecret,
 31      #[error("Invalid authorization code provided")]
 32      InvalidCode,
 33      #[error("Failed to generate expiration time")]
 34      ExpirationTimeGeneration,
 35      #[error("Sqlx error")]
 36      SqlxError(#[from] sqlx::Error),
 37      #[error("Accounts error")]
 38      AccountsError(#[from] AccountsError),
 39      #[error("Redis error")]
 40      RedisError(#[from] RedisError),
 41      #[error("Failed to insert access token into the redis cache")]
 42      CacheInsertion,
 43      #[error("The user has not consented to this client or has consented to the client but not all of the requested scopes")]
 44      MissingClientConsent { client_name: String },
 45      #[error("Invalid scope")]
 46      InvalidScope,
 47      #[error("Requested scope is not registered for the client")]
 48      ScopeNotRegistered,
 49  }
 50  
 51  const AUTH_TOKEN_LENGTH: usize = 48;
 52  
 53  const AUTHORIZATION_KEY_REDIS_PREFIX: &str = "authorization_codes";
 54  // 5 minutes
 55  const AUTHORIZATION_CODE_EXPIRATION_SECONDS: usize = 5 * 60;
 56  
 57  const ACCESS_TOKEN_LENGTH: usize = 128;
 58  pub const ACCESS_TOKEN_KEY_REDIS_PREFIX: &str = "access_tokens";
 59  // 1 hour
 60  const ACCESS_TOKEN_EXPIRATION_SECONDS: i64 = 60 * 60;
 61  
 62  #[derive(Deserialize, Serialize, Debug)]
 63  struct AuthToken {
 64      code: String,
 65      client_id: String,
 66      account_id: Uuid,
 67      scopes: HashSet<OauthScope>,
 68  }
 69  
 70  pub async fn get_auth_token(
 71      db_pool: &Pool<DB>,
 72      redis_pool: &mobc::Pool<RedisConnectionManager>,
 73      client_id: String,
 74      redirect_uri: &String,
 75      state: &String,
 76      account_id: Uuid,
 77      requested_scopes: &HashSet<OauthScope>,
 78  ) -> Result<String, Oauth2Error> {
 79      let mut transaction = new_transaction(db_pool).await?;
 80  
 81      let client = oauth_client_repository::get_by_client_id(&mut transaction, &client_id)
 82          .await?
 83          .ok_or(Oauth2Error::NoClientWithId)?;
 84  
 85      if redirect_uri != &client.redirect_uri {
 86          error!(
 87              "Redirect uri doesn't match, request redirect_uri: {}, client set redirect_uri: {}",
 88              redirect_uri, client.redirect_uri
 89          );
 90          return Err(Oauth2Error::InvalidRedirectUri);
 91      }
 92  
 93      // Ensure that the requested scopes are valid for this client.
 94      let registered_scopes =
 95          client_scope_repository::get_all_for_client(&mut transaction, &client).await?;
 96      if !validate_scopes(requested_scopes, &registered_scopes) {
 97          error!("Client ({client_id:?}) requested scope they are not registered for, requested: {requested_scopes:?}");
 98          return Err(Oauth2Error::ScopeNotRegistered);
 99      }
100  
101      // Check if the user has consented to this client and if not, ask if they would.
102      let Some(user_client_consent) = user_client_consent_repository::get_by_client_and_account(
103          &mut transaction,
104          &client,
105          &account_id,
106      )
107      .await?
108      else {
109          return Err(Oauth2Error::MissingClientConsent {
110              client_name: client.client_name,
111          });
112      };
113  
114      let consented_scopes = client_scope_repository::consented_by_user_for_client(
115          &mut transaction,
116          &client,
117          &user_client_consent,
118      )
119      .await?;
120  
121      if !validate_scopes(requested_scopes, &consented_scopes) {
122          warn!("The user has consented to this client but not all of the requested scopes, asking if they would consider it");
123          return Err(Oauth2Error::MissingClientConsent {
124              client_name: client.client_name,
125          });
126      }
127  
128      // We accept the request, now we generate and return a token.
129      let code: String = thread_rng()
130          .sample_iter(&Alphanumeric)
131          .take(AUTH_TOKEN_LENGTH)
132          .map(char::from)
133          .collect();
134  
135      let auth_token = AuthToken {
136          code: code.clone(),
137          client_id,
138          account_id,
139          scopes: requested_scopes.clone(),
140      };
141  
142      let key = format!("{}:{}", AUTHORIZATION_KEY_REDIS_PREFIX, code);
143      redis_service::redis_set(
144          redis_pool,
145          key,
146          auth_token,
147          AUTHORIZATION_CODE_EXPIRATION_SECONDS,
148      )
149      .await?;
150  
151      transaction.commit().await?;
152  
153      Ok(format!(
154          "{}?state={}&code={}",
155          client.redirect_uri, state, code
156      ))
157  }
158  
159  #[derive(Deserialize, Serialize, Debug, Clone)]
160  pub struct AccessToken {
161      pub access_token: String,
162      pub expiration: DateTime<Utc>,
163      pub client_id: String,
164      pub account_id: Uuid,
165      pub issued_at: DateTime<Utc>,
166      pub scopes: HashSet<OauthScope>,
167  }
168  
169  impl AccessToken {
170      pub fn expires_in(&self) -> u32 {
171          let now = Utc::now();
172          let expires_in = self.expiration.timestamp() - now.timestamp(); // The number of seconds until expiration
173          if expires_in <= 0 {
174              log::warn!("Expires in is {expires_in} before being returned to the caller!");
175          }
176          expires_in as u32
177      }
178  
179      pub fn has_scope(&self, oauth_scope: &OauthScope) -> bool {
180          self.scopes.contains(oauth_scope)
181      }
182  }
183  
184  pub async fn get_access_token(
185      db_pool: &Pool<DB>,
186      redis_pool: &mobc::Pool<RedisConnectionManager>,
187      client_id: String,
188      client_secret: String,
189      redirect_uri: String,
190      code: String,
191  ) -> Result<AccessToken, Oauth2Error> {
192      let mut transaction = new_transaction(db_pool).await?;
193  
194      let client = oauth_client_repository::get_by_client_id(&mut transaction, &client_id)
195          .await?
196          .ok_or(Oauth2Error::NoClientWithId)?;
197  
198      if client.redirect_uri != redirect_uri {
199          error!(
200              "Received redirect_uri ({}) did not match stored redirect_uri ({}) for client {}",
201              redirect_uri, client.redirect_uri, client_id
202          );
203          return Err(Oauth2Error::InvalidRedirectUri);
204      }
205  
206      if client.client_secret != client_secret {
207          error!(
208              "Received client_secret did not match stored client_secret for client {}",
209              client_id
210          );
211          return Err(Oauth2Error::InvalidClientSecret);
212      }
213  
214      let key = format!("{}:{}", AUTHORIZATION_KEY_REDIS_PREFIX, code);
215      let code_auth_token: AuthToken = redis_service::redis_get_option(redis_pool, key.clone())
216          .await?
217          .ok_or(Oauth2Error::InvalidCode)?;
218  
219      if code_auth_token.client_id != client_id {
220          error!(
221              "Stored clientId for code(code={}) {} did not match provided client id {}",
222              code, code_auth_token.client_id, client_id
223          );
224          return Err(Oauth2Error::NoClientWithId);
225      }
226  
227      // The request has met the requirements, we can now issue the access token.
228      // Start by deleting it from the cache
229      redis_service::redis_del(redis_pool, key).await?;
230  
231      let access_token = generate_access_token(
232          redis_pool,
233          code_auth_token.client_id,
234          code_auth_token.account_id,
235          code_auth_token.scopes,
236      )
237      .await?;
238  
239      transaction.commit().await?;
240  
241      Ok(access_token)
242  }
243  
244  pub async fn get_access_token_basic_auth(
245      redis_pool: &mobc::Pool<RedisConnectionManager>,
246      service: String,
247      account_id: Uuid,
248      scopes: HashSet<OauthScope>,
249  ) -> Result<AccessToken, Oauth2Error> {
250      generate_access_token(redis_pool, service, account_id, scopes).await
251  }
252  
253  async fn generate_access_token(
254      redis_pool: &mobc::Pool<RedisConnectionManager>,
255      client_id: String,
256      account_id: Uuid,
257      scopes: HashSet<OauthScope>,
258  ) -> Result<AccessToken, Oauth2Error> {
259      let access_token: String = thread_rng()
260          .sample_iter(&Alphanumeric)
261          .take(ACCESS_TOKEN_LENGTH)
262          .map(char::from)
263          .collect();
264  
265      let time_until_expiration = Duration::seconds(ACCESS_TOKEN_EXPIRATION_SECONDS);
266  
267      let issuing_time: DateTime<Utc> = Utc::now();
268      let expiration_time: DateTime<Utc> = issuing_time
269          .checked_add_signed(time_until_expiration)
270          .ok_or(Oauth2Error::ExpirationTimeGeneration)?;
271  
272      let access_token: AccessToken = AccessToken {
273          access_token: access_token.clone(),
274          expiration: expiration_time,
275          client_id,
276          account_id,
277          issued_at: issuing_time,
278          scopes,
279      };
280  
281      let key = format!(
282          "{}:{}",
283          ACCESS_TOKEN_KEY_REDIS_PREFIX, access_token.access_token
284      );
285      redis_service::redis_set(
286          redis_pool,
287          key,
288          access_token.clone(),
289          time_until_expiration.num_seconds() as usize,
290      )
291      .await
292      .or(Err(Oauth2Error::CacheInsertion))?;
293  
294      Ok(access_token)
295  }
296  
297  pub fn validate_scopes(
298      requested_scopes: &HashSet<OauthScope>,
299      client_scopes: &[ClientScope],
300  ) -> bool {
301      let consented_scopes: HashSet<&OauthScope> = client_scopes.iter().map(|s| &s.scope).collect();
302  
303      requested_scopes
304          .iter()
305          .all(|s| consented_scopes.contains(s))
306  }