/ test / resources / scripts / generate_tokenized_dataset.py
generate_tokenized_dataset.py
 1  import argparse
 2  
 3  from datasets import load_dataset
 4  from transformers import AutoTokenizer
 5  
 6  
 7  def prepare_dataset(tokenizer, text_file_path: str):
 8      """
 9      Tokenizes a text file where each line is a different example.
10      Padding is applied to each example.
11      """
12      # Each line is a different example
13      dataset = load_dataset("text", data_files={"train": text_file_path})
14  
15      def tokenize_function(examples):
16          return tokenizer(
17              examples["text"], padding="max_length", truncation=True, max_length=128
18          )
19  
20      tokenized_dataset = dataset.map(
21          tokenize_function, batched=True, remove_columns=["text"]
22      )
23      return tokenized_dataset["train"]
24  
25  
26  def generate_tokenized_dataset(
27      tokenizer_path: str, text_file_path: str, output_dir: str
28  ) -> None:
29      """
30      Generate tokenized dataset from a text file, where each line is a different example.
31  
32      Args:
33          tokenizer_path (str): Path to the directory containing the tokenizer files.
34          text_file_path (str): Path to the text file to tokenize.
35          output_dir (str): Directory where the tokenized dataset will be saved
36      """
37      tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
38      tokenizer.pad_token = tokenizer.eos_token
39  
40      train_dataset = prepare_dataset(tokenizer, text_file_path)
41      train_dataset.save_to_disk(output_dir)
42  
43  
44  if __name__ == "__main__":
45      # Example usage:
46      # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized
47      parser = argparse.ArgumentParser(
48          description="Generate tokenized dataset from a text file."
49      )
50  
51      # Add arguments
52      parser.add_argument(
53          "--tokenizer_path",
54          type=str,
55          required=True,
56          help="Path to the directory containing the tokenizer files.",
57      )
58      parser.add_argument(
59          "--text_file_path",
60          type=str,
61          required=True,
62          help="Path to the text file to tokenize.",
63      )
64      parser.add_argument(
65          "--output_dir",
66          type=str,
67          required=True,
68          help="Directory where the tokenized dataset will be saved.",
69      )
70  
71      # Parse the arguments
72      args = parser.parse_args()
73  
74      # Call the function with parsed arguments
75      generate_tokenized_dataset(
76          tokenizer_path=args.tokenizer_path,
77          text_file_path=args.text_file_path,
78          output_dir=args.output_dir,
79      )