/ src / net / oauth.rs
oauth.rs
  1  use std::io::{BufRead, BufReader, Write};
  2  use std::net::TcpListener;
  3  use anyhow::{anyhow, Result};
  4  use oauth2::basic::{BasicClient};
  5  use oauth2::{AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret,
  6  	CsrfToken, PkceCodeChallenge, RedirectUrl, ResponseType, Scope, TokenUrl};
  7  use reqwest::{ClientBuilder, Url};
  8  use reqwest::header::HeaderValue;
  9  use reqwest::redirect::Policy;
 10  use tracing::{error, info, warn};
 11  use crate::net::Oauth2Session;
 12  
 13  /**
 14  Object to contain and manage the standard OAuth2 authorization flow.
 15  */
 16  pub struct AuthorizationManager
 17  {
 18  	authCode: Option<AuthorizationCode>,
 19  	authUrl: AuthUrl,
 20  	clientId: ClientId,
 21  	clientSecret: ClientSecret,
 22  	csrfToken: Option<CsrfToken>,
 23  	preferredRedirectPort: Option<u64>,
 24  	tcpListener: Option<TcpListener>,
 25  	tokenUrl: TokenUrl,
 26  }
 27  
 28  impl AuthorizationManager
 29  {			
 30  	const SuccessHtml: &str = r#"<!DOCTYPE html>
 31  <html>
 32  	<head>
 33  		<title>Platform Authentication Successful | Local Achievements</title>
 34  		<style>
 35  			body
 36  			{
 37  				align-content: center;
 38  				color: rgb(204, 204, 204);
 39  				background-color: rgb(35, 35, 35);
 40  				height: 100vh;
 41  				overflow: hidden;
 42  				width: 100vw;
 43  			}
 44  			
 45  			h1 { text-align: center; }
 46  			p { text-align: center; }
 47  		</style>
 48  	</head>
 49  	<body>
 50  		<h1>Authentication Successful</h1>
 51  		<p>You can close this tab and return to Local Achievements.</p>
 52  	</body>
 53  </html>"#;
 54  	
 55  	pub fn constructAuthorizationHeader(authScheme: String, accessToken: &AccessToken) -> Result<HeaderValue>
 56  	{
 57  		let mut headerValue = HeaderValue::from_str(format!(
 58  			"{} {}",
 59  			authScheme,
 60  			accessToken.secret()
 61  		).as_str())?;
 62  		headerValue.set_sensitive(true);
 63  		return Ok(headerValue);
 64  	}
 65  	
 66  	pub fn new(clientId: String, clientSecret: String, authUrl: String, tokenUrl: String, redirectPort: Option<u64>) -> Result<Self>
 67  	{
 68  		let authUrl = AuthUrl::new(authUrl)?;
 69  		let tokenUrl = TokenUrl::new(tokenUrl)?;
 70  		
 71  		return Ok(Self
 72  		{
 73  			authCode: None,
 74  			authUrl,
 75  			clientId: ClientId::new(clientId),
 76  			clientSecret: ClientSecret::new(clientSecret),
 77  			csrfToken: None,
 78  			preferredRedirectPort: redirectPort,
 79  			tcpListener: None,
 80  			tokenUrl,
 81  		});
 82  	}
 83  	
 84  	pub async fn authorizationCodeFlow<T>(&mut self, responseType: ResponseType, scopes: Vec<Scope>) -> Result<T>
 85  		where T: Oauth2Session
 86  	{
 87  		let mut out: Option<T> = None;
 88  		self.bindListener()?;
 89  		
 90  		if let Some(listener) = &self.tcpListener
 91  		{
 92  			let localAddr = listener.local_addr()?;
 93  			let redirectUrl = RedirectUrl::new(format!("http://{}", localAddr.to_string()))?;
 94  			
 95  			let client = BasicClient::new(self.clientId.clone())
 96  				.set_client_secret(self.clientSecret.clone())
 97  				.set_auth_uri(self.authUrl.clone())
 98  				.set_token_uri(self.tokenUrl.clone())
 99  				.set_redirect_uri(redirectUrl);
100  			
101  			let (pkceChallenge, pkceVerifier) = PkceCodeChallenge::new_random_sha256();
102  			
103  			let (authUrl, csrfState) = client.authorize_url(CsrfToken::new_random)
104  				.add_scopes(scopes)
105  				.set_pkce_challenge(pkceChallenge)
106  				.set_response_type(&responseType)
107  				.url();
108  			
109  			self.csrfToken = Some(csrfState);
110  			
111  			//Note: HTTPS is required, due to including the 'hardened' feature flag
112  			if let Err(e) = webbrowser::open(authUrl.as_str())
113  			{
114  				error!("Error opening the default browser: {:?}", e);
115  			}
116  			
117  			self.waitForResponse().await?;
118  			
119  			if let Some(authCode) = self.authCode.clone()
120  			{
121  				let httpClient = ClientBuilder::new()
122  					.redirect(Policy::none())
123  					.build()?;
124  				
125  				let tokenResult = client
126  					.exchange_code(authCode)
127  					.set_pkce_verifier(pkceVerifier)
128  					.request_async(&httpClient)
129  					.await?;
130  				
131  				out = Some(T::fromTokenResult(tokenResult));
132  			}
133  		}
134  		
135  		self.dropListener();
136  		
137  		return match out
138  		{
139  			None => Err(anyhow!("Authorization flow failed without error")),
140  			Some(session) => Ok(session),
141  		};
142  	}
143  	
144  	fn bindListener(&mut self) -> Result<()>
145  	{
146  		let uri = format!("127.0.0.1:{}", match self.preferredRedirectPort
147  		{
148  			// Note: When the port is 0, this prompts the OS to choose a random port to which to bind
149  			None => 0,
150  			Some(port) => port,
151  		});
152  		
153  		self.tcpListener = Some(TcpListener::bind(uri)?);
154  		
155  		if let Some(listener) = &self.tcpListener
156  		{
157  			info!("Listening on: {}", listener.local_addr()?);
158  		}
159  		
160  		return Ok(());
161  	}
162  	
163  	fn dropListener(&mut self)
164  	{
165  		self.tcpListener = None;
166  	}
167  	
168  	/**
169  	Call `TcpListener::accept()` and wait for the redirect from the login process.
170  	Parse the authorization code from the response.
171  	*/
172  	async fn waitForResponse(&mut self) -> Result<()>
173  	{
174  		if let Some(listener) = &self.tcpListener
175  		{
176  			match listener.accept()
177  			{
178  				Err(e) => warn!("[OAuth2] Listener::accept error: {:?}", e),
179  				Ok((mut stream, address)) => {
180  					info!("[OAuth2] Received request from: {}", address);
181  					
182  					let bufReader = BufReader::new(&stream);
183  					let httpRequest: Vec<_> = bufReader
184  						.lines()
185  						.map(|result| result.unwrap())
186  						.take_while(|line| !line.is_empty())
187  						.collect();
188  					
189  					if let Some(get) = httpRequest.iter()
190  						.find(|l| l.starts_with("GET /?"))
191  					{
192  						if let Some(parameters) = get.split_whitespace().nth(1)
193  						{
194  							//TODO: Switch this to if let Ok(url) once stabilized, to avoid putting sensitive information into the logs
195  							match Url::parse(&format!("http://127.0.0.1{}", parameters))
196  							{
197  								Err(e) => warn!("[OAuth2] Failed to parse the url with parameters: {:?}", e),
198  								Ok(url) => {
199  									let state = url.query_pairs()
200  										.find_map(|(k, v)| match k == "state"
201  										{
202  											false => None,
203  											true => Some(CsrfToken::new(v.into_owned())),
204  										});
205  									
206  									//Verify that the CSRF tokens match
207  									if self.csrfToken.as_ref().is_some_and(|token| state.is_some_and(|s| s.secret() == token.secret()))
208  									{
209  										self.authCode = url.query_pairs()
210  											.find_map(|(k, v)| match k == "code"
211  										{
212  											false => None,
213  											true => Some(AuthorizationCode::new(v.into_owned())),
214  										});
215  									}
216  								}
217  							}
218  						}
219  					}
220  					
221  					//TODO: Make this more robust. Success is assumed but there are a number of things that could go wrong.
222  					let response = format!(
223  						"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\n\r\n{}",
224  						Self::SuccessHtml.len(),
225  						Self::SuccessHtml
226  					);
227  					stream.write_all(response.as_bytes())?;
228  				},
229  			}
230  		}
231  		
232  		return Ok(());
233  	}
234  }