1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.reshape_op.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import gradient_checker 27 from tensorflow.python.platform import test 28 29 30 class ReshapeTest(test.TestCase): 31 32 def _testReshape(self, x, y, use_gpu=False): 33 with self.test_session(use_gpu=use_gpu): 34 np_ans = x.reshape(y) 35 tf_ans = array_ops.reshape(x, y) 36 out = tf_ans.eval() 37 self.assertEqual(tf_ans.get_shape(), out.shape) 38 self.assertShapeEqual(np_ans, tf_ans) 39 40 # Repeat with an int64 shape tensor. 41 y64 = constant_op.constant(y, dtype=dtypes.int64) 42 tf_ans = array_ops.reshape(x, y64) 43 out = tf_ans.eval() 44 self.assertEqual(tf_ans.get_shape(), out.shape) 45 self.assertShapeEqual(np_ans, tf_ans) 46 47 def _testBothReshape(self, x, y): 48 self._testReshape(x, y, False) 49 self._testReshape(x, y, True) 50 51 def testBoolBasic(self): 52 x = np.arange(1., 7.).reshape([1, 6]) > 3 53 self._testBothReshape(x, [2, 3]) 54 55 def testFloatBasic(self): 56 x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32) 57 self._testBothReshape(x, [2, 3]) 58 59 def testDoubleBasic(self): 60 x = np.arange(1., 7.).reshape([1, 6]).astype(np.float64) 61 self._testBothReshape(x, [2, 3]) 62 63 def testInt32Basic(self): 64 x = np.arange(1., 7.).reshape([1, 6]).astype(np.int32) 65 self._testBothReshape(x, [2, 3]) 66 67 def testComplex64Basic(self): 68 x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex64) 69 self._testBothReshape(x, [2, 3]) 70 71 def testComplex128Basic(self): 72 x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex128) 73 self._testBothReshape(x, [2, 3]) 74 75 def testFloatReshapeThreeDimensions(self): 76 x = np.arange(1., 28.).reshape([1, 27]).astype(np.float32) 77 self._testBothReshape(x, [3, 3, 3]) 78 79 def testFloatUnspecifiedDimOnly(self): 80 x = np.arange(1., 7.).reshape([6]).astype(np.float32) 81 self._testBothReshape(x, [-1]) 82 83 def testFloatUnspecifiedDimBegin(self): 84 x = np.arange(1., 7.).reshape([6]).astype(np.float32) 85 self._testBothReshape(x, [-1, 2]) 86 87 def testFloatUnspecifiedDimEnd(self): 88 x = np.arange(1., 7.).reshape([6]).astype(np.float32) 89 self._testBothReshape(x, [3, -1]) 90 91 # TODO(vrv): Add tests for failure conditions once python test_util 92 # reports errors. 93 94 def testFloatReshapeGradThreeDimensions(self): 95 x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32) 96 s = list(np.shape(x)) 97 with self.test_session(): 98 input_tensor = constant_op.constant(x) 99 reshape_out = array_ops.reshape(input_tensor, [1, 8, 3]) 100 err = gradient_checker.compute_gradient_error( 101 input_tensor, s, reshape_out, s, x_init_value=x) 102 print("Reshape gradient error = " % err) 103 self.assertLess(err, 1e-3) 104 105 def testFloatEmpty(self): 106 x = np.empty((0, 0, 0, 0), dtype=np.float32) 107 self._testBothReshape(x, [1, 2, 3, 0]) 108 self._testBothReshape(x, [1, 0, 0, 4]) 109 self._testBothReshape(x, [0, 0, 0, 0]) 110 self._testBothReshape(x, [1, 2, 0]) 111 self._testBothReshape(x, [0, 0, 0]) 112 self._testBothReshape(x, [1, -1, 5]) 113 114 def testErrors(self): 115 y = constant_op.constant(0.0, shape=[23, 29, 31]) 116 with self.assertRaisesRegexp(ValueError, "must be evenly divisible by 17"): 117 array_ops.reshape(y, [17, -1]) 118 119 z = constant_op.constant(0.0, shape=[32, 128]) 120 with self.assertRaisesRegexp(ValueError, 121 "Cannot reshape a tensor with 4096 elements"): 122 array_ops.reshape(z, [4095]) 123 124 def testPartialShapes(self): 125 x = array_ops.placeholder(dtypes.float32) 126 127 # Unknown input shape, partial new shape. 128 y = array_ops.reshape(x, [1, 1, -1, 1]) 129 self.assertEqual([1, 1, None, 1], y.get_shape().as_list()) 130 131 # Unknown input shape, unknown new shape. 132 y = array_ops.reshape(x, array_ops.placeholder(dtypes.int32)) 133 self.assertEqual(None, y.get_shape().ndims) 134 135 # Unknown input shape, known rank for new shape. 136 y = array_ops.reshape(x, array_ops.placeholder(dtypes.int32, shape=(3,))) 137 self.assertEqual([None, None, None], y.get_shape().as_list()) 138 139 # Unknown input shape, partial new shape using `tf.stack()`. 140 y = array_ops.reshape(x, [array_ops.placeholder(dtypes.int32), 37]) 141 self.assertEqual([None, 37], y.get_shape().as_list()) 142 143 # Unknown input shape, partial new shape using `tf.concat()`. 144 y = array_ops.reshape( 145 x, 146 array_ops.concat( 147 [array_ops.placeholder( 148 dtypes.int32, shape=(2,)), [37, 42]], 0)) 149 self.assertEqual([None, None, 37, 42], y.get_shape().as_list()) 150 151 # Unknown input shape, partial new shape using `tf.shape()`. 152 y = array_ops.reshape( 153 x, 154 array_ops.shape( 155 array_ops.placeholder( 156 dtypes.float32, shape=[None, 37, None]))) 157 self.assertEqual([None, 37, None], y.get_shape().as_list()) 158 159 160 if __name__ == "__main__": 161 test.main() 162