Home | History | Annotate | only in /external/tensorflow/tensorflow/contrib/model_pruning
Up to higher level directory
NameDateSize
__init__.py21-Aug-20182.4K
BUILD21-Aug-20183.2K
examples/21-Aug-2018
python/21-Aug-2018
README.md21-Aug-20187.7K

README.md

      1 # Model pruning: Training tensorflow models to have masked connections
      2 
      3 This document describes the API that facilitates magnitude-based pruning of
      4 neural network's weight tensors. The API helps inject necessary tensorflow op
      5 into the training graph so the model can be pruned while it is being trained.
      6 
      7 ### Model creation
      8 
      9 The first step involves adding mask and threshold variables to the layers that
     10 need to undergo pruning. The variable mask is the same shape as the layer's
     11 weight tensor and determines which of the weights participate in the forward
     12 execution of the graph. This can be achieved by wrapping the weight tensor of
     13 the layer with the `apply_mask` function provided in
     14 [pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py).
     15 For example:
     16 
     17 ```python
     18 conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding)
     19 ```
     20 
     21 This creates a convolutional layer with additional variables mask and threshold
     22 as shown below: ![Convolutional layer with mask and
     23 threshold](https://storage.googleapis.com/download.tensorflow.org/example_images/mask.png "Convolutional layer with mask and threshold")
     24 
     25 Alternatively, the API also provides variant of tensorflow layers with these
     26 auxiliary variables built-in (see
     27 [layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers))
     28 . Layers currently supported:
     29 
     30 *   [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83)
     31 
     32 *   [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241)
     33 
     34 *   [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
     35 
     36 ### Adding pruning ops to the training graph
     37 
     38 The pruning library allows for specification of the following hyper parameters:
     39 
     40 |Hyperparameter               | Type    | Default       | Description |
     41 |:----------------------------|:-------:|:-------------:|:--------------|
     42 | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope |
     43 | begin_pruning_step | integer | 0 | The global step at which to begin pruning |
     44 | end_pruning_step   | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till  the training stops |
     45 | do_not_prune | list of strings | [""] | list of layers names that are not pruned |
     46 | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
     47 | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
     48 | nbins | integer | 255 | Number of bins to use for histogram computation |
     49 | block_height|integer | 1 | Number of rows in a block for block sparse matrices|
     50 | block_width |integer | 1 | Number of cols in a block for block sparse matrices|
     51 | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
     52 | initial_sparsity | float | 0.0 | Initial sparsity value |
     53 | target_sparsity | float | 0.5 | Target sparsity value |
     54 | sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect |
     55 | sparsity_function_end_step | integer | 100 | The global step used as the end point for the gradual sparsity function |
     56 | sparsity_function_exponent | float | 3.0 | exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning |
     57 
     58 The sparsity $$s_t$$ at global step $$t$$ is given by:
     59 
     60 $$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$
     61 
     62 The interval between sparsity_function_begin_step and sparsity_function_end_step
     63 is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
     64 t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
     65 is the sparsity_function_begin_step. In this equation, the
     66 sparsity_function_exponent is set to 3.
     67 ### Adding pruning ops to the training graph
     68 
     69 The final step involves adding ops to the training graph that monitors the
     70 distribution of the layer's weight magnitudes and determines the layer threshold
     71 such masking all the weights below this threshold achieves the sparsity level
     72 desired for the current training step. This can be achieved as follows:
     73 
     74 ```python
     75 tf.app.flags.DEFINE_string(
     76     'pruning_hparams', '',
     77     """Comma separated list of pruning-related hyperparameters""")
     78 
     79 with tf.graph.as_default():
     80 
     81   # Create global step variable
     82   global_step = tf.train.get_global_step()
     83 
     84   # Parse pruning hyperparameters
     85   pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
     86 
     87   # Create a pruning object using the pruning specification
     88   p = pruning.Pruning(pruning_hparams, global_step=global_step)
     89 
     90   # Add conditional mask update op. Executing this op will update all
     91   # the masks in the graph if the current global step is in the range
     92   # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
     93   mask_update_op = p.conditional_mask_update_op()
     94 
     95   # Add summaries to keep track of the sparsity in different layers during training
     96   p.add_pruning_summaries()
     97 
     98   with tf.train.MonitoredTrainingSession(...) as mon_sess:
     99     # Run the usual training op in the tf session
    100     mon_sess.run(train_op)
    101 
    102     # Update the masks by running the mask_update_op
    103     mon_sess.run(mask_update_op)
    104 
    105 ```
    106 
    107 ## Example: Pruning and training deep CNNs on the cifar10 dataset
    108 
    109 Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
    110 network architecture, setting up inputs etc. The additional changes needed to
    111 incorporate pruning are captured in the following:
    112 
    113 *   [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py)
    114     creates a deep CNN with the same architecture, but adds mask and threshold
    115     variables for each of the weight tensors in the convolutional and
    116     locally-connected layers.
    117 
    118 *   [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py)
    119     add pruning ops to the training graph as described above.
    120 
    121 To train the pruned version of cifar10:
    122 
    123 ```bash
    124 $ examples_dir=contrib/model_pruning/examples
    125 $ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
    126 $ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
    127 ```
    128 
    129 Eval:
    130 
    131 ```shell
    132 $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
    133 ```
    134 
    135 ### Block Sparsity
    136 
    137 For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is supported for weight tensors with rank 2 only. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
    138 The convolution layer tensors are always pruned used block dimensions of [1,1].
    139 
    140 ## References
    141 
    142 Michael Zhu and Suyog Gupta, To prune, or not to prune: exploring the efficacy of pruning for model compression, *2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices* (https://arxiv.org/pdf/1710.01878.pdf)
    143