/ abzu-daemon / src / auth.rs
auth.rs
  1  //! RPC Authentication Module
  2  //!
  3  //! Provides token-based authentication for the JSON-RPC control plane.
  4  //! On startup, generates a random 256-bit token and writes it to a file
  5  //! with restricted permissions. All RPC requests must include this token
  6  //! in the Authorization header.
  7  
  8  use std::fs::{self, OpenOptions};
  9  use std::io::Write;
 10  use std::os::unix::fs::OpenOptionsExt;
 11  use std::path::PathBuf;
 12  use std::sync::Arc;
 13  use std::task::{Context, Poll};
 14  
 15  use http::{Request, Response, StatusCode};
 16  use tower::{Layer, Service};
 17  use tracing::{debug, warn};
 18  
 19  /// 32-byte auth token (256 bits of entropy)
 20  pub type AuthToken = [u8; 32];
 21  
 22  /// Generate a cryptographically random auth token
 23  pub fn generate_token() -> AuthToken {
 24      use rand::RngCore;
 25      let mut token = [0u8; 32];
 26      rand::thread_rng().fill_bytes(&mut token);
 27      token
 28  }
 29  
 30  /// Write auth token to file with restricted permissions (0600)
 31  pub fn write_token_file(token: &AuthToken, path: &PathBuf) -> std::io::Result<()> {
 32      // Ensure parent directory exists
 33      if let Some(parent) = path.parent() {
 34          fs::create_dir_all(parent)?;
 35      }
 36  
 37      // Open file with mode 0600 (owner read/write only)
 38      let mut file = OpenOptions::new()
 39          .write(true)
 40          .create(true)
 41          .truncate(true)
 42          .mode(0o600)
 43          .open(path)?;
 44  
 45      // Write token as hex
 46      let hex = hex_encode(token);
 47      file.write_all(hex.as_bytes())?;
 48      file.write_all(b"\n")?;
 49  
 50      Ok(())
 51  }
 52  
 53  /// Read auth token from file (for client-side authentication)
 54  #[allow(dead_code)]
 55  pub fn read_token_file(path: &PathBuf) -> std::io::Result<AuthToken> {
 56      let contents = fs::read_to_string(path)?;
 57      let hex = contents.trim();
 58      
 59      hex_decode(hex).map_err(|e| {
 60          std::io::Error::new(std::io::ErrorKind::InvalidData, e)
 61      })
 62  }
 63  
 64  /// Default token file path
 65  pub fn default_token_path() -> PathBuf {
 66      dirs::home_dir()
 67          .unwrap_or_else(|| PathBuf::from("."))
 68          .join(".abzu")
 69          .join("rpc_token")
 70  }
 71  
 72  /// Authentication layer for tower middleware
 73  #[derive(Clone)]
 74  pub struct AuthLayer {
 75      token: Arc<AuthToken>,
 76  }
 77  
 78  impl AuthLayer {
 79      pub fn new(token: AuthToken) -> Self {
 80          Self { token: Arc::new(token) }
 81      }
 82  }
 83  
 84  impl<S> Layer<S> for AuthLayer {
 85      type Service = AuthMiddleware<S>;
 86  
 87      fn layer(&self, inner: S) -> Self::Service {
 88          AuthMiddleware {
 89              inner,
 90              token: Arc::clone(&self.token),
 91          }
 92      }
 93  }
 94  
 95  /// Authentication middleware that validates the Authorization header
 96  #[derive(Clone)]
 97  pub struct AuthMiddleware<S> {
 98      inner: S,
 99      token: Arc<AuthToken>,
100  }
101  
102  impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for AuthMiddleware<S>
103  where
104      S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
105      S::Future: Send,
106      ReqBody: Send + 'static,
107      ResBody: Default + Send + 'static,
108  {
109      type Response = Response<ResBody>;
110      type Error = S::Error;
111      type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
112  
113      fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114          self.inner.poll_ready(cx)
115      }
116  
117      fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
118          // Extract the Authorization header
119          let auth_header = req.headers().get("authorization").cloned();
120          let expected_token = hex_encode(&*self.token);
121          
122          // Check if request is authenticated
123          let is_authenticated = match auth_header {
124              Some(value) => {
125                  if let Ok(value_str) = value.to_str() {
126                      // Support both "Bearer <token>" and raw token
127                      let token_str = value_str
128                          .strip_prefix("Bearer ")
129                          .unwrap_or(value_str);
130                      
131                      let valid = token_str == expected_token;
132                      if !valid {
133                          debug!("Invalid auth token received");
134                      }
135                      valid
136                  } else {
137                      false
138                  }
139              }
140              None => {
141                  debug!("Missing Authorization header");
142                  false
143              }
144          };
145  
146          if is_authenticated {
147              let future = self.inner.call(req);
148              Box::pin(future)
149          } else {
150              warn!("Unauthorized RPC request rejected");
151              Box::pin(async move {
152                  let response = Response::builder()
153                      .status(StatusCode::UNAUTHORIZED)
154                      .body(ResBody::default())
155                      .unwrap();
156                  Ok(response)
157              })
158          }
159      }
160  }
161  
162  /// Hex encode bytes
163  fn hex_encode(bytes: &[u8]) -> String {
164      bytes.iter().map(|b| format!("{:02x}", b)).collect()
165  }
166  
167  /// Hex decode string to [u8; 32]
168  fn hex_decode(s: &str) -> Result<AuthToken, String> {
169      if s.len() != 64 {
170          return Err(format!("Expected 64 hex chars, got {}", s.len()));
171      }
172      
173      let mut result = [0u8; 32];
174      for (i, chunk) in s.as_bytes().chunks(2).enumerate() {
175          let byte_str = std::str::from_utf8(chunk)
176              .map_err(|_| "Invalid UTF-8 in hex string")?;
177          result[i] = u8::from_str_radix(byte_str, 16)
178              .map_err(|_| format!("Invalid hex char at position {}", i * 2))?;
179      }
180      
181      Ok(result)
182  }
183  
184  #[cfg(test)]
185  mod tests {
186      use super::*;
187      use tempfile::tempdir;
188  
189      #[test]
190      fn test_token_generation() {
191          let token1 = generate_token();
192          let token2 = generate_token();
193          
194          // Tokens should be different (overwhelmingly unlikely to be equal)
195          assert_ne!(token1, token2);
196          
197          // Token should be 32 bytes
198          assert_eq!(token1.len(), 32);
199      }
200  
201      #[test]
202      fn test_token_file_roundtrip() {
203          let dir = tempdir().unwrap();
204          let path = dir.path().join("test_token");
205          
206          let original = generate_token();
207          write_token_file(&original, &path).unwrap();
208          
209          // Check file permissions (Unix only)
210          #[cfg(unix)]
211          {
212              use std::os::unix::fs::MetadataExt;
213              let metadata = fs::metadata(&path).unwrap();
214              assert_eq!(metadata.mode() & 0o777, 0o600);
215          }
216          
217          let loaded = read_token_file(&path).unwrap();
218          assert_eq!(original, loaded);
219      }
220  
221      #[test]
222      fn test_hex_roundtrip() {
223          let original: AuthToken = [
224              0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77,
225              0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
226              0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
227              0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10,
228          ];
229          
230          let hex = hex_encode(&original);
231          assert_eq!(hex.len(), 64);
232          
233          let decoded = hex_decode(&hex).unwrap();
234          assert_eq!(original, decoded);
235      }
236  }