Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.reshape_op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class ReshapeTest(test.TestCase):
     31 
     32   def _testReshape(self, x, y, use_gpu=False):
     33     with self.test_session(use_gpu=use_gpu):
     34       np_ans = x.reshape(y)
     35       tf_ans = array_ops.reshape(x, y)
     36       out = tf_ans.eval()
     37       self.assertEqual(tf_ans.get_shape(), out.shape)
     38       self.assertShapeEqual(np_ans, tf_ans)
     39 
     40       # Repeat with an int64 shape tensor.
     41       y64 = constant_op.constant(y, dtype=dtypes.int64)
     42       tf_ans = array_ops.reshape(x, y64)
     43       out = tf_ans.eval()
     44       self.assertEqual(tf_ans.get_shape(), out.shape)
     45       self.assertShapeEqual(np_ans, tf_ans)
     46 
     47   def _testBothReshape(self, x, y):
     48     self._testReshape(x, y, False)
     49     self._testReshape(x, y, True)
     50 
     51   def testBoolBasic(self):
     52     x = np.arange(1., 7.).reshape([1, 6]) > 3
     53     self._testBothReshape(x, [2, 3])
     54 
     55   def testFloatBasic(self):
     56     x = np.arange(1., 7.).reshape([1, 6]).astype(np.float32)
     57     self._testBothReshape(x, [2, 3])
     58 
     59   def testDoubleBasic(self):
     60     x = np.arange(1., 7.).reshape([1, 6]).astype(np.float64)
     61     self._testBothReshape(x, [2, 3])
     62 
     63   def testInt32Basic(self):
     64     x = np.arange(1., 7.).reshape([1, 6]).astype(np.int32)
     65     self._testBothReshape(x, [2, 3])
     66 
     67   def testComplex64Basic(self):
     68     x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex64)
     69     self._testBothReshape(x, [2, 3])
     70 
     71   def testComplex128Basic(self):
     72     x = np.arange(1., 7.).reshape([1, 6]).astype(np.complex128)
     73     self._testBothReshape(x, [2, 3])
     74 
     75   def testFloatReshapeThreeDimensions(self):
     76     x = np.arange(1., 28.).reshape([1, 27]).astype(np.float32)
     77     self._testBothReshape(x, [3, 3, 3])
     78 
     79   def testFloatUnspecifiedDimOnly(self):
     80     x = np.arange(1., 7.).reshape([6]).astype(np.float32)
     81     self._testBothReshape(x, [-1])
     82 
     83   def testFloatUnspecifiedDimBegin(self):
     84     x = np.arange(1., 7.).reshape([6]).astype(np.float32)
     85     self._testBothReshape(x, [-1, 2])
     86 
     87   def testFloatUnspecifiedDimEnd(self):
     88     x = np.arange(1., 7.).reshape([6]).astype(np.float32)
     89     self._testBothReshape(x, [3, -1])
     90 
     91   # TODO(vrv): Add tests for failure conditions once python test_util
     92   # reports errors.
     93 
     94   def testFloatReshapeGradThreeDimensions(self):
     95     x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32)
     96     s = list(np.shape(x))
     97     with self.test_session():
     98       input_tensor = constant_op.constant(x)
     99       reshape_out = array_ops.reshape(input_tensor, [1, 8, 3])
    100       err = gradient_checker.compute_gradient_error(
    101           input_tensor, s, reshape_out, s, x_init_value=x)
    102     print("Reshape gradient error = " % err)
    103     self.assertLess(err, 1e-3)
    104 
    105   def testFloatEmpty(self):
    106     x = np.empty((0, 0, 0, 0), dtype=np.float32)
    107     self._testBothReshape(x, [1, 2, 3, 0])
    108     self._testBothReshape(x, [1, 0, 0, 4])
    109     self._testBothReshape(x, [0, 0, 0, 0])
    110     self._testBothReshape(x, [1, 2, 0])
    111     self._testBothReshape(x, [0, 0, 0])
    112     self._testBothReshape(x, [1, -1, 5])
    113 
    114   def testErrors(self):
    115     y = constant_op.constant(0.0, shape=[23, 29, 31])
    116     with self.assertRaisesRegexp(ValueError, "must be evenly divisible by 17"):
    117       array_ops.reshape(y, [17, -1])
    118 
    119     z = constant_op.constant(0.0, shape=[32, 128])
    120     with self.assertRaisesRegexp(ValueError,
    121                                  "Cannot reshape a tensor with 4096 elements"):
    122       array_ops.reshape(z, [4095])
    123 
    124   def testPartialShapes(self):
    125     x = array_ops.placeholder(dtypes.float32)
    126 
    127     # Unknown input shape, partial new shape.
    128     y = array_ops.reshape(x, [1, 1, -1, 1])
    129     self.assertEqual([1, 1, None, 1], y.get_shape().as_list())
    130 
    131     # Unknown input shape, unknown new shape.
    132     y = array_ops.reshape(x, array_ops.placeholder(dtypes.int32))
    133     self.assertEqual(None, y.get_shape().ndims)
    134 
    135     # Unknown input shape, known rank for new shape.
    136     y = array_ops.reshape(x, array_ops.placeholder(dtypes.int32, shape=(3,)))
    137     self.assertEqual([None, None, None], y.get_shape().as_list())
    138 
    139     # Unknown input shape, partial new shape using `tf.stack()`.
    140     y = array_ops.reshape(x, [array_ops.placeholder(dtypes.int32), 37])
    141     self.assertEqual([None, 37], y.get_shape().as_list())
    142 
    143     # Unknown input shape, partial new shape using `tf.concat()`.
    144     y = array_ops.reshape(
    145         x,
    146         array_ops.concat(
    147             [array_ops.placeholder(
    148                 dtypes.int32, shape=(2,)), [37, 42]], 0))
    149     self.assertEqual([None, None, 37, 42], y.get_shape().as_list())
    150 
    151     # Unknown input shape, partial new shape using `tf.shape()`.
    152     y = array_ops.reshape(
    153         x,
    154         array_ops.shape(
    155             array_ops.placeholder(
    156                 dtypes.float32, shape=[None, 37, None])))
    157     self.assertEqual([None, 37, None], y.get_shape().as_list())
    158 
    159 
    160 if __name__ == "__main__":
    161   test.main()
    162