/ abzu-inference / src / shard.rs
shard.rs
  1  //! Token sharding for privacy-preserving inference.
  2  //!
  3  //! This module implements the core sharding logic that splits hidden states
  4  //! across multiple mesh nodes, ensuring no single node sees the full context.
  5  
  6  use blake3::Hasher;
  7  use rand::Rng;
  8  use serde::{Deserialize, Serialize};
  9  
 10  /// Configuration for shard distribution.
 11  #[derive(Debug, Clone, Serialize, Deserialize)]
 12  pub struct ShardConfig {
 13      /// Number of shards to split across
 14      pub shard_count: usize,
 15      
 16      /// Minimum shards needed to reconstruct (threshold)
 17      pub threshold: usize,
 18      
 19      /// Add random padding to prevent size analysis
 20      pub pad_shards: bool,
 21      
 22      /// Target shard size in bytes (for padding)
 23      pub target_shard_size: usize,
 24  }
 25  
 26  impl Default for ShardConfig {
 27      fn default() -> Self {
 28          Self {
 29              shard_count: 3,
 30              threshold: 2,     // 2-of-3 by default
 31              pad_shards: true,
 32              target_shard_size: 1024, // 1KB shards
 33          }
 34      }
 35  }
 36  
 37  impl ShardConfig {
 38      /// Create a new config with specified shard count
 39      pub fn new(shard_count: usize) -> Self {
 40          let threshold = (shard_count / 2) + 1; // Majority threshold
 41          Self {
 42              shard_count,
 43              threshold,
 44              ..Default::default()
 45          }
 46      }
 47  
 48      /// Set the reconstruction threshold
 49      pub fn with_threshold(mut self, threshold: usize) -> Self {
 50          self.threshold = threshold;
 51          self
 52      }
 53  
 54      /// Validate configuration
 55      pub fn validate(&self) -> Result<(), String> {
 56          if self.shard_count < 2 {
 57              return Err("shard_count must be at least 2".into());
 58          }
 59          if self.threshold < 1 {
 60              return Err("threshold must be at least 1".into());
 61          }
 62          if self.threshold > self.shard_count {
 63              return Err("threshold cannot exceed shard_count".into());
 64          }
 65          Ok(())
 66      }
 67  }
 68  
 69  /// A single shard of tokenized/embedded content.
 70  #[derive(Debug, Clone, Serialize, Deserialize)]
 71  pub struct TokenShard {
 72      /// Unique shard identifier (derived from request ID + shard index)
 73      pub shard_id: [u8; 32],
 74      
 75      /// Request ID this shard belongs to
 76      pub request_id: [u8; 32],
 77      
 78      /// Shard index (0-based)
 79      pub index: usize,
 80      
 81      /// Total shards in this request
 82      pub total: usize,
 83      
 84      /// The sharded data (XOR-split or secret-shared)
 85      pub data: Vec<u8>,
 86      
 87      /// Verification hash (to detect tampering)
 88      pub commitment: [u8; 32],
 89  }
 90  
 91  impl TokenShard {
 92      /// Create shards from input data using XOR splitting.
 93      ///
 94      /// This is a simple but effective approach:
 95      /// - Generate N-1 random shards
 96      /// - Final shard is XOR of input with all random shards
 97      /// - XOR all shards together to reconstruct
 98      pub fn split(input: &[u8], config: &ShardConfig) -> Vec<Self> {
 99          let request_id = Self::generate_request_id(input);
100          let mut rng = rand::thread_rng();
101          
102          // Pad input to target size if configured
103          let padded_input = if config.pad_shards {
104              Self::pad_to_size(input, config.target_shard_size)
105          } else {
106              input.to_vec()
107          };
108          
109          let data_len = padded_input.len();
110          let mut shards = Vec::with_capacity(config.shard_count);
111          let mut xor_accumulator = padded_input.clone();
112          
113          // Generate N-1 random shards
114          for i in 0..(config.shard_count - 1) {
115              let random_data: Vec<u8> = (0..data_len).map(|_| rng.gen()).collect();
116              
117              // XOR into accumulator
118              for (j, byte) in random_data.iter().enumerate() {
119                  xor_accumulator[j] ^= byte;
120              }
121              
122              let shard_id = Self::derive_shard_id(&request_id, i);
123              let commitment = Self::compute_commitment(&random_data);
124              
125              shards.push(TokenShard {
126                  shard_id,
127                  request_id,
128                  index: i,
129                  total: config.shard_count,
130                  data: random_data,
131                  commitment,
132              });
133          }
134          
135          // Final shard is the XOR accumulator (input XOR all random shards)
136          let final_index = config.shard_count - 1;
137          let final_shard_id = Self::derive_shard_id(&request_id, final_index);
138          let final_commitment = Self::compute_commitment(&xor_accumulator);
139          
140          shards.push(TokenShard {
141              shard_id: final_shard_id,
142              request_id,
143              index: final_index,
144              total: config.shard_count,
145              data: xor_accumulator,
146              commitment: final_commitment,
147          });
148          
149          shards
150      }
151      
152      /// Reconstruct input from shards using XOR.
153      pub fn reconstruct(shards: &[TokenShard]) -> Result<Vec<u8>, String> {
154          if shards.is_empty() {
155              return Err("no shards provided".into());
156          }
157          
158          // Verify all shards belong to same request
159          let request_id = shards[0].request_id;
160          let total = shards[0].total;
161          
162          for shard in shards {
163              if shard.request_id != request_id {
164                  return Err("shards from different requests".into());
165              }
166              if shard.total != total {
167                  return Err("inconsistent shard total".into());
168              }
169          }
170          
171          // Need all shards for XOR reconstruction
172          if shards.len() != total {
173              return Err(format!(
174                  "need all {} shards for XOR reconstruction, got {}",
175                  total, shards.len()
176              ));
177          }
178          
179          // Verify commitments
180          for shard in shards {
181              let computed = Self::compute_commitment(&shard.data);
182              if computed != shard.commitment {
183                  return Err(format!("shard {} commitment mismatch", shard.index));
184              }
185          }
186          
187          // XOR all shards together
188          let data_len = shards[0].data.len();
189          let mut result = vec![0u8; data_len];
190          
191          for shard in shards {
192              for (i, byte) in shard.data.iter().enumerate() {
193                  result[i] ^= byte;
194              }
195          }
196          
197          // Remove padding (trailing zeros after null terminator)
198          Ok(Self::unpad(&result))
199      }
200      
201      /// Generate a request ID from input
202      fn generate_request_id(input: &[u8]) -> [u8; 32] {
203          let mut hasher = Hasher::new();
204          hasher.update(b"abzu:inference:request:");
205          hasher.update(input);
206          // Add randomness so same input produces different request IDs
207          let nonce: [u8; 16] = rand::thread_rng().gen();
208          hasher.update(&nonce);
209          *hasher.finalize().as_bytes()
210      }
211      
212      /// Derive shard ID from request ID and index
213      fn derive_shard_id(request_id: &[u8; 32], index: usize) -> [u8; 32] {
214          let mut hasher = Hasher::new();
215          hasher.update(b"abzu:inference:shard:");
216          hasher.update(request_id);
217          hasher.update(&(index as u64).to_le_bytes());
218          *hasher.finalize().as_bytes()
219      }
220      
221      /// Compute commitment hash for shard data
222      fn compute_commitment(data: &[u8]) -> [u8; 32] {
223          let mut hasher = Hasher::new();
224          hasher.update(b"abzu:inference:commitment:");
225          hasher.update(data);
226          *hasher.finalize().as_bytes()
227      }
228      
229      /// Pad data to target size with random bytes
230      fn pad_to_size(data: &[u8], target: usize) -> Vec<u8> {
231          let mut padded = data.to_vec();
232          padded.push(0x00); // Null terminator to mark end
233          
234          let mut rng = rand::thread_rng();
235          while padded.len() < target {
236              padded.push(rng.gen());
237          }
238          
239          padded
240      }
241      
242      /// Remove padding from data
243      fn unpad(data: &[u8]) -> Vec<u8> {
244          // Find last null terminator
245          if let Some(_pos) = data.iter().rposition(|&b| b == 0x00) {
246              // Check if everything after is random padding
247              // We assume the first null after actual data is the terminator
248              for (i, &byte) in data.iter().enumerate() {
249                  if byte == 0x00 {
250                      return data[..i].to_vec();
251                  }
252              }
253          }
254          data.to_vec()
255      }
256  }
257  
258  #[cfg(test)]
259  mod tests {
260      use super::*;
261  
262      #[test]
263      fn test_shard_config_default() {
264          let config = ShardConfig::default();
265          assert_eq!(config.shard_count, 3);
266          assert_eq!(config.threshold, 2);
267          assert!(config.validate().is_ok());
268      }
269  
270      #[test]
271      fn test_shard_config_validation() {
272          let invalid = ShardConfig {
273              shard_count: 1,
274              threshold: 2,
275              ..Default::default()
276          };
277          assert!(invalid.validate().is_err());
278      }
279  
280      #[test]
281      fn test_split_and_reconstruct() {
282          let input = b"What is the meaning of life?";
283          let config = ShardConfig::new(3);
284          
285          let shards = TokenShard::split(input, &config);
286          assert_eq!(shards.len(), 3);
287          
288          // All shards have same request ID
289          let request_id = shards[0].request_id;
290          for shard in &shards {
291              assert_eq!(shard.request_id, request_id);
292          }
293          
294          // Reconstruct
295          let reconstructed = TokenShard::reconstruct(&shards).unwrap();
296          assert_eq!(reconstructed, input);
297      }
298  
299      #[test]
300      fn test_individual_shard_reveals_nothing() {
301          let input = b"Secret prompt about sensitive topic";
302          let config = ShardConfig::new(3);
303          
304          let shards = TokenShard::split(input, &config);
305          
306          // Each individual shard should look random
307          // (Not contain the original plaintext)
308          for shard in &shards {
309              assert!(!shard.data.windows(input.len()).any(|w| w == input));
310          }
311      }
312  
313      #[test]
314      fn test_commitment_verification() {
315          let input = b"Test input";
316          let config = ShardConfig::new(2);
317          
318          let mut shards = TokenShard::split(input, &config);
319          
320          // Tamper with shard data
321          shards[0].data[0] ^= 0xFF;
322          
323          // Reconstruction should fail due to commitment mismatch
324          let result = TokenShard::reconstruct(&shards);
325          assert!(result.is_err());
326          assert!(result.unwrap_err().contains("commitment mismatch"));
327      }
328  }