Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for fused conv2d bias and activation operation."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors_impl
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gen_array_ops
     30 from tensorflow.python.ops import nn_ops
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.platform import tf_logging
     34 
     35 
     36 def GetShrunkInceptionShapes(shrink=10):
     37   """Iterator for smaller versions of convolution shapes in 2015 Inception.
     38 
     39   Relative to inception, each depth value is `depth // shrink`.
     40 
     41   Args:
     42     shrink: Factor to shrink each depth value by relative to Inception.
     43 
     44   Yields:
     45     Tuple (input_size, filter_size, out_size, stride, padding), the convolution
     46     parameters of Inception layers.
     47   """
     48   input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
     49       4, 8, 8, 2048
     50   ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
     51       4, 8, 8, 1760
     52   ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
     53       4, 17, 17, 192
     54   ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
     55       4, 17, 17, 192
     56   ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
     57       4, 17, 17, 192
     58   ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
     59       4, 17, 17, 160
     60   ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
     61                  [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
     62                      4, 17, 17, 768
     63                  ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
     64                  [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
     65                      4, 35, 35, 64
     66                  ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
     67                      4, 35, 35, 256
     68                  ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
     69                      4, 35, 35, 192
     70                  ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
     71   filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
     72       1, 1, 2048, 192
     73   ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
     74                   [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
     75                       1, 1, 1760, 320
     76                   ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
     77                       3, 3, 128, 320
     78                   ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
     79                       1, 3, 192, 256
     80                   ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
     81                       3, 3, 192, 224
     82                   ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
     83                       3, 1, 192, 192
     84                   ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
     85                       1, 3, 128, 192
     86                   ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
     87                       3, 1, 128, 128
     88                   ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
     89                       1, 1, 768, 128
     90                   ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
     91                       3, 3, 64, 96
     92                   ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
     93                   [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
     94                       1, 1, 192, 64
     95                   ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
     96                                                         64], [1, 1, 24, 64]]
     97   out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
     98       4, 8, 8, 384
     99   ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
    100       4, 8, 8, 192
    101   ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
    102       4, 17, 17, 192
    103   ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
    104       4, 17, 17, 256
    105   ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
    106       4, 17, 17, 192
    107   ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
    108       4, 17, 17, 160
    109   ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
    110       4, 17, 17, 256
    111   ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
    112       4, 17, 17, 128
    113   ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
    114       4, 35, 35, 64
    115   ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
    116                [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
    117                [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
    118   strides = [
    119       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    120       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
    121       1, 1, 1, 1, 1
    122   ]
    123   # Shrink sizes to make the test faster
    124   for i in input_sizes:
    125     i[3] //= shrink
    126   for f in filter_sizes:
    127     f[2] //= shrink
    128     f[3] //= shrink
    129   for o in out_sizes:
    130     o[3] //= shrink
    131   # pylint: disable=invalid-name
    132   VALID = "VALID"
    133   SAME = "SAME"
    134   # pylint: enable=invalid-name
    135   paddings = [
    136       SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    137       VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    138       SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    139       SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
    140       SAME, SAME, SAME, SAME, VALID, VALID, VALID
    141   ]
    142   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
    143                            paddings):
    144     yield i, f, o, s, p
    145 
    146 
    147 def GetTestConfigs():
    148   """Get all the valid tests configs to run.
    149 
    150   Returns:
    151     all the valid test configs as tuples of data_format and use_gpu.
    152   """
    153   test_configs = [("NCHW", True), ("NHWC", True)]
    154   return test_configs
    155 
    156 
    157 class FusedConv2DBiasActivationTest(test.TestCase):
    158 
    159   def _DtypesToTest(self, use_gpu):
    160     return [dtypes.float32]
    161 
    162   def _FilterFormatsToTest(self, use_gpu):
    163     return ["HWIO", "OIHW"]
    164 
    165   def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
    166                             strides, padding, activation_mode, data_format,
    167                             filter_format, dtype):
    168     """Verifies the output values of the convolution function.
    169 
    170     Args:
    171       tensor_in_sizes: Input tensor dimensions in
    172         [batch, input_rows, input_cols, input_depth].
    173       filter_in_sizes: Filter tensor dimensions in
    174         [kernel_rows, kernel_cols, input_depth, output_depth].
    175       bias: 1-D bias tensor of length output_depth.
    176       strides: Stride: [col_stride, row_stride]
    177       padding: Padding type.
    178       activation_mode: Activation mode.
    179       data_format: Format of the data tensors.
    180       filter_format: Filter format to use for the fused convolution.
    181       dtype: Data type for inputs and outputs.
    182     Returns:
    183       Symbolic tensor value and reference value that can be used to
    184       execute the computation and verify the results.
    185     """
    186     input_size = np.prod(tensor_in_sizes)
    187     filter_size = np.prod(filter_in_sizes)
    188     bias_size = filter_in_sizes[-1]  # equals to output depth
    189     # Initializes the input tensor with array containing incrementing
    190     # numbers from 1.
    191     x1 = [f * 1.0 for f in range(1, input_size + 1)]
    192     x2 = [f * 1.0 for f in range(1, filter_size + 1)]
    193     # This is to guarantee that there is always negative values after
    194     # bias add so that we can test whether relu works correctly.
    195     x3 = bias
    196     with self.test_session(use_gpu=True):
    197       t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
    198       t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
    199       fused_t2 = t2
    200       if filter_format == "OIHW":
    201         fused_t2 = HwioToOihw(t2)
    202       t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
    203       strides = [1] + strides + [1]
    204       if data_format == "NCHW":
    205         t1 = test_util.NHWCToNCHW(t1)
    206         strides = test_util.NHWCToNCHW(strides)
    207       output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    208           t1,
    209           fused_t2,
    210           t3,
    211           strides=strides,
    212           padding=padding,
    213           data_format=data_format,
    214           filter_format=filter_format,
    215           activation_mode=activation_mode)
    216       ref_conv_output = nn_ops.conv2d(
    217           t1, t2, strides=strides, padding=padding, data_format=data_format)
    218       ref_bias_output = nn_ops.bias_add(
    219           ref_conv_output, t3, data_format=data_format)
    220       ref_output = nn_ops.relu(ref_bias_output)
    221       if data_format == "NCHW":
    222         output = test_util.NCHWToNHWC(output)
    223         ref_output = test_util.NCHWToNHWC(ref_output)
    224 
    225       return output, ref_output
    226 
    227   def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
    228                         padding):
    229     """Verifies that CPU and GPU produce the same values.
    230 
    231     Args:
    232       tensor_in_sizes: Input tensor dimensions in
    233         [batch, input_rows, input_cols, input_depth].
    234       filter_in_sizes: Filter tensor dimensions in
    235         [kernel_rows, kernel_cols, input_depth, output_depth].
    236       conv_strides: [row_stride, col_stride] for the convolution;
    237       padding: Padding type.
    238     """
    239     x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
    240     x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
    241     x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
    242 
    243     def _SetupVal(data_format, use_gpu):
    244       with self.test_session(use_gpu=use_gpu):
    245         t1 = constant_op.constant(x1, shape=tensor_in_sizes)
    246         t2 = constant_op.constant(x2, shape=filter_in_sizes)
    247         t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
    248         strides = [1] + conv_strides + [1]
    249         if data_format == "NCHW":
    250           t1 = test_util.NHWCToNCHW(t1)
    251           strides = test_util.NHWCToNCHW(strides)
    252         output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    253             t1,
    254             t2,
    255             t3,
    256             strides=strides,
    257             padding=padding,
    258             data_format=data_format,
    259             activation_mode="Relu")
    260 
    261         if data_format == "NCHW":
    262           output = test_util.NCHWToNHWC(output)
    263         return output
    264 
    265     tensors = []
    266     for (data_format, use_gpu) in GetTestConfigs():
    267       tensors.append(_SetupVal(data_format, use_gpu))
    268     with self.test_session() as sess:
    269       values = sess.run(tensors)
    270       for i in range(1, len(values)):
    271         self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
    272 
    273   def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
    274                     padding):
    275     tensors = []
    276     ref_tensors = []
    277     for (data_format, use_gpu) in GetTestConfigs():
    278       for dtype in self._DtypesToTest(use_gpu):
    279         for filter_format in self._FilterFormatsToTest(use_gpu):
    280           result, expected = self._SetupValuesForDevice(
    281               tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
    282               data_format, filter_format, dtype)
    283         tensors.append(result)
    284         ref_tensors.append(expected)
    285       with self.test_session() as sess:
    286         values = sess.run(tensors)
    287         ref_values = sess.run(ref_tensors)
    288         for i in range(len(tensors)):
    289           conv = tensors[i]
    290           value = values[i]
    291           ref_value = ref_values[i]
    292           print("expected = ", ref_value)
    293           print("actual = ", value)
    294           tol = 1e-5
    295           if value.dtype == np.float16:
    296             tol = 1e-3
    297           self.assertAllClose(
    298               np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
    299           self.assertShapeEqual(value, conv)
    300 
    301   def testConv2D1x1Filter(self, gpu_only=True):
    302     if gpu_only and not test.is_gpu_available():
    303       tf_logging.info("Skipping Conv2D1x1Filter test.")
    304       return
    305     # expected_output = [
    306     #    0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
    307     #    86.0, 43.0, 165.0, 131.0, 97.0
    308     # ]
    309     medians = [-45.0, -130.0, -215.0]
    310     self._VerifyValues(
    311         tensor_in_sizes=[1, 2, 3, 3],
    312         filter_in_sizes=[1, 1, 3, 3],
    313         bias=medians,
    314         strides=[1, 1],
    315         padding="VALID")
    316 
    317   def testConv2DEmpty(self, gpu_only=True):
    318     if gpu_only and not test.is_gpu_available():
    319       tf_logging.info("Skipping Conv2DEmpty test.")
    320       return
    321     # expected_output = []
    322     self._VerifyValues(
    323         tensor_in_sizes=[0, 2, 3, 3],
    324         filter_in_sizes=[1, 1, 3, 3],
    325         bias=[0.0, 0.0, 0.0],
    326         strides=[1, 1],
    327         padding="VALID")
    328 
    329   def testConv2D2x2Filter(self, gpu_only=True):
    330     if gpu_only and not test.is_gpu_available():
    331       tf_logging.info("Skipping Conv2D2x2Filter test.")
    332       return
    333     # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
    334     self._VerifyValues(
    335         tensor_in_sizes=[1, 2, 3, 3],
    336         filter_in_sizes=[2, 2, 3, 3],
    337         bias=[-2500.0, -2500.0, -2500.0],
    338         strides=[1, 1],
    339         padding="VALID")
    340 
    341   def testConv2D1x2Filter(self, gpu_only=True):
    342     if gpu_only and not test.is_gpu_available():
    343       tf_logging.info("Skipping Conv2D1x2Filter test.")
    344       return
    345     # expected_output = [
    346     #    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
    347     # ]
    348     self._VerifyValues(
    349         tensor_in_sizes=[1, 2, 3, 3],
    350         filter_in_sizes=[1, 2, 3, 3],
    351         bias=[-500.0, -500.0, -500.0],
    352         strides=[1, 1],
    353         padding="VALID")
    354 
    355   def testConv2D2x2FilterStride2(self, gpu_only=True):
    356     if gpu_only and not test.is_gpu_available():
    357       tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
    358       return
    359     # expected_output = [0.0, 67.0, 163.0]
    360     self._VerifyValues(
    361         tensor_in_sizes=[1, 2, 3, 3],
    362         filter_in_sizes=[2, 2, 3, 3],
    363         bias=[-2300.0, -2300.0, -2300.0],
    364         strides=[2, 2],
    365         padding="VALID")
    366 
    367   def testConv2D2x2FilterStride2Same(self, gpu_only=True):
    368     if gpu_only and not test.is_gpu_available():
    369       tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
    370       return
    371     # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
    372     self._VerifyValues(
    373         tensor_in_sizes=[1, 2, 3, 3],
    374         filter_in_sizes=[2, 2, 3, 3],
    375         bias=[-2300.0, -1000.0, -1000.0],
    376         strides=[2, 2],
    377         padding="SAME")
    378 
    379   def testConv2D2x2FilterStride1x2(self, gpu_only=True):
    380     if gpu_only and not test.is_gpu_available():
    381       tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
    382       return
    383     # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
    384     self._VerifyValues(
    385         tensor_in_sizes=[1, 3, 6, 1],
    386         filter_in_sizes=[2, 2, 1, 1],
    387         bias=[-90.0],
    388         strides=[1, 2],
    389         padding="VALID")
    390 
    391   def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
    392     if gpu_only and not test.is_gpu_available():
    393       tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
    394       return
    395     # expected_output = [0, 0, 175, 205]
    396     self._VerifyValues(
    397         tensor_in_sizes=[1, 7, 7, 1],
    398         filter_in_sizes=[2, 2, 1, 1],
    399         bias=[-100.0],
    400         strides=[3, 3],
    401         padding="VALID")
    402 
    403   def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
    404     if gpu_only and not test.is_gpu_available():
    405       tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
    406       return
    407     # expected = [0, 0, 2, 4]
    408     self._VerifyValues(
    409         tensor_in_sizes=[1, 3, 3, 1],
    410         filter_in_sizes=[1, 1, 1, 1],
    411         bias=[-5.0],
    412         strides=[2, 2],
    413         padding="SAME")
    414 
    415     # expected = [0, 0, 4, 6]
    416     self._VerifyValues(
    417         tensor_in_sizes=[1, 4, 4, 1],
    418         filter_in_sizes=[1, 1, 1, 1],
    419         bias=[-5.0],
    420         strides=[2, 2],
    421         padding="SAME")
    422 
    423     # expected = [4, 0, 1, 0]
    424     self._VerifyValues(
    425         tensor_in_sizes=[1, 4, 4, 1],
    426         filter_in_sizes=[2, 2, 1, 1],
    427         bias=[-40.0],
    428         strides=[3, 3],
    429         padding="SAME")
    430 
    431   def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
    432     if gpu_only and not test.is_gpu_available():
    433       tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
    434       return
    435     # expected = [0, 5]
    436     self._VerifyValues(
    437         tensor_in_sizes=[1, 2, 2, 1],
    438         filter_in_sizes=[2, 2, 1, 2],
    439         bias=[-50.0, -55.0],
    440         strides=[1, 1],
    441         padding="VALID")
    442 
    443     # expected = [0, 2, 282, 322]
    444     self._VerifyValues(
    445         tensor_in_sizes=[1, 8, 8, 1],
    446         filter_in_sizes=[2, 2, 1, 1],
    447         bias=[-200.0],
    448         strides=[4, 4],
    449         padding="SAME")
    450 
    451   def testShapeFunctionEdgeCases(self):
    452     # All shapes unknown.
    453     c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    454         array_ops.placeholder(dtypes.float32),
    455         array_ops.placeholder(dtypes.float32),
    456         array_ops.placeholder(dtypes.float32),
    457         strides=[1, 1, 1, 1],
    458         padding="SAME",
    459         activation_mode="Relu")
    460     self.assertEqual([None, None, None, None], c1.get_shape().as_list())
    461 
    462     # Incorrect input shape.
    463     with self.assertRaises(ValueError):
    464       fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    465           array_ops.placeholder(dtypes.float32, shape=[1, 3]),
    466           array_ops.placeholder(dtypes.float32),
    467           array_ops.placeholder(dtypes.float32),
    468           strides=[1, 1, 1, 1],
    469           padding="SAME",
    470           activation_mode="Relu")
    471 
    472     # Incorrect filter shape.
    473     with self.assertRaises(ValueError):
    474       fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    475           array_ops.placeholder(dtypes.float32),
    476           array_ops.placeholder(dtypes.float32, shape=[1, 3]),
    477           array_ops.placeholder(dtypes.float32),
    478           strides=[1, 1, 1, 1],
    479           padding="SAME",
    480           activation_mode="Relu")
    481 
    482     # Depth mismatch.
    483     with self.assertRaises(ValueError):
    484       fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    485           array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
    486           array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
    487           array_ops.placeholder(dtypes.float32),
    488           strides=[1, 1, 1, 1],
    489           padding="SAME",
    490           activation_mode="Relu")
    491 
    492   def testOpEdgeCases(self, gpu_only=True):
    493     if gpu_only and not test.is_gpu_available():
    494       tf_logging.info("Skipping OpEdgeCases tests.")
    495       return
    496     with self.test_session() as sess:
    497       # Illegal strides.
    498       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    499                                    "Convolutional strides are not supported in "
    500                                    "the batch or depth dimensions."):
    501         sess.run(
    502             fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    503                 array_ops.placeholder(dtypes.float32),
    504                 array_ops.placeholder(dtypes.float32),
    505                 array_ops.placeholder(dtypes.float32),
    506                 strides=[2, 1, 1, 1],
    507                 padding="SAME",
    508                 activation_mode="Relu"))
    509       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    510                                    "Convolutional strides are not supported in "
    511                                    "the batch or depth dimensions."):
    512         sess.run(
    513             fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    514                 array_ops.placeholder(dtypes.float32),
    515                 array_ops.placeholder(dtypes.float32),
    516                 array_ops.placeholder(dtypes.float32),
    517                 strides=[1, 1, 1, 2],
    518                 padding="SAME",
    519                 activation_mode="Relu"))
    520 
    521       # Illegal activation mode.
    522       with self.assertRaisesRegexp(ValueError,
    523                                    "Op passed string 'Tanh' not in:"):
    524         sess.run(
    525             fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    526                 array_ops.placeholder(dtypes.float32),
    527                 array_ops.placeholder(dtypes.float32),
    528                 array_ops.placeholder(dtypes.float32),
    529                 strides=[1, 1, 1, 1],
    530                 padding="SAME",
    531                 activation_mode="Tanh"))
    532 
    533       # Filter larger than input.
    534       with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
    535         sess.run(
    536             fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    537                 array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
    538                 array_ops.placeholder(dtypes.float32, shape=[20, 21, 3, 2]),
    539                 array_ops.placeholder(dtypes.float32, shape=[2]),
    540                 strides=[1, 1, 1, 1],
    541                 padding="VALID",
    542                 activation_mode="Relu"))
    543       with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
    544         sess.run(
    545             fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    546                 array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
    547                 array_ops.placeholder(dtypes.float32, shape=[21, 20, 3, 2]),
    548                 array_ops.placeholder(dtypes.float32, shape=[2]),
    549                 strides=[1, 1, 1, 1],
    550                 padding="VALID",
    551                 activation_mode="Relu"))
    552 
    553 
    554 def GetInceptionFwdTest(input_size, filter_size, stride, padding,
    555                         gpu_only=True):
    556 
    557   def Test(self):
    558     if gpu_only and not test.is_gpu_available():
    559       tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size,
    560                                                    stride, padding))
    561       return
    562     tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride,
    563                                                 padding))
    564     self._CompareFwdValues(input_size, filter_size, [stride, stride], padding)
    565 
    566   return Test
    567 
    568 
    569 def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type):
    570   """Calculates the size of an output dimension of a strided convolution.
    571 
    572   Given the sizes of the corresponding dimension of the input and filter shapes,
    573   and the stride and padding_types, calculates the size of the output dimension.
    574   This function can be called separately for each input dimension.
    575 
    576   Args:
    577     input_dim: An `int` specifying the size of the input dimension.
    578     filter_dim: An `int` specifying the size of the filter dimension.
    579     stride: An `int` specifying the step size of the convolution along the
    580       input dimension.
    581     padding_type: either 'VALID' or 'SAME'.
    582 
    583   Returns:
    584     The size of the output dimension.
    585   """
    586   if padding_type == "VALID":
    587     return (input_dim - filter_dim + stride) // stride
    588   else:  # padding_type == 'SAME'
    589     return (input_dim + stride - 1) // stride
    590 
    591 
    592 def NchwVectCToNchw(in_tensor):
    593   # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
    594   t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
    595   n = in_tensor.shape.dims[0].value
    596   c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
    597   h = in_tensor.shape.dims[2].value
    598   w = in_tensor.shape.dims[3].value
    599   return array_ops.reshape(t, [n, c, h, w])
    600 
    601 
    602 def OihwVectIToHwio(in_tensor):
    603   # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
    604   t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
    605   o = in_tensor.shape.dims[0].value
    606   i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
    607   h = in_tensor.shape.dims[2].value
    608   w = in_tensor.shape.dims[3].value
    609   return array_ops.reshape(t, [h, w, i, o])
    610 
    611 
    612 def NchwToNchwVectC(in_tensor):
    613   n, c, h, w = in_tensor.shape.as_list()
    614   assert c % 4 == 0
    615   t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
    616   return array_ops.transpose(t, [0, 1, 3, 4, 2])
    617 
    618 
    619 def HwioToOihw(in_tensor):
    620   return array_ops.transpose(in_tensor, [3, 2, 0, 1])
    621 
    622 
    623 def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
    624                                           padding, strides, side_input_scale,
    625                                           side_input, biases):
    626   """Simulates the int8 fused 2-D convolution op using separate float ops.
    627 
    628     The arguments and return values have the same format, meanings and
    629     restrictions as the actual op.
    630   Args:
    631     conv_input_scale: A scalar 'float'.
    632     conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
    633     kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
    634     padding: A `string` from: `"SAME", "VALID"`.
    635     strides: A list of `ints`.
    636     side_input_scale: A scalar 'float'.
    637     side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
    638     biases: A `Tensor` of type `float32` in NCHW layout.
    639   Returns:
    640     A `Tensor` of type `qint8` in NCHW_VECT_C layout.
    641   """
    642   conv_result = nn_ops.conv2d(
    643       NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
    644       OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
    645       strides=strides,
    646       padding=padding,
    647       data_format="NCHW") * conv_input_scale
    648 
    649   conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
    650       gen_array_ops.dequantize(side_input, -128, 127))
    651 
    652   logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
    653 
    654   result, _, _ = gen_array_ops.quantize_v2(
    655       NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8)
    656   return result
    657 
    658 
    659 class FusedConvInt8Tests(test.TestCase):
    660   _test_params = [
    661       {
    662           "batch_size": 1,
    663           "input_channels": 4,
    664           "output_channels": 4,
    665           "input_height": 8,
    666           "input_width": 8,
    667           "filter_height": 6,
    668           "filter_width": 6,
    669           "vertical_stride": 2,
    670           "horizontal_stride": 2,
    671           "conv_input_scale": 0.002,
    672           "side_input_scale": 0.0,
    673           "bias_scale": 1,
    674           "padding_type": "SAME"
    675       },
    676       {
    677           "batch_size": 1,
    678           "input_channels": 4,
    679           "output_channels": 4,
    680           "input_height": 6,
    681           "input_width": 6,
    682           "filter_height": 6,
    683           "filter_width": 6,
    684           "vertical_stride": 2,
    685           "horizontal_stride": 2,
    686           "conv_input_scale": 0.002,
    687           "side_input_scale": 0.0,
    688           "bias_scale": 1,
    689           "padding_type": "SAME"
    690       },
    691       {
    692           "batch_size": 2,
    693           "input_channels": 8,
    694           "output_channels": 16,
    695           "input_height": 8,
    696           "input_width": 8,
    697           "filter_height": 3,
    698           "filter_width": 3,
    699           "vertical_stride": 2,
    700           "horizontal_stride": 2,
    701           "conv_input_scale": 0.002,
    702           "side_input_scale": 0.0,
    703           "bias_scale": 1,
    704           "padding_type": "VALID"
    705       },
    706       {
    707           "batch_size": 2,
    708           "input_channels": 8,
    709           "output_channels": 16,
    710           "input_height": 8,
    711           "input_width": 8,
    712           "filter_height": 3,
    713           "filter_width": 3,
    714           "vertical_stride": 2,
    715           "horizontal_stride": 2,
    716           "conv_input_scale": 0.002,
    717           "side_input_scale": 0.0,
    718           "bias_scale": 1,
    719           "padding_type": "SAME"
    720       },
    721       {
    722           "batch_size": 2,
    723           "input_channels": 8,
    724           "output_channels": 16,
    725           "input_height": 8,
    726           "input_width": 8,
    727           "filter_height": 3,
    728           "filter_width": 3,
    729           "vertical_stride": 2,
    730           "horizontal_stride": 2,
    731           "conv_input_scale": 0.002,
    732           "side_input_scale": 0.5,
    733           "bias_scale": 1,
    734           "padding_type": "VALID"
    735       },
    736       {
    737           "batch_size": 2,
    738           "input_channels": 16,
    739           "output_channels": 16,
    740           "input_height": 9,
    741           "input_width": 9,
    742           "filter_height": 3,
    743           "filter_width": 3,
    744           "vertical_stride": 1,
    745           "horizontal_stride": 1,
    746           "conv_input_scale": 0.001,
    747           "side_input_scale": 0.5,
    748           "bias_scale": 1,
    749           "padding_type": "SAME"
    750       },
    751       {
    752           "batch_size": 3,
    753           "input_channels": 8,
    754           "output_channels": 8,
    755           "input_height": 9,
    756           "input_width": 9,
    757           "filter_height": 5,
    758           "filter_width": 5,
    759           "vertical_stride": 1,
    760           "horizontal_stride": 1,
    761           "conv_input_scale": 0.001,
    762           "side_input_scale": 0.5,
    763           "bias_scale": 1,
    764           "padding_type": "SAME"
    765       },
    766       {
    767           "batch_size": 3,
    768           "input_channels": 8,
    769           "output_channels": 8,
    770           "input_height": 9,
    771           "input_width": 9,
    772           "filter_height": 7,
    773           "filter_width": 1,
    774           "vertical_stride": 2,
    775           "horizontal_stride": 1,
    776           "conv_input_scale": 0.002,
    777           "side_input_scale": 0.5,
    778           "bias_scale": 1,
    779           "padding_type": "SAME"
    780       },
    781       {
    782           "batch_size": 3,
    783           "input_channels": 8,
    784           "output_channels": 8,
    785           "input_height": 9,
    786           "input_width": 9,
    787           "filter_height": 1,
    788           "filter_width": 7,
    789           "vertical_stride": 1,
    790           "horizontal_stride": 1,
    791           "conv_input_scale": 0.002,
    792           "side_input_scale": 0.5,
    793           "bias_scale": 1,
    794           "padding_type": "SAME"
    795       },
    796   ]
    797 
    798   def runTest(self, test_param):
    799     batch_size = test_param["batch_size"]
    800     input_channels = test_param["input_channels"]
    801     output_channels = test_param["output_channels"]
    802     input_height = test_param["input_height"]
    803     input_width = test_param["input_width"]
    804     filter_height = test_param["filter_height"]
    805     filter_width = test_param["filter_width"]
    806     vertical_stride = test_param["vertical_stride"]
    807     horizontal_stride = test_param["horizontal_stride"]
    808     conv_input_scale = test_param["conv_input_scale"]
    809     side_input_scale = test_param["side_input_scale"]
    810     bias_scale = test_param["bias_scale"]
    811     padding_type = test_param["padding_type"]
    812 
    813     conv_input, _, _ = gen_array_ops.quantize_v2(
    814         random_ops.random_uniform(
    815             [batch_size, input_channels // 4, input_height, input_width, 4],
    816             minval=-0.0,
    817             maxval=1.0,
    818             dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
    819 
    820     kernel, _, _ = gen_array_ops.quantize_v2(
    821         random_ops.random_uniform(
    822             [
    823                 output_channels, input_channels // 4, filter_height,
    824                 filter_width, 4
    825             ],
    826             minval=-1.0,
    827             maxval=1.0,
    828             dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
    829 
    830     output_height = CalculateCovolvedOutputDim(input_height, filter_height,
    831                                                vertical_stride, padding_type)
    832     output_width = CalculateCovolvedOutputDim(input_width, filter_width,
    833                                               horizontal_stride, padding_type)
    834     print("output_height=", output_height, ", output_width=", output_width)
    835 
    836     side_input, _, _ = gen_array_ops.quantize_v2(
    837         random_ops.random_uniform(
    838             [batch_size, output_channels // 4, output_height, output_width, 4],
    839             minval=0.0,
    840             maxval=1.0,
    841             dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
    842 
    843     biases = random_ops.random_uniform(
    844         [output_channels],
    845         minval=-10 * bias_scale,
    846         maxval=20 * bias_scale,
    847         dtype=dtypes.float32)
    848 
    849     strides = [1, 1, vertical_stride, horizontal_stride]
    850 
    851     actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
    852         conv_input,
    853         kernel,
    854         biases,
    855         strides=strides,
    856         padding=padding_type,
    857         conv_input_scale=conv_input_scale,
    858         side_input_scale=side_input_scale,
    859         side_input=side_input,
    860         data_format="NCHW_VECT_C",
    861         filter_format="OIHW_VECT_I")
    862 
    863     expected = SimulateFusedConv2dBiasActivationInt8(
    864         conv_input_scale, conv_input, kernel, padding_type, strides,
    865         side_input_scale, side_input, biases)
    866 
    867     with self.test_session(use_gpu=True) as sess:
    868       actual_y, expected_y = sess.run([actual, expected])
    869       print("actual_y = ", actual_y)
    870       print("expected_y = ", expected_y)
    871       self.assertTrue(np.array_equal(actual_y, expected_y))
    872 
    873   def testFusedConvInt8(self):
    874     if not test.is_gpu_available(
    875         cuda_only=True, min_cuda_compute_capability=(6, 1)):
    876       tf_logging.info("int8 test skipped because not run with --config=cuda or "
    877                       "no GPUs with compute capability >= 6.1 are available.")
    878       return
    879     for test_param in self._test_params:
    880       self.runTest(test_param)
    881 
    882 
    883 if __name__ == "__main__":
    884   for index, (input_size_, filter_size_, output_size_, stride_,
    885               padding_) in enumerate(GetShrunkInceptionShapes()):
    886     setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
    887             GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
    888 
    889   # TODO(b/35359731)
    890   # Fwd, BckInput, and BackFilter to test that for certain input parameter
    891   # set, winograd nonfused algorithm will be excluded from conv autotune. If
    892   # in such case, winograd nonfused algorithm is added as one option of the
    893   # conv autotune, and cuDNN version is smaller than 7, the following tests
    894   # will fail.
    895   ishape = [1, 400, 400, 1]
    896   fshape = [1, 1, 1, 256]
    897   oshape = [1, 400, 400, 256]
    898   setattr(FusedConv2DBiasActivationTest,
    899           "testInceptionFwd_No_Winograd_Nonfused",
    900           GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
    901   test.main()
    902