iris_data_utils.py
1 # From https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py 2 # This file is the example used by TensorFlow to get users started. This code is used for testing. 3 import pandas as pd 4 import tensorflow as tf 5 6 TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv" 7 TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" 8 9 CSV_COLUMN_NAMES = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Species"] 10 SPECIES = ["Setosa", "Versicolor", "Virginica"] 11 12 13 def maybe_download(): 14 train_path = tf.keras.utils.get_file(TRAIN_URL.split("/")[-1], TRAIN_URL) 15 test_path = tf.keras.utils.get_file(TEST_URL.split("/")[-1], TEST_URL) 16 17 return train_path, test_path 18 19 20 def load_data(y_name="Species"): 21 """Returns the iris dataset as (train_x, train_y), (test_x, test_y).""" 22 train_path, test_path = maybe_download() 23 24 train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0) 25 train_y = train.pop(y_name) 26 train_x = train 27 28 test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0) 29 test_y = test.pop(y_name) 30 test_x = test 31 32 return (train_x, train_y), (test_x, test_y) 33 34 35 def train_input_fn(features, labels, batch_size): 36 """An input function for training""" 37 # Convert the inputs to a Dataset. 38 dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels)) 39 40 # Shuffle, repeat, and batch the examples. 41 return dataset.shuffle(1000).repeat().batch(batch_size) 42 43 44 def eval_input_fn(features, labels, batch_size): 45 """An input function for evaluation or prediction""" 46 features = dict(features) 47 48 # Use only features when labels are null. 49 inputs = features if labels is None else (features, labels) 50 51 # Convert the inputs to a Dataset. 52 dataset = tf.data.Dataset.from_tensor_slices(inputs) 53 54 # Batch the examples 55 assert batch_size is not None, "batch_size must not be None" 56 return dataset.batch(batch_size) 57 58 59 # The remainder of this file contains a simple example of a csv parser, 60 # implemented using the `Dataset` class. 61 62 # `tf.parse_csv` sets the types of the outputs to match the examples given in 63 # the `record_defaults` argument. 64 CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]] 65 66 67 def _parse_line(line): 68 # Decode the line into its fields 69 fields = tf.decode_csv(line, record_defaults=CSV_TYPES) 70 71 # Pack the result into a dictionary 72 features = dict(zip(CSV_COLUMN_NAMES, fields)) 73 74 # Separate the label from the features 75 label = features.pop("Species") 76 77 return features, label 78 79 80 def csv_input_fn(csv_path, batch_size): 81 # Create a dataset containing the text lines. 82 dataset = tf.data.TextLineDataset(csv_path).skip(1) 83 84 # Parse each line. 85 dataset = dataset.map(_parse_line) 86 87 # Shuffle, repeat, and batch the examples. 88 return dataset.shuffle(1000).repeat().batch(batch_size)