1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A binary to train pruned CIFAR-10 using a single GPU. 16 17 Accuracy: 18 cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 19 data) as judged by cifar10_eval.py when target sparsity in 20 cifar10_pruning_spec.pbtxt is set to zero 21 22 Results: 23 Sparsity | Accuracy after 150K steps 24 -------- | ------------------------- 25 0% | 86% 26 50% | 86% 27 75% | TODO(suyoggupta) 28 90% | TODO(suyoggupta) 29 95% | 77% 30 31 Usage: 32 Please see the tutorial and website for how to download the CIFAR-10 33 data set, compile the program and train the model. 34 35 36 """ 37 from __future__ import absolute_import 38 from __future__ import division 39 from __future__ import print_function 40 41 import argparse 42 import datetime 43 import sys 44 import time 45 46 47 import tensorflow as tf 48 49 from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10 50 from tensorflow.contrib.model_pruning.python import pruning 51 52 FLAGS = None 53 54 55 def train(): 56 """Train CIFAR-10 for a number of steps.""" 57 with tf.Graph().as_default(): 58 global_step = tf.contrib.framework.get_or_create_global_step() 59 60 # Get images and labels for CIFAR-10. 61 images, labels = cifar10.distorted_inputs() 62 63 # Build a Graph that computes the logits predictions from the 64 # inference model. 65 logits = cifar10.inference(images) 66 67 # Calculate loss. 68 loss = cifar10.loss(logits, labels) 69 70 # Build a Graph that trains the model with one batch of examples and 71 # updates the model parameters. 72 train_op = cifar10.train(loss, global_step) 73 74 # Parse pruning hyperparameters 75 pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) 76 77 # Create a pruning object using the pruning hyperparameters 78 pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) 79 80 # Use the pruning_obj to add ops to the training graph to update the masks 81 # The conditional_mask_update_op will update the masks only when the 82 # training step is in [begin_pruning_step, end_pruning_step] specified in 83 # the pruning spec proto 84 mask_update_op = pruning_obj.conditional_mask_update_op() 85 86 # Use the pruning_obj to add summaries to the graph to track the sparsity 87 # of each of the layers 88 pruning_obj.add_pruning_summaries() 89 90 class _LoggerHook(tf.train.SessionRunHook): 91 """Logs loss and runtime.""" 92 93 def begin(self): 94 self._step = -1 95 96 def before_run(self, run_context): 97 self._step += 1 98 self._start_time = time.time() 99 return tf.train.SessionRunArgs(loss) # Asks for loss value. 100 101 def after_run(self, run_context, run_values): 102 duration = time.time() - self._start_time 103 loss_value = run_values.results 104 if self._step % 10 == 0: 105 num_examples_per_step = 128 106 examples_per_sec = num_examples_per_step / duration 107 sec_per_batch = float(duration) 108 109 format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 110 'sec/batch)') 111 print(format_str % (datetime.datetime.now(), self._step, loss_value, 112 examples_per_sec, sec_per_batch)) 113 114 with tf.train.MonitoredTrainingSession( 115 checkpoint_dir=FLAGS.train_dir, 116 hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 117 tf.train.NanTensorHook(loss), 118 _LoggerHook()], 119 config=tf.ConfigProto( 120 log_device_placement=FLAGS.log_device_placement)) as mon_sess: 121 while not mon_sess.should_stop(): 122 mon_sess.run(train_op) 123 # Update the masks 124 mon_sess.run(mask_update_op) 125 126 127 def main(argv=None): # pylint: disable=unused-argument 128 cifar10.maybe_download_and_extract() 129 if tf.gfile.Exists(FLAGS.train_dir): 130 tf.gfile.DeleteRecursively(FLAGS.train_dir) 131 tf.gfile.MakeDirs(FLAGS.train_dir) 132 train() 133 134 135 if __name__ == '__main__': 136 parser = argparse.ArgumentParser() 137 parser.add_argument( 138 '--train_dir', 139 type=str, 140 default='/tmp/cifar10_train', 141 help='Directory where to write event logs and checkpoint.') 142 parser.add_argument( 143 '--pruning_hparams', 144 type=str, 145 default='', 146 help="""Comma separated list of pruning-related hyperparameters""") 147 parser.add_argument( 148 '--max_steps', 149 type=int, 150 default=1000000, 151 help='Number of batches to run.') 152 parser.add_argument( 153 '--log_device_placement', 154 type=bool, 155 default=False, 156 help='Whether to log device placement.') 157 158 FLAGS, unparsed = parser.parse_known_args() 159 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 160