Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 #if GOOGLE_CUDA
     20 #define EIGEN_USE_GPU
     21 #endif  // GOOGLE_CUDA
     22 
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 
     26 #include "tensorflow/core/kernels/segment_reduction_ops.h"
     27 #include <vector>
     28 
     29 #include "tensorflow/core/framework/bounds_check.h"
     30 #include "tensorflow/core/framework/numeric_op.h"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/register_types.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_types.h"
     35 #include "tensorflow/core/framework/types.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/util/util.h"
     39 
     40 #if GOOGLE_CUDA
     41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     42 #include "tensorflow/core/kernels/cuda_solvers.h"
     43 #include "tensorflow/core/platform/cuda.h"
     44 
     45 using stream_executor::cuda::ScopedActivateExecutorContext;
     46 #endif  // GOOGLE_CUDA
     47 
     48 namespace tensorflow {
     49 
     50 typedef Eigen::ThreadPoolDevice CPUDevice;
     51 typedef Eigen::GpuDevice GPUDevice;
     52 
     53 // Static routines not in the templated class to reduce code size
     54 static void SegmentReductionValidationHelper(OpKernelContext* context,
     55                                              const Tensor& input,
     56                                              const Tensor& segment_ids) {
     57   OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
     58               errors::InvalidArgument("segment_ids should be a vector."));
     59   const int64 num_indices = segment_ids.NumElements();
     60   OP_REQUIRES(context, num_indices == input.dim_size(0),
     61               errors::InvalidArgument(
     62                   "segment_ids should be the same size as dimension 0 of"
     63                   " input."));
     64 }
     65 
     66 static bool SegmentReductionDoValidation(OpKernelContext* c,
     67                                          const Tensor& input,
     68                                          const Tensor& segment_ids) {
     69   SegmentReductionValidationHelper(c, input, segment_ids);
     70   return c->status().ok();
     71 }
     72 
     73 // This operator handles reducing segments along the first dimension.
     74 // See core/ops/math_ops.cc for more details.
     75 template <typename Device, class T, class Index, typename Reducer,
     76           int default_value>
     77 class SegmentReductionOp : public OpKernel {
     78  public:
     79   explicit SegmentReductionOp(OpKernelConstruction* context)
     80       : OpKernel(context) {}
     81 
     82   void Compute(OpKernelContext* context) override {
     83     const Tensor& input = context->input(0);
     84     const Tensor& segment_ids = context->input(1);
     85 
     86     if (!SegmentReductionDoValidation(context, input, segment_ids)) {
     87       return;
     88     }
     89 
     90     const int64 num_indices = segment_ids.NumElements();
     91     auto input_flat = input.flat_outer_dims<T>();
     92     const int64 num_col = input_flat.dimension(1);
     93 
     94     const auto segment_vec = segment_ids.vec<Index>();
     95     // Note that the current implementation assumes that segment_vec values are
     96     // sorted.
     97     const Index output_rows =
     98         num_indices > 0
     99             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
    100             : 0;
    101     OP_REQUIRES(context, output_rows >= 0,
    102                 errors::InvalidArgument("segment ids must be >= 0"));
    103 
    104     TensorShape output_shape = input.shape();
    105     output_shape.set_dim(0, output_rows);
    106 
    107     // Note that we do not initialize the output buffer with a default value, so
    108     // we need to explicitly set missing indices to the default value.
    109     Tensor* output = nullptr;
    110     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    111     if (num_indices == 0) return;
    112     OP_REQUIRES(context, output_rows > 0,
    113                 errors::InvalidArgument("segment ids must be >= 0"));
    114     auto output_flat = output->flat_outer_dims<T>();
    115 
    116 #if !defined(EIGEN_HAS_INDEX_LIST)
    117     Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
    118     dims_to_reduce[0] = 0;
    119 #else
    120     Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
    121 #endif
    122     Index start = 0, end = 1;
    123 
    124     Index uninitialized_index = 0;  // Index from which the output is not set.
    125     Index out_index = internal::SubtleMustCopy(segment_vec(start));
    126 
    127     // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
    128     // across threads.
    129     Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
    130     while (end <= num_indices) {
    131       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
    132       // used uninitialized in this function" in the Mac build (since the
    133       // compiler isn't smart enough to realize the code is safe).
    134       Index next_index = 0;
    135       if (end < num_indices) {
    136         next_index = internal::SubtleMustCopy(segment_vec(end));
    137         if (out_index == next_index) {
    138           ++end;
    139           continue;
    140         }
    141         // We have a new segment here.  Verify that the segment ids are growing.
    142         OP_REQUIRES(context, out_index < next_index,
    143                     errors::InvalidArgument("segment ids are not increasing"));
    144       }
    145 
    146       // Process segment [start, end)
    147       const T* in_slice_ptr = &input_flat(start, 0);
    148       typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
    149                                Eigen::Unaligned>
    150           OutT;
    151 
    152       OP_REQUIRES(
    153           context, FastBoundsCheck(out_index, output_rows),
    154           errors::InvalidArgument(
    155               "Segment id ", out_index, " out of range [0, ", output_rows,
    156               "), possibly because 'segment_ids' input is not sorted."));
    157 
    158       // If there is a gap between two indices, we need to set that gap to the
    159       // default value.
    160       if (out_index > uninitialized_index) {
    161         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    162             out_index - uninitialized_index, num_col);
    163         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    164             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    165         gap_slice.setConstant(T(default_value));
    166       }
    167 
    168       T* out_slice_ptr = &output_flat(out_index, 0);
    169       OutT out_slice(out_slice_ptr, out_slice_shape);
    170       // We don't use out_slice.device(context->eigen_device<Device>)
    171       // because these pieces of work are likely to be very small and
    172       // the context switching overhead dwarfs any benefit we get from
    173       // using another thread to do this work.
    174       if (start == end - 1) {
    175         typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
    176                                  Eigen::Unaligned>
    177             InT;
    178         InT in_slice(in_slice_ptr, out_slice_shape);
    179         out_slice = in_slice;
    180       } else {
    181         Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
    182                                                            num_col);
    183         typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
    184                                  Eigen::Unaligned>
    185             InT;
    186         InT in_slice(in_slice_ptr, in_slice_shape);
    187 
    188         out_slice = in_slice.reduce(dims_to_reduce, Reducer());
    189       }
    190       if (end >= num_indices) break;
    191       start = end;
    192       ++end;
    193       uninitialized_index = out_index + 1;
    194       out_index = next_index;
    195     }
    196   }
    197 };
    198 
    199 #ifdef GOOGLE_CUDA
    200 //  SegmentSumGPUOp is a segment sum operator implemented for GPU only.
    201 //  TODO: This implementation of SegmentSumGPUOp is sometimes slower than
    202 //  its unsorted counterpart (mostly when problem size is small).
    203 //  This is due to the following two main reasons and a cost-effective way
    204 //  to resolve these problems is desirable.
    205 //  1. Sorted segment sum requires a memory transfer from device to host in
    206 //     order to know the size of the output dimension whereas unsorted segment
    207 //     sum receives the size of the output dimension as an input parameter.
    208 //  2. Sorted segment sum is essentially a tiled version of unsorted segment
    209 //     sum and therefore such optimization comes at an inherent cost. However
    210 //     such cost may not be justified when the problem size is small. When to
    211 //     use the tiled version or the untiled version depends on many factors
    212 //     including data alignments, ratio of calculation to memory traffic and
    213 //     obviously, the problem sizes.
    214 template <class T, class Index>
    215 class SegmentSumGPUOp : public AsyncOpKernel {
    216  public:
    217   explicit SegmentSumGPUOp(OpKernelConstruction* context)
    218       : AsyncOpKernel(context) {}
    219 
    220   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    221     const Tensor& input = context->input(0);
    222     const Tensor& segment_ids = context->input(1);
    223 
    224     OP_REQUIRES_ASYNC(
    225         context, TensorShapeUtils::IsVector(segment_ids.shape()),
    226         errors::InvalidArgument("segment_ids should be a vector."), done);
    227 
    228     const int64 num_indices = segment_ids.NumElements();
    229     OP_REQUIRES_ASYNC(
    230         context, num_indices == input.dim_size(0),
    231         errors::InvalidArgument(
    232             "segment_ids should be the same size as dimension 0 of"
    233             " input."),
    234         done);
    235 
    236     if (num_indices == 0) {
    237       TensorShape output_shape = input.shape();
    238       output_shape.set_dim(0, 0);
    239 
    240       Tensor* output = nullptr;
    241       OP_REQUIRES_OK_ASYNC(
    242           context, context->allocate_output(0, output_shape, &output), done);
    243       done();
    244       return;
    245     }
    246 
    247     se::DeviceMemoryBase output_rows_device(
    248         const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
    249         (num_indices - 1));
    250     ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
    251 
    252     auto stream = context->op_device_context()->stream();
    253     OP_REQUIRES_ASYNC(
    254         context,
    255         stream
    256             ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
    257                          sizeof(Index))
    258             .ok(),
    259         errors::Internal(
    260             "SegmentSumGPUOp: failed to copy output_rows from device"),
    261         done);
    262 
    263     functor::SegmentSumFunctor<T, Index> functor_;
    264     auto create_and_check_output = [context, output_rows_host, &input,
    265                                     &segment_ids, &functor_, done]() {
    266       // Ensure that within the callback, the proper GPU settings are
    267       // configured.
    268       auto stream = context->op_device_context()->stream();
    269       ScopedActivateExecutorContext scoped_activation{stream->parent()};
    270 
    271       Index output_rows = *output_rows_host.data();
    272       output_rows++;
    273       OP_REQUIRES_ASYNC(context, output_rows > 0,
    274                         errors::InvalidArgument("segment ids must be >= 0"),
    275                         done);
    276 
    277       TensorShape output_shape = input.shape();
    278       output_shape.set_dim(0, output_rows);
    279 
    280       Tensor* output = nullptr;
    281       OP_REQUIRES_OK_ASYNC(
    282           context, context->allocate_output(0, output_shape, &output), done);
    283 
    284       auto output_flat = output->flat_outer_dims<T>();
    285       auto data_ptr = input.template flat<T>().data();
    286       auto segment_flat = segment_ids.flat<Index>();
    287       functor_(context, context->eigen_device<GPUDevice>(), output_rows,
    288                segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
    289                output_flat);
    290 
    291       done();
    292     };
    293 
    294     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
    295         stream, create_and_check_output);
    296   }
    297 };
    298 #endif  // GOOGLE_CUDA
    299 
    300 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
    301                                     default_value)                   \
    302   REGISTER_KERNEL_BUILDER(                                           \
    303       Name(name)                                                     \
    304           .Device(DEVICE_CPU)                                        \
    305           .TypeConstraint<type>("T")                                 \
    306           .TypeConstraint<index_type>("Tindices"),                   \
    307       SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>)
    308 
    309 #define REGISTER_REAL_CPU_KERNELS(type, index_type)                            \
    310   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
    311                               type, index_type, 0);                            \
    312   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    313       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
    314   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    315       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \
    316   REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \
    317                               type, index_type, 0);                            \
    318   REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \
    319                               type, index_type, 0)
    320 
    321 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type)                         \
    322   REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \
    323                               type, index_type, 0);                            \
    324   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    325       "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \
    326   REGISTER_CPU_KERNEL_SEGMENT(                                                 \
    327       "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1);
    328 
    329 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \
    330   REGISTER_REAL_CPU_KERNELS(type, int32);   \
    331   REGISTER_REAL_CPU_KERNELS(type, int64)
    332 
    333 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \
    334   REGISTER_COMPLEX_CPU_KERNELS(type, int32);   \
    335   REGISTER_COMPLEX_CPU_KERNELS(type, int64)
    336 
    337 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL);
    338 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64);
    339 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
    340 #undef REGISTER_CPU_KERNEL_SEGMENT
    341 #undef REGISTER_REAL_CPU_KERNELS
    342 #undef REGISTER_COMPLEX_CPU_KERNELS
    343 #undef REGISTER_REAL_CPU_KERNELS_ALL
    344 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL
    345 
    346 #if GOOGLE_CUDA
    347 #define REGISTER_GPU_SORTED_KERNELS(type, index_type)                  \
    348   REGISTER_KERNEL_BUILDER(Name("SegmentSum")                           \
    349                               .Device(DEVICE_GPU)                      \
    350                               .TypeConstraint<type>("T")               \
    351                               .TypeConstraint<index_type>("Tindices"), \
    352                           SegmentSumGPUOp<type, index_type>)
    353 
    354 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
    355   REGISTER_GPU_SORTED_KERNELS(type, int32);   \
    356   REGISTER_GPU_SORTED_KERNELS(type, int64);
    357 
    358 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
    359 #undef REGISTER_GPU_SORTED_KERNELS
    360 #undef REGISTER_GPU_SORTED_KERNELS_ALL
    361 #endif  // GOOGLE_CUDA
    362 
    363 // ____________________________________________________________________________
    364 // Unsorted segment reduction ops.
    365 
    366 namespace functor {
    367 
    368 // The ReductionFunctor implementation for CPU.
    369 template <typename T, typename Index, typename InitialValueF,
    370           typename ReductionF>
    371 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
    372   void operator()(OpKernelContext* ctx, const Index num_segments,
    373                   const TensorShape& segment_ids_shape,
    374                   typename TTypes<Index>::ConstFlat segment_ids,
    375                   const Index data_size, const T* data,
    376                   typename TTypes<T, 2>::Tensor output) {
    377     output.setConstant(InitialValueF()());
    378     if (data_size == 0) {
    379       return;
    380     }
    381     const int64 N = segment_ids.dimension(0);
    382     ReductionF reduction;
    383     auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
    384     for (int64 i = 0; i < N; ++i) {
    385       Index j = internal::SubtleMustCopy(segment_ids(i));
    386       if (j < 0) {
    387         continue;
    388       }
    389       OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
    390                   errors::InvalidArgument(
    391                       "segment_ids", SliceDebugString(segment_ids_shape, i),
    392                       " = ", j, " is out of range [0, ", num_segments, ")"));
    393       reduction(data_flat.template chip<0>(i), output.template chip<0>(j));
    394     }
    395   }
    396 };
    397 
    398 template <typename T>
    399 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
    400 
    401 template <typename T>
    402 using constMatrixChip =
    403     Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
    404 
    405 // reduction functors
    406 template <typename T>
    407 struct SumOp {
    408   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    409     output += data;
    410   }
    411 };
    412 
    413 template <typename T>
    414 struct MaxOp {
    415   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    416     output = data.cwiseMax(output);
    417   }
    418 };
    419 
    420 template <typename T>
    421 struct MinOp {
    422   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    423     output = data.cwiseMin(output);
    424   }
    425 };
    426 
    427 template <typename T>
    428 struct ProdOp {
    429   void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
    430     output *= data;
    431   }
    432 };
    433 }  // namespace functor
    434 
    435 // Static check routines not in the templated class to reduce code size
    436 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel,
    437                                                OpKernelContext* context,
    438                                                const Tensor& data,
    439                                                const Tensor& segment_ids,
    440                                                const Tensor& num_segments) {
    441   OP_REQUIRES(
    442       context, op_kernel->IsLegacyScalar(num_segments.shape()),
    443       errors::InvalidArgument("num_segments should be a scalar, not shape ",
    444                               num_segments.shape().DebugString()));
    445   OP_REQUIRES(
    446       context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
    447       errors::InvalidArgument("data.shape = ", data.shape().DebugString(),
    448                               " does not start with segment_ids.shape = ",
    449                               segment_ids.shape().DebugString()));
    450 }
    451 
    452 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel,
    453                                                  OpKernelContext* context,
    454                                                  const Tensor& data,
    455                                                  const Tensor& segment_ids,
    456                                                  const Tensor& num_segments) {
    457   UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids,
    458                                      num_segments);
    459   return context->status().ok();
    460 }
    461 
    462 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
    463 // is the device specific implementation of the reduction. These device
    464 // specific implementations are templated themselves with the corresponding
    465 // initial value functors and reduction functors.
    466 template <typename T, typename Index, typename DeviceReductionFunctor>
    467 class UnsortedSegmentReductionOp : public OpKernel {
    468  public:
    469   explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
    470       : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
    471 
    472   void Compute(OpKernelContext* context) override {
    473     const Tensor& data = context->input(0);
    474     const Tensor& segment_ids = context->input(1);
    475     const Tensor& num_segments = context->input(2);
    476     if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids,
    477                                               num_segments)) {
    478       return;
    479     }
    480     const auto segment_flat = segment_ids.flat<Index>();
    481     const Index output_rows =
    482         internal::SubtleMustCopy(num_segments.scalar<int32>()());
    483     OP_REQUIRES(context, output_rows >= 0,
    484                 errors::InvalidArgument("Input num_segments == ", output_rows,
    485                                         " must not be negative."));
    486     TensorShape output_shape;
    487     output_shape.AddDim(output_rows);
    488     for (int i = segment_ids.dims(); i < data.dims(); i++) {
    489       output_shape.AddDim(data.dim_size(i));
    490     }
    491     Tensor* output = nullptr;
    492     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    493     auto output_flat = output->flat_outer_dims<T>();
    494     auto data_ptr = data.template flat<T>().data();
    495     reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat,
    496                        data.NumElements(), data_ptr, output_flat);
    497   }
    498 
    499  protected:
    500   DeviceReductionFunctor reduction_functor_;
    501 };
    502 
    503 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT(                           \
    504     name, type, index_type, initial_value_functor, reduction_functor)  \
    505   REGISTER_KERNEL_BUILDER(                                             \
    506       Name(name)                                                       \
    507           .Device(DEVICE_CPU)                                          \
    508           .TypeConstraint<type>("T")                                   \
    509           .TypeConstraint<index_type>("Tindices"),                     \
    510       UnsortedSegmentReductionOp<                                      \
    511           type, index_type,                                            \
    512           functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \
    513                                           initial_value_functor,       \
    514                                           reduction_functor> >)
    515 
    516 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type)                   \
    517   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
    518                                       functor::Zero<type>,                     \
    519                                       functor::SumOp<type>);                   \
    520   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
    521                                       functor::Lowest<type>,                   \
    522                                       functor::MaxOp<type>);                   \
    523   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
    524                                       functor::Highest<type>,                  \
    525                                       functor::MinOp<type>);                   \
    526   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    527                                       functor::One<type>,                      \
    528                                       functor::ProdOp<type>);
    529 
    530 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type)                \
    531   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type,  \
    532                                       functor::Zero<type>,                     \
    533                                       functor::SumOp<type>);                   \
    534   REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    535                                       functor::One<type>,                      \
    536                                       functor::ProdOp<type>)
    537 
    538 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
    539   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32);   \
    540   REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
    541 
    542 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
    543   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32);   \
    544   REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
    545 
    546 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
    547 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
    548 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
    549 
    550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS
    551 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT
    552 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
    553 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
    554 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
    555 
    556 #if GOOGLE_CUDA
    557 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT(                                 \
    558     name, type, index_type, initial_value_functor, reduction_kernel_functor) \
    559   REGISTER_KERNEL_BUILDER(                                                   \
    560       Name(name)                                                             \
    561           .Device(DEVICE_GPU)                                                \
    562           .HostMemory("num_segments")                                        \
    563           .TypeConstraint<type>("T")                                         \
    564           .TypeConstraint<index_type>("Tindices"),                           \
    565       UnsortedSegmentReductionOp<                                            \
    566           type, index_type,                                                  \
    567           functor::UnsortedSegmentFunctor<GPUDevice, type, index_type,       \
    568                                           initial_value_functor,             \
    569                                           reduction_kernel_functor> >)
    570 
    571 // sum is the only op that supports all input types currently
    572 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type)                   \
    573   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type,  \
    574                                       functor::Lowest<type>,                   \
    575                                       functor::MaxOpGpu<type>);                \
    576   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type,  \
    577                                       functor::Highest<type>,                  \
    578                                       functor::MinOpGpu<type>);                \
    579   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
    580                                       functor::One<type>,                      \
    581                                       functor::ProdOpGpu<type>);
    582 
    583 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type)                   \
    584   REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
    585                                       functor::Zero<type>,                    \
    586                                       functor::SumOpGpu<type>);
    587 
    588 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
    589   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32);   \
    590   REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64);
    591 
    592 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \
    593   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32);   \
    594   REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64);
    595 
    596 
    597 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
    598 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL);
    599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    600 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    601 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    602 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL);
    603 
    604 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT
    605 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS
    606 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS
    607 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL
    608 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL
    609 
    610 #endif  // GOOGLE_CUDA
    611 
    612 // ____________________________________________________________________________
    613 // Sparse segment reduction ops.
    614 
    615 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
    616 // by two dense tensors, one containing the data, and the other containing
    617 // indices into the data.
    618 template <typename Device, class T>
    619 class SparseSegmentReductionOpBase : public OpKernel {
    620  public:
    621   explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
    622                                         bool is_mean, bool is_sqrtn,
    623                                         bool has_num_segments, T default_value)
    624       : OpKernel(context),
    625         is_mean_(is_mean),
    626         is_sqrtn_(is_sqrtn),
    627         has_num_segments_(has_num_segments),
    628         default_value_(default_value) {}
    629 
    630   void Compute(OpKernelContext* context) override {
    631     const Tensor& input = context->input(0);
    632     const Tensor& indices = context->input(1);
    633     const Tensor& segment_ids = context->input(2);
    634 
    635     Index output_rows = -1;
    636     if (has_num_segments_) {
    637       const Tensor& num_segments = context->input(3);
    638 
    639       OP_REQUIRES(
    640           context, num_segments.shape().dims() == 0,
    641           errors::InvalidArgument("num_segments should be a scalar, not shape ",
    642                                   num_segments.shape().DebugString()));
    643       output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
    644       OP_REQUIRES(context, output_rows >= 0,
    645                   errors::InvalidArgument("segment ids must be >= 0"));
    646     }
    647 
    648     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
    649                 errors::InvalidArgument("indices should be a vector."));
    650     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
    651                 errors::InvalidArgument("segment_ids should be a vector."));
    652 
    653     const int64 num_indices = indices.NumElements();
    654     OP_REQUIRES(context, num_indices == segment_ids.NumElements(),
    655                 errors::InvalidArgument(
    656                     "segment_ids and indices should have same size."));
    657 
    658     auto input_flat = input.flat_outer_dims<T>();
    659     const int64 num_col = input_flat.dimension(1);
    660     const auto indices_vec = indices.vec<Index>();
    661     typedef int32 OutputRow;
    662     const auto segment_vec = segment_ids.vec<OutputRow>();
    663     // Note that the current implementation assumes that segment_vec values are
    664     // sorted.
    665     const OutputRow last_segment_id_plus_one =
    666         num_indices > 0
    667             ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
    668             : 0;
    669     if (has_num_segments_) {
    670       OP_REQUIRES(
    671           context, output_rows >= last_segment_id_plus_one,
    672           errors::InvalidArgument("segment ids must be < num_segments"));
    673     } else {
    674       output_rows = last_segment_id_plus_one;
    675     }
    676     OP_REQUIRES(context, output_rows >= 0,
    677                 errors::InvalidArgument("segment ids must be >= 0"));
    678 
    679     TensorShape output_shape = input.shape();
    680     output_shape.set_dim(0, output_rows);
    681 
    682     // Note that we do not initialize the output buffer with a default value, so
    683     // we need to explicitly set missing indices to the default value.
    684     Tensor* output = nullptr;
    685     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    686     if (num_indices == 0) {
    687       if (output_rows > 0) {
    688         output->flat_outer_dims<T>().setConstant(default_value_);
    689       }
    690       return;
    691     }
    692     OP_REQUIRES(context, output_rows > 0,
    693                 errors::InvalidArgument("segment ids must be >= 0"));
    694     auto output_flat = output->flat_outer_dims<T>();
    695 
    696     int64 start = 0, end = 1;
    697     // Index from which the output is not initialized.
    698     OutputRow uninitialized_index = 0;
    699     OutputRow out_index = internal::SubtleMustCopy(segment_vec(start));
    700 
    701     while (true) {
    702       // We initialize next_index to 0 to avoid "warning: 'next_index' may be
    703       // used uninitialized in this function" in the Mac build (since the
    704       // compiler isn't smart enough to realize the code is safe).
    705       OutputRow next_index = 0;
    706       if (end < num_indices) {
    707         next_index = internal::SubtleMustCopy(segment_vec(end));
    708         if (out_index == next_index) {
    709           ++end;
    710           continue;
    711         }
    712         // We have a new segment here.  Verify that the segment ids are growing.
    713         OP_REQUIRES(context, out_index < next_index,
    714                     errors::InvalidArgument("segment ids are not increasing"));
    715       }
    716 
    717       OP_REQUIRES(
    718           context, FastBoundsCheck(out_index, output_rows),
    719           errors::InvalidArgument(
    720               "Segment id ", out_index, " out of range [0, ", output_rows,
    721               "), possibly because 'segment_ids' input is not sorted."));
    722 
    723       // If there is a gap between two indices, we need to set that gap to the
    724       // default value.
    725       if (out_index > uninitialized_index) {
    726         Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    727             out_index - uninitialized_index, num_col);
    728         Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    729             gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    730         gap_slice.setConstant(default_value_);
    731       }
    732 
    733       auto out = output_flat.template chip<0>(out_index);
    734       const int bad_offset =
    735           Reduce(input_flat, indices_vec, start, end - start, out);
    736       OP_REQUIRES(context, bad_offset < 0,
    737                   errors::InvalidArgument(
    738                       "Bad: indices[", start + bad_offset,
    739                       "] == ", indices_vec(start + bad_offset),
    740                       " out of range [0, ", input_flat.dimension(0), ")"));
    741 
    742       start = end;
    743       ++end;
    744       uninitialized_index = out_index + 1;
    745       out_index = next_index;
    746       if (end > num_indices) break;
    747     }
    748 
    749     // Fill the gap at the end with the default value.
    750     if (uninitialized_index < output_rows) {
    751       Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
    752           output_rows - uninitialized_index, num_col);
    753       Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
    754           gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
    755       gap_slice.setConstant(default_value_);
    756     }
    757   }
    758 
    759  private:
    760   typedef int32 Index;
    761 
    762   int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat,
    763                const typename TTypes<Index>::ConstVec& indices_vec, int64 start,
    764                int64 num,
    765                Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) {
    766 #define INDEX(n, i)                               \
    767   const auto index##n = indices_vec(start + (i)); \
    768   if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
    769 
    770 #define L(n) input_flat.template chip<0>(index##n)
    771 
    772     if (num == 1) {
    773       INDEX(0, 0);
    774       out = L(0);
    775     } else {
    776       int64 r = num % 8;
    777       T m(1);
    778       if (is_mean_ && (num < 10)) {
    779         m = T(num);
    780       }
    781       if (is_sqrtn_ && (num < 10)) {
    782         m = T(sqrt(num));
    783       }
    784       switch (r) {
    785         case 2: {
    786           INDEX(0, 0);
    787           INDEX(1, 1);
    788           out = (L(0) + L(1)) / m;
    789           break;
    790         }
    791         case 3: {
    792           INDEX(0, 0);
    793           INDEX(1, 1);
    794           INDEX(2, 2);
    795           out = (L(0) + L(1) + L(2)) / m;
    796           break;
    797         }
    798         case 4: {
    799           INDEX(0, 0);
    800           INDEX(1, 1);
    801           INDEX(2, 2);
    802           INDEX(3, 3);
    803           out = (L(0) + L(1) + L(2) + L(3)) / m;
    804           break;
    805         }
    806         case 5: {
    807           INDEX(0, 0);
    808           INDEX(1, 1);
    809           INDEX(2, 2);
    810           INDEX(3, 3);
    811           INDEX(4, 4);
    812           out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m;
    813           break;
    814         }
    815         case 6: {
    816           INDEX(0, 0);
    817           INDEX(1, 1);
    818           INDEX(2, 2);
    819           INDEX(3, 3);
    820           INDEX(4, 4);
    821           INDEX(5, 5);
    822           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m;
    823           break;
    824         }
    825         case 7: {
    826           INDEX(0, 0);
    827           INDEX(1, 1);
    828           INDEX(2, 2);
    829           INDEX(3, 3);
    830           INDEX(4, 4);
    831           INDEX(5, 5);
    832           INDEX(6, 6);
    833           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m;
    834           break;
    835         }
    836         case 0: {
    837           INDEX(0, 0);
    838           INDEX(1, 1);
    839           INDEX(2, 2);
    840           INDEX(3, 3);
    841           INDEX(4, 4);
    842           INDEX(5, 5);
    843           INDEX(6, 6);
    844           INDEX(7, 7);
    845           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m;
    846           r = 8;
    847           break;
    848         }
    849         case 1: {
    850           INDEX(0, 0);
    851           INDEX(1, 1);
    852           INDEX(2, 2);
    853           INDEX(3, 3);
    854           INDEX(4, 4);
    855           INDEX(5, 5);
    856           INDEX(6, 6);
    857           INDEX(7, 7);
    858           INDEX(8, 8);
    859           out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) /
    860                 m;
    861           r = 9;
    862           break;
    863         }
    864       }
    865       for (; r < num; r += 8) {
    866         INDEX(0, r);
    867         INDEX(1, r + 1);
    868         INDEX(2, r + 2);
    869         INDEX(3, r + 3);
    870         INDEX(4, r + 4);
    871         INDEX(5, r + 5);
    872         INDEX(6, r + 6);
    873         INDEX(7, r + 7);
    874         out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
    875       }
    876       if (is_mean_ && num >= 10) {
    877         out = out / static_cast<T>(num);
    878       }
    879       if (is_sqrtn_ && num >= 10) {
    880         out = out / static_cast<T>(sqrt(num));
    881       }
    882     }
    883 
    884     return -1;
    885 #undef L
    886 #undef INDEX
    887   }
    888 
    889   const bool is_mean_;
    890   const bool is_sqrtn_;
    891   const bool has_num_segments_;
    892   const T default_value_;
    893 };
    894 
    895 template <typename Device, class T>
    896 class SparseSegmentReductionMeanOp
    897     : public SparseSegmentReductionOpBase<Device, T> {
    898  public:
    899   explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
    900       : SparseSegmentReductionOpBase<Device, T>(
    901             context, true /*is_mean*/, false /*is_sqrtn*/,
    902             false /* has_num_segments */, T(0) /* default_value */) {}
    903 };
    904 
    905 template <typename Device, class T>
    906 class SparseSegmentReductionMeanWithNumSegmentsOp
    907     : public SparseSegmentReductionOpBase<Device, T> {
    908  public:
    909   explicit SparseSegmentReductionMeanWithNumSegmentsOp(
    910       OpKernelConstruction* context)
    911       : SparseSegmentReductionOpBase<Device, T>(
    912             context, true /*is_mean*/, false /*is_sqrtn*/,
    913             true /* has_num_segments */, T(0) /* default_value */) {}
    914 };
    915 
    916 template <typename Device, class T>
    917 class SparseSegmentReductionSqrtNOp
    918     : public SparseSegmentReductionOpBase<Device, T> {
    919  public:
    920   explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
    921       : SparseSegmentReductionOpBase<Device, T>(
    922             context, false /*is_mean*/, true /*is_sqrtn*/,
    923             false /* has_num_segments */, T(0) /* default_value */) {}
    924 };
    925 
    926 template <typename Device, class T>
    927 class SparseSegmentReductionSqrtNWithNumSegmentsOp
    928     : public SparseSegmentReductionOpBase<Device, T> {
    929  public:
    930   explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
    931       OpKernelConstruction* context)
    932       : SparseSegmentReductionOpBase<Device, T>(
    933             context, false /*is_mean*/, true /*is_sqrtn*/,
    934             true /* has_num_segments */, T(0) /* default_value */) {}
    935 };
    936 
    937 template <typename Device, class T>
    938 class SparseSegmentReductionSumOp
    939     : public SparseSegmentReductionOpBase<Device, T> {
    940  public:
    941   explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
    942       : SparseSegmentReductionOpBase<Device, T>(
    943             context, false /*is_mean*/, false /*is_sqrtn*/,
    944             false /* has_num_segments */, T(0) /* default_value */) {}
    945 };
    946 
    947 template <typename Device, class T>
    948 class SparseSegmentReductionSumWithNumSegmentsOp
    949     : public SparseSegmentReductionOpBase<Device, T> {
    950  public:
    951   explicit SparseSegmentReductionSumWithNumSegmentsOp(
    952       OpKernelConstruction* context)
    953       : SparseSegmentReductionOpBase<Device, T>(
    954             context, false /*is_mean*/, false /*is_sqrtn*/,
    955             true /* has_num_segments */, T(0) /* default_value */) {}
    956 };
    957 
    958 #define REGISTER_CPU_SPARSE_KERNELS(type)                                \
    959   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum")                       \
    960                               .Device(DEVICE_CPU)                        \
    961                               .TypeConstraint<type>("T")                 \
    962                               .TypeConstraint<int32>("Tidx"),            \
    963                           SparseSegmentReductionSumOp<CPUDevice, type>); \
    964   REGISTER_KERNEL_BUILDER(                                               \
    965       Name("SparseSegmentSumWithNumSegments")                            \
    966           .Device(DEVICE_CPU)                                            \
    967           .TypeConstraint<type>("T")                                     \
    968           .TypeConstraint<int32>("Tidx"),                                \
    969       SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
    970 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
    971 #undef REGISTER_CPU_SPARSE_KERNELS
    972 
    973 #define REGISTER_CPU_SPARSE_KERNELS(type)                                 \
    974   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean")                       \
    975                               .Device(DEVICE_CPU)                         \
    976                               .TypeConstraint<type>("T")                  \
    977                               .TypeConstraint<int32>("Tidx"),             \
    978                           SparseSegmentReductionMeanOp<CPUDevice, type>); \
    979   REGISTER_KERNEL_BUILDER(                                                \
    980       Name("SparseSegmentMeanWithNumSegments")                            \
    981           .Device(DEVICE_CPU)                                             \
    982           .TypeConstraint<type>("T")                                      \
    983           .TypeConstraint<int32>("Tidx"),                                 \
    984       SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
    985 REGISTER_CPU_SPARSE_KERNELS(float);
    986 REGISTER_CPU_SPARSE_KERNELS(double);
    987 #undef REGISTER_CPU_SPARSE_KERNELS
    988 
    989 #define REGISTER_CPU_SPARSE_KERNELS(type)                                  \
    990   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN")                       \
    991                               .Device(DEVICE_CPU)                          \
    992                               .TypeConstraint<type>("T")                   \
    993                               .TypeConstraint<int32>("Tidx"),              \
    994                           SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
    995   REGISTER_KERNEL_BUILDER(                                                 \
    996       Name("SparseSegmentSqrtNWithNumSegments")                            \
    997           .Device(DEVICE_CPU)                                              \
    998           .TypeConstraint<type>("T")                                       \
    999           .TypeConstraint<int32>("Tidx"),                                  \
   1000       SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
   1001 REGISTER_CPU_SPARSE_KERNELS(float);
   1002 REGISTER_CPU_SPARSE_KERNELS(double);
   1003 #undef REGISTER_CPU_SPARSE_KERNELS
   1004 
   1005 template <class T>
   1006 class SparseSegmentGradOpBase : public OpKernel {
   1007  public:
   1008   explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn)
   1009       : OpKernel(context), is_sqrtn_(is_sqrtn) {}
   1010 
   1011   void Compute(OpKernelContext* context) override {
   1012     const Tensor& input = context->input(0);
   1013     const Tensor& indices = context->input(1);
   1014     const Tensor& segment_ids = context->input(2);
   1015     const Tensor& output_dim0 = context->input(3);
   1016 
   1017     OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
   1018                 errors::InvalidArgument("indices should be a vector."));
   1019     OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
   1020                 errors::InvalidArgument("segment_ids should be a vector."));
   1021     OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
   1022                 errors::InvalidArgument("output_dim0 should be a scalar."));
   1023 
   1024     const int64 N = indices.NumElements();
   1025     OP_REQUIRES(context, N == segment_ids.NumElements(),
   1026                 errors::InvalidArgument(
   1027                     "segment_ids and indices should have same size."));
   1028     typedef int32 SegmentId;
   1029     const SegmentId M =
   1030         internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()());
   1031 
   1032     auto input_flat = input.flat_outer_dims<T>();
   1033     typedef int32 Index;
   1034     const auto indices_vec = indices.vec<Index>();
   1035     const auto segment_vec = segment_ids.vec<SegmentId>();
   1036 
   1037     TensorShape output_shape = input.shape();
   1038     output_shape.set_dim(0, M);
   1039     Tensor* output = nullptr;
   1040     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
   1041     if (M == 0 || N == 0) return;
   1042 
   1043     // Note that similar to SparseSegmentMean, we assume that segment_vec is
   1044     // already sorted and has non-negative values.
   1045     const SegmentId num_segments = input.dim_size(0);
   1046     const SegmentId last_segment_id_plus_one =
   1047         internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
   1048     OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
   1049                 errors::InvalidArgument("Invalid number of segments"));
   1050 
   1051     // Compute scaling factors for input.
   1052     std::vector<double> scaling(num_segments, 0.0);
   1053     for (int64 i = 0; i < N; ++i) {
   1054       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
   1055       OP_REQUIRES(
   1056           context, FastBoundsCheck(idx, num_segments),
   1057           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
   1058                                   num_segments, ")."));
   1059       scaling[idx] += 1;
   1060     }
   1061     for (size_t i = 0; i < scaling.size(); ++i) {
   1062       if (is_sqrtn_) {
   1063         scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
   1064       } else {
   1065         scaling[i] = 1.0 / std::max(scaling[i], 1.0);
   1066       }
   1067     }
   1068 
   1069     auto output_flat = output->flat_outer_dims<T>();
   1070     output_flat.setZero();
   1071     std::vector<bool> is_modified(M, false);
   1072 
   1073     for (int64 i = 0; i < N; ++i) {
   1074       const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
   1075       OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
   1076                   errors::InvalidArgument("Index ", output_idx,
   1077                                           " out of range [0, ", M, ")."));
   1078 
   1079       const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
   1080       OP_REQUIRES(
   1081           context, FastBoundsCheck(idx, num_segments),
   1082           errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
   1083                                   num_segments, ")."));
   1084 
   1085       const T scale = static_cast<T>(scaling[idx]);
   1086       if (is_modified[output_idx]) {
   1087         if (scale == 1.0) {
   1088           output_flat.template chip<0>(output_idx) +=
   1089               input_flat.template chip<0>(idx);
   1090         } else {
   1091           output_flat.template chip<0>(output_idx) +=
   1092               input_flat.template chip<0>(idx) * scale;
   1093         }
   1094       } else {
   1095         if (scale == 1.0) {
   1096           output_flat.template chip<0>(output_idx) =
   1097               input_flat.template chip<0>(idx);
   1098         } else {
   1099           output_flat.template chip<0>(output_idx) =
   1100               input_flat.template chip<0>(idx) * scale;
   1101         }
   1102       }
   1103       is_modified[output_idx] = true;
   1104     }
   1105   }
   1106 
   1107  private:
   1108   const bool is_sqrtn_;
   1109 };
   1110 
   1111 template <class T>
   1112 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> {
   1113  public:
   1114   explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
   1115       : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {}
   1116 };
   1117 
   1118 template <class T>
   1119 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> {
   1120  public:
   1121   explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
   1122       : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {}
   1123 };
   1124 
   1125 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
   1126   REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad")       \
   1127                               .Device(DEVICE_CPU)             \
   1128                               .TypeConstraint<type>("T")      \
   1129                               .TypeConstraint<int32>("Tidx"), \
   1130                           SparseSegmentMeanGradOp<type>);
   1131 REGISTER_CPU_SPARSE_KERNELS(float);
   1132 REGISTER_CPU_SPARSE_KERNELS(double);
   1133 #undef REGISTER_CPU_SPARSE_KERNELS
   1134 
   1135 #define REGISTER_CPU_SPARSE_KERNELS(type)                     \
   1136   REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad")      \
   1137                               .Device(DEVICE_CPU)             \
   1138                               .TypeConstraint<type>("T")      \
   1139                               .TypeConstraint<int32>("Tidx"), \
   1140                           SparseSegmentSqrtNGradOp<type>);
   1141 REGISTER_CPU_SPARSE_KERNELS(float);
   1142 REGISTER_CPU_SPARSE_KERNELS(double);
   1143 #undef REGISTER_CPU_SPARSE_KERNELS
   1144 }  // namespace tensorflow
   1145