Home | History | Annotate | Download | only in examples
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Demo of the tfdbg curses CLI: Locating the source of bad numerical values.
     16 
     17 The neural network in this demo is larged based on the tutorial at:
     18   tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
     19 
     20 But modifications are made so that problematic numerical values (infs and nans)
     21 appear in nodes of the graph during training.
     22 """
     23 from __future__ import absolute_import
     24 from __future__ import division
     25 from __future__ import print_function
     26 
     27 import argparse
     28 import sys
     29 
     30 import tensorflow as tf
     31 
     32 from tensorflow.examples.tutorials.mnist import input_data
     33 from tensorflow.python import debug as tf_debug
     34 
     35 
     36 IMAGE_SIZE = 28
     37 HIDDEN_SIZE = 500
     38 NUM_LABELS = 10
     39 RAND_SEED = 42
     40 
     41 
     42 def main(_):
     43   # Import data
     44   mnist = input_data.read_data_sets(FLAGS.data_dir,
     45                                     one_hot=True,
     46                                     fake_data=FLAGS.fake_data)
     47 
     48   def feed_dict(train):
     49     if train or FLAGS.fake_data:
     50       xs, ys = mnist.train.next_batch(FLAGS.train_batch_size,
     51                                       fake_data=FLAGS.fake_data)
     52     else:
     53       xs, ys = mnist.test.images, mnist.test.labels
     54 
     55     return {x: xs, y_: ys}
     56 
     57   sess = tf.InteractiveSession()
     58 
     59   # Create the MNIST neural network graph.
     60 
     61   # Input placeholders.
     62   with tf.name_scope("input"):
     63     x = tf.placeholder(
     64         tf.float32, [None, IMAGE_SIZE * IMAGE_SIZE], name="x-input")
     65     y_ = tf.placeholder(tf.float32, [None, NUM_LABELS], name="y-input")
     66 
     67   def weight_variable(shape):
     68     """Create a weight variable with appropriate initialization."""
     69     initial = tf.truncated_normal(shape, stddev=0.1, seed=RAND_SEED)
     70     return tf.Variable(initial)
     71 
     72   def bias_variable(shape):
     73     """Create a bias variable with appropriate initialization."""
     74     initial = tf.constant(0.1, shape=shape)
     75     return tf.Variable(initial)
     76 
     77   def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
     78     """Reusable code for making a simple neural net layer."""
     79     # Adding a name scope ensures logical grouping of the layers in the graph.
     80     with tf.name_scope(layer_name):
     81       # This Variable will hold the state of the weights for the layer
     82       with tf.name_scope("weights"):
     83         weights = weight_variable([input_dim, output_dim])
     84       with tf.name_scope("biases"):
     85         biases = bias_variable([output_dim])
     86       with tf.name_scope("Wx_plus_b"):
     87         preactivate = tf.matmul(input_tensor, weights) + biases
     88 
     89       activations = act(preactivate)
     90       return activations
     91 
     92   hidden = nn_layer(x, IMAGE_SIZE**2, HIDDEN_SIZE, "hidden")
     93   logits = nn_layer(hidden, HIDDEN_SIZE, NUM_LABELS, "output", tf.identity)
     94   y = tf.nn.softmax(logits)
     95 
     96   with tf.name_scope("cross_entropy"):
     97     # The following line is the culprit of the bad numerical values that appear
     98     # during training of this graph. Log of zero gives inf, which is first seen
     99     # in the intermediate tensor "cross_entropy/Log:0" during the 4th run()
    100     # call. A multiplication of the inf values with zeros leads to nans,
    101     # which is first in "cross_entropy/mul:0".
    102     #
    103     # You can use the built-in, numerically-stable implementation to fix this
    104     # issue:
    105     #   diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits)
    106 
    107     diff = -(y_ * tf.log(y))
    108     with tf.name_scope("total"):
    109       cross_entropy = tf.reduce_mean(diff)
    110 
    111   with tf.name_scope("train"):
    112     train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
    113         cross_entropy)
    114 
    115   with tf.name_scope("accuracy"):
    116     with tf.name_scope("correct_prediction"):
    117       correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    118     with tf.name_scope("accuracy"):
    119       accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    120 
    121   sess.run(tf.global_variables_initializer())
    122 
    123   if FLAGS.debug and FLAGS.tensorboard_debug_address:
    124     raise ValueError(
    125         "The --debug and --tensorboard_debug_address flags are mutually "
    126         "exclusive.")
    127   if FLAGS.debug:
    128     sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
    129   elif FLAGS.tensorboard_debug_address:
    130     sess = tf_debug.TensorBoardDebugWrapperSession(
    131         sess, FLAGS.tensorboard_debug_address)
    132 
    133   # Add this point, sess is a debug wrapper around the actual Session if
    134   # FLAGS.debug is true. In that case, calling run() will launch the CLI.
    135   for i in range(FLAGS.max_steps):
    136     acc = sess.run(accuracy, feed_dict=feed_dict(False))
    137     print("Accuracy at step %d: %s" % (i, acc))
    138 
    139     sess.run(train_step, feed_dict=feed_dict(True))
    140 
    141 
    142 if __name__ == "__main__":
    143   parser = argparse.ArgumentParser()
    144   parser.register("type", "bool", lambda v: v.lower() == "true")
    145   parser.add_argument(
    146       "--max_steps",
    147       type=int,
    148       default=10,
    149       help="Number of steps to run trainer.")
    150   parser.add_argument(
    151       "--train_batch_size",
    152       type=int,
    153       default=100,
    154       help="Batch size used during training.")
    155   parser.add_argument(
    156       "--learning_rate",
    157       type=float,
    158       default=0.025,
    159       help="Initial learning rate.")
    160   parser.add_argument(
    161       "--data_dir",
    162       type=str,
    163       default="/tmp/mnist_data",
    164       help="Directory for storing data")
    165   parser.add_argument(
    166       "--ui_type",
    167       type=str,
    168       default="curses",
    169       help="Command-line user interface type (curses | readline)")
    170   parser.add_argument(
    171       "--fake_data",
    172       type="bool",
    173       nargs="?",
    174       const=True,
    175       default=False,
    176       help="Use fake MNIST data for unit testing")
    177   parser.add_argument(
    178       "--debug",
    179       type="bool",
    180       nargs="?",
    181       const=True,
    182       default=False,
    183       help="Use debugger to track down bad values during training. "
    184       "Mutually exclusive with the --tensorboard_debug_address flag.")
    185   parser.add_argument(
    186       "--tensorboard_debug_address",
    187       type=str,
    188       default=None,
    189       help="Connect to the TensorBoard Debugger Plugin backend specified by "
    190       "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
    191       "--debug flag.")
    192   FLAGS, unparsed = parser.parse_known_args()
    193   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    194