lib.rs
1 use half::bf16; 2 use safetensors::SafeTensors; 3 use std::collections::HashMap; 4 pub mod compute_handle; 5 pub mod gpu; 6 7 #[derive(Debug)] 8 pub struct BgTensor { 9 pub data: Vec<f32>, 10 pub shape: Vec<usize>, 11 } 12 13 pub fn load_weights(data: &[u8]) -> HashMap<String, BgTensor> { 14 let tensor_data = SafeTensors::deserialize(data).unwrap(); 15 let mut map = HashMap::new(); 16 for (weight_name, tensor) in tensor_data.tensors() { 17 let f32_weights: Vec<f32> = tensor 18 .data() 19 .chunks(2) 20 .map(|chunk| { 21 let float = bf16::from_le_bytes(chunk.try_into().unwrap()); 22 float.to_f32() 23 }) 24 .collect(); 25 let shape = tensor.shape().to_vec(); 26 let tensor = BgTensor { 27 data: f32_weights, 28 shape, 29 }; 30 map.insert(weight_name, tensor); 31 } 32 map 33 }