1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Parameterized unit tests for quantizing a Tensorflow graph.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.layers.python.layers import layers 22 from tensorflow.contrib.quantize.python import fold_batch_norms 23 from tensorflow.contrib.quantize.python import quantize 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import control_flow_ops 28 from tensorflow.python.ops import init_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn_ops 31 from tensorflow.python.platform import googletest 32 33 batch_norm = layers.batch_norm 34 conv2d = layers.conv2d 35 fully_connected = layers.fully_connected 36 separable_conv2d = layers.separable_conv2d 37 38 39 class QuantizeTest(test_util.TensorFlowTestCase): 40 41 def _RunWithoutBatchNormTestOverParameters(self, test_fn): 42 # TODO(suharshs): Use parameterized test once OSS TF supports it. 43 parameters_list = [ 44 # (activation, activation_op_name, with_bypass, delay) 45 (nn_ops.relu6, 'Relu6', False, None), 46 (nn_ops.relu, 'Relu', False, None), 47 (array_ops.identity, 'Identity', False, None), 48 (nn_ops.relu6, 'Relu6', False, 5000), 49 (nn_ops.relu, 'Relu', False, 5000), 50 (array_ops.identity, 'Identity', False, 5000), 51 (nn_ops.relu6, 'Relu6', True, None), 52 (nn_ops.relu, 'Relu', True, None), 53 (array_ops.identity, 'Identity', True, None), 54 (nn_ops.relu6, 'Relu6', True, 5000), 55 (nn_ops.relu, 'Relu', True, 5000), 56 (array_ops.identity, 'Identity', True, 5000), 57 ] 58 for params in parameters_list: 59 test_fn(params[0], params[1], params[2], params[3]) 60 61 def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, 62 with_bypass, delay): 63 """Tests quantization: inputs -> Conv2d no batch norm -> Activation. 64 65 Args: 66 activation: Callable that returns an Operation, a factory method for the 67 Activation. 68 activation_op_name: String, name of the Activation operation. 69 with_bypass: Bool, when true there is an extra connection added from 70 inputs to just before Activation. 71 delay: Int (optional), delay in number of steps until quantization starts. 72 """ 73 graph = ops.Graph() 74 with graph.as_default(): 75 batch_size, height, width, depth = 5, 128, 128, 3 76 inputs = array_ops.zeros((batch_size, height, width, depth)) 77 stride = 1 if with_bypass else 2 78 out_depth = 3 if with_bypass else 32 79 activation_fn = None if with_bypass else activation 80 scope = 'test/test2' if with_bypass else 'test' 81 node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', 82 weights_initializer=self._WeightInit(0.09), 83 activation_fn=activation_fn, scope=scope) 84 if with_bypass: 85 node = math_ops.add(inputs, node, name='test/Add') 86 node = activation(node, name='test/' + activation_op_name) 87 update_barrier = control_flow_ops.no_op(name='update_barrier') 88 with ops.control_dependencies([update_barrier]): 89 array_ops.identity(node, name='control_dependency') 90 91 quantize.Quantize(graph, True, quant_delay=delay) 92 quantization_node_name = 'FakeQuantWithMinMaxVars' 93 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 94 quantization_node_name) 95 self.assertEqual(weights_quant.type, quantization_node_name) 96 expected_inputs = [ 97 scope + '/weights_quant/AssignMinLast', 98 scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' 99 ] 100 self._AssertInputOpsAre(weights_quant, expected_inputs) 101 if delay and delay > 0: 102 output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' 103 else: 104 output_op_name = scope + '/Conv2D' 105 106 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 107 108 if with_bypass: 109 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 110 quantization_node_name) 111 self.assertEqual(conv_quant.type, quantization_node_name) 112 expected_inputs = [ 113 scope + '/conv_quant/AssignMinEma', 114 scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' 115 ] 116 self._AssertInputOpsAre(conv_quant, expected_inputs) 117 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 118 if delay else 'test/Add') 119 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 120 121 act_quant = graph.get_operation_by_name('test/act_quant/' + 122 quantization_node_name) 123 self.assertEqual(act_quant.type, quantization_node_name) 124 125 expected_inputs = [ 126 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 127 'test/' + activation_op_name 128 ] 129 self._AssertInputOpsAre(act_quant, expected_inputs) 130 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 131 if delay else 'control_dependency') 132 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 133 134 def testQuantize_Conv2dWithoutBatchNorm(self): 135 self._RunWithoutBatchNormTestOverParameters( 136 self._TestQuantize_Conv2dWithoutBatchNorm) 137 138 def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, 139 with_bypass, delay): 140 """Tests quantization: inputs -> FC no batch norm -> Activation. 141 142 Args: 143 activation: Callable that returns an Operation, a factory method for the 144 Activation. 145 activation_op_name: String, name of the Activation operation. 146 with_bypass: Bool, when true there is an extra connection added from 147 inputs to just before Activation. 148 delay: Int (optional), delay in number of steps until quantization starts. 149 """ 150 graph = ops.Graph() 151 with graph.as_default(): 152 batch_size, depth = 5, 256 153 inputs = array_ops.zeros((batch_size, depth)) 154 out_depth = 256 if with_bypass else 128 155 activation_fn = None if with_bypass else activation 156 scope = 'test/test2' if with_bypass else 'test' 157 node = fully_connected(inputs, out_depth, 158 weights_initializer=self._WeightInit(0.03), 159 activation_fn=activation_fn, scope=scope) 160 if with_bypass: 161 node = math_ops.add(inputs, node, name='test/Add') 162 node = activation(node, name='test/' + activation_op_name) 163 update_barrier = control_flow_ops.no_op(name='update_barrier') 164 with ops.control_dependencies([update_barrier]): 165 array_ops.identity(node, name='control_dependency') 166 167 quantize.Quantize(graph, True, quant_delay=delay) 168 169 quantization_node_name = 'FakeQuantWithMinMaxVars' 170 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 171 quantization_node_name) 172 self.assertEqual(weights_quant.type, quantization_node_name) 173 expected_inputs = [ 174 scope + '/weights_quant/AssignMinLast', 175 scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' 176 ] 177 self._AssertInputOpsAre(weights_quant, expected_inputs) 178 if delay and delay > 0: 179 output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' 180 else: 181 output_op_name = scope + '/MatMul' 182 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 183 184 if with_bypass: 185 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 186 quantization_node_name) 187 self.assertEqual(conv_quant.type, quantization_node_name) 188 expected_inputs = [ 189 scope + '/conv_quant/AssignMinEma', 190 scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' 191 ] 192 self._AssertInputOpsAre(conv_quant, expected_inputs) 193 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 194 if delay else 'test/Add') 195 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 196 197 act_quant = graph.get_operation_by_name('test/act_quant/' + 198 quantization_node_name) 199 self.assertEqual(act_quant.type, quantization_node_name) 200 expected_inputs = [ 201 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 202 'test/' + activation_op_name 203 ] 204 self._AssertInputOpsAre(act_quant, expected_inputs) 205 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 206 if delay else 'control_dependency') 207 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 208 209 def testQuantize_FCWithoutBatchNorm(self): 210 self._RunWithoutBatchNormTestOverParameters( 211 self._TestQuantize_FCWithoutBatchNorm) 212 213 def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( 214 self, activation, activation_op_name, with_bypass, delay): 215 """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. 216 217 Args: 218 activation: Callable that returns an Operation, a factory method for the 219 Activation. 220 activation_op_name: String, name of the Activation operation. 221 with_bypass: Bool, when true there is an extra connection added from 222 inputs to just before Activation. 223 delay: Int (optional), delay in number of steps until quantization starts. 224 """ 225 graph = ops.Graph() 226 with graph.as_default(): 227 batch_size, height, width, depth = 5, 128, 128, 3 228 inputs = array_ops.zeros((batch_size, height, width, depth)) 229 stride = 1 if with_bypass else 2 230 activation_fn = None if with_bypass else activation 231 scope = 'test/test2' if with_bypass else 'test' 232 node = separable_conv2d(inputs, None, [5, 5], stride=stride, 233 depth_multiplier=1.0, padding='SAME', 234 weights_initializer=self._WeightInit(0.09), 235 activation_fn=activation_fn, scope=scope) 236 if with_bypass: 237 node = math_ops.add(inputs, node, name='test/Add') 238 node = activation(node, name='test/' + activation_op_name) 239 update_barrier = control_flow_ops.no_op(name='update_barrier') 240 with ops.control_dependencies([update_barrier]): 241 array_ops.identity(node, name='control_dependency') 242 243 quantize.Quantize(graph, True, quant_delay=delay) 244 245 quantization_node_name = 'FakeQuantWithMinMaxVars' 246 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 247 quantization_node_name) 248 self.assertEqual(weights_quant.type, quantization_node_name) 249 expected_inputs = [ 250 scope + '/weights_quant/AssignMinLast', 251 scope + '/weights_quant/AssignMaxLast', 252 scope + '/depthwise_weights/read' 253 ] 254 self._AssertInputOpsAre(weights_quant, expected_inputs) 255 if delay and delay > 0: 256 output_op_name = scope + '/weights_quant/delayed_quant/Switch_1' 257 else: 258 output_op_name = scope + '/depthwise' 259 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 260 261 if with_bypass: 262 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 263 quantization_node_name) 264 self.assertEqual(conv_quant.type, quantization_node_name) 265 expected_inputs = [ 266 scope + '/conv_quant/AssignMinEma', 267 scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' 268 ] 269 self._AssertInputOpsAre(conv_quant, expected_inputs) 270 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 271 if delay else 'test/Add') 272 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 273 274 act_quant = graph.get_operation_by_name('test/act_quant/' + 275 quantization_node_name) 276 self.assertEqual(act_quant.type, quantization_node_name) 277 expected_inputs = [ 278 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 279 'test/' + activation_op_name 280 ] 281 self._AssertInputOpsAre(act_quant, expected_inputs) 282 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 283 if delay else 'control_dependency') 284 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 285 286 def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): 287 self._RunWithoutBatchNormTestOverParameters( 288 self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) 289 290 def _RunBatchNormTestOverParameters(self, test_fn): 291 # TODO(suharshs): Use parameterized test once OSS TF supports it. 292 parameters_list = [ 293 # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) 294 (nn_ops.relu6, 'Relu6', False, None, False), 295 (nn_ops.relu, 'Relu', False, None, False), 296 (array_ops.identity, 'Identity', False, None, False), 297 (nn_ops.relu6, 'Relu6', False, 5000, False), 298 (nn_ops.relu, 'Relu', False, 5000, False), 299 (array_ops.identity, 'Identity', False, 5000, False), 300 (nn_ops.relu6, 'Relu6', True, None, False), 301 (nn_ops.relu, 'Relu', True, None, False), 302 (array_ops.identity, 'Identity', True, None, False), 303 (nn_ops.relu6, 'Relu6', True, 5000, False), 304 (nn_ops.relu, 'Relu', True, 5000, False), 305 (array_ops.identity, 'Identity', True, 5000, False), 306 (nn_ops.relu6, 'Relu6', False, None, True), 307 (nn_ops.relu, 'Relu', False, None, True), 308 (array_ops.identity, 'Identity', False, None, True), 309 (nn_ops.relu6, 'Relu6', False, 5000, True), 310 (nn_ops.relu, 'Relu', False, 5000, True), 311 (array_ops.identity, 'Identity', False, 5000, True), 312 (nn_ops.relu6, 'Relu6', True, None, True), 313 (nn_ops.relu, 'Relu', True, None, True), 314 (array_ops.identity, 'Identity', True, None, True), 315 (nn_ops.relu6, 'Relu6', True, 5000, True), 316 (nn_ops.relu, 'Relu', True, 5000, True), 317 (array_ops.identity, 'Identity', True, 5000, True) 318 ] 319 for params in parameters_list: 320 test_fn(params[0], params[1], params[2], params[3], params[4]) 321 322 def testQuantize_Conv2dWithBatchNorm(self): 323 self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) 324 325 def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, 326 with_bypass, delay, fused_batch_norm): 327 """Tests quantization: inputs -> Conv2d with batch norm -> Activation. 328 329 Args: 330 activation: Callable that returns an Operation, a factory method for the 331 Activation. 332 activation_op_name: String, name of the Activation operation. 333 with_bypass: Bool, when true there is an extra connection added from 334 inputs to just before Activation. 335 delay: Int (optional), delay in number of steps until quantization starts. 336 fused_batch_norm: Bool, when true use FusedBatchNorm. 337 """ 338 graph = ops.Graph() 339 with graph.as_default(): 340 batch_size, height, width, depth = 5, 128, 128, 3 341 inputs = array_ops.zeros((batch_size, height, width, depth)) 342 stride = 1 if with_bypass else 2 343 out_depth = 3 if with_bypass else 32 344 scope = 'test/test2' if with_bypass else 'test' 345 node = conv2d( 346 inputs, 347 out_depth, [5, 5], 348 stride=stride, 349 padding='SAME', 350 weights_initializer=self._WeightInit(0.09), 351 activation_fn=None, 352 normalizer_fn=batch_norm, 353 normalizer_params=self._BatchNormParams(fused_batch_norm), 354 scope=scope) 355 356 # Manually add a bypass (optionaly) and an activation. 357 if with_bypass: 358 node = math_ops.add(inputs, node, name='test/Add') 359 360 node = activation(node, name='test/' + activation_op_name) 361 362 update_barrier = control_flow_ops.no_op(name='update_barrier') 363 with ops.control_dependencies([update_barrier]): 364 array_ops.identity(node, name='control_dependency') 365 366 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 367 368 quantize.Quantize(graph, True, quant_delay=delay) 369 370 quantization_node_name = 'FakeQuantWithMinMaxVars' 371 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 372 quantization_node_name) 373 self.assertEqual(weights_quant.type, quantization_node_name) 374 expected_inputs = [ 375 scope + '/weights_quant/' + 'AssignMinLast', 376 scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' 377 ] 378 self._AssertInputOpsAre(weights_quant, expected_inputs) 379 output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' 380 if delay else '/Conv2D_Fold') 381 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 382 383 if with_bypass: 384 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 385 quantization_node_name) 386 self.assertEqual(conv_quant.type, quantization_node_name) 387 expected_inputs = [ 388 scope + '/conv_quant/AssignMinEma', 389 scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' 390 ] 391 self._AssertInputOpsAre(conv_quant, expected_inputs) 392 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 393 if delay else 'test/Add') 394 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 395 396 act_quant = graph.get_operation_by_name('test/act_quant/' + 397 quantization_node_name) 398 self.assertEqual(act_quant.type, quantization_node_name) 399 expected_inputs = [ 400 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 401 'test/' + activation_op_name 402 ] 403 self._AssertInputOpsAre(act_quant, expected_inputs) 404 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 405 if delay else 'control_dependency') 406 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 407 408 def testQuantize_FCWithBatchNorm(self): 409 self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) 410 411 def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, 412 with_bypass, delay, fused_batch_norm): 413 """Tests quantization: inputs -> FC with batch norm -> Activation. 414 415 Args: 416 activation: Callable that returns an Operation, a factory method for the 417 Activation. 418 activation_op_name: String, name of the Activation operation. 419 with_bypass: Bool, when true there is an extra connection added from 420 inputs to just before Activation. 421 delay: Int (optional), delay in number of steps until quantization starts. 422 fused_batch_norm: Bool, when true use FusedBatchNorm. 423 """ 424 graph = ops.Graph() 425 with graph.as_default(): 426 batch_size, depth = 5, 256 427 inputs = array_ops.zeros((batch_size, depth)) 428 out_depth = 256 if with_bypass else 128 429 scope = 'test/test2' if with_bypass else 'test' 430 node = fully_connected( 431 inputs, 432 out_depth, 433 weights_initializer=self._WeightInit(0.03), 434 activation_fn=None, 435 normalizer_fn=batch_norm, 436 normalizer_params=self._BatchNormParams(fused_batch_norm), 437 scope=scope) 438 439 # Manually add a bypass (optionaly) and an activation. 440 if with_bypass: 441 node = math_ops.add(inputs, node, name='test/Add') 442 443 node = activation(node, name='test/' + activation_op_name) 444 445 update_barrier = control_flow_ops.no_op(name='update_barrier') 446 with ops.control_dependencies([update_barrier]): 447 array_ops.identity(node, name='control_dependency') 448 449 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 450 451 quantize.Quantize(graph, True, quant_delay=delay) 452 453 quantization_node_name = 'FakeQuantWithMinMaxVars' 454 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 455 quantization_node_name) 456 self.assertEqual(weights_quant.type, quantization_node_name) 457 expected_inputs = [ 458 scope + '/weights_quant/' + 'AssignMinLast', 459 scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' 460 ] 461 self._AssertInputOpsAre(weights_quant, expected_inputs) 462 output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' 463 if delay else '/MatMul_Fold') 464 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 465 466 if with_bypass: 467 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 468 quantization_node_name) 469 self.assertEqual(conv_quant.type, quantization_node_name) 470 expected_inputs = [ 471 scope + '/conv_quant/AssignMinEma', 472 scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' 473 ] 474 self._AssertInputOpsAre(conv_quant, expected_inputs) 475 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 476 if delay else 'test/Add') 477 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 478 479 act_quant = graph.get_operation_by_name('test/act_quant/' + 480 quantization_node_name) 481 self.assertEqual(act_quant.type, quantization_node_name) 482 expected_inputs = [ 483 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 484 'test/' + activation_op_name 485 ] 486 self._AssertInputOpsAre(act_quant, expected_inputs) 487 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 488 if delay else 'control_dependency') 489 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 490 491 def testQuantize_DepthwiseConv2dWithBatchNorm(self): 492 self._RunBatchNormTestOverParameters( 493 self._TestQuantize_DepthwiseConv2dWithBatchNorm) 494 495 def _TestQuantize_DepthwiseConv2dWithBatchNorm( 496 self, activation, activation_op_name, with_bypass, delay, 497 fused_batch_norm): 498 """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. 499 500 Args: 501 activation: Callable that returns an Operation, a factory method for the 502 Activation. 503 activation_op_name: String, name of the Activation operation. 504 with_bypass: Bool, when true there is an extra connection added from 505 inputs to just before Activation. 506 delay: Int (optional), delay in number of steps until quantization starts. 507 fused_batch_norm: Bool, when true use FusedBatchNorm. 508 """ 509 graph = ops.Graph() 510 with graph.as_default(): 511 batch_size, height, width, depth = 5, 128, 128, 3 512 inputs = array_ops.zeros((batch_size, height, width, depth)) 513 stride = 1 if with_bypass else 2 514 scope = 'test/test2' if with_bypass else 'test' 515 node = separable_conv2d( 516 inputs, 517 None, [5, 5], 518 stride=stride, 519 depth_multiplier=1.0, 520 padding='SAME', 521 weights_initializer=self._WeightInit(0.09), 522 activation_fn=None, 523 normalizer_fn=batch_norm, 524 normalizer_params=self._BatchNormParams(fused_batch_norm), 525 scope=scope) 526 527 # Manually add a bypass (optionaly) and an activation. 528 if with_bypass: 529 node = math_ops.add(inputs, node, name='test/Add') 530 531 node = activation(node, name='test/' + activation_op_name) 532 533 update_barrier = control_flow_ops.no_op(name='update_barrier') 534 with ops.control_dependencies([update_barrier]): 535 array_ops.identity(node, name='control_dependency') 536 537 fold_batch_norms.FoldBatchNorms(graph, is_training=True) 538 539 quantize.Quantize(graph, True, quant_delay=delay) 540 quantization_node_name = 'FakeQuantWithMinMaxVars' 541 weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + 542 quantization_node_name) 543 self.assertEqual(weights_quant.type, quantization_node_name) 544 expected_inputs = [ 545 scope + '/weights_quant/' + 'AssignMinLast', 546 scope + '/weights_quant/' + 'AssignMaxLast', scope + '/mul_fold' 547 ] 548 self._AssertInputOpsAre(weights_quant, expected_inputs) 549 output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' 550 if delay else '/depthwise_Fold') 551 self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) 552 553 if with_bypass: 554 conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + 555 quantization_node_name) 556 self.assertEqual(conv_quant.type, quantization_node_name) 557 expected_inputs = [ 558 scope + '/conv_quant/AssignMinEma', 559 scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' 560 ] 561 self._AssertInputOpsAre(conv_quant, expected_inputs) 562 output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' 563 if delay else 'test/Add') 564 self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) 565 566 act_quant = graph.get_operation_by_name('test/act_quant/' + 567 quantization_node_name) 568 self.assertEqual(act_quant.type, quantization_node_name) 569 expected_inputs = [ 570 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 571 'test/' + activation_op_name 572 ] 573 self._AssertInputOpsAre(act_quant, expected_inputs) 574 output_op_name = ('test/act_quant/delayed_quant/Switch_1' 575 if delay else 'control_dependency') 576 self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) 577 578 def _BatchNormParams(self, fused=False): 579 return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} 580 581 def _WeightInit(self, stddev): 582 """Returns truncated normal variable initializer. 583 584 Function is defined purely to shorten the name so that it stops wrapping. 585 586 Args: 587 stddev: Standard deviation of normal variable. 588 589 Returns: 590 An initialized that initialzes with a truncated normal variable. 591 """ 592 return init_ops.truncated_normal_initializer(stddev=stddev) 593 594 def _AssertInputOpsAre(self, op, in_op_names): 595 """Asserts that all inputs to op come from in_op_names (disregarding order). 596 597 Args: 598 op: Operation to check inputs for. 599 in_op_names: List of strings, operations where all op's inputs should 600 come from. 601 """ 602 expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] 603 self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) 604 605 def _AssertOutputGoesToOps(self, op, graph, out_op_names): 606 """Asserts that outputs from op go to out_op_names (and perhaps others). 607 608 Args: 609 op: Operation to check outputs for. 610 graph: Graph where output operations are located. 611 out_op_names: List of strings, operations where op's outputs should go. 612 """ 613 for out_op_name in out_op_names: 614 out_op = graph.get_operation_by_name(out_op_name) 615 self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) 616 617 618 if __name__ == '__main__': 619 googletest.main() 620