Home | History | Annotate | Download | only in tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for depthwise convolutional operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from six.moves import xrange  # pylint: disable=redefined-builtin
     23 
     24 from tensorflow.compiler.tests.xla_test import XLATestCase
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import nn_ops
     29 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     30 from tensorflow.python.platform import test
     31 
     32 
     33 # Reference implementation of depthwise_conv2d
     34 def ReferenceDepthwiseConv2D(input_tensor, filter_tensor, strides, padding,
     35                              data_format=None):
     36   # Reference implementation of depthwise convolution that uses regular
     37   # convolution.
     38   convs = []
     39   in_channels = filter_tensor.shape[2]
     40   # Use a custom implementation of depthwise conv2d using slicing.
     41   for channel in xrange(in_channels):
     42     # Slice the input along channel
     43     if data_format == "NCHW":
     44       input_slice = input_tensor[:, channel:channel+1, :, :]
     45     else:
     46       input_slice = input_tensor[:, :, :, channel:channel+1]
     47 
     48     # Slice the filters.  Filters are  H, W, InC, DepthMultiplier
     49     filter_slice = filter_tensor[:, :, channel:channel+1, :]
     50     # Do conv
     51     convs.append(nn_ops.conv2d(input_slice, filter_slice,
     52                                strides, padding,
     53                                data_format=data_format,
     54                                name="depthwise_slice_%d" % channel))
     55 
     56   # Concat along dimension.
     57   if data_format == "NCHW":
     58     return array_ops.concat(convs, 1)
     59   else:
     60     return array_ops.concat(convs, 3)
     61 
     62 
     63 def ConfigsToTest():
     64   """Iterator for different convolution shapes, strides and paddings.
     65 
     66   Yields:
     67     Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
     68     convolution parameters.
     69   """
     70   input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
     71                  [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
     72                  [3, 299, 299, 3], [5, 183, 183, 1]]
     73   filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
     74                   [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
     75                                                              8], [5, 5, 1, 2]]
     76   out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
     77                [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
     78                [3, 150, 150, 24], [5, 92, 92, 2]]
     79   strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
     80   # pylint: disable=invalid-name
     81   VALID = "VALID"
     82   SAME = "SAME"
     83   # pylint: enable=invalid-name
     84   paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
     85   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
     86                            paddings):
     87     yield i, f, o, s, p
     88 
     89 
     90 def CheckGradConfigsToTest():
     91   """Iterator for different convolution shapes, strides and paddings.
     92 
     93   compute_gradient_error() is very expensive. So the configs should be
     94   relatively small.
     95 
     96   Yields:
     97     Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
     98     convolution parameters.
     99   """
    100   input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2],
    101                  [2, 15, 16, 1]]
    102   filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1],
    103                   [3, 3, 1, 2]]
    104   out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2],
    105                [2, 5, 5, 2]]
    106   strides = [1, 2, 1, 1, 3]
    107   # pylint: disable=invalid-name
    108   VALID = "VALID"
    109   SAME = "SAME"
    110   # pylint: enable=invalid-name
    111   paddings = [SAME, VALID, SAME, SAME, VALID]
    112   for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
    113                            paddings):
    114     yield i, f, o, s, p
    115 
    116 
    117 class DepthwiseConv2DTest(XLATestCase):
    118 
    119   # This is testing that depthwise_conv2d and depthwise_conv2d_native
    120   # produce the same results.  It also tests that NCHW and NWHC
    121   # formats agree, by comparing the depthwise_conv2d_native with
    122   # 'NCHW' format (with transposition) matches the 'NHWC' format using
    123   # the higher level interface.
    124   def _VerifyValues(self,
    125                     tensor_in_sizes,
    126                     filter_in_sizes,
    127                     stride,
    128                     padding,
    129                     data_type,
    130                     data_format="NHWC"):
    131     """Verifies the output values of the convolution function.
    132 
    133     Args:
    134       tensor_in_sizes: Input tensor dimensions in
    135         [batch, input_rows, input_cols, input_depth].
    136       filter_in_sizes: Filter tensor dimensions in
    137         [filter_rows, filter_cols, input_depth, depth_multiplier].
    138       stride: Stride.
    139       padding: Padding type.
    140       data_type: The data type to use.
    141       data_format: The data_format of the input. "NHWC" or "NCHW".
    142     """
    143     total_size_1 = 1
    144     total_size_2 = 1
    145     for s in tensor_in_sizes:
    146       total_size_1 *= s
    147     for s in filter_in_sizes:
    148       total_size_2 *= s
    149     # Initializes the input and filter tensor with numbers incrementing from 1.
    150     x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
    151                   dtype=data_type).reshape(tensor_in_sizes)
    152     x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
    153                   dtype=data_type).reshape(filter_in_sizes)
    154     with self.test_session() as sess:
    155       if data_type == np.float32:
    156         tolerance = 1e-5
    157       else:
    158         self.assertEqual(data_type, np.float64)
    159         tolerance = 1e-8
    160 
    161       t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type)
    162       t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type)
    163 
    164       native_t1 = t1
    165       strides = [1, stride, stride, 1]
    166       if data_format == "NCHW":
    167         # Transpose from NWHC input to NCHW
    168         # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
    169         native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
    170         strides = [1, 1, stride, stride]
    171 
    172       with self.test_scope():
    173         conv_native = nn_ops.depthwise_conv2d_native(
    174             native_t1,
    175             t2,
    176             strides=strides,
    177             data_format=data_format,
    178             padding=padding)
    179 
    180       if data_format == "NCHW":
    181         # Transpose back from NCHW to NHWC
    182         conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
    183 
    184       with ops.device("CPU"):
    185         conv_interface = ReferenceDepthwiseConv2D(
    186             t1, t2, strides=[1, stride, stride, 1], padding=padding)
    187 
    188       native_result = sess.run(conv_native, {t1: x1, t2: x2})
    189       interface_result = sess.run(conv_interface, {t1: x1, t2: x2})
    190 
    191     print("data_type:", data_type, "max diff = ",
    192           np.amax(np.absolute(native_result - interface_result)))
    193     self.assertAllClose(
    194         np.ravel(native_result), np.ravel(interface_result), rtol=tolerance)
    195 
    196   def testDepthwiseConv2D(self):
    197     for index, (input_size, filter_size, _, stride,
    198                 padding) in enumerate(ConfigsToTest()):
    199       print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
    200             filter_size, "stride:", stride, "padding:", padding)
    201       for data_type in self.float_types:
    202         # TODO(phawkins): the reference implementation only supports float32.
    203         if data_type == np.float32:
    204           self._VerifyValues(
    205               input_size, filter_size, stride, padding, data_type)
    206 
    207   def testDepthwiseConv2DFormat(self):
    208     for index, (input_size, filter_size, _, stride,
    209                 padding) in enumerate(ConfigsToTest()):
    210       print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
    211             "*", filter_size, "stride:", stride, "padding:", padding)
    212       for data_type in self.float_types:
    213         # TODO(phawkins): the reference implementation only supports float32.
    214         if data_type == np.float32:
    215           self._VerifyValues(
    216               input_size,
    217               filter_size,
    218               stride,
    219               padding,
    220               data_type,
    221               data_format="NCHW")
    222 
    223 # This is testing against hand calculated results.
    224 
    225   def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
    226                         expected):
    227     """Verifies the output values of the depthwise convolution function.
    228 
    229     Args:
    230       tensor_in_sizes: Input tensor dimensions in
    231         [batch, input_rows, input_cols, input_depth].
    232       filter_in_sizes: Filter tensor dimensions in
    233         [filter_rows, filter_cols, input_depth, depth_multiplier].
    234       stride: Stride.
    235       padding: Padding type.
    236       expected: An array containing the expected operation outputs.
    237     """
    238     total_size_1 = 1
    239     total_size_2 = 1
    240     for s in tensor_in_sizes:
    241       total_size_1 *= s
    242     for s in filter_in_sizes:
    243       total_size_2 *= s
    244     # Initializes the input tensor with array containing incrementing
    245     # numbers from 1.
    246     x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
    247                   dtype=np.float32).reshape(tensor_in_sizes)
    248     x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
    249                   dtype=np.float32).reshape(filter_in_sizes)
    250     with self.test_session() as sess:
    251       t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32)
    252       t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32)
    253       with self.test_scope():
    254         conv = nn_ops.depthwise_conv2d_native(
    255             t1, t2, strides=[1, stride, stride, 1], padding=padding)
    256       value = sess.run(conv, {t1: x1, t2: x2})
    257     print("value = ", value)
    258     self.assertArrayNear(expected, np.ravel(value), 1e-5)
    259     self.assertShapeEqual(value, conv)
    260 
    261   def testConv2D2x2Filter(self):
    262     # The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
    263     #
    264     # [ (1.0, 2.0), (3.0,  4.0), ( 5.0,  6.0) ]
    265     # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
    266     #  We can view this as two inputs
    267     #
    268     #  input depth 0:
    269     #
    270     #  [ 1.0,  3.0,  5.0 ]
    271     #  [ 7.0,  9.0, 11.0 ]
    272     #
    273     #  input depth 1:
    274     #
    275     #  [ 2.0,  4.0,  6.0 ]
    276     #  [ 8.0, 10.0, 12.0 ]
    277     #
    278     # The filter looks like this (it has two 2 x 2 patches, each generating 2
    279     # depths):
    280     #
    281     #  filter #0:
    282     #
    283     #  [ (1.0,  3.0), ( 5.0,  7.0)]
    284     #  [ (9.0, 11.0), (13.0, 15.0)]
    285     #
    286     #  filter #1:
    287     #
    288     #  [ ( 2.0,  4.0), ( 6.0,  8.0)]
    289     #  [ (10.0, 12.0), (14.0, 16.0)]
    290     #
    291     # So the outputs are:
    292     #
    293     # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
    294     #  1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
    295     # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
    296     #  1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
    297     # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
    298     #  2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
    299     # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
    300     #  2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
    301     #
    302     # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
    303     #  3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
    304     # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
    305     #  3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
    306     # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
    307     #  4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
    308     # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
    309     #  4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
    310     expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
    311     self._VerifyHandValues(
    312         tensor_in_sizes=[1, 2, 3, 2],
    313         filter_in_sizes=[2, 2, 2, 2],
    314         stride=1,
    315         padding="VALID",
    316         expected=expected_output)
    317 
    318   def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
    319                             stride, padding):
    320     x1 = np.random.rand(*filter_sizes).astype(np.float32)
    321     x2 = np.random.rand(*output_sizes).astype(np.float32)
    322 
    323     def _GetVal(use_xla):
    324       with self.test_session():
    325         t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
    326         t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
    327         t2 = array_ops.placeholder(np.float32, shape=output_sizes)
    328         if use_xla:
    329           with self.test_scope():
    330             backprop = nn_ops.depthwise_conv2d_native_backprop_input(
    331                 t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    332         else:
    333           backprop = nn_ops.depthwise_conv2d_native_backprop_input(
    334               t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    335 
    336         ret = backprop.eval({t1: x1, t2: x2})
    337         self.assertShapeEqual(ret, backprop)
    338         return ret
    339 
    340     gpu_value = _GetVal(use_xla=True)
    341     cpu_value = _GetVal(use_xla=False)
    342     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    343 
    344   def testDepthwiseConv2DInputGradCompare(self):
    345     for index, (input_size, filter_size, output_size, stride,
    346                 padding) in enumerate(ConfigsToTest()):
    347       print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
    348             input_size, "*", filter_size, "stride:", stride, "padding:",
    349             padding)
    350       self._CompareBackpropInput(input_size, filter_size, output_size, stride,
    351                                  padding)
    352 
    353   def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
    354                              stride, padding):
    355     x0 = np.random.rand(*input_sizes).astype(np.float32)
    356     x2 = np.random.rand(*output_sizes).astype(np.float32)
    357 
    358     def _GetVal(use_xla):
    359       with self.test_session():
    360         t0 = array_ops.placeholder(np.float32, shape=input_sizes)
    361         t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
    362         t2 = array_ops.placeholder(np.float32, shape=output_sizes)
    363         if use_xla:
    364           with self.test_scope():
    365             backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
    366                 t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    367         else:
    368           backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
    369               t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
    370         ret = backprop.eval({t0: x0, t2: x2})
    371         self.assertShapeEqual(ret, backprop)
    372         return ret
    373 
    374     gpu_value = _GetVal(use_xla=True)
    375     cpu_value = _GetVal(use_xla=False)
    376     self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
    377 
    378   def testDepthwiseConv2DFilterGradCompare(self):
    379     for index, (input_size, filter_size, output_size, stride,
    380                 padding) in enumerate(ConfigsToTest()):
    381       print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
    382             input_size, "*", filter_size, "stride:", stride, "padding:",
    383             padding)
    384       self._CompareBackpropFilter(input_size, filter_size, output_size,
    385                                   stride, padding)
    386 
    387 
    388 if __name__ == "__main__":
    389   test.main()
    390