Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional tests for segment reduction ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.client import session
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes as dtypes_lib
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import gradient_checker
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 
     34 
     35 class SegmentReductionHelper(test.TestCase):
     36 
     37   def _input(self, input_shape, dtype=dtypes_lib.int32):
     38     num_elem = 1
     39     for x in input_shape:
     40       num_elem *= x
     41     values = np.arange(1, num_elem + 1)
     42     np_values = values.reshape(input_shape).astype(dtype.as_numpy_dtype)
     43     # Add a non-zero imaginary component to complex types.
     44     if dtype.is_complex:
     45       np_values -= 1j * np_values
     46     return constant_op.constant(
     47         np_values, shape=input_shape, dtype=dtype), np_values
     48 
     49   def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None,
     50                      initial_value=0):
     51     if not x.size:
     52       return np.array([])
     53     indices = np.asarray(indices)
     54     if num_segments is None:
     55       num_segments = indices[-1] + 1
     56     output = [None] * num_segments
     57     slice_shape = x.shape[indices.ndim:]
     58     x_flat = x.reshape((indices.size,) + slice_shape)
     59     for i, index in enumerate(indices.ravel()):
     60       if (output[index] is not None) and op1 == np.max:
     61         for j in range(0, output[index].shape[0]):
     62           output[index][j] = op1([output[index][j], x_flat[i][j]])
     63       elif output[index] is not None:
     64         output[index] = op1(output[index], x_flat[i])
     65       else:
     66         output[index] = x_flat[i]
     67     # zero initialize values that are still uncalcuated.
     68     initial_value_slice = np.ones(slice_shape) * initial_value
     69     output = [o if o is not None else initial_value_slice for o in output]
     70     if op2 is not None:
     71       output = [op2(o) for o in output]
     72     output = [o.reshape(slice_shape) for o in output]
     73     return np.array(output)
     74 
     75   def _mean_cum_op(self, x, y):
     76     return (x[0] + y, x[1] + 1) if isinstance(x, tuple) else (x + y, 2)
     77 
     78   def _mean_reduce_op(self, x):
     79     return x[0] / x[1] if isinstance(x, tuple) else x
     80 
     81   def _sqrt_n_reduce_op(self, x):
     82     return x[0] / np.sqrt(x[1]) if isinstance(x, tuple) else x
     83 
     84 
     85 class SegmentReductionOpTest(SegmentReductionHelper):
     86 
     87   def testValues(self):
     88     dtypes = [
     89         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
     90         dtypes_lib.int32, dtypes_lib.complex64, dtypes_lib.complex128
     91     ]
     92 
     93     # Each item is np_op1, np_op2, tf_op
     94     ops_list = [(np.add, None, math_ops.segment_sum), (self._mean_cum_op,
     95                                                        self._mean_reduce_op,
     96                                                        math_ops.segment_mean),
     97                 (np.ndarray.__mul__, None, math_ops.segment_prod),
     98                 (np.minimum, None, math_ops.segment_min),
     99                 (np.maximum, None, math_ops.segment_max)]
    100 
    101     # A subset of ops has been enabled for complex numbers
    102     complex_ops_list = [(np.add, None, math_ops.segment_sum),
    103                         (np.ndarray.__mul__, None, math_ops.segment_prod)]
    104 
    105     n = 10
    106     shape = [n, 2]
    107     indices = [i // 3 for i in range(n)]
    108     for dtype in dtypes:
    109       if dtype in (dtypes_lib.complex64, dtypes_lib.complex128):
    110         curr_ops_list = complex_ops_list
    111       else:
    112         curr_ops_list = ops_list
    113       for use_gpu in [True, False]:
    114         with self.test_session(use_gpu=use_gpu):
    115           tf_x, np_x = self._input(shape, dtype=dtype)
    116           for np_op1, np_op2, tf_op in curr_ops_list:
    117             np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
    118             s = tf_op(data=tf_x, segment_ids=indices)
    119             tf_ans = s.eval()
    120             self.assertAllClose(np_ans, tf_ans)
    121             # NOTE(mrry): The static shape inference that computes
    122             # `tf_ans.shape` can only infer that sizes from dimension 1
    123             # onwards, because the size of dimension 0 is data-dependent
    124             # and may therefore vary dynamically.
    125             self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
    126 
    127   def testSegmentIdsShape(self):
    128     shape = [4, 4]
    129     tf_x, _ = self._input(shape)
    130     indices = constant_op.constant([0, 1, 2, 2], shape=[2, 2])
    131     with self.assertRaises(ValueError):
    132       math_ops.segment_sum(data=tf_x, segment_ids=indices)
    133 
    134   def testSegmentIdsSize(self):
    135     shape = [4, 4]
    136     for use_gpu in [True, False]:
    137       with self.test_session(use_gpu=use_gpu):
    138         tf_x, _ = self._input(shape)
    139         indices = [0, 1]
    140         s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    141         with self.assertRaisesOpError("segment_ids should be the same size"):
    142           s.eval()
    143 
    144   def testSegmentIdsValid(self):
    145     # This is a baseline for the following SegmentIdsInvalid* tests.
    146     shape = [4, 4]
    147     for use_gpu in [True, False]:
    148       with self.test_session(use_gpu=use_gpu):
    149         tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
    150         indices = [0, 0, 0, 1]
    151         result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
    152         self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
    153 
    154   def testSegmentIdsGreaterThanZero(self):
    155     shape = [4, 4]
    156     for use_gpu in [True, False]:
    157       with self.test_session(use_gpu=use_gpu):
    158         tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
    159         indices = [1, 1, 2, 2]
    160         np_ans = self._segmentReduce(indices, np_x, np.add)
    161         s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    162         tf_ans = s.eval()
    163         self.assertAllClose(np_ans, tf_ans)
    164 
    165   def testSegmentIdsHole(self):
    166     shape = [4, 4]
    167     for use_gpu in [True, False]:
    168       with self.test_session(use_gpu=use_gpu):
    169         tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
    170         indices = [0, 0, 3, 3]
    171         np_ans = self._segmentReduce(indices, np_x, np.add)
    172         s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    173         tf_ans = s.eval()
    174         self.assertAllClose(np_ans, tf_ans)
    175 
    176   def testSegmentIdsInvalid1(self):
    177     shape = [4, 4]
    178     with self.test_session():
    179       tf_x, _ = self._input(shape)
    180       indices = [-1, -1, 0, 0]
    181       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    182       with self.assertRaisesOpError(
    183           r"Segment id -1 out of range \[0, 1\), possibly because "
    184           "'segment_ids' input is not sorted."):
    185         s.eval()
    186 
    187   def testSegmentIdsInvalid2(self):
    188     shape = [4, 4]
    189     with self.test_session():
    190       tf_x, _ = self._input(shape)
    191       indices = [0, 1, 0, 1]
    192       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    193       with self.assertRaisesOpError("segment ids are not increasing"):
    194         s.eval()
    195 
    196   def testSegmentIdsInvalid3(self):
    197     shape = [4, 4]
    198     with self.test_session():
    199       tf_x, _ = self._input(shape)
    200       indices = [0, 1, 2, 0]
    201       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    202       with self.assertRaisesOpError(
    203           r"Segment id 1 out of range \[0, 1\), possibly "
    204           "because 'segment_ids' input is not sorted."):
    205         s.eval()
    206 
    207   def testSegmentIdsInvalid4(self):
    208     shape = [4, 4]
    209     for use_gpu in [True, False]:
    210       with self.test_session(use_gpu=use_gpu):
    211         tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
    212         indices = [0, 0, 0, -1]
    213         s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    214         with self.assertRaisesOpError("segment ids must be >= 0"):
    215           s.eval()
    216 
    217   def testSegmentIdsInvalid5(self):
    218     shape = [4, 4]
    219     for use_gpu in [True, False]:
    220       with self.test_session(use_gpu=use_gpu):
    221         tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
    222         indices = [0, 0, 0, -2]
    223         s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    224         with self.assertRaisesOpError("segment ids must be >= 0"):
    225           s.eval()
    226 
    227   def testGradient(self):
    228     shape = [4, 4]
    229     indices = [0, 1, 2, 2]
    230     for tf_op in [
    231         math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min,
    232         math_ops.segment_max
    233     ]:
    234       with self.test_session():
    235         tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
    236         s = tf_op(data=tf_x, segment_ids=indices)
    237         jacob_t, jacob_n = gradient_checker.compute_gradient(
    238             tf_x,
    239             shape,
    240             s, [3, 4],
    241             x_init_value=np_x.astype(np.double),
    242             delta=1)
    243       self.assertAllClose(jacob_t, jacob_n)
    244 
    245 
    246 class UnsortedSegmentTest(SegmentReductionHelper):
    247 
    248   def __init__(self, methodName='runTest'):
    249     # Each item is np_op1, np_op2, tf_op, initial_value functor
    250     self.ops_list = [(np.add, None,
    251                       math_ops.unsorted_segment_sum, lambda t: 0),
    252                      (self._mean_cum_op, self._mean_reduce_op,
    253                       math_ops.unsorted_segment_mean, lambda t: 0),
    254                      (self._mean_cum_op, self._sqrt_n_reduce_op,
    255                       math_ops.unsorted_segment_sqrt_n, lambda t: 0),
    256                      (np.ndarray.__mul__, None,
    257                       math_ops.unsorted_segment_prod, lambda t: 1),
    258                      (np.minimum, None,
    259                       math_ops.unsorted_segment_min, lambda t: t.max),
    260                      (np.maximum, None,
    261                       math_ops.unsorted_segment_max, lambda t: t.min)]
    262 
    263     # A subset of ops has been enabled for complex numbers
    264     self.complex_ops_list = [(np.add, None,
    265                               math_ops.unsorted_segment_sum, lambda t: 0)]
    266     self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
    267                                   dtypes_lib.float64]
    268     self.all_dtypes = (self.differentiable_dtypes +
    269                        [dtypes_lib.bfloat16,
    270                         dtypes_lib.int64, dtypes_lib.int32,
    271                         dtypes_lib.complex64, dtypes_lib.complex128])
    272     super(UnsortedSegmentTest, self).__init__(methodName=methodName)
    273 
    274   def testValues(self):
    275     indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
    276     num_segments = 12
    277     for indices in indices_flat, indices_flat.reshape(5, 2):
    278       shape = indices.shape + (2,)
    279       for dtype in self.all_dtypes:
    280         ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
    281         tf_x, np_x = self._input(shape, dtype=dtype)
    282         for use_gpu in [True, False]:
    283           with self.test_session(use_gpu=True):
    284             for np_op1, np_op2, tf_op, init_op in ops_list:
    285               # sqrt_n doesn't support integers
    286               if (np_op2 == self._sqrt_n_reduce_op and dtype.is_integer):
    287                 continue
    288               # todo(philjd): enable this test once real_div supports bfloat16
    289               if (np_op2 in [self._sqrt_n_reduce_op, self._mean_reduce_op] and
    290                   dtype == dtypes_lib.bfloat16):
    291                 continue
    292               np_ans = self._segmentReduce(
    293                   indices, np_x, np_op1, np_op2, num_segments=num_segments,
    294                   initial_value=init_op(dtype))
    295               s = tf_op(tf_x, segment_ids=indices, num_segments=num_segments)
    296               tf_ans = s.eval()
    297               if dtype is dtypes_lib.bfloat16:
    298                 tf_ans = tf_ans.astype(np.float32)
    299               self.assertAllClose(np_ans, tf_ans)
    300               self.assertShapeEqual(np_ans, s)
    301 
    302   def testNumSegmentsTypes(self):
    303     dtypes = [dtypes_lib.int32, dtypes_lib.int64]
    304     indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
    305     num_segments = 12
    306     for indices in indices_flat, indices_flat.reshape(5, 2):
    307       shape = indices.shape + (2,)
    308       for dtype in dtypes:
    309         with self.test_session(use_gpu=True):
    310           tf_x, np_x = self._input(shape)
    311           num_segments_constant = constant_op.constant(
    312               num_segments, dtype=dtype)
    313           np_ans = self._segmentReduce(
    314               indices, np_x, np.add, op2=None, num_segments=num_segments)
    315           s = math_ops.unsorted_segment_sum(
    316               data=tf_x,
    317               segment_ids=indices,
    318               num_segments=num_segments_constant)
    319           tf_ans = s.eval()
    320         self.assertAllClose(np_ans, tf_ans)
    321         self.assertShapeEqual(np_ans, s)
    322 
    323   def testGradients(self):
    324     num_cols = 2
    325     indices_flat = np.array([0, 4, 0, -1, 3, -1, 4, 7, 7, 3])
    326     num_segments = max(indices_flat) + 3
    327     for dtype in self.differentiable_dtypes:
    328       ops_list = self.complex_ops_list if dtype.is_complex else self.ops_list
    329       for indices in indices_flat, indices_flat.reshape(5, 2):
    330         shape = indices.shape + (num_cols,)
    331         # test CPU and GPU as tf.gather behaves differently on each device
    332         for use_gpu in [False, True]:
    333           with self.test_session(use_gpu=use_gpu):
    334             for _, _, tf_op, _ in ops_list:
    335               tf_x, np_x = self._input(shape, dtype=dtype)
    336               s = tf_op(tf_x, indices, num_segments)
    337               jacob_t, jacob_n = gradient_checker.compute_gradient(
    338                   tf_x,
    339                   shape,
    340                   s, [num_segments, num_cols],
    341                   x_init_value=np_x,
    342                   delta=1)
    343             self.assertAllClose(jacob_t, jacob_n)
    344 
    345   def testProdGrad(self):
    346     # additional test for the prod gradient to ensure correct handling of zeros
    347     values = np.array([0, 0, 1, 0, 2, 2, 3, 3, 3], dtype=np.float32)
    348     indices = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int32)
    349     indices_neg = np.array([-1, 0, 0, -1, 1, 1, -1, 2, 2], dtype=np.int32)
    350     values_tf = constant_op.constant(values)
    351     # ground truth partial derivatives
    352     gradients_indices = np.zeros((9, 3), dtype=np.float32)
    353     gradients_indices_neg = np.zeros((9, 3), dtype=np.float32)
    354     # the derivative w.r.t. to the other segments is zero, so here we only
    355     # explicitly set the grad values for the corresponding segment
    356     gradients_indices[range(9), indices] = [0, 0, 0, 4, 0, 0, 9, 9, 9]
    357     gradients_indices_neg[range(9), indices_neg] = [0, 1, 0, 0, 2, 2, 0, 3, 3]
    358     for use_gpu in [False, True]:
    359       with self.test_session(use_gpu=use_gpu):
    360         for ind, grad_gt in [(indices, gradients_indices),
    361                              (indices_neg, gradients_indices_neg)]:
    362           s = math_ops.unsorted_segment_prod(values_tf,
    363                                              constant_op.constant(ind), 3)
    364           jacob_t, jacob_n = gradient_checker.compute_gradient(
    365               values_tf, (9,), s, (3,), x_init_value=values, delta=1)
    366           self.assertAllClose(jacob_t, jacob_n)
    367           self.assertAllClose(jacob_t, grad_gt)
    368 
    369   def testGradientMatchesSegmentSum(self):
    370     # Strategy: compute the gradient for UnsortedSegmentSum and SegmentSum
    371     # and compare the outputs, which should be identical.
    372     # NB: for this test to work, indices must be valid for SegmentSum, namely
    373     # it must be sorted, the indices must be contiguous, and num_segments
    374     # must be max(indices) + 1.
    375     indices = [0, 0, 1, 1, 1, 2, 3, 4, 5]
    376     n = len(indices)
    377     num_cols = 2
    378     shape = [n, num_cols]
    379     num_segments = max(indices) + 1
    380     for dtype in self.differentiable_dtypes:
    381       with self.test_session(use_gpu=True):
    382         tf_x, np_x = self._input(shape, dtype=dtype)
    383         # Results from UnsortedSegmentSum
    384         unsorted_s = math_ops.unsorted_segment_sum(
    385             data=tf_x, segment_ids=indices, num_segments=num_segments)
    386         unsorted_jacob_t, unsorted_jacob_n = (
    387             gradient_checker.compute_gradient(tf_x, shape, unsorted_s,
    388                                               [num_segments, num_cols],
    389                                               x_init_value=np_x, delta=1))
    390 
    391         # Results from SegmentSum
    392         sorted_s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
    393         sorted_jacob_t, sorted_jacob_n = gradient_checker.compute_gradient(
    394             tf_x,
    395             shape,
    396             sorted_s, [num_segments, num_cols],
    397             x_init_value=np_x,
    398             delta=1)
    399       self.assertAllClose(unsorted_jacob_t, sorted_jacob_t)
    400       self.assertAllClose(unsorted_jacob_n, sorted_jacob_n)
    401 
    402   def testBadIndices(self):
    403     # Note: GPU kernel does not return the out-of-range error needed for this
    404     # test, so this test is marked as cpu-only.
    405     # Note: With PR #13055 a negative index will be ignored silently.
    406     with self.test_session(use_gpu=False):
    407       for bad in [[2]], [[7]]:
    408         unsorted = math_ops.unsorted_segment_sum([[17]], bad, num_segments=2)
    409         with self.assertRaisesOpError(
    410             r"segment_ids\[0,0\] = %d is out of range \[0, 2\)" % bad[0][0]):
    411           unsorted.eval()
    412 
    413   def testEmptySecondDimension(self):
    414     dtypes = [np.float16, np.float32, np.float64, np.int64, np.int32,
    415               np.complex64, np.complex128]
    416     with self.test_session(use_gpu=True):
    417       for dtype in dtypes:
    418         for itype in (np.int32, np.int64):
    419           data = np.zeros((2, 0), dtype=dtype)
    420           segment_ids = np.array([0, 1], dtype=itype)
    421           unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
    422           self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
    423 
    424   def testDropNegatives(self):
    425     # Note: the test is done by replacing segment_ids with 8 to -1
    426     # for index  and replace values generated by numpy with 0.
    427     indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
    428     num_segments = 12
    429     for indices in indices_flat, indices_flat.reshape(5, 2):
    430       shape = indices.shape + (2,)
    431       for dtype in self.all_dtypes:
    432         with self.test_session(use_gpu=True):
    433           tf_x, np_x = self._input(shape, dtype=dtype)
    434           np_ans = self._segmentReduce(
    435               indices, np_x, np.add, op2=None, num_segments=num_segments)
    436           # Replace np_ans[8] with 0 for the value
    437           np_ans[8:] = 0
    438           # Replace 8 with -1 in indices
    439           np.place(indices, indices == 8, [-1])
    440           s = math_ops.unsorted_segment_sum(
    441               data=tf_x, segment_ids=indices, num_segments=num_segments)
    442           tf_ans = s.eval()
    443         self.assertAllClose(np_ans, tf_ans)
    444         self.assertShapeEqual(np_ans, s)
    445 
    446 
    447 class SparseSegmentReductionHelper(SegmentReductionHelper):
    448 
    449   def _sparse_input(self, input_shape, num_indices, dtype=dtypes_lib.int32):
    450     a, b = super(SparseSegmentReductionHelper, self)._input(input_shape, dtype)
    451     indices = np.random.randint(0, input_shape[0], num_indices).astype(np.int32)
    452     return (constant_op.constant(
    453         indices, dtype=dtypes_lib.int32), indices, a, b)
    454 
    455   def _sparseSegmentReduce(self,
    456                            x,
    457                            indices,
    458                            segment_indices,
    459                            op1,
    460                            op2=None,
    461                            num_segments=None):
    462     return self._segmentReduce(
    463         segment_indices, x[indices], op1, op2, num_segments=num_segments)
    464 
    465 
    466 class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
    467 
    468   def testValues(self):
    469     dtypes = [
    470         dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int64,
    471         dtypes_lib.int32
    472     ]
    473 
    474     mean_dtypes = [dtypes_lib.float32, dtypes_lib.float64]
    475 
    476     # Each item is np_op1, np_op2, tf_op
    477     ops_list = [(np.add, None, math_ops.sparse_segment_sum),
    478                 (self._mean_cum_op, self._mean_reduce_op,
    479                  math_ops.sparse_segment_mean)]
    480 
    481     n = 400
    482     shape = [n, 2]
    483     segment_indices = []
    484     for i in range(20):
    485       for _ in range(i + 1):
    486         segment_indices.append(i)
    487     num_indices = len(segment_indices)
    488     for dtype in dtypes:
    489       with self.test_session(use_gpu=False):
    490         tf_indices, np_indices, tf_x, np_x = self._sparse_input(
    491             shape, num_indices, dtype=dtype)
    492         for np_op1, np_op2, tf_op in ops_list:
    493           if tf_op == math_ops.sparse_segment_mean and dtype not in mean_dtypes:
    494             continue
    495           np_ans = self._sparseSegmentReduce(np_x, np_indices, segment_indices,
    496                                              np_op1, np_op2)
    497           s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    498           tf_ans = s.eval()
    499           self.assertAllClose(np_ans, tf_ans)
    500           # NOTE(mrry): The static shape inference that computes
    501           # `tf_ans.shape` can only infer that sizes from dimension 1
    502           # onwards, because the size of dimension 0 is data-dependent
    503           # and may therefore vary dynamically.
    504           self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
    505 
    506   def testSegmentIdsHole(self):
    507     tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
    508     ops_list = [(np.add, None, math_ops.sparse_segment_sum), (
    509         self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)]
    510     segment_indices = [0, 2, 2, 2]
    511     tf_indices = [8, 3, 0, 9]
    512     with self.test_session(use_gpu=False):
    513       for np_op1, np_op2, tf_op in ops_list:
    514         np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices,
    515                                            np_op1, np_op2)
    516         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    517         tf_ans = s.eval()
    518         self.assertAllClose(np_ans, tf_ans)
    519 
    520   def testWithNumSegments(self):
    521     tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
    522     ops_list = [(np.add, None, math_ops.sparse_segment_sum_with_num_segments),
    523                 (self._mean_cum_op, self._mean_reduce_op,
    524                  math_ops.sparse_segment_mean_with_num_segments)]
    525     segment_indices = [0, 2, 2, 2]
    526     tf_indices = [8, 3, 0, 9]
    527     num_segments = 5
    528     with self.test_session(use_gpu=False):
    529       for np_op1, np_op2, tf_op in ops_list:
    530         np_ans = self._sparseSegmentReduce(
    531             np_x,
    532             tf_indices,
    533             segment_indices,
    534             np_op1,
    535             np_op2,
    536             num_segments=num_segments)
    537         s = tf_op(
    538             data=tf_x,
    539             indices=tf_indices,
    540             segment_ids=segment_indices,
    541             num_segments=num_segments)
    542         tf_ans = s.eval()
    543         self.assertAllClose(np_ans, tf_ans)
    544 
    545   def testSegmentIdsGreaterThanZero(self):
    546     tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
    547     ops_list = [(np.add, None, math_ops.sparse_segment_sum), (
    548         self._mean_cum_op, self._mean_reduce_op, math_ops.sparse_segment_mean)]
    549     segment_indices = [1, 2, 2, 2]
    550     tf_indices = [8, 3, 0, 9]
    551     with self.test_session(use_gpu=False):
    552       for np_op1, np_op2, tf_op in ops_list:
    553         np_ans = self._sparseSegmentReduce(np_x, tf_indices, segment_indices,
    554                                            np_op1, np_op2)
    555         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    556         tf_ans = s.eval()
    557         self.assertAllClose(np_ans, tf_ans)
    558 
    559   def testValid(self):
    560     # Baseline for the test*Invalid* methods below.
    561     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    562     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    563     segment_indices = [0, 1, 2, 2]
    564     tf_indices = [8, 3, 0, 9]
    565     with self.test_session(use_gpu=False):
    566       for tf_op in ops_list:
    567         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    568         s.eval()
    569 
    570   def testIndicesInvalid1(self):
    571     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    572     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    573     segment_indices = [0, 1, 2, 2]
    574     tf_indices = [8, -1, 0, 9]
    575     with self.test_session(use_gpu=False):
    576       for tf_op in ops_list:
    577         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    578         with self.assertRaisesOpError(
    579             r"indices\[1\] == -1 out of range \[0, 10\)"):
    580           s.eval()
    581 
    582   def testIndicesInvalid2(self):
    583     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    584     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    585     segment_indices = [0, 1, 2, 2]
    586     tf_indices = [8, 3, 0, 10]
    587     with self.test_session(use_gpu=False):
    588       for tf_op in ops_list:
    589         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    590         with self.assertRaisesOpError(
    591             r"indices\[3\] == 10 out of range \[0, 10\)"):
    592           s.eval()
    593 
    594   def testSegmentsInvalid2(self):
    595     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    596     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    597     segment_indices = [0, 1, 0, 1]
    598     tf_indices = [8, 3, 0, 9]
    599     with self.test_session(use_gpu=False):
    600       for tf_op in ops_list:
    601         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    602         with self.assertRaisesOpError("segment ids are not increasing"):
    603           s.eval()
    604 
    605   def testSegmentsInvalid3(self):
    606     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    607     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    608     segment_indices = [0, 1, 2, 0]
    609     tf_indices = [8, 3, 0, 9]
    610     with self.test_session(use_gpu=False):
    611       for tf_op in ops_list:
    612         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    613         with self.assertRaisesOpError(
    614             r"Segment id 1 out of range \[0, 1\), possibly because "
    615             "'segment_ids' input is not sorted"):
    616           s.eval()
    617 
    618   def testSegmentsInvalid4(self):
    619     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    620     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    621     segment_indices = [-1, 0, 1, 1]
    622     tf_indices = [8, 3, 0, 9]
    623     with self.test_session(use_gpu=False):
    624       for tf_op in ops_list:
    625         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    626         with self.assertRaisesOpError(
    627             r"Segment id -1 out of range \[0, 2\), possibly because "
    628             "'segment_ids' input is not sorted"):
    629           s.eval()
    630 
    631   def testSegmentsInvalid6(self):
    632     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    633     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    634     segment_indices = [0, 0, 0, -1]
    635     tf_indices = [8, 3, 0, 9]
    636     with self.test_session(use_gpu=False):
    637       for tf_op in ops_list:
    638         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    639         with self.assertRaisesOpError("segment ids must be >= 0"):
    640           s.eval()
    641 
    642   def testSegmentsInvalid7(self):
    643     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    644     ops_list = [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]
    645     segment_indices = [0, 0, 0, -2]
    646     tf_indices = [8, 3, 0, 9]
    647     with self.test_session(use_gpu=False):
    648       for tf_op in ops_list:
    649         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    650         with self.assertRaisesOpError("segment ids must be >= 0"):
    651           s.eval()
    652 
    653   def testSegmentWithNumSegmentsValid(self):
    654     # Baseline for the test*WithNumSegmentsInvalid* methods below.
    655     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    656     ops_list = [
    657         math_ops.sparse_segment_sum_with_num_segments,
    658         math_ops.sparse_segment_mean_with_num_segments,
    659     ]
    660     num_segments = 5
    661     segment_indices = [0, 1, 3, 3]
    662     tf_indices = [8, 3, 0, 9]
    663     with self.test_session(use_gpu=False):
    664       for tf_op in ops_list:
    665         s = tf_op(
    666             data=tf_x,
    667             indices=tf_indices,
    668             segment_ids=segment_indices,
    669             num_segments=num_segments)
    670         s.eval()
    671 
    672   def testSegmentWithNumSegmentsInvalid1(self):
    673     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    674     ops_list = [
    675         math_ops.sparse_segment_sum_with_num_segments,
    676         math_ops.sparse_segment_mean_with_num_segments,
    677     ]
    678     num_segments = 5
    679     segment_indices = [0, 1, 3, 5]
    680     tf_indices = [8, 3, 0, 9]
    681     with self.test_session(use_gpu=False):
    682       for tf_op in ops_list:
    683         s = tf_op(
    684             data=tf_x,
    685             indices=tf_indices,
    686             segment_ids=segment_indices,
    687             num_segments=num_segments)
    688         with self.assertRaisesOpError("segment ids must be < num_segments"):
    689           s.eval()
    690 
    691   def testSegmentWithNumSegmentsInvalid2(self):
    692     tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
    693     ops_list = [
    694         math_ops.sparse_segment_sum_with_num_segments,
    695         math_ops.sparse_segment_mean_with_num_segments,
    696     ]
    697     num_segments = -2
    698     segment_indices = [0, 1, 3, 3]
    699     tf_indices = [8, 3, 0, 9]
    700     with self.test_session(use_gpu=False):
    701       for tf_op in ops_list:
    702         with self.assertRaisesRegexp(
    703             ValueError, "Cannot specify a negative value for num_segments"):
    704           tf_op(
    705               data=tf_x,
    706               indices=tf_indices,
    707               segment_ids=segment_indices,
    708               num_segments=num_segments)
    709 
    710   def testGradient(self):
    711     shape = [10, 4]
    712 
    713     segment_indices = [0, 1, 2, 2]
    714     num_indices = len(segment_indices)
    715     for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
    716       with self.test_session():
    717         tf_indices, _, tf_x, np_x = self._sparse_input(
    718             shape, num_indices, dtype=dtypes_lib.float64)
    719         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
    720         jacob_t, jacob_n = gradient_checker.compute_gradient(
    721             tf_x,
    722             shape,
    723             s, [3, 4],
    724             x_init_value=np_x.astype(np.double),
    725             delta=1)
    726       self.assertAllClose(jacob_t, jacob_n)
    727 
    728   def testGradientWithEmptySegmentsAtEnd(self):
    729     shape = [10, 4]
    730 
    731     num_segments = 5
    732     segment_indices = [0, 1, 2, 2]
    733     num_indices = len(segment_indices)
    734     for tf_op in [
    735         math_ops.sparse_segment_sum_with_num_segments,
    736         math_ops.sparse_segment_mean_with_num_segments,
    737     ]:
    738       with self.test_session():
    739         tf_indices, _, tf_x, np_x = self._sparse_input(
    740             shape, num_indices, dtype=dtypes_lib.float64)
    741         s = tf_op(
    742             data=tf_x,
    743             indices=tf_indices,
    744             segment_ids=segment_indices,
    745             num_segments=num_segments)
    746         jacob_t, jacob_n = gradient_checker.compute_gradient(
    747             tf_x,
    748             shape,
    749             s, [5, 4],
    750             x_init_value=np_x.astype(np.double),
    751             delta=1)
    752       self.assertAllClose(jacob_t, jacob_n)
    753 
    754   def testGradientValid(self):
    755     # Baseline for the testGradient*Invalid* methods below.
    756     tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
    757     ops_list = [
    758         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    759     ]
    760     segment_indices = [0, 1, 2, 2]
    761     tf_indices = [8, 3, 0, 9]
    762     with self.test_session(use_gpu=False):
    763       for tf_op in ops_list:
    764         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    765         s.eval()
    766 
    767   def testGradientIndicesInvalid1(self):
    768     tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
    769     ops_list = [
    770         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    771     ]
    772     segment_indices = [0, 1, 2, 2]
    773     tf_indices = [8, 3, 0, 10]
    774     with self.test_session(use_gpu=False):
    775       for tf_op in ops_list:
    776         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    777         with self.assertRaisesOpError(r"Index 10 out of range \[0, 10\)"):
    778           s.eval()
    779 
    780   def testGradientIndicesInvalid2(self):
    781     tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
    782     ops_list = [
    783         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    784     ]
    785     segment_indices = [0, 1, 2, 2]
    786     tf_indices = [8, 3, -1, 9]
    787     with self.test_session(use_gpu=False):
    788       for tf_op in ops_list:
    789         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    790         with self.assertRaisesOpError(r"Index -1 out of range \[0, 10\)"):
    791           s.eval()
    792 
    793   def testGradientSegmentsInvalid1(self):
    794     tf_x, _ = self._input(
    795         [3, 4], dtype=dtypes_lib.float32)  # expecting 3 segments
    796     ops_list = [
    797         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    798     ]
    799     segment_indices = [0, 1, 1, 4]  # 5 segments
    800     tf_indices = [8, 3, 0, 9]
    801     with self.test_session(use_gpu=False):
    802       for tf_op in ops_list:
    803         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    804         with self.assertRaisesOpError("Invalid number of segments"):
    805           s.eval()
    806 
    807   def testGradientSegmentsInvalid2(self):
    808     tf_x, _ = self._input([1, 4], dtype=dtypes_lib.float32)
    809     ops_list = [
    810         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    811     ]
    812     segment_indices = [0, 1, 2, 0]
    813     tf_indices = [8, 3, 0, 9]
    814     with self.test_session(use_gpu=False):
    815       for tf_op in ops_list:
    816         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    817         with self.assertRaisesOpError(r"Segment id 1 out of range \[0, 1\)"):
    818           s.eval()
    819 
    820   def testGradientSegmentsInvalid3(self):
    821     tf_x, _ = self._input([2, 4], dtype=dtypes_lib.float32)
    822     ops_list = [
    823         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    824     ]
    825     segment_indices = [-1, 0, 1, 1]
    826     tf_indices = [8, 3, 0, 9]
    827     with self.test_session(use_gpu=False):
    828       for tf_op in ops_list:
    829         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    830         with self.assertRaisesOpError(r"Segment id -1 out of range \[0, 2\)"):
    831           s.eval()
    832 
    833   def testGradientSegmentsInvalid4(self):
    834     tf_x, _ = self._input([0, 4], dtype=dtypes_lib.float32)
    835     ops_list = [
    836         math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
    837     ]
    838     segment_indices = [0, 1, 2, -1]
    839     tf_indices = [8, 3, 0, 9]
    840     with self.test_session(use_gpu=False):
    841       for tf_op in ops_list:
    842         s = tf_op(tf_x, tf_indices, segment_indices, 10)
    843         with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"):
    844           s.eval()
    845 
    846 class SegmentReductionOpBenchmark(test.Benchmark):
    847   outer_dim_options = [2**x for x in range(9, 14, 2)]
    848   ratio_options = [2**x for x in range(1, 6, 2)]
    849   inner_dim_options = [2**x for x in range(9, 14, 2)]
    850   # randomly generated sizes with less alignments
    851   inner_dim_options += [
    852       1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584
    853   ]
    854   dtype_options = [np.float32, np.float64]
    855   options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options)
    856   # pylint: disable=g-long-lambda
    857   op_functors = [lambda vc, vs, seg_ids:
    858                  ("sorted", math_ops.segment_sum(vc, vs)),
    859                  lambda vc, vs, seg_ids:
    860                  ("unsorted",
    861                   math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
    862   # pylint: enable=g-long-lambda
    863   repeat = 10
    864 
    865   def _npTypeToStr(self, t):
    866     if t == np.float32:
    867       return "fp32"
    868     if t == np.float64:
    869       return "fp64"
    870 
    871   def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
    872     output_outer_dim = int(outer_dim / ratio)
    873     const = np.random.randint(5, size=(outer_dim, inner_dim))
    874     seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim))
    875     vs = variables.Variable(seg_ids.astype(np.int32))
    876     with ops.device("/gpu:0"):
    877       vc = variables.Variable(const.astype(dtype))
    878     name, op = op_functor(vc, vs, seg_ids)
    879     with session.Session() as sess:
    880       variables.global_variables_initializer().run()
    881       r = self.run_op_benchmark(
    882           sess,
    883           op,
    884           min_iters=self.repeat,
    885           name="_".join(
    886               map(str,
    887                   [name, outer_dim, ratio, inner_dim,
    888                    self._npTypeToStr(dtype)])))
    889     return name, r["wall_time"]
    890 
    891   def benchmarkSegmentSumGPU(self):
    892     if not test.is_gpu_available(cuda_only=True):
    893       return
    894     for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
    895       op_functor = self.op_functors[0]
    896       with ops.Graph().as_default():
    897         self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
    898 
    899   def benchmarkUnsortedSegmentSumGPU(self):
    900     if not test.is_gpu_available(cuda_only=True):
    901       return
    902     for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
    903       op_functor = self.op_functors[1]
    904       with ops.Graph().as_default():
    905         self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
    906 
    907 
    908 if __name__ == "__main__":
    909   test.main()
    910