Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Helper functions to add support for magnitude-based model pruning.
     16 
     17   # Adds variables and ops to the graph to enable
     18   # elementwise masking of weights
     19   apply_mask(weights)
     20 
     21   # Returns a list containing the sparsity of each of the weight tensors
     22   get_weight_sparsity()
     23 
     24   # Returns a list of all the masked weight tensorflow variables
     25   get_masked_weights()
     26 
     27   # Returns a list of all the mask tensorflow variables
     28   get_masks()
     29 
     30   # Returns a list of all the thresholds
     31   get_thresholds()
     32 
     33   # Returns a list of all the weight tensors that have been masked
     34   get_weights()
     35 
     36   The Pruning class uses a tf.hparams object to set up the
     37   parameters for a model pruning. Here's a typical usage:
     38 
     39   # Parse pruning hyperparameters
     40   pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
     41 
     42   # Create a pruning object using the pruning_hparams
     43   p = pruning.Pruning(pruning_hparams)
     44 
     45   # Add mask update ops to the graph
     46   mask_update_op = p.conditional_mask_update_op()
     47 
     48   # Add the summaries
     49   p.add_pruning_summaries()
     50 
     51   # Run the op
     52   session.run(mask_update_op)
     53 
     54   # An object of the pruning also accepts externally defined sparsity:
     55   sparsity = tf.Variable(0.5, name = "ConstantSparsity")
     56   p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
     57 """
     58 # pylint: disable=missing-docstring
     59 from __future__ import absolute_import
     60 from __future__ import division
     61 from __future__ import print_function
     62 
     63 from tensorflow.contrib.model_pruning.python import pruning_utils
     64 from tensorflow.contrib.model_pruning.python.layers import core_layers as core
     65 from tensorflow.contrib.training.python.training import hparam
     66 from tensorflow.python.framework import dtypes
     67 from tensorflow.python.framework import ops
     68 from tensorflow.python.ops import array_ops
     69 from tensorflow.python.ops import control_flow_ops
     70 from tensorflow.python.ops import init_ops
     71 from tensorflow.python.ops import math_ops
     72 from tensorflow.python.ops import nn_impl
     73 from tensorflow.python.ops import nn_ops
     74 from tensorflow.python.ops import state_ops
     75 from tensorflow.python.ops import variable_scope
     76 from tensorflow.python.ops import variables
     77 from tensorflow.python.platform import tf_logging as logging
     78 from tensorflow.python.summary import summary
     79 from tensorflow.python.training import training_util
     80 
     81 _MASK_COLLECTION = core.MASK_COLLECTION
     82 _THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
     83 _MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
     84 _WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
     85 _MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
     86 
     87 
     88 def apply_mask(x, scope=''):
     89   """Apply mask to a given weight tensor.
     90 
     91   Args:
     92     x: Input weight tensor
     93     scope: The current variable scope. Defaults to "".
     94   Returns:
     95     Tensor representing masked_weights
     96   """
     97 
     98   mask = pruning_utils.weight_mask_variable(x, scope)
     99   threshold = pruning_utils.weight_threshold_variable(x, scope)
    100   # Add masked_weights in the weights namescope so as to make it easier
    101   # for the quantization library to add quant ops.
    102   masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
    103 
    104   # Make sure the mask for a given variable are not added multiple times to the
    105   # collection. This is particularly important when applying mask to RNN's
    106   # weight variables
    107   if mask not in ops.get_collection_ref(_MASK_COLLECTION):
    108     ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
    109     ops.add_to_collection(_MASK_COLLECTION, mask)
    110     ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
    111     ops.add_to_collection(_WEIGHT_COLLECTION, x)
    112   return masked_weights
    113 
    114 
    115 def get_masked_weights():
    116   return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
    117 
    118 
    119 def get_masks():
    120   return ops.get_collection(_MASK_COLLECTION)
    121 
    122 
    123 def get_thresholds():
    124   return ops.get_collection(_THRESHOLD_COLLECTION)
    125 
    126 
    127 def get_weights():
    128   return ops.get_collection(_WEIGHT_COLLECTION)
    129 
    130 
    131 def get_weight_sparsity():
    132   """Get sparsity of the weights.
    133 
    134   Args:
    135     None
    136 
    137   Returns:
    138     A list containing the sparsity of each of the weight tensors
    139   """
    140   masks = get_masks()
    141   return [nn_impl.zero_fraction(mask) for mask in masks]
    142 
    143 
    144 def get_pruning_hparams():
    145   """Get a tf.HParams object with the default values for the hyperparameters.
    146 
    147     name: string
    148       name of the pruning specification. Used for adding summaries and ops under
    149       a common tensorflow name_scope
    150     begin_pruning_step: integer
    151       the global step at which to begin pruning
    152     end_pruning_step: integer
    153       the global step at which to terminate pruning. Defaults to -1 implying
    154       that pruning continues till the training stops
    155     weight_sparsity_map: list of strings
    156        comma separed list of weight variable name:target sparsity pairs.
    157        For layers/weights not in this list, sparsity as specified by the
    158        target_sparsity hyperparameter is used.
    159        Eg. [conv1:0.9,conv2/kernel:0.8]
    160     threshold_decay: float
    161       the decay factor to use for exponential decay of the thresholds
    162     pruning_frequency: integer
    163       How often should the masks be updated? (in # of global_steps)
    164     nbins: integer
    165       number of bins to use for histogram computation
    166     block_height: integer
    167       number of rows in a block (defaults to 1)
    168     block_width: integer
    169       number of cols in a block (defaults to 1)
    170     block_pooling_function: string
    171       Whether to perform average (AVG) or max (MAX) pooling in the block
    172       (default: AVG)
    173     initial_sparsity: float
    174       initial sparsity value
    175     target_sparsity: float
    176       target sparsity value
    177     sparsity_function_begin_step: integer
    178       the global step at this which the gradual sparsity function begins to
    179       take effect
    180     sparsity_function_end_step: integer
    181       the global step used as the end point for the gradual sparsity function
    182     sparsity_function_exponent: float
    183       exponent = 1 is linearly varying sparsity between initial and final.
    184       exponent > 1 varies more slowly towards the end than the beginning
    185     use_tpu: False
    186       Indicates whether to use TPU
    187 
    188     We use the following sparsity function:
    189 
    190     num_steps = (sparsity_function_end_step -
    191                  sparsity_function_begin_step)/pruning_frequency
    192     sparsity(step) = (initial_sparsity - target_sparsity)*
    193                      [1-step/(num_steps -1)]**exponent + target_sparsity
    194 
    195   Args:
    196     None
    197 
    198   Returns:
    199     tf.HParams object initialized to default values
    200 
    201   """
    202   return hparam.HParams(
    203       name='model_pruning',
    204       begin_pruning_step=0,
    205       end_pruning_step=-1,
    206       weight_sparsity_map=[''],
    207       threshold_decay=0.0,
    208       pruning_frequency=10,
    209       nbins=256,
    210       block_height=1,
    211       block_width=1,
    212       block_pooling_function='AVG',
    213       initial_sparsity=0.0,
    214       target_sparsity=0.5,
    215       sparsity_function_begin_step=0,
    216       sparsity_function_end_step=100,
    217       sparsity_function_exponent=3.0,
    218       use_tpu=False)
    219 
    220 
    221 class Pruning(object):
    222 
    223   def __init__(self, spec=None, global_step=None, sparsity=None):
    224     """Set up the specification for model pruning.
    225 
    226     If a spec is provided, the sparsity is set up based on the sparsity_function
    227     in the spec. The effect of sparsity_function is overridden if the sparsity
    228     variable is passed to the constructor. This enables setting up arbitrary
    229     sparsity profiles externally and passing it to this pruning functions.
    230 
    231     Args:
    232       spec: Pruning spec as defined in pruning.proto
    233       global_step: A tensorflow variable that is used while setting up the
    234         sparsity function
    235       sparsity: A tensorflow scalar variable storing the sparsity
    236     """
    237     # Pruning specification
    238     self._spec = spec if spec else get_pruning_hparams()
    239 
    240     # Sanity check for pruning hparams
    241     self._validate_spec()
    242 
    243     # A tensorflow variable that tracks the sparsity function.
    244     # If not provided as input, the graph must already contain the global_step
    245     # variable before calling this constructor.
    246     self._global_step = self._setup_global_step(global_step)
    247 
    248     # Stores the tensorflow sparsity variable.
    249     # Built using self._setup_sparsity() or provided externally
    250     self._sparsity = (sparsity
    251                       if sparsity is not None else self._setup_sparsity())
    252 
    253     # List of tensorflow assignments ops for new masks and thresholds
    254     self._assign_ops = []
    255 
    256     # Tensorflow variable keeping track of the last global step when the masks
    257     # were updated
    258     self._last_update_step = self._setup_last_update_step()
    259 
    260     # Block dimensions
    261     self._block_dim = [self._spec.block_height, self._spec.block_width]
    262 
    263     # Block pooling function
    264     self._block_pooling_function = self._spec.block_pooling_function
    265 
    266     # Mapping of weight names and target sparsity
    267     self._weight_sparsity_map = self._get_weight_sparsity_map()
    268 
    269   def _validate_spec(self):
    270     spec = self._spec
    271     if spec.begin_pruning_step < 0:
    272       raise ValueError('Illegal value for begin_pruning_step')
    273 
    274     if spec.begin_pruning_step >= spec.end_pruning_step:
    275       if spec.end_pruning_step != -1:
    276         raise ValueError(
    277             'Pruning must begin before it can end. begin_step=%d, end_step=%d.'
    278             'Set end_pruning_step to -1 if pruning is required till training'
    279             'stops' % (spec.begin_pruning_step, spec.end_pruning_step))
    280 
    281     if spec.sparsity_function_begin_step < 0:
    282       raise ValueError('Illegal value for sparsity_function_begin_step')
    283 
    284     if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step:
    285       raise ValueError(
    286           'Sparsity function requires begin_step < end_step')
    287 
    288     if not 0.0 <= spec.threshold_decay < 1.0:
    289       raise ValueError('threshold_decay must be in range [0,1)')
    290 
    291     if not 0.0 <= spec.initial_sparsity < 1.0:
    292       raise ValueError('initial_sparsity must be in range [0,1)')
    293 
    294     if not 0.0 <= spec.target_sparsity < 1.0:
    295       raise ValueError('target_sparsity must be in range [0,1)')
    296 
    297   def _setup_global_step(self, global_step):
    298     graph_global_step = global_step
    299     if graph_global_step is None:
    300       graph_global_step = training_util.get_global_step()
    301 
    302     return math_ops.cast(graph_global_step, dtypes.int32)
    303 
    304   def _setup_sparsity(self):
    305     begin_step = self._spec.sparsity_function_begin_step
    306     end_step = self._spec.sparsity_function_end_step
    307     initial_sparsity = self._spec.initial_sparsity
    308     target_sparsity = self._spec.target_sparsity
    309     exponent = self._spec.sparsity_function_exponent
    310 
    311     with ops.name_scope(self._spec.name):
    312       p = math_ops.minimum(
    313           1.0,
    314           math_ops.maximum(
    315               0.0,
    316               math_ops.div(
    317                   math_ops.cast(self._global_step - begin_step, dtypes.float32),
    318                   end_step - begin_step)))
    319       sparsity = math_ops.add(
    320           math_ops.multiply(initial_sparsity - target_sparsity,
    321                             math_ops.pow(1 - p, exponent)),
    322           target_sparsity,
    323           name='sparsity')
    324 
    325     return sparsity
    326 
    327   def _setup_last_update_step(self):
    328     with variable_scope.variable_scope(
    329         self._spec.name, use_resource=self._spec.use_tpu) as scope:
    330       try:
    331         last_update_step = variable_scope.get_variable(
    332             'last_mask_update_step', [],
    333             initializer=init_ops.zeros_initializer(),
    334             trainable=False,
    335             dtype=dtypes.int32)
    336       except ValueError:
    337         scope.reuse_variables()
    338         last_update_step = variable_scope.get_variable(
    339             'last_mask_update_step', dtype=dtypes.int32)
    340     return last_update_step
    341 
    342   def _get_weight_sparsity_map(self):
    343     """Return the map of weight_name:sparsity parsed from the hparams."""
    344     weight_sparsity_map = {}
    345     val_list = self._spec.weight_sparsity_map
    346     filtered_val_list = [l for l in val_list if l]
    347     for val in filtered_val_list:
    348       weight_name, sparsity = val.split(':')
    349       if float(sparsity) >= 1.0:
    350         raise ValueError('Weight sparsity can not exceed 1.0')
    351       weight_sparsity_map[weight_name] = float(sparsity)
    352 
    353     return weight_sparsity_map
    354 
    355   def _get_sparsity(self, weight_name):
    356     """Return target sparsity for the given layer/weight name."""
    357     target_sparsity = [
    358         sparsity for name, sparsity in self._weight_sparsity_map.items()
    359         if weight_name.find(name) != -1
    360     ]
    361     if not target_sparsity:
    362       return self._sparsity
    363 
    364     if len(target_sparsity) > 1:
    365       raise ValueError(
    366           'Multiple matches in weight_sparsity_map for weight %s' % weight_name)
    367     # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize
    368     # to handle other cases as well.
    369     return math_ops.mul(
    370         self._sparsity,
    371         math_ops.div(target_sparsity[0], self._spec.target_sparsity))
    372 
    373   def _update_mask(self, weights, threshold):
    374     """Updates the mask for a given weight tensor.
    375 
    376     This functions first computes the cdf of the weight tensor, and estimates
    377     the threshold value such that 'desired_sparsity' fraction of weights
    378     have magnitude less than the threshold.
    379 
    380     Args:
    381       weights: The weight tensor that needs to be masked.
    382       threshold: The current threshold value. The function will compute a new
    383         threshold and return the exponential moving average using the current
    384         value of threshold
    385 
    386     Returns:
    387       new_threshold: The new value of the threshold based on weights, and
    388         sparsity at the current global_step
    389       new_mask: A numpy array of the same size and shape as weights containing
    390         0 or 1 to indicate which of the values in weights falls below
    391         the threshold
    392 
    393     Raises:
    394       ValueError: if sparsity is not defined
    395     """
    396     if self._sparsity is None:
    397       raise ValueError('Sparsity variable undefined')
    398 
    399     sparsity = self._get_sparsity(weights.op.name)
    400     with ops.name_scope(weights.op.name + '_pruning_ops'):
    401       abs_weights = math_ops.abs(weights)
    402       k = math_ops.cast(
    403           math_ops.round(
    404               math_ops.cast(array_ops.size(abs_weights), dtypes.float32) *
    405               (1 - sparsity)), dtypes.int32)
    406       # Sort the entire array
    407       values, _ = nn_ops.top_k(
    408           array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights))
    409       # Grab the (k-1) th value
    410       current_threshold = array_ops.gather(values, k - 1)
    411       smoothed_threshold = math_ops.add_n([
    412           math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
    413           math_ops.multiply(threshold, self._spec.threshold_decay)
    414       ])
    415 
    416       new_mask = math_ops.cast(
    417           math_ops.greater_equal(abs_weights, smoothed_threshold),
    418           dtypes.float32)
    419 
    420     return smoothed_threshold, new_mask
    421 
    422   def _maybe_update_block_mask(self, weights, threshold):
    423     """Performs block-granular masking of the weights.
    424 
    425     Block pruning occurs only if the block_height or block_width is > 1 and
    426     if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
    427     pruning occurs.
    428     Args:
    429       weights: The weight tensor that needs to be masked.
    430       threshold: The current threshold value. The function will compute a new
    431         threshold and return the exponential moving average using the current
    432         value of threshold
    433 
    434     Returns:
    435       new_threshold: The new value of the threshold based on weights, and
    436         sparsity at the current global_step
    437       new_mask: A numpy array of the same size and shape as weights containing
    438         0 or 1 to indicate which of the values in weights falls below
    439         the threshold
    440 
    441     Raises:
    442       ValueError: if block pooling function is not AVG or MAX
    443     """
    444     squeezed_weights = array_ops.squeeze(weights)
    445     if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
    446       return self._update_mask(weights, threshold)
    447 
    448     if self._block_pooling_function not in ['AVG', 'MAX']:
    449       raise ValueError('Unknown pooling function for block sparsity: %s' %
    450                        self._block_pooling_function)
    451 
    452     with ops.name_scope(weights.op.name + '_pruning_ops'):
    453       abs_weights = math_ops.abs(squeezed_weights)
    454 
    455       pool_window = [self._block_dim[0], self._block_dim[1]]
    456       pool_fn = pruning_utils.factorized_pool
    457       squeeze_axis = None
    458       if not self._spec.use_tpu:
    459         pool_fn = nn_ops.pool
    460         abs_weights = array_ops.reshape(
    461             abs_weights,
    462             [1, abs_weights.get_shape()[0],
    463              abs_weights.get_shape()[1], 1])
    464         squeeze_axis = [0, 3]
    465 
    466       pooled_weights = pool_fn(
    467           abs_weights,
    468           window_shape=pool_window,
    469           pooling_type=self._block_pooling_function,
    470           strides=pool_window,
    471           padding='SAME',
    472           name=weights.op.name + '_pooled')
    473 
    474       if pooled_weights.get_shape().ndims != 2:
    475         pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis)
    476 
    477       smoothed_threshold, new_mask = self._update_mask(pooled_weights,
    478                                                        threshold)
    479 
    480       updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
    481       sliced_mask = array_ops.slice(
    482           updated_mask, [0, 0],
    483           [squeezed_weights.get_shape()[0],
    484            squeezed_weights.get_shape()[1]])
    485 
    486     return smoothed_threshold, array_ops.reshape(sliced_mask,
    487                                                  array_ops.shape(weights))
    488 
    489   def _get_mask_assign_ops(self):
    490     # Make sure the assignment ops have not already been added to the list
    491     if self._assign_ops:
    492       raise ValueError(
    493           'Assign op list not empty. _get_mask_assign_ops() called twice?')
    494 
    495     masks = get_masks()
    496     weights = get_weights()
    497     thresholds = get_thresholds()
    498 
    499     if len(masks) != len(thresholds):
    500       raise ValueError(
    501           'Number of masks %s and number of thresholds %s mismatch' %
    502           (len(masks), len(thresholds)))
    503 
    504     for index, mask in enumerate(masks):
    505       threshold = thresholds[index]
    506       weight = weights[index]
    507       is_partitioned = isinstance(weight, variables.PartitionedVariable)
    508       if is_partitioned:
    509         weight = weight.as_tensor()
    510 
    511       new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
    512       self._assign_ops.append(
    513           pruning_utils.variable_assign(threshold, new_threshold))
    514 
    515       self._assign_ops.append(
    516           pruning_utils.partitioned_variable_assign(mask, new_mask)
    517           if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
    518 
    519   def mask_update_op(self):
    520     with ops.name_scope(self._spec.name):
    521       if not self._assign_ops:
    522         self._get_mask_assign_ops()
    523       with ops.control_dependencies([
    524           state_ops.assign(
    525               self._last_update_step,
    526               self._global_step,
    527               name='last_mask_update_step_assign')
    528       ]):
    529         with ops.control_dependencies(self._assign_ops):
    530           logging.info('Updating masks.')
    531           return control_flow_ops.no_op('mask_update')
    532 
    533   def conditional_mask_update_op(self):
    534 
    535     def maybe_update_masks():
    536       with ops.name_scope(self._spec.name):
    537         is_step_within_pruning_range = math_ops.logical_and(
    538             math_ops.greater_equal(self._global_step,
    539                                    self._spec.begin_pruning_step),
    540             # If end_pruning_step is negative, keep pruning forever!
    541             math_ops.logical_or(
    542                 math_ops.less_equal(self._global_step,
    543                                     self._spec.end_pruning_step),
    544                 math_ops.less(self._spec.end_pruning_step, 0)))
    545         is_pruning_step = math_ops.less_equal(
    546             math_ops.add(self._last_update_step, self._spec.pruning_frequency),
    547             self._global_step)
    548         return math_ops.logical_and(is_step_within_pruning_range,
    549                                     is_pruning_step)
    550 
    551     def mask_update_op():
    552       return self.mask_update_op()
    553 
    554     def no_update_op():
    555       return control_flow_ops.no_op()
    556 
    557     return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
    558                                  no_update_op)
    559 
    560   def add_pruning_summaries(self):
    561     """Adds summaries of weight sparsities and thresholds."""
    562     with ops.name_scope(self._spec.name + '_summaries'):
    563       summary.scalar('sparsity', self._sparsity)
    564       summary.scalar('last_mask_update_step', self._last_update_step)
    565       masks = get_masks()
    566       thresholds = get_thresholds()
    567       for mask, threshold in zip(masks, thresholds):
    568         summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask))
    569         summary.scalar(threshold.op.name + '/threshold', threshold)
    570 
    571   def print_hparams(self):
    572     logging.info(self._spec.to_json())
    573