Home | History | Annotate | Download | only in cifar10
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Builds the CIFAR-10 network with additional variables to support pruning.
     16 
     17 Summary of available functions:
     18 
     19  # Compute input images and labels for training. If you would like to run
     20  # evaluations, use inputs() instead.
     21  inputs, labels = distorted_inputs()
     22 
     23  # Compute inference on the model inputs to make a prediction.
     24  predictions = inference(inputs)
     25 
     26  # Compute the total loss of the prediction with respect to the labels.
     27  loss = loss(predictions, labels)
     28 
     29  # Create a graph to run one step of training with respect to the loss.
     30  train_op = train(loss, global_step)
     31 """
     32 # pylint: disable=missing-docstring
     33 from __future__ import absolute_import
     34 from __future__ import division
     35 from __future__ import print_function
     36 
     37 import os
     38 import re
     39 import sys
     40 import tarfile
     41 
     42 from six.moves import urllib
     43 import tensorflow as tf
     44 
     45 from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input
     46 from tensorflow.contrib.model_pruning.python import pruning
     47 
     48 # Global constants describing the CIFAR-10 data set.
     49 IMAGE_SIZE = cifar10_input.IMAGE_SIZE
     50 NUM_CLASSES = cifar10_input.NUM_CLASSES
     51 NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN  # pylint: disable=line-too-long
     52 NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
     53 BATCH_SIZE = 128
     54 DATA_DIR = '/tmp/cifar10_data'
     55 
     56 # Constants describing the training process.
     57 MOVING_AVERAGE_DECAY = 0.9999  # The decay to use for the moving average.
     58 NUM_EPOCHS_PER_DECAY = 350.0  # Epochs after which learning rate decays.
     59 LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
     60 INITIAL_LEARNING_RATE = 0.1  # Initial learning rate.
     61 
     62 # If a model is trained with multiple GPUs, prefix all Op names with tower_name
     63 # to differentiate the operations. Note that this prefix is removed from the
     64 # names of the summaries when visualizing a model.
     65 TOWER_NAME = 'tower'
     66 
     67 DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
     68 
     69 
     70 def _activation_summary(x):
     71   """Helper to create summaries for activations.
     72 
     73   Creates a summary that provides a histogram of activations.
     74   Creates a summary that measures the sparsity of activations.
     75 
     76   Args:
     77     x: Tensor
     78   Returns:
     79     nothing
     80   """
     81   # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
     82   # session. This helps the clarity of presentation on tensorboard.
     83   tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
     84   tf.summary.histogram(tensor_name + '/activations', x)
     85   tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
     86 
     87 
     88 def _variable_on_cpu(name, shape, initializer):
     89   """Helper to create a Variable stored on CPU memory.
     90 
     91   Args:
     92     name: name of the variable
     93     shape: list of ints
     94     initializer: initializer for Variable
     95 
     96   Returns:
     97     Variable Tensor
     98   """
     99   with tf.device('/cpu:0'):
    100     dtype = tf.float32
    101     var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
    102   return var
    103 
    104 
    105 def _variable_with_weight_decay(name, shape, stddev, wd):
    106   """Helper to create an initialized Variable with weight decay.
    107 
    108   Note that the Variable is initialized with a truncated normal distribution.
    109   A weight decay is added only if one is specified.
    110 
    111   Args:
    112     name: name of the variable
    113     shape: list of ints
    114     stddev: standard deviation of a truncated Gaussian
    115     wd: add L2Loss weight decay multiplied by this float. If None, weight
    116         decay is not added for this Variable.
    117 
    118   Returns:
    119     Variable Tensor
    120   """
    121   dtype = tf.float32
    122   var = _variable_on_cpu(name, shape,
    123                          tf.truncated_normal_initializer(
    124                              stddev=stddev, dtype=dtype))
    125   if wd is not None:
    126     weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
    127     tf.add_to_collection('losses', weight_decay)
    128   return var
    129 
    130 
    131 def distorted_inputs():
    132   """Construct distorted input for CIFAR training using the Reader ops.
    133 
    134   Returns:
    135     images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    136     labels: Labels. 1D tensor of [batch_size] size.
    137 
    138   Raises:
    139     ValueError: If no data_dir
    140   """
    141   if not DATA_DIR:
    142     raise ValueError('Please supply a data_dir')
    143   data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
    144   images, labels = cifar10_input.distorted_inputs(
    145       data_dir=data_dir, batch_size=BATCH_SIZE)
    146   return images, labels
    147 
    148 
    149 def inputs(eval_data):
    150   """Construct input for CIFAR evaluation using the Reader ops.
    151 
    152   Args:
    153     eval_data: bool, indicating if one should use the train or eval data set.
    154 
    155   Returns:
    156     images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    157     labels: Labels. 1D tensor of [batch_size] size.
    158 
    159   Raises:
    160     ValueError: If no data_dir
    161   """
    162   if not DATA_DIR:
    163     raise ValueError('Please supply a data_dir')
    164   data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin')
    165   images, labels = cifar10_input.inputs(
    166       eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE)
    167   return images, labels
    168 
    169 
    170 def inference(images):
    171   """Build the CIFAR-10 model.
    172 
    173   Args:
    174     images: Images returned from distorted_inputs() or inputs().
    175 
    176   Returns:
    177     Logits.
    178   """
    179   # We instantiate all variables using tf.get_variable() instead of
    180   # tf.Variable() in order to share variables across multiple GPU training runs.
    181   # If we only ran this model on a single GPU, we could simplify this function
    182   # by replacing all instances of tf.get_variable() with tf.Variable().
    183   #
    184   # While instantiating conv and local layers, we add mask and threshold
    185   # variables to the layer by calling the pruning.apply_mask() function.
    186   # Note that the masks are applied only to the weight tensors
    187   # conv1
    188   with tf.variable_scope('conv1') as scope:
    189     kernel = _variable_with_weight_decay(
    190         'weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)
    191 
    192     conv = tf.nn.conv2d(
    193         images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
    194     biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
    195     pre_activation = tf.nn.bias_add(conv, biases)
    196     conv1 = tf.nn.relu(pre_activation, name=scope.name)
    197     _activation_summary(conv1)
    198 
    199   # pool1
    200   pool1 = tf.nn.max_pool(
    201       conv1,
    202       ksize=[1, 3, 3, 1],
    203       strides=[1, 2, 2, 1],
    204       padding='SAME',
    205       name='pool1')
    206   # norm1
    207   norm1 = tf.nn.lrn(
    208       pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
    209 
    210   # conv2
    211   with tf.variable_scope('conv2') as scope:
    212     kernel = _variable_with_weight_decay(
    213         'weights', shape=[5, 5, 64, 64], stddev=5e-2, wd=0.0)
    214     conv = tf.nn.conv2d(
    215         norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
    216     biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
    217     pre_activation = tf.nn.bias_add(conv, biases)
    218     conv2 = tf.nn.relu(pre_activation, name=scope.name)
    219     _activation_summary(conv2)
    220 
    221   # norm2
    222   norm2 = tf.nn.lrn(
    223       conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
    224   # pool2
    225   pool2 = tf.nn.max_pool(
    226       norm2,
    227       ksize=[1, 3, 3, 1],
    228       strides=[1, 2, 2, 1],
    229       padding='SAME',
    230       name='pool2')
    231 
    232   # local3
    233   with tf.variable_scope('local3') as scope:
    234     # Move everything into depth so we can perform a single matrix multiply.
    235     reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
    236     dim = reshape.get_shape()[1].value
    237     weights = _variable_with_weight_decay(
    238         'weights', shape=[dim, 384], stddev=0.04, wd=0.004)
    239     biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
    240     local3 = tf.nn.relu(
    241         tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
    242         name=scope.name)
    243     _activation_summary(local3)
    244 
    245   # local4
    246   with tf.variable_scope('local4') as scope:
    247     weights = _variable_with_weight_decay(
    248         'weights', shape=[384, 192], stddev=0.04, wd=0.004)
    249     biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
    250     local4 = tf.nn.relu(
    251         tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
    252         name=scope.name)
    253     _activation_summary(local4)
    254 
    255   # linear layer(WX + b),
    256   # We don't apply softmax here because
    257   # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
    258   # and performs the softmax internally for efficiency.
    259   with tf.variable_scope('softmax_linear') as scope:
    260     weights = _variable_with_weight_decay(
    261         'weights', [192, NUM_CLASSES], stddev=1 / 192.0, wd=0.0)
    262     biases = _variable_on_cpu('biases', [NUM_CLASSES],
    263                               tf.constant_initializer(0.0))
    264     softmax_linear = tf.add(
    265         tf.matmul(local4, pruning.apply_mask(weights, scope)),
    266         biases,
    267         name=scope.name)
    268     _activation_summary(softmax_linear)
    269 
    270   return softmax_linear
    271 
    272 
    273 def loss(logits, labels):
    274   """Add L2Loss to all the trainable variables.
    275 
    276   Add summary for "Loss" and "Loss/avg".
    277   Args:
    278     logits: Logits from inference().
    279     labels: Labels from distorted_inputs or inputs(). 1-D tensor
    280             of shape [batch_size]
    281 
    282   Returns:
    283     Loss tensor of type float.
    284   """
    285   # Calculate the average cross entropy loss across the batch.
    286   labels = tf.cast(labels, tf.int64)
    287   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    288       labels=labels, logits=logits, name='cross_entropy_per_example')
    289   cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    290   tf.add_to_collection('losses', cross_entropy_mean)
    291 
    292   # The total loss is defined as the cross entropy loss plus all of the weight
    293   # decay terms (L2 loss).
    294   return tf.add_n(tf.get_collection('losses'), name='total_loss')
    295 
    296 
    297 def _add_loss_summaries(total_loss):
    298   """Add summaries for losses in CIFAR-10 model.
    299 
    300   Generates moving average for all losses and associated summaries for
    301   visualizing the performance of the network.
    302 
    303   Args:
    304     total_loss: Total loss from loss().
    305   Returns:
    306     loss_averages_op: op for generating moving averages of losses.
    307   """
    308   # Compute the moving average of all individual losses and the total loss.
    309   loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
    310   losses = tf.get_collection('losses')
    311   loss_averages_op = loss_averages.apply(losses + [total_loss])
    312 
    313   # Attach a scalar summary to all individual losses and the total loss; do the
    314   # same for the averaged version of the losses.
    315   for l in losses + [total_loss]:
    316     # Name each loss as '(raw)' and name the moving average version of the loss
    317     # as the original loss name.
    318     tf.summary.scalar(l.op.name + ' (raw)', l)
    319     tf.summary.scalar(l.op.name, loss_averages.average(l))
    320 
    321   return loss_averages_op
    322 
    323 
    324 def train(total_loss, global_step):
    325   """Train CIFAR-10 model.
    326 
    327   Create an optimizer and apply to all trainable variables. Add moving
    328   average for all trainable variables.
    329 
    330   Args:
    331     total_loss: Total loss from loss().
    332     global_step: Integer Variable counting the number of training steps
    333       processed.
    334   Returns:
    335     train_op: op for training.
    336   """
    337   # Variables that affect learning rate.
    338   num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
    339   decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
    340 
    341   # Decay the learning rate exponentially based on the number of steps.
    342   lr = tf.train.exponential_decay(
    343       INITIAL_LEARNING_RATE,
    344       global_step,
    345       decay_steps,
    346       LEARNING_RATE_DECAY_FACTOR,
    347       staircase=True)
    348   tf.summary.scalar('learning_rate', lr)
    349 
    350   # Generate moving averages of all losses and associated summaries.
    351   loss_averages_op = _add_loss_summaries(total_loss)
    352 
    353   # Compute gradients.
    354   with tf.control_dependencies([loss_averages_op]):
    355     opt = tf.train.GradientDescentOptimizer(lr)
    356     grads = opt.compute_gradients(total_loss)
    357 
    358   # Apply gradients.
    359   apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    360 
    361   # Add histograms for trainable variables.
    362   for var in tf.trainable_variables():
    363     tf.summary.histogram(var.op.name, var)
    364 
    365   # Add histograms for gradients.
    366   for grad, var in grads:
    367     if grad is not None:
    368       tf.summary.histogram(var.op.name + '/gradients', grad)
    369 
    370   # Track the moving averages of all trainable variables.
    371   variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,
    372                                                         global_step)
    373   variables_averages_op = variable_averages.apply(tf.trainable_variables())
    374 
    375   with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
    376     train_op = tf.no_op(name='train')
    377 
    378   return train_op
    379 
    380 
    381 def maybe_download_and_extract():
    382   """Download and extract the tarball from Alex's website."""
    383   dest_directory = DATA_DIR
    384   if not os.path.exists(dest_directory):
    385     os.makedirs(dest_directory)
    386   filename = DATA_URL.split('/')[-1]
    387   filepath = os.path.join(dest_directory, filename)
    388   if not os.path.exists(filepath):
    389 
    390     def _progress(count, block_size, total_size):
    391       sys.stdout.write('\r>> Downloading %s %.1f%%' %
    392                        (filename,
    393                         float(count * block_size) / float(total_size) * 100.0))
    394       sys.stdout.flush()
    395 
    396     filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    397     print()
    398     statinfo = os.stat(filepath)
    399     print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    400 
    401   tarfile.open(filepath, 'r:gz').extractall(dest_directory)
    402