Home | History | Annotate | Download | only in layers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.layers.pooling."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.layers import pooling as pooling_layers
     22 from tensorflow.python.ops import array_ops
     23 from tensorflow.python.ops import random_ops
     24 from tensorflow.python.platform import test
     25 
     26 
     27 class PoolingTest(test.TestCase):
     28 
     29   def testInvalidDataFormat(self):
     30     height, width = 7, 9
     31     images = random_ops.random_uniform((5, height, width, 3), seed=1)
     32     with self.assertRaisesRegexp(ValueError, 'data_format'):
     33       pooling_layers.max_pooling2d(images, 3, strides=2, data_format='invalid')
     34 
     35   def testInvalidStrides(self):
     36     height, width = 7, 9
     37     images = random_ops.random_uniform((5, height, width, 3), seed=1)
     38     with self.assertRaisesRegexp(ValueError, 'strides'):
     39       pooling_layers.max_pooling2d(images, 3, strides=(1, 2, 3))
     40 
     41     with self.assertRaisesRegexp(ValueError, 'strides'):
     42       pooling_layers.max_pooling2d(images, 3, strides=None)
     43 
     44   def testInvalidPoolSize(self):
     45     height, width = 7, 9
     46     images = random_ops.random_uniform((5, height, width, 3), seed=1)
     47     with self.assertRaisesRegexp(ValueError, 'pool_size'):
     48       pooling_layers.max_pooling2d(images, (1, 2, 3), strides=2)
     49 
     50     with self.assertRaisesRegexp(ValueError, 'pool_size'):
     51       pooling_layers.max_pooling2d(images, None, strides=2)
     52 
     53   def testCreateMaxPooling2D(self):
     54     height, width = 7, 9
     55     images = random_ops.random_uniform((5, height, width, 4))
     56     layer = pooling_layers.MaxPooling2D([2, 2], strides=2)
     57     output = layer.apply(images)
     58     self.assertListEqual(output.get_shape().as_list(), [5, 3, 4, 4])
     59 
     60   def testCreateAveragePooling2D(self):
     61     height, width = 7, 9
     62     images = random_ops.random_uniform((5, height, width, 4))
     63     layer = pooling_layers.AveragePooling2D([2, 2], strides=2)
     64     output = layer.apply(images)
     65     self.assertListEqual(output.get_shape().as_list(), [5, 3, 4, 4])
     66 
     67   def testCreateMaxPooling2DChannelsFirst(self):
     68     height, width = 7, 9
     69     images = random_ops.random_uniform((5, 2, height, width))
     70     layer = pooling_layers.MaxPooling2D([2, 2],
     71                                         strides=1,
     72                                         data_format='channels_first')
     73     output = layer.apply(images)
     74     self.assertListEqual(output.get_shape().as_list(), [5, 2, 6, 8])
     75 
     76   def testCreateAveragePooling2DChannelsFirst(self):
     77     height, width = 5, 6
     78     images = random_ops.random_uniform((3, 4, height, width))
     79     layer = pooling_layers.AveragePooling2D((2, 2),
     80                                             strides=(1, 1),
     81                                             padding='valid',
     82                                             data_format='channels_first')
     83     output = layer.apply(images)
     84     self.assertListEqual(output.get_shape().as_list(), [3, 4, 4, 5])
     85 
     86   def testCreateAveragePooling2DChannelsFirstWithNoneBatch(self):
     87     height, width = 5, 6
     88     images = array_ops.placeholder(dtype='float32',
     89                                    shape=(None, 4, height, width))
     90     layer = pooling_layers.AveragePooling2D((2, 2),
     91                                             strides=(1, 1),
     92                                             padding='valid',
     93                                             data_format='channels_first')
     94     output = layer.apply(images)
     95     self.assertListEqual(output.get_shape().as_list(), [None, 4, 4, 5])
     96 
     97   def testCreateMaxPooling1D(self):
     98     width = 7
     99     channels = 3
    100     images = random_ops.random_uniform((5, width, channels))
    101     layer = pooling_layers.MaxPooling1D(2, strides=2)
    102     output = layer.apply(images)
    103     self.assertListEqual(output.get_shape().as_list(),
    104                          [5, width // 2, channels])
    105 
    106   def testCreateAveragePooling1D(self):
    107     width = 7
    108     channels = 3
    109     images = random_ops.random_uniform((5, width, channels))
    110     layer = pooling_layers.AveragePooling1D(2, strides=2)
    111     output = layer.apply(images)
    112     self.assertListEqual(output.get_shape().as_list(),
    113                          [5, width // 2, channels])
    114 
    115   def testCreateMaxPooling1DChannelsFirst(self):
    116     width = 7
    117     channels = 3
    118     images = random_ops.random_uniform((5, channels, width))
    119     layer = pooling_layers.MaxPooling1D(
    120         2, strides=2, data_format='channels_first')
    121     output = layer.apply(images)
    122     self.assertListEqual(output.get_shape().as_list(),
    123                          [5, channels, width // 2])
    124 
    125   def testCreateAveragePooling1DChannelsFirst(self):
    126     width = 7
    127     channels = 3
    128     images = random_ops.random_uniform((5, channels, width))
    129     layer = pooling_layers.AveragePooling1D(
    130         2, strides=2, data_format='channels_first')
    131     output = layer.apply(images)
    132     self.assertListEqual(output.get_shape().as_list(),
    133                          [5, channels, width // 2])
    134 
    135   def testCreateMaxPooling3D(self):
    136     depth, height, width = 6, 7, 9
    137     images = random_ops.random_uniform((5, depth, height, width, 4))
    138     layer = pooling_layers.MaxPooling3D([2, 2, 2], strides=2)
    139     output = layer.apply(images)
    140     self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 4, 4])
    141 
    142   def testCreateAveragePooling3D(self):
    143     depth, height, width = 6, 7, 9
    144     images = random_ops.random_uniform((5, depth, height, width, 4))
    145     layer = pooling_layers.AveragePooling3D([2, 2, 2], strides=2)
    146     output = layer.apply(images)
    147     self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 4, 4])
    148 
    149   def testMaxPooling3DChannelsFirst(self):
    150     depth, height, width = 6, 7, 9
    151     images = random_ops.random_uniform((5, 2, depth, height, width))
    152     layer = pooling_layers.MaxPooling3D(
    153         [2, 2, 2], strides=2, data_format='channels_first')
    154     output = layer.apply(images)
    155     self.assertListEqual(output.get_shape().as_list(), [5, 2, 3, 3, 4])
    156 
    157   def testAveragePooling3DChannelsFirst(self):
    158     depth, height, width = 6, 7, 9
    159     images = random_ops.random_uniform((5, 2, depth, height, width))
    160     layer = pooling_layers.AveragePooling3D(
    161         [2, 2, 2], strides=2, data_format='channels_first')
    162     output = layer.apply(images)
    163     self.assertListEqual(output.get_shape().as_list(), [5, 2, 3, 3, 4])
    164 
    165   def testCreateMaxPooling2DIntegerPoolSize(self):
    166     height, width = 7, 9
    167     images = random_ops.random_uniform((5, height, width, 4))
    168     layer = pooling_layers.MaxPooling2D(2, strides=2)
    169     output = layer.apply(images)
    170     self.assertListEqual(output.get_shape().as_list(), [5, 3, 4, 4])
    171 
    172   def testMaxPooling2DPaddingSame(self):
    173     height, width = 7, 9
    174     images = random_ops.random_uniform((5, height, width, 4), seed=1)
    175     layer = pooling_layers.MaxPooling2D(
    176         images.get_shape()[1:3], strides=2, padding='same')
    177     output = layer.apply(images)
    178     self.assertListEqual(output.get_shape().as_list(), [5, 4, 5, 4])
    179 
    180   def testCreatePooling2DWithStrides(self):
    181     height, width = 6, 8
    182     # Test strides tuple
    183     images = random_ops.random_uniform((5, height, width, 3), seed=1)
    184     layer = pooling_layers.MaxPooling2D([2, 2], strides=(2, 2), padding='same')
    185     output = layer.apply(images)
    186     self.assertListEqual(output.get_shape().as_list(),
    187                          [5, height / 2, width / 2, 3])
    188 
    189     # Test strides integer
    190     layer = pooling_layers.MaxPooling2D([2, 2], strides=2, padding='same')
    191     output = layer.apply(images)
    192     self.assertListEqual(output.get_shape().as_list(),
    193                          [5, height / 2, width / 2, 3])
    194 
    195     # Test unequal strides
    196     layer = pooling_layers.MaxPooling2D([2, 2], strides=(2, 1), padding='same')
    197     output = layer.apply(images)
    198     self.assertListEqual(output.get_shape().as_list(),
    199                          [5, height / 2, width, 3])
    200 
    201 
    202 if __name__ == '__main__':
    203   test.main()
    204