1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the key functions in pruning library.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.model_pruning.python import pruning 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops import partitioned_variables 27 from tensorflow.python.ops import random_ops 28 from tensorflow.python.ops import state_ops 29 from tensorflow.python.ops import variable_scope 30 from tensorflow.python.ops import variables 31 from tensorflow.python.platform import test 32 from tensorflow.python.training import training_util 33 34 35 class PruningHParamsTest(test.TestCase): 36 PARAM_LIST = [ 37 "name=test", "threshold_decay=0.9", "pruning_frequency=10", 38 "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", 39 "target_sparsity=0.9" 40 ] 41 TEST_HPARAMS = ",".join(PARAM_LIST) 42 43 def setUp(self): 44 super(PruningHParamsTest, self).setUp() 45 # Add global step variable to the graph 46 self.global_step = training_util.get_or_create_global_step() 47 # Add sparsity 48 self.sparsity = variables.Variable(0.5, name="sparsity") 49 # Parse hparams 50 self.pruning_hparams = pruning.get_pruning_hparams().parse( 51 self.TEST_HPARAMS) 52 53 def testInit(self): 54 p = pruning.Pruning(self.pruning_hparams) 55 self.assertEqual(p._spec.name, "test") 56 self.assertAlmostEqual(p._spec.threshold_decay, 0.9) 57 self.assertEqual(p._spec.pruning_frequency, 10) 58 self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) 59 self.assertEqual(p._spec.sparsity_function_end_step, 100) 60 self.assertAlmostEqual(p._spec.target_sparsity, 0.9) 61 62 def testInitWithExternalSparsity(self): 63 with self.test_session(): 64 p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) 65 variables.global_variables_initializer().run() 66 sparsity = p._sparsity.eval() 67 self.assertAlmostEqual(sparsity, 0.5) 68 69 def testInitWithVariableReuse(self): 70 with self.test_session(): 71 p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) 72 p_copy = pruning.Pruning( 73 spec=self.pruning_hparams, sparsity=self.sparsity) 74 variables.global_variables_initializer().run() 75 sparsity = p._sparsity.eval() 76 self.assertAlmostEqual(sparsity, 0.5) 77 self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval()) 78 79 80 class PruningTest(test.TestCase): 81 82 def setUp(self): 83 super(PruningTest, self).setUp() 84 self.global_step = training_util.get_or_create_global_step() 85 86 def testCreateMask2D(self): 87 width = 10 88 height = 20 89 with self.test_session(): 90 weights = variables.Variable( 91 random_ops.random_normal([width, height], stddev=1), name="weights") 92 masked_weights = pruning.apply_mask(weights, 93 variable_scope.get_variable_scope()) 94 variables.global_variables_initializer().run() 95 weights_val = weights.eval() 96 masked_weights_val = masked_weights.eval() 97 self.assertAllEqual(weights_val, masked_weights_val) 98 99 def testUpdateSingleMask(self): 100 with self.test_session() as session: 101 weights = variables.Variable( 102 math_ops.linspace(1.0, 100.0, 100), name="weights") 103 masked_weights = pruning.apply_mask(weights) 104 sparsity = variables.Variable(0.5, name="sparsity") 105 p = pruning.Pruning(sparsity=sparsity) 106 p._spec.threshold_decay = 0.0 107 mask_update_op = p.mask_update_op() 108 variables.global_variables_initializer().run() 109 masked_weights_val = masked_weights.eval() 110 self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) 111 session.run(mask_update_op) 112 masked_weights_val = masked_weights.eval() 113 self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) 114 115 def _blockMasking(self, hparams, weights, expected_mask): 116 117 threshold = variables.Variable(0.0, name="threshold") 118 sparsity = variables.Variable(0.51, name="sparsity") 119 test_spec = ",".join(hparams) 120 pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) 121 122 # Set up pruning 123 p = pruning.Pruning(pruning_hparams, sparsity=sparsity) 124 with self.test_session(): 125 variables.global_variables_initializer().run() 126 _, new_mask = p._maybe_update_block_mask(weights, threshold) 127 # Check if the mask is the same size as the weights 128 self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) 129 mask_val = new_mask.eval() 130 self.assertAllEqual(mask_val, expected_mask) 131 132 def testBlockMasking(self): 133 param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] 134 135 weights_avg = constant_op.constant( 136 [[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], 137 [0.3, 0.3, 0.4, 0.4]]) 138 weights_max = constant_op.constant( 139 [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], 140 [0.0, -0.3, 0.0, -0.4]]) 141 expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] 142 143 self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, 144 expected_mask) 145 self._blockMasking(param_list + ["block_pooling_function=AVG"], 146 weights_avg, expected_mask) 147 148 def testPartitionedVariableMasking(self): 149 partitioner = partitioned_variables.variable_axis_size_partitioner(40) 150 with self.test_session() as session: 151 with variable_scope.variable_scope("", partitioner=partitioner): 152 sparsity = variables.Variable(0.5, name="Sparsity") 153 weights = variable_scope.get_variable( 154 "weights", initializer=math_ops.linspace(1.0, 100.0, 100)) 155 masked_weights = pruning.apply_mask( 156 weights, scope=variable_scope.get_variable_scope()) 157 p = pruning.Pruning(sparsity=sparsity) 158 p._spec.threshold_decay = 0.0 159 mask_update_op = p.mask_update_op() 160 variables.global_variables_initializer().run() 161 masked_weights_val = masked_weights.eval() 162 session.run(mask_update_op) 163 masked_weights_val = masked_weights.eval() 164 self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) 165 166 def testConditionalMaskUpdate(self): 167 param_list = [ 168 "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" 169 ] 170 test_spec = ",".join(param_list) 171 pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) 172 weights = variables.Variable( 173 math_ops.linspace(1.0, 100.0, 100), name="weights") 174 masked_weights = pruning.apply_mask(weights) 175 sparsity = variables.Variable(0.00, name="sparsity") 176 # Set up pruning 177 p = pruning.Pruning(pruning_hparams, sparsity=sparsity) 178 p._spec.threshold_decay = 0.0 179 mask_update_op = p.conditional_mask_update_op() 180 sparsity_val = math_ops.linspace(0.0, 0.9, 10) 181 increment_global_step = state_ops.assign_add(self.global_step, 1) 182 non_zero_count = [] 183 with self.test_session() as session: 184 variables.global_variables_initializer().run() 185 for i in range(10): 186 session.run(state_ops.assign(sparsity, sparsity_val[i])) 187 session.run(mask_update_op) 188 session.run(increment_global_step) 189 non_zero_count.append(np.count_nonzero(masked_weights.eval())) 190 # Weights pruned at steps 0,2,4,and,6 191 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] 192 self.assertAllEqual(expected_non_zero_count, non_zero_count) 193 194 195 if __name__ == "__main__": 196 test.main() 197