Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for convolutional operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import os
     23 import time
     24 
     25 import numpy as np
     26 
     27 from six.moves import xrange  # pylint: disable=redefined-builtin
     28 from tensorflow.contrib import layers
     29 from tensorflow.python.client import session as session_lib
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import errors_impl
     33 from tensorflow.python.framework import ops
     34 from tensorflow.python.framework import test_util
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import gradient_checker
     37 from tensorflow.python.ops import gradients_impl
     38 from tensorflow.python.ops import nn_impl
     39 from tensorflow.python.ops import nn_ops
     40 from tensorflow.python.ops import random_ops
     41 from tensorflow.python.ops import variables
     42 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.platform import tf_logging
     45 
     46 
     47 def GetShrunkInceptionShapes(shrink=10):
     48   """Iterator for smaller versions of convolution shapes in 2015 Inception.
     49 
     50   Relative to inception, each depth value is `depth // shrink`.
     51 
     52   Args:
     53     shrink: Factor to shrink each depth value by relative to Inception.
     54 
     55   Yields:
     56     Tuple (input_size, filter_size, out_size, stride, padding), the convolution
     57     parameters of Inception layers.
     58   """
     59   input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384],
     60                  [4, 8, 8, 2048], [4, 8, 8, 448], [4, 8, 8, 2048],
     61                  [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 1760],
     62                  [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760],
     63                  [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1248],
     64                  [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224],
     65                  [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 1216],
     66                  [4, 17, 17, 1216], [4, 17, 17, 224], [4, 17, 17, 192],
     67                  [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152],
     68                  [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 1152],
     69                  [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
     70                  [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128],
     71                  [4, 17, 17, 768], [4, 17, 17, 128], [4, 17, 17, 128],
     72                  [4, 17, 17, 768], [4, 17, 17, 768], [4, 35, 35, 96],
     73                  [4, 35, 35, 288], [4, 35, 35, 64], [4, 35, 35, 288],
     74                  [4, 35, 35, 256], [4, 35, 35, 48], [4, 35, 35, 256],
     75                  [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192],
     76                  [4, 35, 35, 192], [4, 73, 73, 64], [4, 73, 73, 64],
     77                  [4, 147, 147, 24]]
     78   filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384],
     79                   [1, 1, 2048, 192], [3, 3, 448, 384], [1, 1, 2048, 320],
     80                   [1, 1, 2048, 448], [1, 1, 2048, 384], [1, 1, 1760, 384],
     81                   [1, 1, 1760, 192], [1, 1, 1760, 448], [1, 1, 1760, 320],
     82                   [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192],
     83                   [3, 3, 128, 320], [1, 1, 1248, 128], [1, 3, 224, 224],
     84                   [3, 1, 192, 256], [1, 3, 192, 256], [1, 1, 1216, 192],
     85                   [1, 1, 1216, 96], [3, 1, 224, 224], [3, 3, 192, 224],
     86                   [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128],
     87                   [3, 1, 192, 192], [3, 3, 160, 192], [1, 1, 1152, 160],
     88                   [1, 1, 1024, 128], [1, 3, 128, 192], [1, 1, 1024, 160],
     89                   [3, 1, 128, 192], [1, 1, 1024, 256], [3, 1, 128, 128],
     90                   [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128],
     91                   [1, 1, 768, 128], [1, 1, 768, 320], [3, 3, 96, 96],
     92                   [3, 3, 288, 384], [3, 3, 64, 96], [1, 1, 288, 64],
     93                   [1, 1, 256, 64], [5, 5, 48, 64], [1, 1, 256, 48],
     94                   [3, 3, 96, 96], [1, 1, 192, 32], [1, 1, 192, 64],
     95                   [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64, 64],
     96                   [1, 1, 24, 64]]
     97   out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384],
     98                [4, 8, 8, 192], [4, 8, 8, 384], [4, 8, 8, 320],
     99                [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384],
    100                [4, 8, 8, 192], [4, 8, 8, 448], [4, 8, 8, 320],
    101                [4, 8, 8, 192], [4, 17, 17, 192], [4, 17, 17, 192],
    102                [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224],
    103                [4, 17, 17, 256], [4, 17, 17, 256], [4, 17, 17, 192],
    104                [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224],
    105                [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 128],
    106                [4, 17, 17, 192], [4, 17, 17, 192], [4, 17, 17, 160],
    107                [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160],
    108                [4, 17, 17, 192], [4, 17, 17, 256], [4, 17, 17, 128],
    109                [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128],
    110                [4, 17, 17, 128], [4, 17, 17, 320], [4, 17, 17, 96],
    111                [4, 17, 17, 384], [4, 35, 35, 96], [4, 35, 35, 64],
    112                [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48],
    113                [4, 35, 35, 96], [4, 35, 35, 32], [4, 35, 35, 64],
    114                [4, 35, 35, 48], [4, 71, 71, 192], [4, 73, 73, 64],
    115                [4, 147, 147, 64]]
    116   strides = [
    117       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    118       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
    119       1, 1, 1, 1, 1
    120   ]
    121   # Shrink sizes to make the test faster
    122   for i in input_sizes:
    123     i[3] //= shrink
    124   for f in filter_sizes:
    125     f[2] //= shrink
    126     f[3] //= shrink
    127   for o in out_sizes:
    128     o[3] //= shrink
    129   # pylint: disable=invalid-name
    130   VALID = "VALID"
    131   SAME = "SAME"
    132   # pylint: enable=invalid-name
    133   paddings = [
    134       SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    135       VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    136       SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
    137       SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
    138       SAME, SAME, SAME, SAME, VALID, VALID, VALID
    139   ]
    140   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
    141                            paddings):
    142     yield i, f, o, s, p
    143 
    144 
    145 def GetTestConfigs():
    146   """Get all the valid tests configs to run.
    147 
    148   Returns:
    149     all the valid test configs as tuples of data_format and use_gpu.
    150   """
    151   test_configs = [("NHWC", False), ("NHWC", True)]
    152   if test.is_gpu_available(cuda_only=True):
    153     # "NCHW" format is only supported on CUDA.
    154     test_configs += [("NCHW", True)]
    155   return test_configs
    156 
    157 
    158 class Conv2DTest(test.TestCase):
    159 
    160   def _DtypesToTest(self, use_gpu):
    161     if use_gpu and not test_util.CudaSupportsHalfMatMulAndConv():
    162       return [dtypes.float32]
    163     else:
    164       # It is important that float32 comes before float16 here,
    165       # as we will be using its gradients as reference for fp16 gradients.
    166       return [dtypes.float32, dtypes.float16]
    167 
    168   def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, dilations,
    169                             strides, padding, data_format, dtype, use_gpu):
    170     """Verifies the output values of the convolution function.
    171 
    172     Args:
    173       tensor_in_sizes: Input tensor dimensions in
    174         [batch, input_rows, input_cols, input_depth].
    175       filter_in_sizes: Filter tensor dimensions in
    176         [kernel_rows, kernel_cols, input_depth, output_depth].
    177       dilations: Dilated rate: [col_dilation, row_dilation]
    178       strides: Stride: [col_stride, row_stride]
    179       padding: Padding type.
    180       data_format: Format of the data tensors.
    181       dtype: Data type for inputs and outputs.
    182       use_gpu: True if the operations should be run on GPU
    183     Returns:
    184       Symbolic tensor value that can be used to execute the computation
    185     """
    186     total_size_1 = 1
    187     total_size_2 = 1
    188     for s in tensor_in_sizes:
    189       total_size_1 *= s
    190     for s in filter_in_sizes:
    191       total_size_2 *= s
    192     # Initializes the input tensor with array containing incrementing
    193     # numbers from 1.
    194     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
    195     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
    196 
    197     with test_util.device(use_gpu):
    198       t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
    199       t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
    200       strides = [1] + strides + [1]
    201       dilations = [1] + dilations + [1]
    202       if data_format == "NCHW":
    203         t1 = test_util.NHWCToNCHW(t1)
    204         strides = test_util.NHWCToNCHW(strides)
    205         dilations = test_util.NHWCToNCHW(dilations)
    206       conv = nn_ops.conv2d(
    207           t1,
    208           t2,
    209           dilations=dilations,
    210           strides=strides,
    211           padding=padding,
    212           data_format=data_format)
    213       if data_format == "NCHW":
    214         conv = test_util.NCHWToNHWC(conv)
    215 
    216       return conv
    217 
    218   def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
    219                         padding):
    220     """Verifies that CPU and GPU produce the same values.
    221 
    222     Args:
    223       tensor_in_sizes: Input tensor dimensions in
    224         [batch, input_rows, input_cols, input_depth].
    225       filter_in_sizes: Filter tensor dimensions in
    226         [kernel_rows, kernel_cols, input_depth, output_depth].
    227       conv_strides: [row_stride, col_stride] for the convolution;
    228       padding: Padding type.
    229     """
    230     x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
    231     x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
    232 
    233     def _SetupVal(data_format, use_gpu):
    234       with test_util.device(use_gpu):
    235         t1 = constant_op.constant(x1, shape=tensor_in_sizes)
    236         t2 = constant_op.constant(x2, shape=filter_in_sizes)
    237         strides = [1] + conv_strides + [1]
    238         if data_format == "NCHW":
    239           t1 = test_util.NHWCToNCHW(t1)
    240           strides = test_util.NHWCToNCHW(strides)
    241         conv = nn_ops.conv2d(
    242             t1, t2, strides=strides, padding=padding, data_format=data_format)
    243         if data_format == "NCHW":
    244           conv = test_util.NCHWToNHWC(conv)
    245         return conv
    246 
    247     tensors = []
    248     for (data_format, use_gpu) in GetTestConfigs():
    249       tensors.append(_SetupVal(data_format, use_gpu))
    250     values = self.evaluate(tensors)
    251     for i in range(1, len(values)):
    252       self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
    253 
    254   def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes,
    255                                    stride, dilation, padding, data_format,
    256                                    use_gpu):
    257     total_size_1 = 1
    258     total_size_2 = 1
    259     for s in tensor_in_sizes:
    260       total_size_1 *= s
    261     for s in filter_in_sizes:
    262       total_size_2 *= s
    263 
    264     # Initializes the input tensor with array containing incrementing
    265     # numbers from 1.
    266     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
    267     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
    268     with test_util.device(use_gpu):
    269       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
    270       t2 = constant_op.constant(x2, shape=filter_in_sizes)
    271       if isinstance(stride, collections.Iterable):
    272         strides = list(stride)
    273       else:
    274         strides = [stride, stride]
    275       if data_format == "NCHW":
    276         t1 = test_util.NHWCToNCHW(t1)
    277         full_strides = [1, 1] + strides
    278         full_dilation = [1, 1] + dilation
    279       else:
    280         full_strides = [1] + strides + [1]
    281         full_dilation = [1] + dilation + [1]
    282       expected = nn_ops.convolution(
    283           t1,
    284           t2,
    285           padding=padding,
    286           strides=strides,
    287           dilation_rate=dilation,
    288           data_format=data_format)
    289       computed = nn_ops.conv2d(
    290           t1,
    291           t2,
    292           strides=full_strides,
    293           dilations=full_dilation,
    294           padding=padding,
    295           data_format=data_format)
    296       if data_format == "NCHW":
    297         expected = test_util.NCHWToNHWC(expected)
    298         computed = test_util.NCHWToNHWC(computed)
    299     return expected, computed
    300 
    301   def _VerifyDilatedConvValues(self, tensor_in_sizes, filter_in_sizes, strides,
    302                                padding, dilations):
    303     expected_results = []
    304     computed_results = []
    305     for data_format, use_gpu in GetTestConfigs():
    306       expected, computed = self._ComputeReferenceDilatedConv(
    307           tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
    308           data_format, use_gpu)
    309       expected_results.append(expected)
    310       computed_results.append(computed)
    311       tolerance = 1e-2 if use_gpu else 1e-5
    312       expected_values = self.evaluate(expected_results)
    313       computed_values = self.evaluate(computed_results)
    314       for e_value, c_value in zip(expected_values, computed_values):
    315         print("expected = ", e_value)
    316         print("actual = ", c_value)
    317         self.assertAllClose(
    318             e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
    319 
    320   def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
    321                     expected):
    322     tensors = []
    323     dilations = [1, 1]
    324     for (data_format, use_gpu) in GetTestConfigs():
    325       for dtype in self._DtypesToTest(use_gpu):
    326         result = self._SetupValuesForDevice(
    327             tensor_in_sizes,
    328             filter_in_sizes,
    329             dilations,
    330             strides,
    331             padding,
    332             data_format,
    333             dtype,
    334             use_gpu=use_gpu)
    335         tensors.append(result)
    336       values = self.evaluate(tensors)
    337       for i in range(len(tensors)):
    338         conv = tensors[i]
    339         value = values[i]
    340         print("expected = ", expected)
    341         print("actual = ", value)
    342         tol = 1e-5
    343         if value.dtype == np.float16:
    344           tol = 1e-3
    345         self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol)
    346         self.assertShapeEqual(value, conv)
    347 
    348   @test_util.run_in_graph_and_eager_modes()
    349   def testConv2D1x1Filter(self):
    350     expected_output = [
    351         30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
    352         204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
    353     ]
    354     self._VerifyValues(
    355         tensor_in_sizes=[1, 2, 3, 3],
    356         filter_in_sizes=[1, 1, 3, 3],
    357         strides=[1, 1],
    358         padding="VALID",
    359         expected=expected_output)
    360 
    361   @test_util.run_in_graph_and_eager_modes()
    362   def testConv2D2x2Filter2x1Dilation(self):
    363     self._VerifyDilatedConvValues(
    364         tensor_in_sizes=[1, 4, 4, 1],
    365         filter_in_sizes=[2, 2, 1, 1],
    366         strides=[1, 1],
    367         dilations=[2, 1],
    368         padding="VALID")
    369 
    370   @test_util.run_in_graph_and_eager_modes()
    371   def testConv2DEmpty(self):
    372     expected_output = []
    373     self._VerifyValues(
    374         tensor_in_sizes=[0, 2, 3, 3],
    375         filter_in_sizes=[1, 1, 3, 3],
    376         strides=[1, 1],
    377         padding="VALID",
    378         expected=expected_output)
    379 
    380   @test_util.run_in_graph_and_eager_modes()
    381   def testConv2DEmptyDilation(self):
    382     self._VerifyDilatedConvValues(
    383         tensor_in_sizes=[0, 2, 3, 3],
    384         filter_in_sizes=[1, 1, 3, 3],
    385         strides=[1, 1],
    386         dilations=[2, 1],
    387         padding="VALID")
    388 
    389   @test_util.run_in_graph_and_eager_modes()
    390   def testConv2D2x2Filter(self):
    391     # The outputs are computed using third_party/py/IPython/notebook.
    392     expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
    393     self._VerifyValues(
    394         tensor_in_sizes=[1, 2, 3, 3],
    395         filter_in_sizes=[2, 2, 3, 3],
    396         strides=[1, 1],
    397         padding="VALID",
    398         expected=expected_output)
    399 
    400   @test_util.run_in_graph_and_eager_modes()
    401   def testConv2D2x2FilterDilation(self):
    402     self._VerifyDilatedConvValues(
    403         tensor_in_sizes=[1, 2, 3, 3],
    404         filter_in_sizes=[2, 2, 3, 3],
    405         strides=[1, 1],
    406         dilations=[1, 2],
    407         padding="VALID")
    408 
    409   @test_util.run_in_graph_and_eager_modes()
    410   def testConv2D1x2Filter(self):
    411     # The outputs are computed using third_party/py/IPython/notebook.
    412     expected_output = [
    413         231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0,
    414         936.0, 1029.0
    415     ]
    416     self._VerifyValues(
    417         tensor_in_sizes=[1, 2, 3, 3],
    418         filter_in_sizes=[1, 2, 3, 3],
    419         strides=[1, 1],
    420         padding="VALID",
    421         expected=expected_output)
    422 
    423   @test_util.run_in_graph_and_eager_modes()
    424   def testConv2D1x2FilterDilation(self):
    425     self._VerifyDilatedConvValues(
    426         tensor_in_sizes=[1, 2, 3, 3],
    427         filter_in_sizes=[1, 2, 3, 3],
    428         strides=[1, 1],
    429         dilations=[2, 1],
    430         padding="VALID")
    431 
    432   @test_util.run_in_graph_and_eager_modes()
    433   def testConv2D2x2FilterStride2(self):
    434     expected_output = [2271.0, 2367.0, 2463.0]
    435     self._VerifyValues(
    436         tensor_in_sizes=[1, 2, 3, 3],
    437         filter_in_sizes=[2, 2, 3, 3],
    438         strides=[2, 2],
    439         padding="VALID",
    440         expected=expected_output)
    441 
    442   @test_util.run_in_graph_and_eager_modes()
    443   def testConv2D2x2FilterStride2Same(self):
    444     expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
    445     self._VerifyValues(
    446         tensor_in_sizes=[1, 2, 3, 3],
    447         filter_in_sizes=[2, 2, 3, 3],
    448         strides=[2, 2],
    449         padding="SAME",
    450         expected=expected_output)
    451 
    452   @test_util.run_in_graph_and_eager_modes()
    453   def testConv2D2x2FilterStride1x2(self):
    454     expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0]
    455     self._VerifyValues(
    456         tensor_in_sizes=[1, 3, 6, 1],
    457         filter_in_sizes=[2, 2, 1, 1],
    458         strides=[1, 2],
    459         padding="VALID",
    460         expected=expected_output)
    461 
    462   @test_util.run_in_graph_and_eager_modes()
    463   def testConv2DKernelSmallerThanStrideValid(self):
    464     expected_output = [65, 95, 275, 305]
    465     self._VerifyValues(
    466         tensor_in_sizes=[1, 7, 7, 1],
    467         filter_in_sizes=[2, 2, 1, 1],
    468         strides=[3, 3],
    469         padding="VALID",
    470         expected=expected_output)
    471 
    472   @test_util.run_in_graph_and_eager_modes()
    473   def testConv2DKernelSmallerThanStrideSame(self):
    474     self._VerifyValues(
    475         tensor_in_sizes=[1, 3, 3, 1],
    476         filter_in_sizes=[1, 1, 1, 1],
    477         strides=[2, 2],
    478         padding="SAME",
    479         expected=[1, 3, 7, 9])
    480 
    481     self._VerifyValues(
    482         tensor_in_sizes=[1, 4, 4, 1],
    483         filter_in_sizes=[1, 1, 1, 1],
    484         strides=[2, 2],
    485         padding="SAME",
    486         expected=[1, 3, 9, 11])
    487 
    488     self._VerifyValues(
    489         tensor_in_sizes=[1, 4, 4, 1],
    490         filter_in_sizes=[2, 2, 1, 1],
    491         strides=[3, 3],
    492         padding="SAME",
    493         expected=[44, 28, 41, 16])
    494 
    495   @test_util.run_in_graph_and_eager_modes()
    496   def testConv2DKernelSizeMatchesInputSize(self):
    497     self._VerifyValues(
    498         tensor_in_sizes=[1, 2, 2, 1],
    499         filter_in_sizes=[2, 2, 1, 2],
    500         strides=[1, 1],
    501         padding="VALID",
    502         expected=[50, 60])
    503 
    504   @test_util.run_in_graph_and_eager_modes()
    505   def testConv2DKernelSizeMatchesInputSizeDilation(self):
    506     self._VerifyDilatedConvValues(
    507         tensor_in_sizes=[1, 3, 3, 1],
    508         filter_in_sizes=[2, 2, 1, 2],
    509         strides=[1, 1],
    510         dilations=[2, 2],
    511         padding="VALID")
    512 
    513   # TODO(yzhwang): this currently fails.
    514   # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
    515   #                   filter_in_sizes=[2, 2, 1, 1],
    516   #                   strides=[4, 4], padding="SAME",
    517   #                   expected=[72, 112, 392, 432])
    518 
    519   # Testing for backprops
    520   def _RunAndVerifyBackpropInput(self, input_sizes, filter_sizes, output_sizes,
    521                                  strides, padding, expected, data_format,
    522                                  use_gpu, err):
    523     total_output_size = 1
    524     total_filter_size = 1
    525     for s in output_sizes:
    526       total_output_size *= s
    527     for s in filter_sizes:
    528       total_filter_size *= s
    529     # Initializes the input tensor with array containing incrementing
    530     # numbers from 1.
    531     x1 = [f * 1.0 for f in range(1, total_filter_size + 1)]
    532     x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
    533     with test_util.device(use_gpu):
    534       if data_format == "NCHW":
    535         input_sizes = test_util.NHWCToNCHW(input_sizes)
    536       t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
    537       t1 = constant_op.constant(x1, shape=filter_sizes)
    538       t2 = constant_op.constant(x2, shape=output_sizes)
    539       strides = [1] + strides + [1]
    540       if data_format == "NCHW":
    541         t2 = test_util.NHWCToNCHW(t2)
    542         strides = test_util.NHWCToNCHW(strides)
    543       conv = nn_ops.conv2d_backprop_input(
    544           t0, t1, t2, strides=strides, padding=padding, data_format=data_format)
    545       if data_format == "NCHW":
    546         conv = test_util.NCHWToNHWC(conv)
    547       # "values" consists of two tensors for two backprops
    548       value = self.evaluate(conv)
    549       self.assertShapeEqual(value, conv)
    550     print("expected = ", expected)
    551     print("actual = ", value)
    552     self.assertArrayNear(expected, value.flatten(), err)
    553 
    554   def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
    555                             conv_strides, padding):
    556     x1 = np.random.rand(*filter_sizes).astype(np.float32)
    557     x2 = np.random.rand(*output_sizes).astype(np.float32)
    558 
    559     def _GetVal(data_format, use_gpu):
    560       with test_util.device(use_gpu):
    561         if data_format == "NCHW":
    562           new_input_sizes = test_util.NHWCToNCHW(input_sizes)
    563         else:
    564           new_input_sizes = input_sizes
    565         t0 = constant_op.constant(new_input_sizes, shape=[len(new_input_sizes)])
    566         t1 = constant_op.constant(x1, shape=filter_sizes)
    567         t2 = constant_op.constant(x2, shape=output_sizes)
    568         strides = [1] + conv_strides + [1]
    569         if data_format == "NCHW":
    570           t2 = test_util.NHWCToNCHW(t2)
    571           strides = test_util.NHWCToNCHW(strides)
    572         conv = nn_ops.conv2d_backprop_input(
    573             t0,
    574             t1,
    575             t2,
    576             strides=strides,
    577             padding=padding,
    578             data_format=data_format)
    579         if data_format == "NCHW":
    580           conv = test_util.NCHWToNHWC(conv)
    581         ret = self.evaluate(conv)
    582         self.assertShapeEqual(ret, conv)
    583         return ret
    584 
    585     values = []
    586     for (data_format, use_gpu) in GetTestConfigs():
    587       values.append(_GetVal(data_format, use_gpu))
    588 
    589     for i in range(1, len(values)):
    590       self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
    591 
    592   @test_util.run_in_graph_and_eager_modes()
    593   def testConv2D2x2Depth1ValidBackpropInput(self):
    594     expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0]
    595     for (data_format, use_gpu) in GetTestConfigs():
    596       self._RunAndVerifyBackpropInput(
    597           input_sizes=[1, 2, 3, 1],
    598           filter_sizes=[2, 2, 1, 1],
    599           output_sizes=[1, 1, 2, 1],
    600           strides=[1, 1],
    601           padding="VALID",
    602           expected=expected_output,
    603           data_format=data_format,
    604           use_gpu=use_gpu,
    605           err=1e-5)
    606 
    607   @test_util.run_in_graph_and_eager_modes()
    608   def testConv2DEmptyBackpropInput(self):
    609     expected_output = []
    610     for (data_format, use_gpu) in GetTestConfigs():
    611       self._RunAndVerifyBackpropInput(
    612           input_sizes=[0, 2, 3, 1],
    613           filter_sizes=[2, 2, 1, 1],
    614           output_sizes=[0, 1, 2, 1],
    615           strides=[1, 1],
    616           padding="VALID",
    617           expected=expected_output,
    618           data_format=data_format,
    619           use_gpu=use_gpu,
    620           err=1e-5)
    621 
    622   @test_util.run_in_graph_and_eager_modes()
    623   def testConv2D2x2Depth3ValidBackpropInput(self):
    624     expected_output = [
    625         14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0,
    626         140.0, 158.0, 478.0, 541.0, 604.0, 437.0, 482.0, 527.0
    627     ]
    628     for (data_format, use_gpu) in GetTestConfigs():
    629       # The GPU version of this test is not very stable. So adjusting the
    630       # error threshold to 1e-4.
    631       self._RunAndVerifyBackpropInput(
    632           input_sizes=[1, 2, 3, 3],
    633           filter_sizes=[2, 2, 3, 3],
    634           output_sizes=[1, 1, 2, 3],
    635           strides=[1, 1],
    636           padding="VALID",
    637           expected=expected_output,
    638           data_format=data_format,
    639           use_gpu=use_gpu,
    640           err=1e-4)
    641 
    642   @test_util.run_in_graph_and_eager_modes()
    643   def testConv2D2x2Depth3ValidBackpropInputStride1x2(self):
    644     expected_output = [
    645         1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0,
    646         16.0, 15.0, 20.0, 18.0, 24.0
    647     ]
    648     for (data_format, use_gpu) in GetTestConfigs():
    649       self._RunAndVerifyBackpropInput(
    650           input_sizes=[1, 3, 6, 1],
    651           filter_sizes=[2, 2, 1, 1],
    652           output_sizes=[1, 2, 3, 1],
    653           strides=[1, 2],
    654           padding="VALID",
    655           expected=expected_output,
    656           data_format=data_format,
    657           use_gpu=use_gpu,
    658           err=1e-5)
    659 
    660   @test_util.run_in_graph_and_eager_modes()
    661   def testConv2DStrideTwoFilterOneSameBackpropInput(self):
    662     expected_output = [
    663         1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0,
    664         0.0, 0.0
    665     ]
    666     for (data_format, use_gpu) in GetTestConfigs():
    667       self._RunAndVerifyBackpropInput(
    668           input_sizes=[1, 4, 4, 1],
    669           filter_sizes=[1, 1, 1, 1],
    670           output_sizes=[1, 2, 2, 1],
    671           strides=[2, 2],
    672           padding="SAME",
    673           expected=expected_output,
    674           data_format=data_format,
    675           use_gpu=use_gpu,
    676           err=1e-5)
    677 
    678   @test_util.run_in_graph_and_eager_modes()
    679   def testConv2DKernelSizeMatchesInputSizeBackpropInput(self):
    680     expected_output = [5.0, 11.0, 17.0, 23.0]
    681     for (data_format, use_gpu) in GetTestConfigs():
    682       self._RunAndVerifyBackpropInput(
    683           input_sizes=[1, 2, 2, 1],
    684           filter_sizes=[2, 2, 1, 2],
    685           output_sizes=[1, 1, 1, 2],
    686           strides=[1, 1],
    687           padding="VALID",
    688           expected=expected_output,
    689           data_format=data_format,
    690           use_gpu=use_gpu,
    691           err=1e-5)
    692 
    693   # Testing for backprops
    694   def _RunAndVerifyBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
    695                                   strides, padding, expected, data_format,
    696                                   use_gpu):
    697     total_input_size = 1
    698     total_output_size = 1
    699     for s in input_sizes:
    700       total_input_size *= s
    701     for s in output_sizes:
    702       total_output_size *= s
    703     # Initializes the input tensor with array containing incrementing
    704     # numbers from 1.
    705     x0 = [f * 1.0 for f in range(1, total_input_size + 1)]
    706     x2 = [f * 1.0 for f in range(1, total_output_size + 1)]
    707     for dtype in self._DtypesToTest(use_gpu=use_gpu):
    708       with test_util.device(use_gpu):
    709         t0 = constant_op.constant(x0, shape=input_sizes, dtype=dtype)
    710         t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
    711         t2 = constant_op.constant(x2, shape=output_sizes, dtype=dtype)
    712         explicit_strides = [1] + strides + [1]
    713         if data_format == "NCHW":
    714           t0 = test_util.NHWCToNCHW(t0)
    715           t2 = test_util.NHWCToNCHW(t2)
    716           explicit_strides = test_util.NHWCToNCHW(explicit_strides)
    717         conv = nn_ops.conv2d_backprop_filter(
    718             t0,
    719             t1,
    720             t2,
    721             strides=explicit_strides,
    722             padding=padding,
    723             data_format=data_format)
    724         value = self.evaluate(conv)
    725         self.assertShapeEqual(value, conv)
    726       print("expected = ", expected)
    727       print("actual = ", value)
    728       self.assertArrayNear(expected, value.flatten(), 1e-5)
    729 
    730   def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes,
    731                          conv_strides, padding):
    732     x0 = np.random.rand(*input_sizes).astype(np.float32)
    733     x2 = np.random.rand(*output_sizes).astype(np.float32)
    734 
    735     def _GetVal(data_format, use_gpu):
    736       with test_util.device(use_gpu):
    737         t0 = constant_op.constant(x0, shape=input_sizes)
    738         t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
    739         t2 = constant_op.constant(x2, shape=output_sizes)
    740         strides = [1] + conv_strides + [1]
    741         if data_format == "NCHW":
    742           t0 = test_util.NHWCToNCHW(t0)
    743           t2 = test_util.NHWCToNCHW(t2)
    744           strides = test_util.NHWCToNCHW(strides)
    745         conv = nn_ops.conv2d_backprop_filter(
    746             t0,
    747             t1,
    748             t2,
    749             strides=strides,
    750             padding=padding,
    751             data_format=data_format)
    752         ret = self.evaluate(conv)
    753         self.assertShapeEqual(ret, conv)
    754         return ret
    755 
    756     values = []
    757     for (data_format, use_gpu) in GetTestConfigs():
    758       values.append(_GetVal(data_format, use_gpu))
    759     for i in range(1, len(values)):
    760       self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
    761 
    762   @test_util.run_in_graph_and_eager_modes()
    763   def testConv2D2x2Depth1ValidBackpropFilter(self):
    764     expected = [5.0, 8.0, 14.0, 17.0]
    765     for (data_format, use_gpu) in GetTestConfigs():
    766       self._RunAndVerifyBackpropFilter(
    767           input_sizes=[1, 2, 3, 1],
    768           filter_sizes=[2, 2, 1, 1],
    769           output_sizes=[1, 1, 2, 1],
    770           strides=[1, 1],
    771           padding="VALID",
    772           expected=expected,
    773           data_format=data_format,
    774           use_gpu=use_gpu)
    775 
    776   @test_util.run_in_graph_and_eager_modes()
    777   def testConv2DEmptyBackpropFilter(self):
    778     expected = []
    779     for (data_format, use_gpu) in GetTestConfigs():
    780       self._RunAndVerifyBackpropFilter(
    781           input_sizes=[1, 2, 3, 1],
    782           filter_sizes=[2, 2, 1, 0],
    783           output_sizes=[1, 1, 2, 0],
    784           strides=[1, 1],
    785           padding="VALID",
    786           expected=expected,
    787           data_format=data_format,
    788           use_gpu=use_gpu)
    789 
    790   @test_util.run_in_graph_and_eager_modes()
    791   def testConv2DBackpropFilterWithEmptyInput(self):
    792     expected = [0, 0, 0, 0]
    793     for (data_format, use_gpu) in GetTestConfigs():
    794       self._RunAndVerifyBackpropFilter(
    795           input_sizes=[0, 2, 3, 1],
    796           filter_sizes=[2, 2, 1, 1],
    797           output_sizes=[0, 1, 2, 1],
    798           strides=[1, 1],
    799           padding="VALID",
    800           expected=expected,
    801           data_format=data_format,
    802           use_gpu=use_gpu)
    803 
    804   @test_util.run_in_graph_and_eager_modes()
    805   def testConv2D2x2Depth3ValidBackpropFilter(self):
    806     expected = [
    807         17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0,
    808         37.0, 50.0, 63.0, 42.0, 57.0, 72.0, 62.0, 85.0, 108.0, 67.0, 92.0,
    809         117.0, 72.0, 99.0, 126.0, 77.0, 106.0, 135.0, 82.0, 113.0, 144.0, 87.0,
    810         120.0, 153.0
    811     ]
    812     for (data_format, use_gpu) in GetTestConfigs():
    813       self._RunAndVerifyBackpropFilter(
    814           input_sizes=[1, 2, 3, 3],
    815           filter_sizes=[2, 2, 3, 3],
    816           output_sizes=[1, 1, 2, 3],
    817           strides=[1, 1],
    818           padding="VALID",
    819           expected=expected,
    820           data_format=data_format,
    821           use_gpu=use_gpu)
    822 
    823   @test_util.run_in_graph_and_eager_modes()
    824   def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self):
    825     expected = [161.0, 182.0, 287.0, 308.0]
    826     for (data_format, use_gpu) in GetTestConfigs():
    827       self._RunAndVerifyBackpropFilter(
    828           input_sizes=[1, 3, 6, 1],
    829           filter_sizes=[2, 2, 1, 1],
    830           output_sizes=[1, 2, 3, 1],
    831           strides=[1, 2],
    832           padding="VALID",
    833           expected=expected,
    834           data_format=data_format,
    835           use_gpu=use_gpu)
    836 
    837   @test_util.run_in_graph_and_eager_modes()
    838   def testConv2DStrideTwoFilterOneSameBackpropFilter(self):
    839     expected_output = [78.]
    840     for (data_format, use_gpu) in GetTestConfigs():
    841       self._RunAndVerifyBackpropFilter(
    842           input_sizes=[1, 4, 4, 1],
    843           filter_sizes=[1, 1, 1, 1],
    844           output_sizes=[1, 2, 2, 1],
    845           strides=[2, 2],
    846           padding="SAME",
    847           expected=expected_output,
    848           data_format=data_format,
    849           use_gpu=use_gpu)
    850 
    851   @test_util.run_in_graph_and_eager_modes()
    852   def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self):
    853     expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0]
    854     for (data_format, use_gpu) in GetTestConfigs():
    855       self._RunAndVerifyBackpropFilter(
    856           input_sizes=[1, 2, 2, 1],
    857           filter_sizes=[2, 2, 1, 2],
    858           output_sizes=[1, 1, 1, 2],
    859           strides=[1, 1],
    860           padding="VALID",
    861           expected=expected_output,
    862           data_format=data_format,
    863           use_gpu=use_gpu)
    864 
    865   # Testing for backprops
    866   def _RunAndVerifyBackpropInputDilation(self, input_sizes, filter_sizes,
    867                                          output_sizes, strides, dilations,
    868                                          padding, data_format, use_gpu, err):
    869     total_input_size = 1
    870     total_filter_size = 1
    871     for s in input_sizes:
    872       total_input_size *= s
    873     for s in filter_sizes:
    874       total_filter_size *= s
    875     # Initializes the input tensor with array containing incrementing
    876     # numbers from 1.
    877     x1 = [f * 1.0 for f in range(1, total_input_size + 1)]
    878     x2 = [f * 1.0 for f in range(1, total_filter_size + 1)]
    879     default_dilations = (dilations[0] == 1 and dilations[1] == 1)
    880     if default_dilations or use_gpu:
    881       with self.test_session(use_gpu=use_gpu) as sess:
    882         if data_format == "NCHW":
    883           input_sizes = test_util.NHWCToNCHW(input_sizes)
    884         t1 = constant_op.constant(x1, shape=input_sizes)
    885         t2 = constant_op.constant(x2, shape=filter_sizes)
    886         full_strides = [1] + strides + [1]
    887         full_dilations = [1] + dilations + [1]
    888         if data_format == "NCHW":
    889           full_strides = test_util.NHWCToNCHW(full_strides)
    890           full_dilations = test_util.NHWCToNCHW(full_dilations)
    891         conv_forward = nn_ops.conv2d(
    892             t1,
    893             t2,
    894             strides=full_strides,
    895             dilations=full_dilations,
    896             padding=padding,
    897             data_format=data_format)
    898         conv_forward_2 = nn_ops.convolution(
    899             t1,
    900             t2,
    901             padding=padding,
    902             strides=strides,
    903             dilation_rate=dilations,
    904             data_format=data_format)
    905         if data_format == "NCHW":
    906           conv_forward = test_util.NCHWToNHWC(conv_forward)
    907           conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2)
    908         conv = gradients_impl.gradients(conv_forward, t1)[0]
    909         conv_2 = gradients_impl.gradients(conv_forward_2, t1)[0]
    910         # "values" consists of two tensors for two backprops
    911         value = sess.run(conv)
    912         value_2 = sess.run(conv_2)
    913         self.assertShapeEqual(value, conv)
    914         self.assertShapeEqual(value_2, conv_2)
    915       print("expected = ", value_2)
    916       print("actual = ", value)
    917       self.assertArrayNear(value_2.flatten(), value.flatten(), err)
    918 
    919   # Testing for backprops
    920   def _RunAndVerifyBackpropFilterDilation(self, input_sizes, filter_sizes,
    921                                           output_sizes, strides, dilations,
    922                                           padding, data_format, use_gpu, err):
    923     total_input_size = 1
    924     total_filter_size = 1
    925     for s in input_sizes:
    926       total_input_size *= s
    927     for s in filter_sizes:
    928       total_filter_size *= s
    929     # Initializes the input tensor with array containing incrementing
    930     # numbers from 1.
    931     x1 = [f * 1.0 for f in range(1, total_input_size + 1)]
    932     x2 = [f * 1.0 for f in range(1, total_filter_size + 1)]
    933     default_dilations = (dilations[0] == 1 and dilations[1] == 1)
    934     if default_dilations or use_gpu:
    935       with self.test_session(use_gpu=use_gpu) as sess:
    936         if data_format == "NCHW":
    937           input_sizes = test_util.NHWCToNCHW(input_sizes)
    938         t1 = constant_op.constant(x1, shape=input_sizes)
    939         t2 = constant_op.constant(x2, shape=filter_sizes)
    940         full_strides = [1] + strides + [1]
    941         full_dilations = [1] + dilations + [1]
    942         if data_format == "NCHW":
    943           full_strides = test_util.NHWCToNCHW(full_strides)
    944           full_dilations = test_util.NHWCToNCHW(full_dilations)
    945         conv_forward = nn_ops.conv2d(
    946             t1,
    947             t2,
    948             strides=full_strides,
    949             dilations=full_dilations,
    950             padding=padding,
    951             data_format=data_format)
    952         conv_forward_2 = nn_ops.convolution(
    953             t1,
    954             t2,
    955             padding=padding,
    956             strides=strides,
    957             dilation_rate=dilations,
    958             data_format=data_format)
    959         if data_format == "NCHW":
    960           conv_forward = test_util.NCHWToNHWC(conv_forward)
    961           conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2)
    962         conv = gradients_impl.gradients(conv_forward, t2)[0]
    963         conv_2 = gradients_impl.gradients(conv_forward, t2)[0]
    964         value = sess.run(conv)
    965         value_2 = sess.run(conv_2)
    966         self.assertShapeEqual(value, conv)
    967         self.assertShapeEqual(value_2, conv_2)
    968       print("expected = ", value_2)
    969       print("actual = ", value)
    970       self.assertArrayNear(value_2.flatten(), value.flatten(), err)
    971 
    972   def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self):
    973     if test.is_gpu_available(cuda_only=True):
    974       for (data_format, use_gpu) in GetTestConfigs():
    975         self._RunAndVerifyBackpropFilterDilation(
    976             input_sizes=[1, 3, 6, 1],
    977             filter_sizes=[2, 2, 1, 1],
    978             output_sizes=[1, 1, 5, 1],
    979             strides=[1, 1],
    980             dilations=[2, 1],
    981             padding="VALID",
    982             data_format=data_format,
    983             use_gpu=use_gpu,
    984             err=1e-5)
    985 
    986   def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self):
    987     if test.is_gpu_available(cuda_only=True):
    988       for (data_format, use_gpu) in GetTestConfigs():
    989         self._RunAndVerifyBackpropFilterDilation(
    990             input_sizes=[1, 2, 3, 1],
    991             filter_sizes=[2, 2, 1, 1],
    992             output_sizes=[1, 1, 2, 1],
    993             strides=[1, 1],
    994             dilations=[1, 2],
    995             padding="VALID",
    996             data_format=data_format,
    997             use_gpu=use_gpu,
    998             err=1e-5)
    999 
   1000   def testConv2DEmptyBackpropFilterDilation1x2(self):
   1001     if test.is_gpu_available(cuda_only=True):
   1002       for (data_format, use_gpu) in GetTestConfigs():
   1003         self._RunAndVerifyBackpropFilterDilation(
   1004             input_sizes=[1, 2, 3, 1],
   1005             filter_sizes=[2, 2, 1, 0],
   1006             output_sizes=[1, 1, 2, 0],
   1007             strides=[1, 1],
   1008             dilations=[1, 2],
   1009             padding="VALID",
   1010             data_format=data_format,
   1011             use_gpu=use_gpu,
   1012             err=1e-5)
   1013 
   1014   def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self):
   1015     if test.is_gpu_available(cuda_only=True):
   1016       for (data_format, use_gpu) in GetTestConfigs():
   1017         self._RunAndVerifyBackpropFilterDilation(
   1018             input_sizes=[1, 3, 4, 3],
   1019             filter_sizes=[2, 2, 3, 3],
   1020             output_sizes=[1, 1, 2, 3],
   1021             strides=[1, 1],
   1022             dilations=[2, 2],
   1023             padding="VALID",
   1024             data_format=data_format,
   1025             use_gpu=use_gpu,
   1026             err=1e-5)
   1027 
   1028   def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self):
   1029     if test.is_gpu_available(cuda_only=True):
   1030       for (data_format, use_gpu) in GetTestConfigs():
   1031         self._RunAndVerifyBackpropFilterDilation(
   1032             input_sizes=[1, 3, 3, 1],
   1033             filter_sizes=[2, 2, 1, 2],
   1034             output_sizes=[1, 1, 1, 2],
   1035             strides=[1, 1],
   1036             dilations=[2, 2],
   1037             padding="VALID",
   1038             data_format=data_format,
   1039             use_gpu=use_gpu,
   1040             err=1e-5)
   1041 
   1042   def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self):
   1043     if test.is_gpu_available(cuda_only=True):
   1044       for (data_format, use_gpu) in GetTestConfigs():
   1045         self._RunAndVerifyBackpropInputDilation(
   1046             input_sizes=[1, 3, 6, 1],
   1047             filter_sizes=[2, 2, 1, 1],
   1048             output_sizes=[1, 1, 5, 1],
   1049             strides=[1, 1],
   1050             dilations=[2, 1],
   1051             padding="VALID",
   1052             data_format=data_format,
   1053             use_gpu=use_gpu,
   1054             err=1e-5)
   1055 
   1056   def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self):
   1057     if test.is_gpu_available(cuda_only=True):
   1058       for (data_format, use_gpu) in GetTestConfigs():
   1059         self._RunAndVerifyBackpropInputDilation(
   1060             input_sizes=[1, 2, 3, 1],
   1061             filter_sizes=[2, 2, 1, 1],
   1062             output_sizes=[1, 1, 2, 1],
   1063             strides=[1, 1],
   1064             dilations=[1, 2],
   1065             padding="VALID",
   1066             data_format=data_format,
   1067             use_gpu=use_gpu,
   1068             err=1e-5)
   1069 
   1070   def testConv2DEmptyBackpropInputDilation1x2(self):
   1071     if test.is_gpu_available(cuda_only=True):
   1072       for (data_format, use_gpu) in GetTestConfigs():
   1073         self._RunAndVerifyBackpropInputDilation(
   1074             input_sizes=[0, 2, 3, 1],
   1075             filter_sizes=[2, 2, 1, 1],
   1076             output_sizes=[0, 1, 2, 1],
   1077             strides=[1, 1],
   1078             dilations=[1, 2],
   1079             padding="VALID",
   1080             data_format=data_format,
   1081             use_gpu=use_gpu,
   1082             err=1e-5)
   1083 
   1084   def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self):
   1085     if test.is_gpu_available(cuda_only=True):
   1086       for (data_format, use_gpu) in GetTestConfigs():
   1087         # The GPU version of this test is not very stable. So adjusting the
   1088         # error threshold to 1e-4.
   1089         self._RunAndVerifyBackpropInputDilation(
   1090             input_sizes=[1, 3, 2, 3],
   1091             filter_sizes=[2, 2, 3, 3],
   1092             output_sizes=[1, 1, 2, 3],
   1093             strides=[1, 1],
   1094             dilations=[2, 1],
   1095             padding="VALID",
   1096             data_format=data_format,
   1097             use_gpu=use_gpu,
   1098             err=1e-4)
   1099 
   1100   def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self):
   1101     if test.is_gpu_available(cuda_only=True):
   1102       for (data_format, use_gpu) in GetTestConfigs():
   1103         self._RunAndVerifyBackpropInputDilation(
   1104             input_sizes=[1, 3, 3, 1],
   1105             filter_sizes=[2, 2, 1, 2],
   1106             output_sizes=[1, 1, 1, 2],
   1107             strides=[1, 1],
   1108             dilations=[2, 2],
   1109             padding="VALID",
   1110             data_format=data_format,
   1111             use_gpu=use_gpu,
   1112             err=1e-5)
   1113 
   1114   # Gradient checkers
   1115   def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows,
   1116                                filter_cols, in_depth, out_depth, stride_rows,
   1117                                stride_cols, padding, test_input, data_format,
   1118                                use_gpu):
   1119     input_shape = [batch, input_rows, input_cols, in_depth]
   1120     filter_shape = [filter_rows, filter_cols, in_depth, out_depth]
   1121     # TODO(yangke): re-factor the computation of output shape.
   1122     if padding == "VALID":
   1123       output_rows = (input_rows - filter_rows + stride_rows) // stride_rows
   1124       output_cols = (input_cols - filter_cols + stride_cols) // stride_cols
   1125     else:
   1126       output_rows = (input_rows + stride_rows - 1) // stride_rows
   1127       output_cols = (input_cols + stride_cols - 1) // stride_cols
   1128     output_shape = [batch, output_rows, output_cols, out_depth]
   1129     input_size = 1
   1130     for x in input_shape:
   1131       input_size *= x
   1132     filter_size = 1
   1133     for x in filter_shape:
   1134       filter_size *= x
   1135     input_data = [x * 1.0 / input_size for x in range(0, input_size)]
   1136     filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
   1137     # Conv2DGrad functions are not compiled for double due to
   1138     # a problem in the way Eigen's Conv2DGrad works for double.
   1139     # So we disable the DOUBLE path.  We should re-enable this
   1140     # when double support returns for CPU and/or GPU.
   1141     for dtype in self._DtypesToTest(use_gpu=use_gpu):
   1142       with self.test_session(use_gpu=use_gpu):
   1143         input_tensor = constant_op.constant(
   1144             input_data, shape=input_shape, dtype=dtype, name="input")
   1145         filter_tensor = constant_op.constant(
   1146             filter_data, shape=filter_shape, dtype=dtype, name="filter")
   1147         strides = [1, stride_rows, stride_cols, 1]
   1148         if data_format == "NCHW":
   1149           new_input_tensor = test_util.NHWCToNCHW(input_tensor)
   1150           strides = test_util.NHWCToNCHW(strides)
   1151         else:
   1152           new_input_tensor = input_tensor
   1153         conv = nn_ops.conv2d(
   1154             new_input_tensor,
   1155             filter_tensor,
   1156             strides,
   1157             padding,
   1158             data_format=data_format,
   1159             name="conv")
   1160         if data_format == "NCHW":
   1161           conv = test_util.NCHWToNHWC(conv)
   1162         self.assertEqual(output_shape, conv.get_shape())
   1163         if test_input:
   1164           jacob_t, jacob_n = gradient_checker.compute_gradient(input_tensor,
   1165                                                                input_shape,
   1166                                                                conv,
   1167                                                                output_shape)
   1168         else:
   1169           jacob_t, jacob_n = gradient_checker.compute_gradient(filter_tensor,
   1170                                                                filter_shape,
   1171                                                                conv,
   1172                                                                output_shape)
   1173         if dtype == dtypes.float32:
   1174           reference_jacob_t = jacob_t
   1175           err = np.fabs(jacob_t - jacob_n).max()
   1176         else:
   1177           # Compare fp16 theoretical gradients to fp32 theoretical gradients,
   1178           # since fp16 numerical gradients are too imprecise.
   1179           err = np.fabs(jacob_t - reference_jacob_t).max()
   1180 
   1181         print("conv_2d gradient error = ", err)
   1182         self.assertLess(err, 0.002)
   1183 
   1184   def testInputGradientValidPaddingStrideOne(self):
   1185     for (data_format, use_gpu) in GetTestConfigs():
   1186       self.ConstructAndTestGradient(
   1187           batch=2,
   1188           input_rows=5,
   1189           input_cols=4,
   1190           filter_rows=3,
   1191           filter_cols=3,
   1192           in_depth=2,
   1193           out_depth=3,
   1194           stride_rows=1,
   1195           stride_cols=1,
   1196           padding="VALID",
   1197           test_input=True,
   1198           data_format=data_format,
   1199           use_gpu=use_gpu)
   1200 
   1201   def testFilterGradientValidPaddingStrideOne(self):
   1202     for (data_format, use_gpu) in GetTestConfigs():
   1203       self.ConstructAndTestGradient(
   1204           batch=4,
   1205           input_rows=6,
   1206           input_cols=5,
   1207           filter_rows=2,
   1208           filter_cols=2,
   1209           in_depth=2,
   1210           out_depth=3,
   1211           stride_rows=1,
   1212           stride_cols=1,
   1213           padding="VALID",
   1214           test_input=False,
   1215           data_format=data_format,
   1216           use_gpu=use_gpu)
   1217 
   1218   def testInputGradientValidPaddingStrideTwo(self):
   1219     for (data_format, use_gpu) in GetTestConfigs():
   1220       self.ConstructAndTestGradient(
   1221           batch=2,
   1222           input_rows=4,
   1223           input_cols=5,
   1224           filter_rows=3,
   1225           filter_cols=3,
   1226           in_depth=2,
   1227           out_depth=3,
   1228           stride_rows=2,
   1229           stride_cols=2,
   1230           padding="VALID",
   1231           test_input=True,
   1232           data_format=data_format,
   1233           use_gpu=use_gpu)
   1234 
   1235   def testFilterGradientValidPaddingStrideTwo(self):
   1236     for (data_format, use_gpu) in GetTestConfigs():
   1237       self.ConstructAndTestGradient(
   1238           batch=4,
   1239           input_rows=6,
   1240           input_cols=5,
   1241           filter_rows=2,
   1242           filter_cols=2,
   1243           in_depth=2,
   1244           out_depth=3,
   1245           stride_rows=2,
   1246           stride_cols=2,
   1247           padding="VALID",
   1248           test_input=False,
   1249           data_format=data_format,
   1250           use_gpu=use_gpu)
   1251 
   1252   def testInputGradientValidPaddingStrideThree(self):
   1253     for (data_format, use_gpu) in GetTestConfigs():
   1254       self.ConstructAndTestGradient(
   1255           batch=2,
   1256           input_rows=7,
   1257           input_cols=6,
   1258           filter_rows=3,
   1259           filter_cols=3,
   1260           in_depth=4,
   1261           out_depth=5,
   1262           stride_rows=3,
   1263           stride_cols=3,
   1264           padding="VALID",
   1265           test_input=True,
   1266           data_format=data_format,
   1267           use_gpu=use_gpu)
   1268 
   1269   def testFilterGradientValidPaddingStrideThree(self):
   1270     for (data_format, use_gpu) in GetTestConfigs():
   1271       self.ConstructAndTestGradient(
   1272           batch=2,
   1273           input_rows=8,
   1274           input_cols=7,
   1275           filter_rows=4,
   1276           filter_cols=4,
   1277           in_depth=2,
   1278           out_depth=3,
   1279           stride_rows=3,
   1280           stride_cols=3,
   1281           padding="VALID",
   1282           test_input=False,
   1283           data_format=data_format,
   1284           use_gpu=use_gpu)
   1285 
   1286   def testInputGradientSamePaddingStrideOne(self):
   1287     for (data_format, use_gpu) in GetTestConfigs():
   1288       self.ConstructAndTestGradient(
   1289           batch=2,
   1290           input_rows=7,
   1291           input_cols=6,
   1292           filter_rows=3,
   1293           filter_cols=3,
   1294           in_depth=2,
   1295           out_depth=3,
   1296           stride_rows=1,
   1297           stride_cols=1,
   1298           padding="SAME",
   1299           test_input=True,
   1300           data_format=data_format,
   1301           use_gpu=use_gpu)
   1302 
   1303   def testFilterGradientSamePaddingStrideOne(self):
   1304     for (data_format, use_gpu) in GetTestConfigs():
   1305       self.ConstructAndTestGradient(
   1306           batch=4,
   1307           input_rows=6,
   1308           input_cols=5,
   1309           filter_rows=2,
   1310           filter_cols=2,
   1311           in_depth=2,
   1312           out_depth=3,
   1313           stride_rows=1,
   1314           stride_cols=1,
   1315           padding="SAME",
   1316           test_input=False,
   1317           data_format=data_format,
   1318           use_gpu=use_gpu)
   1319 
   1320   def testInputGradientSamePaddingStrideTwo(self):
   1321     for (data_format, use_gpu) in GetTestConfigs():
   1322       self.ConstructAndTestGradient(
   1323           batch=2,
   1324           input_rows=5,
   1325           input_cols=4,
   1326           filter_rows=3,
   1327           filter_cols=3,
   1328           in_depth=3,
   1329           out_depth=3,
   1330           stride_rows=2,
   1331           stride_cols=2,
   1332           padding="SAME",
   1333           test_input=True,
   1334           data_format=data_format,
   1335           use_gpu=use_gpu)
   1336 
   1337   def testFilterGradientSamePaddingStrideTwo(self):
   1338     for (data_format, use_gpu) in GetTestConfigs():
   1339       self.ConstructAndTestGradient(
   1340           batch=4,
   1341           input_rows=6,
   1342           input_cols=5,
   1343           filter_rows=2,
   1344           filter_cols=2,
   1345           in_depth=2,
   1346           out_depth=3,
   1347           stride_rows=2,
   1348           stride_cols=2,
   1349           padding="SAME",
   1350           test_input=False,
   1351           data_format=data_format,
   1352           use_gpu=use_gpu)
   1353 
   1354   def testInputGradientSamePaddingStrideThree(self):
   1355     for (data_format, use_gpu) in GetTestConfigs():
   1356       self.ConstructAndTestGradient(
   1357           batch=2,
   1358           input_rows=7,
   1359           input_cols=6,
   1360           filter_rows=3,
   1361           filter_cols=3,
   1362           in_depth=4,
   1363           out_depth=5,
   1364           stride_rows=3,
   1365           stride_cols=3,
   1366           padding="SAME",
   1367           test_input=True,
   1368           data_format=data_format,
   1369           use_gpu=use_gpu)
   1370 
   1371   def testFilterGradientSamePaddingStrideThree(self):
   1372     for (data_format, use_gpu) in GetTestConfigs():
   1373       self.ConstructAndTestGradient(
   1374           batch=2,
   1375           input_rows=8,
   1376           input_cols=7,
   1377           filter_rows=4,
   1378           filter_cols=4,
   1379           in_depth=2,
   1380           out_depth=3,
   1381           stride_rows=3,
   1382           stride_cols=3,
   1383           padding="SAME",
   1384           test_input=False,
   1385           data_format=data_format,
   1386           use_gpu=use_gpu)
   1387 
   1388   def testFilterGradientSamePaddingStride2x1(self):
   1389     for (data_format, use_gpu) in GetTestConfigs():
   1390       self.ConstructAndTestGradient(
   1391           batch=2,
   1392           input_rows=8,
   1393           input_cols=7,
   1394           filter_rows=4,
   1395           filter_cols=4,
   1396           in_depth=2,
   1397           out_depth=3,
   1398           stride_rows=2,
   1399           stride_cols=1,
   1400           padding="SAME",
   1401           test_input=False,
   1402           data_format=data_format,
   1403           use_gpu=use_gpu)
   1404 
   1405   def testInputGradientKernelSizeMatchesInputSize(self):
   1406     for (data_format, use_gpu) in GetTestConfigs():
   1407       self.ConstructAndTestGradient(
   1408           batch=2,
   1409           input_rows=4,
   1410           input_cols=3,
   1411           filter_rows=4,
   1412           filter_cols=3,
   1413           in_depth=2,
   1414           out_depth=3,
   1415           stride_rows=1,
   1416           stride_cols=1,
   1417           padding="VALID",
   1418           test_input=True,
   1419           data_format=data_format,
   1420           use_gpu=use_gpu)
   1421 
   1422   def testFilterGradientKernelSizeMatchesInputSize(self):
   1423     for (data_format, use_gpu) in GetTestConfigs():
   1424       self.ConstructAndTestGradient(
   1425           batch=2,
   1426           input_rows=4,
   1427           input_cols=3,
   1428           filter_rows=4,
   1429           filter_cols=3,
   1430           in_depth=2,
   1431           out_depth=3,
   1432           stride_rows=1,
   1433           stride_cols=1,
   1434           padding="VALID",
   1435           test_input=False,
   1436           data_format=data_format,
   1437           use_gpu=use_gpu)
   1438 
   1439   def testShapeFunctionEdgeCases(self):
   1440     # All shapes unknown.
   1441     c1 = nn_ops.conv2d(
   1442         array_ops.placeholder(dtypes.float32),
   1443         array_ops.placeholder(dtypes.float32),
   1444         strides=[1, 1, 1, 1],
   1445         padding="SAME")
   1446     self.assertEqual([None, None, None, None], c1.get_shape().as_list())
   1447 
   1448     # Incorrect input shape.
   1449     with self.assertRaises(ValueError):
   1450       nn_ops.conv2d(
   1451           array_ops.placeholder(
   1452               dtypes.float32, shape=[1, 3]),
   1453           array_ops.placeholder(dtypes.float32),
   1454           strides=[1, 1, 1, 1],
   1455           padding="SAME")
   1456 
   1457     # Incorrect filter shape.
   1458     with self.assertRaises(ValueError):
   1459       nn_ops.conv2d(
   1460           array_ops.placeholder(dtypes.float32),
   1461           array_ops.placeholder(
   1462               dtypes.float32, shape=[1, 3]),
   1463           strides=[1, 1, 1, 1],
   1464           padding="SAME")
   1465 
   1466     # Depth mismatch.
   1467     with self.assertRaises(ValueError):
   1468       nn_ops.conv2d(
   1469           array_ops.placeholder(
   1470               dtypes.float32, shape=[32, 20, 20, 3]),
   1471           array_ops.placeholder(
   1472               dtypes.float32, shape=[4, 4, 2, 2]),
   1473           strides=[1, 1, 1, 1],
   1474           padding="SAME")
   1475 
   1476   def testOpEdgeCases(self):
   1477     with self.test_session() as sess:
   1478       # Illegal strides.
   1479       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1480                                    "strides in the batch and depth"):
   1481         sess.run(
   1482             nn_ops.conv2d(
   1483                 array_ops.placeholder(dtypes.float32),
   1484                 array_ops.placeholder(dtypes.float32),
   1485                 strides=[2, 1, 1, 1],
   1486                 padding="SAME"))
   1487       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
   1488                                    "strides in the batch and depth"):
   1489         sess.run(
   1490             nn_ops.conv2d(
   1491                 array_ops.placeholder(dtypes.float32),
   1492                 array_ops.placeholder(dtypes.float32),
   1493                 strides=[1, 1, 1, 2],
   1494                 padding="SAME"))
   1495 
   1496       # Filter larger than input.
   1497       with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
   1498         sess.run(
   1499             nn_ops.conv2d(
   1500                 array_ops.placeholder(
   1501                     dtypes.float32, shape=[32, 20, 20, 3]),
   1502                 array_ops.placeholder(
   1503                     dtypes.float32, shape=[20, 21, 3, 2]),
   1504                 strides=[1, 1, 1, 1],
   1505                 padding="VALID"))
   1506       with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
   1507         sess.run(
   1508             nn_ops.conv2d(
   1509                 array_ops.placeholder(
   1510                     dtypes.float32, shape=[32, 20, 20, 3]),
   1511                 array_ops.placeholder(
   1512                     dtypes.float32, shape=[21, 20, 3, 2]),
   1513                 strides=[1, 1, 1, 1],
   1514                 padding="VALID"))
   1515 
   1516   def testCPUConv2DNCHWUnimplemented(self):
   1517     with self.test_session(use_gpu=False):
   1518       with self.assertRaisesRegexp(errors_impl.UnimplementedError,
   1519                                    "NHWC tensor format for now"):
   1520         conv = self._SetupValuesForDevice(
   1521             tensor_in_sizes=[1, 4, 4, 1],
   1522             filter_in_sizes=[2, 2, 1, 1],
   1523             dilations=[1, 1],
   1524             strides=[1, 1],
   1525             padding="VALID",
   1526             data_format="NCHW",
   1527             dtype=dtypes.float32,
   1528             use_gpu=False)
   1529         self.evaluate(conv)
   1530 
   1531 
   1532 class DepthwiseConv2DTest(test.TestCase):
   1533 
   1534   def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
   1535                     expected):
   1536     """Verifies the output values of the convolution function.
   1537 
   1538     Args:
   1539       tensor_in_sizes: Input tensor dimensions in
   1540         [batch, input_rows, input_cols, input_depth].
   1541       filter_in_sizes: Filter tensor dimensions in
   1542         [filter_rows, filter_cols, input_depth, depth_multiplier].
   1543       stride: Stride.
   1544       padding: Padding type.
   1545       expected: An array containing the expected operation outputs.
   1546     """
   1547     total_size_1 = 1
   1548     total_size_2 = 1
   1549     for s in tensor_in_sizes:
   1550       total_size_1 *= s
   1551     for s in filter_in_sizes:
   1552       total_size_2 *= s
   1553     # Initializes the input tensor with array containing incrementing
   1554     # numbers from 1.
   1555     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
   1556     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
   1557     with self.test_session() as sess:
   1558       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
   1559       t1.set_shape(tensor_in_sizes)
   1560       t2 = constant_op.constant(x2, shape=filter_in_sizes)
   1561       conv = nn_impl.depthwise_conv2d(
   1562           t1, t2, strides=[1, stride, stride, 1], padding=padding)
   1563       value = sess.run(conv)
   1564     print("value = ", value)
   1565     self.assertArrayNear(expected, np.ravel(value), 1e-5)
   1566     self.assertShapeEqual(value, conv)
   1567 
   1568   def testConv2D2x2Filter(self):
   1569     # The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
   1570     #
   1571     # [ (1.0, 2.0), (3.0,  4.0), ( 5.0,  6.0) ]
   1572     # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
   1573     #  We can view this as two inputs
   1574     #
   1575     #  input depth 0:
   1576     #
   1577     #  [ 1.0,  3.0,  5.0 ]
   1578     #  [ 7.0,  9.0, 11.0 ]
   1579     #
   1580     #  input depth 1:
   1581     #
   1582     #  [ 2.0,  4.0,  6.0 ]
   1583     #  [ 8.0, 10.0, 12.0 ]
   1584     #
   1585     # The filter looks like this (it has two 2 x 2 patches, each generating 2
   1586     # depths):
   1587     #
   1588     #  filter #0:
   1589     #
   1590     #  [ (1.0,  3.0), ( 5.0,  7.0)]
   1591     #  [ (9.0, 11.0), (13.0, 15.0)]
   1592     #
   1593     #  filter #1:
   1594     #
   1595     #  [ ( 2.0,  4.0), ( 6.0,  8.0)]
   1596     #  [ (10.0, 12.0), (14.0, 16.0)]
   1597     #
   1598     # So the outputs are:
   1599     #
   1600     # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
   1601     #  1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
   1602     # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
   1603     #  1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
   1604     # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
   1605     #  2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
   1606     # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
   1607     #  2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
   1608     #
   1609     # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
   1610     #  3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
   1611     # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
   1612     #  3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
   1613     # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
   1614     #  4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
   1615     # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
   1616     #  4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
   1617     expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
   1618     self._VerifyValues(
   1619         tensor_in_sizes=[1, 2, 3, 2],
   1620         filter_in_sizes=[2, 2, 2, 2],
   1621         stride=1,
   1622         padding="VALID",
   1623         expected=expected_output)
   1624 
   1625 
   1626 class SeparableConv2DTest(test.TestCase):
   1627 
   1628   def _InitValues(self, sizes):
   1629     """Initializes values for input tensors.
   1630 
   1631     Args:
   1632       sizes: Tensor dimensions.
   1633 
   1634     Returns:
   1635       Tensor initialized to values.
   1636     """
   1637     total_size = 1
   1638     for s in sizes:
   1639       total_size *= s
   1640     x = [f * 0.5 for f in range(1, total_size + 1)]
   1641     return constant_op.constant(x, shape=sizes)
   1642 
   1643   def _VerifyValues(self,
   1644                     tensor_in_sizes,
   1645                     depthwise_filter_in_sizes,
   1646                     pointwise_filter_in_sizes,
   1647                     stride,
   1648                     padding,
   1649                     expected,
   1650                     data_format="NHWC"):
   1651     """Verifies the output values of the separable convolution function.
   1652 
   1653     Args:
   1654       tensor_in_sizes: Input tensor dimensions.
   1655       depthwise_filter_in_sizes: Depthwise filter tensor dimensions.
   1656       pointwise_filter_in_sizes: Pointwise filter tensor dimensions.
   1657       stride: Stride.
   1658       padding: Padding type.
   1659       expected: An array containing the expected operation outputs.
   1660       data_format: string data format for input tensor.
   1661     """
   1662     with self.test_session(use_gpu=True) as sess:
   1663       t1 = self._InitValues(tensor_in_sizes)
   1664       f1 = self._InitValues(depthwise_filter_in_sizes)
   1665       f1.set_shape(depthwise_filter_in_sizes)
   1666       f2 = self._InitValues(pointwise_filter_in_sizes)
   1667 
   1668       real_t1 = t1
   1669       strides = [1, stride, stride, 1]
   1670       if data_format == "NCHW":
   1671         real_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
   1672         strides = [1, 1, stride, stride]
   1673 
   1674       conv = nn_impl.separable_conv2d(
   1675           real_t1,
   1676           f1,
   1677           f2,
   1678           strides=strides,
   1679           padding=padding,
   1680           data_format=data_format)
   1681 
   1682       if data_format == "NCHW":
   1683         conv = array_ops.transpose(conv, [0, 2, 3, 1])
   1684 
   1685       value = sess.run(conv)
   1686     print("value = ", value)
   1687     self.assertArrayNear(expected, np.ravel(value), 1e-5)
   1688     self.assertShapeEqual(value, conv)
   1689 
   1690   def _testSeparableConv2D(self, data_format):
   1691     # The output is the result of two convolutions:
   1692     # First with tensor_in[1, 4, 4, 2] * filter1[2, 2, 2, 3].
   1693     # Second with intermediate_out[1, 4, 4, 6] * filter2[1, 1, 6, 7].
   1694     # Complexity is O(2*3*2*2 + 6*7*1*1) as opposed to O(2*7*2*2).
   1695     expected_output = [
   1696         6644.5, 6971.5, 7298.5, 7625.5, 7952.5, 8279.5, 8606.5, 8154.5, 8556.5,
   1697         8958.5, 9360.5, 9762.5, 10164.5, 10566.5, 9664.5, 10141.5, 10618.5,
   1698         11095.5, 11572.5, 12049.5, 12526.5, 4145.5, 4346.5, 4547.5, 4748.5,
   1699         4949.5, 5150.5, 5351.5, 12684.5, 13311.5, 13938.5, 14565.5, 15192.5,
   1700         15819.5, 16446.5, 14194.5, 14896.5, 15598.5, 16300.5, 17002.5, 17704.5,
   1701         18406.5, 15704.5, 16481.5, 17258.5, 18035.5, 18812.5, 19589.5, 20366.5,
   1702         6499.5, 6814.5, 7129.5, 7444.5, 7759.5, 8074.5, 8389.5, 18724.5,
   1703         19651.5, 20578.5, 21505.5, 22432.5, 23359.5, 24286.5, 20234.5, 21236.5,
   1704         22238.5, 23240.5, 24242.5, 25244.5, 26246.5, 21744.5, 22821.5, 23898.5,
   1705         24975.5, 26052.5, 27129.5, 28206.5, 8853.5, 9282.5, 9711.5, 10140.5,
   1706         10569.5, 10998.5, 11427.5, 5746.75, 6010.75, 6274.75, 6538.75, 6802.75,
   1707         7066.75, 7330.75, 6168.75, 6452.25, 6735.75, 7019.25, 7302.75, 7586.25,
   1708         7869.75, 6590.75, 6893.75, 7196.75, 7499.75, 7802.75, 8105.75, 8408.75,
   1709         2036.25, 2119.5, 2202.75, 2286.0, 2369.25, 2452.5, 2535.75
   1710     ]
   1711 
   1712     self._VerifyValues(
   1713         tensor_in_sizes=[1, 4, 4, 2],
   1714         depthwise_filter_in_sizes=[2, 2, 2, 3],
   1715         pointwise_filter_in_sizes=[1, 1, 6, 7],
   1716         stride=1,
   1717         padding="SAME",
   1718         expected=expected_output,
   1719         data_format=data_format)
   1720 
   1721   def testSeparableConv2D(self):
   1722     self._testSeparableConv2D("NHWC")
   1723 
   1724   def testSeparableConv2DNCHW(self):
   1725     if not test.is_gpu_available():
   1726       return
   1727     self._testSeparableConv2D("NCHW")
   1728 
   1729   def _testSeparableConv2DEqualInputOutputDepth(self, data_format):
   1730     # The output is the result of two convolutions:
   1731     # First with tensor_in[1, 4, 4, 2] * filter1[2, 2, 3, 3].
   1732     # Second with intermediate_out[1, 4, 4, 6] * filter2[1, 1, 6, 6].
   1733     # Complexity is O(2*3*2*2 + 6*6*1*1) as opposed to O(2*6*2*2).
   1734     expected_output = [
   1735         5742.0, 6069.0, 6396.0, 6723.0, 7050.0, 7377.0, 7047.0, 7449.0, 7851.0,
   1736         8253.0, 8655.0, 9057.0, 8352.0, 8829.0, 9306.0, 9783.0, 10260.0,
   1737         10737.0, 3582.0, 3783.0, 3984.0, 4185.0, 4386.0, 4587.0, 10962.0,
   1738         11589.0, 12216.0, 12843.0, 13470.0, 14097.0, 12267.0, 12969.0, 13671.0,
   1739         14373.0, 15075.0, 15777.0, 13572.0, 14349.0, 15126.0, 15903.0, 16680.0,
   1740         17457.0, 5616.0, 5931.0, 6246.0, 6561.0, 6876.0, 7191.0, 16182.0,
   1741         17109.0, 18036.0, 18963.0, 19890.0, 20817.0, 17487.0, 18489.0, 19491.0,
   1742         20493.0, 21495.0, 22497.0, 18792.0, 19869.0, 20946.0, 22023.0, 23100.0,
   1743         24177.0, 7650.0, 8079.0, 8508.0, 8937.0, 9366.0, 9795.0, 4963.5, 5227.5,
   1744         5491.5, 5755.5, 6019.5, 6283.5, 5328.0, 5611.5, 5895.0, 6178.5, 6462.0,
   1745         6745.5, 5692.5, 5995.5, 6298.5, 6601.5, 6904.5, 7207.5, 1757.25, 1840.5,
   1746         1923.75, 2007.0, 2090.25, 2173.5
   1747     ]
   1748 
   1749     self._VerifyValues(
   1750         tensor_in_sizes=[1, 4, 4, 2],
   1751         depthwise_filter_in_sizes=[2, 2, 2, 3],
   1752         pointwise_filter_in_sizes=[1, 1, 6, 6],
   1753         stride=1,
   1754         padding="SAME",
   1755         expected=expected_output,
   1756         data_format=data_format)
   1757 
   1758   def testSeparableConv2DEqualInputOutputDepth(self):
   1759     self._testSeparableConv2DEqualInputOutputDepth("NHWC")
   1760 
   1761   def testSeparableConv2DEqualInputOutputDepthNCHW(self):
   1762     if not test.is_gpu_available():
   1763       return
   1764     self._testSeparableConv2DEqualInputOutputDepth("NCHW")
   1765 
   1766 
   1767 class DeepConv2DTest(test.TestCase):
   1768 
   1769   def _CompareFwdConv2D(self, tensor_in_sizes, filter_in_sizes, conv_strides,
   1770                         padding):
   1771     """Verifies that DeepConv2D and Conv2D produce the same values.
   1772 
   1773     Args:
   1774       tensor_in_sizes: Input tensor dimensions in
   1775         [batch, input_rows, input_cols, input_depth].
   1776       filter_in_sizes: Filter tensor dimensions in
   1777         [kernel_rows, kernel_cols, input_depth, output_depth].
   1778       conv_strides: [row_stride, col_stride] for the convolution;
   1779       padding: Padding type.
   1780     """
   1781     x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
   1782     x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
   1783 
   1784     with self.test_session(use_gpu=False) as sess:
   1785       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
   1786       t2 = constant_op.constant(x2, shape=filter_in_sizes)
   1787       strides = [1] + conv_strides + [1]
   1788 
   1789       conv = nn_ops.conv2d(t1, t2, strides=strides, padding=padding)
   1790 
   1791       os.environ["TF_USE_DEEP_CONV2D"] = "0"
   1792       values_expect = sess.run([conv])
   1793 
   1794       os.environ["TF_USE_DEEP_CONV2D"] = "1"
   1795       values_test = sess.run([conv])
   1796 
   1797       self.assertAllClose(values_expect, values_test, rtol=1e-5, atol=1e-5)
   1798 
   1799   def _RunTestCases(self, conv_strides, padding):
   1800     input_sizes = [[5, 5, 5, 1248], [3, 17, 17, 192], [2, 35, 35, 288],
   1801                    [2, 6, 8, 517], [2, 7, 4, 81], [3, 11, 3, 77]]
   1802     filter_sizes = [[3, 3, 1248, 128], [3, 3, 192, 192], [3, 3, 288, 384],
   1803                     [3, 3, 517, 64], [3, 3, 81, 77], [3, 3, 77, 181]]
   1804     for input_shape, filter_shape in zip(input_sizes, filter_sizes):
   1805       self._CompareFwdConv2D(input_shape, filter_shape, conv_strides, padding)
   1806 
   1807   def testConv2D3x3FilterStride1x1Valid(self):
   1808     self._RunTestCases([1, 1], "VALID")
   1809 
   1810   def testConv2D3x3FilterStride1x1Same(self):
   1811     self._RunTestCases([1, 1], "SAME")
   1812 
   1813 
   1814 class Conv2DBenchmark(test.Benchmark):
   1815 
   1816   def benchmarkGPUConvStackFirst(self):
   1817     # Benchmark the first iteration of a conv-net with many identical conv
   1818     # operations.
   1819     if not test.is_gpu_available():
   1820       return
   1821 
   1822     with ops.Graph().as_default(), session_lib.Session() as session:
   1823       batch_size = 1
   1824       timesteps = 600
   1825       features = 1
   1826 
   1827       inputs = random_ops.random_uniform(
   1828           [batch_size, 1, timesteps, features], seed=1234)
   1829       num_outputs_list = [512] * 40 + [1]
   1830       kernel_w = 3
   1831       x = inputs
   1832       for num_outputs in num_outputs_list:
   1833         x = layers.convolution2d(x, num_outputs, [1, kernel_w])
   1834       outputs = x
   1835 
   1836       variables.global_variables_initializer().run()
   1837       num_iterations = 4
   1838       for iter_index in xrange(num_iterations):
   1839         start = time.time()
   1840         session.run(outputs)
   1841         wall_time = time.time() - start
   1842         self.report_benchmark(
   1843             name="conv_stack_iter_%d" % iter_index, wall_time=wall_time)
   1844         print("conv_stack_iter_%d: %.4f" % (iter_index, wall_time))
   1845 
   1846 
   1847 def GetInceptionFwdTest(input_size, filter_size, stride, padding,
   1848                         gpu_only=False):
   1849 
   1850   def Test(self):
   1851     if gpu_only and not test.is_gpu_available():
   1852       tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size,
   1853                                                    stride, padding))
   1854       return
   1855     tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride,
   1856                                                 padding))
   1857     self._CompareFwdValues(input_size, filter_size, [stride, stride], padding)
   1858 
   1859   return Test
   1860 
   1861 
   1862 def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
   1863 
   1864   def Test(self):
   1865     if stride == 1:
   1866       tf_logging.info("Testing InceptionFwd with dilations %s",
   1867                       (input_size, filter_size, stride, padding))
   1868       self._VerifyDilatedConvValues(
   1869           tensor_in_sizes=input_size,
   1870           filter_in_sizes=filter_size,
   1871           strides=[stride, stride],
   1872           dilations=[2, 2],
   1873           padding=padding)
   1874 
   1875   return Test
   1876 
   1877 
   1878 def GetInceptionBackInputTest(input_size, filter_size, output_size, stride,
   1879                               padding,
   1880                               gpu_only=False):
   1881 
   1882   def Test(self):
   1883     if gpu_only and not test.is_gpu_available():
   1884       tf_logging.info("Skipping InceptionBackInput %s",
   1885                       (input_size, filter_size, output_size, stride, padding))
   1886       return
   1887     tf_logging.info("Testing InceptionBackInput %s",
   1888                     (input_size, filter_size, output_size, stride, padding))
   1889     self._CompareBackpropInput(input_size, filter_size, output_size,
   1890                                [stride, stride], padding)
   1891 
   1892   return Test
   1893 
   1894 
   1895 def GetInceptionBackFilterTest(input_size, filter_size, output_size, strides,
   1896                                padding, gpu_only=False):
   1897 
   1898   def Test(self):
   1899     if gpu_only and not test.is_gpu_available():
   1900       tf_logging.info("Skipping InceptionBackFilter %s",
   1901                       (input_size, filter_size, output_size, strides, padding))
   1902       return
   1903     tf_logging.info("Testing InceptionBackFilter %s",
   1904                     (input_size, filter_size, output_size, strides, padding))
   1905     self._CompareBackFilter(input_size, filter_size, output_size, strides,
   1906                             padding)
   1907 
   1908   return Test
   1909 
   1910 
   1911 if __name__ == "__main__":
   1912   for index, (input_size_, filter_size_, output_size_, stride_,
   1913               padding_) in enumerate(GetShrunkInceptionShapes()):
   1914     setattr(Conv2DTest, "testInceptionFwd_" + str(index),
   1915             test_util.run_in_graph_and_eager_modes()(
   1916                 GetInceptionFwdTest(input_size_, filter_size_, stride_,
   1917                                     padding_)))
   1918     setattr(
   1919         Conv2DTest, "testInceptionFwdDilatedConv_" + str(index),
   1920         test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest(
   1921             input_size_, filter_size_, stride_, padding_)))
   1922     setattr(Conv2DTest, "testInceptionBackInput_" + str(index),
   1923             test_util.run_in_graph_and_eager_modes()(
   1924                 GetInceptionBackInputTest(input_size_, filter_size_,
   1925                                           output_size_, stride_, padding_)))
   1926     setattr(Conv2DTest, "testInceptionBackFilter_" + str(index),
   1927             test_util.run_in_graph_and_eager_modes()(
   1928                 GetInceptionBackFilterTest(input_size_, filter_size_,
   1929                                            output_size_, [stride_, stride_],
   1930                                            padding_)))
   1931 
   1932   # TODO(b/35359731)
   1933   # Fwd, BckInput, and BackFilter to test that for certain input parameter
   1934   # set, winograd nonfused algorithm will be excluded from conv autotune. If
   1935   # in such case, winograd nonfused algorithm is added as one option of the
   1936   # conv autotune, and cuDNN version is smaller than 7, the following tests
   1937   # will fail.
   1938   ishape = [1, 400, 400, 1]
   1939   fshape = [1, 1, 1, 256]
   1940   oshape = [1, 400, 400, 256]
   1941   setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
   1942           test_util.run_in_graph_and_eager_modes()(
   1943               GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)))
   1944   setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused",
   1945           test_util.run_in_graph_and_eager_modes()(
   1946               GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME")))
   1947   setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
   1948           test_util.run_in_graph_and_eager_modes()(
   1949               GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME",
   1950                                         gpu_only=True)))
   1951   setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
   1952           test_util.run_in_graph_and_eager_modes()(
   1953               GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME",
   1954                                          gpu_only=True)))
   1955   test.main()
   1956