Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 #if GOOGLE_CUDA
     20 #define EIGEN_USE_GPU
     21 #endif  // GOOGLE_CUDA
     22 
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/kernels/segment_reduction_ops.h"
     26 #include <vector>
     27 #include "tensorflow/core/framework/numeric_op.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_types.h"
     32 #include "tensorflow/core/framework/types.h"
     33 #include "tensorflow/core/kernels/bounds_check.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/util/util.h"
     37 
     38 #if GOOGLE_CUDA
     39 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     40 #include "tensorflow/core/kernels/cuda_solvers.h"
     41 #include "tensorflow/core/platform/cuda.h"
     42 
     43 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
     44 #endif  // GOOGLE_CUDA
     45 
     46 namespace tensorflow {
     47 
     48 typedef Eigen::ThreadPoolDevice CPUDevice;
     49 typedef Eigen::GpuDevice GPUDevice;
     50 
     51 // Static routines not in the templated class to reduce code size
     52 static void SegmentReductionValidationHelper(OpKernelContext* context,
     53                                              const Tensor& input,
     54                                              const Tensor& segment_ids) {
     55   OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
     56               errors::InvalidArgument("segment_ids should be a vector."));
     57   const int64 num_indices = segment_ids.NumElements();
     58   OP_REQUIRES(context, num_indices == input.dim_size(0),
     59               errors::InvalidArgument(
     60                   "segment_ids should be the same size as dimension 0 of"
     61                   " input."));
     62 }
     63 
     64 static bool SegmentReductionDoValidation(OpKernelContext* c,
     65                                          const Tensor& input,
     66                                          const Tensor& segment_ids) {
     67   SegmentReductionValidationHelper(c, input, segment_ids);
     68   return c->status().ok();
     69 }
     70 
     71 // This operator handles reducing segments along the first dimension.
     72 // See core/ops/math_ops.cc for more details.
     73 template <typename Device, class T, class Index, typename Reducer,
     74           int default_value>
     75 class SegmentReductionOp : public OpKernel {
     76  public:
     77   explicit SegmentReductionOp(OpKernelConstruction* context)
     78       : OpKernel(context) {}
     79 
     80   void Compute(OpKernelContext* context) override {
     81     const Tensor& input = context->input(0);
     82     const Tensor& segment_ids = context->input(1);
     83 
     84     if (!SegmentReductionDoValidation(context, input, segment_ids)) {
     85       return;
     86     }
     87 
     88     const int64 num_indices = segment_ids.NumElements();
     89     auto input_flat = input.flat_outer_dims<T>();
     90     const int64 num_col = input_flat.dimension(1);
     91 
     92     const auto segment_vec = segment_ids.vec<Index>();
     93     // Note that the current implementation assumes that segment_vec values are
     94     // sorted.
     95     const Index output_rows =
     96         num_indices > 0
     97             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
     98             : 0;
     99     OP_REQUIRES(context, output_rows >= 0,
    100                 errors::InvalidArgument("segment ids must be >= 0"));
    101 
    102     TensorShape output_shape = input.shape();
    103     output_shape.set_dim(0, output_rows);
    104 
    105     // Note that we do not initialize the output buffer with a default value, so
    106     // we need to explicitly set missing indices to the default value.
    107     Tensor* output = nullptr;
    108     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    109     if (num_indices == 0) return;
    110     OP_REQUIRES(context, output_rows > 0,
    111                 errors::InvalidArgument("segment ids must be >= 0"));
    112     auto output_flat = output->flat_outer_dims<T>();
    113 
    114 #if !defined(EIGEN_HAS_INDEX_LIST)
    115     Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
    116     dims_to_reduce[0] = 0;
    117 #else
    118     Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
    119 #endif
    120     Index start = 0, end = 1;
    121 
    122     Index uninitialized_index = 0;  // Index from which the output is not set.
    123     Index out_index = internal::SubtleMustCopy(segment_vec(start));
    124 
    125     // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
    126     // across threads.
    127     Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
    128     while (end <= num_indices) {
    129       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
    130       // used uninitialized in this function" in the Mac build (since the
    131       // compiler isn't smart enough to realize the code is safe).
    132       Index next_index = 0;
    133       if (end < num_indices) {
    134         next_index = internal::SubtleMustCopy(segment_vec(end));
    135         if (out_index == next_index) {
    136           ++end;
    137           continue;
    138         }
    139         // We have a new segment here.  Verify that the segment ids are growing.
    140         OP_REQUIRES(context, out_index < next_index,
    141                     errors::InvalidArgument("segment ids are not increasing"));
    142       }
    143 
    144       // Process segment [start, end)
    145       const T* in_slice_ptr = &input_flat(start, 0);
    146       typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
    147                                Eigen::Unaligned>
    148           OutT;
    149 
    150       OP_REQUIRES(
    151           context, FastBoundsCheck(out_index, output_rows),
    152           errors::InvalidArgument(
    153               "Segment id ", out_index, " out of range [0, ", output_rows,
    154               "), possibly because 'segment_ids' input is not sorted."));
    155 
    156       // If there is a gap between two indices, we need to set that gap to the
    157       // default value.
    158       if (out_index > uninitialized_index) {
    159         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    160             out_index - uninitialized_index, num_col);
    161         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    162             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    163         gap_slice.setConstant(T(default_value));
    164       }
    165 
    166       T* out_slice_ptr = &output_flat(out_index, 0);
    167       OutT out_slice(out_slice_ptr, out_slice_shape);
    168       // We don't use out_slice.device(context->eigen_device<Device>)
    169       // because these pieces of work are likely to be very small and
    170       // the context switching overhead dwarfs any benefit we get from
    171       // using another thread to do this work.
    172       if (start == end - 1) {
    173         typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
    174                                  Eigen::Unaligned>
    175             InT;
    176         InT in_slice(in_slice_ptr, out_slice_shape);
    177         out_slice = in_slice;
    178       } else {
    179         Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
    180                                                            num_col);
    181         typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    182                                  Eigen::Unaligned>
    183             InT;
    184         InT in_slice(in_slice_ptr, in_slice_shape);
    185 
    186         out_slice = in_slice.reduce(dims_to_reduce, Reducer());
    187       }
    188       if (end >= num_indices) break;
    189       start = end;
    190       ++end;
    191       uninitialized_index = out_index + 1;
    192       out_index = next_index;
    193     }
    194   }
    195 };
    196 
    197 #ifdef GOOGLE_CUDA
    198 //  SegmentSumGPUOp is a segment sum operator implemented for GPU only.
    199 //  TODO: This implementation of SegmentSumGPUOp is sometimes slower than
    200 //  its unsorted counterpart (mostly when problem size is small).
    201 //  This is due to the following two main reasons and a cost-effective way
    202 //  to resolve these problems is desirable.
    203 //  1. Sorted segment sum requires a memory transfer from device to host in
    204 //     order to know the size of the output dimension whereas unsorted segment
    205 //     sum receives the size of the output dimension as an input parameter.
    206 //  2. Sorted segment sum is essentially a tiled version of unsorted segment
    207 //     sum and therefore such optimization comes at an inherent cost. However
    208 //     such cost may not be justified when the problem size is small. When to
    209 //     use the tiled version or the untiled version depends on many factors
    210 //     including data alignments, ratio of calculation to memory traffic and
    211 //     obviously, the problem sizes.
    212 template <class T, class Index>
    213 class SegmentSumGPUOp : public AsyncOpKernel {
    214  public:
    215   explicit SegmentSumGPUOp(OpKernelConstruction* context)
    216       : AsyncOpKernel(context) {}
    217 
    218   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    219     const Tensor& input = context->input(0);
    220     const Tensor& segment_ids = context->input(1);
    221 
    222     OP_REQUIRES_ASYNC(
    223         context, TensorShapeUtils::IsVector(segment_ids.shape()),
    224         errors::InvalidArgument("segment_ids should be a vector."), done);
    225 
    226     const int64 num_indices = segment_ids.NumElements();
    227     OP_REQUIRES_ASYNC(
    228         context, num_indices == input.dim_size(0),
    229         errors::InvalidArgument(
    230             "segment_ids should be the same size as dimension 0 of"
    231             " input."),
    232         done);
    233 
    234     if (num_indices == 0) {
    235       TensorShape output_shape = input.shape();
    236       output_shape.set_dim(0, 0);
    237 
    238       Tensor* output = nullptr;
    239       OP_REQUIRES_OK_ASYNC(
    240           context, context->allocate_output(0, output_shape, &output), done);
    241       done();
    242       return;
    243     }
    244 
    245     perftools::gputools::DeviceMemoryBase output_rows_device(
    246         const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
    247         (num_indices - 1));
    248     ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
    249 
    250     auto stream = context->op_device_context()->stream();
    251     OP_REQUIRES_ASYNC(
    252         context,
    253         stream
    254             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
    255                          sizeof(Index))
    256             .ok(),
    257         errors::Internal(
    258             "SegmentSumGPUOp: failed to copy output_rows from device"),
    259         done);
    260 
    261     functor::SegmentSumFunctor<T, Index> functor_;
    262     auto create_and_check_output = [context, output_rows_host, &input,
    263                                     &segment_ids, &functor_, done]() {
    264       // Ensure that within the callback, the proper GPU settings are
    265       // configured.
    266       auto stream = context->op_device_context()->stream();
    267       ScopedActivateExecutorContext scoped_activation{stream->parent()};
    268 
    269       Index output_rows = *output_rows_host.data();
    270       output_rows++;
    271       OP_REQUIRES_ASYNC(context, output_rows > 0,
    272                         errors::InvalidArgument("segment ids must be >= 0"),
    273                         done);
    274 
    275       TensorShape output_shape = input.shape();
    276       output_shape.set_dim(0, output_rows);
    277 
    278       Tensor* output = nullptr;
    279       OP_REQUIRES_OK_ASYNC(
    280           context, context->allocate_output(0, output_shape, &output), done);
    281 
    282       auto output_flat = output->flat_outer_dims<T>();
    283       auto data_ptr = input.template flat<T>().data();
    284       auto segment_flat = segment_ids.flat<Index>();
    285       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
    286                segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
    287                output_flat);
    288 
    289       done();
    290     };
    291 
    292     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
    293         stream, create_and_check_output);
    294   }
    295 };
    296 #endif  // GOOGLE_CUDA
    297 
    298 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
    299                                     default_value)                   \
    300   REGISTER_KERNEL_BUILDER(                                           \
    301       Name(name)                                                     \
    302           .Device(DEVICE_CPU)                                        \
    303           .TypeConstraint<type>("T")                                 \
    304           .TypeConstraint<index_type>("Tindices"),                   \
    305       SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
    306 
    307 #define REGISTER_REAL_CPU_KERNELS(type, index_type)                            \
    308   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
    309                               type, index_type, 0);                            \
    310   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    311       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
    312   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    313       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
    314   REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
    315                               type, index_type, 0);                            \
    316   REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
    317                               type, index_type, 0)
    318 
    319 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type)                         \
    320   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
    321                               type, index_type, 0);                            \
    322   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    323       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1)
    324 
    325 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
    326   REGISTER_REAL_CPU_KERNELS(type, int32);   \
    327   REGISTER_REAL_CPU_KERNELS(type, int64)
    328 
    329 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
    330   REGISTER_COMPLEX_CPU_KERNELS(type, int32);   \
    331   REGISTER_COMPLEX_CPU_KERNELS(type, int64)
    332 
    333 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
    334 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
    335 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
    336 #undef REGISTER_CPU_KERNEL_SEGMENT
    337 #undef REGISTER_REAL_CPU_KERNELS
    338 #undef REGISTER_COMPLEX_CPU_KERNELS
    339 #undef REGISTER_REAL_CPU_KERNELS_ALL
    340 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
    341 
    342 #if GOOGLE_CUDA
    343 #define REGISTER_GPU_SORTED_KERNELS(type, index_type)                  \
    344   REGISTER_KERNEL_BUILDER(Name("SegmentSum")                           \
    345                               .Device(DEVICE_GPU)                      \
    346                               .TypeConstraint<type>("T")               \
    347                               .TypeConstraint<index_type>("Tindices"), \
    348                           SegmentSumGPUOp<type, index_type>)
    349 
    350 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
    351   REGISTER_GPU_SORTED_KERNELS(type, int32);   \
    352   REGISTER_GPU_SORTED_KERNELS(type, int64);
    353 
    354 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
    355 #undef REGISTER_GPU_SORTED_KERNELS
    356 #undef REGISTER_GPU_SORTED_KERNELS_ALL
    357 #endif  // GOOGLE_CUDA
    358 
    359 // ____________________________________________________________________________
    360 // Unsorted segment reduction ops.
    361 
    362 namespace functor {
    363 
    364 // The ReductionFunctor implementation for CPU.
    365 template <typename T, typename Index, typename InitialValueF,
    366           typename ReductionF>
    367 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
    368   void operator()(OpKernelContext* ctx, const Index num_segments,
    369                   const TensorShape& segment_ids_shape,
    370                   typename TTypes<Index>::ConstFlat segment_ids,
    371                   const Index data_size, const T* data,
    372                   typename TTypes<T, 2>::Tensor output) {
    373     output.setConstant(InitialValueF()());
    374     if (data_size == 0) {
    375       return;
    376     }
    377     const int64 N = segment_ids.dimension(0);
    378     ReductionF reduction;
    379     auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
    380     for (int64 i = 0; i < N; ++i) {
    381       Index j = internal::SubtleMustCopy(segment_ids(i));
    382       if (j < 0) {
    383         continue;
    384       }
    385       OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
    386                   errors::InvalidArgument(
    387                       "segment_ids", SliceDebugString(segment_ids_shape, i),
    388                       " = ", j, " is out of range [0, ", num_segments, ")"));
    389       reduction(data_flat.template chip<0>(i), output.template chip<0>(j));
    390     }
    391   }
    392 };
    393 
    394 template <typename T>
    395 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
    396 
    397 template <typename T>
    398 using constMatrixChip =
    399     Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
    400 
    401 // reduction functors
    402 template <typename T>
    403 struct SumOp {
    404   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    405     output += data;
    406   }
    407 };
    408 
    409 template <typename T>
    410 struct MaxOp {
    411   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    412     output = data.cwiseMax(output);
    413   }
    414 };
    415 
    416 template <typename T>
    417 struct MinOp {
    418   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    419     output = data.cwiseMin(output);
    420   }
    421 };
    422 
    423 template <typename T>
    424 struct ProdOp {
    425   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    426     output *= data;
    427   }
    428 };
    429 }  // namespace functor
    430 
    431 // Static check routines not in the templated class to reduce code size
    432 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel,
    433                                                OpKernelContext* context,
    434                                                const Tensor& data,
    435                                                const Tensor& segment_ids,
    436                                                const Tensor& num_segments) {
    437   OP_REQUIRES(
    438       context, op_kernel->IsLegacyScalar(num_segments.shape()),
    439       errors::InvalidArgument("num_segments should be a scalar, not shape ",
    440                               num_segments.shape().DebugString()));
    441   OP_REQUIRES(
    442       context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
    443       errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
    444                               " does not start with segment_ids.shape = ",
    445                               segment_ids.shape().DebugString()));
    446 }
    447 
    448 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel,
    449                                                  OpKernelContext* context,
    450                                                  const Tensor& data,
    451                                                  const Tensor& segment_ids,
    452                                                  const Tensor& num_segments) {
    453   UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids,
    454                                      num_segments);
    455   return context->status().ok();
    456 }
    457 
    458 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
    459 // is the device specific implementation of the reduction. These device
    460 // specific implementations are templated themselves with the corresponding
    461 // initial value functors and reduction functors.
    462 template <typename T, typename Index, typename DeviceReductionFunctor>
    463 class UnsortedSegmentReductionOp : public OpKernel {
    464  public:
    465   explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
    466       : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
    467 
    468   void Compute(OpKernelContext* context) override {
    469     const Tensor& data = context->input(0);
    470     const Tensor& segment_ids = context->input(1);
    471     const Tensor& num_segments = context->input(2);
    472     if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids,
    473                                               num_segments)) {
    474       return;
    475     }
    476     const auto segment_flat = segment_ids.flat<Index>();
    477     const Index output_rows =
    478         internal::SubtleMustCopy(num_segments.scalar<int32>()());
    479     OP_REQUIRES(context, output_rows >= 0,
    480                 errors::InvalidArgument("Input num_segments == ", output_rows,
    481                                         " must not be negative."));
    482     TensorShape output_shape;
    483     output_shape.AddDim(output_rows);
    484     for (int i = segment_ids.dims(); i < data.dims(); i++) {
    485       output_shape.AddDim(data.dim_size(i));
    486     }
    487     Tensor* output = nullptr;
    488     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    489     auto output_flat = output->flat_outer_dims<T>();
    490     auto data_ptr = data.template flat<T>().data();
    491     reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat,
    492                        data.NumElements(), data_ptr, output_flat);
    493   }
    494 
    495  protected:
    496   DeviceReductionFunctor reduction_functor_;
    497 };
    498 
    499 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT(                           \
    500     name, type, index_type, initial_value_functor, reduction_functor)  \
    501   REGISTER_KERNEL_BUILDER(                                             \
    502       Name(name)                                                       \
    503           .Device(DEVICE_CPU)                                          \
    504           .TypeConstraint<type>("T")                                   \
    505           .TypeConstraint<index_type>("Tindices"),                     \
    506       UnsortedSegmentReductionOp<                                      \
    507           type, index_type,                                            \
    508           functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
    509                                           initial_value_functor,       \
    510                                           reduction_functor> >)
    511 
    512 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type)                   \
    513   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
    514                                       functor::Zero<type>,                     \
    515                                       functor::SumOp<type>);                   \
    516   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
    517                                       functor::Lowest<type>,                   \
    518                                       functor::MaxOp<type>);                   \
    519   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
    520                                       functor::Highest<type>,                  \
    521                                       functor::MinOp<type>);                   \
    522   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    523                                       functor::One<type>,                      \
    524                                       functor::ProdOp<type>);
    525 
    526 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type)                \
    527   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
    528                                       functor::Zero<type>,                     \
    529                                       functor::SumOp<type>);                   \
    530   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    531                                       functor::One<type>,                      \
    532                                       functor::ProdOp<type>)
    533 
    534 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
    535   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32);   \
    536   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
    537 
    538 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
    539   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32);   \
    540   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
    541 
    542 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
    543 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
    544 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
    545 
    546 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS
    547 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
    548 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
    549 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
    550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
    551 
    552 #if GOOGLE_CUDA
    553 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT(                                 \
    554     name, type, index_type, initial_value_functor, reduction_kernel_functor) \
    555   REGISTER_KERNEL_BUILDER(                                                   \
    556       Name(name)                                                             \
    557           .Device(DEVICE_GPU)                                                \
    558           .HostMemory("num_segments")                                        \
    559           .TypeConstraint<type>("T")                                         \
    560           .TypeConstraint<index_type>("Tindices"),                           \
    561       UnsortedSegmentReductionOp<                                            \
    562           type, index_type,                                                  \
    563           functor::UnsortedSegmentFunctor<GPUDevice, type, index_type,       \
    564                                           initial_value_functor,             \
    565                                           reduction_kernel_functor> >)
    566 
    567 // sum is the only op that supports all input types currently
    568 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
    569   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
    570                                       functor::Lowest<type>,                   \
    571                                       functor::MaxOpGpu<type>);                \
    572   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
    573                                       functor::Highest<type>,                  \
    574                                       functor::MinOpGpu<type>);                \
    575   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    576                                       functor::One<type>,                      \
    577                                       functor::ProdOpGpu<type>);
    578 
    579 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
    580   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
    581                                       functor::Zero<type>,                    \
    582                                       functor::SumOpGpu<type>);
    583 
    584 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
    585   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32);   \
    586   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64);
    587 
    588 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
    589   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32);   \
    590   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64);
    591 
    592 
    593 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
    594 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
    595 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    596 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    597 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    598 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    599 
    600 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
    601 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS
    602 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS
    603 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
    604 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
    605 
    606 #endif  // GOOGLE_CUDA
    607 
    608 // ____________________________________________________________________________
    609 // Sparse segment reduction ops.
    610 
    611 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
    612 // by two dense tensors, one containing the data, and the other containing
    613 // indices into the data.
    614 template <typename Device, class T>
    615 class SparseSegmentReductionOpBase : public OpKernel {
    616  public:
    617   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
    618                                         bool is_mean, bool is_sqrtn,
    619                                         bool has_num_segments, T default_value)
    620       : OpKernel(context),
    621         is_mean_(is_mean),
    622         is_sqrtn_(is_sqrtn),
    623         has_num_segments_(has_num_segments),
    624         default_value_(default_value) {}
    625 
    626   void Compute(OpKernelContext* context) override {
    627     const Tensor& input = context->input(0);
    628     const Tensor& indices = context->input(1);
    629     const Tensor& segment_ids = context->input(2);
    630 
    631     Index output_rows = -1;
    632     if (has_num_segments_) {
    633       const Tensor& num_segments = context->input(3);
    634 
    635       OP_REQUIRES(
    636           context, num_segments.shape().dims() == 0,
    637           errors::InvalidArgument("num_segments should be a scalar, not shape ",
    638                                   num_segments.shape().DebugString()));
    639       output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
    640       OP_REQUIRES(context, output_rows >= 0,
    641                   errors::InvalidArgument("segment ids must be >= 0"));
    642     }
    643 
    644     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
    645                 errors::InvalidArgument("indices should be a vector."));
    646     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
    647                 errors::InvalidArgument("segment_ids should be a vector."));
    648 
    649     const int64 num_indices = indices.NumElements();
    650     OP_REQUIRES(context, num_indices == segment_ids.NumElements(),
    651                 errors::InvalidArgument(
    652                     "segment_ids and indices should have same size."));
    653 
    654     auto input_flat = input.flat_outer_dims<T>();
    655     const int64 num_col = input_flat.dimension(1);
    656     const auto indices_vec = indices.vec<Index>();
    657     typedef int32 OutputRow;
    658     const auto segment_vec = segment_ids.vec<OutputRow>();
    659     // Note that the current implementation assumes that segment_vec values are
    660     // sorted.
    661     const OutputRow last_segment_id_plus_one =
    662         num_indices > 0
    663             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
    664             : 0;
    665     if (has_num_segments_) {
    666       OP_REQUIRES(
    667           context, output_rows >= last_segment_id_plus_one,
    668           errors::InvalidArgument("segment ids must be < num_segments"));
    669     } else {
    670       output_rows = last_segment_id_plus_one;
    671     }
    672     OP_REQUIRES(context, output_rows >= 0,
    673                 errors::InvalidArgument("segment ids must be >= 0"));
    674 
    675     TensorShape output_shape = input.shape();
    676     output_shape.set_dim(0, output_rows);
    677 
    678     // Note that we do not initialize the output buffer with a default value, so
    679     // we need to explicitly set missing indices to the default value.
    680     Tensor* output = nullptr;
    681     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    682     if (num_indices == 0) return;
    683     OP_REQUIRES(context, output_rows > 0,
    684                 errors::InvalidArgument("segment ids must be >= 0"));
    685     auto output_flat = output->flat_outer_dims<T>();
    686 
    687     int64 start = 0, end = 1;
    688     // Index from which the output is not initialized.
    689     OutputRow uninitialized_index = 0;
    690     OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
    691 
    692     while (true) {
    693       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
    694       // used uninitialized in this function" in the Mac build (since the
    695       // compiler isn't smart enough to realize the code is safe).
    696       OutputRow next_index = 0;
    697       if (end < num_indices) {
    698         next_index = internal::SubtleMustCopy(segment_vec(end));
    699         if (out_index == next_index) {
    700           ++end;
    701           continue;
    702         }
    703         // We have a new segment here.  Verify that the segment ids are growing.
    704         OP_REQUIRES(context, out_index < next_index,
    705                     errors::InvalidArgument("segment ids are not increasing"));
    706       }
    707 
    708       OP_REQUIRES(
    709           context, FastBoundsCheck(out_index, output_rows),
    710           errors::InvalidArgument(
    711               "Segment id ", out_index, " out of range [0, ", output_rows,
    712               "), possibly because 'segment_ids' input is not sorted."));
    713 
    714       // If there is a gap between two indices, we need to set that gap to the
    715       // default value.
    716       if (out_index > uninitialized_index) {
    717         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    718             out_index - uninitialized_index, num_col);
    719         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    720             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    721         gap_slice.setConstant(default_value_);
    722       }
    723 
    724       auto out = output_flat.template chip<0>(out_index);
    725       const int bad_offset =
    726           Reduce(input_flat, indices_vec, start, end - start, out);
    727       OP_REQUIRES(context, bad_offset < 0,
    728                   errors::InvalidArgument(
    729                       "Bad: indices[", start + bad_offset,
    730                       "] == ", indices_vec(start + bad_offset),
    731                       " out of range [0, ", input_flat.dimension(0), ")"));
    732 
    733       start = end;
    734       ++end;
    735       uninitialized_index = out_index + 1;
    736       out_index = next_index;
    737       if (end > num_indices) break;
    738     }
    739 
    740     // Fill the gap at the end with the default value.
    741     if (uninitialized_index < output_rows) {
    742       Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    743           output_rows - uninitialized_index, num_col);
    744       Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    745           gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    746       gap_slice.setConstant(default_value_);
    747     }
    748   }
    749 
    750  private:
    751   typedef int32 Index;
    752 
    753   int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
    754                const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
    755                int64 num,
    756                Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
    757 #define INDEX(n, i)                               \
    758   const auto index##n = indices_vec(start + (i)); \
    759   if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
    760 
    761 #define L(n) input_flat.template chip<0>(index##n)
    762 
    763     if (num == 1) {
    764       INDEX(0, 0);
    765       out = L(0);
    766     } else {
    767       int64 r = num % 8;
    768       T m(1);
    769       if (is_mean_ && (num < 10)) {
    770         m = T(num);
    771       }
    772       if (is_sqrtn_ && (num < 10)) {
    773         m = T(sqrt(num));
    774       }
    775       switch (r) {
    776         case 2: {
    777           INDEX(0, 0);
    778           INDEX(1, 1);
    779           out = (L(0) + L(1)) / m;
    780           break;
    781         }
    782         case 3: {
    783           INDEX(0, 0);
    784           INDEX(1, 1);
    785           INDEX(2, 2);
    786           out = (L(0) + L(1) + L(2)) / m;
    787           break;
    788         }
    789         case 4: {
    790           INDEX(0, 0);
    791           INDEX(1, 1);
    792           INDEX(2, 2);
    793           INDEX(3, 3);
    794           out = (L(0) + L(1) + L(2) + L(3)) / m;
    795           break;
    796         }
    797         case 5: {
    798           INDEX(0, 0);
    799           INDEX(1, 1);
    800           INDEX(2, 2);
    801           INDEX(3, 3);
    802           INDEX(4, 4);
    803           out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
    804           break;
    805         }
    806         case 6: {
    807           INDEX(0, 0);
    808           INDEX(1, 1);
    809           INDEX(2, 2);
    810           INDEX(3, 3);
    811           INDEX(4, 4);
    812           INDEX(5, 5);
    813           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
    814           break;
    815         }
    816         case 7: {
    817           INDEX(0, 0);
    818           INDEX(1, 1);
    819           INDEX(2, 2);
    820           INDEX(3, 3);
    821           INDEX(4, 4);
    822           INDEX(5, 5);
    823           INDEX(6, 6);
    824           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
    825           break;
    826         }
    827         case 0: {
    828           INDEX(0, 0);
    829           INDEX(1, 1);
    830           INDEX(2, 2);
    831           INDEX(3, 3);
    832           INDEX(4, 4);
    833           INDEX(5, 5);
    834           INDEX(6, 6);
    835           INDEX(7, 7);
    836           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
    837           r = 8;
    838           break;
    839         }
    840         case 1: {
    841           INDEX(0, 0);
    842           INDEX(1, 1);
    843           INDEX(2, 2);
    844           INDEX(3, 3);
    845           INDEX(4, 4);
    846           INDEX(5, 5);
    847           INDEX(6, 6);
    848           INDEX(7, 7);
    849           INDEX(8, 8);
    850           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
    851                 m;
    852           r = 9;
    853           break;
    854         }
    855       }
    856       for (; r < num; r += 8) {
    857         INDEX(0, r);
    858         INDEX(1, r + 1);
    859         INDEX(2, r + 2);
    860         INDEX(3, r + 3);
    861         INDEX(4, r + 4);
    862         INDEX(5, r + 5);
    863         INDEX(6, r + 6);
    864         INDEX(7, r + 7);
    865         out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
    866       }
    867       if (is_mean_ && num >= 10) {
    868         out = out / static_cast<T>(num);
    869       }
    870       if (is_sqrtn_ && num >= 10) {
    871         out = out / static_cast<T>(sqrt(num));
    872       }
    873     }
    874 
    875     return -1;
    876 #undef L
    877 #undef INDEX
    878   }
    879 
    880   const bool is_mean_;
    881   const bool is_sqrtn_;
    882   const bool has_num_segments_;
    883   const T default_value_;
    884 };
    885 
    886 template <typename Device, class T>
    887 class SparseSegmentReductionMeanOp
    888     : public SparseSegmentReductionOpBase<Device, T> {
    889  public:
    890   explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
    891       : SparseSegmentReductionOpBase<Device, T>(
    892             context, true /*is_mean*/, false /*is_sqrtn*/,
    893             false /* has_num_segments */, T(0) /* default_value */) {}
    894 };
    895 
    896 template <typename Device, class T>
    897 class SparseSegmentReductionMeanWithNumSegmentsOp
    898     : public SparseSegmentReductionOpBase<Device, T> {
    899  public:
    900   explicit SparseSegmentReductionMeanWithNumSegmentsOp(
    901       OpKernelConstruction* context)
    902       : SparseSegmentReductionOpBase<Device, T>(
    903             context, true /*is_mean*/, false /*is_sqrtn*/,
    904             true /* has_num_segments */, T(0) /* default_value */) {}
    905 };
    906 
    907 template <typename Device, class T>
    908 class SparseSegmentReductionSqrtNOp
    909     : public SparseSegmentReductionOpBase<Device, T> {
    910  public:
    911   explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
    912       : SparseSegmentReductionOpBase<Device, T>(
    913             context, false /*is_mean*/, true /*is_sqrtn*/,
    914             false /* has_num_segments */, T(0) /* default_value */) {}
    915 };
    916 
    917 template <typename Device, class T>
    918 class SparseSegmentReductionSqrtNWithNumSegmentsOp
    919     : public SparseSegmentReductionOpBase<Device, T> {
    920  public:
    921   explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
    922       OpKernelConstruction* context)
    923       : SparseSegmentReductionOpBase<Device, T>(
    924             context, false /*is_mean*/, true /*is_sqrtn*/,
    925             true /* has_num_segments */, T(0) /* default_value */) {}
    926 };
    927 
    928 template <typename Device, class T>
    929 class SparseSegmentReductionSumOp
    930     : public SparseSegmentReductionOpBase<Device, T> {
    931  public:
    932   explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
    933       : SparseSegmentReductionOpBase<Device, T>(
    934             context, false /*is_mean*/, false /*is_sqrtn*/,
    935             false /* has_num_segments */, T(0) /* default_value */) {}
    936 };
    937 
    938 template <typename Device, class T>
    939 class SparseSegmentReductionSumWithNumSegmentsOp
    940     : public SparseSegmentReductionOpBase<Device, T> {
    941  public:
    942   explicit SparseSegmentReductionSumWithNumSegmentsOp(
    943       OpKernelConstruction* context)
    944       : SparseSegmentReductionOpBase<Device, T>(
    945             context, false /*is_mean*/, false /*is_sqrtn*/,
    946             true /* has_num_segments */, T(0) /* default_value */) {}
    947 };
    948 
    949 #define REGISTER_CPU_SPARSE_KERNELS(type)                                \
    950   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum")                       \
    951                               .Device(DEVICE_CPU)                        \
    952                               .TypeConstraint<type>("T")                 \
    953                               .TypeConstraint<int32>("Tidx"),            \
    954                           SparseSegmentReductionSumOp<CPUDevice, type>); \
    955   REGISTER_KERNEL_BUILDER(                                               \
    956       Name("SparseSegmentSumWithNumSegments")                            \
    957           .Device(DEVICE_CPU)                                            \
    958           .TypeConstraint<type>("T")                                     \
    959           .TypeConstraint<int32>("Tidx"),                                \
    960       SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
    961 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
    962 #undef REGISTER_CPU_SPARSE_KERNELS
    963 
    964 #define REGISTER_CPU_SPARSE_KERNELS(type)                                 \
    965   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean")                       \
    966                               .Device(DEVICE_CPU)                         \
    967                               .TypeConstraint<type>("T")                  \
    968                               .TypeConstraint<int32>("Tidx"),             \
    969                           SparseSegmentReductionMeanOp<CPUDevice, type>); \
    970   REGISTER_KERNEL_BUILDER(                                                \
    971       Name("SparseSegmentMeanWithNumSegments")                            \
    972           .Device(DEVICE_CPU)                                             \
    973           .TypeConstraint<type>("T")                                      \
    974           .TypeConstraint<int32>("Tidx"),                                 \
    975       SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
    976 REGISTER_CPU_SPARSE_KERNELS(float);
    977 REGISTER_CPU_SPARSE_KERNELS(double);
    978 #undef REGISTER_CPU_SPARSE_KERNELS
    979 
    980 #define REGISTER_CPU_SPARSE_KERNELS(type)                                  \
    981   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN")                       \
    982                               .Device(DEVICE_CPU)                          \
    983                               .TypeConstraint<type>("T")                   \
    984                               .TypeConstraint<int32>("Tidx"),              \
    985                           SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
    986   REGISTER_KERNEL_BUILDER(                                                 \
    987       Name("SparseSegmentSqrtNWithNumSegments")                            \
    988           .Device(DEVICE_CPU)                                              \
    989           .TypeConstraint<type>("T")                                       \
    990           .TypeConstraint<int32>("Tidx"),                                  \
    991       SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
    992 REGISTER_CPU_SPARSE_KERNELS(float);
    993 REGISTER_CPU_SPARSE_KERNELS(double);
    994 #undef REGISTER_CPU_SPARSE_KERNELS
    995 
    996 template <class T>
    997 class SparseSegmentGradOpBase : public OpKernel {
    998  public:
    999   explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
   1000       : OpKernel(context), is_sqrtn_(is_sqrtn) {}
   1001 
   1002   void Compute(OpKernelContext* context) override {
   1003     const Tensor& input = context->input(0);
   1004     const Tensor& indices = context->input(1);
   1005     const Tensor& segment_ids = context->input(2);
   1006     const Tensor& output_dim0 = context->input(3);
   1007 
   1008     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
   1009                 errors::InvalidArgument("indices should be a vector."));
   1010     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
   1011                 errors::InvalidArgument("segment_ids should be a vector."));
   1012     OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
   1013                 errors::InvalidArgument("output_dim0 should be a scalar."));
   1014 
   1015     const int64 N = indices.NumElements();
   1016     OP_REQUIRES(context, N == segment_ids.NumElements(),
   1017                 errors::InvalidArgument(
   1018                     "segment_ids and indices should have same size."));
   1019     typedef int32 SegmentId;
   1020     const SegmentId M =
   1021         internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
   1022 
   1023     auto input_flat = input.flat_outer_dims<T>();
   1024     typedef int32 Index;
   1025     const auto indices_vec = indices.vec<Index>();
   1026     const auto segment_vec = segment_ids.vec<SegmentId>();
   1027 
   1028     TensorShape output_shape = input.shape();
   1029     output_shape.set_dim(0, M);
   1030     Tensor* output = nullptr;
   1031     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
   1032     if (M == 0 || N == 0) return;
   1033 
   1034     // Note that similar to SparseSegmentMean, we assume that segment_vec is
   1035     // already sorted and has non-negative values.
   1036     const SegmentId num_segments = input.dim_size(0);
   1037     const SegmentId last_segment_id_plus_one =
   1038         internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
   1039     OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
   1040                 errors::InvalidArgument("Invalid number of segments"));
   1041 
   1042     // Compute scaling factors for input.
   1043     std::vector<double> scaling(num_segments, 0.0);
   1044     for (int64 i = 0; i < N; ++i) {
   1045       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
   1046       OP_REQUIRES(
   1047           context, FastBoundsCheck(idx, num_segments),
   1048           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
   1049                                   num_segments, ")."));
   1050       scaling[idx] += 1;
   1051     }
   1052     for (size_t i = 0; i < scaling.size(); ++i) {
   1053       if (is_sqrtn_) {
   1054         scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
   1055       } else {
   1056         scaling[i] = 1.0 / std::max(scaling[i], 1.0);
   1057       }
   1058     }
   1059 
   1060     auto output_flat = output->flat_outer_dims<T>();
   1061     output_flat.setZero();
   1062     std::vector<bool> is_modified(M, false);
   1063 
   1064     for (int64 i = 0; i < N; ++i) {
   1065       const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
   1066       OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
   1067                   errors::InvalidArgument("Index ", output_idx,
   1068                                           " out of range [0, ", M, ")."));
   1069 
   1070       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
   1071       OP_REQUIRES(
   1072           context, FastBoundsCheck(idx, num_segments),
   1073           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
   1074                                   num_segments, ")."));
   1075 
   1076       const T scale = static_cast<T>(scaling[idx]);
   1077       if (is_modified[output_idx]) {
   1078         if (scale == 1.0) {
   1079           output_flat.template chip<0>(output_idx) +=
   1080               input_flat.template chip<0>(idx);
   1081         } else {
   1082           output_flat.template chip<0>(output_idx) +=
   1083               input_flat.template chip<0>(idx) * scale;
   1084         }
   1085       } else {
   1086         if (scale == 1.0) {
   1087           output_flat.template chip<0>(output_idx) =
   1088               input_flat.template chip<0>(idx);
   1089         } else {
   1090           output_flat.template chip<0>(output_idx) =
   1091               input_flat.template chip<0>(idx) * scale;
   1092         }
   1093       }
   1094       is_modified[output_idx] = true;
   1095     }
   1096   }
   1097 
   1098  private:
   1099   const bool is_sqrtn_;
   1100 };
   1101 
   1102 template <class T>
   1103 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
   1104  public:
   1105   explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
   1106       : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
   1107 };
   1108 
   1109 template <class T>
   1110 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
   1111  public:
   1112   explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
   1113       : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
   1114 };
   1115 
   1116 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
   1117   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad")       \
   1118                               .Device(DEVICE_CPU)             \
   1119                               .TypeConstraint<type>("T")      \
   1120                               .TypeConstraint<int32>("Tidx"), \
   1121                           SparseSegmentMeanGradOp<type>);
   1122 REGISTER_CPU_SPARSE_KERNELS(float);
   1123 REGISTER_CPU_SPARSE_KERNELS(double);
   1124 #undef REGISTER_CPU_SPARSE_KERNELS
   1125 
   1126 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
   1127   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad")      \
   1128                               .Device(DEVICE_CPU)             \
   1129                               .TypeConstraint<type>("T")      \
   1130                               .TypeConstraint<int32>("Tidx"), \
   1131                           SparseSegmentSqrtNGradOp<type>);
   1132 REGISTER_CPU_SPARSE_KERNELS(float);
   1133 REGISTER_CPU_SPARSE_KERNELS(double);
   1134 #undef REGISTER_CPU_SPARSE_KERNELS
   1135 }  // namespace tensorflow
   1136