Home | History | Annotate | Download | only in cifar10
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A binary to train pruned CIFAR-10 using a single GPU.
     16 
     17 Accuracy:
     18 cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
     19 data) as judged by cifar10_eval.py when target sparsity in
     20 cifar10_pruning_spec.pbtxt is set to zero
     21 
     22 Results:
     23 Sparsity | Accuracy after 150K steps
     24 -------- | -------------------------
     25 0%       | 86%
     26 50%      | 86%
     27 75%      | TODO(suyoggupta)
     28 90%      | TODO(suyoggupta)
     29 95%      | 77%
     30 
     31 Usage:
     32 Please see the tutorial and website for how to download the CIFAR-10
     33 data set, compile the program and train the model.
     34 
     35 
     36 """
     37 from __future__ import absolute_import
     38 from __future__ import division
     39 from __future__ import print_function
     40 
     41 import argparse
     42 import datetime
     43 import sys
     44 import time
     45 
     46 
     47 import tensorflow as tf
     48 
     49 from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10
     50 from tensorflow.contrib.model_pruning.python import pruning
     51 
     52 FLAGS = None
     53 
     54 
     55 def train():
     56   """Train CIFAR-10 for a number of steps."""
     57   with tf.Graph().as_default():
     58     global_step = tf.contrib.framework.get_or_create_global_step()
     59 
     60     # Get images and labels for CIFAR-10.
     61     images, labels = cifar10.distorted_inputs()
     62 
     63     # Build a Graph that computes the logits predictions from the
     64     # inference model.
     65     logits = cifar10.inference(images)
     66 
     67     # Calculate loss.
     68     loss = cifar10.loss(logits, labels)
     69 
     70     # Build a Graph that trains the model with one batch of examples and
     71     # updates the model parameters.
     72     train_op = cifar10.train(loss, global_step)
     73 
     74     # Parse pruning hyperparameters
     75     pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
     76 
     77     # Create a pruning object using the pruning hyperparameters
     78     pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
     79 
     80     # Use the pruning_obj to add ops to the training graph to update the masks
     81     # The conditional_mask_update_op will update the masks only when the
     82     # training step is in [begin_pruning_step, end_pruning_step] specified in
     83     # the pruning spec proto
     84     mask_update_op = pruning_obj.conditional_mask_update_op()
     85 
     86     # Use the pruning_obj to add summaries to the graph to track the sparsity
     87     # of each of the layers
     88     pruning_obj.add_pruning_summaries()
     89 
     90     class _LoggerHook(tf.train.SessionRunHook):
     91       """Logs loss and runtime."""
     92 
     93       def begin(self):
     94         self._step = -1
     95 
     96       def before_run(self, run_context):
     97         self._step += 1
     98         self._start_time = time.time()
     99         return tf.train.SessionRunArgs(loss)  # Asks for loss value.
    100 
    101       def after_run(self, run_context, run_values):
    102         duration = time.time() - self._start_time
    103         loss_value = run_values.results
    104         if self._step % 10 == 0:
    105           num_examples_per_step = 128
    106           examples_per_sec = num_examples_per_step / duration
    107           sec_per_batch = float(duration)
    108 
    109           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
    110                         'sec/batch)')
    111           print(format_str % (datetime.datetime.now(), self._step, loss_value,
    112                               examples_per_sec, sec_per_batch))
    113 
    114     with tf.train.MonitoredTrainingSession(
    115         checkpoint_dir=FLAGS.train_dir,
    116         hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
    117                tf.train.NanTensorHook(loss),
    118                _LoggerHook()],
    119         config=tf.ConfigProto(
    120             log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    121       while not mon_sess.should_stop():
    122         mon_sess.run(train_op)
    123         # Update the masks
    124         mon_sess.run(mask_update_op)
    125 
    126 
    127 def main(argv=None):  # pylint: disable=unused-argument
    128   cifar10.maybe_download_and_extract()
    129   if tf.gfile.Exists(FLAGS.train_dir):
    130     tf.gfile.DeleteRecursively(FLAGS.train_dir)
    131   tf.gfile.MakeDirs(FLAGS.train_dir)
    132   train()
    133 
    134 
    135 if __name__ == '__main__':
    136   parser = argparse.ArgumentParser()
    137   parser.add_argument(
    138       '--train_dir',
    139       type=str,
    140       default='/tmp/cifar10_train',
    141       help='Directory where to write event logs and checkpoint.')
    142   parser.add_argument(
    143       '--pruning_hparams',
    144       type=str,
    145       default='',
    146       help="""Comma separated list of pruning-related hyperparameters""")
    147   parser.add_argument(
    148       '--max_steps',
    149       type=int,
    150       default=1000000,
    151       help='Number of batches to run.')
    152   parser.add_argument(
    153       '--log_device_placement',
    154       type=bool,
    155       default=False,
    156       help='Whether to log device placement.')
    157 
    158   FLAGS, unparsed = parser.parse_known_args()
    159   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    160