1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for reduction operators.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import itertools 23 from absl.testing import parameterized 24 import numpy as np 25 26 from tensorflow.compiler.tests import xla_test 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors_impl 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.platform import googletest 32 33 34 @parameterized.named_parameters(('32_bit_index', dtypes.int32), 35 ('64_bit_index', dtypes.int64)) 36 class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): 37 def _testReduction(self, 38 tf_reduce_fn, 39 np_reduce_fn, 40 dtype, 41 test_inputs, 42 index_dtype, 43 rtol=1e-4, 44 atol=1e-4): 45 """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" 46 47 for test_input in test_inputs: 48 with self.cached_session() as sess: 49 with self.test_scope(): 50 a = array_ops.placeholder(dtype) 51 index = array_ops.placeholder(index_dtype) 52 out = tf_reduce_fn(a, index) 53 result = sess.run(out, {a: test_input, index: [0]}) 54 self.assertAllClose( 55 result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) 56 57 result = sess.run(out, {a: test_input, index: [1]}) 58 self.assertAllClose( 59 result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) 60 61 result = sess.run(out, {a: test_input, index: [-1]}) 62 self.assertAllClose( 63 result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) 64 65 with self.assertRaisesWithPredicateMatch( 66 errors_impl.InvalidArgumentError, 'Invalid reduction dim'): 67 sess.run(out, {a: test_input, index: [-33]}) 68 69 with self.assertRaisesWithPredicateMatch( 70 errors_impl.InvalidArgumentError, 'Invalid reduction dim'): 71 sess.run(out, {a: test_input, index: [2]}) 72 73 REAL_DATA = [ 74 np.zeros(shape=(2, 0)), 75 np.zeros(shape=(0, 30)), 76 np.arange(1, 7).reshape(2, 3), 77 np.arange(-10, -4).reshape(2, 3), 78 np.arange(-4, 2).reshape(2, 3), 79 ] 80 COMPLEX_DATA = [ 81 np.zeros(shape=(2, 0)).astype(np.complex64), 82 np.zeros(shape=(0, 30)).astype(np.complex64), 83 np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), 84 np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), 85 np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), 86 ] 87 NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0] 88 NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] 89 BOOL_DATA = [ 90 np.array([], dtype=np.bool).reshape(2, 0), 91 np.array([], dtype=np.bool).reshape(0, 3), 92 np.array([[False, True, False], [True, True, False]]), 93 ] 94 ONES = [np.ones([34000, 2])] 95 96 def testReduceSumF32(self, index_dtype): 97 self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, 98 index_dtype) 99 100 def testReduceSumC64(self, index_dtype): 101 self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, 102 self.COMPLEX_DATA, index_dtype) 103 104 def testReduceProdF32(self, index_dtype): 105 self._testReduction(math_ops.reduce_prod, np.prod, np.float32, 106 self.REAL_DATA, index_dtype) 107 108 def testReduceProdC64(self, index_dtype): 109 self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, 110 self.COMPLEX_DATA, index_dtype) 111 112 def testReduceMin(self, index_dtype): 113 114 def reference_min(dtype, inp, axis): 115 """Wrapper around np.amin that returns +infinity for an empty input.""" 116 if inp.shape[axis] == 0: 117 if np.issubdtype(dtype, np.floating): 118 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) 119 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], 120 np.iinfo(dtype).max) 121 return np.amin(inp, axis) 122 123 for dtype in set(self.all_types).intersection( 124 [np.float32, np.int32, np.int64]): 125 self._testReduction(math_ops.reduce_min, 126 functools.partial(reference_min, dtype), dtype, 127 self.REAL_DATA, index_dtype) 128 129 def testReduceMax(self, index_dtype): 130 131 def reference_max(dtype, inp, axis): 132 """Wrapper around np.amax that returns -infinity for an empty input.""" 133 if inp.shape[axis] == 0: 134 if np.issubdtype(dtype, np.floating): 135 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], 136 float('-inf')) 137 return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], 138 np.iinfo(dtype).min) 139 return np.amax(inp, axis) 140 141 for dtype in set(self.all_types).intersection( 142 [np.float32, np.int32, np.int64]): 143 self._testReduction(math_ops.reduce_max, 144 functools.partial(reference_max, dtype), dtype, 145 self.REAL_DATA, index_dtype) 146 147 def testReduceMeanF32(self, index_dtype): 148 # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when 149 # reducing across zero inputs. 150 self._testReduction(math_ops.reduce_mean, np.mean, np.float32, 151 self.NONEMPTY_REAL_DATA, index_dtype) 152 153 def testReduceMeanF16(self, index_dtype): 154 if np.float16 in self.all_types: 155 self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES, 156 index_dtype) 157 158 def testReduceMeanC64(self, index_dtype): 159 self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, 160 self.NONEMPTY_COMPLEX_DATA, index_dtype) 161 162 def testReduceAll(self, index_dtype): 163 self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA, 164 index_dtype) 165 166 def testReduceAny(self, index_dtype): 167 self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA, 168 index_dtype) 169 170 171 class ReduceOpPrecisionTest(xla_test.XLATestCase): 172 173 def _testReduceSum(self, 174 expected_result, 175 dtype, 176 test_inputs, 177 rtol=1e-3, 178 atol=1e-4): 179 """Tests reduce sum on a list of input arrays. 180 181 For each array in test_inputs, check that performing reduce sum on the array 182 produces a value that is close to the expected result. 183 184 Args: 185 expected_result: the expected result. 186 dtype: the data type of the reduce sum operation. 187 test_inputs: a list of input arrays for the reduce sum operation. 188 rtol: the relative error. 189 atol: the absolute error. 190 """ 191 192 for test_input in test_inputs: 193 with self.cached_session() as sess: 194 with self.test_scope(): 195 a = array_ops.placeholder(dtype) 196 index = array_ops.placeholder(dtypes.int32) 197 out = math_ops.reduce_sum(a, index) 198 result = sess.run(out, { 199 a: np.array(test_input, dtype=dtype), 200 index: [0] 201 }) 202 # Compare the results using float32 type. 203 self.assertAllClose( 204 np.float32(result), 205 np.float32(expected_result), 206 rtol=rtol, 207 atol=atol) 208 209 def testReduceSumF16(self): 210 """Tests the reduce sum of float16 doesn't lose too much precision.""" 211 212 if np.float16 not in self.all_types: 213 return 214 215 f16_max = np.finfo(np.float16).max 216 self._testReduceSum( 217 f16_max, np.float16, 218 itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3)) 219 220 def testReduceSumBF16(self): 221 """Tests the reduce sum of bfloat16 doesn't lose too much precision.""" 222 223 if dtypes.bfloat16.as_numpy_dtype not in self.all_types: 224 return 225 226 bf16_max = np.float32(dtypes.bfloat16.max) 227 f32_max = dtypes.float32.max 228 value = min(bf16_max, f32_max - bf16_max) / 2 229 self._testReduceSum( 230 dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, 231 itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) 232 233 234 if __name__ == '__main__': 235 googletest.main() 236