Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Parameterized unit tests for quantizing a Tensorflow graph."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.layers.python.layers import layers
     22 from tensorflow.contrib.quantize.python import fold_batch_norms
     23 from tensorflow.contrib.quantize.python import quantize
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import control_flow_ops
     28 from tensorflow.python.ops import init_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn_ops
     31 from tensorflow.python.platform import googletest
     32 
     33 batch_norm = layers.batch_norm
     34 conv2d = layers.conv2d
     35 fully_connected = layers.fully_connected
     36 separable_conv2d = layers.separable_conv2d
     37 
     38 
     39 class QuantizeTest(test_util.TensorFlowTestCase):
     40 
     41   def _RunWithoutBatchNormTestOverParameters(self, test_fn):
     42     # TODO(suharshs): Use parameterized test once OSS TF supports it.
     43     parameters_list = [
     44         # (activation, activation_op_name, with_bypass, delay)
     45         (nn_ops.relu6, 'Relu6', False, None),
     46         (nn_ops.relu, 'Relu', False, None),
     47         (array_ops.identity, 'Identity', False, None),
     48         (nn_ops.relu6, 'Relu6', False, 5000),
     49         (nn_ops.relu, 'Relu', False, 5000),
     50         (array_ops.identity, 'Identity', False, 5000),
     51         (nn_ops.relu6, 'Relu6', True, None),
     52         (nn_ops.relu, 'Relu', True, None),
     53         (array_ops.identity, 'Identity', True, None),
     54         (nn_ops.relu6, 'Relu6', True, 5000),
     55         (nn_ops.relu, 'Relu', True, 5000),
     56         (array_ops.identity, 'Identity', True, 5000),
     57     ]
     58     for params in parameters_list:
     59       test_fn(params[0], params[1], params[2], params[3])
     60 
     61   def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
     62                                            with_bypass, delay):
     63     """Tests quantization: inputs -> Conv2d no batch norm -> Activation.
     64 
     65     Args:
     66       activation: Callable that returns an Operation, a factory method for the
     67         Activation.
     68       activation_op_name: String, name of the Activation operation.
     69       with_bypass: Bool, when true there is an extra connection added from
     70         inputs to just before Activation.
     71       delay: Int (optional), delay in number of steps until quantization starts.
     72     """
     73     graph = ops.Graph()
     74     with graph.as_default():
     75       batch_size, height, width, depth = 5, 128, 128, 3
     76       inputs = array_ops.zeros((batch_size, height, width, depth))
     77       stride = 1 if with_bypass else 2
     78       out_depth = 3 if with_bypass else 32
     79       activation_fn = None if with_bypass else activation
     80       scope = 'test/test2' if with_bypass else 'test'
     81       node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
     82                     weights_initializer=self._WeightInit(0.09),
     83                     activation_fn=activation_fn, scope=scope)
     84       if with_bypass:
     85         node = math_ops.add(inputs, node, name='test/Add')
     86         node = activation(node, name='test/' + activation_op_name)
     87       update_barrier = control_flow_ops.no_op(name='update_barrier')
     88       with ops.control_dependencies([update_barrier]):
     89         array_ops.identity(node, name='control_dependency')
     90 
     91       quantize.Quantize(graph, True, quant_delay=delay)
     92     quantization_node_name = 'FakeQuantWithMinMaxVars'
     93     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
     94                                                 quantization_node_name)
     95     self.assertEqual(weights_quant.type, quantization_node_name)
     96     expected_inputs = [
     97         scope + '/weights_quant/AssignMinLast',
     98         scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
     99     ]
    100     self._AssertInputOpsAre(weights_quant, expected_inputs)
    101     if delay and delay > 0:
    102       output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
    103     else:
    104       output_op_name = scope + '/Conv2D'
    105 
    106     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    107 
    108     if with_bypass:
    109       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    110                                                quantization_node_name)
    111       self.assertEqual(conv_quant.type, quantization_node_name)
    112       expected_inputs = [
    113           scope + '/conv_quant/AssignMinEma',
    114           scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
    115       ]
    116       self._AssertInputOpsAre(conv_quant, expected_inputs)
    117       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    118                         if delay else 'test/Add')
    119       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    120 
    121     act_quant = graph.get_operation_by_name('test/act_quant/' +
    122                                             quantization_node_name)
    123     self.assertEqual(act_quant.type, quantization_node_name)
    124 
    125     expected_inputs = [
    126         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    127         'test/' + activation_op_name
    128     ]
    129     self._AssertInputOpsAre(act_quant, expected_inputs)
    130     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    131                       if delay else 'control_dependency')
    132     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    133 
    134   def testQuantize_Conv2dWithoutBatchNorm(self):
    135     self._RunWithoutBatchNormTestOverParameters(
    136         self._TestQuantize_Conv2dWithoutBatchNorm)
    137 
    138   def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
    139                                        with_bypass, delay):
    140     """Tests quantization: inputs -> FC no batch norm -> Activation.
    141 
    142     Args:
    143       activation: Callable that returns an Operation, a factory method for the
    144         Activation.
    145       activation_op_name: String, name of the Activation operation.
    146       with_bypass: Bool, when true there is an extra connection added from
    147         inputs to just before Activation.
    148       delay: Int (optional), delay in number of steps until quantization starts.
    149     """
    150     graph = ops.Graph()
    151     with graph.as_default():
    152       batch_size, depth = 5, 256
    153       inputs = array_ops.zeros((batch_size, depth))
    154       out_depth = 256 if with_bypass else 128
    155       activation_fn = None if with_bypass else activation
    156       scope = 'test/test2' if with_bypass else 'test'
    157       node = fully_connected(inputs, out_depth,
    158                              weights_initializer=self._WeightInit(0.03),
    159                              activation_fn=activation_fn, scope=scope)
    160       if with_bypass:
    161         node = math_ops.add(inputs, node, name='test/Add')
    162         node = activation(node, name='test/' + activation_op_name)
    163       update_barrier = control_flow_ops.no_op(name='update_barrier')
    164       with ops.control_dependencies([update_barrier]):
    165         array_ops.identity(node, name='control_dependency')
    166 
    167       quantize.Quantize(graph, True, quant_delay=delay)
    168 
    169     quantization_node_name = 'FakeQuantWithMinMaxVars'
    170     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
    171                                                 quantization_node_name)
    172     self.assertEqual(weights_quant.type, quantization_node_name)
    173     expected_inputs = [
    174         scope + '/weights_quant/AssignMinLast',
    175         scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
    176     ]
    177     self._AssertInputOpsAre(weights_quant, expected_inputs)
    178     if delay and delay > 0:
    179       output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
    180     else:
    181       output_op_name = scope + '/MatMul'
    182     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    183 
    184     if with_bypass:
    185       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    186                                                quantization_node_name)
    187       self.assertEqual(conv_quant.type, quantization_node_name)
    188       expected_inputs = [
    189           scope + '/conv_quant/AssignMinEma',
    190           scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
    191       ]
    192       self._AssertInputOpsAre(conv_quant, expected_inputs)
    193       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    194                         if delay else 'test/Add')
    195       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    196 
    197     act_quant = graph.get_operation_by_name('test/act_quant/' +
    198                                             quantization_node_name)
    199     self.assertEqual(act_quant.type, quantization_node_name)
    200     expected_inputs = [
    201         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    202         'test/' + activation_op_name
    203     ]
    204     self._AssertInputOpsAre(act_quant, expected_inputs)
    205     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    206                       if delay else 'control_dependency')
    207     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    208 
    209   def testQuantize_FCWithoutBatchNorm(self):
    210     self._RunWithoutBatchNormTestOverParameters(
    211         self._TestQuantize_FCWithoutBatchNorm)
    212 
    213   def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
    214       self, activation, activation_op_name, with_bypass, delay):
    215     """Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
    216 
    217     Args:
    218       activation: Callable that returns an Operation, a factory method for the
    219         Activation.
    220       activation_op_name: String, name of the Activation operation.
    221       with_bypass: Bool, when true there is an extra connection added from
    222         inputs to just before Activation.
    223       delay: Int (optional), delay in number of steps until quantization starts.
    224     """
    225     graph = ops.Graph()
    226     with graph.as_default():
    227       batch_size, height, width, depth = 5, 128, 128, 3
    228       inputs = array_ops.zeros((batch_size, height, width, depth))
    229       stride = 1 if with_bypass else 2
    230       activation_fn = None if with_bypass else activation
    231       scope = 'test/test2' if with_bypass else 'test'
    232       node = separable_conv2d(inputs, None, [5, 5], stride=stride,
    233                               depth_multiplier=1.0, padding='SAME',
    234                               weights_initializer=self._WeightInit(0.09),
    235                               activation_fn=activation_fn, scope=scope)
    236       if with_bypass:
    237         node = math_ops.add(inputs, node, name='test/Add')
    238         node = activation(node, name='test/' + activation_op_name)
    239       update_barrier = control_flow_ops.no_op(name='update_barrier')
    240       with ops.control_dependencies([update_barrier]):
    241         array_ops.identity(node, name='control_dependency')
    242 
    243       quantize.Quantize(graph, True, quant_delay=delay)
    244 
    245     quantization_node_name = 'FakeQuantWithMinMaxVars'
    246     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
    247                                                 quantization_node_name)
    248     self.assertEqual(weights_quant.type, quantization_node_name)
    249     expected_inputs = [
    250         scope + '/weights_quant/AssignMinLast',
    251         scope + '/weights_quant/AssignMaxLast',
    252         scope + '/depthwise_weights/read'
    253     ]
    254     self._AssertInputOpsAre(weights_quant, expected_inputs)
    255     if delay and delay > 0:
    256       output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
    257     else:
    258       output_op_name = scope + '/depthwise'
    259     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    260 
    261     if with_bypass:
    262       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    263                                                quantization_node_name)
    264       self.assertEqual(conv_quant.type, quantization_node_name)
    265       expected_inputs = [
    266           scope + '/conv_quant/AssignMinEma',
    267           scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
    268       ]
    269       self._AssertInputOpsAre(conv_quant, expected_inputs)
    270       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    271                         if delay else 'test/Add')
    272       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    273 
    274     act_quant = graph.get_operation_by_name('test/act_quant/' +
    275                                             quantization_node_name)
    276     self.assertEqual(act_quant.type, quantization_node_name)
    277     expected_inputs = [
    278         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    279         'test/' + activation_op_name
    280     ]
    281     self._AssertInputOpsAre(act_quant, expected_inputs)
    282     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    283                       if delay else 'control_dependency')
    284     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    285 
    286   def testQuantize_DepthwiseConv2dWithoutBatchNorm(self):
    287     self._RunWithoutBatchNormTestOverParameters(
    288         self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
    289 
    290   def _RunBatchNormTestOverParameters(self, test_fn):
    291     # TODO(suharshs): Use parameterized test once OSS TF supports it.
    292     parameters_list = [
    293         # (activation, activation_op_name, with_bypass, delay, fused_batch_norm)
    294         (nn_ops.relu6, 'Relu6', False, None, False),
    295         (nn_ops.relu, 'Relu', False, None, False),
    296         (array_ops.identity, 'Identity', False, None, False),
    297         (nn_ops.relu6, 'Relu6', False, 5000, False),
    298         (nn_ops.relu, 'Relu', False, 5000, False),
    299         (array_ops.identity, 'Identity', False, 5000, False),
    300         (nn_ops.relu6, 'Relu6', True, None, False),
    301         (nn_ops.relu, 'Relu', True, None, False),
    302         (array_ops.identity, 'Identity', True, None, False),
    303         (nn_ops.relu6, 'Relu6', True, 5000, False),
    304         (nn_ops.relu, 'Relu', True, 5000, False),
    305         (array_ops.identity, 'Identity', True, 5000, False),
    306         (nn_ops.relu6, 'Relu6', False, None, True),
    307         (nn_ops.relu, 'Relu', False, None, True),
    308         (array_ops.identity, 'Identity', False, None, True),
    309         (nn_ops.relu6, 'Relu6', False, 5000, True),
    310         (nn_ops.relu, 'Relu', False, 5000, True),
    311         (array_ops.identity, 'Identity', False, 5000, True),
    312         (nn_ops.relu6, 'Relu6', True, None, True),
    313         (nn_ops.relu, 'Relu', True, None, True),
    314         (array_ops.identity, 'Identity', True, None, True),
    315         (nn_ops.relu6, 'Relu6', True, 5000, True),
    316         (nn_ops.relu, 'Relu', True, 5000, True),
    317         (array_ops.identity, 'Identity', True, 5000, True)
    318     ]
    319     for params in parameters_list:
    320       test_fn(params[0], params[1], params[2], params[3], params[4])
    321 
    322   def testQuantize_Conv2dWithBatchNorm(self):
    323     self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
    324 
    325   def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
    326                                         with_bypass, delay, fused_batch_norm):
    327     """Tests quantization: inputs -> Conv2d with batch norm -> Activation.
    328 
    329     Args:
    330       activation: Callable that returns an Operation, a factory method for the
    331         Activation.
    332       activation_op_name: String, name of the Activation operation.
    333       with_bypass: Bool, when true there is an extra connection added from
    334         inputs to just before Activation.
    335       delay: Int (optional), delay in number of steps until quantization starts.
    336       fused_batch_norm: Bool, when true use FusedBatchNorm.
    337     """
    338     graph = ops.Graph()
    339     with graph.as_default():
    340       batch_size, height, width, depth = 5, 128, 128, 3
    341       inputs = array_ops.zeros((batch_size, height, width, depth))
    342       stride = 1 if with_bypass else 2
    343       out_depth = 3 if with_bypass else 32
    344       scope = 'test/test2' if with_bypass else 'test'
    345       node = conv2d(
    346           inputs,
    347           out_depth, [5, 5],
    348           stride=stride,
    349           padding='SAME',
    350           weights_initializer=self._WeightInit(0.09),
    351           activation_fn=None,
    352           normalizer_fn=batch_norm,
    353           normalizer_params=self._BatchNormParams(fused_batch_norm),
    354           scope=scope)
    355 
    356       # Manually add a bypass (optionaly) and an activation.
    357       if with_bypass:
    358         node = math_ops.add(inputs, node, name='test/Add')
    359 
    360       node = activation(node, name='test/' + activation_op_name)
    361 
    362       update_barrier = control_flow_ops.no_op(name='update_barrier')
    363       with ops.control_dependencies([update_barrier]):
    364         array_ops.identity(node, name='control_dependency')
    365 
    366       fold_batch_norms.FoldBatchNorms(graph, is_training=True)
    367 
    368       quantize.Quantize(graph, True, quant_delay=delay)
    369 
    370     quantization_node_name = 'FakeQuantWithMinMaxVars'
    371     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
    372                                                 quantization_node_name)
    373     self.assertEqual(weights_quant.type, quantization_node_name)
    374     expected_inputs = [
    375         scope + '/weights_quant/' + 'AssignMinLast',
    376         scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
    377     ]
    378     self._AssertInputOpsAre(weights_quant, expected_inputs)
    379     output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
    380                               if delay else '/Conv2D_Fold')
    381     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    382 
    383     if with_bypass:
    384       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    385                                                quantization_node_name)
    386       self.assertEqual(conv_quant.type, quantization_node_name)
    387       expected_inputs = [
    388           scope + '/conv_quant/AssignMinEma',
    389           scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
    390       ]
    391       self._AssertInputOpsAre(conv_quant, expected_inputs)
    392       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    393                         if delay else 'test/Add')
    394       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    395 
    396     act_quant = graph.get_operation_by_name('test/act_quant/' +
    397                                             quantization_node_name)
    398     self.assertEqual(act_quant.type, quantization_node_name)
    399     expected_inputs = [
    400         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    401         'test/' + activation_op_name
    402     ]
    403     self._AssertInputOpsAre(act_quant, expected_inputs)
    404     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    405                       if delay else 'control_dependency')
    406     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    407 
    408   def testQuantize_FCWithBatchNorm(self):
    409     self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
    410 
    411   def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
    412                                     with_bypass, delay, fused_batch_norm):
    413     """Tests quantization: inputs -> FC with batch norm -> Activation.
    414 
    415     Args:
    416       activation: Callable that returns an Operation, a factory method for the
    417         Activation.
    418       activation_op_name: String, name of the Activation operation.
    419       with_bypass: Bool, when true there is an extra connection added from
    420         inputs to just before Activation.
    421       delay: Int (optional), delay in number of steps until quantization starts.
    422       fused_batch_norm: Bool, when true use FusedBatchNorm.
    423     """
    424     graph = ops.Graph()
    425     with graph.as_default():
    426       batch_size, depth = 5, 256
    427       inputs = array_ops.zeros((batch_size, depth))
    428       out_depth = 256 if with_bypass else 128
    429       scope = 'test/test2' if with_bypass else 'test'
    430       node = fully_connected(
    431           inputs,
    432           out_depth,
    433           weights_initializer=self._WeightInit(0.03),
    434           activation_fn=None,
    435           normalizer_fn=batch_norm,
    436           normalizer_params=self._BatchNormParams(fused_batch_norm),
    437           scope=scope)
    438 
    439       # Manually add a bypass (optionaly) and an activation.
    440       if with_bypass:
    441         node = math_ops.add(inputs, node, name='test/Add')
    442 
    443       node = activation(node, name='test/' + activation_op_name)
    444 
    445       update_barrier = control_flow_ops.no_op(name='update_barrier')
    446       with ops.control_dependencies([update_barrier]):
    447         array_ops.identity(node, name='control_dependency')
    448 
    449       fold_batch_norms.FoldBatchNorms(graph, is_training=True)
    450 
    451       quantize.Quantize(graph, True, quant_delay=delay)
    452 
    453     quantization_node_name = 'FakeQuantWithMinMaxVars'
    454     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
    455                                                 quantization_node_name)
    456     self.assertEqual(weights_quant.type, quantization_node_name)
    457     expected_inputs = [
    458         scope + '/weights_quant/' + 'AssignMinLast',
    459         scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
    460     ]
    461     self._AssertInputOpsAre(weights_quant, expected_inputs)
    462     output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
    463                               if delay else '/MatMul_Fold')
    464     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    465 
    466     if with_bypass:
    467       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    468                                                quantization_node_name)
    469       self.assertEqual(conv_quant.type, quantization_node_name)
    470       expected_inputs = [
    471           scope + '/conv_quant/AssignMinEma',
    472           scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
    473       ]
    474       self._AssertInputOpsAre(conv_quant, expected_inputs)
    475       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    476                         if delay else 'test/Add')
    477       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    478 
    479     act_quant = graph.get_operation_by_name('test/act_quant/' +
    480                                             quantization_node_name)
    481     self.assertEqual(act_quant.type, quantization_node_name)
    482     expected_inputs = [
    483         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    484         'test/' + activation_op_name
    485     ]
    486     self._AssertInputOpsAre(act_quant, expected_inputs)
    487     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    488                       if delay else 'control_dependency')
    489     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    490 
    491   def testQuantize_DepthwiseConv2dWithBatchNorm(self):
    492     self._RunBatchNormTestOverParameters(
    493         self._TestQuantize_DepthwiseConv2dWithBatchNorm)
    494 
    495   def _TestQuantize_DepthwiseConv2dWithBatchNorm(
    496       self, activation, activation_op_name, with_bypass, delay,
    497       fused_batch_norm):
    498     """Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
    499 
    500     Args:
    501       activation: Callable that returns an Operation, a factory method for the
    502         Activation.
    503       activation_op_name: String, name of the Activation operation.
    504       with_bypass: Bool, when true there is an extra connection added from
    505         inputs to just before Activation.
    506       delay: Int (optional), delay in number of steps until quantization starts.
    507       fused_batch_norm: Bool, when true use FusedBatchNorm.
    508     """
    509     graph = ops.Graph()
    510     with graph.as_default():
    511       batch_size, height, width, depth = 5, 128, 128, 3
    512       inputs = array_ops.zeros((batch_size, height, width, depth))
    513       stride = 1 if with_bypass else 2
    514       scope = 'test/test2' if with_bypass else 'test'
    515       node = separable_conv2d(
    516           inputs,
    517           None, [5, 5],
    518           stride=stride,
    519           depth_multiplier=1.0,
    520           padding='SAME',
    521           weights_initializer=self._WeightInit(0.09),
    522           activation_fn=None,
    523           normalizer_fn=batch_norm,
    524           normalizer_params=self._BatchNormParams(fused_batch_norm),
    525           scope=scope)
    526 
    527       # Manually add a bypass (optionaly) and an activation.
    528       if with_bypass:
    529         node = math_ops.add(inputs, node, name='test/Add')
    530 
    531       node = activation(node, name='test/' + activation_op_name)
    532 
    533       update_barrier = control_flow_ops.no_op(name='update_barrier')
    534       with ops.control_dependencies([update_barrier]):
    535         array_ops.identity(node, name='control_dependency')
    536 
    537       fold_batch_norms.FoldBatchNorms(graph, is_training=True)
    538 
    539       quantize.Quantize(graph, True, quant_delay=delay)
    540     quantization_node_name = 'FakeQuantWithMinMaxVars'
    541     weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
    542                                                 quantization_node_name)
    543     self.assertEqual(weights_quant.type, quantization_node_name)
    544     expected_inputs = [
    545         scope + '/weights_quant/' + 'AssignMinLast',
    546         scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold'
    547     ]
    548     self._AssertInputOpsAre(weights_quant, expected_inputs)
    549     output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
    550                               if delay else '/depthwise_Fold')
    551     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
    552 
    553     if with_bypass:
    554       conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
    555                                                quantization_node_name)
    556       self.assertEqual(conv_quant.type, quantization_node_name)
    557       expected_inputs = [
    558           scope + '/conv_quant/AssignMinEma',
    559           scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
    560       ]
    561       self._AssertInputOpsAre(conv_quant, expected_inputs)
    562       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
    563                         if delay else 'test/Add')
    564       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
    565 
    566     act_quant = graph.get_operation_by_name('test/act_quant/' +
    567                                             quantization_node_name)
    568     self.assertEqual(act_quant.type, quantization_node_name)
    569     expected_inputs = [
    570         'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
    571         'test/' + activation_op_name
    572     ]
    573     self._AssertInputOpsAre(act_quant, expected_inputs)
    574     output_op_name = ('test/act_quant/delayed_quant/Switch_1'
    575                       if delay else 'control_dependency')
    576     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
    577 
    578   def _BatchNormParams(self, fused=False):
    579     return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
    580 
    581   def _WeightInit(self, stddev):
    582     """Returns truncated normal variable initializer.
    583 
    584     Function is defined purely to shorten the name so that it stops wrapping.
    585 
    586     Args:
    587       stddev: Standard deviation of normal variable.
    588 
    589     Returns:
    590       An initialized that initialzes with a truncated normal variable.
    591     """
    592     return init_ops.truncated_normal_initializer(stddev=stddev)
    593 
    594   def _AssertInputOpsAre(self, op, in_op_names):
    595     """Asserts that all inputs to op come from in_op_names (disregarding order).
    596 
    597     Args:
    598       op: Operation to check inputs for.
    599       in_op_names: List of strings, operations where all op's inputs should
    600         come from.
    601     """
    602     expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names]
    603     self.assertItemsEqual([t.name for t in op.inputs], expected_inputs)
    604 
    605   def _AssertOutputGoesToOps(self, op, graph, out_op_names):
    606     """Asserts that outputs from op go to out_op_names (and perhaps others).
    607 
    608     Args:
    609       op: Operation to check outputs for.
    610       graph: Graph where output operations are located.
    611       out_op_names: List of strings, operations where op's outputs should go.
    612     """
    613     for out_op_name in out_op_names:
    614       out_op = graph.get_operation_by_name(out_op_name)
    615       self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
    616 
    617 
    618 if __name__ == '__main__':
    619   googletest.main()
    620