1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Functional tests for scan ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import errors_impl 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.platform import test 30 31 32 def numpy_reverse(x, axis): 33 length = len(x.shape) 34 if axis < 0: 35 axis = length + axis 36 37 ix = [ 38 slice(None, None, -1) if i == axis else slice(None) for i in range(length) 39 ] 40 return x[ix] 41 42 43 def handle_options(func, x, axis, exclusive, reverse): 44 """Adds tf options to numpy scan ops.""" 45 length = len(x.shape) 46 if axis < 0: 47 axis = length + axis 48 49 if reverse: 50 x = numpy_reverse(x, axis) 51 52 if exclusive: 53 ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] 54 ix_init = [ 55 slice(0, -1) if i == axis else slice(None) for i in range(length) 56 ] 57 if func == np.cumsum: 58 init = np.zeros_like(x[ix_head]) 59 elif func == np.cumprod: 60 init = np.ones_like(x[ix_head]) 61 else: 62 raise ValueError("Unknown scan function.") 63 x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) 64 else: 65 x = func(x, axis=axis) 66 67 if reverse: 68 x = numpy_reverse(x, axis) 69 return x 70 71 72 class CumsumTest(XLATestCase): 73 74 valid_dtypes = [np.float32] 75 76 def axis_dtypes(self): 77 return set(self.int_types).intersection([np.int32, np.int64]) 78 79 def _compare(self, x, axis, exclusive, reverse): 80 np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) 81 with self.test_session(), self.test_scope(): 82 p = array_ops.placeholder(x.dtype) 83 tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( 84 feed_dict={p: x}) 85 86 self.assertAllClose(np_out, tf_out) 87 88 def _compareAll(self, x, axis): 89 for exclusive in [True, False]: 90 for reverse in [True, False]: 91 self._compare(x, axis, exclusive, reverse) 92 93 def testEmpty(self): 94 for dtype in self.valid_dtypes: 95 x = np.zeros([0]).astype(dtype) 96 for axis in (-1, 0): 97 self._compareAll(x, axis) 98 99 def testAxisType(self): 100 for dtype in self.valid_dtypes: 101 x = np.arange(1, 6).reshape([5]).astype(dtype) 102 for axis_dtype in self.axis_dtypes(): 103 with self.test_session(), self.test_scope(): 104 p = array_ops.placeholder(x.dtype) 105 axis = constant_op.constant(0, axis_dtype) 106 math_ops.cumsum(p, axis).eval(feed_dict={p: x}) 107 108 def test1D(self): 109 for dtype in self.valid_dtypes: 110 x = np.arange(1, 6).reshape([5]).astype(dtype) 111 for axis in (-1, 0): 112 self._compareAll(x, axis) 113 114 def test2D(self): 115 for dtype in self.valid_dtypes: 116 x = np.arange(0, 10).reshape([2, 5]).astype(dtype) 117 for axis in (-2, -1, 0, 1): 118 self._compareAll(x, axis) 119 120 def test3D(self): 121 for dtype in self.valid_dtypes: 122 x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) 123 for axis in (-3, -2, -1, 0, 1, 2): 124 self._compareAll(x, axis) 125 126 def test6D(self): 127 for dtype in self.valid_dtypes: 128 x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) 129 for axis in range(-6, 6, 3): 130 self._compareAll(x, axis) 131 132 def testInvalidAxis(self): 133 x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) 134 with self.test_session(), self.test_scope(): 135 input_tensor = ops.convert_to_tensor(x) 136 with self.assertRaisesWithPredicateMatch( 137 errors_impl.InvalidArgumentError, 138 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 139 math_ops.cumsum(input_tensor, -3).eval() 140 with self.assertRaisesWithPredicateMatch( 141 errors_impl.InvalidArgumentError, 142 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 143 math_ops.cumsum(input_tensor, 2).eval() 144 with self.assertRaisesWithPredicateMatch( 145 errors_impl.InvalidArgumentError, 146 lambda e: "axis must be a scalar" in str(e)): 147 math_ops.cumsum(input_tensor, [0]).eval() 148 149 150 class CumprodTest(XLATestCase): 151 152 valid_dtypes = [np.float32] 153 154 def axis_dtypes(self): 155 return set(self.int_types).intersection([np.int32, np.int64]) 156 157 def _compare(self, x, axis, exclusive, reverse): 158 np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) 159 with self.test_session(), self.test_scope(): 160 p = array_ops.placeholder(x.dtype) 161 prod = math_ops.cumprod(p, axis, exclusive, reverse) 162 tf_out = prod.eval(feed_dict={p: x}) 163 164 self.assertAllClose(np_out, tf_out) 165 166 def _compareAll(self, x, axis): 167 for exclusive in [True, False]: 168 for reverse in [True, False]: 169 self._compare(x, axis, exclusive, reverse) 170 171 def testEmpty(self): 172 for dtype in self.valid_dtypes: 173 x = np.zeros([0]).astype(dtype) 174 for axis in (-1, 0): 175 self._compareAll(x, axis) 176 177 def testAxisType(self): 178 for dtype in self.valid_dtypes: 179 x = np.arange(1, 6).reshape([5]).astype(dtype) 180 for axis_dtype in self.axis_dtypes(): 181 with self.test_session(), self.test_scope(): 182 p = array_ops.placeholder(x.dtype) 183 axis = constant_op.constant(0, axis_dtype) 184 math_ops.cumprod(x, axis).eval(feed_dict={p: x}) 185 186 def test1D(self): 187 for dtype in self.valid_dtypes: 188 x = np.arange(1, 6).reshape([5]).astype(dtype) 189 for axis in (-1, 0): 190 self._compareAll(x, axis) 191 192 def test2D(self): 193 for dtype in self.valid_dtypes: 194 x = np.arange(1, 11).reshape([2, 5]).astype(dtype) 195 for axis in (-2, -1, 0, 1): 196 self._compareAll(x, axis) 197 198 def test3D(self): 199 for dtype in self.valid_dtypes: 200 x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) 201 for axis in (-3, -2, -1, 0, 1, 2): 202 self._compareAll(x, axis) 203 204 def test6D(self): 205 for dtype in self.valid_dtypes: 206 x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) 207 for axis in range(-6, 6, 3): 208 self._compareAll(x, axis) 209 210 def testInvalidAxis(self): 211 x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) 212 with self.test_session(), self.test_scope(): 213 input_tensor = ops.convert_to_tensor(x) 214 with self.assertRaisesWithPredicateMatch( 215 errors_impl.InvalidArgumentError, 216 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 217 math_ops.cumprod(input_tensor, -3).eval() 218 with self.assertRaisesWithPredicateMatch( 219 errors_impl.InvalidArgumentError, 220 lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): 221 math_ops.cumprod(input_tensor, 2).eval() 222 with self.assertRaisesWithPredicateMatch( 223 errors_impl.InvalidArgumentError, 224 lambda e: "axis must be a scalar" in str(e)): 225 math_ops.cumprod(input_tensor, [0]).eval() 226 227 228 if __name__ == "__main__": 229 test.main() 230