auth.rs
1 //! RPC Authentication Module 2 //! 3 //! Provides token-based authentication for the JSON-RPC control plane. 4 //! On startup, generates a random 256-bit token and writes it to a file 5 //! with restricted permissions. All RPC requests must include this token 6 //! in the Authorization header. 7 8 use std::fs::{self, OpenOptions}; 9 use std::io::Write; 10 use std::os::unix::fs::OpenOptionsExt; 11 use std::path::PathBuf; 12 use std::sync::Arc; 13 use std::task::{Context, Poll}; 14 15 use http::{Request, Response, StatusCode}; 16 use tower::{Layer, Service}; 17 use tracing::{debug, warn}; 18 19 /// 32-byte auth token (256 bits of entropy) 20 pub type AuthToken = [u8; 32]; 21 22 /// Generate a cryptographically random auth token 23 pub fn generate_token() -> AuthToken { 24 use rand::RngCore; 25 let mut token = [0u8; 32]; 26 rand::thread_rng().fill_bytes(&mut token); 27 token 28 } 29 30 /// Write auth token to file with restricted permissions (0600) 31 pub fn write_token_file(token: &AuthToken, path: &PathBuf) -> std::io::Result<()> { 32 // Ensure parent directory exists 33 if let Some(parent) = path.parent() { 34 fs::create_dir_all(parent)?; 35 } 36 37 // Open file with mode 0600 (owner read/write only) 38 let mut file = OpenOptions::new() 39 .write(true) 40 .create(true) 41 .truncate(true) 42 .mode(0o600) 43 .open(path)?; 44 45 // Write token as hex 46 let hex = hex_encode(token); 47 file.write_all(hex.as_bytes())?; 48 file.write_all(b"\n")?; 49 50 Ok(()) 51 } 52 53 /// Read auth token from file (for client-side authentication) 54 #[allow(dead_code)] 55 pub fn read_token_file(path: &PathBuf) -> std::io::Result<AuthToken> { 56 let contents = fs::read_to_string(path)?; 57 let hex = contents.trim(); 58 59 hex_decode(hex).map_err(|e| { 60 std::io::Error::new(std::io::ErrorKind::InvalidData, e) 61 }) 62 } 63 64 /// Default token file path 65 pub fn default_token_path() -> PathBuf { 66 dirs::home_dir() 67 .unwrap_or_else(|| PathBuf::from(".")) 68 .join(".abzu") 69 .join("rpc_token") 70 } 71 72 /// Authentication layer for tower middleware 73 #[derive(Clone)] 74 pub struct AuthLayer { 75 token: Arc<AuthToken>, 76 } 77 78 impl AuthLayer { 79 pub fn new(token: AuthToken) -> Self { 80 Self { token: Arc::new(token) } 81 } 82 } 83 84 impl<S> Layer<S> for AuthLayer { 85 type Service = AuthMiddleware<S>; 86 87 fn layer(&self, inner: S) -> Self::Service { 88 AuthMiddleware { 89 inner, 90 token: Arc::clone(&self.token), 91 } 92 } 93 } 94 95 /// Authentication middleware that validates the Authorization header 96 #[derive(Clone)] 97 pub struct AuthMiddleware<S> { 98 inner: S, 99 token: Arc<AuthToken>, 100 } 101 102 impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AuthMiddleware<S> 103 where 104 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static, 105 S::Future: Send, 106 ReqBody: Send + 'static, 107 ResBody: Default + Send + 'static, 108 { 109 type Response = Response<ResBody>; 110 type Error = S::Error; 111 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>; 112 113 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 114 self.inner.poll_ready(cx) 115 } 116 117 fn call(&mut self, req: Request<ReqBody>) -> Self::Future { 118 // Extract the Authorization header 119 let auth_header = req.headers().get("authorization").cloned(); 120 let expected_token = hex_encode(&*self.token); 121 122 // Check if request is authenticated 123 let is_authenticated = match auth_header { 124 Some(value) => { 125 if let Ok(value_str) = value.to_str() { 126 // Support both "Bearer <token>" and raw token 127 let token_str = value_str 128 .strip_prefix("Bearer ") 129 .unwrap_or(value_str); 130 131 let valid = token_str == expected_token; 132 if !valid { 133 debug!("Invalid auth token received"); 134 } 135 valid 136 } else { 137 false 138 } 139 } 140 None => { 141 debug!("Missing Authorization header"); 142 false 143 } 144 }; 145 146 if is_authenticated { 147 let future = self.inner.call(req); 148 Box::pin(future) 149 } else { 150 warn!("Unauthorized RPC request rejected"); 151 Box::pin(async move { 152 let response = Response::builder() 153 .status(StatusCode::UNAUTHORIZED) 154 .body(ResBody::default()) 155 .unwrap(); 156 Ok(response) 157 }) 158 } 159 } 160 } 161 162 /// Hex encode bytes 163 fn hex_encode(bytes: &[u8]) -> String { 164 bytes.iter().map(|b| format!("{:02x}", b)).collect() 165 } 166 167 /// Hex decode string to [u8; 32] 168 fn hex_decode(s: &str) -> Result<AuthToken, String> { 169 if s.len() != 64 { 170 return Err(format!("Expected 64 hex chars, got {}", s.len())); 171 } 172 173 let mut result = [0u8; 32]; 174 for (i, chunk) in s.as_bytes().chunks(2).enumerate() { 175 let byte_str = std::str::from_utf8(chunk) 176 .map_err(|_| "Invalid UTF-8 in hex string")?; 177 result[i] = u8::from_str_radix(byte_str, 16) 178 .map_err(|_| format!("Invalid hex char at position {}", i * 2))?; 179 } 180 181 Ok(result) 182 } 183 184 #[cfg(test)] 185 mod tests { 186 use super::*; 187 use tempfile::tempdir; 188 189 #[test] 190 fn test_token_generation() { 191 let token1 = generate_token(); 192 let token2 = generate_token(); 193 194 // Tokens should be different (overwhelmingly unlikely to be equal) 195 assert_ne!(token1, token2); 196 197 // Token should be 32 bytes 198 assert_eq!(token1.len(), 32); 199 } 200 201 #[test] 202 fn test_token_file_roundtrip() { 203 let dir = tempdir().unwrap(); 204 let path = dir.path().join("test_token"); 205 206 let original = generate_token(); 207 write_token_file(&original, &path).unwrap(); 208 209 // Check file permissions (Unix only) 210 #[cfg(unix)] 211 { 212 use std::os::unix::fs::MetadataExt; 213 let metadata = fs::metadata(&path).unwrap(); 214 assert_eq!(metadata.mode() & 0o777, 0o600); 215 } 216 217 let loaded = read_token_file(&path).unwrap(); 218 assert_eq!(original, loaded); 219 } 220 221 #[test] 222 fn test_hex_roundtrip() { 223 let original: AuthToken = [ 224 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 225 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 226 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 227 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 228 ]; 229 230 let hex = hex_encode(&original); 231 assert_eq!(hex.len(), 64); 232 233 let decoded = hex_decode(&hex).unwrap(); 234 assert_eq!(original, decoded); 235 } 236 }