1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Trains and Evaluates the MNIST network using a feed dictionary.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 # pylint: disable=missing-docstring 22 import argparse 23 import os 24 import sys 25 import time 26 27 from six.moves import xrange # pylint: disable=redefined-builtin 28 import tensorflow as tf 29 30 from tensorflow.examples.tutorials.mnist import input_data 31 from tensorflow.examples.tutorials.mnist import mnist 32 33 # Basic model parameters as external flags. 34 FLAGS = None 35 36 37 def placeholder_inputs(batch_size): 38 """Generate placeholder variables to represent the input tensors. 39 40 These placeholders are used as inputs by the rest of the model building 41 code and will be fed from the downloaded data in the .run() loop, below. 42 43 Args: 44 batch_size: The batch size will be baked into both placeholders. 45 46 Returns: 47 images_placeholder: Images placeholder. 48 labels_placeholder: Labels placeholder. 49 """ 50 # Note that the shapes of the placeholders match the shapes of the full 51 # image and label tensors, except the first dimension is now batch_size 52 # rather than the full size of the train or test data sets. 53 images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 54 mnist.IMAGE_PIXELS)) 55 labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) 56 return images_placeholder, labels_placeholder 57 58 59 def fill_feed_dict(data_set, images_pl, labels_pl): 60 """Fills the feed_dict for training the given step. 61 62 A feed_dict takes the form of: 63 feed_dict = { 64 <placeholder>: <tensor of values to be passed for placeholder>, 65 .... 66 } 67 68 Args: 69 data_set: The set of images and labels, from input_data.read_data_sets() 70 images_pl: The images placeholder, from placeholder_inputs(). 71 labels_pl: The labels placeholder, from placeholder_inputs(). 72 73 Returns: 74 feed_dict: The feed dictionary mapping from placeholders to values. 75 """ 76 # Create the feed_dict for the placeholders filled with the next 77 # `batch size` examples. 78 images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, 79 FLAGS.fake_data) 80 feed_dict = { 81 images_pl: images_feed, 82 labels_pl: labels_feed, 83 } 84 return feed_dict 85 86 87 def do_eval(sess, 88 eval_correct, 89 images_placeholder, 90 labels_placeholder, 91 data_set): 92 """Runs one evaluation against the full epoch of data. 93 94 Args: 95 sess: The session in which the model has been trained. 96 eval_correct: The Tensor that returns the number of correct predictions. 97 images_placeholder: The images placeholder. 98 labels_placeholder: The labels placeholder. 99 data_set: The set of images and labels to evaluate, from 100 input_data.read_data_sets(). 101 """ 102 # And run one epoch of eval. 103 true_count = 0 # Counts the number of correct predictions. 104 steps_per_epoch = data_set.num_examples // FLAGS.batch_size 105 num_examples = steps_per_epoch * FLAGS.batch_size 106 for step in xrange(steps_per_epoch): 107 feed_dict = fill_feed_dict(data_set, 108 images_placeholder, 109 labels_placeholder) 110 true_count += sess.run(eval_correct, feed_dict=feed_dict) 111 precision = float(true_count) / num_examples 112 print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' % 113 (num_examples, true_count, precision)) 114 115 116 def run_training(): 117 """Train MNIST for a number of steps.""" 118 # Get the sets of images and labels for training, validation, and 119 # test on MNIST. 120 data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data) 121 122 # Tell TensorFlow that the model will be built into the default Graph. 123 with tf.Graph().as_default(): 124 # Generate placeholders for the images and labels. 125 images_placeholder, labels_placeholder = placeholder_inputs( 126 FLAGS.batch_size) 127 128 # Build a Graph that computes predictions from the inference model. 129 logits = mnist.inference(images_placeholder, 130 FLAGS.hidden1, 131 FLAGS.hidden2) 132 133 # Add to the Graph the Ops for loss calculation. 134 loss = mnist.loss(logits, labels_placeholder) 135 136 # Add to the Graph the Ops that calculate and apply gradients. 137 train_op = mnist.training(loss, FLAGS.learning_rate) 138 139 # Add the Op to compare the logits to the labels during evaluation. 140 eval_correct = mnist.evaluation(logits, labels_placeholder) 141 142 # Build the summary Tensor based on the TF collection of Summaries. 143 summary = tf.summary.merge_all() 144 145 # Add the variable initializer Op. 146 init = tf.global_variables_initializer() 147 148 # Create a saver for writing training checkpoints. 149 saver = tf.train.Saver() 150 151 # Create a session for running Ops on the Graph. 152 sess = tf.Session() 153 154 # Instantiate a SummaryWriter to output summaries and the Graph. 155 summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph) 156 157 # And then after everything is built: 158 159 # Run the Op to initialize the variables. 160 sess.run(init) 161 162 # Start the training loop. 163 for step in xrange(FLAGS.max_steps): 164 start_time = time.time() 165 166 # Fill a feed dictionary with the actual set of images and labels 167 # for this particular training step. 168 feed_dict = fill_feed_dict(data_sets.train, 169 images_placeholder, 170 labels_placeholder) 171 172 # Run one step of the model. The return values are the activations 173 # from the `train_op` (which is discarded) and the `loss` Op. To 174 # inspect the values of your Ops or variables, you may include them 175 # in the list passed to sess.run() and the value tensors will be 176 # returned in the tuple from the call. 177 _, loss_value = sess.run([train_op, loss], 178 feed_dict=feed_dict) 179 180 duration = time.time() - start_time 181 182 # Write the summaries and print an overview fairly often. 183 if step % 100 == 0: 184 # Print status to stdout. 185 print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) 186 # Update the events file. 187 summary_str = sess.run(summary, feed_dict=feed_dict) 188 summary_writer.add_summary(summary_str, step) 189 summary_writer.flush() 190 191 # Save a checkpoint and evaluate the model periodically. 192 if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: 193 checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt') 194 saver.save(sess, checkpoint_file, global_step=step) 195 # Evaluate against the training set. 196 print('Training Data Eval:') 197 do_eval(sess, 198 eval_correct, 199 images_placeholder, 200 labels_placeholder, 201 data_sets.train) 202 # Evaluate against the validation set. 203 print('Validation Data Eval:') 204 do_eval(sess, 205 eval_correct, 206 images_placeholder, 207 labels_placeholder, 208 data_sets.validation) 209 # Evaluate against the test set. 210 print('Test Data Eval:') 211 do_eval(sess, 212 eval_correct, 213 images_placeholder, 214 labels_placeholder, 215 data_sets.test) 216 217 218 def main(_): 219 if tf.gfile.Exists(FLAGS.log_dir): 220 tf.gfile.DeleteRecursively(FLAGS.log_dir) 221 tf.gfile.MakeDirs(FLAGS.log_dir) 222 run_training() 223 224 225 if __name__ == '__main__': 226 parser = argparse.ArgumentParser() 227 parser.add_argument( 228 '--learning_rate', 229 type=float, 230 default=0.01, 231 help='Initial learning rate.' 232 ) 233 parser.add_argument( 234 '--max_steps', 235 type=int, 236 default=2000, 237 help='Number of steps to run trainer.' 238 ) 239 parser.add_argument( 240 '--hidden1', 241 type=int, 242 default=128, 243 help='Number of units in hidden layer 1.' 244 ) 245 parser.add_argument( 246 '--hidden2', 247 type=int, 248 default=32, 249 help='Number of units in hidden layer 2.' 250 ) 251 parser.add_argument( 252 '--batch_size', 253 type=int, 254 default=100, 255 help='Batch size. Must divide evenly into the dataset sizes.' 256 ) 257 parser.add_argument( 258 '--input_data_dir', 259 type=str, 260 default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 261 'tensorflow/mnist/input_data'), 262 help='Directory to put the input data.' 263 ) 264 parser.add_argument( 265 '--log_dir', 266 type=str, 267 default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 268 'tensorflow/mnist/logs/fully_connected_feed'), 269 help='Directory to put the log data.' 270 ) 271 parser.add_argument( 272 '--fake_data', 273 default=False, 274 help='If true, uses fake data for unit testing.', 275 action='store_true' 276 ) 277 278 FLAGS, unparsed = parser.parse_known_args() 279 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 280