Home | History | Annotate | Download | only in learn
      1 #  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 #  Licensed under the Apache License, Version 2.0 (the "License");
      4 #  you may not use this file except in compliance with the License.
      5 #  You may obtain a copy of the License at
      6 #
      7 #   http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 #  Unless required by applicable law or agreed to in writing, software
     10 #  distributed under the License is distributed on an "AS IS" BASIS,
     11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 #  See the License for the specific language governing permissions and
     13 #  limitations under the License.
     14 """Example of DNNClassifier for Iris plant dataset.
     15 
     16 This example uses APIs in Tensorflow 1.4 or above.
     17 """
     18 
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 
     23 import os
     24 import urllib
     25 
     26 import tensorflow as tf
     27 
     28 # Data sets
     29 IRIS_TRAINING = 'iris_training.csv'
     30 IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
     31 
     32 IRIS_TEST = 'iris_test.csv'
     33 IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
     34 
     35 FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
     36 
     37 
     38 def maybe_download_iris_data(file_name, download_url):
     39   """Downloads the file and returns the number of data."""
     40   if not os.path.exists(file_name):
     41     raw = urllib.urlopen(download_url).read()
     42     with open(file_name, 'w') as f:
     43       f.write(raw)
     44 
     45   # The first line is a comma-separated string. The first one is the number of
     46   # total data in the file.
     47   with open(file_name, 'r') as f:
     48     first_line = f.readline()
     49   num_elements = first_line.split(',')[0]
     50   return int(num_elements)
     51 
     52 
     53 def input_fn(file_name, num_data, batch_size, is_training):
     54   """Creates an input_fn required by Estimator train/evaluate."""
     55   # If the data sets aren't stored locally, download them.
     56 
     57   def _parse_csv(rows_string_tensor):
     58     """Takes the string input tensor and returns tuple of (features, labels)."""
     59     # Last dim is the label.
     60     num_features = len(FEATURE_KEYS)
     61     num_columns = num_features + 1
     62     columns = tf.decode_csv(rows_string_tensor,
     63                             record_defaults=[[]] * num_columns)
     64     features = dict(zip(FEATURE_KEYS, columns[:num_features]))
     65     labels = tf.cast(columns[num_features], tf.int32)
     66     return features, labels
     67 
     68   def _input_fn():
     69     """The input_fn."""
     70     dataset = tf.data.TextLineDataset([file_name])
     71     # Skip the first line (which does not have data).
     72     dataset = dataset.skip(1)
     73     dataset = dataset.map(_parse_csv)
     74 
     75     if is_training:
     76       # For this small dataset, which can fit into memory, to achieve true
     77       # randomness, the shuffle buffer size is set as the total number of
     78       # elements in the dataset.
     79       dataset = dataset.shuffle(num_data)
     80       dataset = dataset.repeat()
     81 
     82     dataset = dataset.batch(batch_size)
     83     iterator = dataset.make_one_shot_iterator()
     84     features, labels = iterator.get_next()
     85     return features, labels
     86 
     87   return _input_fn
     88 
     89 
     90 def main(unused_argv):
     91   tf.logging.set_verbosity(tf.logging.INFO)
     92 
     93   num_training_data = maybe_download_iris_data(
     94       IRIS_TRAINING, IRIS_TRAINING_URL)
     95   num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
     96 
     97   # Build 3 layer DNN with 10, 20, 10 units respectively.
     98   feature_columns = [
     99       tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
    100   classifier = tf.estimator.DNNClassifier(
    101       feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
    102 
    103   # Train.
    104   train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
    105                             is_training=True)
    106   classifier.train(input_fn=train_input_fn, steps=400)
    107 
    108   # Eval.
    109   test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
    110                            is_training=False)
    111   scores = classifier.evaluate(input_fn=test_input_fn)
    112   print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
    113 
    114 
    115 if __name__ == '__main__':
    116   tf.app.run()
    117