Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for 3d pooling operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gen_nn_ops
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 # Wrapper around AvgPoolGrad that ignores extra arguments needed by
     33 # MaxPoolGrad.
     34 def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding):
     35   del outputs  # Unused by average-pooling gradients.
     36   return gen_nn_ops._avg_pool3d_grad(
     37       inputs.get_shape().as_list(),
     38       output_gradients,
     39       ksize=ksize,
     40       strides=strides,
     41       padding=padding)
     42 
     43 
     44 class Pooling3DTest(XLATestCase):
     45 
     46   def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
     47                     expected):
     48     """Verifies the output values of the pooling function.
     49 
     50     Args:
     51       pool_func: Function to be called: co.MaxPool, co.AvgPool.
     52       input_sizes: Input tensor dimensions.
     53       window: Tuple of kernel dims: planes, rows, cols.
     54       strides: Tuple of strides for dims: planes, rows, cols.
     55       padding: Padding type.
     56       expected: An array containing the expected operation outputs.
     57     """
     58     total_size = 1
     59     for s in input_sizes:
     60       total_size *= s
     61     # Initializes the input tensor with array containing incrementing
     62     # numbers from 1.
     63     x = np.arange(1.0, total_size + 1, dtype=np.float32)
     64     x = x.reshape(input_sizes)
     65     with self.test_session() as sess, self.test_scope():
     66       inputs = array_ops.placeholder(dtypes.float32)
     67       t = pool_func(
     68           inputs,
     69           ksize=[1] + window + [1],
     70           strides=[1] + strides + [1],
     71           padding=padding)
     72       vals = sess.run(t, {inputs: x})
     73     # Verifies values.
     74     actual = vals.flatten()
     75     self.assertAllClose(expected, actual)
     76 
     77   def testAvgPool3dValidPadding(self):
     78     expected_output = [20.5, 21.5, 22.5]
     79     self._VerifyValues(
     80         nn_ops.avg_pool3d,
     81         input_sizes=[1, 3, 3, 3, 3],
     82         window=[2, 2, 2],
     83         strides=[2, 2, 2],
     84         padding="VALID",
     85         expected=expected_output)
     86 
     87   def testAvgPool3dSamePadding(self):
     88     expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5]
     89     self._VerifyValues(
     90         nn_ops.avg_pool3d,
     91         input_sizes=[1, 2, 2, 4, 3],
     92         window=[2, 2, 2],
     93         strides=[2, 2, 2],
     94         padding="SAME",
     95         expected=expected_output)
     96 
     97   def testAvgPool3dSamePaddingDifferentStrides(self):
     98     expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5]
     99     self._VerifyValues(
    100         nn_ops.avg_pool3d,
    101         input_sizes=[1, 5, 8, 1, 1],
    102         window=[1, 2, 3],
    103         strides=[2, 3, 1],
    104         padding="SAME",
    105         expected=expected_output)
    106 
    107   def testMaxPool3dValidPadding(self):
    108     expected_output = [40.0, 41.0, 42.0]
    109     self._VerifyValues(
    110         nn_ops.max_pool3d,
    111         input_sizes=[1, 3, 3, 3, 3],
    112         window=[2, 2, 2],
    113         strides=[2, 2, 2],
    114         padding="VALID",
    115         expected=expected_output)
    116 
    117   def testMaxPool3dSamePadding(self):
    118     expected_output = [31., 32., 33., 34., 35., 36.]
    119     self._VerifyValues(
    120         nn_ops.max_pool3d,
    121         input_sizes=[1, 2, 2, 3, 3],
    122         window=[2, 2, 2],
    123         strides=[2, 2, 2],
    124         padding="SAME",
    125         expected=expected_output)
    126 
    127   def testMaxPool3dSamePaddingDifferentStrides(self):
    128     expected_output = [2., 5., 8., 18., 21., 24., 34., 37., 40.]
    129     self._VerifyValues(
    130         nn_ops.max_pool3d,
    131         input_sizes=[1, 5, 8, 1, 1],
    132         window=[1, 2, 3],
    133         strides=[2, 3, 1],
    134         padding="SAME",
    135         expected=expected_output)
    136 
    137     # Test pooling on a larger input, with different stride and kernel
    138     # size for the 'z' dimension.
    139 
    140     # Simulate max pooling in numpy to get the expected output.
    141     input_data = np.arange(1, 5 * 27 * 27 * 64 + 1).reshape((5, 27, 27, 64))
    142     input_data = np.pad(input_data, [[0, 0], [0, 1], [0, 1], [0, 0]],
    143                         mode="constant")
    144     expected_output = input_data[:, 1::2, 1::2, :]
    145     expected_output[:, -1, :, :] = input_data[:, -2, 1::2, :]
    146     expected_output[:, :, -1, :] = input_data[:, 1::2, -2, :]
    147     expected_output[:, -1, -1, :] = input_data[:, -2, -2, :]
    148 
    149     self._VerifyValues(
    150         nn_ops.max_pool3d,
    151         input_sizes=[1, 5, 27, 27, 64],
    152         window=[1, 2, 2],
    153         strides=[1, 2, 2],
    154         padding="SAME",
    155         expected=expected_output.flatten())
    156 
    157   def testKernelSmallerThanStride(self):
    158     self._VerifyValues(
    159         nn_ops.max_pool3d,
    160         input_sizes=[1, 3, 3, 3, 1],
    161         window=[1, 1, 1],
    162         strides=[2, 2, 2],
    163         padding="SAME",
    164         expected=[1, 3, 7, 9, 19, 21, 25, 27])
    165 
    166     self._VerifyValues(
    167         nn_ops.max_pool3d,
    168         input_sizes=[1, 7, 7, 7, 1],
    169         window=[2, 2, 2],
    170         strides=[3, 3, 3],
    171         padding="VALID",
    172         expected=[58, 61, 79, 82, 205, 208, 226, 229])
    173 
    174     self._VerifyValues(
    175         nn_ops.avg_pool3d,
    176         input_sizes=[1, 3, 3, 3, 1],
    177         window=[1, 1, 1],
    178         strides=[2, 2, 2],
    179         padding="SAME",
    180         expected=[1, 3, 7, 9, 19, 21, 25, 27])
    181 
    182     self._VerifyValues(
    183         nn_ops.avg_pool3d,
    184         input_sizes=[1, 7, 7, 7, 1],
    185         window=[2, 2, 2],
    186         strides=[3, 3, 3],
    187         padding="VALID",
    188         expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
    189 
    190   def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize,
    191                       strides, padding):
    192     """Verifies the output values of the pooling gradient function.
    193 
    194     Args:
    195       pool_func: Forward pooling function
    196       pool_grad_func: Pooling gradient function for pool_grad_func
    197       input_sizes: Input tensor dimensions.
    198       ksize: The kernel size dimensions
    199       strides: The stride dimensions
    200       padding: Padding type.
    201     """
    202     ksize = [1] + ksize + [1]
    203     strides = [1] + strides + [1]
    204     total_size = np.prod(input_sizes)
    205     x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
    206     with self.test_session() as sess:
    207       # Use the forward pool function to compute some corresponding outputs
    208       # (needed for the CPU device, and we need the shape in both cases).
    209       with ops.device("CPU"):
    210         inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
    211         outputs = pool_func(
    212             inputs,
    213             ksize=ksize,
    214             strides=strides,
    215             padding=padding)
    216 
    217       output_vals = np.array(sess.run(outputs, {inputs: x}))
    218       output_gradient_vals = np.arange(
    219           1, output_vals.size + 1, dtype=np.float32)
    220       output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
    221 
    222       # Use the Tensorflow CPU pooling gradient to compute the expected input
    223       # gradients.
    224       with ops.device("CPU"):
    225         output_gradients = array_ops.placeholder(
    226             dtypes.float32, shape=output_vals.shape)
    227         expected_input_gradients = pool_grad_func(
    228             inputs,
    229             outputs,
    230             output_gradients,
    231             ksize=ksize,
    232             strides=strides,
    233             padding=padding)
    234         expected_input_gradient_vals = sess.run(
    235             expected_input_gradients,
    236             {inputs: x,
    237              output_gradients: output_gradient_vals})
    238 
    239       # Run the gradient op on the XLA device
    240       with self.test_scope():
    241         outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
    242         actual_input_gradients = pool_grad_func(
    243             inputs,
    244             outputs,
    245             output_gradients,
    246             ksize=ksize,
    247             strides=strides,
    248             padding=padding)
    249       actual = sess.run(actual_input_gradients, {
    250           inputs: x,
    251           outputs: output_vals,
    252           output_gradients: output_gradient_vals
    253       })
    254 
    255       # Compare the Tensorflow and XLA results.
    256       self.assertAllClose(
    257           expected_input_gradient_vals.flatten(),
    258           actual.flatten(),
    259           rtol=1e-5,
    260           atol=1e-6)
    261       self.assertShapeEqual(actual, inputs)
    262 
    263   def testMaxPoolGradValidPadding1_1_3d(self):
    264     self._VerifyGradient(
    265         nn_ops.max_pool3d,
    266         gen_nn_ops._max_pool3d_grad,
    267         input_sizes=[1, 3, 3, 3, 1],
    268         ksize=[1, 1, 1],
    269         strides=[1, 1, 1],
    270         padding="VALID")
    271 
    272   def testMaxPoolGradValidPadding2_1_6_3d(self):
    273     self._VerifyGradient(
    274         nn_ops.max_pool3d,
    275         gen_nn_ops._max_pool3d_grad,
    276         input_sizes=[2, 3, 3, 6, 3],
    277         ksize=[2, 2, 2],
    278         strides=[1, 1, 1],
    279         padding="VALID")
    280 
    281   def testMaxPoolGradValidPadding2_1_7_3d(self):
    282     self._VerifyGradient(
    283         nn_ops.max_pool3d,
    284         gen_nn_ops._max_pool3d_grad,
    285         input_sizes=[2, 3, 5, 7, 3],
    286         ksize=[2, 2, 2],
    287         strides=[1, 1, 1],
    288         padding="VALID")
    289 
    290   def testMaxPoolGradValidPadding2_2_3d(self):
    291     self._VerifyGradient(
    292         nn_ops.max_pool3d,
    293         gen_nn_ops._max_pool3d_grad,
    294         input_sizes=[2, 2, 2, 2, 3],
    295         ksize=[2, 2, 2],
    296         strides=[2, 2, 2],
    297         padding="VALID")
    298 
    299   def testMaxPoolGradSamePadding1_1_3d(self):
    300     self._VerifyGradient(
    301         nn_ops.max_pool3d,
    302         gen_nn_ops._max_pool3d_grad,
    303         input_sizes=[2, 3, 2, 4, 1],
    304         ksize=[1, 1, 1],
    305         strides=[1, 1, 1],
    306         padding="SAME")
    307 
    308   def testMaxPoolGradSamePadding2_1_3d(self):
    309     self._VerifyGradient(
    310         nn_ops.max_pool3d,
    311         gen_nn_ops._max_pool3d_grad,
    312         input_sizes=[2, 3, 2, 4, 1],
    313         ksize=[2, 2, 2],
    314         strides=[1, 1, 1],
    315         padding="SAME")
    316 
    317   def testMaxPoolGradSamePadding2_2_3d(self):
    318     self._VerifyGradient(
    319         nn_ops.max_pool3d,
    320         gen_nn_ops._max_pool3d_grad,
    321         input_sizes=[2, 5, 2, 4, 3],
    322         ksize=[2, 2, 2],
    323         strides=[2, 2, 2],
    324         padding="SAME")
    325 
    326   def testMaxPoolGradSamePadding3_1_3d(self):
    327     self._VerifyGradient(
    328         nn_ops.max_pool3d,
    329         gen_nn_ops._max_pool3d_grad,
    330         input_sizes=[1, 3, 3, 7, 1],
    331         ksize=[3, 3, 3],
    332         strides=[1, 1, 1],
    333         padding="SAME")
    334 
    335   def testAvgPoolGradValidPadding1_1_3d(self):
    336     self._VerifyGradient(
    337         nn_ops.avg_pool3d,
    338         _AvgPoolGrad,
    339         input_sizes=[2, 3, 3, 3, 3],
    340         ksize=[1, 1, 1],
    341         strides=[1, 1, 1],
    342         padding="VALID")
    343 
    344   def testAvgPoolGradValidPadding2_1_3d(self):
    345     self._VerifyGradient(
    346         nn_ops.avg_pool3d,
    347         _AvgPoolGrad,
    348         input_sizes=[2, 3, 3, 3, 3],
    349         ksize=[2, 2, 2],
    350         strides=[1, 1, 1],
    351         padding="VALID")
    352 
    353   def testAvgPoolGradValidPadding2_2_3d(self):
    354     self._VerifyGradient(
    355         nn_ops.avg_pool3d,
    356         _AvgPoolGrad,
    357         input_sizes=[2, 2, 2, 2, 3],
    358         ksize=[2, 2, 2],
    359         strides=[2, 2, 2],
    360         padding="VALID")
    361 
    362   def testAvgPoolGradSamePadding1_1_3d(self):
    363     self._VerifyGradient(
    364         nn_ops.avg_pool3d,
    365         _AvgPoolGrad,
    366         input_sizes=[2, 3, 2, 4, 3],
    367         ksize=[1, 1, 1],
    368         strides=[1, 1, 1],
    369         padding="SAME")
    370 
    371   def testAvgPoolGradSamePadding2_1_3d(self):
    372     self._VerifyGradient(
    373         nn_ops.avg_pool3d,
    374         _AvgPoolGrad,
    375         input_sizes=[1, 2, 2, 2, 1],
    376         ksize=[2, 2, 2],
    377         strides=[1, 1, 1],
    378         padding="SAME")
    379 
    380   def testAvgPoolGradSamePadding2_2_3d(self):
    381     self._VerifyGradient(
    382         nn_ops.avg_pool3d,
    383         _AvgPoolGrad,
    384         input_sizes=[2, 5, 2, 4, 3],
    385         ksize=[2, 2, 2],
    386         strides=[2, 2, 2],
    387         padding="SAME")
    388 
    389   def testAvgPoolGradSamePadding3_1_3d(self):
    390     self._VerifyGradient(
    391         nn_ops.avg_pool3d,
    392         _AvgPoolGrad,
    393         input_sizes=[1, 3, 6, 7, 1],
    394         ksize=[3, 3, 3],
    395         strides=[1, 1, 1],
    396         padding="SAME")
    397 
    398 
    399 if __name__ == "__main__":
    400   test.main()
    401