tokenizer.rs
1 use std::path::Path; 2 3 use super::ModelError; 4 use tokenizers::Tokenizer; 5 6 type EncodeResult = (Vec<i64>, Vec<i64>); 7 8 const MAX_LENGTH: usize = 77; 9 10 pub struct ClipTokenizers { 11 tokenizer: Tokenizer, 12 tokenizer_2: Tokenizer, 13 } 14 impl ClipTokenizers { 15 pub fn new(data_dir: &Path) -> Result<Self, ModelError> { 16 Ok(Self{ 17 tokenizer: Tokenizer::from_file(data_dir.join("tokenizer.json"))?, 18 tokenizer_2: Tokenizer::from_file(data_dir.join("tokenizer_2.json"))?, 19 }) 20 } 21 22 pub fn encode(&self, text: &str) -> Result<EncodeResult, ModelError> { 23 let tokens1 = self.encode_single(&self.tokenizer, text)?; 24 let tokens2 = self.encode_single(&self.tokenizer_2, text)?; 25 Ok((tokens1, tokens2)) 26 } 27 28 fn encode_single(&self, tokenizer: &Tokenizer, text: &str) -> Result<Vec<i64>, ModelError> { 29 let encoding = tokenizer.encode(text, true)?; 30 31 let mut ids: Vec<i64> = encoding 32 .get_ids() 33 .iter() 34 .map(|&id| id as i64) 35 .collect(); 36 37 ids.truncate(MAX_LENGTH); 38 39 while ids.len() < MAX_LENGTH { 40 ids.push(0); 41 } 42 43 Ok(ids) 44 } 45 }