/ tests / tensorflow / iris_data_utils.py
iris_data_utils.py
 1  # From https://github.com/tensorflow/models/blob/master/samples/core/get_started/iris_data.py
 2  # This file is the example used by TensorFlow to get users started. This code is used for testing.
 3  import pandas as pd
 4  import tensorflow as tf
 5  
 6  TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
 7  TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
 8  
 9  CSV_COLUMN_NAMES = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Species"]
10  SPECIES = ["Setosa", "Versicolor", "Virginica"]
11  
12  
13  def maybe_download():
14      train_path = tf.keras.utils.get_file(TRAIN_URL.split("/")[-1], TRAIN_URL)
15      test_path = tf.keras.utils.get_file(TEST_URL.split("/")[-1], TEST_URL)
16  
17      return train_path, test_path
18  
19  
20  def load_data(y_name="Species"):
21      """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
22      train_path, test_path = maybe_download()
23  
24      train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
25      train_y = train.pop(y_name)
26      train_x = train
27  
28      test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
29      test_y = test.pop(y_name)
30      test_x = test
31  
32      return (train_x, train_y), (test_x, test_y)
33  
34  
35  def train_input_fn(features, labels, batch_size):
36      """An input function for training"""
37      # Convert the inputs to a Dataset.
38      dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
39  
40      # Shuffle, repeat, and batch the examples.
41      return dataset.shuffle(1000).repeat().batch(batch_size)
42  
43  
44  def eval_input_fn(features, labels, batch_size):
45      """An input function for evaluation or prediction"""
46      features = dict(features)
47  
48      # Use only features when labels are null.
49      inputs = features if labels is None else (features, labels)
50  
51      # Convert the inputs to a Dataset.
52      dataset = tf.data.Dataset.from_tensor_slices(inputs)
53  
54      # Batch the examples
55      assert batch_size is not None, "batch_size must not be None"
56      return dataset.batch(batch_size)
57  
58  
59  # The remainder of this file contains a simple example of a csv parser,
60  #     implemented using the `Dataset` class.
61  
62  # `tf.parse_csv` sets the types of the outputs to match the examples given in
63  #     the `record_defaults` argument.
64  CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
65  
66  
67  def _parse_line(line):
68      # Decode the line into its fields
69      fields = tf.decode_csv(line, record_defaults=CSV_TYPES)
70  
71      # Pack the result into a dictionary
72      features = dict(zip(CSV_COLUMN_NAMES, fields))
73  
74      # Separate the label from the features
75      label = features.pop("Species")
76  
77      return features, label
78  
79  
80  def csv_input_fn(csv_path, batch_size):
81      # Create a dataset containing the text lines.
82      dataset = tf.data.TextLineDataset(csv_path).skip(1)
83  
84      # Parse each line.
85      dataset = dataset.map(_parse_line)
86  
87      # Shuffle, repeat, and batch the examples.
88      return dataset.shuffle(1000).repeat().batch(batch_size)