Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for depthwise convolutional operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.ops import nn_impl
     28 from tensorflow.python.ops import nn_ops
     29 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     30 from tensorflow.python.platform import test
     31 
     32 
     33 def ConfigsToTest():
     34   """Iterator for different convolution shapes, strides and paddings.
     35 
     36   Yields:
     37     Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
     38     convolution parameters.
     39   """
     40   input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
     41                  [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
     42                  [3, 299, 299, 3], [5, 183, 183, 1]]
     43   filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
     44                   [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
     45                                                              8], [5, 5, 1, 2]]
     46   out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
     47                [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
     48                [3, 150, 150, 24], [5, 92, 92, 2]]
     49   strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
     50   # pylint: disable=invalid-name
     51   VALID = "VALID"
     52   SAME = "SAME"
     53   # pylint: enable=invalid-name
     54   paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
     55   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
     56                            paddings):
     57     yield i, f, o, s, p
     58 
     59 
     60 def CheckGradConfigsToTest():
     61   """Iterator for different convolution shapes, strides and paddings.
     62 
     63   compute_gradient_error() is very expensive. So the configs should be
     64   relatively small.
     65 
     66   Yields:
     67     Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
     68     convolution parameters.
     69   """
     70   input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2],
     71                  [2, 15, 16, 1]]
     72   filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1],
     73                   [3, 3, 1, 2]]
     74   out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2],
     75                [2, 5, 5, 2]]
     76   strides = [1, 2, 1, 1, 3]
     77   # pylint: disable=invalid-name
     78   VALID = "VALID"
     79   SAME = "SAME"
     80   # pylint: enable=invalid-name
     81   paddings = [SAME, VALID, SAME, SAME, VALID]
     82   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
     83                            paddings):
     84     yield i, f, o, s, p
     85 
     86 
     87 class DepthwiseConv2DTest(test.TestCase):
     88 
     89   # This is testing that depthwise_conv2d and depthwise_conv2d_native
     90   # produce the same results.  It also tests that NCHW and NWHC
     91   # formats agree, by comparing the depthwise_conv2d_native with
     92   # 'NCHW' format (with transposition) matches the 'NHWC' format using
     93   # the higher level interface.
     94   def _VerifyValues(self,
     95                     tensor_in_sizes,
     96                     filter_in_sizes,
     97                     stride,
     98                     padding,
     99                     data_type,
    100                     use_gpu,
    101                     data_format="NHWC"):
    102     """Verifies the output values of the convolution function.
    103 
    104     Args:
    105       tensor_in_sizes: Input tensor dimensions in
    106         [batch, input_rows, input_cols, input_depth].
    107       filter_in_sizes: Filter tensor dimensions in
    108         [filter_rows, filter_cols, input_depth, depth_multiplier].
    109       stride: Stride.
    110       padding: Padding type.
    111       data_type: The data type to use.
    112       use_gpu: Whether to use GPU.
    113       data_format: The data_format of the input. "NHWC" or "NCHW".
    114     """
    115     total_size_1 = 1
    116     total_size_2 = 1
    117     for s in tensor_in_sizes:
    118       total_size_1 *= s
    119     for s in filter_in_sizes:
    120       total_size_2 *= s
    121     # Initializes the input and filter tensor with numbers incrementing from 1.
    122     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
    123     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
    124     with self.test_session(use_gpu=use_gpu) as sess:
    125       if data_type == dtypes.float16:
    126         tolerance = 1e-5
    127       elif data_type == dtypes.float32:
    128         tolerance = 1e-5
    129       else:
    130         self.assertEqual(data_type, dtypes.float64)
    131         tolerance = 1e-8
    132 
    133       t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
    134       t1.set_shape(tensor_in_sizes)
    135       t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=data_type)
    136 
    137       native_t1 = t1
    138       strides = [1, stride, stride, 1]
    139       if data_format == "NCHW":
    140         # Transpose from NWHC input to NCHW
    141         # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
    142         native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
    143         strides = [1, 1, stride, stride]
    144 
    145       conv_native = nn_ops.depthwise_conv2d_native(
    146           native_t1,
    147           t2,
    148           strides=strides,
    149           data_format=data_format,
    150           padding=padding)
    151 
    152       if data_format == "NCHW":
    153         # Transpose back from NCHW to NHWC
    154         conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
    155 
    156       conv_interface = nn_impl.depthwise_conv2d(
    157           t1, t2, strides=[1, stride, stride, 1], padding=padding)
    158 
    159       native_result = sess.run(conv_native)
    160       interface_result = sess.run(conv_interface)
    161 
    162     print("data_type:", data_type, "use_gpu:", use_gpu, "max diff = ",
    163           np.amax(np.absolute(native_result - interface_result)))
    164     self.assertArrayNear(
    165         np.ravel(native_result), np.ravel(interface_result), tolerance)
    166     self.assertShapeEqual(native_result, conv_native)
    167     self.assertShapeEqual(native_result, conv_interface)
    168 
    169   def testDepthwiseConv2D(self):
    170     for index, (input_size, filter_size, _, stride,
    171                 padding) in enumerate(ConfigsToTest()):
    172       print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
    173             filter_size, "stride:", stride, "padding:", padding)
    174       for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
    175         self._VerifyValues(
    176             input_size, filter_size, stride, padding, data_type, use_gpu=True)
    177 
    178   def testDepthwiseConv2DFormat(self):
    179     if not test.is_gpu_available():
    180       return
    181 
    182     for index, (input_size, filter_size, _, stride,
    183                 padding) in enumerate(ConfigsToTest()):
    184       print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
    185             "*", filter_size, "stride:", stride, "padding:", padding)
    186       for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
    187         self._VerifyValues(
    188             input_size,
    189             filter_size,
    190             stride,
    191             padding,
    192             data_type,
    193             use_gpu=True,
    194             data_format="NCHW")
    195 
    196 # This is testing against hand calculated results.
    197 
    198   def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
    199                         expected, use_gpu):
    200     """Verifies the output values of the depthwise convolution function.
    201 
    202     Args:
    203       tensor_in_sizes: Input tensor dimensions in
    204         [batch, input_rows, input_cols, input_depth].
    205       filter_in_sizes: Filter tensor dimensions in
    206         [filter_rows, filter_cols, input_depth, depth_multiplier].
    207       stride: Stride.
    208       padding: Padding type.
    209       expected: An array containing the expected operation outputs.
    210       use_gpu: Whether to use GPU.
    211     """
    212     total_size_1 = 1
    213     total_size_2 = 1
    214     for s in tensor_in_sizes:
    215       total_size_1 *= s
    216     for s in filter_in_sizes:
    217       total_size_2 *= s
    218     # Initializes the input tensor with array containing incrementing
    219     # numbers from 1.
    220     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
    221     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
    222     with self.test_session(use_gpu=use_gpu) as sess:
    223       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
    224       t1.set_shape(tensor_in_sizes)
    225       t2 = constant_op.constant(x2, shape=filter_in_sizes)
    226       conv = nn_ops.depthwise_conv2d_native(
    227           t1, t2, strides=[1, stride, stride, 1], padding=padding)
    228       value = sess.run(conv)
    229     print("value = ", value)
    230     self.assertArrayNear(expected, np.ravel(value), 1e-5)
    231     self.assertShapeEqual(value, conv)
    232 
    233   def testConv2D2x2Filter(self):
    234     # The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
    235     #
    236     # [ (1.0, 2.0), (3.0,  4.0), ( 5.0,  6.0) ]
    237     # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
    238     #  We can view this as two inputs
    239     #
    240     #  input depth 0:
    241     #
    242     #  [ 1.0,  3.0,  5.0 ]
    243     #  [ 7.0,  9.0, 11.0 ]
    244     #
    245     #  input depth 1:
    246     #
    247     #  [ 2.0,  4.0,  6.0 ]
    248     #  [ 8.0, 10.0, 12.0 ]
    249     #
    250     # The filter looks like this (it has two 2 x 2 patches, each generating 2
    251     # depths):
    252     #
    253     #  filter #0:
    254     #
    255     #  [ (1.0,  3.0), ( 5.0,  7.0)]
    256     #  [ (9.0, 11.0), (13.0, 15.0)]
    257     #
    258     #  filter #1:
    259     #
    260     #  [ ( 2.0,  4.0), ( 6.0,  8.0)]
    261     #  [ (10.0, 12.0), (14.0, 16.0)]
    262     #
    263     # So the outputs are:
    264     #
    265     # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
    266     #  1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
    267     # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
    268     #  1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
    269     # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
    270     #  2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
    271     # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
    272     #  2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
    273     #
    274     # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
    275     #  3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
    276     # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
    277     #  3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
    278     # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
    279     #  4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
    280     # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
    281     #  4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
    282     expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
    283     self._VerifyHandValues(
    284         tensor_in_sizes=[1, 2, 3, 2],
    285         filter_in_sizes=[2, 2, 2, 2],
    286         stride=1,
    287         padding="VALID",
    288         expected=expected_output,
    289         use_gpu=False)
    290 
    291     self._VerifyHandValues(
    292         tensor_in_sizes=[1, 2, 3, 2],
    293         filter_in_sizes=[2, 2, 2, 2],
    294         stride=1,
    295         padding="VALID",
    296         expected=expected_output,
    297         use_gpu=True)
    298 
    299   # Gradient checkers.This tests depthwise gradient computations for both
    300   # BackpropFilter and BackpropInput by comparing gradients computed by the
    301   # depthwise gradient ops with the gradients computed numerically (details can
    302   # be found in the compute_gradient_error().
    303   # Note this check is very expensive so the input should not be too big.
    304   def _ConstructAndTestGradient(self,
    305                                 input_shape,
    306                                 filter_shape,
    307                                 output_shape,
    308                                 stride,
    309                                 padding,
    310                                 data_type,
    311                                 test_input,
    312                                 use_gpu,
    313                                 data_format="NHWC"):
    314     input_size = 1
    315     for x in input_shape:
    316       input_size *= x
    317     filter_size = 1
    318     for x in filter_shape:
    319       filter_size *= x
    320     input_data = [x * 1.0 / input_size for x in range(0, input_size)]
    321     filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
    322     with self.test_session(use_gpu=use_gpu):
    323       if data_type == dtypes.float16:
    324         tolerance = 0.002
    325       elif data_type == dtypes.float32:
    326         tolerance = 0.002
    327       else:
    328         self.assertEqual(data_type, dtypes.float64)
    329         tolerance = 1e-8
    330 
    331       input_tensor = constant_op.constant(
    332           input_data, shape=input_shape, dtype=data_type, name="input")
    333       filter_tensor = constant_op.constant(
    334           filter_data, shape=filter_shape, dtype=data_type, name="filter")
    335 
    336       native_input = input_tensor
    337       strides = [1, stride, stride, 1]
    338       if data_format == "NCHW":
    339         # Transpose from NWHC input to NCHW
    340         # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
    341         native_input = array_ops.transpose(input_tensor, [0, 3, 1, 2])
    342         input_shape = [
    343             input_shape[0], input_shape[3], input_shape[1], input_shape[2]
    344         ]
    345         output_shape = [
    346             output_shape[0], output_shape[3], output_shape[1], output_shape[2]
    347         ]
    348         strides = [1, 1, stride, stride]
    349 
    350       depthwise_conv2d = nn_ops.depthwise_conv2d_native(
    351           native_input,
    352           filter_tensor,
    353           strides,
    354           padding,
    355           data_format=data_format,
    356           name="depthwise_conv2d")
    357 
    358       self.assertEqual(output_shape, depthwise_conv2d.get_shape())
    359       if test_input:
    360         err = gradient_checker.compute_gradient_error(
    361             native_input, input_shape, depthwise_conv2d, output_shape)
    362       else:
    363         err = gradient_checker.compute_gradient_error(filter_tensor,
    364                                                       filter_shape,
    365                                                       depthwise_conv2d,
    366                                                       output_shape)
    367       print("data_type:", data_type, "use_gpu:", use_gpu, ", error = ", err)
    368       self.assertLess(err, tolerance)
    369 
    370   def testDepthwiseConv2DInputGrad(self):
    371     for index, (input_size, filter_size, output_size, stride,
    372                 padding) in enumerate(CheckGradConfigsToTest()):
    373       print("Testing DepthwiseConv2DInputGrad,", index, "th config:",
    374             input_size, "*", filter_size, "stride:", stride, "padding:",
    375             padding)
    376       # Note: float16 test for DepthwiseConv2DInputGrad is not enabled,
    377       # calculations are not very precise.
    378       for data_type in [dtypes.float32, dtypes.float64]:
    379         self._ConstructAndTestGradient(
    380             input_size,
    381             filter_size,
    382             output_size,
    383             stride,
    384             padding,
    385             data_type,
    386             test_input=True,
    387             use_gpu=True)
    388 
    389   def testDepthwiseConv2DInputGradFormat(self):
    390     if not test.is_gpu_available():
    391       return
    392 
    393     for index, (input_size, filter_size, output_size, stride,
    394                 padding) in enumerate(CheckGradConfigsToTest()):
    395       print("Testing DepthwiseConv2DInputGradFormat,", index, "th config:",
    396             input_size, "*", filter_size, "stride:", stride, "padding:",
    397             padding)
    398       # Note: float16 test for DepthwiseConv2DInputGradFormat is not enabled,
    399       # calculations are not very precise.
    400       for data_type in [dtypes.float32, dtypes.float64]:
    401         self._ConstructAndTestGradient(
    402             input_size,
    403             filter_size,
    404             output_size,
    405             stride,
    406             padding,
    407             data_type,
    408             test_input=True,
    409             use_gpu=True,
    410             data_format="NCHW")
    411 
    412   def testDepthwiseConv2DFilterGrad(self):
    413     for index, (input_size, filter_size, output_size, stride,
    414                 padding) in enumerate(CheckGradConfigsToTest()):
    415       print("Testing DepthwiseConv2DFilterGrad,", index, "th config:",
    416             input_size, "*", filter_size, "stride:", stride, "padding:",
    417             padding)
    418       # Note: float16 test for DepthwiseConv2DFilterGrad is not enabled,
    419       # calculations are not very precise.
    420       for data_type in [dtypes.float32, dtypes.float64]:
    421         self._ConstructAndTestGradient(
    422             input_size,
    423             filter_size,
    424             output_size,
    425             stride,
    426             padding,
    427             data_type,
    428             test_input=False,
    429             use_gpu=True)
    430 
    431   def testDepthwiseConv2DFilterGradFormat(self):
    432     if not test.is_gpu_available():
    433       return
    434 
    435     for index, (input_size, filter_size, output_size, stride,
    436                 padding) in enumerate(CheckGradConfigsToTest()):
    437       print("Testing DepthwiseConv2DFilterGradFormat,", index, "th config:",
    438             input_size, "*", filter_size, "stride:", stride, "padding:",
    439             padding)
    440       # Note: float16 test for DepthwiseConv2DFilterGradFormat is not enabled,
    441       # calculations are not very precise.
    442       for data_type in [dtypes.float32, dtypes.float64]:
    443         self._ConstructAndTestGradient(
    444             input_size,
    445             filter_size,
    446             output_size,
    447             stride,
    448             padding,
    449             data_type,
    450             test_input=False,
    451             use_gpu=True,
    452             data_format="NCHW")
    453 
    454   def _CompareBackpropInputFloat(self, input_sizes, filter_sizes, output_sizes,
    455                                  stride, padding):
    456     x1 = np.random.rand(*filter_sizes).astype(np.float32)
    457     x2 = np.random.rand(*output_sizes).astype(np.float32)
    458 
    459     def _GetVal(use_gpu):
    460       with self.test_session(use_gpu=use_gpu):
    461         t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
    462         t1 = constant_op.constant(x1, shape=filter_sizes)
    463         t2 = constant_op.constant(x2, shape=output_sizes)
    464         backprop = nn_ops.depthwise_conv2d_native_backprop_input(
    465             t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    466         ret = backprop.eval()
    467         self.assertShapeEqual(ret, backprop)
    468         return ret
    469 
    470     gpu_value = _GetVal(use_gpu=True)
    471     cpu_value = _GetVal(use_gpu=False)
    472     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    473 
    474   def _CompareBackpropInputDouble(self, input_sizes, filter_sizes, output_sizes,
    475                                   stride, padding):
    476     x1 = np.random.rand(*filter_sizes).astype(np.float64)
    477     x2 = np.random.rand(*output_sizes).astype(np.float64)
    478 
    479     def _GetVal(use_gpu):
    480       with self.test_session(use_gpu=use_gpu):
    481         t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
    482         t1 = constant_op.constant(x1, shape=filter_sizes)
    483         t2 = constant_op.constant(x2, shape=output_sizes)
    484         backprop = nn_ops.depthwise_conv2d_native_backprop_input(
    485             t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    486         ret = backprop.eval()
    487         self.assertShapeEqual(ret, backprop)
    488         return ret
    489 
    490     gpu_value = _GetVal(use_gpu=True)
    491     cpu_value = _GetVal(use_gpu=False)
    492     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    493 
    494   def testDepthwiseConv2DInputGradCompare(self):
    495     for index, (input_size, filter_size, output_size, stride,
    496                 padding) in enumerate(ConfigsToTest()):
    497       print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
    498             input_size, "*", filter_size, "stride:", stride, "padding:",
    499             padding)
    500       self._CompareBackpropInputFloat(input_size, filter_size, output_size,
    501                                       stride, padding)
    502       self._CompareBackpropInputDouble(input_size, filter_size, output_size,
    503                                        stride, padding)
    504 
    505   def _CompareBackpropFilterFloat(self, input_sizes, filter_sizes, output_sizes,
    506                                   stride, padding):
    507     x0 = np.random.rand(*input_sizes).astype(np.float32)
    508     x2 = np.random.rand(*output_sizes).astype(np.float32)
    509 
    510     def _GetVal(use_gpu):
    511       with self.test_session(use_gpu=use_gpu):
    512         t0 = constant_op.constant(x0, shape=input_sizes)
    513         t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
    514         t2 = constant_op.constant(x2, shape=output_sizes)
    515         backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
    516             t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    517         ret = backprop.eval()
    518         self.assertShapeEqual(ret, backprop)
    519         return ret
    520 
    521     gpu_value = _GetVal(use_gpu=True)
    522     cpu_value = _GetVal(use_gpu=False)
    523     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    524 
    525   def _CompareBackpropFilterDouble(self, input_sizes, filter_sizes,
    526                                    output_sizes, stride, padding):
    527     x0 = np.random.rand(*input_sizes).astype(np.float64)
    528     x2 = np.random.rand(*output_sizes).astype(np.float64)
    529 
    530     def _GetVal(use_gpu):
    531       with self.test_session(use_gpu=use_gpu):
    532         t0 = constant_op.constant(x0, shape=input_sizes)
    533         t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
    534         t2 = constant_op.constant(x2, shape=output_sizes)
    535         backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
    536             t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    537         ret = backprop.eval()
    538         self.assertShapeEqual(ret, backprop)
    539         return ret
    540 
    541     gpu_value = _GetVal(use_gpu=True)
    542     cpu_value = _GetVal(use_gpu=False)
    543     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    544 
    545   def testDepthwiseConv2DFilterGradCompare(self):
    546     for index, (input_size, filter_size, output_size, stride,
    547                 padding) in enumerate(ConfigsToTest()):
    548       print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
    549             input_size, "*", filter_size, "stride:", stride, "padding:",
    550             padding)
    551       self._CompareBackpropFilterFloat(input_size, filter_size, output_size,
    552                                        stride, padding)
    553       self._CompareBackpropFilterDouble(input_size, filter_size, output_size,
    554                                         stride, padding)
    555 
    556 
    557 if __name__ == "__main__":
    558   test.main()
    559