1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 #if GOOGLE_CUDA 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 26 #include "tensorflow/core/kernels/segment_reduction_ops.h" 27 #include <vector> 28 29 #include "tensorflow/core/framework/bounds_check.h" 30 #include "tensorflow/core/framework/numeric_op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_types.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/platform/logging.h" 38 #include "tensorflow/core/util/util.h" 39 40 #if GOOGLE_CUDA 41 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 42 #include "tensorflow/core/kernels/cuda_solvers.h" 43 #include "tensorflow/core/platform/cuda.h" 44 45 using stream_executor::cuda::ScopedActivateExecutorContext; 46 #endif // GOOGLE_CUDA 47 48 namespace tensorflow { 49 50 typedef Eigen::ThreadPoolDevice CPUDevice; 51 typedef Eigen::GpuDevice GPUDevice; 52 53 // Static routines not in the templated class to reduce code size 54 static void SegmentReductionValidationHelper(OpKernelContext* context, 55 const Tensor& input, 56 const Tensor& segment_ids) { 57 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 58 errors::InvalidArgument("segment_ids should be a vector.")); 59 const int64 num_indices = segment_ids.NumElements(); 60 OP_REQUIRES(context, num_indices == input.dim_size(0), 61 errors::InvalidArgument( 62 "segment_ids should be the same size as dimension 0 of" 63 " input.")); 64 } 65 66 static bool SegmentReductionDoValidation(OpKernelContext* c, 67 const Tensor& input, 68 const Tensor& segment_ids) { 69 SegmentReductionValidationHelper(c, input, segment_ids); 70 return c->status().ok(); 71 } 72 73 // This operator handles reducing segments along the first dimension. 74 // See core/ops/math_ops.cc for more details. 75 template <typename Device, class T, class Index, typename Reducer, 76 int default_value> 77 class SegmentReductionOp : public OpKernel { 78 public: 79 explicit SegmentReductionOp(OpKernelConstruction* context) 80 : OpKernel(context) {} 81 82 void Compute(OpKernelContext* context) override { 83 const Tensor& input = context->input(0); 84 const Tensor& segment_ids = context->input(1); 85 86 if (!SegmentReductionDoValidation(context, input, segment_ids)) { 87 return; 88 } 89 90 const int64 num_indices = segment_ids.NumElements(); 91 auto input_flat = input.flat_outer_dims<T>(); 92 const int64 num_col = input_flat.dimension(1); 93 94 const auto segment_vec = segment_ids.vec<Index>(); 95 // Note that the current implementation assumes that segment_vec values are 96 // sorted. 97 const Index output_rows = 98 num_indices > 0 99 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 100 : 0; 101 OP_REQUIRES(context, output_rows >= 0, 102 errors::InvalidArgument("segment ids must be >= 0")); 103 104 TensorShape output_shape = input.shape(); 105 output_shape.set_dim(0, output_rows); 106 107 // Note that we do not initialize the output buffer with a default value, so 108 // we need to explicitly set missing indices to the default value. 109 Tensor* output = nullptr; 110 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 111 if (num_indices == 0) return; 112 OP_REQUIRES(context, output_rows > 0, 113 errors::InvalidArgument("segment ids must be >= 0")); 114 auto output_flat = output->flat_outer_dims<T>(); 115 116 #if !defined(EIGEN_HAS_INDEX_LIST) 117 Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce; 118 dims_to_reduce[0] = 0; 119 #else 120 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce; 121 #endif 122 Index start = 0, end = 1; 123 124 Index uninitialized_index = 0; // Index from which the output is not set. 125 Index out_index = internal::SubtleMustCopy(segment_vec(start)); 126 127 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it 128 // across threads. 129 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col); 130 while (end <= num_indices) { 131 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 132 // used uninitialized in this function" in the Mac build (since the 133 // compiler isn't smart enough to realize the code is safe). 134 Index next_index = 0; 135 if (end < num_indices) { 136 next_index = internal::SubtleMustCopy(segment_vec(end)); 137 if (out_index == next_index) { 138 ++end; 139 continue; 140 } 141 // We have a new segment here. Verify that the segment ids are growing. 142 OP_REQUIRES(context, out_index < next_index, 143 errors::InvalidArgument("segment ids are not increasing")); 144 } 145 146 // Process segment [start, end) 147 const T* in_slice_ptr = &input_flat(start, 0); 148 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 149 Eigen::Unaligned> 150 OutT; 151 152 OP_REQUIRES( 153 context, FastBoundsCheck(out_index, output_rows), 154 errors::InvalidArgument( 155 "Segment id ", out_index, " out of range [0, ", output_rows, 156 "), possibly because 'segment_ids' input is not sorted.")); 157 158 // If there is a gap between two indices, we need to set that gap to the 159 // default value. 160 if (out_index > uninitialized_index) { 161 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 162 out_index - uninitialized_index, num_col); 163 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 164 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 165 gap_slice.setConstant(T(default_value)); 166 } 167 168 T* out_slice_ptr = &output_flat(out_index, 0); 169 OutT out_slice(out_slice_ptr, out_slice_shape); 170 // We don't use out_slice.device(context->eigen_device<Device>) 171 // because these pieces of work are likely to be very small and 172 // the context switching overhead dwarfs any benefit we get from 173 // using another thread to do this work. 174 if (start == end - 1) { 175 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, 176 Eigen::Unaligned> 177 InT; 178 InT in_slice(in_slice_ptr, out_slice_shape); 179 out_slice = in_slice; 180 } else { 181 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start, 182 num_col); 183 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 184 Eigen::Unaligned> 185 InT; 186 InT in_slice(in_slice_ptr, in_slice_shape); 187 188 out_slice = in_slice.reduce(dims_to_reduce, Reducer()); 189 } 190 if (end >= num_indices) break; 191 start = end; 192 ++end; 193 uninitialized_index = out_index + 1; 194 out_index = next_index; 195 } 196 } 197 }; 198 199 #ifdef GOOGLE_CUDA 200 // SegmentSumGPUOp is a segment sum operator implemented for GPU only. 201 // TODO: This implementation of SegmentSumGPUOp is sometimes slower than 202 // its unsorted counterpart (mostly when problem size is small). 203 // This is due to the following two main reasons and a cost-effective way 204 // to resolve these problems is desirable. 205 // 1. Sorted segment sum requires a memory transfer from device to host in 206 // order to know the size of the output dimension whereas unsorted segment 207 // sum receives the size of the output dimension as an input parameter. 208 // 2. Sorted segment sum is essentially a tiled version of unsorted segment 209 // sum and therefore such optimization comes at an inherent cost. However 210 // such cost may not be justified when the problem size is small. When to 211 // use the tiled version or the untiled version depends on many factors 212 // including data alignments, ratio of calculation to memory traffic and 213 // obviously, the problem sizes. 214 template <class T, class Index> 215 class SegmentSumGPUOp : public AsyncOpKernel { 216 public: 217 explicit SegmentSumGPUOp(OpKernelConstruction* context) 218 : AsyncOpKernel(context) {} 219 220 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 221 const Tensor& input = context->input(0); 222 const Tensor& segment_ids = context->input(1); 223 224 OP_REQUIRES_ASYNC( 225 context, TensorShapeUtils::IsVector(segment_ids.shape()), 226 errors::InvalidArgument("segment_ids should be a vector."), done); 227 228 const int64 num_indices = segment_ids.NumElements(); 229 OP_REQUIRES_ASYNC( 230 context, num_indices == input.dim_size(0), 231 errors::InvalidArgument( 232 "segment_ids should be the same size as dimension 0 of" 233 " input."), 234 done); 235 236 if (num_indices == 0) { 237 TensorShape output_shape = input.shape(); 238 output_shape.set_dim(0, 0); 239 240 Tensor* output = nullptr; 241 OP_REQUIRES_OK_ASYNC( 242 context, context->allocate_output(0, output_shape, &output), done); 243 done(); 244 return; 245 } 246 247 se::DeviceMemoryBase output_rows_device( 248 const_cast<Tensor&>(segment_ids).template flat<Index>().data() + 249 (num_indices - 1)); 250 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); 251 252 auto stream = context->op_device_context()->stream(); 253 OP_REQUIRES_ASYNC( 254 context, 255 stream 256 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, 257 sizeof(Index)) 258 .ok(), 259 errors::Internal( 260 "SegmentSumGPUOp: failed to copy output_rows from device"), 261 done); 262 263 functor::SegmentSumFunctor<T, Index> functor_; 264 auto create_and_check_output = [context, output_rows_host, &input, 265 &segment_ids, &functor_, done]() { 266 // Ensure that within the callback, the proper GPU settings are 267 // configured. 268 auto stream = context->op_device_context()->stream(); 269 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 270 271 Index output_rows = *output_rows_host.data(); 272 output_rows++; 273 OP_REQUIRES_ASYNC(context, output_rows > 0, 274 errors::InvalidArgument("segment ids must be >= 0"), 275 done); 276 277 TensorShape output_shape = input.shape(); 278 output_shape.set_dim(0, output_rows); 279 280 Tensor* output = nullptr; 281 OP_REQUIRES_OK_ASYNC( 282 context, context->allocate_output(0, output_shape, &output), done); 283 284 auto output_flat = output->flat_outer_dims<T>(); 285 auto data_ptr = input.template flat<T>().data(); 286 auto segment_flat = segment_ids.flat<Index>(); 287 functor_(context, context->eigen_device<GPUDevice>(), output_rows, 288 segment_ids.shape(), segment_flat, input.NumElements(), data_ptr, 289 output_flat); 290 291 done(); 292 }; 293 294 context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( 295 stream, create_and_check_output); 296 } 297 }; 298 #endif // GOOGLE_CUDA 299 300 #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \ 301 default_value) \ 302 REGISTER_KERNEL_BUILDER( \ 303 Name(name) \ 304 .Device(DEVICE_CPU) \ 305 .TypeConstraint<type>("T") \ 306 .TypeConstraint<index_type>("Tindices"), \ 307 SegmentReductionOp<CPUDevice, type, index_type, functor, default_value>) 308 309 #define REGISTER_REAL_CPU_KERNELS(type, index_type) \ 310 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ 311 type, index_type, 0); \ 312 REGISTER_CPU_KERNEL_SEGMENT( \ 313 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \ 314 REGISTER_CPU_KERNEL_SEGMENT( \ 315 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); \ 316 REGISTER_CPU_KERNEL_SEGMENT("SegmentMin", Eigen::internal::MinReducer<type>, \ 317 type, index_type, 0); \ 318 REGISTER_CPU_KERNEL_SEGMENT("SegmentMax", Eigen::internal::MaxReducer<type>, \ 319 type, index_type, 0) 320 321 #define REGISTER_COMPLEX_CPU_KERNELS(type, index_type) \ 322 REGISTER_CPU_KERNEL_SEGMENT("SegmentSum", Eigen::internal::SumReducer<type>, \ 323 type, index_type, 0); \ 324 REGISTER_CPU_KERNEL_SEGMENT( \ 325 "SegmentMean", Eigen::internal::MeanReducer<type>, type, index_type, 0); \ 326 REGISTER_CPU_KERNEL_SEGMENT( \ 327 "SegmentProd", Eigen::internal::ProdReducer<type>, type, index_type, 1); 328 329 #define REGISTER_REAL_CPU_KERNELS_ALL(type) \ 330 REGISTER_REAL_CPU_KERNELS(type, int32); \ 331 REGISTER_REAL_CPU_KERNELS(type, int64) 332 333 #define REGISTER_COMPLEX_CPU_KERNELS_ALL(type) \ 334 REGISTER_COMPLEX_CPU_KERNELS(type, int32); \ 335 REGISTER_COMPLEX_CPU_KERNELS(type, int64) 336 337 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_KERNELS_ALL); 338 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex64); 339 REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); 340 #undef REGISTER_CPU_KERNEL_SEGMENT 341 #undef REGISTER_REAL_CPU_KERNELS 342 #undef REGISTER_COMPLEX_CPU_KERNELS 343 #undef REGISTER_REAL_CPU_KERNELS_ALL 344 #undef REGISTER_COMPLEX_CPU_KERNELS_ALL 345 346 #if GOOGLE_CUDA 347 #define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ 348 REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ 349 .Device(DEVICE_GPU) \ 350 .TypeConstraint<type>("T") \ 351 .TypeConstraint<index_type>("Tindices"), \ 352 SegmentSumGPUOp<type, index_type>) 353 354 #define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ 355 REGISTER_GPU_SORTED_KERNELS(type, int32); \ 356 REGISTER_GPU_SORTED_KERNELS(type, int64); 357 358 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); 359 #undef REGISTER_GPU_SORTED_KERNELS 360 #undef REGISTER_GPU_SORTED_KERNELS_ALL 361 #endif // GOOGLE_CUDA 362 363 // ____________________________________________________________________________ 364 // Unsorted segment reduction ops. 365 366 namespace functor { 367 368 // The ReductionFunctor implementation for CPU. 369 template <typename T, typename Index, typename InitialValueF, 370 typename ReductionF> 371 struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> { 372 void operator()(OpKernelContext* ctx, const Index num_segments, 373 const TensorShape& segment_ids_shape, 374 typename TTypes<Index>::ConstFlat segment_ids, 375 const Index data_size, const T* data, 376 typename TTypes<T, 2>::Tensor output) { 377 output.setConstant(InitialValueF()()); 378 if (data_size == 0) { 379 return; 380 } 381 const int64 N = segment_ids.dimension(0); 382 ReductionF reduction; 383 auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N); 384 for (int64 i = 0; i < N; ++i) { 385 Index j = internal::SubtleMustCopy(segment_ids(i)); 386 if (j < 0) { 387 continue; 388 } 389 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), 390 errors::InvalidArgument( 391 "segment_ids", SliceDebugString(segment_ids_shape, i), 392 " = ", j, " is out of range [0, ", num_segments, ")")); 393 reduction(data_flat.template chip<0>(i), output.template chip<0>(j)); 394 } 395 } 396 }; 397 398 template <typename T> 399 using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>; 400 401 template <typename T> 402 using constMatrixChip = 403 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>; 404 405 // reduction functors 406 template <typename T> 407 struct SumOp { 408 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 409 output += data; 410 } 411 }; 412 413 template <typename T> 414 struct MaxOp { 415 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 416 output = data.cwiseMax(output); 417 } 418 }; 419 420 template <typename T> 421 struct MinOp { 422 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 423 output = data.cwiseMin(output); 424 } 425 }; 426 427 template <typename T> 428 struct ProdOp { 429 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { 430 output *= data; 431 } 432 }; 433 } // namespace functor 434 435 // Static check routines not in the templated class to reduce code size 436 static void UnsortedSegmentReductionValidation(OpKernel* op_kernel, 437 OpKernelContext* context, 438 const Tensor& data, 439 const Tensor& segment_ids, 440 const Tensor& num_segments) { 441 OP_REQUIRES( 442 context, op_kernel->IsLegacyScalar(num_segments.shape()), 443 errors::InvalidArgument("num_segments should be a scalar, not shape ", 444 num_segments.shape().DebugString())); 445 OP_REQUIRES( 446 context, TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()), 447 errors::InvalidArgument("data.shape = ", data.shape().DebugString(), 448 " does not start with segment_ids.shape = ", 449 segment_ids.shape().DebugString())); 450 } 451 452 static bool UnsortedSegmentReductionDoValidation(OpKernel* op_kernel, 453 OpKernelContext* context, 454 const Tensor& data, 455 const Tensor& segment_ids, 456 const Tensor& num_segments) { 457 UnsortedSegmentReductionValidation(op_kernel, context, data, segment_ids, 458 num_segments); 459 return context->status().ok(); 460 } 461 462 // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor 463 // is the device specific implementation of the reduction. These device 464 // specific implementations are templated themselves with the corresponding 465 // initial value functors and reduction functors. 466 template <typename T, typename Index, typename DeviceReductionFunctor> 467 class UnsortedSegmentReductionOp : public OpKernel { 468 public: 469 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) 470 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} 471 472 void Compute(OpKernelContext* context) override { 473 const Tensor& data = context->input(0); 474 const Tensor& segment_ids = context->input(1); 475 const Tensor& num_segments = context->input(2); 476 if (!UnsortedSegmentReductionDoValidation(this, context, data, segment_ids, 477 num_segments)) { 478 return; 479 } 480 const auto segment_flat = segment_ids.flat<Index>(); 481 const Index output_rows = 482 internal::SubtleMustCopy(num_segments.scalar<int32>()()); 483 OP_REQUIRES(context, output_rows >= 0, 484 errors::InvalidArgument("Input num_segments == ", output_rows, 485 " must not be negative.")); 486 TensorShape output_shape; 487 output_shape.AddDim(output_rows); 488 for (int i = segment_ids.dims(); i < data.dims(); i++) { 489 output_shape.AddDim(data.dim_size(i)); 490 } 491 Tensor* output = nullptr; 492 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 493 auto output_flat = output->flat_outer_dims<T>(); 494 auto data_ptr = data.template flat<T>().data(); 495 reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat, 496 data.NumElements(), data_ptr, output_flat); 497 } 498 499 protected: 500 DeviceReductionFunctor reduction_functor_; 501 }; 502 503 #define REGISTER_CPU_KERNEL_UNSORTEDSEGMENT( \ 504 name, type, index_type, initial_value_functor, reduction_functor) \ 505 REGISTER_KERNEL_BUILDER( \ 506 Name(name) \ 507 .Device(DEVICE_CPU) \ 508 .TypeConstraint<type>("T") \ 509 .TypeConstraint<index_type>("Tindices"), \ 510 UnsortedSegmentReductionOp< \ 511 type, index_type, \ 512 functor::UnsortedSegmentFunctor<CPUDevice, type, index_type, \ 513 initial_value_functor, \ 514 reduction_functor> >) 515 516 #define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \ 517 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 518 functor::Zero<type>, \ 519 functor::SumOp<type>); \ 520 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ 521 functor::Lowest<type>, \ 522 functor::MaxOp<type>); \ 523 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ 524 functor::Highest<type>, \ 525 functor::MinOp<type>); \ 526 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 527 functor::One<type>, \ 528 functor::ProdOp<type>); 529 530 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \ 531 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 532 functor::Zero<type>, \ 533 functor::SumOp<type>); \ 534 REGISTER_CPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 535 functor::One<type>, \ 536 functor::ProdOp<type>) 537 538 #define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \ 539 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \ 540 REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64) 541 542 #define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \ 543 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \ 544 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64) 545 546 TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL); 547 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64); 548 REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128); 549 550 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS 551 #undef REGISTER_CPU_KERNEL_UNSORTEDSEGMENT 552 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS 553 #undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL 554 #undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL 555 556 #if GOOGLE_CUDA 557 #define REGISTER_GPU_KERNEL_UNSORTEDSEGMENT( \ 558 name, type, index_type, initial_value_functor, reduction_kernel_functor) \ 559 REGISTER_KERNEL_BUILDER( \ 560 Name(name) \ 561 .Device(DEVICE_GPU) \ 562 .HostMemory("num_segments") \ 563 .TypeConstraint<type>("T") \ 564 .TypeConstraint<index_type>("Tindices"), \ 565 UnsortedSegmentReductionOp< \ 566 type, index_type, \ 567 functor::UnsortedSegmentFunctor<GPUDevice, type, index_type, \ 568 initial_value_functor, \ 569 reduction_kernel_functor> >) 570 571 // sum is the only op that supports all input types currently 572 #define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \ 573 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \ 574 functor::Lowest<type>, \ 575 functor::MaxOpGpu<type>); \ 576 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \ 577 functor::Highest<type>, \ 578 functor::MinOpGpu<type>); \ 579 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \ 580 functor::One<type>, \ 581 functor::ProdOpGpu<type>); 582 583 #define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \ 584 REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \ 585 functor::Zero<type>, \ 586 functor::SumOpGpu<type>); 587 588 #define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \ 589 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32); \ 590 REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64); 591 592 #define REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL(type) \ 593 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int32); \ 594 REGISTER_SUM_GPU_UNSORTED_KERNELS(type, int64); 595 596 597 TF_CALL_GPU_NUMBER_TYPES(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); 598 TF_CALL_int32(REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL); 599 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 600 TF_CALL_int32(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 601 TF_CALL_complex64(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 602 TF_CALL_complex128(REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL); 603 604 #undef REGISTER_GPU_KERNEL_UNSORTEDSEGMENT 605 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS 606 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS 607 #undef REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL 608 #undef REGISTER_SUM_GPU_UNSORTED_KERNELS_ALL 609 610 #endif // GOOGLE_CUDA 611 612 // ____________________________________________________________________________ 613 // Sparse segment reduction ops. 614 615 // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented 616 // by two dense tensors, one containing the data, and the other containing 617 // indices into the data. 618 template <typename Device, class T> 619 class SparseSegmentReductionOpBase : public OpKernel { 620 public: 621 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, 622 bool is_mean, bool is_sqrtn, 623 bool has_num_segments, T default_value) 624 : OpKernel(context), 625 is_mean_(is_mean), 626 is_sqrtn_(is_sqrtn), 627 has_num_segments_(has_num_segments), 628 default_value_(default_value) {} 629 630 void Compute(OpKernelContext* context) override { 631 const Tensor& input = context->input(0); 632 const Tensor& indices = context->input(1); 633 const Tensor& segment_ids = context->input(2); 634 635 Index output_rows = -1; 636 if (has_num_segments_) { 637 const Tensor& num_segments = context->input(3); 638 639 OP_REQUIRES( 640 context, num_segments.shape().dims() == 0, 641 errors::InvalidArgument("num_segments should be a scalar, not shape ", 642 num_segments.shape().DebugString())); 643 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); 644 OP_REQUIRES(context, output_rows >= 0, 645 errors::InvalidArgument("segment ids must be >= 0")); 646 } 647 648 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 649 errors::InvalidArgument("indices should be a vector.")); 650 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 651 errors::InvalidArgument("segment_ids should be a vector.")); 652 653 const int64 num_indices = indices.NumElements(); 654 OP_REQUIRES(context, num_indices == segment_ids.NumElements(), 655 errors::InvalidArgument( 656 "segment_ids and indices should have same size.")); 657 658 auto input_flat = input.flat_outer_dims<T>(); 659 const int64 num_col = input_flat.dimension(1); 660 const auto indices_vec = indices.vec<Index>(); 661 typedef int32 OutputRow; 662 const auto segment_vec = segment_ids.vec<OutputRow>(); 663 // Note that the current implementation assumes that segment_vec values are 664 // sorted. 665 const OutputRow last_segment_id_plus_one = 666 num_indices > 0 667 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 668 : 0; 669 if (has_num_segments_) { 670 OP_REQUIRES( 671 context, output_rows >= last_segment_id_plus_one, 672 errors::InvalidArgument("segment ids must be < num_segments")); 673 } else { 674 output_rows = last_segment_id_plus_one; 675 } 676 OP_REQUIRES(context, output_rows >= 0, 677 errors::InvalidArgument("segment ids must be >= 0")); 678 679 TensorShape output_shape = input.shape(); 680 output_shape.set_dim(0, output_rows); 681 682 // Note that we do not initialize the output buffer with a default value, so 683 // we need to explicitly set missing indices to the default value. 684 Tensor* output = nullptr; 685 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 686 if (num_indices == 0) { 687 if (output_rows > 0) { 688 output->flat_outer_dims<T>().setConstant(default_value_); 689 } 690 return; 691 } 692 OP_REQUIRES(context, output_rows > 0, 693 errors::InvalidArgument("segment ids must be >= 0")); 694 auto output_flat = output->flat_outer_dims<T>(); 695 696 int64 start = 0, end = 1; 697 // Index from which the output is not initialized. 698 OutputRow uninitialized_index = 0; 699 OutputRow out_index = internal::SubtleMustCopy(segment_vec(start)); 700 701 while (true) { 702 // We initialize next_index to 0 to avoid "warning: 'next_index' may be 703 // used uninitialized in this function" in the Mac build (since the 704 // compiler isn't smart enough to realize the code is safe). 705 OutputRow next_index = 0; 706 if (end < num_indices) { 707 next_index = internal::SubtleMustCopy(segment_vec(end)); 708 if (out_index == next_index) { 709 ++end; 710 continue; 711 } 712 // We have a new segment here. Verify that the segment ids are growing. 713 OP_REQUIRES(context, out_index < next_index, 714 errors::InvalidArgument("segment ids are not increasing")); 715 } 716 717 OP_REQUIRES( 718 context, FastBoundsCheck(out_index, output_rows), 719 errors::InvalidArgument( 720 "Segment id ", out_index, " out of range [0, ", output_rows, 721 "), possibly because 'segment_ids' input is not sorted.")); 722 723 // If there is a gap between two indices, we need to set that gap to the 724 // default value. 725 if (out_index > uninitialized_index) { 726 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 727 out_index - uninitialized_index, num_col); 728 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 729 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 730 gap_slice.setConstant(default_value_); 731 } 732 733 auto out = output_flat.template chip<0>(out_index); 734 const int bad_offset = 735 Reduce(input_flat, indices_vec, start, end - start, out); 736 OP_REQUIRES(context, bad_offset < 0, 737 errors::InvalidArgument( 738 "Bad: indices[", start + bad_offset, 739 "] == ", indices_vec(start + bad_offset), 740 " out of range [0, ", input_flat.dimension(0), ")")); 741 742 start = end; 743 ++end; 744 uninitialized_index = out_index + 1; 745 out_index = next_index; 746 if (end > num_indices) break; 747 } 748 749 // Fill the gap at the end with the default value. 750 if (uninitialized_index < output_rows) { 751 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( 752 output_rows - uninitialized_index, num_col); 753 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> 754 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); 755 gap_slice.setConstant(default_value_); 756 } 757 } 758 759 private: 760 typedef int32 Index; 761 762 int64 Reduce(const typename TTypes<T>::ConstMatrix& input_flat, 763 const typename TTypes<Index>::ConstVec& indices_vec, int64 start, 764 int64 num, 765 Eigen::TensorChippingOp<0, typename TTypes<T>::Matrix> out) { 766 #define INDEX(n, i) \ 767 const auto index##n = indices_vec(start + (i)); \ 768 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); 769 770 #define L(n) input_flat.template chip<0>(index##n) 771 772 if (num == 1) { 773 INDEX(0, 0); 774 out = L(0); 775 } else { 776 int64 r = num % 8; 777 T m(1); 778 if (is_mean_ && (num < 10)) { 779 m = T(num); 780 } 781 if (is_sqrtn_ && (num < 10)) { 782 m = T(sqrt(num)); 783 } 784 switch (r) { 785 case 2: { 786 INDEX(0, 0); 787 INDEX(1, 1); 788 out = (L(0) + L(1)) / m; 789 break; 790 } 791 case 3: { 792 INDEX(0, 0); 793 INDEX(1, 1); 794 INDEX(2, 2); 795 out = (L(0) + L(1) + L(2)) / m; 796 break; 797 } 798 case 4: { 799 INDEX(0, 0); 800 INDEX(1, 1); 801 INDEX(2, 2); 802 INDEX(3, 3); 803 out = (L(0) + L(1) + L(2) + L(3)) / m; 804 break; 805 } 806 case 5: { 807 INDEX(0, 0); 808 INDEX(1, 1); 809 INDEX(2, 2); 810 INDEX(3, 3); 811 INDEX(4, 4); 812 out = (L(0) + L(1) + L(2) + L(3) + L(4)) / m; 813 break; 814 } 815 case 6: { 816 INDEX(0, 0); 817 INDEX(1, 1); 818 INDEX(2, 2); 819 INDEX(3, 3); 820 INDEX(4, 4); 821 INDEX(5, 5); 822 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) / m; 823 break; 824 } 825 case 7: { 826 INDEX(0, 0); 827 INDEX(1, 1); 828 INDEX(2, 2); 829 INDEX(3, 3); 830 INDEX(4, 4); 831 INDEX(5, 5); 832 INDEX(6, 6); 833 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) / m; 834 break; 835 } 836 case 0: { 837 INDEX(0, 0); 838 INDEX(1, 1); 839 INDEX(2, 2); 840 INDEX(3, 3); 841 INDEX(4, 4); 842 INDEX(5, 5); 843 INDEX(6, 6); 844 INDEX(7, 7); 845 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) / m; 846 r = 8; 847 break; 848 } 849 case 1: { 850 INDEX(0, 0); 851 INDEX(1, 1); 852 INDEX(2, 2); 853 INDEX(3, 3); 854 INDEX(4, 4); 855 INDEX(5, 5); 856 INDEX(6, 6); 857 INDEX(7, 7); 858 INDEX(8, 8); 859 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) / 860 m; 861 r = 9; 862 break; 863 } 864 } 865 for (; r < num; r += 8) { 866 INDEX(0, r); 867 INDEX(1, r + 1); 868 INDEX(2, r + 2); 869 INDEX(3, r + 3); 870 INDEX(4, r + 4); 871 INDEX(5, r + 5); 872 INDEX(6, r + 6); 873 INDEX(7, r + 7); 874 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); 875 } 876 if (is_mean_ && num >= 10) { 877 out = out / static_cast<T>(num); 878 } 879 if (is_sqrtn_ && num >= 10) { 880 out = out / static_cast<T>(sqrt(num)); 881 } 882 } 883 884 return -1; 885 #undef L 886 #undef INDEX 887 } 888 889 const bool is_mean_; 890 const bool is_sqrtn_; 891 const bool has_num_segments_; 892 const T default_value_; 893 }; 894 895 template <typename Device, class T> 896 class SparseSegmentReductionMeanOp 897 : public SparseSegmentReductionOpBase<Device, T> { 898 public: 899 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) 900 : SparseSegmentReductionOpBase<Device, T>( 901 context, true /*is_mean*/, false /*is_sqrtn*/, 902 false /* has_num_segments */, T(0) /* default_value */) {} 903 }; 904 905 template <typename Device, class T> 906 class SparseSegmentReductionMeanWithNumSegmentsOp 907 : public SparseSegmentReductionOpBase<Device, T> { 908 public: 909 explicit SparseSegmentReductionMeanWithNumSegmentsOp( 910 OpKernelConstruction* context) 911 : SparseSegmentReductionOpBase<Device, T>( 912 context, true /*is_mean*/, false /*is_sqrtn*/, 913 true /* has_num_segments */, T(0) /* default_value */) {} 914 }; 915 916 template <typename Device, class T> 917 class SparseSegmentReductionSqrtNOp 918 : public SparseSegmentReductionOpBase<Device, T> { 919 public: 920 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) 921 : SparseSegmentReductionOpBase<Device, T>( 922 context, false /*is_mean*/, true /*is_sqrtn*/, 923 false /* has_num_segments */, T(0) /* default_value */) {} 924 }; 925 926 template <typename Device, class T> 927 class SparseSegmentReductionSqrtNWithNumSegmentsOp 928 : public SparseSegmentReductionOpBase<Device, T> { 929 public: 930 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( 931 OpKernelConstruction* context) 932 : SparseSegmentReductionOpBase<Device, T>( 933 context, false /*is_mean*/, true /*is_sqrtn*/, 934 true /* has_num_segments */, T(0) /* default_value */) {} 935 }; 936 937 template <typename Device, class T> 938 class SparseSegmentReductionSumOp 939 : public SparseSegmentReductionOpBase<Device, T> { 940 public: 941 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) 942 : SparseSegmentReductionOpBase<Device, T>( 943 context, false /*is_mean*/, false /*is_sqrtn*/, 944 false /* has_num_segments */, T(0) /* default_value */) {} 945 }; 946 947 template <typename Device, class T> 948 class SparseSegmentReductionSumWithNumSegmentsOp 949 : public SparseSegmentReductionOpBase<Device, T> { 950 public: 951 explicit SparseSegmentReductionSumWithNumSegmentsOp( 952 OpKernelConstruction* context) 953 : SparseSegmentReductionOpBase<Device, T>( 954 context, false /*is_mean*/, false /*is_sqrtn*/, 955 true /* has_num_segments */, T(0) /* default_value */) {} 956 }; 957 958 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 959 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \ 960 .Device(DEVICE_CPU) \ 961 .TypeConstraint<type>("T") \ 962 .TypeConstraint<int32>("Tidx"), \ 963 SparseSegmentReductionSumOp<CPUDevice, type>); \ 964 REGISTER_KERNEL_BUILDER( \ 965 Name("SparseSegmentSumWithNumSegments") \ 966 .Device(DEVICE_CPU) \ 967 .TypeConstraint<type>("T") \ 968 .TypeConstraint<int32>("Tidx"), \ 969 SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>); 970 TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS); 971 #undef REGISTER_CPU_SPARSE_KERNELS 972 973 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 974 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \ 975 .Device(DEVICE_CPU) \ 976 .TypeConstraint<type>("T") \ 977 .TypeConstraint<int32>("Tidx"), \ 978 SparseSegmentReductionMeanOp<CPUDevice, type>); \ 979 REGISTER_KERNEL_BUILDER( \ 980 Name("SparseSegmentMeanWithNumSegments") \ 981 .Device(DEVICE_CPU) \ 982 .TypeConstraint<type>("T") \ 983 .TypeConstraint<int32>("Tidx"), \ 984 SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>); 985 REGISTER_CPU_SPARSE_KERNELS(float); 986 REGISTER_CPU_SPARSE_KERNELS(double); 987 #undef REGISTER_CPU_SPARSE_KERNELS 988 989 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 990 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \ 991 .Device(DEVICE_CPU) \ 992 .TypeConstraint<type>("T") \ 993 .TypeConstraint<int32>("Tidx"), \ 994 SparseSegmentReductionSqrtNOp<CPUDevice, type>); \ 995 REGISTER_KERNEL_BUILDER( \ 996 Name("SparseSegmentSqrtNWithNumSegments") \ 997 .Device(DEVICE_CPU) \ 998 .TypeConstraint<type>("T") \ 999 .TypeConstraint<int32>("Tidx"), \ 1000 SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>); 1001 REGISTER_CPU_SPARSE_KERNELS(float); 1002 REGISTER_CPU_SPARSE_KERNELS(double); 1003 #undef REGISTER_CPU_SPARSE_KERNELS 1004 1005 template <class T> 1006 class SparseSegmentGradOpBase : public OpKernel { 1007 public: 1008 explicit SparseSegmentGradOpBase(OpKernelConstruction* context, bool is_sqrtn) 1009 : OpKernel(context), is_sqrtn_(is_sqrtn) {} 1010 1011 void Compute(OpKernelContext* context) override { 1012 const Tensor& input = context->input(0); 1013 const Tensor& indices = context->input(1); 1014 const Tensor& segment_ids = context->input(2); 1015 const Tensor& output_dim0 = context->input(3); 1016 1017 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), 1018 errors::InvalidArgument("indices should be a vector.")); 1019 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), 1020 errors::InvalidArgument("segment_ids should be a vector.")); 1021 OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()), 1022 errors::InvalidArgument("output_dim0 should be a scalar.")); 1023 1024 const int64 N = indices.NumElements(); 1025 OP_REQUIRES(context, N == segment_ids.NumElements(), 1026 errors::InvalidArgument( 1027 "segment_ids and indices should have same size.")); 1028 typedef int32 SegmentId; 1029 const SegmentId M = 1030 internal::SubtleMustCopy(output_dim0.scalar<SegmentId>()()); 1031 1032 auto input_flat = input.flat_outer_dims<T>(); 1033 typedef int32 Index; 1034 const auto indices_vec = indices.vec<Index>(); 1035 const auto segment_vec = segment_ids.vec<SegmentId>(); 1036 1037 TensorShape output_shape = input.shape(); 1038 output_shape.set_dim(0, M); 1039 Tensor* output = nullptr; 1040 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 1041 if (M == 0 || N == 0) return; 1042 1043 // Note that similar to SparseSegmentMean, we assume that segment_vec is 1044 // already sorted and has non-negative values. 1045 const SegmentId num_segments = input.dim_size(0); 1046 const SegmentId last_segment_id_plus_one = 1047 internal::SubtleMustCopy(segment_vec(N - 1)) + 1; 1048 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, 1049 errors::InvalidArgument("Invalid number of segments")); 1050 1051 // Compute scaling factors for input. 1052 std::vector<double> scaling(num_segments, 0.0); 1053 for (int64 i = 0; i < N; ++i) { 1054 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1055 OP_REQUIRES( 1056 context, FastBoundsCheck(idx, num_segments), 1057 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1058 num_segments, ").")); 1059 scaling[idx] += 1; 1060 } 1061 for (size_t i = 0; i < scaling.size(); ++i) { 1062 if (is_sqrtn_) { 1063 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); 1064 } else { 1065 scaling[i] = 1.0 / std::max(scaling[i], 1.0); 1066 } 1067 } 1068 1069 auto output_flat = output->flat_outer_dims<T>(); 1070 output_flat.setZero(); 1071 std::vector<bool> is_modified(M, false); 1072 1073 for (int64 i = 0; i < N; ++i) { 1074 const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); 1075 OP_REQUIRES(context, FastBoundsCheck(output_idx, M), 1076 errors::InvalidArgument("Index ", output_idx, 1077 " out of range [0, ", M, ").")); 1078 1079 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); 1080 OP_REQUIRES( 1081 context, FastBoundsCheck(idx, num_segments), 1082 errors::InvalidArgument("Segment id ", idx, " out of range [0, ", 1083 num_segments, ").")); 1084 1085 const T scale = static_cast<T>(scaling[idx]); 1086 if (is_modified[output_idx]) { 1087 if (scale == 1.0) { 1088 output_flat.template chip<0>(output_idx) += 1089 input_flat.template chip<0>(idx); 1090 } else { 1091 output_flat.template chip<0>(output_idx) += 1092 input_flat.template chip<0>(idx) * scale; 1093 } 1094 } else { 1095 if (scale == 1.0) { 1096 output_flat.template chip<0>(output_idx) = 1097 input_flat.template chip<0>(idx); 1098 } else { 1099 output_flat.template chip<0>(output_idx) = 1100 input_flat.template chip<0>(idx) * scale; 1101 } 1102 } 1103 is_modified[output_idx] = true; 1104 } 1105 } 1106 1107 private: 1108 const bool is_sqrtn_; 1109 }; 1110 1111 template <class T> 1112 class SparseSegmentMeanGradOp : public SparseSegmentGradOpBase<T> { 1113 public: 1114 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) 1115 : SparseSegmentGradOpBase<T>(context, false /*is_sqrtn*/) {} 1116 }; 1117 1118 template <class T> 1119 class SparseSegmentSqrtNGradOp : public SparseSegmentGradOpBase<T> { 1120 public: 1121 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) 1122 : SparseSegmentGradOpBase<T>(context, true /*is_sqrtn*/) {} 1123 }; 1124 1125 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 1126 REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \ 1127 .Device(DEVICE_CPU) \ 1128 .TypeConstraint<type>("T") \ 1129 .TypeConstraint<int32>("Tidx"), \ 1130 SparseSegmentMeanGradOp<type>); 1131 REGISTER_CPU_SPARSE_KERNELS(float); 1132 REGISTER_CPU_SPARSE_KERNELS(double); 1133 #undef REGISTER_CPU_SPARSE_KERNELS 1134 1135 #define REGISTER_CPU_SPARSE_KERNELS(type) \ 1136 REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtNGrad") \ 1137 .Device(DEVICE_CPU) \ 1138 .TypeConstraint<type>("T") \ 1139 .TypeConstraint<int32>("Tidx"), \ 1140 SparseSegmentSqrtNGradOp<type>); 1141 REGISTER_CPU_SPARSE_KERNELS(float); 1142 REGISTER_CPU_SPARSE_KERNELS(double); 1143 #undef REGISTER_CPU_SPARSE_KERNELS 1144 } // namespace tensorflow 1145