shard.rs
1 //! Token sharding for privacy-preserving inference. 2 //! 3 //! This module implements the core sharding logic that splits hidden states 4 //! across multiple mesh nodes, ensuring no single node sees the full context. 5 6 use blake3::Hasher; 7 use rand::Rng; 8 use serde::{Deserialize, Serialize}; 9 10 /// Configuration for shard distribution. 11 #[derive(Debug, Clone, Serialize, Deserialize)] 12 pub struct ShardConfig { 13 /// Number of shards to split across 14 pub shard_count: usize, 15 16 /// Minimum shards needed to reconstruct (threshold) 17 pub threshold: usize, 18 19 /// Add random padding to prevent size analysis 20 pub pad_shards: bool, 21 22 /// Target shard size in bytes (for padding) 23 pub target_shard_size: usize, 24 } 25 26 impl Default for ShardConfig { 27 fn default() -> Self { 28 Self { 29 shard_count: 3, 30 threshold: 2, // 2-of-3 by default 31 pad_shards: true, 32 target_shard_size: 1024, // 1KB shards 33 } 34 } 35 } 36 37 impl ShardConfig { 38 /// Create a new config with specified shard count 39 pub fn new(shard_count: usize) -> Self { 40 let threshold = (shard_count / 2) + 1; // Majority threshold 41 Self { 42 shard_count, 43 threshold, 44 ..Default::default() 45 } 46 } 47 48 /// Set the reconstruction threshold 49 pub fn with_threshold(mut self, threshold: usize) -> Self { 50 self.threshold = threshold; 51 self 52 } 53 54 /// Validate configuration 55 pub fn validate(&self) -> Result<(), String> { 56 if self.shard_count < 2 { 57 return Err("shard_count must be at least 2".into()); 58 } 59 if self.threshold < 1 { 60 return Err("threshold must be at least 1".into()); 61 } 62 if self.threshold > self.shard_count { 63 return Err("threshold cannot exceed shard_count".into()); 64 } 65 Ok(()) 66 } 67 } 68 69 /// A single shard of tokenized/embedded content. 70 #[derive(Debug, Clone, Serialize, Deserialize)] 71 pub struct TokenShard { 72 /// Unique shard identifier (derived from request ID + shard index) 73 pub shard_id: [u8; 32], 74 75 /// Request ID this shard belongs to 76 pub request_id: [u8; 32], 77 78 /// Shard index (0-based) 79 pub index: usize, 80 81 /// Total shards in this request 82 pub total: usize, 83 84 /// The sharded data (XOR-split or secret-shared) 85 pub data: Vec<u8>, 86 87 /// Verification hash (to detect tampering) 88 pub commitment: [u8; 32], 89 } 90 91 impl TokenShard { 92 /// Create shards from input data using XOR splitting. 93 /// 94 /// This is a simple but effective approach: 95 /// - Generate N-1 random shards 96 /// - Final shard is XOR of input with all random shards 97 /// - XOR all shards together to reconstruct 98 pub fn split(input: &[u8], config: &ShardConfig) -> Vec<Self> { 99 let request_id = Self::generate_request_id(input); 100 let mut rng = rand::thread_rng(); 101 102 // Pad input to target size if configured 103 let padded_input = if config.pad_shards { 104 Self::pad_to_size(input, config.target_shard_size) 105 } else { 106 input.to_vec() 107 }; 108 109 let data_len = padded_input.len(); 110 let mut shards = Vec::with_capacity(config.shard_count); 111 let mut xor_accumulator = padded_input.clone(); 112 113 // Generate N-1 random shards 114 for i in 0..(config.shard_count - 1) { 115 let random_data: Vec<u8> = (0..data_len).map(|_| rng.gen()).collect(); 116 117 // XOR into accumulator 118 for (j, byte) in random_data.iter().enumerate() { 119 xor_accumulator[j] ^= byte; 120 } 121 122 let shard_id = Self::derive_shard_id(&request_id, i); 123 let commitment = Self::compute_commitment(&random_data); 124 125 shards.push(TokenShard { 126 shard_id, 127 request_id, 128 index: i, 129 total: config.shard_count, 130 data: random_data, 131 commitment, 132 }); 133 } 134 135 // Final shard is the XOR accumulator (input XOR all random shards) 136 let final_index = config.shard_count - 1; 137 let final_shard_id = Self::derive_shard_id(&request_id, final_index); 138 let final_commitment = Self::compute_commitment(&xor_accumulator); 139 140 shards.push(TokenShard { 141 shard_id: final_shard_id, 142 request_id, 143 index: final_index, 144 total: config.shard_count, 145 data: xor_accumulator, 146 commitment: final_commitment, 147 }); 148 149 shards 150 } 151 152 /// Reconstruct input from shards using XOR. 153 pub fn reconstruct(shards: &[TokenShard]) -> Result<Vec<u8>, String> { 154 if shards.is_empty() { 155 return Err("no shards provided".into()); 156 } 157 158 // Verify all shards belong to same request 159 let request_id = shards[0].request_id; 160 let total = shards[0].total; 161 162 for shard in shards { 163 if shard.request_id != request_id { 164 return Err("shards from different requests".into()); 165 } 166 if shard.total != total { 167 return Err("inconsistent shard total".into()); 168 } 169 } 170 171 // Need all shards for XOR reconstruction 172 if shards.len() != total { 173 return Err(format!( 174 "need all {} shards for XOR reconstruction, got {}", 175 total, shards.len() 176 )); 177 } 178 179 // Verify commitments 180 for shard in shards { 181 let computed = Self::compute_commitment(&shard.data); 182 if computed != shard.commitment { 183 return Err(format!("shard {} commitment mismatch", shard.index)); 184 } 185 } 186 187 // XOR all shards together 188 let data_len = shards[0].data.len(); 189 let mut result = vec![0u8; data_len]; 190 191 for shard in shards { 192 for (i, byte) in shard.data.iter().enumerate() { 193 result[i] ^= byte; 194 } 195 } 196 197 // Remove padding (trailing zeros after null terminator) 198 Ok(Self::unpad(&result)) 199 } 200 201 /// Generate a request ID from input 202 fn generate_request_id(input: &[u8]) -> [u8; 32] { 203 let mut hasher = Hasher::new(); 204 hasher.update(b"abzu:inference:request:"); 205 hasher.update(input); 206 // Add randomness so same input produces different request IDs 207 let nonce: [u8; 16] = rand::thread_rng().gen(); 208 hasher.update(&nonce); 209 *hasher.finalize().as_bytes() 210 } 211 212 /// Derive shard ID from request ID and index 213 fn derive_shard_id(request_id: &[u8; 32], index: usize) -> [u8; 32] { 214 let mut hasher = Hasher::new(); 215 hasher.update(b"abzu:inference:shard:"); 216 hasher.update(request_id); 217 hasher.update(&(index as u64).to_le_bytes()); 218 *hasher.finalize().as_bytes() 219 } 220 221 /// Compute commitment hash for shard data 222 fn compute_commitment(data: &[u8]) -> [u8; 32] { 223 let mut hasher = Hasher::new(); 224 hasher.update(b"abzu:inference:commitment:"); 225 hasher.update(data); 226 *hasher.finalize().as_bytes() 227 } 228 229 /// Pad data to target size with random bytes 230 fn pad_to_size(data: &[u8], target: usize) -> Vec<u8> { 231 let mut padded = data.to_vec(); 232 padded.push(0x00); // Null terminator to mark end 233 234 let mut rng = rand::thread_rng(); 235 while padded.len() < target { 236 padded.push(rng.gen()); 237 } 238 239 padded 240 } 241 242 /// Remove padding from data 243 fn unpad(data: &[u8]) -> Vec<u8> { 244 // Find last null terminator 245 if let Some(_pos) = data.iter().rposition(|&b| b == 0x00) { 246 // Check if everything after is random padding 247 // We assume the first null after actual data is the terminator 248 for (i, &byte) in data.iter().enumerate() { 249 if byte == 0x00 { 250 return data[..i].to_vec(); 251 } 252 } 253 } 254 data.to_vec() 255 } 256 } 257 258 #[cfg(test)] 259 mod tests { 260 use super::*; 261 262 #[test] 263 fn test_shard_config_default() { 264 let config = ShardConfig::default(); 265 assert_eq!(config.shard_count, 3); 266 assert_eq!(config.threshold, 2); 267 assert!(config.validate().is_ok()); 268 } 269 270 #[test] 271 fn test_shard_config_validation() { 272 let invalid = ShardConfig { 273 shard_count: 1, 274 threshold: 2, 275 ..Default::default() 276 }; 277 assert!(invalid.validate().is_err()); 278 } 279 280 #[test] 281 fn test_split_and_reconstruct() { 282 let input = b"What is the meaning of life?"; 283 let config = ShardConfig::new(3); 284 285 let shards = TokenShard::split(input, &config); 286 assert_eq!(shards.len(), 3); 287 288 // All shards have same request ID 289 let request_id = shards[0].request_id; 290 for shard in &shards { 291 assert_eq!(shard.request_id, request_id); 292 } 293 294 // Reconstruct 295 let reconstructed = TokenShard::reconstruct(&shards).unwrap(); 296 assert_eq!(reconstructed, input); 297 } 298 299 #[test] 300 fn test_individual_shard_reveals_nothing() { 301 let input = b"Secret prompt about sensitive topic"; 302 let config = ShardConfig::new(3); 303 304 let shards = TokenShard::split(input, &config); 305 306 // Each individual shard should look random 307 // (Not contain the original plaintext) 308 for shard in &shards { 309 assert!(!shard.data.windows(input.len()).any(|w| w == input)); 310 } 311 } 312 313 #[test] 314 fn test_commitment_verification() { 315 let input = b"Test input"; 316 let config = ShardConfig::new(2); 317 318 let mut shards = TokenShard::split(input, &config); 319 320 // Tamper with shard data 321 shards[0].data[0] ^= 0xFF; 322 323 // Reconstruction should fail due to commitment mismatch 324 let result = TokenShard::reconstruct(&shards); 325 assert!(result.is_err()); 326 assert!(result.unwrap_err().contains("commitment mismatch")); 327 } 328 }