Home | History | Annotate | Download | only in tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for scan ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import errors_impl
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 def numpy_reverse(x, axis):
     33   length = len(x.shape)
     34   if axis < 0:
     35     axis = length + axis
     36 
     37   ix = [
     38       slice(None, None, -1) if i == axis else slice(None) for i in range(length)
     39   ]
     40   return x[ix]
     41 
     42 
     43 def handle_options(func, x, axis, exclusive, reverse):
     44   """Adds tf options to numpy scan ops."""
     45   length = len(x.shape)
     46   if axis < 0:
     47     axis = length + axis
     48 
     49   if reverse:
     50     x = numpy_reverse(x, axis)
     51 
     52   if exclusive:
     53     ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)]
     54     ix_init = [
     55         slice(0, -1) if i == axis else slice(None) for i in range(length)
     56     ]
     57     if func == np.cumsum:
     58       init = np.zeros_like(x[ix_head])
     59     elif func == np.cumprod:
     60       init = np.ones_like(x[ix_head])
     61     else:
     62       raise ValueError("Unknown scan function.")
     63     x = np.concatenate([init, func(x[ix_init], axis)], axis=axis)
     64   else:
     65     x = func(x, axis=axis)
     66 
     67   if reverse:
     68     x = numpy_reverse(x, axis)
     69   return x
     70 
     71 
     72 class CumsumTest(XLATestCase):
     73 
     74   valid_dtypes = [np.float32]
     75 
     76   def axis_dtypes(self):
     77     return set(self.int_types).intersection([np.int32, np.int64])
     78 
     79   def _compare(self, x, axis, exclusive, reverse):
     80     np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
     81     with self.test_session(), self.test_scope():
     82       p = array_ops.placeholder(x.dtype)
     83       tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval(
     84           feed_dict={p: x})
     85 
     86     self.assertAllClose(np_out, tf_out)
     87 
     88   def _compareAll(self, x, axis):
     89     for exclusive in [True, False]:
     90       for reverse in [True, False]:
     91         self._compare(x, axis, exclusive, reverse)
     92 
     93   def testEmpty(self):
     94     for dtype in self.valid_dtypes:
     95       x = np.zeros([0]).astype(dtype)
     96       for axis in (-1, 0):
     97         self._compareAll(x, axis)
     98 
     99   def testAxisType(self):
    100     for dtype in self.valid_dtypes:
    101       x = np.arange(1, 6).reshape([5]).astype(dtype)
    102       for axis_dtype in self.axis_dtypes():
    103         with self.test_session(), self.test_scope():
    104           p = array_ops.placeholder(x.dtype)
    105           axis = constant_op.constant(0, axis_dtype)
    106           math_ops.cumsum(p, axis).eval(feed_dict={p: x})
    107 
    108   def test1D(self):
    109     for dtype in self.valid_dtypes:
    110       x = np.arange(1, 6).reshape([5]).astype(dtype)
    111       for axis in (-1, 0):
    112         self._compareAll(x, axis)
    113 
    114   def test2D(self):
    115     for dtype in self.valid_dtypes:
    116       x = np.arange(0, 10).reshape([2, 5]).astype(dtype)
    117       for axis in (-2, -1, 0, 1):
    118         self._compareAll(x, axis)
    119 
    120   def test3D(self):
    121     for dtype in self.valid_dtypes:
    122       x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype)
    123       for axis in (-3, -2, -1, 0, 1, 2):
    124         self._compareAll(x, axis)
    125 
    126   def test6D(self):
    127     for dtype in self.valid_dtypes:
    128       x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
    129       for axis in range(-6, 6, 3):
    130         self._compareAll(x, axis)
    131 
    132   def testInvalidAxis(self):
    133     x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
    134     with self.test_session(), self.test_scope():
    135       input_tensor = ops.convert_to_tensor(x)
    136       with self.assertRaisesWithPredicateMatch(
    137           errors_impl.InvalidArgumentError,
    138           lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
    139         math_ops.cumsum(input_tensor, -3).eval()
    140       with self.assertRaisesWithPredicateMatch(
    141           errors_impl.InvalidArgumentError,
    142           lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
    143         math_ops.cumsum(input_tensor, 2).eval()
    144       with self.assertRaisesWithPredicateMatch(
    145           errors_impl.InvalidArgumentError,
    146           lambda e: "axis must be a scalar" in str(e)):
    147         math_ops.cumsum(input_tensor, [0]).eval()
    148 
    149 
    150 class CumprodTest(XLATestCase):
    151 
    152   valid_dtypes = [np.float32]
    153 
    154   def axis_dtypes(self):
    155     return set(self.int_types).intersection([np.int32, np.int64])
    156 
    157   def _compare(self, x, axis, exclusive, reverse):
    158     np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
    159     with self.test_session(), self.test_scope():
    160       p = array_ops.placeholder(x.dtype)
    161       prod = math_ops.cumprod(p, axis, exclusive, reverse)
    162       tf_out = prod.eval(feed_dict={p: x})
    163 
    164     self.assertAllClose(np_out, tf_out)
    165 
    166   def _compareAll(self, x, axis):
    167     for exclusive in [True, False]:
    168       for reverse in [True, False]:
    169         self._compare(x, axis, exclusive, reverse)
    170 
    171   def testEmpty(self):
    172     for dtype in self.valid_dtypes:
    173       x = np.zeros([0]).astype(dtype)
    174       for axis in (-1, 0):
    175         self._compareAll(x, axis)
    176 
    177   def testAxisType(self):
    178     for dtype in self.valid_dtypes:
    179       x = np.arange(1, 6).reshape([5]).astype(dtype)
    180       for axis_dtype in self.axis_dtypes():
    181         with self.test_session(), self.test_scope():
    182           p = array_ops.placeholder(x.dtype)
    183           axis = constant_op.constant(0, axis_dtype)
    184           math_ops.cumprod(x, axis).eval(feed_dict={p: x})
    185 
    186   def test1D(self):
    187     for dtype in self.valid_dtypes:
    188       x = np.arange(1, 6).reshape([5]).astype(dtype)
    189       for axis in (-1, 0):
    190         self._compareAll(x, axis)
    191 
    192   def test2D(self):
    193     for dtype in self.valid_dtypes:
    194       x = np.arange(1, 11).reshape([2, 5]).astype(dtype)
    195       for axis in (-2, -1, 0, 1):
    196         self._compareAll(x, axis)
    197 
    198   def test3D(self):
    199     for dtype in self.valid_dtypes:
    200       x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype)
    201       for axis in (-3, -2, -1, 0, 1, 2):
    202         self._compareAll(x, axis)
    203 
    204   def test6D(self):
    205     for dtype in self.valid_dtypes:
    206       x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
    207       for axis in range(-6, 6, 3):
    208         self._compareAll(x, axis)
    209 
    210   def testInvalidAxis(self):
    211     x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
    212     with self.test_session(), self.test_scope():
    213       input_tensor = ops.convert_to_tensor(x)
    214       with self.assertRaisesWithPredicateMatch(
    215           errors_impl.InvalidArgumentError,
    216           lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
    217         math_ops.cumprod(input_tensor, -3).eval()
    218       with self.assertRaisesWithPredicateMatch(
    219           errors_impl.InvalidArgumentError,
    220           lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
    221         math_ops.cumprod(input_tensor, 2).eval()
    222       with self.assertRaisesWithPredicateMatch(
    223           errors_impl.InvalidArgumentError,
    224           lambda e: "axis must be a scalar" in str(e)):
    225         math_ops.cumprod(input_tensor, [0]).eval()
    226 
    227 
    228 if __name__ == "__main__":
    229   test.main()
    230