generate_tokenized_dataset.py
1 import argparse 2 3 from datasets import load_dataset 4 from transformers import AutoTokenizer 5 6 7 def prepare_dataset(tokenizer, text_file_path: str): 8 """ 9 Tokenizes a text file where each line is a different example. 10 Padding is applied to each example. 11 """ 12 # Each line is a different example 13 dataset = load_dataset("text", data_files={"train": text_file_path}) 14 15 def tokenize_function(examples): 16 return tokenizer( 17 examples["text"], padding="max_length", truncation=True, max_length=128 18 ) 19 20 tokenized_dataset = dataset.map( 21 tokenize_function, batched=True, remove_columns=["text"] 22 ) 23 return tokenized_dataset["train"] 24 25 26 def generate_tokenized_dataset( 27 tokenizer_path: str, text_file_path: str, output_dir: str 28 ) -> None: 29 """ 30 Generate tokenized dataset from a text file, where each line is a different example. 31 32 Args: 33 tokenizer_path (str): Path to the directory containing the tokenizer files. 34 text_file_path (str): Path to the text file to tokenize. 35 output_dir (str): Directory where the tokenized dataset will be saved 36 """ 37 tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 38 tokenizer.pad_token = tokenizer.eos_token 39 40 train_dataset = prepare_dataset(tokenizer, text_file_path) 41 train_dataset.save_to_disk(output_dir) 42 43 44 if __name__ == "__main__": 45 # Example usage: 46 # python generate_tokenized_dataset.py --tokenizer_path /shared/public/models/Mistral-7B --text_file_path ./../../resources/tiny_shakespeare.txt --output_dir ./../../resources/tiny_shakespeare_tokenized 47 parser = argparse.ArgumentParser( 48 description="Generate tokenized dataset from a text file." 49 ) 50 51 # Add arguments 52 parser.add_argument( 53 "--tokenizer_path", 54 type=str, 55 required=True, 56 help="Path to the directory containing the tokenizer files.", 57 ) 58 parser.add_argument( 59 "--text_file_path", 60 type=str, 61 required=True, 62 help="Path to the text file to tokenize.", 63 ) 64 parser.add_argument( 65 "--output_dir", 66 type=str, 67 required=True, 68 help="Directory where the tokenized dataset will be saved.", 69 ) 70 71 # Parse the arguments 72 args = parser.parse_args() 73 74 # Call the function with parsed arguments 75 generate_tokenized_dataset( 76 tokenizer_path=args.tokenizer_path, 77 text_file_path=args.text_file_path, 78 output_dir=args.output_dir, 79 )