Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for atrous convolution functionality in tensorflow.ops.nn."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import contextlib
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.eager import context
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import test_util
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import gradient_checker
     31 from tensorflow.python.ops import nn_ops
     32 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     33 from tensorflow.python.platform import test
     34 
     35 
     36 def upsample_filters(filters, rate):
     37   """Upsamples the filters by a factor of rate along the spatial dimensions.
     38 
     39   Args:
     40     filters: spatial_shape + [in_channels, out_channels]
     41       Original filters.
     42     rate: A list of len(spatial_shape) positive ints, specifying the
     43       upsampling rate.
     44 
     45   Returns:
     46     filters_up: output_spatial_shape + [in_channels, out_channels].
     47       Upsampled filters with
     48       output_spatial_shape[i] = (spatial_shape[i] - 1) * rate[i] + 1
     49       containing (rate[i] - 1) zeros between consecutive filter values along
     50       spatial dimension i.
     51   """
     52   num_spatial_dims = len(rate)
     53   spatial_shape = np.array(filters.shape[:num_spatial_dims])
     54   output_spatial_shape = (spatial_shape - 1) * rate + 1
     55   output = np.zeros(
     56       tuple(output_spatial_shape) + tuple(filters.shape[-2:]), filters.dtype)
     57   output[tuple(np.s_[::rate[i]] for i in range(num_spatial_dims))] = filters
     58   return output
     59 
     60 
     61 class AtrousConvolutionTest(test.TestCase):
     62 
     63   @contextlib.contextmanager
     64   def _delay_checks(self):
     65     """Context manager for combining checks depending on tensor evaluations.
     66 
     67     Each call to Session.run has some overhead, and this overhead can easily
     68     account for the majority of the time spent in tests that call Session.run
     69     (or Tensor.eval) many times.
     70 
     71     This context manager provides a mechanism for registering callback functions
     72     and associated tensors.  When the context is exited, all of the tensors
     73     associated with all of the registrations are evaluated with a single call to
     74     Session.run, and then each registered callback function is called with the
     75     values of its associated tensors.
     76 
     77     Yields:
     78       A function `add_check(check, *args, **kwargs)` where `check` is the
     79       callback function to be invoked, and `*args` and `**kwargs` specify the
     80       associated Tensors. When in EAGER mode, check is executed in add_check,
     81       otherwise, it's delayed after the context.
     82     """
     83     checks = []
     84 
     85     def add_check(check, *args, **kwargs):
     86       if context.in_eager_mode():
     87         args_val, kwargs_val = self.evaluate([args, kwargs])
     88         check(*args_val, **kwargs_val)
     89       else:
     90         checks.append((check, args, kwargs))
     91 
     92     yield add_check
     93     if context.in_graph_mode():
     94       all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
     95       for (check, _, _), (args, kwargs) in zip(checks, all_values):
     96         check(*args, **kwargs)
     97 
     98   def _test_atrous_convolution(self, add_check, input_shape, filter_shape,
     99                                dilation_rate, **kwargs):
    100     filters = np.arange(
    101         np.prod(filter_shape), dtype=np.float32).reshape(filter_shape)
    102     filters_upsampled = upsample_filters(filters, dilation_rate)
    103     x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape)
    104     y1 = nn_ops.convolution(
    105         input=x, filter=filters, dilation_rate=dilation_rate, **kwargs)
    106     y2 = nn_ops.convolution(input=x, filter=filters_upsampled, **kwargs)
    107 
    108     def check(y1_eval, y2_eval):
    109       self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2)
    110 
    111     add_check(check, y1, y2)
    112 
    113   def test_unknown_spatial_dims_for_channel_last_format(self):
    114     x = array_ops.placeholder(dtypes.float32, [1, None, None, 10])
    115     w = array_ops.zeros([3, 3, 10, 20])
    116     y = nn_ops.convolution(
    117         x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC")
    118     self.assertEqual(y.shape.as_list(), [1, None, None, 20])
    119 
    120   def test_unknown_spatial_dims_for_channel_first_format(self):
    121     x = array_ops.placeholder(dtypes.float32, [1, 10, None, None])
    122     w = array_ops.zeros([3, 3, 10, 20])
    123     y = nn_ops.convolution(
    124         x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW")
    125     self.assertEqual(y.shape.as_list(), [1, 20, None, None])
    126 
    127   @test_util.run_in_graph_and_eager_modes()
    128   def testAtrousConvolution2D(self):
    129     with self._delay_checks() as add_check:
    130       for padding in ["SAME", "VALID"]:
    131         for height, width in [[9, 9], [9, 10]]:
    132           for kernel_height, kernel_width in [[1, 1], [2, 2], [2, 3]]:
    133             for dilation_rate in [[1, 1], [3, 2], [2, 1]]:
    134               self._test_atrous_convolution(
    135                   add_check=add_check,
    136                   input_shape=[2, height, width, 2],
    137                   filter_shape=[kernel_height, kernel_width, 2, 2],
    138                   padding=padding,
    139                   dilation_rate=dilation_rate,
    140               )
    141 
    142   @test_util.run_in_graph_and_eager_modes()
    143   def testAtrousConvolution3D(self):
    144     with self._delay_checks() as add_check:
    145       for padding in ["SAME", "VALID"]:
    146         for depth, height, width in [[9, 9, 10], [9, 10, 9]]:
    147           for kernel_depth, kernel_height, kernel_width in [[3, 3,
    148                                                              3], [3, 2, 2],
    149                                                             [2, 1, 3]]:
    150             for dilation_rate in [[1, 1, 1], [3, 3, 3], [3, 2, 3], [3, 1, 2]]:
    151               self._test_atrous_convolution(
    152                   add_check=add_check,
    153                   input_shape=[2, depth, height, width, 2],
    154                   filter_shape=[
    155                       kernel_depth, kernel_height, kernel_width, 2, 2
    156                   ],
    157                   padding=padding,
    158                   dilation_rate=dilation_rate,
    159               )
    160 
    161   @test_util.run_in_graph_and_eager_modes()
    162   def testAtrousConvolution1D(self):
    163     with self._delay_checks() as add_check:
    164       for padding in ["SAME", "VALID"]:
    165         for width in [9, 10]:
    166           for kernel_width in range(1, 4):
    167             for rate in range(1, 4):
    168               self._test_atrous_convolution(
    169                   add_check=add_check,
    170                   input_shape=[2, width, 2],
    171                   filter_shape=[kernel_width, 2, 2],
    172                   padding=padding,
    173                   dilation_rate=[rate],
    174               )
    175 
    176   @test_util.run_in_graph_and_eager_modes()
    177   def testAtrousConvolutionNC(self):
    178     if test.is_gpu_available(cuda_only=True):
    179       # "NCW" and "NCHW" formats are currently supported only on CUDA.
    180       with test_util.device(use_gpu=True):
    181         with self._delay_checks() as add_check:
    182           for padding in ["SAME", "VALID"]:
    183             self._test_atrous_convolution(
    184                 add_check=add_check,
    185                 input_shape=[2, 2, 9],
    186                 padding=padding,
    187                 filter_shape=[3, 2, 2],
    188                 dilation_rate=[2],
    189                 data_format="NCW",
    190             )
    191             self._test_atrous_convolution(
    192                 add_check=add_check,
    193                 input_shape=[2, 2, 9, 5],
    194                 padding=padding,
    195                 filter_shape=[3, 3, 2, 2],
    196                 dilation_rate=[2, 1],
    197                 data_format="NCHW",
    198             )
    199 
    200   @test_util.run_in_graph_and_eager_modes()
    201   def testAtrousSequence(self):
    202     """Tests optimization of sequence of atrous convolutions.
    203 
    204     See the documentation of with_space_to_batch.
    205     """
    206     with self._delay_checks() as add_check:
    207       for padding in ["SAME", "VALID"]:
    208         for height in range(15, 17):
    209           for width in range(15, 17):
    210             x_shape = [3, height, width, 2]
    211             x = np.random.random_sample(x_shape).astype(np.float32)
    212 
    213             kernel_sizes = [1, 3] if padding == "SAME" else range(1, 3)
    214             for kernel in kernel_sizes:
    215               f_shape = [kernel, kernel, 2, 2]
    216               f1 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
    217               f2 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32)
    218 
    219               def combined_op(converted_input, num_spatial_dims, padding_arg):  # pylint: disable=unused-argument
    220                 # pylint: disable=cell-var-from-loop
    221                 result = nn_ops.convolution(
    222                     input=converted_input, filter=f1, padding=padding)
    223                 result = nn_ops.convolution(
    224                     input=result, filter=f2, padding=padding)
    225                 # pylint: enable=cell-var-from-loop
    226                 return result
    227 
    228               for rate_height in range(2, 4):
    229                 for rate_width in range(2, 4):
    230                   dilation_rate = [rate_height, rate_width]
    231                   y1 = nn_ops.convolution(
    232                       input=x,
    233                       filter=f1,
    234                       padding=padding,
    235                       dilation_rate=dilation_rate)
    236                   y1 = nn_ops.convolution(
    237                       input=y1,
    238                       filter=f2,
    239                       padding=padding,
    240                       dilation_rate=dilation_rate)
    241                   y2 = nn_ops.with_space_to_batch(
    242                       input=x,
    243                       dilation_rate=dilation_rate,
    244                       op=combined_op,
    245                       padding="VALID")
    246 
    247                   def check(y1_eval, y2_eval):
    248                     self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2)
    249 
    250                   add_check(check, y1, y2)
    251 
    252   def _test_gradient(self, x_shape, f_shape, dilation_rate, padding):
    253     x_val = np.random.random_sample(x_shape).astype(np.float32)
    254     f_val = np.random.random_sample(f_shape).astype(np.float32)
    255     x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
    256     f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
    257     output = nn_ops.convolution(
    258         input=x, filter=f, dilation_rate=dilation_rate, padding=padding)
    259     y_shape = output.get_shape().as_list()
    260     err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape],
    261                                                   output, y_shape)
    262     err_tolerance = 1e-3
    263     self.assertLess(err, err_tolerance)
    264 
    265   def testGradient(self):
    266     with self.test_session():
    267       for padding in ["SAME", "VALID"]:
    268         for rate_width in range(1, 3):
    269           for rate_height in range(1, 3):
    270             self._test_gradient(
    271                 x_shape=[2, 5, 6, 2],
    272                 f_shape=[3, 3, 2, 2],
    273                 dilation_rate=[rate_height, rate_width],
    274                 padding=padding)
    275 
    276 
    277 if __name__ == "__main__":
    278   test.main()
    279