1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python support for quantization operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.framework.python.ops import add_arg_scope 22 from tensorflow.contrib.framework.python.ops import model_variable 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import init_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import state_ops 28 from tensorflow.python.ops import variable_scope 29 from tensorflow.python.training import moving_averages 30 31 32 @add_arg_scope 33 def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): 34 """Adds a fake quantize layer with fixed quantization interval. 35 36 Args: 37 inputs: a tensor containing values to be quantized. 38 init_min: the lower end of quantization interval. 39 init_max: the upper end of quantization interval. 40 scope: Optional scope for name_scope. 41 Returns: 42 a tensor containing quantized values. 43 """ 44 with ops.name_scope(scope, 'FixedQuantize', values=[inputs]): 45 return array_ops.fake_quant_with_min_max_args( 46 inputs, min=init_min, max=init_max) 47 48 49 @add_arg_scope 50 def LastValueQuantize(inputs, 51 per_channel=False, 52 init_min=-6.0, 53 init_max=6.0, 54 updates_collection=ops.GraphKeys.UPDATE_OPS, 55 vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 56 name_prefix='LastValueQuant', 57 reuse=None, 58 is_training=True, 59 num_bits=8, 60 narrow_range=False): 61 """Adds a layer that collects quantization ranges as last input ranges. 62 63 LastValueQuantize creates variables called 'min' and 'max', representing the 64 interval used for quantization and clamping. 65 66 Args: 67 inputs: a tensor containing values to be quantized. 68 per_channel: (Optional) a boolean specifying whether to use different 69 quantization ranges per output channel. 70 init_min: a float scalar, the initial value for variable min. 71 init_max: a float scalar, the initial value for variable max. 72 updates_collection: (Optional) collections to collect the update ops for 73 computation. 74 vars_collection: (Optional) collection where to store variables for 75 quantization interval ends. 76 name_prefix: name_prefix for created nodes. 77 reuse: whether or not the layer and its variables should be reused. To be 78 able to reuse the layer scope must be given. 79 is_training: Whether the op is applied to a training or eval graph. 80 num_bits: Number of bits to use for quantization, must be between 2 and 8. 81 narrow_range: Whether to use the narrow quantization range 82 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 83 Returns: 84 a tensor containing quantized values. 85 """ 86 with variable_scope.variable_scope( 87 None, default_name=name_prefix, values=[inputs], reuse=reuse): 88 input_shape = inputs.get_shape() 89 input_dim = len(input_shape) 90 if per_channel: 91 # Only support quantizing 1-, 2- and 4-dimensional tensors. 92 assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' 93 ' scope: %s' % (input_shape, name_prefix)) 94 min_max_shape = [input_shape[-1]] 95 else: 96 min_max_shape = [] 97 98 min_var = model_variable( 99 'min', 100 shape=min_max_shape, 101 initializer=init_ops.constant_initializer(init_min), 102 collections=[vars_collection], 103 trainable=False) 104 max_var = model_variable( 105 'max', 106 shape=min_max_shape, 107 initializer=init_ops.constant_initializer(init_max), 108 collections=[vars_collection], 109 trainable=False) 110 if not is_training: 111 return _FakeQuantWithMinMaxVars( 112 inputs, 113 min_var, 114 max_var, 115 per_channel=per_channel, 116 num_bits=num_bits, 117 narrow_range=narrow_range) 118 119 if per_channel: 120 if input_dim == 2: 121 reduce_dims = [0] 122 elif input_dim == 4: 123 reduce_dims = [0, 1, 2] 124 125 if per_channel: 126 if input_dim >= 2: 127 batch_min = math_ops.reduce_min( 128 inputs, reduction_indices=reduce_dims, name='BatchMin') 129 else: 130 batch_min = inputs 131 else: 132 batch_min = math_ops.reduce_min(inputs, name='BatchMin') 133 # TFLite requires that 0.0 if always in the [min; max] range. 134 batch_min = math_ops.minimum(batch_min, 0.0) 135 assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast') 136 ops.add_to_collection(updates_collection, assign_min.op) 137 138 if per_channel: 139 if input_dim >= 2: 140 batch_max = math_ops.reduce_max( 141 inputs, reduction_indices=reduce_dims, name='BatchMax') 142 else: 143 batch_max = inputs 144 else: 145 batch_max = math_ops.reduce_max(inputs, name='BatchMax') 146 # TFLite requires that 0.0 if always in the [min; max] range. 147 batch_max = math_ops.maximum(batch_max, 0.0) 148 assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast') 149 ops.add_to_collection(updates_collection, assign_max.op) 150 151 return _FakeQuantWithMinMaxVars( 152 inputs, 153 assign_min, 154 assign_max, 155 per_channel=per_channel, 156 num_bits=num_bits, 157 narrow_range=narrow_range) 158 159 160 @add_arg_scope 161 def MovingAvgQuantize(inputs, 162 per_channel=False, 163 init_min=-6.0, 164 init_max=6.0, 165 ema_decay=0.999, 166 updates_collection=ops.GraphKeys.UPDATE_OPS, 167 vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 168 name_prefix='MovingAvgQuantize', 169 reuse=None, 170 is_training=True, 171 num_bits=8, 172 narrow_range=False): 173 """Adds a layer that collects quantization ranges as EMAs of input ranges. 174 175 MovingAvgQuantize creates variables called 'min' and 'max', representing the 176 interval used for quantization and clamping. 177 178 Args: 179 inputs: a tensor containing values to be quantized. 180 per_channel: (default False) a boolean specifying whether to use different 181 quantization ranges per output channel. 182 init_min: a float scalar, the initial value for variable min. 183 init_max: a float scalar, the initial value for variable max. 184 ema_decay: EMA decay parameter. 185 updates_collection: (Optional) collections to collect the update ops for 186 computation. 187 vars_collection: (Optional) collection where to store variables for 188 quantization interval ends. 189 name_prefix: name_prefix for created nodes. 190 reuse: whether or not the layer and its variables should be reused. To be 191 able to reuse the layer scope must be given. 192 is_training: Whether the op is applied to a training or eval graph. 193 num_bits: Number of bits to use for quantization, must be between 2 and 8. 194 narrow_range: Whether to use the narrow quantization range 195 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 196 Returns: 197 a tensor containing quantized values. 198 """ 199 with variable_scope.variable_scope( 200 None, default_name=name_prefix, values=[inputs], reuse=reuse): 201 input_shape = inputs.get_shape() 202 input_dim = len(input_shape) 203 if per_channel: 204 # Only support quantizing 1-, 2- and 4-dimensional tensors. 205 assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' 206 ' scope: %s' % (input_shape, name_prefix)) 207 min_max_shape = [input_shape[-1]] 208 else: 209 min_max_shape = [] 210 211 min_var = model_variable( 212 'min', 213 shape=min_max_shape, 214 initializer=init_ops.constant_initializer(init_min), 215 collections=[vars_collection], 216 trainable=False) 217 max_var = model_variable( 218 'max', 219 shape=min_max_shape, 220 initializer=init_ops.constant_initializer(init_max), 221 collections=[vars_collection], 222 trainable=False) 223 if not is_training: 224 return _FakeQuantWithMinMaxVars( 225 inputs, 226 min_var, 227 max_var, 228 per_channel=per_channel, 229 num_bits=num_bits, 230 narrow_range=narrow_range) 231 if per_channel: 232 if input_dim == 2: 233 reduce_dims = [0] 234 elif input_dim == 4: 235 reduce_dims = [0, 1, 2] 236 237 if per_channel: 238 if input_dim >= 2: 239 batch_min = math_ops.reduce_min( 240 inputs, reduction_indices=reduce_dims, name='BatchMin') 241 else: 242 batch_min = inputs 243 else: 244 batch_min = math_ops.reduce_min(inputs, name='BatchMin') 245 # B-eng requires that 0.0 if always in the [min; max] range. 246 batch_min = math_ops.minimum(batch_min, 0.0) 247 assign_min = moving_averages.assign_moving_average( 248 min_var, batch_min, ema_decay, name='AssignMinEma') 249 ops.add_to_collection(updates_collection, assign_min.op) 250 251 if per_channel: 252 if input_dim >= 2: 253 batch_max = math_ops.reduce_max( 254 inputs, reduction_indices=reduce_dims, name='BatchMax') 255 else: 256 batch_max = inputs 257 else: 258 batch_max = math_ops.reduce_max(inputs, name='BatchMax') 259 # B-eng requires that 0.0 if always in the [min; max] range. 260 batch_max = math_ops.maximum(batch_max, 0.0) 261 assign_max = moving_averages.assign_moving_average( 262 max_var, batch_max, ema_decay, name='AssignMaxEma') 263 ops.add_to_collection(updates_collection, assign_max.op) 264 265 return _FakeQuantWithMinMaxVars( 266 inputs, 267 assign_min, 268 assign_max, 269 per_channel=per_channel, 270 num_bits=num_bits, 271 narrow_range=narrow_range) 272 273 274 def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, 275 narrow_range): 276 """Adds a fake quantization operation. 277 278 Depending on value of per_channel, this operation may do global quantization 279 or per channel quantization. min_var and max_var should have corresponding 280 shapes: [1] when per_channel == False and [d] when per_channel == True. 281 282 Args: 283 inputs: a tensor containing values to be quantized. 284 min_var: a variable containing quantization range lower end(s). 285 max_var: a variable containing quantization range lupper end(s). 286 per_channel: a boolean specifying whether to use per-channel quantizatioh. 287 num_bits: Number of bits to use for quantization, must be between 2 and 8. 288 narrow_range: Whether to use the narrow quantization range 289 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 290 Returns: 291 a tensor containing quantized values. 292 """ 293 294 if per_channel: 295 assert len(min_var.get_shape()) == 1 296 assert len(max_var.get_shape()) == 1 297 return array_ops.fake_quant_with_min_max_vars_per_channel( 298 inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) 299 else: 300 assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison 301 assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison 302 return array_ops.fake_quant_with_min_max_vars( 303 inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) 304