Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for reduction ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import numbers
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.framework import tensor_shape
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import gradient_checker
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.platform import test
     34 
     35 # The maximum input rank to test.
     36 _MAX_RANK = 5
     37 
     38 
     39 def _powerset(iterable):
     40   """Helper for generating all possible reduction_axes arguments.
     41 
     42   Example:
     43   powerset([0,1,2]): () (0,) (1,) (2,) (0,1) (0,2) (1,2) (0,1,2)
     44 
     45   Args:
     46     iterable: An iterable of items to generate the powerset of.
     47 
     48   Returns:
     49     The powerset of all items in iterable.
     50   """
     51   s = list(iterable)
     52   return itertools.chain.from_iterable(
     53       itertools.combinations(s, r) for r in range(len(s) + 1))
     54 
     55 
     56 class ReducedShapeTest(test.TestCase):
     57 
     58   def _check(self, shape, axes, result):
     59     output = math_ops.reduced_shape(shape, axes=axes)
     60     self.assertAllEqual(output.eval(), result)
     61 
     62   def testSimple(self):
     63     with self.test_session():
     64       self._check([3], [], [3])
     65       self._check([3], [0], [1])
     66       self._check([5, 3], [], [5, 3])
     67       self._check([5, 3], [0], [1, 3])
     68       self._check([5, 3], [1], [5, 1])
     69       self._check([5, 3], [0, 1], [1, 1])
     70 
     71   def testZeros(self):
     72     """Check that reduced_shape does the right thing with zero dimensions."""
     73     with self.test_session():
     74       self._check([0], [], [0])
     75       self._check([0], [0], [1])
     76       self._check([0, 3], [], [0, 3])
     77       self._check([0, 3], [0], [1, 3])
     78       self._check([0, 3], [1], [0, 1])
     79       self._check([0, 3], [0, 1], [1, 1])
     80       self._check([3, 0], [], [3, 0])
     81       self._check([3, 0], [0], [1, 0])
     82       self._check([3, 0], [1], [3, 1])
     83       self._check([3, 0], [0, 1], [1, 1])
     84 
     85   def testNegAxes(self):
     86     with self.test_session():
     87       self._check([10, 10, 10], [-1], [10, 10, 1])
     88       self._check([10, 10, 10], [-1, 2], [10, 10, 1])
     89       self._check([10, 10, 10], [-1, -1], [10, 10, 1])
     90       self._check([10, 10, 10], [-1, 0], [1, 10, 1])
     91       self._check([10, 10, 10], [-3], [1, 10, 10])
     92 
     93 
     94 class ReductionUnknownShape(test.TestCase):
     95 
     96   def testBasic(self):
     97     with self.test_session():
     98       for dtype, reductions in [(dtypes.float32,
     99                                  (math_ops.reduce_sum, math_ops.reduce_mean,
    100                                   math_ops.reduce_prod, math_ops.reduce_max,
    101                                   math_ops.reduce_min)),
    102                                 (dtypes.bool, (math_ops.reduce_all,
    103                                                math_ops.reduce_any))]:
    104         for reduction in reductions:
    105           x = array_ops.placeholder(
    106               dtype=dtype, shape=None)  # Some tensor w/ unknown shape.
    107           y = reduction(x)
    108           self.assertEqual(y.shape, ())
    109 
    110 
    111 class BaseReductionTest(test.TestCase):
    112 
    113   def _tf_reduce(self, x, reduction_axes, keepdims):
    114     raise NotImplementedError()
    115 
    116   def _np_reduce(self, x, reduction_axes, keepdims):
    117     raise NotImplementedError()
    118 
    119   def _makeIncremental(self, shape, dtype):
    120     data = np.arange(np.prod(shape)).reshape(shape).astype(dtype.as_numpy_dtype)
    121     if dtype.is_complex:
    122       data -= 2j * data
    123     return data
    124 
    125   def _makeRandom(self, shape, dtype):
    126     data = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
    127     if dtype.is_complex:
    128       data -= 2j * data
    129     return data
    130 
    131   def _compare(self, x, reduction_axes, keepdims, feed_dict=None):
    132     np_ans = self._np_reduce(x, reduction_axes, keepdims)
    133     with self.test_session(use_gpu=True) as sess:
    134       tf_ans = self._tf_reduce(x, reduction_axes, keepdims)
    135       out = sess.run(tf_ans, feed_dict)
    136     self.assertAllClose(np_ans, out)
    137     self.assertShapeEqual(np_ans, tf_ans)
    138 
    139   def _compareAll(self, x, reduction_axes, feed_dict=None):
    140     if reduction_axes is not None and np.shape(reduction_axes) == (1,):
    141       # Test scalar reduction_axes argument
    142       self._compareAll(x, reduction_axes[0])
    143     self._compare(x, reduction_axes, keepdims=False, feed_dict=feed_dict)
    144     self._compare(x, reduction_axes, keepdims=True, feed_dict=feed_dict)
    145 
    146   def _compareAllAxes(self, x, feed_dict=None):
    147     self._compareAll(x, None)
    148     for axes in _powerset(range(x.ndim)):
    149       self._compareAll(x, axes, feed_dict)
    150 
    151   def _compareGradient(self, x, reduction_axes, rtol=1e-8, atol=1e-8):
    152     if reduction_axes is not None and np.shape(reduction_axes) == (1,):
    153       # Test scalar reduction_axes argument
    154       self._compareGradient(x, reduction_axes[0], rtol=rtol, atol=atol)
    155     with self.test_session(use_gpu=True):
    156       t = ops.convert_to_tensor(x)
    157       su = self._tf_reduce(t, reduction_axes, False)
    158       jacob_t, jacob_n = gradient_checker.compute_gradient(
    159           t, x.shape, su, su.get_shape().as_list(), x_init_value=x, delta=1)
    160     self.assertAllClose(jacob_t, jacob_n, rtol=rtol, atol=atol)
    161 
    162   def _compareGradientAxes(self, x, rtol=1e-8, atol=1e-8):
    163     self._compareGradient(x, None, rtol=rtol, atol=atol)
    164     self._compareGradient(x, [], rtol=rtol, atol=atol)
    165     self._compareGradient(x, 0, rtol=rtol, atol=atol)
    166     self._compareGradient(x, [1], rtol=rtol, atol=atol)
    167     self._compareGradient(x, [2], rtol=rtol, atol=atol)
    168     self._compareGradient(x, [1, 2], rtol=rtol, atol=atol)
    169     self._compareGradient(x, [0, 1, 2, 3], rtol=rtol, atol=atol)
    170 
    171 
    172 class SumReductionTest(BaseReductionTest):
    173 
    174   def _tf_reduce(self, x, reduction_axes, keepdims):
    175     return math_ops.reduce_sum(x, reduction_axes, keepdims)
    176 
    177   def _np_reduce(self, x, reduction_axes, keepdims):
    178     if isinstance(reduction_axes, list) or isinstance(reduction_axes,
    179                                                       np.ndarray):
    180       reduction_axes = tuple(reduction_axes)
    181     return np.sum(x, axis=reduction_axes, keepdims=keepdims)
    182 
    183   def testAxesType(self):
    184     for dtype in [dtypes.int64, dtypes.int32]:
    185       with self.test_session(use_gpu=True) as sess:
    186         v = math_ops.reduce_sum([0, 0], constant_op.constant(0, dtype=dtype))
    187         tf_v = sess.run(v)
    188       self.assertAllEqual(tf_v, 0)
    189 
    190   def testInfinity(self):
    191     for dtype in [np.float32, np.float64]:
    192       for special_value_x in [-np.inf, np.inf]:
    193         for special_value_y in [-np.inf, np.inf]:
    194           np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
    195           self._compareAll(np_arr, None)
    196 
    197   def testInt32(self):
    198     for rank in range(1, _MAX_RANK + 1):
    199       np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
    200       self._compareAllAxes(np_arr)
    201 
    202   def testFloat16(self):
    203     for rank in range(1, _MAX_RANK + 1):
    204       np_arr = self._makeIncremental((2,) * rank, dtypes.float16)
    205       self._compareAllAxes(np_arr)
    206 
    207     # test that mean doesn't overflow
    208     # only on GPU, since it has the more accurate implementation
    209     if not test.is_gpu_available():
    210       return
    211 
    212     arr = np.ones([68000], dtype=np.float16)
    213 
    214     with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
    215       tf_arr = array_ops.constant(arr)
    216       tf_mean = math_ops.reduce_mean(tf_arr, 0, False)
    217       tf_out_mean = sess.run(tf_mean)
    218     self.assertAllClose(tf_out_mean, 1.)
    219 
    220   def testFloat32(self):
    221     for rank in range(1, _MAX_RANK + 1):
    222       np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
    223       self._compareAllAxes(np_arr)
    224 
    225     for _ in range(10):
    226       size_x = int(2**np.random.uniform(0, 15))
    227       size_y = int(2**np.random.uniform(0, 15))
    228 
    229       if size_x * size_y > 1e7:
    230         size_y = int(1e7 / size_x)
    231 
    232       arr = np.ones([size_x, size_y], dtype=np.float32)
    233       col_sum = np.sum(arr, axis=0)
    234       row_sum = np.sum(arr, axis=1)
    235 
    236       with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
    237         tf_row_sum = self._tf_reduce(arr, 1, False)
    238         tf_col_sum = self._tf_reduce(arr, 0, False)
    239         tf_out_row, tf_out_col = sess.run([tf_row_sum, tf_col_sum])
    240       self.assertAllClose(col_sum, tf_out_col)
    241       self.assertAllClose(row_sum, tf_out_row)
    242 
    243     for size_x in [1, 3, 16, 33]:
    244       for size_y in [1, 3, 16, 33]:
    245         for size_z in [1, 3, 16, 33]:
    246           arr = np.ones([size_x, size_y, size_z], dtype=np.float32)
    247           sum_y = np.sum(arr, axis=1)
    248           sum_xz = np.sum(arr, axis=(0, 2))
    249 
    250           with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
    251             tf_sum_xz = self._tf_reduce(arr, [0, 2], False)
    252             tf_sum_y = self._tf_reduce(arr, 1, False)
    253             tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
    254           self.assertAllClose(sum_y, tf_out_sum_y)
    255           self.assertAllClose(sum_xz, tf_out_sum_xz)
    256 
    257   def testFloat64(self):
    258     for rank in range(1, _MAX_RANK + 1):
    259       np_arr = self._makeIncremental((2,) * rank, dtypes.float64)
    260       self._compareAllAxes(np_arr)
    261 
    262   def testComplex64(self):
    263     for rank in range(1, _MAX_RANK + 1):
    264       np_arr = self._makeIncremental((2,) * rank, dtypes.complex64)
    265       self._compareAllAxes(np_arr)
    266 
    267   def testComplex128(self):
    268     for rank in range(1, _MAX_RANK + 1):
    269       np_arr = self._makeIncremental((2,) * rank, dtypes.complex128)
    270       self._compareAllAxes(np_arr)
    271 
    272   def testInvalidIndex(self):
    273     np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
    274     input_tensor = ops.convert_to_tensor(np_arr)
    275     with self.assertRaisesWithPredicateMatch(
    276         ValueError, lambda e: "Invalid reduction dimension" in str(e)):
    277       math_ops.reduce_sum(input_tensor, [-3])
    278     with self.assertRaisesWithPredicateMatch(
    279         ValueError, lambda e: "Invalid reduction dimension" in str(e)):
    280       math_ops.reduce_sum(input_tensor, [2])
    281     with self.assertRaisesWithPredicateMatch(
    282         ValueError, lambda e: "Invalid reduction dimension" in str(e)):
    283       math_ops.reduce_sum(input_tensor, [0, 2])
    284 
    285   def testPartialShapes(self):
    286     np.random.seed(1618)
    287 
    288     # Input shape is unknown.
    289     reduction_axes = [1, 2]
    290     c_unknown = array_ops.placeholder(dtypes.float32)
    291     s_unknown = math_ops.reduce_sum(c_unknown, reduction_axes)
    292     self.assertEqual(tensor_shape.unknown_shape(), s_unknown.get_shape())
    293 
    294     np_input = np.random.randn(3, 3, 3)
    295     self._compareAll(np_input, reduction_axes, {c_unknown: np_input})
    296 
    297     # Input shape only has known rank.
    298     c_known_rank = array_ops.placeholder(dtypes.float32)
    299     c_known_rank.set_shape(tensor_shape.unknown_shape(ndims=3))
    300     s_known_rank = math_ops.reduce_sum(
    301         c_known_rank, reduction_axes, keepdims=True)
    302     self.assertEqual(3, s_known_rank.get_shape().ndims)
    303 
    304     np_input = np.random.randn(3, 3, 3)
    305     self._compareAll(np_input, reduction_axes, {c_known_rank: np_input})
    306 
    307     # Reduction indices are unknown.
    308     unknown_indices = array_ops.placeholder(dtypes.int32)
    309     c_unknown_indices = constant_op.constant([[10.0], [20.0]])
    310     s_unknown_indices = math_ops.reduce_sum(
    311         c_unknown_indices, unknown_indices, keepdims=False)
    312     self.assertEqual(tensor_shape.unknown_shape(),
    313                      s_unknown_indices.get_shape())
    314     s_unknown_indices_keep = math_ops.reduce_sum(
    315         c_unknown_indices, unknown_indices, keepdims=True)
    316     self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims)
    317 
    318   def testWrongShapeForReductionIndices(self):
    319     reduction_axes = [[1], [2]]
    320     c_unknown = array_ops.placeholder(dtypes.float32)
    321     with self.assertRaisesWithPredicateMatch(ValueError,
    322                                              ".*must be at most rank 1.*"):
    323       math_ops.reduce_sum(c_unknown, reduction_axes)
    324 
    325   # Int64??
    326 
    327   def testGradient(self):
    328     for dtype in [
    329         dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128
    330     ]:
    331       x = self._makeIncremental([2, 3, 4, 2], dtype)
    332       self._compareGradientAxes(x)
    333 
    334   def testHighRank(self):
    335     # Do a bunch of random high dimensional reductions
    336     np.random.seed(42)
    337     for _ in range(20):
    338       rank = np.random.randint(4, 10 + 1)
    339       axes, = np.nonzero(np.random.randint(2, size=rank))
    340       shape = tuple(np.random.randint(1, 3 + 1, size=rank))
    341       data = np.random.randint(1024, size=shape)
    342       self._compareAll(data, axes)
    343     # Check some particular axis patterns
    344     for rank in 4, 7, 10:
    345       shape = tuple(np.random.randint(1, 3 + 1, size=rank))
    346       data = np.random.randint(1024, size=shape)
    347       for axes in ([], np.arange(rank), np.arange(0, rank, 2),
    348                    np.arange(1, rank, 2)):
    349         self._compareAll(data, axes)
    350 
    351   def testExpand(self):
    352     # Reduce an empty tensor to a nonempty tensor
    353     x = np.zeros((5, 0))
    354     self._compareAll(x, [1])
    355 
    356   def testEmptyGradients(self):
    357     with self.test_session(use_gpu=True):
    358       x = array_ops.zeros([0, 3])
    359       y = math_ops.reduce_sum(x, [1])
    360       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
    361       self.assertEqual(error, 0)
    362 
    363   def testDegenerate(self):
    364     with self.test_session(use_gpu=True):
    365       for dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
    366                     dtypes.complex64, dtypes.complex128):
    367         # A large number is needed to get Eigen to die
    368         x = array_ops.zeros((0, 9938), dtype=dtype)
    369         y = math_ops.reduce_sum(x, [0])
    370         self.assertAllEqual(y.eval(), np.zeros(9938))
    371 
    372 
    373 class MeanReductionTest(BaseReductionTest):
    374 
    375   def _tf_reduce(self, x, reduction_axes, keepdims):
    376     return math_ops.reduce_mean(x, reduction_axes, keepdims)
    377 
    378   def _np_reduce(self, x, reduction_axes, keepdims):
    379     if isinstance(reduction_axes, list) or isinstance(reduction_axes,
    380                                                       np.ndarray):
    381       reduction_axes = tuple(reduction_axes)
    382     elif isinstance(reduction_axes, numbers.Integral):
    383       reduction_axes = (reduction_axes,)
    384 
    385     if reduction_axes is None:
    386       count = np.prod(x.shape)
    387     else:
    388       count = np.prod([x.shape[ax] for ax in reduction_axes])
    389     # np.mean automatically converts integer inputs to float, while TensorFlow's
    390     # reduce_mean does not. For integer inputs, we emulate TensorFlow's behavior
    391     # using np.sum and truncating division.
    392     np_sum = np.sum(x, axis=reduction_axes, keepdims=keepdims)
    393     if np.issubdtype(x.dtype, np.integer):
    394       return np_sum // count
    395     return np_sum / count
    396 
    397   def testAxesType(self):
    398     for dtype in [dtypes.int64, dtypes.int32]:
    399       with self.test_session(use_gpu=True) as sess:
    400         v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
    401         tf_v = sess.run(v)
    402       self.assertAllEqual(tf_v, 0)
    403 
    404   def testInfinity(self):
    405     for dtype in [np.float32, np.float64]:
    406       for special_value_x in [-np.inf, np.inf]:
    407         for special_value_y in [-np.inf, np.inf]:
    408           np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
    409           self._compareAll(np_arr, None)
    410 
    411   def testInt32(self):
    412     for rank in range(1, _MAX_RANK + 1):
    413       np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
    414       self._compareAllAxes(np_arr)
    415 
    416   def testFloat32(self):
    417     for rank in range(1, _MAX_RANK + 1):
    418       np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
    419       self._compareAllAxes(np_arr)
    420 
    421   def testFloat64(self):
    422     for rank in range(1, _MAX_RANK + 1):
    423       np_arr = self._makeIncremental((2,) * rank, dtypes.float64)
    424       self._compareAllAxes(np_arr)
    425 
    426   def testComplex64(self):
    427     for rank in range(1, _MAX_RANK + 1):
    428       np_arr = self._makeIncremental((2,) * rank, dtypes.complex64)
    429       self._compareAllAxes(np_arr)
    430 
    431   def testComplex128(self):
    432     for rank in range(1, _MAX_RANK + 1):
    433       np_arr = self._makeIncremental((2,) * rank, dtypes.complex128)
    434       self._compareAllAxes(np_arr)
    435 
    436   def testGradient(self):
    437     s = [2, 3, 4, 2]
    438     for dtype in [dtypes.float32, dtypes.float64]:
    439       x = self._makeIncremental(s, dtype)
    440       self._compareGradientAxes(x, rtol=1e-3, atol=1e-3)
    441 
    442   def testEmptyGradients(self):
    443     with self.test_session(use_gpu=True):
    444       x = array_ops.zeros([0, 3])
    445       y = math_ops.reduce_mean(x, [1])
    446       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
    447       self.assertEqual(error, 0)
    448 
    449   def testDegenerate(self):
    450     with self.test_session(use_gpu=True):
    451       for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
    452         # A large number is needed to get Eigen to die
    453         x = array_ops.zeros((0, 9938), dtype=dtype)
    454         y = math_ops.reduce_mean(x, [0]).eval()
    455         self.assertEqual(y.shape, (9938,))
    456         self.assertTrue(np.all(np.isnan(y)))
    457 
    458 
    459 class ProdReductionTest(BaseReductionTest):
    460 
    461   def _tf_reduce(self, x, reduction_axes, keepdims):
    462     return math_ops.reduce_prod(x, reduction_axes, keepdims)
    463 
    464   def _np_reduce(self, x, reduction_axes, keepdims):
    465     if isinstance(reduction_axes, list) or isinstance(reduction_axes,
    466                                                       np.ndarray):
    467       reduction_axes = tuple(reduction_axes)
    468     return np.prod(x, axis=reduction_axes, keepdims=keepdims)
    469 
    470   def testAxesType(self):
    471     for dtype in [dtypes.int64, dtypes.int32]:
    472       with self.test_session(use_gpu=True) as sess:
    473         v = math_ops.reduce_prod([0, 0], constant_op.constant(0, dtype=dtype))
    474         tf_v = sess.run(v)
    475       self.assertAllEqual(tf_v, 0)
    476 
    477   def testInfinity(self):
    478     for dtype in [np.float32, np.float64]:
    479       for special_value_x in [-np.inf, np.inf]:
    480         for special_value_y in [-np.inf, np.inf]:
    481           np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
    482           self._compareAll(np_arr, None)
    483 
    484   def testInt32(self):
    485     # Numpy automatically upgrades the type of np.prod from int32 to int64, so
    486     # Numpy does not overflow an int32 np.prod while TensorFlow does. To avoid
    487     # overflow, divide the incremental int32 array by 2.
    488     for rank in range(1, _MAX_RANK + 1):
    489       np_arr = self._makeIncremental((2,) * rank, dtypes.int32) / 2
    490       self._compareAllAxes(np_arr)
    491 
    492   def testFloat32(self):
    493     for rank in range(1, _MAX_RANK + 1):
    494       np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
    495       self._compareAllAxes(np_arr)
    496 
    497   def testFloat64(self):
    498     for rank in range(1, _MAX_RANK + 1):
    499       np_arr = self._makeIncremental((2,) * rank, dtypes.float64)
    500       self._compareAllAxes(np_arr)
    501 
    502   def testComplex64(self):
    503     for rank in range(1, _MAX_RANK + 1):
    504       np_arr = self._makeIncremental((2,) * rank, dtypes.complex64)
    505       self._compareAllAxes(np_arr)
    506 
    507   def testComplex128(self):
    508     for rank in range(1, _MAX_RANK + 1):
    509       np_arr = self._makeIncremental((2,) * rank, dtypes.complex128)
    510       self._compareAllAxes(np_arr)
    511 
    512   def testGradientWithZeros(self):
    513     s = [2, 3, 4, 2]
    514     x = self._makeIncremental(s, dtypes.float32) / 20.
    515     # No zeros in input
    516     self._compareGradientAxes(x, rtol=1e-3, atol=1e-3)
    517     # Zero at beginning
    518     x1 = x.copy()
    519     x1[:, :, 0, :] = 0
    520     self._compareGradientAxes(x1, rtol=1e-3, atol=1e-3)
    521     # Zero at end
    522     x2 = x.copy()
    523     x2[:, :, -1, :] = 0
    524     self._compareGradientAxes(x2, rtol=1e-3, atol=1e-3)
    525     # Zero in middle
    526     x3 = x.copy()
    527     x3[:, :, 2, :] = 0
    528     self._compareGradientAxes(x3, rtol=1e-3, atol=1e-3)
    529     # All zeros
    530     x4 = x.copy()
    531     x4[:, :, :, :] = 0
    532     self._compareGradientAxes(x4, rtol=1e-3, atol=1e-3)
    533 
    534   def testEmptyGradients(self):
    535     with self.test_session(use_gpu=True):
    536       x = array_ops.zeros([0, 3])
    537       y = math_ops.reduce_prod(x, [1])
    538       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
    539       self.assertEqual(error, 0)
    540 
    541   def testDegenerate(self):
    542     with self.test_session(use_gpu=True):
    543       for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
    544         # A large number is needed to get Eigen to die
    545         x = array_ops.zeros((0, 9938), dtype=dtype)
    546         y = math_ops.reduce_prod(x, [0])
    547         self.assertAllEqual(y.eval(), np.ones(9938))
    548 
    549 
    550 class MinReductionTest(test.TestCase):
    551 
    552   def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
    553     np_ans = x
    554     if reduction_axes is None:
    555       np_ans = np.amin(np_ans, keepdims=keepdims)
    556     else:
    557       for ra in reduction_axes[::-1]:
    558         np_ans = np.amin(np_ans, axis=ra, keepdims=keepdims)
    559     with self.test_session(use_gpu=use_gpu):
    560       if reduction_axes is not None:
    561         reduction_axes = np.array(reduction_axes).astype(np.int32)
    562       tf_ans = math_ops.reduce_min(x, reduction_axes, keepdims)
    563       out = tf_ans.eval()
    564     self.assertAllClose(np_ans, out)
    565     self.assertShapeEqual(np_ans, tf_ans)
    566 
    567   def _compareAll(self, x, reduction_axes):
    568     self._compare(x, reduction_axes, False, use_gpu=True)
    569     self._compare(x, reduction_axes, False, use_gpu=False)
    570     self._compare(x, reduction_axes, True, use_gpu=True)
    571     self._compare(x, reduction_axes, True, use_gpu=False)
    572 
    573   def testAxesType(self):
    574     for dtype in [dtypes.int64, dtypes.int32]:
    575       with self.test_session(use_gpu=True) as sess:
    576         v = math_ops.reduce_min([0, 0], constant_op.constant(0, dtype=dtype))
    577         tf_v = sess.run(v)
    578       self.assertAllEqual(tf_v, 0)
    579 
    580   def testInfinity(self):
    581     for dtype in [np.float32, np.float64]:
    582       for special_value_x in [-np.inf, np.inf]:
    583         for special_value_y in [-np.inf, np.inf]:
    584           np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
    585           self._compareAll(np_arr, None)
    586 
    587   def testFloatReduce3D(self):
    588     # Create a 3D array of floats and reduce across all possible
    589     # dimensions
    590     np_arr = np.arange(1, 31).reshape([2, 3, 5]).astype(np.float32)
    591     self._compareAll(np_arr, None)
    592     self._compareAll(np_arr, [])
    593     self._compareAll(np_arr, [0])
    594     self._compareAll(np_arr, [1])
    595     self._compareAll(np_arr, [2])
    596     self._compareAll(np_arr, [0, 1])
    597     self._compareAll(np_arr, [1, 2])
    598     self._compareAll(np_arr, [0, 2])
    599     self._compareAll(np_arr, [0, 1, 2])
    600 
    601   def testDoubleReduce3D(self):
    602     # Create a 3D array of doubles and reduce across all possible
    603     # dimensions
    604     np_arr = np.arange(1, 31).reshape([2, 3, 5]).astype(np.float64)
    605     self._compareAll(np_arr, None)
    606     self._compareAll(np_arr, [])
    607     self._compareAll(np_arr, [0])
    608     self._compareAll(np_arr, [1])
    609     self._compareAll(np_arr, [2])
    610     self._compareAll(np_arr, [0, 1])
    611     self._compareAll(np_arr, [1, 2])
    612     self._compareAll(np_arr, [0, 2])
    613     self._compareAll(np_arr, [0, 1, 2])
    614 
    615   def testGradient(self):
    616     s = [2, 3, 4, 2]
    617     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
    618     with self.test_session():
    619       t = ops.convert_to_tensor(x)
    620       su = math_ops.reduce_min(t, [1, 2])
    621       jacob_t, jacob_n = gradient_checker.compute_gradient(
    622           t, s, su, [2, 2], x_init_value=x, delta=1)
    623     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    624 
    625   def testGradient2(self):
    626     s = [2, 3, 4, 2]
    627     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
    628     with self.test_session():
    629       t = ops.convert_to_tensor(x)
    630       su = math_ops.reduce_min(t, [1])
    631       jacob_t, jacob_n = gradient_checker.compute_gradient(
    632           t, s, su, [2, 4, 2], x_init_value=x, delta=1)
    633     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    634 
    635   def testGradient3(self):
    636     s = [2, 3, 4, 2]
    637     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
    638     with self.test_session():
    639       t = ops.convert_to_tensor(x)
    640       su = math_ops.reduce_min(t, [2])
    641       jacob_t, jacob_n = gradient_checker.compute_gradient(
    642           t, s, su, [2, 3, 2], x_init_value=x, delta=1)
    643     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    644 
    645   def testGradient4(self):
    646     s = [2, 3, 4, 2]
    647     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
    648     with self.test_session():
    649       t = ops.convert_to_tensor(x)
    650       su = math_ops.reduce_min(t)
    651       jacob_t, jacob_n = gradient_checker.compute_gradient(
    652           t, s, su, [1], x_init_value=x, delta=1)
    653     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    654 
    655   def testEmptyGradients(self):
    656     with self.test_session():
    657       x = array_ops.zeros([0, 3])
    658       y = math_ops.reduce_min(x, [1])
    659       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
    660       self.assertEqual(error, 0)
    661 
    662 
    663 class MaxReductionTest(test.TestCase):
    664 
    665   def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
    666     np_ans = x
    667     if reduction_axes is None:
    668       np_ans = np.amax(np_ans, keepdims=keepdims)
    669     else:
    670       for ra in reduction_axes[::-1]:
    671         np_ans = np.amax(np_ans, axis=ra, keepdims=keepdims)
    672     with self.test_session(use_gpu=use_gpu):
    673       if reduction_axes is not None:
    674         reduction_axes = np.array(reduction_axes).astype(np.int32)
    675       tf_ans = math_ops.reduce_max(x, reduction_axes, keepdims)
    676       out = tf_ans.eval()
    677     self.assertAllClose(np_ans, out)
    678     self.assertShapeEqual(np_ans, tf_ans)
    679 
    680   def _compareAll(self, x, reduction_axes):
    681     self._compare(x, reduction_axes, False, use_gpu=True)
    682     self._compare(x, reduction_axes, False, use_gpu=False)
    683     self._compare(x, reduction_axes, True, use_gpu=True)
    684     self._compare(x, reduction_axes, True, use_gpu=False)
    685 
    686   def testAxesType(self):
    687     for dtype in [dtypes.int64, dtypes.int32]:
    688       with self.test_session(use_gpu=True) as sess:
    689         v = math_ops.reduce_max([0, 0], constant_op.constant(0, dtype=dtype))
    690         tf_v = sess.run(v)
    691       self.assertAllEqual(tf_v, 0)
    692 
    693   def testInfinity(self):
    694     for dtype in [np.float32, np.float64]:
    695       for special_value_x in [-np.inf, np.inf]:
    696         for special_value_y in [-np.inf, np.inf]:
    697           np_arr = np.array([special_value_x, special_value_y]).astype(dtype)
    698           self._compareAll(np_arr, None)
    699 
    700   def testInt64Reduce3D(self):
    701     # Create a 3D array of int64s and reduce across all possible
    702     # dimensions
    703     np_arr = np.arange(-31, -1).reshape([2, 3, 5]).astype(np.int64)
    704     self._compareAll(np_arr, None)
    705     self._compareAll(np_arr, [])
    706     self._compareAll(np_arr, [0])
    707     self._compareAll(np_arr, [1])
    708     self._compareAll(np_arr, [2])
    709     self._compareAll(np_arr, [0, 1])
    710     self._compareAll(np_arr, [1, 2])
    711     self._compareAll(np_arr, [0, 2])
    712     self._compareAll(np_arr, [0, 1, 2])
    713 
    714   def testFloatReduce3D(self):
    715     # Create a 3D array of floats and reduce across all possible
    716     # dimensions
    717     np_arr = np.arange(-31, -1).reshape([2, 3, 5]).astype(np.float32)
    718     self._compareAll(np_arr, None)
    719     self._compareAll(np_arr, [])
    720     self._compareAll(np_arr, [0])
    721     self._compareAll(np_arr, [1])
    722     self._compareAll(np_arr, [2])
    723     self._compareAll(np_arr, [0, 1])
    724     self._compareAll(np_arr, [1, 2])
    725     self._compareAll(np_arr, [0, 2])
    726     self._compareAll(np_arr, [0, 1, 2])
    727 
    728   def testDoubleReduce3D(self):
    729     # Create a 3D array of doubles and reduce across all possible
    730     # dimensions
    731     np_arr = np.arange(-31, -1).reshape([2, 3, 5]).astype(np.float64)
    732     self._compareAll(np_arr, None)
    733     self._compareAll(np_arr, [])
    734     self._compareAll(np_arr, [0])
    735     self._compareAll(np_arr, [1])
    736     self._compareAll(np_arr, [2])
    737     self._compareAll(np_arr, [0, 1])
    738     self._compareAll(np_arr, [1, 2])
    739     self._compareAll(np_arr, [0, 2])
    740     self._compareAll(np_arr, [0, 1, 2])
    741 
    742   def testGradient(self):
    743     s = [2, 3, 4, 2]
    744     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
    745     with self.test_session():
    746       t = ops.convert_to_tensor(x)
    747       su = math_ops.reduce_max(t, [1, 2])
    748       jacob_t, jacob_n = gradient_checker.compute_gradient(
    749           t, s, su, [2, 2], x_init_value=x, delta=1)
    750     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    751 
    752   def testGradient2(self):
    753     s = [2, 3, 4, 2]
    754     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
    755     with self.test_session():
    756       t = ops.convert_to_tensor(x)
    757       su = math_ops.reduce_max(t, [1])
    758       jacob_t, jacob_n = gradient_checker.compute_gradient(
    759           t, s, su, [2, 4, 2], x_init_value=x, delta=1)
    760     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    761 
    762   def testGradient3(self):
    763     s = [2, 3, 4, 2]
    764     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
    765     with self.test_session():
    766       t = ops.convert_to_tensor(x)
    767       su = math_ops.reduce_max(t, [2])
    768       jacob_t, jacob_n = gradient_checker.compute_gradient(
    769           t, s, su, [2, 3, 2], x_init_value=x, delta=1)
    770     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    771 
    772   def testGradient4(self):
    773     s = [2, 3, 4, 2]
    774     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
    775     with self.test_session():
    776       t = ops.convert_to_tensor(x)
    777       su = math_ops.reduce_max(t)
    778       jacob_t, jacob_n = gradient_checker.compute_gradient(
    779           t, s, su, [1], x_init_value=x, delta=1)
    780     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
    781 
    782   def testEmptyGradients(self):
    783     with self.test_session():
    784       x = array_ops.zeros([0, 3])
    785       y = math_ops.reduce_max(x, [1])
    786       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
    787       self.assertEqual(error, 0)
    788 
    789 
    790 class AllReductionTest(test.TestCase):
    791 
    792   def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
    793     np_ans = x
    794     if reduction_axes is None:
    795       np_ans = np.all(np_ans, keepdims=keepdims)
    796     else:
    797       for ra in reduction_axes[::-1]:
    798         np_ans = np.all(np_ans, axis=ra, keepdims=keepdims)
    799     with self.test_session(use_gpu=use_gpu):
    800       if reduction_axes is not None:
    801         reduction_axes = np.array(reduction_axes).astype(np.int32)
    802       tf_ans = math_ops.reduce_all(x, reduction_axes, keepdims)
    803       out = tf_ans.eval()
    804     self.assertAllEqual(np_ans, out)
    805     self.assertShapeEqual(np_ans, tf_ans)
    806 
    807   def _compareAll(self, x, reduction_axes):
    808     self._compare(x, reduction_axes, False, use_gpu=True)
    809     self._compare(x, reduction_axes, False, use_gpu=False)
    810     self._compare(x, reduction_axes, True, use_gpu=True)
    811     self._compare(x, reduction_axes, True, use_gpu=False)
    812 
    813   def testAxesType(self):
    814     for dtype in [dtypes.int64, dtypes.int32]:
    815       with self.test_session(use_gpu=True) as sess:
    816         v = math_ops.reduce_all([True, True],
    817                                 constant_op.constant(0, dtype=dtype))
    818         tf_v = sess.run(v)
    819       self.assertAllEqual(tf_v, True)
    820 
    821   def testAll3D(self):
    822     # Create a 3D array of bools and reduce across all possible
    823     # dimensions
    824     np_arr = (np.random.uniform(0, 1, 30) > 0.1).reshape([2, 3, 5])
    825     self._compareAll(np_arr, None)
    826     self._compareAll(np_arr, [])
    827     self._compareAll(np_arr, [0])
    828     self._compareAll(np_arr, [1])
    829     self._compareAll(np_arr, [2])
    830     self._compareAll(np_arr, [0, 1])
    831     self._compareAll(np_arr, [1, 2])
    832     self._compareAll(np_arr, [0, 2])
    833     self._compareAll(np_arr, [0, 1, 2])
    834 
    835   def testEmpty(self):
    836     self._compareAll([], [0])
    837 
    838 
    839 class AnyReductionTest(test.TestCase):
    840 
    841   def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
    842     np_ans = x
    843     if reduction_axes is None:
    844       np_ans = np.any(np_ans, keepdims=keepdims)
    845     else:
    846       for ra in reduction_axes[::-1]:
    847         np_ans = np.any(np_ans, axis=ra, keepdims=keepdims)
    848     with self.test_session(use_gpu=use_gpu):
    849       if reduction_axes is not None:
    850         reduction_axes = np.array(reduction_axes).astype(np.int32)
    851       tf_ans = math_ops.reduce_any(x, reduction_axes, keepdims)
    852       out = tf_ans.eval()
    853     self.assertAllEqual(np_ans, out)
    854     self.assertShapeEqual(np_ans, tf_ans)
    855 
    856   def _compareAll(self, x, reduction_axes):
    857     self._compare(x, reduction_axes, False, use_gpu=True)
    858     self._compare(x, reduction_axes, False, use_gpu=False)
    859     self._compare(x, reduction_axes, True, use_gpu=True)
    860     self._compare(x, reduction_axes, True, use_gpu=False)
    861 
    862   def testAxesType(self):
    863     for dtype in [dtypes.int64, dtypes.int32]:
    864       with self.test_session(use_gpu=True) as sess:
    865         v = math_ops.reduce_any([True, True],
    866                                 constant_op.constant(0, dtype=dtype))
    867         tf_v = sess.run(v)
    868       self.assertAllEqual(tf_v, True)
    869 
    870   def testAll3D(self):
    871     # Create a 3D array of bools and reduce across all possible
    872     # dimensions
    873     np_arr = (np.random.uniform(0, 1, 30) > 0.9).reshape([2, 3, 5])
    874     self._compareAll(np_arr, None)
    875     self._compareAll(np_arr, [])
    876     self._compareAll(np_arr, [0])
    877     self._compareAll(np_arr, [1])
    878     self._compareAll(np_arr, [2])
    879     self._compareAll(np_arr, [0, 1])
    880     self._compareAll(np_arr, [1, 2])
    881     self._compareAll(np_arr, [0, 2])
    882     self._compareAll(np_arr, [0, 1, 2])
    883 
    884   def testEmpty(self):
    885     self._compareAll([], [0])
    886 
    887 
    888 class CountNonzeroReductionTest(test.TestCase):
    889 
    890   def _compare(self,
    891                x,
    892                reduction_axes,
    893                keepdims,
    894                use_gpu=False,
    895                feed_dict=None):
    896     np_ans = (x != 0).astype(np.int32)
    897     if reduction_axes is None:
    898       np_ans = np.sum(np_ans, keepdims=keepdims)
    899     else:
    900       reduction_axes = np.array(reduction_axes).astype(np.int32)
    901       for ra in reduction_axes.ravel()[::-1]:
    902         np_ans = np.sum(np_ans, axis=ra, keepdims=keepdims)
    903     with self.test_session(use_gpu=use_gpu) as sess:
    904       tf_ans = math_ops.count_nonzero(x, reduction_axes, keepdims)
    905       out = sess.run(tf_ans, feed_dict)
    906     self.assertAllClose(np_ans, out)
    907     self.assertShapeEqual(np_ans, tf_ans)
    908 
    909   def _compareAll(self, x, reduction_axes, feed_dict=None):
    910     if reduction_axes is not None and np.shape(reduction_axes) == (1,):
    911       # Test scalar reduction_axes argument
    912       self._compareAll(x, reduction_axes[0])
    913     self._compare(x, reduction_axes, False, use_gpu=True, feed_dict=feed_dict)
    914     self._compare(x, reduction_axes, False, use_gpu=False, feed_dict=feed_dict)
    915     self._compare(x, reduction_axes, True, use_gpu=True, feed_dict=feed_dict)
    916     self._compare(x, reduction_axes, True, use_gpu=False, feed_dict=feed_dict)
    917 
    918   def testBoolReduce1D(self):
    919     # Create a 1D array of floats
    920     np_arr = np.asarray([False, False, True, False, False, True])
    921     self._compareAll(np_arr, None)
    922     self._compareAll(np_arr, [])
    923     self._compareAll(np_arr, [0])
    924 
    925   def testFloatReduce1D(self):
    926     # Create a 1D array of floats
    927     np_arr = np.asarray([0.0, 1.0, -1.0, 0.0, 0.0, 3.0]).astype(np.float32)
    928     self._compareAll(np_arr, [0])
    929 
    930   def testFloatReduce4D(self):
    931     # Create a 4D array of floats and reduce across some
    932     # dimensions
    933     np_arr = np.floor(np.arange(0.0, 210.0) / 100.0).reshape([2, 3, 5,
    934                                                               7]).astype(
    935                                                                   np.float32)
    936     self._compareAll(np_arr, None)
    937     self._compareAll(np_arr, [])
    938     self._compareAll(np_arr, [0])
    939     self._compareAll(np_arr, [1])
    940     self._compareAll(np_arr, [2])
    941     self._compareAll(np_arr, [0, 1])
    942     self._compareAll(np_arr, [1, 2])
    943     # Need specialization for reduce(4D, [0, 2])
    944     # self._compareAll(np_arr, [0, 2])
    945     self._compareAll(np_arr, [0, 1, 2])
    946     self._compareAll(np_arr, [1, 2, 3])
    947     self._compareAll(np_arr, [0, 1, 2, 3])
    948 
    949   def testExpand(self):
    950     # Reduce an empty tensor to a nonempty tensor
    951     x = np.zeros((5, 0))
    952     self._compareAll(x, [1])
    953 
    954   def testDegenerate(self):
    955     for use_gpu in False, True:
    956       with self.test_session(use_gpu=use_gpu):
    957         for dtype in (dtypes.bool,):
    958           # A large number is needed to get Eigen to die
    959           x = array_ops.zeros((0, 9938), dtype=dtype)
    960           y = math_ops.count_nonzero(x, [0])
    961           self.assertAllEqual(y.eval(), np.zeros(9938))
    962 
    963 
    964 if __name__ == "__main__":
    965   test.main()
    966