1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/kernels/scatter_nd_op.h" 23 #include "tensorflow/core/platform/types.h" 24 #include "tensorflow/core/util/cuda_kernel_helper.h" 25 26 namespace tensorflow { 27 28 typedef Eigen::GpuDevice GPUDevice; 29 30 namespace { 31 32 template <typename T, scatter_nd_op::UpdateOp Op> 33 struct LeftUpdate { 34 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val); 35 }; 36 37 template <typename T> 38 struct LeftUpdate<T, scatter_nd_op::UpdateOp::ASSIGN> { 39 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { 40 *out = val; 41 } 42 }; 43 44 template <typename T> 45 struct LeftUpdate<T, scatter_nd_op::UpdateOp::ADD> { 46 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { 47 CudaAtomicAdd(out, val); 48 } 49 }; 50 51 template <typename T> 52 struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> { 53 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(T* out, const T& val) { 54 CudaAtomicSub(out, val); 55 } 56 }; 57 58 // Specializations for std::complex, updating real and imaginary part 59 // individually. Even though this is not an atomic op anymore, it is safe 60 // because there is only one type of op per kernel. 61 template <typename T> 62 struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> { 63 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( 64 std::complex<T>* out, const std::complex<T>& val) { 65 T* ptr = reinterpret_cast<T*>(out); 66 CudaAtomicAdd(ptr, val.real()); 67 CudaAtomicAdd(ptr, val.imag()); 68 } 69 }; 70 71 template <typename T> 72 struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> { 73 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()( 74 std::complex<T>* out, const std::complex<T>& val) { 75 LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val); 76 } 77 }; 78 79 } // namespace 80 81 template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> 82 __global__ void ScatterNdOpKernel( 83 const Index* indices, const T* updates, T* out, 84 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 85 const Eigen::array<int64, IXDIM> batch_strides, const int64 num_indices, 86 const Index slice_size) { 87 auto update = LeftUpdate<T, op>(); 88 89 CUDA_1D_KERNEL_LOOP(index, num_indices) { 90 Index i = 0; 91 bool out_of_bounds = false; 92 #pragma unroll 93 for (int dim = 0; dim < IXDIM; ++dim) { 94 int offset = (IXDIM * index + dim); 95 const Index ix_d = internal::SubtleMustCopy(ldg(indices + offset)); 96 out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); 97 i += ix_d * batch_strides[dim] * slice_size; 98 } 99 if (!out_of_bounds) { 100 #pragma unroll 101 for (int si = 0; si < slice_size; si++) { 102 update(out + i + si, ldg(updates + (index * slice_size + si))); 103 } 104 } 105 } 106 } 107 108 namespace functor { 109 110 // Functor used by ScatterOp to do the computations. 111 template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM> 112 struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> { 113 Index operator()( 114 const GPUDevice& d, const Index slice_size, 115 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, 116 typename TTypes<T, 2>::Tensor Tparams, 117 typename TTypes<Index, 2>::ConstTensor Tindices, 118 typename TTypes<T, 2>::ConstTensor Tupdates, 119 typename TTypes<T, 2>::Tensor Toutput) { 120 // TODO(ebrevdo): The performance of this for small indices (large 121 // slices) is poor. Write a kernel whose splitting is 122 // independent of the slice size. Same for CPU. See the 123 // gather_nd kernel for an example. 124 125 const Eigen::DenseIndex batch_size = Tindices.dimension(0); 126 127 // Index batch_strides[IXDIM]; 128 Eigen::array<int64, IXDIM> batch_strides; 129 for (int dim = IXDIM - 1; dim >= 0; --dim) { 130 if (dim == IXDIM - 1) { 131 batch_strides[dim] = 1; 132 } else { 133 batch_strides[dim] = 134 batch_strides[dim + 1] * output_shape_prefix[dim + 1]; 135 } 136 } 137 138 CudaLaunchConfig config = GetCudaLaunchConfig(Toutput.size(), d); 139 // clang-format off 140 ScatterNdOpKernel<T, Index, op, IXDIM> 141 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 142 Tindices.data(), Tupdates.data(), Toutput.data(), output_shape_prefix, 143 batch_strides, batch_size, slice_size); 144 // clang-format on 145 146 return -1; 147 } 148 }; 149 150 } // namespace functor 151 152 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ 153 template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>; 154 155 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ 156 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ 157 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \ 158 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \ 159 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \ 160 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \ 161 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \ 162 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7); 163 164 #define DECLARE_GPU_SPECS_INDEX(T, Index) \ 165 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ 166 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \ 167 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB) 168 169 #define DECLARE_GPU_SPECS(T) \ 170 DECLARE_GPU_SPECS_INDEX(T, int32); \ 171 DECLARE_GPU_SPECS_INDEX(T, int64) 172 173 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 174 TF_CALL_complex64(DECLARE_GPU_SPECS); 175 TF_CALL_complex128(DECLARE_GPU_SPECS); 176 177 #undef DECLARE_GPU_SPECS 178 #undef DECLARE_GPU_SPECS_INDEX 179 #undef DECLARE_GPU_SPECS_INDEX_OP 180 181 } // namespace tensorflow 182 183 #endif // GOOGLE_CUDA 184