Home | History | Annotate | Download | only in mnist
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """A deep MNIST classifier using convolutional layers.
     17 
     18 See extensive documentation at
     19 https://www.tensorflow.org/get_started/mnist/pros
     20 """
     21 # Disable linter warnings to maintain consistency with tutorial.
     22 # pylint: disable=invalid-name
     23 # pylint: disable=g-bad-import-order
     24 
     25 from __future__ import absolute_import
     26 from __future__ import division
     27 from __future__ import print_function
     28 
     29 import argparse
     30 import sys
     31 import tempfile
     32 
     33 from tensorflow.examples.tutorials.mnist import input_data
     34 
     35 import tensorflow as tf
     36 
     37 FLAGS = None
     38 
     39 
     40 def deepnn(x):
     41   """deepnn builds the graph for a deep net for classifying digits.
     42 
     43   Args:
     44     x: an input tensor with the dimensions (N_examples, 784), where 784 is the
     45     number of pixels in a standard MNIST image.
     46 
     47   Returns:
     48     A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
     49     equal to the logits of classifying the digit into one of 10 classes (the
     50     digits 0-9). keep_prob is a scalar placeholder for the probability of
     51     dropout.
     52   """
     53   # Reshape to use within a convolutional neural net.
     54   # Last dimension is for "features" - there is only one here, since images are
     55   # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
     56   with tf.name_scope('reshape'):
     57     x_image = tf.reshape(x, [-1, 28, 28, 1])
     58 
     59   # First convolutional layer - maps one grayscale image to 32 feature maps.
     60   with tf.name_scope('conv1'):
     61     W_conv1 = weight_variable([5, 5, 1, 32])
     62     b_conv1 = bias_variable([32])
     63     h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
     64 
     65   # Pooling layer - downsamples by 2X.
     66   with tf.name_scope('pool1'):
     67     h_pool1 = max_pool_2x2(h_conv1)
     68 
     69   # Second convolutional layer -- maps 32 feature maps to 64.
     70   with tf.name_scope('conv2'):
     71     W_conv2 = weight_variable([5, 5, 32, 64])
     72     b_conv2 = bias_variable([64])
     73     h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
     74 
     75   # Second pooling layer.
     76   with tf.name_scope('pool2'):
     77     h_pool2 = max_pool_2x2(h_conv2)
     78 
     79   # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
     80   # is down to 7x7x64 feature maps -- maps this to 1024 features.
     81   with tf.name_scope('fc1'):
     82     W_fc1 = weight_variable([7 * 7 * 64, 1024])
     83     b_fc1 = bias_variable([1024])
     84 
     85     h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
     86     h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
     87 
     88   # Dropout - controls the complexity of the model, prevents co-adaptation of
     89   # features.
     90   with tf.name_scope('dropout'):
     91     keep_prob = tf.placeholder(tf.float32)
     92     h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
     93 
     94   # Map the 1024 features to 10 classes, one for each digit
     95   with tf.name_scope('fc2'):
     96     W_fc2 = weight_variable([1024, 10])
     97     b_fc2 = bias_variable([10])
     98 
     99     y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
    100   return y_conv, keep_prob
    101 
    102 
    103 def conv2d(x, W):
    104   """conv2d returns a 2d convolution layer with full stride."""
    105   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
    106 
    107 
    108 def max_pool_2x2(x):
    109   """max_pool_2x2 downsamples a feature map by 2X."""
    110   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
    111                         strides=[1, 2, 2, 1], padding='SAME')
    112 
    113 
    114 def weight_variable(shape):
    115   """weight_variable generates a weight variable of a given shape."""
    116   initial = tf.truncated_normal(shape, stddev=0.1)
    117   return tf.Variable(initial)
    118 
    119 
    120 def bias_variable(shape):
    121   """bias_variable generates a bias variable of a given shape."""
    122   initial = tf.constant(0.1, shape=shape)
    123   return tf.Variable(initial)
    124 
    125 
    126 def main(_):
    127   # Import data
    128   mnist = input_data.read_data_sets(FLAGS.data_dir)
    129 
    130   # Create the model
    131   x = tf.placeholder(tf.float32, [None, 784])
    132 
    133   # Define loss and optimizer
    134   y_ = tf.placeholder(tf.int64, [None])
    135 
    136   # Build the graph for the deep net
    137   y_conv, keep_prob = deepnn(x)
    138 
    139   with tf.name_scope('loss'):
    140     cross_entropy = tf.losses.sparse_softmax_cross_entropy(
    141         labels=y_, logits=y_conv)
    142   cross_entropy = tf.reduce_mean(cross_entropy)
    143 
    144   with tf.name_scope('adam_optimizer'):
    145     train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    146 
    147   with tf.name_scope('accuracy'):
    148     correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
    149     correct_prediction = tf.cast(correct_prediction, tf.float32)
    150   accuracy = tf.reduce_mean(correct_prediction)
    151 
    152   graph_location = tempfile.mkdtemp()
    153   print('Saving graph to: %s' % graph_location)
    154   train_writer = tf.summary.FileWriter(graph_location)
    155   train_writer.add_graph(tf.get_default_graph())
    156 
    157   with tf.Session() as sess:
    158     sess.run(tf.global_variables_initializer())
    159     for i in range(20000):
    160       batch = mnist.train.next_batch(50)
    161       if i % 100 == 0:
    162         train_accuracy = accuracy.eval(feed_dict={
    163             x: batch[0], y_: batch[1], keep_prob: 1.0})
    164         print('step %d, training accuracy %g' % (i, train_accuracy))
    165       train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    166 
    167     print('test accuracy %g' % accuracy.eval(feed_dict={
    168         x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
    169 
    170 if __name__ == '__main__':
    171   parser = argparse.ArgumentParser()
    172   parser.add_argument('--data_dir', type=str,
    173                       default='/tmp/tensorflow/mnist/input_data',
    174                       help='Directory for storing input data')
    175   FLAGS, unparsed = parser.parse_known_args()
    176   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    177