1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for atrous convolution functionality in tensorflow.ops.nn.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import contextlib 22 23 import numpy as np 24 25 from tensorflow.python.eager import context 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import test_util 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import gradient_checker 31 from tensorflow.python.ops import nn_ops 32 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 33 from tensorflow.python.platform import test 34 35 36 def upsample_filters(filters, rate): 37 """Upsamples the filters by a factor of rate along the spatial dimensions. 38 39 Args: 40 filters: spatial_shape + [in_channels, out_channels] 41 Original filters. 42 rate: A list of len(spatial_shape) positive ints, specifying the 43 upsampling rate. 44 45 Returns: 46 filters_up: output_spatial_shape + [in_channels, out_channels]. 47 Upsampled filters with 48 output_spatial_shape[i] = (spatial_shape[i] - 1) * rate[i] + 1 49 containing (rate[i] - 1) zeros between consecutive filter values along 50 spatial dimension i. 51 """ 52 num_spatial_dims = len(rate) 53 spatial_shape = np.array(filters.shape[:num_spatial_dims]) 54 output_spatial_shape = (spatial_shape - 1) * rate + 1 55 output = np.zeros( 56 tuple(output_spatial_shape) + tuple(filters.shape[-2:]), filters.dtype) 57 output[tuple(np.s_[::rate[i]] for i in range(num_spatial_dims))] = filters 58 return output 59 60 61 class AtrousConvolutionTest(test.TestCase): 62 63 @contextlib.contextmanager 64 def _delay_checks(self): 65 """Context manager for combining checks depending on tensor evaluations. 66 67 Each call to Session.run has some overhead, and this overhead can easily 68 account for the majority of the time spent in tests that call Session.run 69 (or Tensor.eval) many times. 70 71 This context manager provides a mechanism for registering callback functions 72 and associated tensors. When the context is exited, all of the tensors 73 associated with all of the registrations are evaluated with a single call to 74 Session.run, and then each registered callback function is called with the 75 values of its associated tensors. 76 77 Yields: 78 A function `add_check(check, *args, **kwargs)` where `check` is the 79 callback function to be invoked, and `*args` and `**kwargs` specify the 80 associated Tensors. When in EAGER mode, check is executed in add_check, 81 otherwise, it's delayed after the context. 82 """ 83 checks = [] 84 85 def add_check(check, *args, **kwargs): 86 if context.executing_eagerly(): 87 args_val, kwargs_val = self.evaluate([args, kwargs]) 88 check(*args_val, **kwargs_val) 89 else: 90 checks.append((check, args, kwargs)) 91 92 yield add_check 93 if not context.executing_eagerly(): 94 all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks]) 95 for (check, _, _), (args, kwargs) in zip(checks, all_values): 96 check(*args, **kwargs) 97 98 def _test_atrous_convolution(self, add_check, input_shape, filter_shape, 99 dilation_rate, **kwargs): 100 filters = np.arange( 101 np.prod(filter_shape), dtype=np.float32).reshape(filter_shape) 102 filters_upsampled = upsample_filters(filters, dilation_rate) 103 x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) 104 y1 = nn_ops.convolution( 105 input=x, filter=filters, dilation_rate=dilation_rate, **kwargs) 106 y2 = nn_ops.convolution(input=x, filter=filters_upsampled, **kwargs) 107 108 def check(y1_eval, y2_eval): 109 self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2) 110 111 add_check(check, y1, y2) 112 113 @test_util.run_v1_only("b/120545219") 114 def test_unknown_spatial_dims_for_channel_last_format(self): 115 x = array_ops.placeholder(dtypes.float32, [1, None, None, 10]) 116 w = array_ops.zeros([3, 3, 10, 20]) 117 y = nn_ops.convolution( 118 x, w, "VALID", dilation_rate=[2, 2], data_format="NHWC") 119 self.assertEqual(y.shape.as_list(), [1, None, None, 20]) 120 121 @test_util.run_v1_only("b/120545219") 122 def test_unknown_spatial_dims_for_channel_first_format(self): 123 x = array_ops.placeholder(dtypes.float32, [1, 10, None, None]) 124 w = array_ops.zeros([3, 3, 10, 20]) 125 y = nn_ops.convolution( 126 x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") 127 self.assertEqual(y.shape.as_list(), [1, 20, None, None]) 128 129 @test_util.run_in_graph_and_eager_modes 130 def testAtrousConvolution2D(self): 131 with self._delay_checks() as add_check: 132 for padding in ["SAME", "VALID"]: 133 for height, width in [[9, 9], [9, 10]]: 134 for kernel_height, kernel_width in [[1, 1], [2, 2], [2, 3]]: 135 for dilation_rate in [[1, 1], [3, 2], [2, 1]]: 136 self._test_atrous_convolution( 137 add_check=add_check, 138 input_shape=[2, height, width, 2], 139 filter_shape=[kernel_height, kernel_width, 2, 2], 140 padding=padding, 141 dilation_rate=dilation_rate, 142 ) 143 144 @test_util.run_in_graph_and_eager_modes 145 def testAtrousConvolution3D(self): 146 with self._delay_checks() as add_check: 147 for padding in ["SAME", "VALID"]: 148 for depth, height, width in [[9, 9, 10], [9, 10, 9]]: 149 for kernel_depth, kernel_height, kernel_width in [[3, 3, 150 3], [3, 2, 2], 151 [2, 1, 3]]: 152 for dilation_rate in [[1, 1, 1], [3, 3, 3], [3, 2, 3], [3, 1, 2]]: 153 self._test_atrous_convolution( 154 add_check=add_check, 155 input_shape=[2, depth, height, width, 2], 156 filter_shape=[ 157 kernel_depth, kernel_height, kernel_width, 2, 2 158 ], 159 padding=padding, 160 dilation_rate=dilation_rate, 161 ) 162 163 @test_util.run_in_graph_and_eager_modes 164 def testAtrousConvolution1D(self): 165 with self._delay_checks() as add_check: 166 for padding in ["SAME", "VALID"]: 167 for width in [9, 10]: 168 for kernel_width in range(1, 4): 169 for rate in range(1, 4): 170 self._test_atrous_convolution( 171 add_check=add_check, 172 input_shape=[2, width, 2], 173 filter_shape=[kernel_width, 2, 2], 174 padding=padding, 175 dilation_rate=[rate], 176 ) 177 178 @test_util.run_in_graph_and_eager_modes 179 def testAtrousConvolutionNC(self): 180 if test.is_gpu_available(cuda_only=True): 181 # "NCW" and "NCHW" formats are currently supported only on CUDA. 182 with test_util.device(use_gpu=True): 183 with self._delay_checks() as add_check: 184 for padding in ["SAME", "VALID"]: 185 self._test_atrous_convolution( 186 add_check=add_check, 187 input_shape=[2, 2, 9], 188 padding=padding, 189 filter_shape=[3, 2, 2], 190 dilation_rate=[2], 191 data_format="NCW", 192 ) 193 self._test_atrous_convolution( 194 add_check=add_check, 195 input_shape=[2, 2, 9, 5], 196 padding=padding, 197 filter_shape=[3, 3, 2, 2], 198 dilation_rate=[2, 1], 199 data_format="NCHW", 200 ) 201 202 @test_util.run_in_graph_and_eager_modes 203 def testAtrousSequence(self): 204 """Tests optimization of sequence of atrous convolutions. 205 206 See the documentation of with_space_to_batch. 207 """ 208 with self._delay_checks() as add_check: 209 for padding in ["SAME", "VALID"]: 210 for height in range(15, 17): 211 for width in range(15, 17): 212 x_shape = [3, height, width, 2] 213 x = np.random.random_sample(x_shape).astype(np.float32) 214 215 kernel_sizes = [1, 3] if padding == "SAME" else range(1, 3) 216 for kernel in kernel_sizes: 217 f_shape = [kernel, kernel, 2, 2] 218 f1 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32) 219 f2 = 1e-2 * np.random.random_sample(f_shape).astype(np.float32) 220 221 def combined_op(converted_input, num_spatial_dims, padding_arg): # pylint: disable=unused-argument 222 # pylint: disable=cell-var-from-loop 223 result = nn_ops.convolution( 224 input=converted_input, filter=f1, padding=padding) 225 result = nn_ops.convolution( 226 input=result, filter=f2, padding=padding) 227 # pylint: enable=cell-var-from-loop 228 return result 229 230 for rate_height in range(2, 4): 231 for rate_width in range(2, 4): 232 dilation_rate = [rate_height, rate_width] 233 y1 = nn_ops.convolution( 234 input=x, 235 filter=f1, 236 padding=padding, 237 dilation_rate=dilation_rate) 238 y1 = nn_ops.convolution( 239 input=y1, 240 filter=f2, 241 padding=padding, 242 dilation_rate=dilation_rate) 243 y2 = nn_ops.with_space_to_batch( 244 input=x, 245 dilation_rate=dilation_rate, 246 op=combined_op, 247 padding="VALID") 248 249 def check(y1_eval, y2_eval): 250 self.assertAllClose(y1_eval, y2_eval, rtol=1e-2, atol=1e-2) 251 252 add_check(check, y1, y2) 253 254 def _test_gradient(self, x_shape, f_shape, dilation_rate, padding): 255 x_val = np.random.random_sample(x_shape).astype(np.float32) 256 f_val = np.random.random_sample(f_shape).astype(np.float32) 257 x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) 258 f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) 259 output = nn_ops.convolution( 260 input=x, filter=f, dilation_rate=dilation_rate, padding=padding) 261 y_shape = output.get_shape().as_list() 262 err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape], 263 output, y_shape) 264 err_tolerance = 1e-3 265 self.assertLess(err, err_tolerance) 266 267 @test_util.run_v1_only("b/120545219") 268 def testGradient(self): 269 with self.cached_session(): 270 for padding in ["SAME", "VALID"]: 271 for rate_width in range(1, 3): 272 for rate_height in range(1, 3): 273 self._test_gradient( 274 x_shape=[2, 5, 6, 2], 275 f_shape=[3, 3, 2, 2], 276 dilation_rate=[rate_height, rate_width], 277 padding=padding) 278 279 280 if __name__ == "__main__": 281 test.main() 282