Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for atrous convolution functionality in tensorflow.ops.nn."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import contextlib
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.eager import context
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import test_util
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import gradient_checker
     31 from tensorflow.python.ops import nn_ops
     32 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     33 from tensorflow.python.platform import test
     34 
     35 
     36 def upsample_filters(filters, rate):
     37   """Upsamples the filters by a factor of rate along the spatial dimensions.
     38 
     39   Args:
     40     filters: spatial_shape + [in_channels, out_channels]
     41       Original filters.
     42     rate: A list of len(spatial_shape) positive ints, specifying the
     43       upsampling rate.
     44 
     45   Returns:
     46     filters_up: output_spatial_shape + [in_channels, out_channels].
     47       Upsampled filters with
     48       output_spatial_shape[i] = (spatial_shape[i] - 1) * rate[i] + 1
     49       containing (rate[i] - 1) zeros between consecutive filter values along
     50       spatial dimension i.
     51   """
     52   num_spatial_dims = len(rate)
     53   spatial_shape = np.array(filters.shape[:num_spatial_dims])
     54   output_spatial_shape = (spatial_shape - 1) * rate + 1
     55   output = np.zeros(
     56       tuple(output_spatial_shape) + tuple(filters.shape[-2:]), filters.dtype)
     57   output[tuple(np.s_[::rate[i]] for i in range(num_spatial_dims))] = filters
     58   return output
     59 
     60 
     61 class AtrousConvolutionTest(test.TestCase):
     62 
     63   @contextlib.contextmanager
     64   def _delay_checks(self):
     65     """Context manager for combining checks depending on tensor evaluations.
     66 
     67     Each call to Session.run has some overhead, and this overhead can easily
     68     account for the majority of the time spent in tests that call Session.run
     69     (or Tensor.eval) many times.
     70 
     71     This context manager provides a mechanism for registering callback functions
     72     and associated tensors.  When the context is exited, all of the tensors
     73     associated with all of the registrations are evaluated with a single call to
     74     Session.run, and then each registered callback function is called with the
     75     values of its associated tensors.
     76 
     77     Yields:
     78       A function `add_check(check, *args, **kwargs)` where `check` is the
     79       callback function to be invoked, and `*args` and `**kwargs` specify the
     80       associated Tensors. When in EAGER mode, check is executed in add_check,
     81       otherwise, it's delayed after the context.
     82     """
     83     checks = []
     84 
     85     def add_check(check, *args, **kwargs):
     86       if context.executing_eagerly():
     87         args_val, kwargs_val = self.evaluate([args, kwargs])
     88         check(*args_val, **kwargs_val)
     89       else:
     90         checks.append((check, args, kwargs))
     91 
     92     yield add_check
     93     if not context.executing_eagerly():
     94       all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
     95       for (check, _, _), (args, kwargs) in zip(checks, all_values):
     96         check(*args, **kwargs)
     97 
     98   def _test_atrous_convolution(self, add_check, input_shape, filter_shape,
     99                                dilation_rate, **kwargs):
    100     filters = np.arange(
    101         np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
    102     filters_upsampled = upsample_filters(filters, dilation_rate)
    103     x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
    104     y1 = nn_ops.convolution(
    105         input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
    106     y2 = nn_ops.convolution(input=x, filter=filters_upsampled, **kwargs)
    107 
    108     def check(y1_eval, y2_eval):
    109       self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2)
    110 
    111     add_check(check, y1, y2)
    112 
    113   @test_util.run_v1_only("b/120545219")
    114   def test_unknown_spatial_dims_for_channel_last_format(self):
    115     x = array_ops.placeholder(dtypes.float32, [1, None, None, 10])
    116     w = array_ops.zeros([3, 3, 10, 20])
    117     y = nn_ops.convolution(
    118         x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC")
    119     self.assertEqual(y.shape.as_list(), [1, None, None, 20])
    120 
    121   @test_util.run_v1_only("b/120545219")
    122   def test_unknown_spatial_dims_for_channel_first_format(self):
    123     x = array_ops.placeholder(dtypes.float32, [1, 10, None, None])
    124     w = array_ops.zeros([3, 3, 10, 20])
    125     y = nn_ops.convolution(
    126         x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW")
    127     self.assertEqual(y.shape.as_list(), [1, 20, None, None])
    128 
    129   @test_util.run_in_graph_and_eager_modes
    130   def testAtrousConvolution2D(self):
    131     with self._delay_checks() as add_check:
    132       for padding in ["SAME", "VALID"]:
    133         for height, width in [[9, 9], [9, 10]]:
    134           for kernel_height, kernel_width in [[1, 1], [2, 2], [2, 3]]:
    135             for dilation_rate in [[1, 1], [3, 2], [2, 1]]:
    136               self._test_atrous_convolution(
    137                   add_check=add_check,
    138                   input_shape=[2, height, width, 2],
    139                   filter_shape=[kernel_height, kernel_width, 2, 2],
    140                   padding=padding,
    141                   dilation_rate=dilation_rate,
    142               )
    143 
    144   @test_util.run_in_graph_and_eager_modes
    145   def testAtrousConvolution3D(self):
    146     with self._delay_checks() as add_check:
    147       for padding in ["SAME", "VALID"]:
    148         for depth, height, width in [[9, 9, 10], [9, 10, 9]]:
    149           for kernel_depth, kernel_height, kernel_width in [[3, 3,
    150                                                              3], [3, 2, 2],
    151                                                             [2, 1, 3]]:
    152             for dilation_rate in [[1, 1, 1], [3, 3, 3], [3, 2, 3], [3, 1, 2]]:
    153               self._test_atrous_convolution(
    154                   add_check=add_check,
    155                   input_shape=[2, depth, height, width, 2],
    156                   filter_shape=[
    157                       kernel_depth, kernel_height, kernel_width, 2, 2
    158                   ],
    159                   padding=padding,
    160                   dilation_rate=dilation_rate,
    161               )
    162 
    163   @test_util.run_in_graph_and_eager_modes
    164   def testAtrousConvolution1D(self):
    165     with self._delay_checks() as add_check:
    166       for padding in ["SAME", "VALID"]:
    167         for width in [9, 10]:
    168           for kernel_width in range(1, 4):
    169             for rate in range(1, 4):
    170               self._test_atrous_convolution(
    171                   add_check=add_check,
    172                   input_shape=[2, width, 2],
    173                   filter_shape=[kernel_width, 2, 2],
    174                   padding=padding,
    175                   dilation_rate=[rate],
    176               )
    177 
    178   @test_util.run_in_graph_and_eager_modes
    179   def testAtrousConvolutionNC(self):
    180     if test.is_gpu_available(cuda_only=True):
    181       # "NCW" and "NCHW" formats are currently supported only on CUDA.
    182       with test_util.device(use_gpu=True):
    183         with self._delay_checks() as add_check:
    184           for padding in ["SAME", "VALID"]:
    185             self._test_atrous_convolution(
    186                 add_check=add_check,
    187                 input_shape=[2, 2, 9],
    188                 padding=padding,
    189                 filter_shape=[3, 2, 2],
    190                 dilation_rate=[2],
    191                 data_format="NCW",
    192             )
    193             self._test_atrous_convolution(
    194                 add_check=add_check,
    195                 input_shape=[2, 2, 9, 5],
    196                 padding=padding,
    197                 filter_shape=[3, 3, 2, 2],
    198                 dilation_rate=[2, 1],
    199                 data_format="NCHW",
    200             )
    201 
    202   @test_util.run_in_graph_and_eager_modes
    203   def testAtrousSequence(self):
    204     """Tests optimization of sequence of atrous convolutions.
    205 
    206     See the documentation of with_space_to_batch.
    207     """
    208     with self._delay_checks() as add_check:
    209       for padding in ["SAME", "VALID"]:
    210         for height in range(15, 17):
    211           for width in range(15, 17):
    212             x_shape = [3, height, width, 2]
    213             x = np.random.random_sample(x_shape).astype(np.float32)
    214 
    215             kernel_sizes = [1, 3] if padding == "SAME" else range(1, 3)
    216             for kernel in kernel_sizes:
    217               f_shape = [kernel, kernel, 2, 2]
    218               f1 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
    219               f2 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
    220 
    221               def combined_op(converted_input, num_spatial_dims, padding_arg):  # pylint: disable=unused-argument
    222                 # pylint: disable=cell-var-from-loop
    223                 result = nn_ops.convolution(
    224                     input=converted_input, filter=f1, padding=padding)
    225                 result = nn_ops.convolution(
    226                     input=result, filter=f2, padding=padding)
    227                 # pylint: enable=cell-var-from-loop
    228                 return result
    229 
    230               for rate_height in range(2, 4):
    231                 for rate_width in range(2, 4):
    232                   dilation_rate = [rate_height, rate_width]
    233                   y1 = nn_ops.convolution(
    234                       input=x,
    235                       filter=f1,
    236                       padding=padding,
    237                       dilation_rate=dilation_rate)
    238                   y1 = nn_ops.convolution(
    239                       input=y1,
    240                       filter=f2,
    241                       padding=padding,
    242                       dilation_rate=dilation_rate)
    243                   y2 = nn_ops.with_space_to_batch(
    244                       input=x,
    245                       dilation_rate=dilation_rate,
    246                       op=combined_op,
    247                       padding="VALID")
    248 
    249                   def check(y1_eval, y2_eval):
    250                     self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2)
    251 
    252                   add_check(check, y1, y2)
    253 
    254   def _test_gradient(self, x_shape, f_shape, dilation_rate, padding):
    255     x_val = np.random.random_sample(x_shape).astype(np.float32)
    256     f_val = np.random.random_sample(f_shape).astype(np.float32)
    257     x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
    258     f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
    259     output = nn_ops.convolution(
    260         input=x, filter=f, dilation_rate=dilation_rate, padding=padding)
    261     y_shape = output.get_shape().as_list()
    262     err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape],
    263                                                   output, y_shape)
    264     err_tolerance = 1e-3
    265     self.assertLess(err, err_tolerance)
    266 
    267   @test_util.run_v1_only("b/120545219")
    268   def testGradient(self):
    269     with self.cached_session():
    270       for padding in ["SAME", "VALID"]:
    271         for rate_width in range(1, 3):
    272           for rate_height in range(1, 3):
    273             self._test_gradient(
    274                 x_shape=[2, 5, 6, 2],
    275                 f_shape=[3, 3, 2, 2],
    276                 dilation_rate=[rate_height, rate_width],
    277                 padding=padding)
    278 
    279 
    280 if __name__ == "__main__":
    281   test.main()
    282