train.py
1 """Trains and evaluate a simple MLP 2 on the Reuters newswire topic classification task. 3 """ 4 5 import numpy as np 6 from tensorflow import keras 7 from tensorflow.keras.datasets import reuters 8 from tensorflow.keras.layers import Activation, Dense, Dropout 9 from tensorflow.keras.models import Sequential 10 from tensorflow.keras.preprocessing.text import Tokenizer 11 12 # The following import and function call are the only additions to code required 13 # to automatically log metrics and parameters to MLflow. 14 import mlflow 15 16 mlflow.tensorflow.autolog() 17 18 max_words = 1000 19 batch_size = 32 20 epochs = 5 21 22 print("Loading data...") 23 (x_train, y_train), (x_test, y_test) = reuters.load_data(num_words=max_words, test_split=0.2) 24 25 print(len(x_train), "train sequences") 26 print(len(x_test), "test sequences") 27 28 num_classes = np.max(y_train) + 1 29 print(num_classes, "classes") 30 31 print("Vectorizing sequence data...") 32 tokenizer = Tokenizer(num_words=max_words) 33 x_train = tokenizer.sequences_to_matrix(x_train, mode="binary") 34 x_test = tokenizer.sequences_to_matrix(x_test, mode="binary") 35 print("x_train shape:", x_train.shape) 36 print("x_test shape:", x_test.shape) 37 38 print("Convert class vector to binary class matrix (for use with categorical_crossentropy)") 39 y_train = keras.utils.to_categorical(y_train, num_classes) 40 y_test = keras.utils.to_categorical(y_test, num_classes) 41 print("y_train shape:", y_train.shape) 42 print("y_test shape:", y_test.shape) 43 44 print("Building model...") 45 model = Sequential() 46 model.add(Dense(512, input_shape=(max_words,))) 47 model.add(Activation("relu")) 48 model.add(Dropout(0.5)) 49 model.add(Dense(num_classes)) 50 model.add(Activation("softmax")) 51 52 model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) 53 54 history = model.fit( 55 x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.1 56 ) 57 score = model.evaluate(x_test, y_test, batch_size=batch_size, verbose=1) 58 print("Test score:", score[0]) 59 print("Test accuracy:", score[1])