/ src / lib.rs
lib.rs
 1  use half::bf16;
 2  use safetensors::SafeTensors;
 3  use std::collections::HashMap;
 4  pub mod compute_handle;
 5  pub mod gpu;
 6  
 7  #[derive(Debug)]
 8  pub struct BgTensor {
 9      pub data: Vec<f32>,
10      pub shape: Vec<usize>,
11  }
12  
13  pub fn load_weights(data: &[u8]) -> HashMap<String, BgTensor> {
14      let tensor_data = SafeTensors::deserialize(data).unwrap();
15      let mut map = HashMap::new();
16      for (weight_name, tensor) in tensor_data.tensors() {
17          let f32_weights: Vec<f32> = tensor
18              .data()
19              .chunks(2)
20              .map(|chunk| {
21                  let float = bf16::from_le_bytes(chunk.try_into().unwrap());
22                  float.to_f32()
23              })
24              .collect();
25          let shape = tensor.shape().to_vec();
26          let tensor = BgTensor {
27              data: f32_weights,
28              shape,
29          };
30          map.insert(weight_name, tensor);
31      }
32      map
33  }