1 # ============================================================================= 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================= 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy 22 23 from tensorflow.contrib.periodic_resample import periodic_resample 24 from tensorflow.python.framework import errors_impl 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import googletest 28 29 30 class PeriodicResampleTest(test_util.TensorFlowTestCase): 31 32 def testPeriodicResampleBasic2D(self): 33 34 input_tensor = numpy.arange(12).reshape((3, 4)) 35 desired_shape = numpy.array([6, None]) 36 output_tensor = input_tensor.reshape((6, 2)) 37 38 with self.test_session(): 39 variables.global_variables_initializer().run() 40 result = periodic_resample(input_tensor, desired_shape).eval() 41 self.assertAllEqual(result, output_tensor) 42 43 def testPeriodicResampleTruncatedBasic2D(self): 44 45 input_tensor = numpy.arange(12).reshape((3, 4)) 46 desired_shape = numpy.array([5, None]) 47 output_tensor = input_tensor.reshape((6, 2))[:-1] 48 49 with self.test_session(): 50 variables.global_variables_initializer().run() 51 result = periodic_resample(input_tensor, desired_shape).eval() 52 self.assertAllEqual(result, output_tensor) 53 54 def testPeriodicResampleBasic3D(self): 55 56 input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4)) 57 desired_shape = numpy.array([4, 4, None]) 58 output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]], 59 [[8], [10], [12], [14]], [[9], [11], [13], 60 [15]]]) 61 62 # NOTE: output_tensor != input_tensor.reshape((4, 4, -1)) 63 with self.test_session(): 64 variables.global_variables_initializer().run() 65 result = periodic_resample(input_tensor, desired_shape).eval() 66 # input_tensor[0, 0, 0] == result[0, 0, 0] 67 # input_tensor[0, 0, 1] == result[1, 0, 0] 68 # input_tensor[0, 0, 2] == result[0, 1, 0] 69 # input_tensor[0, 0, 3] == result[1, 1, 0] 70 self.assertAllEqual(result, output_tensor) 71 72 def testPeriodicResampleBasic4D(self): 73 74 input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8)) 75 desired_shape = numpy.array([4, 4, 4, None]) 76 output_tensor = numpy.array( 77 [[[[0], [4], [8], [12]], [[2], [6], [10], [14]], 78 [[16], [20], [24], [28]], [[18], [22], [26], [30]]], 79 [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25], 80 [29]], 81 [[19], [23], [27], 82 [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]], 83 [[48], [52], [56], [60]], [[50], [54], [58], [62]]], 84 [[[33], [37], [41], [45]], [[35], [39], [43], [47]], 85 [[49], [53], [57], [61]], [[51], [55], [59], [63]]]]) 86 87 # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1)) 88 with self.test_session(): 89 variables.global_variables_initializer().run() 90 result = periodic_resample(input_tensor, desired_shape).eval() 91 self.assertAllEqual(result, output_tensor) 92 93 def testPeriodicResampleErrors(self): 94 input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) 95 with self.test_session(): 96 variables.global_variables_initializer().run() 97 with self.assertRaisesWithPredicateMatch( 98 errors_impl.InvalidArgumentError, 99 'Dimension 3 input tensor has size 4, desired shape has size 1'): 100 periodic_resample(input_tensor, [None, 4, 4, 1]).eval() 101 with self.assertRaisesWithPredicateMatch( 102 errors_impl.InvalidArgumentError, 103 '4, to be the same as the length of the desired shape, 3'): 104 periodic_resample(input_tensor, [None, 4, 4]).eval() 105 106 107 if __name__ == '__main__': 108 googletest.main() 109