Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the key functions in pruning library."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.model_pruning.python import pruning
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import partitioned_variables
     27 from tensorflow.python.ops import random_ops
     28 from tensorflow.python.ops import state_ops
     29 from tensorflow.python.ops import variable_scope
     30 from tensorflow.python.ops import variables
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.training import training_util
     33 
     34 
     35 class PruningHParamsTest(test.TestCase):
     36   PARAM_LIST = [
     37       "name=test", "threshold_decay=0.9", "pruning_frequency=10",
     38       "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100",
     39       "target_sparsity=0.9"
     40   ]
     41   TEST_HPARAMS = ",".join(PARAM_LIST)
     42 
     43   def setUp(self):
     44     super(PruningHParamsTest, self).setUp()
     45     # Add global step variable to the graph
     46     self.global_step = training_util.get_or_create_global_step()
     47     # Add sparsity
     48     self.sparsity = variables.Variable(0.5, name="sparsity")
     49     # Parse hparams
     50     self.pruning_hparams = pruning.get_pruning_hparams().parse(
     51         self.TEST_HPARAMS)
     52 
     53   def testInit(self):
     54     p = pruning.Pruning(self.pruning_hparams)
     55     self.assertEqual(p._spec.name, "test")
     56     self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
     57     self.assertEqual(p._spec.pruning_frequency, 10)
     58     self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"])
     59     self.assertEqual(p._spec.sparsity_function_end_step, 100)
     60     self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
     61 
     62   def testInitWithExternalSparsity(self):
     63     with self.test_session():
     64       p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
     65       variables.global_variables_initializer().run()
     66       sparsity = p._sparsity.eval()
     67       self.assertAlmostEqual(sparsity, 0.5)
     68 
     69   def testInitWithVariableReuse(self):
     70     with self.test_session():
     71       p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
     72       p_copy = pruning.Pruning(
     73           spec=self.pruning_hparams, sparsity=self.sparsity)
     74       variables.global_variables_initializer().run()
     75       sparsity = p._sparsity.eval()
     76       self.assertAlmostEqual(sparsity, 0.5)
     77       self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
     78 
     79 
     80 class PruningTest(test.TestCase):
     81 
     82   def setUp(self):
     83     super(PruningTest, self).setUp()
     84     self.global_step = training_util.get_or_create_global_step()
     85 
     86   def testCreateMask2D(self):
     87     width = 10
     88     height = 20
     89     with self.test_session():
     90       weights = variables.Variable(
     91           random_ops.random_normal([width, height], stddev=1), name="weights")
     92       masked_weights = pruning.apply_mask(weights,
     93                                           variable_scope.get_variable_scope())
     94       variables.global_variables_initializer().run()
     95       weights_val = weights.eval()
     96       masked_weights_val = masked_weights.eval()
     97       self.assertAllEqual(weights_val, masked_weights_val)
     98 
     99   def testUpdateSingleMask(self):
    100     with self.test_session() as session:
    101       weights = variables.Variable(
    102           math_ops.linspace(1.0, 100.0, 100), name="weights")
    103       masked_weights = pruning.apply_mask(weights)
    104       sparsity = variables.Variable(0.5, name="sparsity")
    105       p = pruning.Pruning(sparsity=sparsity)
    106       p._spec.threshold_decay = 0.0
    107       mask_update_op = p.mask_update_op()
    108       variables.global_variables_initializer().run()
    109       masked_weights_val = masked_weights.eval()
    110       self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
    111       session.run(mask_update_op)
    112       masked_weights_val = masked_weights.eval()
    113       self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
    114 
    115   def _blockMasking(self, hparams, weights, expected_mask):
    116 
    117     threshold = variables.Variable(0.0, name="threshold")
    118     sparsity = variables.Variable(0.51, name="sparsity")
    119     test_spec = ",".join(hparams)
    120     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    121 
    122     # Set up pruning
    123     p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    124     with self.test_session():
    125       variables.global_variables_initializer().run()
    126       _, new_mask = p._maybe_update_block_mask(weights, threshold)
    127       # Check if the mask is the same size as the weights
    128       self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
    129       mask_val = new_mask.eval()
    130       self.assertAllEqual(mask_val, expected_mask)
    131 
    132   def testBlockMasking(self):
    133     param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
    134 
    135     weights_avg = constant_op.constant(
    136         [[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
    137          [0.3, 0.3, 0.4, 0.4]])
    138     weights_max = constant_op.constant(
    139         [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
    140          [0.0, -0.3, 0.0, -0.4]])
    141     expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
    142 
    143     self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
    144                        expected_mask)
    145     self._blockMasking(param_list + ["block_pooling_function=AVG"],
    146                        weights_avg, expected_mask)
    147 
    148   def testPartitionedVariableMasking(self):
    149     partitioner = partitioned_variables.variable_axis_size_partitioner(40)
    150     with self.test_session() as session:
    151       with variable_scope.variable_scope("", partitioner=partitioner):
    152         sparsity = variables.Variable(0.5, name="Sparsity")
    153         weights = variable_scope.get_variable(
    154             "weights", initializer=math_ops.linspace(1.0, 100.0, 100))
    155         masked_weights = pruning.apply_mask(
    156             weights, scope=variable_scope.get_variable_scope())
    157       p = pruning.Pruning(sparsity=sparsity)
    158       p._spec.threshold_decay = 0.0
    159       mask_update_op = p.mask_update_op()
    160       variables.global_variables_initializer().run()
    161       masked_weights_val = masked_weights.eval()
    162       session.run(mask_update_op)
    163       masked_weights_val = masked_weights.eval()
    164       self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
    165 
    166   def testConditionalMaskUpdate(self):
    167     param_list = [
    168         "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
    169     ]
    170     test_spec = ",".join(param_list)
    171     pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
    172     weights = variables.Variable(
    173         math_ops.linspace(1.0, 100.0, 100), name="weights")
    174     masked_weights = pruning.apply_mask(weights)
    175     sparsity = variables.Variable(0.00, name="sparsity")
    176     # Set up pruning
    177     p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
    178     p._spec.threshold_decay = 0.0
    179     mask_update_op = p.conditional_mask_update_op()
    180     sparsity_val = math_ops.linspace(0.0, 0.9, 10)
    181     increment_global_step = state_ops.assign_add(self.global_step, 1)
    182     non_zero_count = []
    183     with self.test_session() as session:
    184       variables.global_variables_initializer().run()
    185       for i in range(10):
    186         session.run(state_ops.assign(sparsity, sparsity_val[i]))
    187         session.run(mask_update_op)
    188         session.run(increment_global_step)
    189         non_zero_count.append(np.count_nonzero(masked_weights.eval()))
    190     # Weights pruned at steps 0,2,4,and,6
    191     expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
    192     self.assertAllEqual(expected_non_zero_count, non_zero_count)
    193 
    194 
    195 if __name__ == "__main__":
    196   test.main()
    197