oauth.rs
1 use std::io::{BufRead, BufReader, Write}; 2 use std::net::TcpListener; 3 use anyhow::{anyhow, Result}; 4 use oauth2::basic::{BasicClient}; 5 use oauth2::{AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, 6 CsrfToken, PkceCodeChallenge, RedirectUrl, ResponseType, Scope, TokenUrl}; 7 use reqwest::{ClientBuilder, Url}; 8 use reqwest::header::HeaderValue; 9 use reqwest::redirect::Policy; 10 use tracing::{error, info, warn}; 11 use crate::net::Oauth2Session; 12 13 /** 14 Object to contain and manage the standard OAuth2 authorization flow. 15 */ 16 pub struct AuthorizationManager 17 { 18 authCode: Option<AuthorizationCode>, 19 authUrl: AuthUrl, 20 clientId: ClientId, 21 clientSecret: ClientSecret, 22 csrfToken: Option<CsrfToken>, 23 preferredRedirectPort: Option<u64>, 24 tcpListener: Option<TcpListener>, 25 tokenUrl: TokenUrl, 26 } 27 28 impl AuthorizationManager 29 { 30 const SuccessHtml: &str = r#"<!DOCTYPE html> 31 <html> 32 <head> 33 <title>Platform Authentication Successful | Local Achievements</title> 34 <style> 35 body 36 { 37 align-content: center; 38 color: rgb(204, 204, 204); 39 background-color: rgb(35, 35, 35); 40 height: 100vh; 41 overflow: hidden; 42 width: 100vw; 43 } 44 45 h1 { text-align: center; } 46 p { text-align: center; } 47 </style> 48 </head> 49 <body> 50 <h1>Authentication Successful</h1> 51 <p>You can close this tab and return to Local Achievements.</p> 52 </body> 53 </html>"#; 54 55 pub fn constructAuthorizationHeader(authScheme: String, accessToken: &AccessToken) -> Result<HeaderValue> 56 { 57 let mut headerValue = HeaderValue::from_str(format!( 58 "{} {}", 59 authScheme, 60 accessToken.secret() 61 ).as_str())?; 62 headerValue.set_sensitive(true); 63 return Ok(headerValue); 64 } 65 66 pub fn new(clientId: String, clientSecret: String, authUrl: String, tokenUrl: String, redirectPort: Option<u64>) -> Result<Self> 67 { 68 let authUrl = AuthUrl::new(authUrl)?; 69 let tokenUrl = TokenUrl::new(tokenUrl)?; 70 71 return Ok(Self 72 { 73 authCode: None, 74 authUrl, 75 clientId: ClientId::new(clientId), 76 clientSecret: ClientSecret::new(clientSecret), 77 csrfToken: None, 78 preferredRedirectPort: redirectPort, 79 tcpListener: None, 80 tokenUrl, 81 }); 82 } 83 84 pub async fn authorizationCodeFlow<T>(&mut self, responseType: ResponseType, scopes: Vec<Scope>) -> Result<T> 85 where T: Oauth2Session 86 { 87 let mut out: Option<T> = None; 88 self.bindListener()?; 89 90 if let Some(listener) = &self.tcpListener 91 { 92 let localAddr = listener.local_addr()?; 93 let redirectUrl = RedirectUrl::new(format!("http://{}", localAddr.to_string()))?; 94 95 let client = BasicClient::new(self.clientId.clone()) 96 .set_client_secret(self.clientSecret.clone()) 97 .set_auth_uri(self.authUrl.clone()) 98 .set_token_uri(self.tokenUrl.clone()) 99 .set_redirect_uri(redirectUrl); 100 101 let (pkceChallenge, pkceVerifier) = PkceCodeChallenge::new_random_sha256(); 102 103 let (authUrl, csrfState) = client.authorize_url(CsrfToken::new_random) 104 .add_scopes(scopes) 105 .set_pkce_challenge(pkceChallenge) 106 .set_response_type(&responseType) 107 .url(); 108 109 self.csrfToken = Some(csrfState); 110 111 //Note: HTTPS is required, due to including the 'hardened' feature flag 112 if let Err(e) = webbrowser::open(authUrl.as_str()) 113 { 114 error!("Error opening the default browser: {:?}", e); 115 } 116 117 self.waitForResponse().await?; 118 119 if let Some(authCode) = self.authCode.clone() 120 { 121 let httpClient = ClientBuilder::new() 122 .redirect(Policy::none()) 123 .build()?; 124 125 let tokenResult = client 126 .exchange_code(authCode) 127 .set_pkce_verifier(pkceVerifier) 128 .request_async(&httpClient) 129 .await?; 130 131 out = Some(T::fromTokenResult(tokenResult)); 132 } 133 } 134 135 self.dropListener(); 136 137 return match out 138 { 139 None => Err(anyhow!("Authorization flow failed without error")), 140 Some(session) => Ok(session), 141 }; 142 } 143 144 fn bindListener(&mut self) -> Result<()> 145 { 146 let uri = format!("127.0.0.1:{}", match self.preferredRedirectPort 147 { 148 // Note: When the port is 0, this prompts the OS to choose a random port to which to bind 149 None => 0, 150 Some(port) => port, 151 }); 152 153 self.tcpListener = Some(TcpListener::bind(uri)?); 154 155 if let Some(listener) = &self.tcpListener 156 { 157 info!("Listening on: {}", listener.local_addr()?); 158 } 159 160 return Ok(()); 161 } 162 163 fn dropListener(&mut self) 164 { 165 self.tcpListener = None; 166 } 167 168 /** 169 Call `TcpListener::accept()` and wait for the redirect from the login process. 170 Parse the authorization code from the response. 171 */ 172 async fn waitForResponse(&mut self) -> Result<()> 173 { 174 if let Some(listener) = &self.tcpListener 175 { 176 match listener.accept() 177 { 178 Err(e) => warn!("[OAuth2] Listener::accept error: {:?}", e), 179 Ok((mut stream, address)) => { 180 info!("[OAuth2] Received request from: {}", address); 181 182 let bufReader = BufReader::new(&stream); 183 let httpRequest: Vec<_> = bufReader 184 .lines() 185 .map(|result| result.unwrap()) 186 .take_while(|line| !line.is_empty()) 187 .collect(); 188 189 if let Some(get) = httpRequest.iter() 190 .find(|l| l.starts_with("GET /?")) 191 { 192 if let Some(parameters) = get.split_whitespace().nth(1) 193 { 194 //TODO: Switch this to if let Ok(url) once stabilized, to avoid putting sensitive information into the logs 195 match Url::parse(&format!("http://127.0.0.1{}", parameters)) 196 { 197 Err(e) => warn!("[OAuth2] Failed to parse the url with parameters: {:?}", e), 198 Ok(url) => { 199 let state = url.query_pairs() 200 .find_map(|(k, v)| match k == "state" 201 { 202 false => None, 203 true => Some(CsrfToken::new(v.into_owned())), 204 }); 205 206 //Verify that the CSRF tokens match 207 if self.csrfToken.as_ref().is_some_and(|token| state.is_some_and(|s| s.secret() == token.secret())) 208 { 209 self.authCode = url.query_pairs() 210 .find_map(|(k, v)| match k == "code" 211 { 212 false => None, 213 true => Some(AuthorizationCode::new(v.into_owned())), 214 }); 215 } 216 } 217 } 218 } 219 } 220 221 //TODO: Make this more robust. Success is assumed but there are a number of things that could go wrong. 222 let response = format!( 223 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\n\r\n{}", 224 Self::SuccessHtml.len(), 225 Self::SuccessHtml 226 ); 227 stream.write_all(response.as_bytes())?; 228 }, 229 } 230 } 231 232 return Ok(()); 233 } 234 }