Home | History | Annotate | Download | only in mnist
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Trains and Evaluates the MNIST network using a feed dictionary."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 # pylint: disable=missing-docstring
     22 import argparse
     23 import os
     24 import sys
     25 import time
     26 
     27 from six.moves import xrange  # pylint: disable=redefined-builtin
     28 import tensorflow as tf
     29 
     30 from tensorflow.examples.tutorials.mnist import input_data
     31 from tensorflow.examples.tutorials.mnist import mnist
     32 
     33 # Basic model parameters as external flags.
     34 FLAGS = None
     35 
     36 
     37 def placeholder_inputs(batch_size):
     38   """Generate placeholder variables to represent the input tensors.
     39 
     40   These placeholders are used as inputs by the rest of the model building
     41   code and will be fed from the downloaded data in the .run() loop, below.
     42 
     43   Args:
     44     batch_size: The batch size will be baked into both placeholders.
     45 
     46   Returns:
     47     images_placeholder: Images placeholder.
     48     labels_placeholder: Labels placeholder.
     49   """
     50   # Note that the shapes of the placeholders match the shapes of the full
     51   # image and label tensors, except the first dimension is now batch_size
     52   # rather than the full size of the train or test data sets.
     53   images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
     54                                                          mnist.IMAGE_PIXELS))
     55   labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
     56   return images_placeholder, labels_placeholder
     57 
     58 
     59 def fill_feed_dict(data_set, images_pl, labels_pl):
     60   """Fills the feed_dict for training the given step.
     61 
     62   A feed_dict takes the form of:
     63   feed_dict = {
     64       <placeholder>: <tensor of values to be passed for placeholder>,
     65       ....
     66   }
     67 
     68   Args:
     69     data_set: The set of images and labels, from input_data.read_data_sets()
     70     images_pl: The images placeholder, from placeholder_inputs().
     71     labels_pl: The labels placeholder, from placeholder_inputs().
     72 
     73   Returns:
     74     feed_dict: The feed dictionary mapping from placeholders to values.
     75   """
     76   # Create the feed_dict for the placeholders filled with the next
     77   # `batch size` examples.
     78   images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
     79                                                  FLAGS.fake_data)
     80   feed_dict = {
     81       images_pl: images_feed,
     82       labels_pl: labels_feed,
     83   }
     84   return feed_dict
     85 
     86 
     87 def do_eval(sess,
     88             eval_correct,
     89             images_placeholder,
     90             labels_placeholder,
     91             data_set):
     92   """Runs one evaluation against the full epoch of data.
     93 
     94   Args:
     95     sess: The session in which the model has been trained.
     96     eval_correct: The Tensor that returns the number of correct predictions.
     97     images_placeholder: The images placeholder.
     98     labels_placeholder: The labels placeholder.
     99     data_set: The set of images and labels to evaluate, from
    100       input_data.read_data_sets().
    101   """
    102   # And run one epoch of eval.
    103   true_count = 0  # Counts the number of correct predictions.
    104   steps_per_epoch = data_set.num_examples // FLAGS.batch_size
    105   num_examples = steps_per_epoch * FLAGS.batch_size
    106   for step in xrange(steps_per_epoch):
    107     feed_dict = fill_feed_dict(data_set,
    108                                images_placeholder,
    109                                labels_placeholder)
    110     true_count += sess.run(eval_correct, feed_dict=feed_dict)
    111   precision = float(true_count) / num_examples
    112   print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
    113         (num_examples, true_count, precision))
    114 
    115 
    116 def run_training():
    117   """Train MNIST for a number of steps."""
    118   # Get the sets of images and labels for training, validation, and
    119   # test on MNIST.
    120   data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
    121 
    122   # Tell TensorFlow that the model will be built into the default Graph.
    123   with tf.Graph().as_default():
    124     # Generate placeholders for the images and labels.
    125     images_placeholder, labels_placeholder = placeholder_inputs(
    126         FLAGS.batch_size)
    127 
    128     # Build a Graph that computes predictions from the inference model.
    129     logits = mnist.inference(images_placeholder,
    130                              FLAGS.hidden1,
    131                              FLAGS.hidden2)
    132 
    133     # Add to the Graph the Ops for loss calculation.
    134     loss = mnist.loss(logits, labels_placeholder)
    135 
    136     # Add to the Graph the Ops that calculate and apply gradients.
    137     train_op = mnist.training(loss, FLAGS.learning_rate)
    138 
    139     # Add the Op to compare the logits to the labels during evaluation.
    140     eval_correct = mnist.evaluation(logits, labels_placeholder)
    141 
    142     # Build the summary Tensor based on the TF collection of Summaries.
    143     summary = tf.summary.merge_all()
    144 
    145     # Add the variable initializer Op.
    146     init = tf.global_variables_initializer()
    147 
    148     # Create a saver for writing training checkpoints.
    149     saver = tf.train.Saver()
    150 
    151     # Create a session for running Ops on the Graph.
    152     sess = tf.Session()
    153 
    154     # Instantiate a SummaryWriter to output summaries and the Graph.
    155     summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
    156 
    157     # And then after everything is built:
    158 
    159     # Run the Op to initialize the variables.
    160     sess.run(init)
    161 
    162     # Start the training loop.
    163     for step in xrange(FLAGS.max_steps):
    164       start_time = time.time()
    165 
    166       # Fill a feed dictionary with the actual set of images and labels
    167       # for this particular training step.
    168       feed_dict = fill_feed_dict(data_sets.train,
    169                                  images_placeholder,
    170                                  labels_placeholder)
    171 
    172       # Run one step of the model.  The return values are the activations
    173       # from the `train_op` (which is discarded) and the `loss` Op.  To
    174       # inspect the values of your Ops or variables, you may include them
    175       # in the list passed to sess.run() and the value tensors will be
    176       # returned in the tuple from the call.
    177       _, loss_value = sess.run([train_op, loss],
    178                                feed_dict=feed_dict)
    179 
    180       duration = time.time() - start_time
    181 
    182       # Write the summaries and print an overview fairly often.
    183       if step % 100 == 0:
    184         # Print status to stdout.
    185         print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
    186         # Update the events file.
    187         summary_str = sess.run(summary, feed_dict=feed_dict)
    188         summary_writer.add_summary(summary_str, step)
    189         summary_writer.flush()
    190 
    191       # Save a checkpoint and evaluate the model periodically.
    192       if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
    193         checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
    194         saver.save(sess, checkpoint_file, global_step=step)
    195         # Evaluate against the training set.
    196         print('Training Data Eval:')
    197         do_eval(sess,
    198                 eval_correct,
    199                 images_placeholder,
    200                 labels_placeholder,
    201                 data_sets.train)
    202         # Evaluate against the validation set.
    203         print('Validation Data Eval:')
    204         do_eval(sess,
    205                 eval_correct,
    206                 images_placeholder,
    207                 labels_placeholder,
    208                 data_sets.validation)
    209         # Evaluate against the test set.
    210         print('Test Data Eval:')
    211         do_eval(sess,
    212                 eval_correct,
    213                 images_placeholder,
    214                 labels_placeholder,
    215                 data_sets.test)
    216 
    217 
    218 def main(_):
    219   if tf.gfile.Exists(FLAGS.log_dir):
    220     tf.gfile.DeleteRecursively(FLAGS.log_dir)
    221   tf.gfile.MakeDirs(FLAGS.log_dir)
    222   run_training()
    223 
    224 
    225 if __name__ == '__main__':
    226   parser = argparse.ArgumentParser()
    227   parser.add_argument(
    228       '--learning_rate',
    229       type=float,
    230       default=0.01,
    231       help='Initial learning rate.'
    232   )
    233   parser.add_argument(
    234       '--max_steps',
    235       type=int,
    236       default=2000,
    237       help='Number of steps to run trainer.'
    238   )
    239   parser.add_argument(
    240       '--hidden1',
    241       type=int,
    242       default=128,
    243       help='Number of units in hidden layer 1.'
    244   )
    245   parser.add_argument(
    246       '--hidden2',
    247       type=int,
    248       default=32,
    249       help='Number of units in hidden layer 2.'
    250   )
    251   parser.add_argument(
    252       '--batch_size',
    253       type=int,
    254       default=100,
    255       help='Batch size.  Must divide evenly into the dataset sizes.'
    256   )
    257   parser.add_argument(
    258       '--input_data_dir',
    259       type=str,
    260       default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
    261                            'tensorflow/mnist/input_data'),
    262       help='Directory to put the input data.'
    263   )
    264   parser.add_argument(
    265       '--log_dir',
    266       type=str,
    267       default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
    268                            'tensorflow/mnist/logs/fully_connected_feed'),
    269       help='Directory to put the log data.'
    270   )
    271   parser.add_argument(
    272       '--fake_data',
    273       default=False,
    274       help='If true, uses fake data for unit testing.',
    275       action='store_true'
    276   )
    277 
    278   FLAGS, unparsed = parser.parse_known_args()
    279   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    280