Home | History | Annotate | Download | only in kernel_tests
      1 # =============================================================================
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # =============================================================================
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy
     22 
     23 from tensorflow.contrib.periodic_resample import periodic_resample
     24 from tensorflow.python.framework import errors_impl
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import googletest
     28 
     29 
     30 class PeriodicResampleTest(test_util.TensorFlowTestCase):
     31 
     32   def testPeriodicResampleBasic2D(self):
     33 
     34     input_tensor = numpy.arange(12).reshape((3, 4))
     35     desired_shape = numpy.array([6, None])
     36     output_tensor = input_tensor.reshape((6, 2))
     37 
     38     with self.test_session():
     39       variables.global_variables_initializer().run()
     40       result = periodic_resample(input_tensor, desired_shape).eval()
     41       self.assertAllEqual(result, output_tensor)
     42 
     43   def testPeriodicResampleTruncatedBasic2D(self):
     44 
     45     input_tensor = numpy.arange(12).reshape((3, 4))
     46     desired_shape = numpy.array([5, None])
     47     output_tensor = input_tensor.reshape((6, 2))[:-1]
     48 
     49     with self.test_session():
     50       variables.global_variables_initializer().run()
     51       result = periodic_resample(input_tensor, desired_shape).eval()
     52       self.assertAllEqual(result, output_tensor)
     53 
     54   def testPeriodicResampleBasic3D(self):
     55 
     56     input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4))
     57     desired_shape = numpy.array([4, 4, None])
     58     output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]],
     59                                  [[8], [10], [12], [14]], [[9], [11], [13],
     60                                                            [15]]])
     61 
     62     # NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
     63     with self.test_session():
     64       variables.global_variables_initializer().run()
     65       result = periodic_resample(input_tensor, desired_shape).eval()
     66       # input_tensor[0, 0, 0] == result[0, 0, 0]
     67       # input_tensor[0, 0, 1] == result[1, 0, 0]
     68       # input_tensor[0, 0, 2] == result[0, 1, 0]
     69       # input_tensor[0, 0, 3] == result[1, 1, 0]
     70       self.assertAllEqual(result, output_tensor)
     71 
     72   def testPeriodicResampleBasic4D(self):
     73 
     74     input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8))
     75     desired_shape = numpy.array([4, 4, 4, None])
     76     output_tensor = numpy.array(
     77         [[[[0], [4], [8], [12]], [[2], [6], [10], [14]],
     78           [[16], [20], [24], [28]], [[18], [22], [26], [30]]],
     79          [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25],
     80                                                           [29]],
     81           [[19], [23], [27],
     82            [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]],
     83                     [[48], [52], [56], [60]], [[50], [54], [58], [62]]],
     84          [[[33], [37], [41], [45]], [[35], [39], [43], [47]],
     85           [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
     86 
     87     # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
     88     with self.test_session():
     89       variables.global_variables_initializer().run()
     90       result = periodic_resample(input_tensor, desired_shape).eval()
     91       self.assertAllEqual(result, output_tensor)
     92 
     93   def testPeriodicResampleErrors(self):
     94     input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
     95     with self.test_session():
     96       variables.global_variables_initializer().run()
     97       with self.assertRaisesWithPredicateMatch(
     98           errors_impl.InvalidArgumentError,
     99           'Dimension 3 input tensor has size 4, desired shape has size 1'):
    100         periodic_resample(input_tensor, [None, 4, 4, 1]).eval()
    101       with self.assertRaisesWithPredicateMatch(
    102           errors_impl.InvalidArgumentError,
    103           '4, to be the same as the length of the desired shape, 3'):
    104         periodic_resample(input_tensor, [None, 4, 4]).eval()
    105 
    106 
    107 if __name__ == '__main__':
    108   googletest.main()
    109