1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 #if GOOGLE_CUDA 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/kernels/segment_reduction_ops.h" 26 #include <vector> 27 #include "tensorflow/core/framework/numeric_op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_types.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/kernels/bounds_check.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/util/util.h" 37 38 #if GOOGLE_CUDA 39 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 40 #include "tensorflow/core/kernels/cuda_solvers.h" 41 #include "tensorflow/core/platform/cuda.h" 42 43 using ::perftools::gputools::cuda::ScopedActivateExecutorContext; 44 #endif // GOOGLE_CUDA 45 46 namespace tensorflow { 47 48 typedef Eigen::ThreadPoolDevice CPUDevice; 49 typedef Eigen::GpuDevice GPUDevice; 50 51 // Static routines not in the templated class to reduce code size 52 static void SegmentReductionValidationHelper(OpKernelContext* context, 53 const Tensor& input, 54 const Tensor& segment_ids) { 55 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 56 errors::InvalidArgument("segment_ids should be a vector.")); 57 const int64 num_indices = segment_ids.NumElements(); 58 OP_REQUIRES(context, num_indices == input.dim_size(0), 59 errors::InvalidArgument( 60 "segment_ids should be the same size as dimension 0 of" 61 " input.")); 62 } 63 64 static bool SegmentReductionDoValidation(OpKernelContext* c, 65 const Tensor& input, 66 const Tensor& segment_ids) { 67 SegmentReductionValidationHelper(c, input, segment_ids); 68 return c->status().ok(); 69 } 70 71 // This operator handles reducing segments along the first dimension. 72 // See core/ops/math_ops.cc for more details. 73 template <typename Device, class T, class Index, typename Reducer, 74 int default_value> 75 class SegmentReductionOp : public OpKernel { 76 public: 77 explicit SegmentReductionOp(OpKernelConstruction* context) 78 : OpKernel(context) {} 79 80 void Compute(OpKernelContext* context) override { 81 const Tensor& input = context->input(0); 82 const Tensor& segment_ids = context->input(1); 83 84 if (!SegmentReductionDoValidation(context, input, segment_ids)) { 85 return; 86 } 87 88 const int64 num_indices = segment_ids.NumElements(); 89 auto input_flat = input.flat_outer_dims<T>(); 90 const int64 num_col = input_flat.dimension(1); 91 92 const auto segment_vec = segment_ids.vec<Index>(); 93 // Note that the current implementation assumes that segment_vec values are 94 // sorted. 95 const Index output_rows = 96 num_indices > 0 97 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 98 : 0; 99 OP_REQUIRES(context, output_rows >= 0, 100 errors::InvalidArgument("segment ids must be >= 0")); 101 102 TensorShape output_shape = input.shape(); 103 output_shape.set_dim(0, output_rows); 104 105 // Note that we do not initialize the output buffer with a default value, so 106 // we need to explicitly set missing indices to the default value. 107 Tensor* output = nullptr; 108 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 109 if (num_indices == 0) return; 110 OP_REQUIRES(context, output_rows > 0, 111 errors::InvalidArgument("segment ids must be >= 0")); 112 auto output_flat = output->flat_outer_dims<T>(); 113 114 #if !defined(EIGEN_HAS_INDEX_LIST) 115 Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce; 116 dims_to_reduce[0] = 0; 117 #else 118 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce; 119 #endif 120 Index start = 0, end = 1; 121 122 Index uninitialized_index = 0; // Index from which the output is not set. 123 Index out_index = internal::SubtleMustCopy(segment_vec(start)); 124 125 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it 126 // across threads. 127 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col); 128 while (end <= num_indices) { 129 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 130 // used uninitialized in this function" in the Mac build (since the 131 // compiler isn't smart enough to realize the code is safe). 132 Index next_index = 0; 133 if (end < num_indices) { 134 next_index = internal::SubtleMustCopy(segment_vec(end)); 135 if (out_index == next_index) { 136 ++end; 137 continue; 138 } 139 // We have a new segment here. Verify that the segment ids are growing. 140 OP_REQUIRES(context, out_index < next_index, 141 errors::InvalidArgument("segment ids are not increasing")); 142 } 143 144 // Process segment [start, end) 145 const T* in_slice_ptr = &input_flat(start, 0); 146 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 147 Eigen::Unaligned> 148 OutT; 149 150 OP_REQUIRES( 151 context, FastBoundsCheck(out_index, output_rows), 152 errors::InvalidArgument( 153 "Segment id ", out_index, " out of range [0, ", output_rows, 154 "), possibly because 'segment_ids' input is not sorted.")); 155 156 // If there is a gap between two indices, we need to set that gap to the 157 // default value. 158 if (out_index > uninitialized_index) { 159 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 160 out_index - uninitialized_index, num_col); 161 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 162 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 163 gap_slice.setConstant(T(default_value)); 164 } 165 166 T* out_slice_ptr = &output_flat(out_index, 0); 167 OutT out_slice(out_slice_ptr, out_slice_shape); 168 // We don't use out_slice.device(context->eigen_device<Device>) 169 // because these pieces of work are likely to be very small and 170 // the context switching overhead dwarfs any benefit we get from 171 // using another thread to do this work. 172 if (start == end - 1) { 173 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, 174 Eigen::Unaligned> 175 InT; 176 InT in_slice(in_slice_ptr, out_slice_shape); 177 out_slice = in_slice; 178 } else { 179 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start, 180 num_col); 181 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 182 Eigen::Unaligned> 183 InT; 184 InT in_slice(in_slice_ptr, in_slice_shape); 185 186 out_slice = in_slice.reduce(dims_to_reduce, Reducer()); 187 } 188 if (end >= num_indices) break; 189 start = end; 190 ++end; 191 uninitialized_index = out_index + 1; 192 out_index = next_index; 193 } 194 } 195 }; 196 197 #ifdef GOOGLE_CUDA 198 // SegmentSumGPUOp is a segment sum operator implemented for GPU only. 199 // TODO: This implementation of SegmentSumGPUOp is sometimes slower than 200 // its unsorted counterpart (mostly when problem size is small). 201 // This is due to the following two main reasons and a cost-effective way 202 // to resolve these problems is desirable. 203 // 1. Sorted segment sum requires a memory transfer from device to host in 204 // order to know the size of the output dimension whereas unsorted segment 205 // sum receives the size of the output dimension as an input parameter. 206 // 2. Sorted segment sum is essentially a tiled version of unsorted segment 207 // sum and therefore such optimization comes at an inherent cost. However 208 // such cost may not be justified when the problem size is small. When to 209 // use the tiled version or the untiled version depends on many factors 210 // including data alignments, ratio of calculation to memory traffic and 211 // obviously, the problem sizes. 212 template <class T, class Index> 213 class SegmentSumGPUOp : public AsyncOpKernel { 214 public: 215 explicit SegmentSumGPUOp(OpKernelConstruction* context) 216 : AsyncOpKernel(context) {} 217 218 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 219 const Tensor& input = context->input(0); 220 const Tensor& segment_ids = context->input(1); 221 222 OP_REQUIRES_ASYNC( 223 context, TensorShapeUtils::IsVector(segment_ids.shape()), 224 errors::InvalidArgument("segment_ids should be a vector."), done); 225 226 const int64 num_indices = segment_ids.NumElements(); 227 OP_REQUIRES_ASYNC( 228 context, num_indices == input.dim_size(0), 229 errors::InvalidArgument( 230 "segment_ids should be the same size as dimension 0 of" 231 " input."), 232 done); 233 234 if (num_indices == 0) { 235 TensorShape output_shape = input.shape(); 236 output_shape.set_dim(0, 0); 237 238 Tensor* output = nullptr; 239 OP_REQUIRES_OK_ASYNC( 240 context, context->allocate_output(0, output_shape, &output), done); 241 done(); 242 return; 243 } 244 245 perftools::gputools::DeviceMemoryBase output_rows_device( 246 const_cast<Tensor&>(segment_ids).template flat<Index>().data() + 247 (num_indices - 1)); 248 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); 249 250 auto stream = context->op_device_context()->stream(); 251 OP_REQUIRES_ASYNC( 252 context, 253 stream 254 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, 255 sizeof(Index)) 256 .ok(), 257 errors::Internal( 258 "SegmentSumGPUOp: failed to copy output_rows from device"), 259 done); 260 261 functor::SegmentSumFunctor<T, Index> functor_; 262 auto create_and_check_output = [context, output_rows_host, &input, 263 &segment_ids, &functor_, done]() { 264 // Ensure that within the callback, the proper GPU settings are 265 // configured. 266 auto stream = context->op_device_context()->stream(); 267 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 268 269 Index output_rows = *output_rows_host.data(); 270 output_rows++; 271 OP_REQUIRES_ASYNC(context, output_rows > 0, 272 errors::InvalidArgument("segment ids must be >= 0"), 273 done); 274 275 TensorShape output_shape = input.shape(); 276 output_shape.set_dim(0, output_rows); 277 278 Tensor* output = nullptr; 279 OP_REQUIRES_OK_ASYNC( 280 context, context->allocate_output(0, output_shape, &output), done); 281 282 auto output_flat = output->flat_outer_dims<T>(); 283 auto data_ptr = input.template flat<T>().data(); 284 auto segment_flat = segment_ids.flat<Index>(); 285 functor_(context, context->eigen_device<GPUDevice>(), output_rows, 286 segment_ids.shape(), segment_flat, input.NumElements(), data_ptr, 287 output_flat); 288 289 done(); 290 }; 291 292 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 293 stream, create_and_check_output); 294 } 295 }; 296 #endif // GOOGLE_CUDA 297 298 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \ 299 default_value) \ 300 REGISTER_KERNEL_BUILDER( \ 301 Name(name) \ 302 .Device(DEVICE_CPU) \ 303 .TypeConstraint<type>("T") \ 304 .TypeConstraint<index_type>("Tindices"), \ 305 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>) 306 307 #define REGISTER_REAL_CPU_KERNELS(type, index_type) \ 308 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ 309 type, index_type, 0); \ 310 REGISTER_CPU_KERNEL_SEGMENT( \ 311 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \ 312 REGISTER_CPU_KERNEL_SEGMENT( \ 313 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \ 314 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \ 315 type, index_type, 0); \ 316 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \ 317 type, index_type, 0) 318 319 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \ 320 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ 321 type, index_type, 0); \ 322 REGISTER_CPU_KERNEL_SEGMENT( \ 323 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1) 324 325 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \ 326 REGISTER_REAL_CPU_KERNELS(type, int32); \ 327 REGISTER_REAL_CPU_KERNELS(type, int64) 328 329 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \ 330 REGISTER_COMPLEX_CPU_KERNELS(type, int32); \ 331 REGISTER_COMPLEX_CPU_KERNELS(type, int64) 332 333 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL); 334 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64); 335 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); 336 #undef REGISTER_CPU_KERNEL_SEGMENT 337 #undef REGISTER_REAL_CPU_KERNELS 338 #undef REGISTER_COMPLEX_CPU_KERNELS 339 #undef REGISTER_REAL_CPU_KERNELS_ALL 340 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL 341 342 #if GOOGLE_CUDA 343 #define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ 344 REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ 345 .Device(DEVICE_GPU) \ 346 .TypeConstraint<type>("T") \ 347 .TypeConstraint<index_type>("Tindices"), \ 348 SegmentSumGPUOp<type, index_type>) 349 350 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ 351 REGISTER_GPU_SORTED_KERNELS(type, int32); \ 352 REGISTER_GPU_SORTED_KERNELS(type, int64); 353 354 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); 355 #undef REGISTER_GPU_SORTED_KERNELS 356 #undef REGISTER_GPU_SORTED_KERNELS_ALL 357 #endif // GOOGLE_CUDA 358 359 // ____________________________________________________________________________ 360 // Unsorted segment reduction ops. 361 362 namespace functor { 363 364 // The ReductionFunctor implementation for CPU. 365 template <typename T, typename Index, typename InitialValueF, 366 typename ReductionF> 367 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> { 368 void operator()(OpKernelContext* ctx, const Index num_segments, 369 const TensorShape& segment_ids_shape, 370 typename TTypes<Index>::ConstFlat segment_ids, 371 const Index data_size, const T* data, 372 typename TTypes<T, 2>::Tensor output) { 373 output.setConstant(InitialValueF()()); 374 if (data_size == 0) { 375 return; 376 } 377 const int64 N = segment_ids.dimension(0); 378 ReductionF reduction; 379 auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N); 380 for (int64 i = 0; i < N; ++i) { 381 Index j = internal::SubtleMustCopy(segment_ids(i)); 382 if (j < 0) { 383 continue; 384 } 385 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), 386 errors::InvalidArgument( 387 "segment_ids", SliceDebugString(segment_ids_shape, i), 388 " = ", j, " is out of range [0, ", num_segments, ")")); 389 reduction(data_flat.template chip<0>(i), output.template chip<0>(j)); 390 } 391 } 392 }; 393 394 template <typename T> 395 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>; 396 397 template <typename T> 398 using constMatrixChip = 399 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>; 400 401 // reduction functors 402 template <typename T> 403 struct SumOp { 404 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 405 output += data; 406 } 407 }; 408 409 template <typename T> 410 struct MaxOp { 411 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 412 output = data.cwiseMax(output); 413 } 414 }; 415 416 template <typename T> 417 struct MinOp { 418 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 419 output = data.cwiseMin(output); 420 } 421 }; 422 423 template <typename T> 424 struct ProdOp { 425 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 426 output *= data; 427 } 428 }; 429 } // namespace functor 430 431 // Static check routines not in the templated class to reduce code size 432 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel, 433 OpKernelContext* context, 434 const Tensor& data, 435 const Tensor& segment_ids, 436 const Tensor& num_segments) { 437 OP_REQUIRES( 438 context, op_kernel->IsLegacyScalar(num_segments.shape()), 439 errors::InvalidArgument("num_segments should be a scalar, not shape ", 440 num_segments.shape().DebugString())); 441 OP_REQUIRES( 442 context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), 443 errors::InvalidArgument("data.shape = ", data.shape().DebugString(), 444 " does not start with segment_ids.shape = ", 445 segment_ids.shape().DebugString())); 446 } 447 448 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel, 449 OpKernelContext* context, 450 const Tensor& data, 451 const Tensor& segment_ids, 452 const Tensor& num_segments) { 453 UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids, 454 num_segments); 455 return context->status().ok(); 456 } 457 458 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor 459 // is the device specific implementation of the reduction. These device 460 // specific implementations are templated themselves with the corresponding 461 // initial value functors and reduction functors. 462 template <typename T, typename Index, typename DeviceReductionFunctor> 463 class UnsortedSegmentReductionOp : public OpKernel { 464 public: 465 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) 466 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} 467 468 void Compute(OpKernelContext* context) override { 469 const Tensor& data = context->input(0); 470 const Tensor& segment_ids = context->input(1); 471 const Tensor& num_segments = context->input(2); 472 if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids, 473 num_segments)) { 474 return; 475 } 476 const auto segment_flat = segment_ids.flat<Index>(); 477 const Index output_rows = 478 internal::SubtleMustCopy(num_segments.scalar<int32>()()); 479 OP_REQUIRES(context, output_rows >= 0, 480 errors::InvalidArgument("Input num_segments == ", output_rows, 481 " must not be negative.")); 482 TensorShape output_shape; 483 output_shape.AddDim(output_rows); 484 for (int i = segment_ids.dims(); i < data.dims(); i++) { 485 output_shape.AddDim(data.dim_size(i)); 486 } 487 Tensor* output = nullptr; 488 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 489 auto output_flat = output->flat_outer_dims<T>(); 490 auto data_ptr = data.template flat<T>().data(); 491 reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat, 492 data.NumElements(), data_ptr, output_flat); 493 } 494 495 protected: 496 DeviceReductionFunctor reduction_functor_; 497 }; 498 499 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \ 500 name, type, index_type, initial_value_functor, reduction_functor) \ 501 REGISTER_KERNEL_BUILDER( \ 502 Name(name) \ 503 .Device(DEVICE_CPU) \ 504 .TypeConstraint<type>("T") \ 505 .TypeConstraint<index_type>("Tindices"), \ 506 UnsortedSegmentReductionOp< \ 507 type, index_type, \ 508 functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \ 509 initial_value_functor, \ 510 reduction_functor> >) 511 512 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ 513 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 514 functor::Zero<type>, \ 515 functor::SumOp<type>); \ 516 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ 517 functor::Lowest<type>, \ 518 functor::MaxOp<type>); \ 519 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ 520 functor::Highest<type>, \ 521 functor::MinOp<type>); \ 522 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 523 functor::One<type>, \ 524 functor::ProdOp<type>); 525 526 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ 527 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 528 functor::Zero<type>, \ 529 functor::SumOp<type>); \ 530 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 531 functor::One<type>, \ 532 functor::ProdOp<type>) 533 534 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ 535 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ 536 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64) 537 538 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \ 539 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \ 540 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64) 541 542 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); 543 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); 544 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); 545 546 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS 547 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT 548 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS 549 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL 550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL 551 552 #if GOOGLE_CUDA 553 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \ 554 name, type, index_type, initial_value_functor, reduction_kernel_functor) \ 555 REGISTER_KERNEL_BUILDER( \ 556 Name(name) \ 557 .Device(DEVICE_GPU) \ 558 .HostMemory("num_segments") \ 559 .TypeConstraint<type>("T") \ 560 .TypeConstraint<index_type>("Tindices"), \ 561 UnsortedSegmentReductionOp< \ 562 type, index_type, \ 563 functor::UnsortedSegmentFunctor<GPUDevice, type, index_type, \ 564 initial_value_functor, \ 565 reduction_kernel_functor> >) 566 567 // sum is the only op that supports all input types currently 568 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ 569 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ 570 functor::Lowest<type>, \ 571 functor::MaxOpGpu<type>); \ 572 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ 573 functor::Highest<type>, \ 574 functor::MinOpGpu<type>); \ 575 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 576 functor::One<type>, \ 577 functor::ProdOpGpu<type>); 578 579 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ 580 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 581 functor::Zero<type>, \ 582 functor::SumOpGpu<type>); 583 584 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ 585 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32); \ 586 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64); 587 588 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \ 589 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32); \ 590 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64); 591 592 593 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); 594 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); 595 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 596 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 597 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 598 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 599 600 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT 601 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS 602 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS 603 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL 604 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL 605 606 #endif // GOOGLE_CUDA 607 608 // ____________________________________________________________________________ 609 // Sparse segment reduction ops. 610 611 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented 612 // by two dense tensors, one containing the data, and the other containing 613 // indices into the data. 614 template <typename Device, class T> 615 class SparseSegmentReductionOpBase : public OpKernel { 616 public: 617 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 618 bool is_mean, bool is_sqrtn, 619 bool has_num_segments, T default_value) 620 : OpKernel(context), 621 is_mean_(is_mean), 622 is_sqrtn_(is_sqrtn), 623 has_num_segments_(has_num_segments), 624 default_value_(default_value) {} 625 626 void Compute(OpKernelContext* context) override { 627 const Tensor& input = context->input(0); 628 const Tensor& indices = context->input(1); 629 const Tensor& segment_ids = context->input(2); 630 631 Index output_rows = -1; 632 if (has_num_segments_) { 633 const Tensor& num_segments = context->input(3); 634 635 OP_REQUIRES( 636 context, num_segments.shape().dims() == 0, 637 errors::InvalidArgument("num_segments should be a scalar, not shape ", 638 num_segments.shape().DebugString())); 639 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); 640 OP_REQUIRES(context, output_rows >= 0, 641 errors::InvalidArgument("segment ids must be >= 0")); 642 } 643 644 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 645 errors::InvalidArgument("indices should be a vector.")); 646 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 647 errors::InvalidArgument("segment_ids should be a vector.")); 648 649 const int64 num_indices = indices.NumElements(); 650 OP_REQUIRES(context, num_indices == segment_ids.NumElements(), 651 errors::InvalidArgument( 652 "segment_ids and indices should have same size.")); 653 654 auto input_flat = input.flat_outer_dims<T>(); 655 const int64 num_col = input_flat.dimension(1); 656 const auto indices_vec = indices.vec<Index>(); 657 typedef int32 OutputRow; 658 const auto segment_vec = segment_ids.vec<OutputRow>(); 659 // Note that the current implementation assumes that segment_vec values are 660 // sorted. 661 const OutputRow last_segment_id_plus_one = 662 num_indices > 0 663 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 664 : 0; 665 if (has_num_segments_) { 666 OP_REQUIRES( 667 context, output_rows >= last_segment_id_plus_one, 668 errors::InvalidArgument("segment ids must be < num_segments")); 669 } else { 670 output_rows = last_segment_id_plus_one; 671 } 672 OP_REQUIRES(context, output_rows >= 0, 673 errors::InvalidArgument("segment ids must be >= 0")); 674 675 TensorShape output_shape = input.shape(); 676 output_shape.set_dim(0, output_rows); 677 678 // Note that we do not initialize the output buffer with a default value, so 679 // we need to explicitly set missing indices to the default value. 680 Tensor* output = nullptr; 681 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 682 if (num_indices == 0) return; 683 OP_REQUIRES(context, output_rows > 0, 684 errors::InvalidArgument("segment ids must be >= 0")); 685 auto output_flat = output->flat_outer_dims<T>(); 686 687 int64 start = 0, end = 1; 688 // Index from which the output is not initialized. 689 OutputRow uninitialized_index = 0; 690 OutputRow out_index = internal::SubtleMustCopy(segment_vec(start)); 691 692 while (true) { 693 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 694 // used uninitialized in this function" in the Mac build (since the 695 // compiler isn't smart enough to realize the code is safe). 696 OutputRow next_index = 0; 697 if (end < num_indices) { 698 next_index = internal::SubtleMustCopy(segment_vec(end)); 699 if (out_index == next_index) { 700 ++end; 701 continue; 702 } 703 // We have a new segment here. Verify that the segment ids are growing. 704 OP_REQUIRES(context, out_index < next_index, 705 errors::InvalidArgument("segment ids are not increasing")); 706 } 707 708 OP_REQUIRES( 709 context, FastBoundsCheck(out_index, output_rows), 710 errors::InvalidArgument( 711 "Segment id ", out_index, " out of range [0, ", output_rows, 712 "), possibly because 'segment_ids' input is not sorted.")); 713 714 // If there is a gap between two indices, we need to set that gap to the 715 // default value. 716 if (out_index > uninitialized_index) { 717 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 718 out_index - uninitialized_index, num_col); 719 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 720 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 721 gap_slice.setConstant(default_value_); 722 } 723 724 auto out = output_flat.template chip<0>(out_index); 725 const int bad_offset = 726 Reduce(input_flat, indices_vec, start, end - start, out); 727 OP_REQUIRES(context, bad_offset < 0, 728 errors::InvalidArgument( 729 "Bad: indices[", start + bad_offset, 730 "] == ", indices_vec(start + bad_offset), 731 " out of range [0, ", input_flat.dimension(0), ")")); 732 733 start = end; 734 ++end; 735 uninitialized_index = out_index + 1; 736 out_index = next_index; 737 if (end > num_indices) break; 738 } 739 740 // Fill the gap at the end with the default value. 741 if (uninitialized_index < output_rows) { 742 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 743 output_rows - uninitialized_index, num_col); 744 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 745 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 746 gap_slice.setConstant(default_value_); 747 } 748 } 749 750 private: 751 typedef int32 Index; 752 753 int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat, 754 const typename TTypes<Index>::ConstVec& indices_vec, int64 start, 755 int64 num, 756 Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) { 757 #define INDEX(n, i) \ 758 const auto index##n = indices_vec(start + (i)); \ 759 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); 760 761 #define L(n) input_flat.template chip<0>(index##n) 762 763 if (num == 1) { 764 INDEX(0, 0); 765 out = L(0); 766 } else { 767 int64 r = num % 8; 768 T m(1); 769 if (is_mean_ && (num < 10)) { 770 m = T(num); 771 } 772 if (is_sqrtn_ && (num < 10)) { 773 m = T(sqrt(num)); 774 } 775 switch (r) { 776 case 2: { 777 INDEX(0, 0); 778 INDEX(1, 1); 779 out = (L(0) + L(1)) / m; 780 break; 781 } 782 case 3: { 783 INDEX(0, 0); 784 INDEX(1, 1); 785 INDEX(2, 2); 786 out = (L(0) + L(1) + L(2)) / m; 787 break; 788 } 789 case 4: { 790 INDEX(0, 0); 791 INDEX(1, 1); 792 INDEX(2, 2); 793 INDEX(3, 3); 794 out = (L(0) + L(1) + L(2) + L(3)) / m; 795 break; 796 } 797 case 5: { 798 INDEX(0, 0); 799 INDEX(1, 1); 800 INDEX(2, 2); 801 INDEX(3, 3); 802 INDEX(4, 4); 803 out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m; 804 break; 805 } 806 case 6: { 807 INDEX(0, 0); 808 INDEX(1, 1); 809 INDEX(2, 2); 810 INDEX(3, 3); 811 INDEX(4, 4); 812 INDEX(5, 5); 813 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m; 814 break; 815 } 816 case 7: { 817 INDEX(0, 0); 818 INDEX(1, 1); 819 INDEX(2, 2); 820 INDEX(3, 3); 821 INDEX(4, 4); 822 INDEX(5, 5); 823 INDEX(6, 6); 824 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m; 825 break; 826 } 827 case 0: { 828 INDEX(0, 0); 829 INDEX(1, 1); 830 INDEX(2, 2); 831 INDEX(3, 3); 832 INDEX(4, 4); 833 INDEX(5, 5); 834 INDEX(6, 6); 835 INDEX(7, 7); 836 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m; 837 r = 8; 838 break; 839 } 840 case 1: { 841 INDEX(0, 0); 842 INDEX(1, 1); 843 INDEX(2, 2); 844 INDEX(3, 3); 845 INDEX(4, 4); 846 INDEX(5, 5); 847 INDEX(6, 6); 848 INDEX(7, 7); 849 INDEX(8, 8); 850 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) / 851 m; 852 r = 9; 853 break; 854 } 855 } 856 for (; r < num; r += 8) { 857 INDEX(0, r); 858 INDEX(1, r + 1); 859 INDEX(2, r + 2); 860 INDEX(3, r + 3); 861 INDEX(4, r + 4); 862 INDEX(5, r + 5); 863 INDEX(6, r + 6); 864 INDEX(7, r + 7); 865 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); 866 } 867 if (is_mean_ && num >= 10) { 868 out = out / static_cast<T>(num); 869 } 870 if (is_sqrtn_ && num >= 10) { 871 out = out / static_cast<T>(sqrt(num)); 872 } 873 } 874 875 return -1; 876 #undef L 877 #undef INDEX 878 } 879 880 const bool is_mean_; 881 const bool is_sqrtn_; 882 const bool has_num_segments_; 883 const T default_value_; 884 }; 885 886 template <typename Device, class T> 887 class SparseSegmentReductionMeanOp 888 : public SparseSegmentReductionOpBase<Device, T> { 889 public: 890 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) 891 : SparseSegmentReductionOpBase<Device, T>( 892 context, true /*is_mean*/, false /*is_sqrtn*/, 893 false /* has_num_segments */, T(0) /* default_value */) {} 894 }; 895 896 template <typename Device, class T> 897 class SparseSegmentReductionMeanWithNumSegmentsOp 898 : public SparseSegmentReductionOpBase<Device, T> { 899 public: 900 explicit SparseSegmentReductionMeanWithNumSegmentsOp( 901 OpKernelConstruction* context) 902 : SparseSegmentReductionOpBase<Device, T>( 903 context, true /*is_mean*/, false /*is_sqrtn*/, 904 true /* has_num_segments */, T(0) /* default_value */) {} 905 }; 906 907 template <typename Device, class T> 908 class SparseSegmentReductionSqrtNOp 909 : public SparseSegmentReductionOpBase<Device, T> { 910 public: 911 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) 912 : SparseSegmentReductionOpBase<Device, T>( 913 context, false /*is_mean*/, true /*is_sqrtn*/, 914 false /* has_num_segments */, T(0) /* default_value */) {} 915 }; 916 917 template <typename Device, class T> 918 class SparseSegmentReductionSqrtNWithNumSegmentsOp 919 : public SparseSegmentReductionOpBase<Device, T> { 920 public: 921 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( 922 OpKernelConstruction* context) 923 : SparseSegmentReductionOpBase<Device, T>( 924 context, false /*is_mean*/, true /*is_sqrtn*/, 925 true /* has_num_segments */, T(0) /* default_value */) {} 926 }; 927 928 template <typename Device, class T> 929 class SparseSegmentReductionSumOp 930 : public SparseSegmentReductionOpBase<Device, T> { 931 public: 932 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) 933 : SparseSegmentReductionOpBase<Device, T>( 934 context, false /*is_mean*/, false /*is_sqrtn*/, 935 false /* has_num_segments */, T(0) /* default_value */) {} 936 }; 937 938 template <typename Device, class T> 939 class SparseSegmentReductionSumWithNumSegmentsOp 940 : public SparseSegmentReductionOpBase<Device, T> { 941 public: 942 explicit SparseSegmentReductionSumWithNumSegmentsOp( 943 OpKernelConstruction* context) 944 : SparseSegmentReductionOpBase<Device, T>( 945 context, false /*is_mean*/, false /*is_sqrtn*/, 946 true /* has_num_segments */, T(0) /* default_value */) {} 947 }; 948 949 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 950 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \ 951 .Device(DEVICE_CPU) \ 952 .TypeConstraint<type>("T") \ 953 .TypeConstraint<int32>("Tidx"), \ 954 SparseSegmentReductionSumOp<CPUDevice, type>); \ 955 REGISTER_KERNEL_BUILDER( \ 956 Name("SparseSegmentSumWithNumSegments") \ 957 .Device(DEVICE_CPU) \ 958 .TypeConstraint<type>("T") \ 959 .TypeConstraint<int32>("Tidx"), \ 960 SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>); 961 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS); 962 #undef REGISTER_CPU_SPARSE_KERNELS 963 964 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 965 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \ 966 .Device(DEVICE_CPU) \ 967 .TypeConstraint<type>("T") \ 968 .TypeConstraint<int32>("Tidx"), \ 969 SparseSegmentReductionMeanOp<CPUDevice, type>); \ 970 REGISTER_KERNEL_BUILDER( \ 971 Name("SparseSegmentMeanWithNumSegments") \ 972 .Device(DEVICE_CPU) \ 973 .TypeConstraint<type>("T") \ 974 .TypeConstraint<int32>("Tidx"), \ 975 SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>); 976 REGISTER_CPU_SPARSE_KERNELS(float); 977 REGISTER_CPU_SPARSE_KERNELS(double); 978 #undef REGISTER_CPU_SPARSE_KERNELS 979 980 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 981 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \ 982 .Device(DEVICE_CPU) \ 983 .TypeConstraint<type>("T") \ 984 .TypeConstraint<int32>("Tidx"), \ 985 SparseSegmentReductionSqrtNOp<CPUDevice, type>); \ 986 REGISTER_KERNEL_BUILDER( \ 987 Name("SparseSegmentSqrtNWithNumSegments") \ 988 .Device(DEVICE_CPU) \ 989 .TypeConstraint<type>("T") \ 990 .TypeConstraint<int32>("Tidx"), \ 991 SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>); 992 REGISTER_CPU_SPARSE_KERNELS(float); 993 REGISTER_CPU_SPARSE_KERNELS(double); 994 #undef REGISTER_CPU_SPARSE_KERNELS 995 996 template <class T> 997 class SparseSegmentGradOpBase : public OpKernel { 998 public: 999 explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn) 1000 : OpKernel(context), is_sqrtn_(is_sqrtn) {} 1001 1002 void Compute(OpKernelContext* context) override { 1003 const Tensor& input = context->input(0); 1004 const Tensor& indices = context->input(1); 1005 const Tensor& segment_ids = context->input(2); 1006 const Tensor& output_dim0 = context->input(3); 1007 1008 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 1009 errors::InvalidArgument("indices should be a vector.")); 1010 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 1011 errors::InvalidArgument("segment_ids should be a vector.")); 1012 OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()), 1013 errors::InvalidArgument("output_dim0 should be a scalar.")); 1014 1015 const int64 N = indices.NumElements(); 1016 OP_REQUIRES(context, N == segment_ids.NumElements(), 1017 errors::InvalidArgument( 1018 "segment_ids and indices should have same size.")); 1019 typedef int32 SegmentId; 1020 const SegmentId M = 1021 internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()()); 1022 1023 auto input_flat = input.flat_outer_dims<T>(); 1024 typedef int32 Index; 1025 const auto indices_vec = indices.vec<Index>(); 1026 const auto segment_vec = segment_ids.vec<SegmentId>(); 1027 1028 TensorShape output_shape = input.shape(); 1029 output_shape.set_dim(0, M); 1030 Tensor* output = nullptr; 1031 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 1032 if (M == 0 || N == 0) return; 1033 1034 // Note that similar to SparseSegmentMean, we assume that segment_vec is 1035 // already sorted and has non-negative values. 1036 const SegmentId num_segments = input.dim_size(0); 1037 const SegmentId last_segment_id_plus_one = 1038 internal::SubtleMustCopy(segment_vec(N - 1)) + 1; 1039 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, 1040 errors::InvalidArgument("Invalid number of segments")); 1041 1042 // Compute scaling factors for input. 1043 std::vector<double> scaling(num_segments, 0.0); 1044 for (int64 i = 0; i < N; ++i) { 1045 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1046 OP_REQUIRES( 1047 context, FastBoundsCheck(idx, num_segments), 1048 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1049 num_segments, ").")); 1050 scaling[idx] += 1; 1051 } 1052 for (size_t i = 0; i < scaling.size(); ++i) { 1053 if (is_sqrtn_) { 1054 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); 1055 } else { 1056 scaling[i] = 1.0 / std::max(scaling[i], 1.0); 1057 } 1058 } 1059 1060 auto output_flat = output->flat_outer_dims<T>(); 1061 output_flat.setZero(); 1062 std::vector<bool> is_modified(M, false); 1063 1064 for (int64 i = 0; i < N; ++i) { 1065 const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); 1066 OP_REQUIRES(context, FastBoundsCheck(output_idx, M), 1067 errors::InvalidArgument("Index ", output_idx, 1068 " out of range [0, ", M, ").")); 1069 1070 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1071 OP_REQUIRES( 1072 context, FastBoundsCheck(idx, num_segments), 1073 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1074 num_segments, ").")); 1075 1076 const T scale = static_cast<T>(scaling[idx]); 1077 if (is_modified[output_idx]) { 1078 if (scale == 1.0) { 1079 output_flat.template chip<0>(output_idx) += 1080 input_flat.template chip<0>(idx); 1081 } else { 1082 output_flat.template chip<0>(output_idx) += 1083 input_flat.template chip<0>(idx) * scale; 1084 } 1085 } else { 1086 if (scale == 1.0) { 1087 output_flat.template chip<0>(output_idx) = 1088 input_flat.template chip<0>(idx); 1089 } else { 1090 output_flat.template chip<0>(output_idx) = 1091 input_flat.template chip<0>(idx) * scale; 1092 } 1093 } 1094 is_modified[output_idx] = true; 1095 } 1096 } 1097 1098 private: 1099 const bool is_sqrtn_; 1100 }; 1101 1102 template <class T> 1103 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> { 1104 public: 1105 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) 1106 : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {} 1107 }; 1108 1109 template <class T> 1110 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> { 1111 public: 1112 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) 1113 : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {} 1114 }; 1115 1116 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 1117 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \ 1118 .Device(DEVICE_CPU) \ 1119 .TypeConstraint<type>("T") \ 1120 .TypeConstraint<int32>("Tidx"), \ 1121 SparseSegmentMeanGradOp<type>); 1122 REGISTER_CPU_SPARSE_KERNELS(float); 1123 REGISTER_CPU_SPARSE_KERNELS(double); 1124 #undef REGISTER_CPU_SPARSE_KERNELS 1125 1126 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 1127 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \ 1128 .Device(DEVICE_CPU) \ 1129 .TypeConstraint<type>("T") \ 1130 .TypeConstraint<int32>("Tidx"), \ 1131 SparseSegmentSqrtNGradOp<type>); 1132 REGISTER_CPU_SPARSE_KERNELS(float); 1133 REGISTER_CPU_SPARSE_KERNELS(double); 1134 #undef REGISTER_CPU_SPARSE_KERNELS 1135 } // namespace tensorflow 1136