/ src / model / tokenizer.rs
tokenizer.rs
 1  use std::path::Path;
 2  
 3  use super::ModelError;
 4  use tokenizers::Tokenizer;
 5  
 6  type EncodeResult = (Vec<i64>, Vec<i64>);
 7  
 8  const MAX_LENGTH: usize = 77;
 9  
10  pub struct ClipTokenizers {
11      tokenizer: Tokenizer,
12      tokenizer_2: Tokenizer,
13  }
14  impl ClipTokenizers {
15      pub fn new(data_dir: &Path) -> Result<Self, ModelError> {
16          Ok(Self{ 
17              tokenizer: Tokenizer::from_file(data_dir.join("tokenizer.json"))?,
18              tokenizer_2: Tokenizer::from_file(data_dir.join("tokenizer_2.json"))?,
19          })
20      }
21  
22      pub fn encode(&self, text: &str) -> Result<EncodeResult, ModelError> {
23          let tokens1 = self.encode_single(&self.tokenizer, text)?;
24          let tokens2 = self.encode_single(&self.tokenizer_2, text)?;
25          Ok((tokens1, tokens2))
26      }
27  
28      fn encode_single(&self, tokenizer: &Tokenizer, text: &str) -> Result<Vec<i64>, ModelError> {
29          let encoding = tokenizer.encode(text, true)?;
30  
31          let mut ids: Vec<i64> = encoding
32              .get_ids()
33              .iter()
34              .map(|&id| id as i64)
35              .collect();
36  
37          ids.truncate(MAX_LENGTH);
38  
39          while ids.len() < MAX_LENGTH {
40              ids.push(0);
41          }
42  
43          Ok(ids)
44      }
45  }