Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python support for quantization operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.framework.python.ops import add_arg_scope
     22 from tensorflow.contrib.framework.python.ops import model_variable
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import init_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import state_ops
     28 from tensorflow.python.ops import variable_scope
     29 from tensorflow.python.training import moving_averages
     30 
     31 
     32 @add_arg_scope
     33 def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
     34   """Adds a fake quantize layer with fixed quantization interval.
     35 
     36   Args:
     37     inputs: a tensor containing values to be quantized.
     38     init_min: the lower end of quantization interval.
     39     init_max: the upper end of quantization interval.
     40     scope: Optional scope for name_scope.
     41   Returns:
     42     a tensor containing quantized values.
     43   """
     44   with ops.name_scope(scope, 'FixedQuantize', values=[inputs]):
     45     return array_ops.fake_quant_with_min_max_args(
     46         inputs, min=init_min, max=init_max)
     47 
     48 
     49 @add_arg_scope
     50 def LastValueQuantize(inputs,
     51                       per_channel=False,
     52                       init_min=-6.0,
     53                       init_max=6.0,
     54                       updates_collection=ops.GraphKeys.UPDATE_OPS,
     55                       vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
     56                       name_prefix='LastValueQuant',
     57                       reuse=None,
     58                       is_training=True,
     59                       num_bits=8,
     60                       narrow_range=False):
     61   """Adds a layer that collects quantization ranges as last input ranges.
     62 
     63   LastValueQuantize creates variables called 'min' and 'max', representing the
     64   interval used for quantization and clamping.
     65 
     66   Args:
     67     inputs: a tensor containing values to be quantized.
     68     per_channel: (Optional) a boolean specifying whether to use different
     69       quantization ranges per output channel.
     70     init_min: a float scalar, the initial value for variable min.
     71     init_max: a float scalar, the initial value for variable max.
     72     updates_collection: (Optional) collections to collect the update ops for
     73       computation.
     74     vars_collection: (Optional) collection where to store variables for
     75       quantization interval ends.
     76     name_prefix: name_prefix for created nodes.
     77     reuse: whether or not the layer and its variables should be reused. To be
     78       able to reuse the layer scope must be given.
     79     is_training: Whether the op is applied to a training or eval graph.
     80     num_bits: Number of bits to use for quantization, must be between 2 and 8.
     81     narrow_range: Whether to use the narrow quantization range
     82       [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
     83   Returns:
     84     a tensor containing quantized values.
     85   """
     86   with variable_scope.variable_scope(
     87       None, default_name=name_prefix, values=[inputs], reuse=reuse):
     88     input_shape = inputs.get_shape()
     89     input_dim = len(input_shape)
     90     if per_channel:
     91       # Only support quantizing 1-, 2- and 4-dimensional tensors.
     92       assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
     93                                       ' scope: %s' % (input_shape, name_prefix))
     94       min_max_shape = [input_shape[-1]]
     95     else:
     96       min_max_shape = []
     97 
     98     min_var = model_variable(
     99         'min',
    100         shape=min_max_shape,
    101         initializer=init_ops.constant_initializer(init_min),
    102         collections=[vars_collection],
    103         trainable=False)
    104     max_var = model_variable(
    105         'max',
    106         shape=min_max_shape,
    107         initializer=init_ops.constant_initializer(init_max),
    108         collections=[vars_collection],
    109         trainable=False)
    110     if not is_training:
    111       return _FakeQuantWithMinMaxVars(
    112           inputs,
    113           min_var,
    114           max_var,
    115           per_channel=per_channel,
    116           num_bits=num_bits,
    117           narrow_range=narrow_range)
    118 
    119     if per_channel:
    120       if input_dim == 2:
    121         reduce_dims = [0]
    122       elif input_dim == 4:
    123         reduce_dims = [0, 1, 2]
    124 
    125     if per_channel:
    126       if input_dim >= 2:
    127         batch_min = math_ops.reduce_min(
    128             inputs, reduction_indices=reduce_dims, name='BatchMin')
    129       else:
    130         batch_min = inputs
    131     else:
    132       batch_min = math_ops.reduce_min(inputs, name='BatchMin')
    133     # TFLite requires that 0.0 if always in the [min; max] range.
    134     batch_min = math_ops.minimum(batch_min, 0.0)
    135     assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast')
    136     ops.add_to_collection(updates_collection, assign_min.op)
    137 
    138     if per_channel:
    139       if input_dim >= 2:
    140         batch_max = math_ops.reduce_max(
    141             inputs, reduction_indices=reduce_dims, name='BatchMax')
    142       else:
    143         batch_max = inputs
    144     else:
    145       batch_max = math_ops.reduce_max(inputs, name='BatchMax')
    146     # TFLite requires that 0.0 if always in the [min; max] range.
    147     batch_max = math_ops.maximum(batch_max, 0.0)
    148     assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast')
    149     ops.add_to_collection(updates_collection, assign_max.op)
    150 
    151     return _FakeQuantWithMinMaxVars(
    152         inputs,
    153         assign_min,
    154         assign_max,
    155         per_channel=per_channel,
    156         num_bits=num_bits,
    157         narrow_range=narrow_range)
    158 
    159 
    160 @add_arg_scope
    161 def MovingAvgQuantize(inputs,
    162                       per_channel=False,
    163                       init_min=-6.0,
    164                       init_max=6.0,
    165                       ema_decay=0.999,
    166                       updates_collection=ops.GraphKeys.UPDATE_OPS,
    167                       vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
    168                       name_prefix='MovingAvgQuantize',
    169                       reuse=None,
    170                       is_training=True,
    171                       num_bits=8,
    172                       narrow_range=False):
    173   """Adds a layer that collects quantization ranges as EMAs of input ranges.
    174 
    175   MovingAvgQuantize creates variables called 'min' and 'max', representing the
    176   interval used for quantization and clamping.
    177 
    178   Args:
    179     inputs: a tensor containing values to be quantized.
    180     per_channel: (default False) a boolean specifying whether to use different
    181       quantization ranges per output channel.
    182     init_min: a float scalar, the initial value for variable min.
    183     init_max: a float scalar, the initial value for variable max.
    184     ema_decay: EMA decay parameter.
    185     updates_collection: (Optional) collections to collect the update ops for
    186       computation.
    187     vars_collection: (Optional) collection where to store variables for
    188       quantization interval ends.
    189     name_prefix: name_prefix for created nodes.
    190     reuse: whether or not the layer and its variables should be reused. To be
    191       able to reuse the layer scope must be given.
    192     is_training: Whether the op is applied to a training or eval graph.
    193     num_bits: Number of bits to use for quantization, must be between 2 and 8.
    194     narrow_range: Whether to use the narrow quantization range
    195       [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
    196   Returns:
    197     a tensor containing quantized values.
    198   """
    199   with variable_scope.variable_scope(
    200       None, default_name=name_prefix, values=[inputs], reuse=reuse):
    201     input_shape = inputs.get_shape()
    202     input_dim = len(input_shape)
    203     if per_channel:
    204       # Only support quantizing 1-, 2- and 4-dimensional tensors.
    205       assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
    206                                       ' scope: %s' % (input_shape, name_prefix))
    207       min_max_shape = [input_shape[-1]]
    208     else:
    209       min_max_shape = []
    210 
    211     min_var = model_variable(
    212         'min',
    213         shape=min_max_shape,
    214         initializer=init_ops.constant_initializer(init_min),
    215         collections=[vars_collection],
    216         trainable=False)
    217     max_var = model_variable(
    218         'max',
    219         shape=min_max_shape,
    220         initializer=init_ops.constant_initializer(init_max),
    221         collections=[vars_collection],
    222         trainable=False)
    223     if not is_training:
    224       return _FakeQuantWithMinMaxVars(
    225           inputs,
    226           min_var,
    227           max_var,
    228           per_channel=per_channel,
    229           num_bits=num_bits,
    230           narrow_range=narrow_range)
    231     if per_channel:
    232       if input_dim == 2:
    233         reduce_dims = [0]
    234       elif input_dim == 4:
    235         reduce_dims = [0, 1, 2]
    236 
    237     if per_channel:
    238       if input_dim >= 2:
    239         batch_min = math_ops.reduce_min(
    240             inputs, reduction_indices=reduce_dims, name='BatchMin')
    241       else:
    242         batch_min = inputs
    243     else:
    244       batch_min = math_ops.reduce_min(inputs, name='BatchMin')
    245     # B-eng requires that 0.0 if always in the [min; max] range.
    246     batch_min = math_ops.minimum(batch_min, 0.0)
    247     assign_min = moving_averages.assign_moving_average(
    248         min_var, batch_min, ema_decay, name='AssignMinEma')
    249     ops.add_to_collection(updates_collection, assign_min.op)
    250 
    251     if per_channel:
    252       if input_dim >= 2:
    253         batch_max = math_ops.reduce_max(
    254             inputs, reduction_indices=reduce_dims, name='BatchMax')
    255       else:
    256         batch_max = inputs
    257     else:
    258       batch_max = math_ops.reduce_max(inputs, name='BatchMax')
    259     # B-eng requires that 0.0 if always in the [min; max] range.
    260     batch_max = math_ops.maximum(batch_max, 0.0)
    261     assign_max = moving_averages.assign_moving_average(
    262         max_var, batch_max, ema_decay, name='AssignMaxEma')
    263     ops.add_to_collection(updates_collection, assign_max.op)
    264 
    265     return _FakeQuantWithMinMaxVars(
    266         inputs,
    267         assign_min,
    268         assign_max,
    269         per_channel=per_channel,
    270         num_bits=num_bits,
    271         narrow_range=narrow_range)
    272 
    273 
    274 def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
    275                              narrow_range):
    276   """Adds a fake quantization operation.
    277 
    278   Depending on value of per_channel, this operation may do global quantization
    279   or per channel quantization.  min_var and max_var should have corresponding
    280   shapes: [1] when per_channel == False and [d] when per_channel == True.
    281 
    282   Args:
    283     inputs: a tensor containing values to be quantized.
    284     min_var: a variable containing quantization range lower end(s).
    285     max_var: a variable containing quantization range lupper end(s).
    286     per_channel: a boolean specifying whether to use per-channel quantizatioh.
    287     num_bits: Number of bits to use for quantization, must be between 2 and 8.
    288     narrow_range: Whether to use the narrow quantization range
    289       [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
    290   Returns:
    291     a tensor containing quantized values.
    292   """
    293 
    294   if per_channel:
    295     assert len(min_var.get_shape()) == 1
    296     assert len(max_var.get_shape()) == 1
    297     return array_ops.fake_quant_with_min_max_vars_per_channel(
    298         inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
    299   else:
    300     assert min_var.get_shape() == []  # pylint: disable=g-explicit-bool-comparison
    301     assert max_var.get_shape() == []  # pylint: disable=g-explicit-bool-comparison
    302     return array_ops.fake_quant_with_min_max_vars(
    303         inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
    304