Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Logic to update a TensorFlow model graph with quantization operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import re
     22 from tensorflow.contrib.quantize.python import common
     23 from tensorflow.contrib.quantize.python import graph_matcher
     24 from tensorflow.contrib.quantize.python import input_to_ops
     25 from tensorflow.contrib.quantize.python import quant_ops
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import control_flow_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.platform import tf_logging as logging
     30 
     31 # Quantizable operation types that are supported by the quantization rewrite.
     32 _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
     33 
     34 # Activations that are supported by the quantization rewrite.
     35 _ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'}
     36 
     37 _RELU_TYPES = {'Relu', 'Relu6'}
     38 
     39 _QUANTIZATION_OP = {'FakeQuantWithMinMaxVars'}
     40 _VALID_SRC_OP = {'Add', 'Mul'}
     41 _INTERMEDIATE_OP = {'Add', 'Mul'}
     42 _PASS_THROUGH_OP = {'Reshape', 'Identity', 'BatchToSpaceND', 'SpaceToBatchND'}
     43 _VALID_ACTIVATION_OP = {'Relu', 'Relu6'}
     44 
     45 
     46 def Quantize(graph,
     47              is_training,
     48              weight_bits=8,
     49              activation_bits=8,
     50              symmetric=False,
     51              ema_decay=0.999,
     52              quant_delay=None,
     53              vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
     54              scope=None):
     55   """Updates graph with quantization operations.
     56 
     57   Currently we quantize the following tensors:
     58   * Conv/MatMul: Quantize the weights if it matches.
     59   * Activation: Quantize the output if it matches.
     60   * Bypass/Post-activation Bypass: Quantize both input and output
     61     if it matches.
     62 
     63   Args:
     64     graph: Graph to modify.
     65     is_training: Whether quantizing training graph or eval graph.
     66     weight_bits: Number of bits to use for quantizing weights.
     67     activation_bits: Number of bits to use for quantizing activations.
     68     symmetric: (Optional) If true, use symmetric quantization limits instead of
     69       training the minimum and maximum of each quantization range separately.
     70     ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
     71       quantization intervals for quantizing activations (see here about EMA:
     72       https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
     73     quant_delay: (Optional, default None) Int, count of global steps for which
     74       to delay quantization.  This helps weights stabilize at the start of
     75       training.
     76     vars_collection: (Optional) Collection where to store the variables for
     77       quantization interval ends.
     78     scope: The scope to be transformed. If it's not None, only the ops which
     79       are in this scope will be transformed.
     80   Raises:
     81     ValueError: When quantization fails.
     82   """
     83   if scope and not scope.endswith('/'):
     84     scope += '/'
     85 
     86   input_to_ops_map = input_to_ops.InputToOps(graph)
     87   quantized_ops = set()
     88   for layer_match in _FindLayersToQuantize(graph):
     89     # Quantize the weights.
     90     context = _GetContextFromOp(layer_match.layer_op)
     91 
     92     # If `scope` is given, only quantize it if the consumer of weights
     93     # (the layer op) is in the right scope.
     94     if layer_match.weight_tensor is not None:
     95       _InsertQuantOp(
     96           context,
     97           'weights_quant',
     98           layer_match.weight_tensor.op,
     99           input_to_ops_map.ConsumerOperations(layer_match.weight_tensor.op),
    100           is_training,
    101           moving_avg=False,
    102           ema_decay=ema_decay,
    103           quant_delay=quant_delay,
    104           narrow_range=True,
    105           vars_collection=vars_collection,
    106           bits=weight_bits,
    107           symmetric=symmetric,
    108           consumer_scope=scope)
    109 
    110     # Quantize the activations.
    111     if layer_match.activation_op is not None:
    112       consumer_ops = input_to_ops_map.ConsumerOperations(
    113           layer_match.activation_op)
    114       add_context = context
    115       if layer_match.bypass_op:
    116         pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
    117         if pattern_match_result is not None:
    118           add_context = pattern_match_result.group(1)
    119         else:
    120           add_context = ''
    121       # If `scope` is given, only quantize it if the producer of weights
    122       # (usually it's the layer op) is in the right scope.
    123       _InsertQuantOp(
    124           add_context,
    125           'act_quant',
    126           layer_match.activation_op,
    127           consumer_ops,
    128           is_training,
    129           moving_avg=True,
    130           ema_decay=ema_decay,
    131           quant_delay=quant_delay,
    132           vars_collection=vars_collection,
    133           bits=activation_bits,
    134           symmetric=symmetric,
    135           init_min=0.0,
    136           producer_scope=scope)
    137       quantized_ops.add(layer_match.activation_op)
    138 
    139     # Quantize the inputs and output to the bypass (if it exists). The input to
    140     # the bypass is the bias add, and the output is the activation.
    141     if layer_match.bypass_op is not None:
    142       # If `scope` is given, only quantize it if the both the producer and the
    143       # consumer are in the right scope.
    144       _InsertQuantOp(
    145           context,
    146           'conv_quant',
    147           layer_match.bias_add_op,
    148           input_to_ops_map.ConsumerOperations(layer_match.bias_add_op),
    149           is_training,
    150           moving_avg=True,
    151           ema_decay=ema_decay,
    152           quant_delay=quant_delay,
    153           vars_collection=vars_collection,
    154           bits=activation_bits,
    155           symmetric=symmetric,
    156           producer_scope=scope,
    157           consumer_scope=scope)
    158       quantized_ops.add(layer_match.bias_add_op)
    159       # Make sure the op following this isn't an activation. In which case, we
    160       # shouldn't quantize it, since the activation will be Fused into the
    161       # Add at inference time.
    162       consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
    163       if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers):
    164         logging.info('Skipping %s, because its followed by an activation.',
    165                      layer_match.bypass_op.name)
    166       else:
    167         _InsertQuantOp(
    168             add_context,
    169             'add_quant',
    170             layer_match.bypass_op,
    171             input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
    172             is_training,
    173             moving_avg=True,
    174             ema_decay=ema_decay,
    175             quant_delay=quant_delay,
    176             vars_collection=vars_collection,
    177             bits=activation_bits,
    178             symmetric=symmetric,
    179             producer_scope=scope,
    180             consumer_scope=scope)
    181         quantized_ops.add(layer_match.bypass_op)
    182 
    183     # Quantize bypass ops that occur after the activation.
    184     if layer_match.post_activation_bypass_op is not None:
    185       pattern_match_result = re.search(
    186           r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
    187       if pattern_match_result is not None:
    188         post_activation_bypass_context = pattern_match_result.group(1)
    189       else:
    190         post_activation_bypass_context = ''
    191       # If `scope` is given, only quantize it if the producer is in the right
    192       # scope.
    193       # Make sure the op following this isn't an activation. In which case, we
    194       # shouldn't quantize it, since the activation will be Fused into the
    195       # Add at inference time.
    196       consumers = input_to_ops_map.ConsumerOperations(
    197           layer_match.post_activation_bypass_op)
    198       if any(consumer.type in _RELU_TYPES for consumer in consumers):
    199         logging.info('Skipping %s, because its followed by an activation.',
    200                      layer_match.post_activation_bypass_op.name)
    201       else:
    202         _InsertQuantOp(
    203             post_activation_bypass_context,
    204             'post_activation_bypass_quant',
    205             layer_match.post_activation_bypass_op,
    206             consumers,
    207             is_training,
    208             moving_avg=True,
    209             ema_decay=ema_decay,
    210             quant_delay=quant_delay,
    211             vars_collection=vars_collection,
    212             bits=activation_bits,
    213             symmetric=symmetric,
    214             producer_scope=scope)
    215         quantized_ops.add(layer_match.post_activation_bypass_op)
    216 
    217   _QuantizeActivationLayers(
    218       quantized_ops,
    219       graph,
    220       is_training,
    221       activation_bits,
    222       ema_decay,
    223       quant_delay,
    224       vars_collection,
    225       scope=scope)
    226 
    227 
    228 def _QuantizeActivationLayers(quantized_ops,
    229                               graph,
    230                               is_training,
    231                               activation_bits=8,
    232                               ema_decay=0.999,
    233                               quant_delay=None,
    234                               vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
    235                               scope=None):
    236   """Quantize intermediate activation tensors after addition and multiplication.
    237 
    238   Args:
    239     quantized_ops: Set of previously quantized activation ops.
    240     graph: Graph to modify.
    241     is_training: Whether quantizing training graph or eval graph.
    242     activation_bits: Number of bits to use for quantizing activations.
    243     ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
    244       quantization intervals for quantizing activations (see here about EMA:
    245       https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    246     quant_delay: (Optional, default None) Int, count of global steps for which
    247       to delay quantization.  This helps weights stabilize at the start of
    248       training.
    249     vars_collection: (Optional) Collection where to store the variables for
    250       quantization interval ends.
    251     scope: The scope to be transformed. If it's not None, only the ops which are
    252       in this scope will be transformed.
    253 
    254   Raises:
    255     ValueError: When quantization fails.
    256   """
    257   input_to_ops_map = input_to_ops.InputToOps(graph)
    258   for op in (op for op in graph.get_operations()):
    259     if _CheckIfQuantizableOp(op, quantized_ops):
    260       logging.info('Inserting fake quant op activation_%s_quant after %s',
    261                    op.type, op.name)
    262       consumers = input_to_ops_map.ConsumerOperations(op)
    263       _InsertQuantOp(
    264           op.name,
    265           'activation_' + op.type + '_quant',
    266           op,
    267           consumers,
    268           is_training,
    269           moving_avg=True,
    270           ema_decay=ema_decay,
    271           quant_delay=quant_delay,
    272           vars_collection=vars_collection,
    273           bits=activation_bits,
    274           producer_scope=scope)
    275 
    276 
    277 def _CheckIfQuantizableOp(src_op, quantized_ops):
    278   """Check if the output of an op should be quantized.
    279 
    280   Args:
    281     src_op: op to be checked
    282     quantized_ops: Set of previously quantized activation ops.
    283 
    284   Returns:
    285     Boolean specifying if output should be quantized or not.
    286   """
    287   src_op_name = set([src_op.type])
    288   if src_op in quantized_ops:
    289     return False
    290   if not src_op_name.intersection(_VALID_SRC_OP):
    291     return False
    292 
    293   # If src op is an add or a mul and the output is immediately
    294   # followed by an activation skip
    295   if len(src_op.outputs) == 1 and len(src_op.outputs[0].consumers()) == 1:
    296     op_consumers = src_op.outputs[0].consumers()
    297     if set([op_consumers[0].type]).intersection(_VALID_ACTIVATION_OP):
    298       logging.info('Skipping quant after %s', src_op.name)
    299       return False
    300   # Is an Add or a Mul
    301   input_ops = src_op.inputs
    302 
    303   for op in input_ops:
    304     curr_op = op.op
    305     curr_op_type = set([curr_op.type])
    306     while curr_op_type.intersection(_PASS_THROUGH_OP):
    307       # Walk back through pass through ops
    308       curr_op = curr_op.inputs[0].op
    309       curr_op_type = set([curr_op.type])
    310       # Now at a valid or quantizable op, need to check if
    311       # atleast one of the inputs to a valid op is connected
    312       # to a quantizable op via pass through ops
    313 
    314     if (curr_op_type.intersection(_QUANTIZATION_OP) or
    315         curr_op.name.find('delayed_quant/Merge') > 0):
    316       return True
    317 
    318     if curr_op_type.intersection(_INTERMEDIATE_OP):
    319       # Check if atleast one input to intermediate_op are quantizable
    320       for input_op in curr_op.inputs:
    321         if _CheckIfQuantizableOp(input_op.op, quantized_ops):
    322           return True
    323   return False
    324 
    325 
    326 def _FindLayersToQuantize(graph):
    327   """Matches layers in graph to quantize.
    328 
    329   The following patterns get matched. Nodes surrounded by [] will be
    330   optionally matched:
    331 
    332           weight|folded_weight
    333                 /
    334          conv|fc
    335             |
    336       [batch_to_space_nd]
    337             |
    338     [post_conv_correction]
    339             |
    340      [biasadd|folded_bias]
    341             |
    342          [bypass]
    343             |
    344         activation
    345             |
    346    [post_activation_bypass]
    347 
    348   Match replacements:
    349     If weight|folded_weight is found, FakeQuant is added afterwards.
    350     If bypass is found, FakeQuant is added before and after.
    351     If activation is found, FakeQuant is added afterwards.
    352     If post_activation_bypass is found, FakeQuant is added afterwards.
    353 
    354   Args:
    355     graph: Graph to perform match on.
    356 
    357   Returns:
    358     list of _LayerMatches.
    359   """
    360   input_pattern = graph_matcher.OpTypePattern('*')
    361   weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
    362   weight_partition_identity_pattern = graph_matcher.OpTypePattern(
    363       'Identity', inputs=[weight_var_pattern])
    364   weight_partition_concat_pattern = graph_matcher.OpTypePattern(
    365       'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
    366   weight_identity_pattern = graph_matcher.OpTypePattern(
    367       'Identity',
    368       inputs=[
    369           graph_matcher.OneofPattern([
    370               weight_partition_identity_pattern,
    371               weight_partition_concat_pattern,
    372               weight_var_pattern,
    373           ])
    374       ])
    375   weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
    376   folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
    377 
    378   # The weights inputs to the layer operation can either be from the Variable or
    379   # the folded weight (Mul).
    380   layer_pattern = graph_matcher.OpTypePattern(
    381       '|'.join(_QUANTIZABLE_TYPES),
    382       inputs=[
    383           input_pattern,
    384           graph_matcher.OneofPattern([
    385               weight_identity_pattern, weight_resource_var_pattern,
    386               folded_weight_pattern
    387           ])
    388       ],
    389       ordered_inputs=False)
    390 
    391   # For atrous convolutions a BatchToSpaceND will occur after the depthwise
    392   # convolution.
    393   batch_to_space_pattern = graph_matcher.OpTypePattern(
    394       'BatchToSpaceND',
    395       inputs=[
    396           layer_pattern,
    397           graph_matcher.OpTypePattern('*'),
    398           graph_matcher.OpTypePattern('*')
    399       ])
    400 
    401   layer_output_pattern = graph_matcher.OneofPattern(
    402       [batch_to_space_pattern, layer_pattern])
    403 
    404   # For separable convolutions, we are looking for a conv, followed by a conv
    405   # with no activations between the two.
    406   sep_conv_pattern = graph_matcher.OpTypePattern(
    407       '|'.join(_QUANTIZABLE_TYPES),
    408       inputs=[
    409           graph_matcher.OneofPattern([layer_output_pattern]),
    410           graph_matcher.OpTypePattern('*')
    411       ],
    412       ordered_inputs=False)
    413   folded_bias_mul_pattern = graph_matcher.OpTypePattern(
    414       'Mul',
    415       inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
    416       ordered_inputs=False)
    417   post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
    418       'Add',
    419       inputs=[folded_bias_mul_pattern,
    420               graph_matcher.OpTypePattern('*')],
    421       ordered_inputs=False)
    422   folded_bias_add_pattern = graph_matcher.OpTypePattern(
    423       'Add',
    424       inputs=[
    425           post_layer_op_correction_pattern,
    426           graph_matcher.OpTypePattern('*')
    427       ],
    428       ordered_inputs=False)
    429 
    430   # batch_norms with forced updates have an Identity operation at the end.
    431   # TODO(suharshs): Find a way to easily skip extra Identity operations. The
    432   # current issue is that doing so can often match patterns across many layers
    433   # incorrectly.
    434   batch_norm_identity = graph_matcher.OpTypePattern(
    435       'Identity', inputs=[folded_bias_add_pattern])
    436 
    437   bias_add_pattern = graph_matcher.OpTypePattern(
    438       'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
    439 
    440   # The bias can come from the bias add or the folded bias add.
    441   bypass_pattern = graph_matcher.OpTypePattern(
    442       'Add',
    443       inputs=[
    444           graph_matcher.OneofPattern(
    445               [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
    446           '*'
    447       ],
    448       ordered_inputs=False)
    449 
    450   # The input to the activation can come from bias add, fold bias add, the
    451   # bypasses.
    452   # TODO(suharshs): We should ideally skip Identity operations instead of
    453   # treating them as activations.
    454   activation_pattern = graph_matcher.OpTypePattern(
    455       '|'.join(_ACTIVATION_TYPES) + '|Identity',
    456       inputs=[
    457           graph_matcher.OneofPattern([
    458               bias_add_pattern,
    459               folded_bias_add_pattern,
    460               batch_norm_identity,
    461               bypass_pattern,
    462               layer_pattern,
    463           ])
    464       ])
    465 
    466   post_activation_bypass_pattern = graph_matcher.OpTypePattern(
    467       'Add', inputs=['*', activation_pattern], ordered_inputs=False)
    468 
    469   # The order of the following matching blocks is very important. Since matches
    470   # aren't guaranteed to be disjoint, we structure matches from largest to
    471   # smallest to guarantee that the largest match always wins. Additionally, we
    472   # ensure that we don't match layers multiple times.
    473 
    474   layer_matches = []
    475   # We use matched_layer_set to ensure that layers aren't matched multiple
    476   # times.
    477   matched_layer_set = set()
    478 
    479   # First, we match layers that have a post activation bypass. We do this first
    480   # to ensure we don't match only the first part of this layer, missing the
    481   # post activation bypass node.
    482   post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
    483       post_activation_bypass_pattern)
    484   for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
    485     layer_op = match_result.get_op(layer_pattern)
    486     weight_tensor = match_result.get_tensor(weight_identity_pattern)
    487     if weight_tensor is None:
    488       weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    489     if weight_tensor is None:
    490       weight_tensor = match_result.get_tensor(folded_weight_pattern)
    491     activation_op = match_result.get_op(activation_pattern)
    492     bias_add_op = match_result.get_op(bias_add_pattern)
    493     if bias_add_op is None:
    494       bias_add_op = match_result.get_op(folded_bias_add_pattern)
    495     bypass_op = match_result.get_op(bypass_pattern)
    496     post_activation_bypass_op = match_result.get_op(
    497         post_activation_bypass_pattern)
    498     if layer_op not in matched_layer_set:
    499       matched_layer_set.add(layer_op)
    500       layer_matches.append(
    501           _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
    502                       post_activation_bypass_op, bias_add_op))
    503 
    504   # Now, we match the basic layer ending at an activation. We may get duplicate
    505   # matches from above, but we don't add them to layer_matches.
    506   layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
    507   for match_result in layer_matcher.match_graph(graph):
    508     layer_op = match_result.get_op(layer_pattern)
    509     weight_tensor = match_result.get_tensor(weight_identity_pattern)
    510     if weight_tensor is None:
    511       weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    512     if weight_tensor is None:
    513       weight_tensor = match_result.get_tensor(folded_weight_pattern)
    514     activation_op = match_result.get_op(activation_pattern)
    515     bias_add_op = match_result.get_op(bias_add_pattern)
    516     if bias_add_op is None:
    517       bias_add_op = match_result.get_op(folded_bias_add_pattern)
    518     bypass_op = match_result.get_op(bypass_pattern)
    519     if layer_op not in matched_layer_set:
    520       if not _IsSkipLayer(activation_op):
    521         matched_layer_set.add(layer_op)
    522         layer_matches.append(
    523             _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None,
    524                         bias_add_op))
    525 
    526   # Match the final layer, where there may not be an activation and instead
    527   # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
    528   # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
    529   # quantized.
    530   final_layer_matcher = graph_matcher.GraphMatcher(
    531       graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
    532   for match_result in final_layer_matcher.match_graph(graph):
    533     layer_op = match_result.get_op(layer_pattern)
    534     weight_tensor = match_result.get_tensor(weight_identity_pattern)
    535     if weight_tensor is None:
    536       weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    537     if weight_tensor is None:
    538       weight_tensor = match_result.get_tensor(folded_weight_pattern)
    539     activation_op = match_result.get_op(bias_add_pattern)
    540     if activation_op is None:
    541       activation_op = match_result.get_op(folded_bias_add_pattern)
    542     if layer_op not in matched_layer_set:
    543       matched_layer_set.add(layer_op)
    544       layer_matches.append(
    545           _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
    546 
    547   # Look for separable convolutions here
    548   sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
    549   for match_result in sep_conv_matcher.match_graph(graph):
    550     layer_op = match_result.get_op(layer_pattern)
    551     weight_tensor = match_result.get_tensor(weight_identity_pattern)
    552     if weight_tensor is None:
    553       weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
    554     activation_op = match_result.get_op(layer_pattern)
    555     if layer_op not in matched_layer_set:
    556       matched_layer_set.add(layer_op)
    557       layer_matches.append(
    558           _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
    559 
    560   return layer_matches
    561 
    562 
    563 def _IsSkipLayer(activation_op):
    564   """Skip quantizing conv->identity->Batch norm layers.
    565 
    566   Args:
    567     activation_op: Activation op detected by layer matching pattern
    568 
    569   Returns:
    570     skip_layer: boolean, true when conv->identity->batch norm is detected.
    571   """
    572 
    573   # Exclude quantization of conv->identity->BN,
    574   # After folding, this part corresponds to estimation of mean and variance
    575   # and should not be quantized.
    576   skip_layer = False
    577   if activation_op.type == 'Identity' and len(activation_op.outputs) == 1:
    578     if len(activation_op.outputs[0].consumers()) == 1:
    579       consumer = activation_op.outputs[0].consumers()[0]
    580       if consumer.type == 'FusedBatchNorm':
    581         skip_layer = True
    582         logging.info(
    583             'Skipping quantizing %s, because it is the output of a conv/fc '
    584             'followed by a identity, feeding a fused batch norm.',
    585             activation_op.name)
    586   return skip_layer
    587 
    588 
    589 class _LayerMatch(object):
    590   """Contains all information related to a matched Layer."""
    591 
    592   def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
    593                post_activation_bypass_op, bias_add_op):
    594     self._layer_op = layer_op
    595     self._weight_tensor = weight_tensor
    596     self._activation_op = activation_op
    597     self._bypass_op = bypass_op
    598     self._post_activation_bypass_op = post_activation_bypass_op
    599     self._bias_add_op = bias_add_op
    600 
    601   @property
    602   def layer_op(self):
    603     return self._layer_op
    604 
    605   @property
    606   def weight_tensor(self):
    607     return self._weight_tensor
    608 
    609   @property
    610   def activation_op(self):
    611     return self._activation_op
    612 
    613   @property
    614   def bypass_op(self):
    615     return self._bypass_op
    616 
    617   @property
    618   def post_activation_bypass_op(self):
    619     return self._post_activation_bypass_op
    620 
    621   @property
    622   def bias_add_op(self):
    623     return self._bias_add_op
    624 
    625 
    626 def _FollowedByFakeQuant(tensor):
    627   """Returns True if the tensor is followed by a FakeQuant."""
    628   fake_quant_ops = set([
    629       'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
    630       'FakeQuantWithMinMaxVarsPerChannel'
    631   ])
    632   pass_through_ops = set(['Reshape', 'Identity'])
    633   consumers = tensor.consumers()
    634   while consumers:
    635     c = consumers.pop()
    636     if c.type in fake_quant_ops:
    637       return True
    638     elif c.type in pass_through_ops:
    639       for output in c.outputs:
    640         consumers.extend(output.consumers())
    641   return False
    642 
    643 
    644 def _InsertQuantOp(context,
    645                    name,
    646                    producer,
    647                    consumers,
    648                    is_training,
    649                    moving_avg=True,
    650                    init_min=-6.0,
    651                    init_max=6.0,
    652                    bits=8,
    653                    symmetric=False,
    654                    ema_decay=0.999,
    655                    quant_delay=None,
    656                    vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
    657                    narrow_range=False,
    658                    producer_scope=None,
    659                    consumer_scope=None):
    660   """Inserts a quant op between a producer op and (multiple) consumer ops.
    661 
    662   Args:
    663     context: Context where producer and consumer operations are nested.
    664     name: Name for the new quantization op within the context.
    665     producer: Producer operation of the pairs where quantization will be
    666       inserted.
    667     consumers: Consumer operations of the pairs.
    668     is_training: Whether quantizing training graph or eval graph.
    669     moving_avg: Specifies whether to use exponential moving average or just
    670       the last value seen.
    671     init_min: Starting minimum value for the new quantization op.
    672     init_max: Starting maximum value for the new quantization op.
    673     bits: Number of bits to use for quantization, must be between 2 and 8.
    674     symmetric: (Optional) If true, use symmetric quantization limits instead of
    675       training the minimum and maximum of each quantization range separately.
    676     ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
    677       quantization intervals for quantizing activations (see here about EMA:
    678       https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    679     quant_delay: (Optional, default None) Int, count of global steps for which
    680       to delay quantization.  This helps weights stabilize at the start of
    681       training.
    682     vars_collection: (Optional) Collection where to store the variables for
    683       quantization interval ends.
    684     narrow_range: Whether to use the narrow quantization range
    685       [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    686     producer_scope: The restriction of producer scope. If not None, the new op
    687       will be inserted only when the producer is in this scope.
    688     consumer_scope: The restriction of consumer scope. If not None, the new op
    689       will be inserted only when all the consumers are in this scope.
    690   Raises:
    691     ValueError: When producer operation is not directly connected to the
    692       consumer operation.
    693   """
    694   if producer_scope and not producer.name.startswith(producer_scope):
    695     logging.info(
    696         '_InsertQuantOp ignores context="%s" name="%s" '
    697         'because producer "%s" is not in scope "%s"',
    698         context, name, producer.name, producer_scope)
    699     return
    700 
    701   if consumer_scope:
    702     consumers_in_scope = []
    703     for consumer in consumers:
    704       if consumer.name.startswith(consumer_scope):
    705         consumers_in_scope.append(consumer)
    706       else:
    707         logging.info(
    708             '_InsertQuantOp context="%s" name="%s" ignores '
    709             'consumer "%s" because it is not in scope "%s"',
    710             context, name, consumer.name, consumer_scope)
    711         return
    712     consumers = consumers_in_scope
    713 
    714   name_prefix = _AddContextToName(context, name)
    715   # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    716   # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    717   # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    718   # breaks things later.
    719   name_scope = ops.get_name_scope()
    720   if name_scope:
    721     name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
    722 
    723   inputs = producer.outputs[0]
    724   # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    725   # overlap between multiple matches, so we need to ensure that we don't
    726   # add duplicate FakeQuant operations.
    727   if _FollowedByFakeQuant(inputs):
    728     return
    729 
    730   if moving_avg:
    731     quant = (
    732         quant_ops.MovingAvgQuantize(
    733             inputs,
    734             init_min=init_min,
    735             init_max=init_max,
    736             ema_decay=ema_decay,
    737             is_training=is_training,
    738             num_bits=bits,
    739             symmetric=symmetric,
    740             narrow_range=narrow_range,
    741             vars_collection=vars_collection,
    742             name_prefix=name_prefix))
    743   else:
    744     quant = (
    745         quant_ops.LastValueQuantize(
    746             inputs,
    747             init_min=init_min,
    748             init_max=init_max,
    749             is_training=is_training,
    750             num_bits=bits,
    751             symmetric=symmetric,
    752             narrow_range=narrow_range,
    753             vars_collection=vars_collection,
    754             name_prefix=name_prefix))
    755 
    756   if quant_delay and quant_delay > 0:
    757     activate_quant = math_ops.greater_equal(
    758         common.CreateOrGetQuantizationStep(),
    759         quant_delay,
    760         name=name_prefix + '/activate_quant')
    761     quant = control_flow_ops.cond(
    762         activate_quant,
    763         lambda: quant,
    764         lambda: inputs,
    765         name=name_prefix + '/delayed_quant')
    766 
    767   if consumers:
    768     tensors_modified_count = common.RerouteTensor(
    769         quant, inputs, can_modify=consumers)
    770     # Some operations can have multiple output tensors going to the same
    771     # consumer. Since consumers is a set, we need to ensure that
    772     # tensors_modified_count is greater than or equal to the length of the set
    773     # of consumers.
    774     if tensors_modified_count < len(consumers):
    775       raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
    776           [consumer.name for consumer in consumers]))
    777 
    778 
    779 def _GetContextFromOp(op):
    780   """Gets the root context name from the op name."""
    781   context_re = re.search(r'^(.*)/([^/]+)', op.name)
    782   if context_re:
    783     return context_re.group(1)
    784   return ''
    785 
    786 
    787 def _AddContextToName(context, name):
    788   """Adds the context to the name if it exists."""
    789   if not context:
    790     return name
    791   return context + '/' + name
    792