Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for pooling operations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gen_nn_ops
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.platform import googletest
     30 
     31 
     32 def NHWCToNCHW(input_tensor):
     33   """Convert the input from NHWC format to NCHW.
     34 
     35   Args:
     36     input_tensor:  a 4-D tensor, or a 4-element array representing the same.
     37 
     38   Returns:
     39     the converted tensor or a shape array
     40   """
     41   if isinstance(input_tensor, ops.Tensor):
     42     return array_ops.transpose(input_tensor, [0, 3, 1, 2])
     43   else:
     44     return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]]
     45 
     46 
     47 def NCHWToNHWC(input_tensor):
     48   """Convert the input from NCHW format to NHWC.
     49 
     50   Args:
     51     input_tensor:  a 4-D tensor, or a 4-element array representing the same.
     52 
     53   Returns:
     54     the converted tensor or a shape array
     55   """
     56   if isinstance(input_tensor, ops.Tensor):
     57     return array_ops.transpose(input_tensor, [0, 2, 3, 1])
     58   else:
     59     return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]]
     60 
     61 
     62 def GetTestConfigs():
     63   """Get all the valid tests configs to run.
     64 
     65   Returns:
     66     all the valid test configs
     67   """
     68   test_configs = ["NHWC", "NCHW"]
     69   return test_configs
     70 
     71 
     72 class PoolingTest(XLATestCase):
     73 
     74   def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
     75                      data_format, expected):
     76     """Verifies the output values of the pooling function.
     77 
     78     Args:
     79       pool_func: Function to be called, currently only co.MaxPool.
     80       input_sizes: Input tensor dimensions.
     81       ksize: The kernel size dimensions
     82       strides: The stride dimensions
     83       padding: Padding type.
     84       data_format: The data format we use to run the pooling operation.
     85       expected: An array containing the expected operation outputs.
     86     """
     87     total_size = np.prod(input_sizes)
     88     # Initializes the input tensor with array containing incrementing
     89     # numbers from 1.
     90     x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
     91     x = x.reshape(input_sizes)
     92     with self.test_session() as sess:
     93       with self.test_scope():
     94         inputs = array_ops.placeholder(dtypes.float32)
     95         t = inputs
     96         if data_format == "NCHW":
     97           t = NHWCToNCHW(t)
     98           ksize = NHWCToNCHW(ksize)
     99           strides = NHWCToNCHW(strides)
    100         t = pool_func(t,
    101                       ksize=ksize,
    102                       strides=strides,
    103                       padding=padding,
    104                       data_format=data_format)
    105         if data_format == "NCHW":
    106           t = NCHWToNHWC(t)
    107       actual = sess.run(t, {inputs: x})
    108       self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6)
    109 
    110   def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
    111                     expected):
    112     """Verifies the output values of the pooling function.
    113 
    114     Args:
    115       pool_func: Function to be called, co.MaxPool, co.AvgPool,
    116         or the Lua version.
    117       input_sizes: Input tensor dimensions.
    118       ksize: The kernel size dimensions
    119       strides: The stride dimensions
    120       padding: Padding type.
    121       expected: An array containing the expected operation outputs.
    122     """
    123     for data_format in GetTestConfigs():
    124       self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
    125                           data_format, expected)
    126 
    127   def testMaxPoolValidPadding(self):
    128     expected_output = [13.0, 14.0, 15.0]
    129     self._VerifyValues(nn_ops.max_pool,
    130                        input_sizes=[1, 3, 3, 3],
    131                        ksize=[1, 2, 2, 1],
    132                        strides=[1, 2, 2, 1],
    133                        padding="VALID",
    134                        expected=expected_output)
    135 
    136   def testMaxPoolSamePadding(self):
    137     expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
    138     self._VerifyValues(nn_ops.max_pool,
    139                        input_sizes=[1, 2, 3, 3],
    140                        ksize=[1, 2, 2, 1],
    141                        strides=[1, 2, 2, 1],
    142                        padding="SAME",
    143                        expected=expected_output)
    144 
    145   def testMaxPoolSamePaddingNonSquareWindow(self):
    146     # input is:
    147     # [1.0, 2.0
    148     #  3.0  4.0]
    149     #
    150     # Window of [x, x] should do:
    151     #
    152     #  [max(1.0, 2.0), max(2.0, padded0),
    153     #   max(3.0, 4.0), max(4.0, padded0)]
    154     self._VerifyValues(
    155         nn_ops.max_pool,
    156         input_sizes=[1, 2, 2, 1],
    157         ksize=[1, 1, 2, 1],
    158         strides=[1, 1, 1, 1],
    159         padding="SAME",
    160         expected=[2.0, 2.0, 4.0, 4.0])
    161 
    162   def testMaxPoolValidPaddingUnevenStride(self):
    163     self._VerifyValues(
    164         nn_ops.max_pool,
    165         input_sizes=[1, 4, 4, 1],
    166         ksize=[1, 2, 2, 1],
    167         strides=[1, 1, 2, 1],
    168         padding="VALID",
    169         expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0])
    170     self._VerifyValues(
    171         nn_ops.max_pool,
    172         input_sizes=[1, 4, 4, 1],
    173         ksize=[1, 2, 2, 1],
    174         strides=[1, 2, 1, 1],
    175         padding="VALID",
    176         expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0])
    177 
    178   def testMaxPoolSamePaddingFilter4(self):
    179     expected_output = [
    180         21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
    181         61.0, 62.0, 63.0, 64.0
    182     ]
    183     self._VerifyValues(
    184         nn_ops.max_pool,
    185         input_sizes=[1, 4, 4, 4],
    186         ksize=[1, 2, 2, 1],
    187         strides=[1, 2, 2, 1],
    188         padding="SAME",
    189         expected=expected_output)
    190 
    191   def testMaxPoolSamePaddingFilter8(self):
    192     expected_output = [
    193         145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
    194         163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
    195         181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
    196         191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
    197         289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
    198         307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
    199         317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
    200         407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
    201         433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
    202         443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
    203         469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
    204         487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
    205         505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
    206     ]
    207     self._VerifyValues(
    208         nn_ops.max_pool,
    209         input_sizes=[1, 8, 8, 8],
    210         ksize=[1, 3, 3, 1],
    211         strides=[1, 2, 2, 1],
    212         padding="SAME",
    213         expected=expected_output)
    214 
    215   # Tests for DepthwiseMaxPooling on CPU only.
    216   def testDepthwiseMaxPool1x1DepthWindow1(self):
    217     # input is:
    218     # [1.0, ..., 10.0] along depth,
    219     #
    220     # We maxpool by depth in patches of 2.
    221     self._VerifyValues(
    222         nn_ops.max_pool,
    223         input_sizes=[1, 1, 1, 10],
    224         ksize=[1, 1, 1, 2],
    225         strides=[1, 1, 1, 2],
    226         padding="SAME",
    227         expected=[2.0, 4.0, 6.0, 8.0, 10.0])
    228 
    229   def testDepthwiseMaxPool2x2DepthWindow3(self):
    230     # input is:
    231     #
    232     # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
    233     # output.  Each node has contiguous values, so the depthwise max
    234     # should be multiples of 3.0.
    235     self._VerifyValues(
    236         nn_ops.max_pool,
    237         input_sizes=[1, 2, 2, 6],
    238         ksize=[1, 1, 1, 3],
    239         strides=[1, 1, 1, 3],
    240         padding="SAME",
    241         expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0])
    242 
    243   def testKernelSmallerThanStrideValid(self):
    244     self._VerifyValues(
    245         nn_ops.max_pool,
    246         input_sizes=[1, 7, 7, 1],
    247         ksize=[1, 2, 2, 1],
    248         strides=[1, 3, 3, 1],
    249         padding="VALID",
    250         expected=[9, 12, 30, 33])
    251 
    252   def testKernelSmallerThanStrideSame(self):
    253     self._VerifyValues(
    254         nn_ops.max_pool,
    255         input_sizes=[1, 3, 3, 1],
    256         ksize=[1, 1, 1, 1],
    257         strides=[1, 2, 2, 1],
    258         padding="SAME",
    259         expected=[1, 3, 7, 9])
    260 
    261     self._VerifyValues(
    262         nn_ops.max_pool,
    263         input_sizes=[1, 4, 4, 1],
    264         ksize=[1, 1, 1, 1],
    265         strides=[1, 2, 2, 1],
    266         padding="SAME",
    267         expected=[1, 3, 9, 11])
    268 
    269   # Average pooling
    270   def testAvgPoolValidPadding(self):
    271     expected_output = [7, 8, 9]
    272     self._VerifyValues(
    273         nn_ops.avg_pool,
    274         input_sizes=[1, 3, 3, 3],
    275         ksize=[1, 2, 2, 1],
    276         strides=[1, 2, 2, 1],
    277         padding="VALID",
    278         expected=expected_output)
    279 
    280   def testAvgPoolSamePadding(self):
    281     expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
    282     self._VerifyValues(
    283         nn_ops.avg_pool,
    284         input_sizes=[1, 2, 3, 3],
    285         ksize=[1, 2, 2, 1],
    286         strides=[1, 2, 2, 1],
    287         padding="SAME",
    288         expected=expected_output)
    289 
    290 
    291 class PoolGradTest(XLATestCase):
    292 
    293   CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
    294 
    295   def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize,
    296                      strides, padding, data_format):
    297     """Verifies the output values of the pooling gradient function.
    298 
    299     Args:
    300       pool_func: Forward pooling function
    301       pool_grad_func: Pooling gradient function for pool_grad_func
    302       input_sizes: Input tensor dimensions.
    303       ksize: The kernel size dimensions
    304       strides: The stride dimensions
    305       padding: Padding type.
    306       data_format: The data format we use to run the pooling operation.
    307     """
    308     total_size = np.prod(input_sizes)
    309     x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
    310     with self.test_session() as sess:
    311       # Use the forward pool function to compute some corresponding outputs
    312       # (needed for the CPU device, and we need the shape in both cases).
    313       with ops.device(self.CPU_DEVICE):
    314         inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
    315         outputs = pool_func(
    316             inputs,
    317             ksize=ksize,
    318             strides=strides,
    319             padding=padding,
    320             data_format="NHWC")
    321 
    322       output_vals = np.array(sess.run(outputs, {inputs: x}))
    323       output_gradient_vals = np.arange(
    324           1, output_vals.size + 1, dtype=np.float32)
    325       output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
    326 
    327       # Use the Tensorflow CPU pooling gradient to compute the expected input
    328       # gradients.
    329       with ops.device(self.CPU_DEVICE):
    330         output_gradients = array_ops.placeholder(
    331             dtypes.float32, shape=output_vals.shape)
    332         expected_input_gradients = pool_grad_func(
    333             inputs,
    334             outputs,
    335             output_gradients,
    336             ksize=ksize,
    337             strides=strides,
    338             padding=padding,
    339             data_format="NHWC")
    340         expected_input_gradient_vals = sess.run(
    341             expected_input_gradients,
    342             {inputs: x,
    343              output_gradients: output_gradient_vals})
    344 
    345       # Run the gradient op on the XLA device
    346       with self.test_scope():
    347         outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
    348         xla_inputs = inputs
    349         xla_outputs = outputs
    350         xla_output_gradients = output_gradients
    351         xla_ksize = ksize
    352         xla_strides = strides
    353         if data_format == "NCHW":
    354           xla_inputs = NHWCToNCHW(inputs)
    355           xla_outputs = NHWCToNCHW(outputs)
    356           xla_output_gradients = NHWCToNCHW(output_gradients)
    357           xla_ksize = NHWCToNCHW(ksize)
    358           xla_strides = NHWCToNCHW(strides)
    359         actual_input_gradients = pool_grad_func(
    360             xla_inputs,
    361             xla_outputs,
    362             xla_output_gradients,
    363             ksize=xla_ksize,
    364             strides=xla_strides,
    365             padding=padding,
    366             data_format=data_format)
    367         if data_format == "NCHW":
    368           actual_input_gradients = NCHWToNHWC(actual_input_gradients)
    369       actual = sess.run(actual_input_gradients, {
    370           inputs: x,
    371           outputs: output_vals,
    372           output_gradients: output_gradient_vals
    373       })
    374 
    375       # Compare the Tensorflow and XLA results.
    376       self.assertAllClose(
    377           expected_input_gradient_vals.flatten(),
    378           actual.flatten(),
    379           rtol=1e-4,
    380           atol=1e-6)
    381       self.assertShapeEqual(actual, inputs)
    382 
    383   def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize,
    384                     strides, padding):
    385     """Verifies the output values of the pooling function.
    386 
    387     Args:
    388       pool_func: Pooling function to be called, e.g., tf.nn.max_pool
    389       pool_grad_func: Corresponding pooling gradient function.
    390       input_sizes: Input tensor dimensions.
    391       ksize: The kernel size dimensions
    392       strides: The stride dimensions
    393       padding: Padding type.
    394     """
    395     for data_format in GetTestConfigs():
    396       self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize,
    397                           strides, padding, data_format)
    398 
    399   def _TestPooling(self, forward_op, backward_op):
    400     # VALID padding
    401     self._VerifyValues(
    402         forward_op,
    403         backward_op,
    404         input_sizes=[1, 3, 3, 3],
    405         ksize=[1, 2, 2, 1],
    406         strides=[1, 2, 2, 1],
    407         padding="VALID")
    408 
    409     # SAME padding
    410     self._VerifyValues(
    411         forward_op,
    412         backward_op,
    413         input_sizes=[1, 2, 3, 3],
    414         ksize=[1, 2, 2, 1],
    415         strides=[1, 2, 2, 1],
    416         padding="SAME")
    417 
    418     # SAME padding, non square window
    419     self._VerifyValues(
    420         forward_op,
    421         backward_op,
    422         input_sizes=[1, 2, 2, 1],
    423         ksize=[1, 1, 2, 1],
    424         strides=[1, 1, 1, 1],
    425         padding="SAME")
    426 
    427     # VALID padding, uneven stride
    428     self._VerifyValues(
    429         forward_op,
    430         backward_op,
    431         input_sizes=[1, 4, 4, 1],
    432         ksize=[1, 2, 2, 1],
    433         strides=[1, 1, 2, 1],
    434         padding="VALID")
    435     self._VerifyValues(
    436         forward_op,
    437         backward_op,
    438         input_sizes=[1, 4, 4, 1],
    439         ksize=[1, 2, 2, 1],
    440         strides=[1, 2, 1, 1],
    441         padding="VALID")
    442 
    443     # SAME padding, size 4 input
    444     self._VerifyValues(
    445         forward_op,
    446         backward_op,
    447         input_sizes=[1, 4, 4, 4],
    448         ksize=[1, 2, 2, 1],
    449         strides=[1, 2, 2, 1],
    450         padding="SAME")
    451 
    452     # SAME padding, size 8 input
    453     self._VerifyValues(
    454         forward_op,
    455         backward_op,
    456         input_sizes=[1, 8, 8, 8],
    457         ksize=[1, 3, 3, 1],
    458         strides=[1, 2, 2, 1],
    459         padding="SAME")
    460 
    461   def testMaxPool(self):
    462     self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad)
    463 
    464   def testAvgPool(self):
    465     # Wrapper around AvgPoolGrad that ignores extra arguments needed by
    466     # MaxPoolGrad.
    467     def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
    468                     data_format):
    469       del outputs  # Unused by average-pooling gradients.
    470       return gen_nn_ops._avg_pool_grad(
    471           inputs.get_shape().as_list(),
    472           output_gradients,
    473           ksize=ksize,
    474           strides=strides,
    475           padding=padding,
    476           data_format=data_format)
    477 
    478     self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
    479 
    480   # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
    481   # the stride size, so we only run the following tests on MaxPoolGrad.
    482 
    483   def testMaxPoolKernelSmallerThanStrideValid(self):
    484     self._VerifyValues(
    485         nn_ops.max_pool,
    486         gen_nn_ops._max_pool_grad,
    487         input_sizes=[1, 7, 7, 1],
    488         ksize=[1, 2, 2, 1],
    489         strides=[1, 3, 3, 1],
    490         padding="VALID")
    491 
    492   def testMaxPoolKernelSmallerThanStrideSame(self):
    493     self._VerifyValues(
    494         nn_ops.max_pool,
    495         gen_nn_ops._max_pool_grad,
    496         input_sizes=[1, 3, 3, 1],
    497         ksize=[1, 1, 1, 1],
    498         strides=[1, 2, 2, 1],
    499         padding="SAME")
    500 
    501     self._VerifyValues(
    502         nn_ops.max_pool,
    503         gen_nn_ops._max_pool_grad,
    504         input_sizes=[1, 4, 4, 1],
    505         ksize=[1, 1, 1, 1],
    506         strides=[1, 2, 2, 1],
    507         padding="SAME")
    508 
    509 
    510 if __name__ == "__main__":
    511   googletest.main()
    512