Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for manip_ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import errors_impl
     24 from tensorflow.python.framework import test_util
     25 from tensorflow.python.ops import gradient_checker
     26 from tensorflow.python.ops import manip_ops
     27 from tensorflow.python.platform import test as test_lib
     28 
     29 # pylint: disable=g-import-not-at-top
     30 try:
     31   from distutils.version import StrictVersion as Version
     32   # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
     33   NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0")
     34 except ImportError:
     35   NP_ROLL_CAN_MULTISHIFT = False
     36 # pylint: enable=g-import-not-at-top
     37 
     38 
     39 class RollTest(test_util.TensorFlowTestCase):
     40 
     41   def _testRoll(self, np_input, shift, axis):
     42     expected_roll = np.roll(np_input, shift, axis)
     43     with self.test_session():
     44       roll = manip_ops.roll(np_input, shift, axis)
     45       self.assertAllEqual(roll.eval(), expected_roll)
     46 
     47   def _testGradient(self, np_input, shift, axis):
     48     with self.test_session():
     49       inx = constant_op.constant(np_input.tolist())
     50       xs = list(np_input.shape)
     51       y = manip_ops.roll(inx, shift, axis)
     52       # Expected y's shape to be the same
     53       ys = xs
     54       jacob_t, jacob_n = gradient_checker.compute_gradient(
     55           inx, xs, y, ys, x_init_value=np_input)
     56       self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
     57 
     58   def _testAll(self, np_input, shift, axis):
     59     self._testRoll(np_input, shift, axis)
     60     if np_input.dtype == np.float32:
     61       self._testGradient(np_input, shift, axis)
     62 
     63   def testIntTypes(self):
     64     for t in [np.int32, np.int64]:
     65       self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
     66       if NP_ROLL_CAN_MULTISHIFT:
     67         self._testAll(
     68             np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3],
     69             [0, 1, 2])
     70         self._testAll(
     71             np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
     72             [1, 2, 3])
     73 
     74   def testFloatTypes(self):
     75     for t in [np.float32, np.float64]:
     76       self._testAll(np.random.rand(5).astype(t), 2, 0)
     77       if NP_ROLL_CAN_MULTISHIFT:
     78         self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
     79         self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
     80 
     81   def testComplexTypes(self):
     82     for t in [np.complex64, np.complex128]:
     83       x = np.random.rand(4, 4).astype(t)
     84       self._testAll(x + 1j * x, 2, 0)
     85       if NP_ROLL_CAN_MULTISHIFT:
     86         x = np.random.rand(2, 5).astype(t)
     87         self._testAll(x + 1j * x, [1, 2], [1, 0])
     88         x = np.random.rand(3, 2, 1, 1).astype(t)
     89         self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
     90 
     91   def testRollInputMustVectorHigherRaises(self):
     92     tensor = 7
     93     shift = 1
     94     axis = 0
     95     with self.test_session():
     96       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
     97                                    "input must be 1-D or higher"):
     98         manip_ops.roll(tensor, shift, axis).eval()
     99 
    100   def testRollAxisMustBeScalarOrVectorRaises(self):
    101     tensor = [[1, 2], [3, 4]]
    102     shift = 1
    103     axis = [[0, 1]]
    104     with self.test_session():
    105       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    106                                    "axis must be a scalar or a 1-D vector"):
    107         manip_ops.roll(tensor, shift, axis).eval()
    108 
    109   def testRollShiftMustBeScalarOrVectorRaises(self):
    110     tensor = [[1, 2], [3, 4]]
    111     shift = [[0, 1]]
    112     axis = 1
    113     with self.test_session():
    114       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    115                                    "shift must be a scalar or a 1-D vector"):
    116         manip_ops.roll(tensor, shift, axis).eval()
    117 
    118   def testRollShiftAndAxisMustBeSameSizeRaises(self):
    119     tensor = [[1, 2], [3, 4]]
    120     shift = [1]
    121     axis = [0, 1]
    122     with self.test_session():
    123       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    124                                    "shift and axis must have the same size"):
    125         manip_ops.roll(tensor, shift, axis).eval()
    126 
    127   def testRollAxisOutOfRangeRaises(self):
    128     tensor = [1, 2]
    129     shift = 1
    130     axis = 1
    131     with self.test_session():
    132       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
    133                                    "is out of range"):
    134         manip_ops.roll(tensor, shift, axis).eval()
    135 
    136 
    137 if __name__ == "__main__":
    138   test_lib.main()
    139