1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/kernels/segment_reduction_ops.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/util/cuda_device_functions.h" 23 #include "tensorflow/core/util/cuda_kernel_helper.h" 24 25 26 namespace tensorflow { 27 28 using GPUDevice = Eigen::GpuDevice; 29 30 // SortedSegmentSumFunctor kernel reduces input data just as 31 // UnsortedSegmentSumCustomKernel does except that input data 32 // is partitioned along the outer reduction dimension. This is 33 // because consecutive rows (elements in a row share the same 34 // outer dimension index) in the flattened 2D input data likely 35 // belong to the same segment in sorted segment sum operation. 36 // Therefore such partitioning strategy has two advantages over 37 // the UnsortedSegmentSumFunctor kernel: 38 // 1. Each thread reduces across multiple rows before writing 39 // answers to the global memory, we can therefore 40 // write reduction results to global memory less often. 41 // 2. We may know that the current thread is the only contributor 42 // to an output element because of the increasing nature of segment 43 // ids. In such cases, we do not need to use atomic operations 44 // to write results to global memory. 45 // In the flattened view of input data (with only outer and inner 46 // dimension), every thread processes a strip of input data of 47 // size OuterDimTileSize x 1. This strip runs across multiple 48 // rows of input data and all reduction elements share one inner 49 // dimension index. 50 template <typename T, typename Index, int OuterDimTileSize> 51 __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, 52 const Index inner_dim_size, 53 const Index output_outer_dim_size, 54 const Index* segment_ids, 55 const T* input, T* output, 56 const Index total_stripe_count) { 57 for (int stripe_index : CudaGridRangeX(total_stripe_count)) { 58 const Index segment_offset = stripe_index % inner_dim_size; 59 const Index input_outer_dim_index_base = 60 stripe_index / inner_dim_size * Index(OuterDimTileSize); 61 62 T sum = T(0); 63 Index first_segment_id = segment_ids[input_outer_dim_index_base]; 64 Index last_output_segment_id = output_outer_dim_size; 65 66 const Index actual_stripe_height = 67 min(Index(OuterDimTileSize), 68 input_outer_dim_size - input_outer_dim_index_base); 69 for (Index j = 0; j < actual_stripe_height; j++) { 70 Index current_output_segment_id = 71 segment_ids[input_outer_dim_index_base + j]; 72 // Decide whether to write result to global memory. 73 // Result is only written to global memory if we move 74 // to another segment. Otherwise we can keep accumulating 75 // locally. 76 if (current_output_segment_id > last_output_segment_id) { 77 const Index output_index = 78 last_output_segment_id * inner_dim_size + segment_offset; 79 // decide whether to write result to global memory using atomic 80 // operations 81 if (last_output_segment_id == first_segment_id) { 82 CudaAtomicAdd(output + output_index, sum); 83 } else { 84 *(output + output_index) = sum; 85 } 86 sum = T(0); 87 } 88 sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + 89 segment_offset); 90 last_output_segment_id = current_output_segment_id; 91 } 92 // For the last result in a strip, always write using atomic operations 93 // due to possible race conditions with threads computing 94 // the following strip. 95 const Index output_index = 96 last_output_segment_id * inner_dim_size + segment_offset; 97 CudaAtomicAdd(output + output_index, sum); 98 } 99 } 100 101 // UnsortedSegmentSumKernel processes 'input_total_size' elements. 102 // Each element is mapped from input to output by a combination of its 103 // 'segment_ids' mapping and 'inner_dim_size'. 104 template <typename T, typename Index, typename KernelReductionFunctor> 105 __global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size, 106 const Index inner_dim_size, 107 const Index output_outer_dim_size, 108 const Index* segment_ids, 109 const T* input, T* output) { 110 const Index input_total_size = input_outer_dim_size * inner_dim_size; 111 const Index output_total_size = output_outer_dim_size * inner_dim_size; 112 for (int input_index : CudaGridRangeX(input_total_size)) { 113 const Index input_segment_index = input_index / inner_dim_size; 114 const Index segment_offset = input_index % inner_dim_size; 115 const Index output_segment_index = segment_ids[input_segment_index]; 116 if (output_segment_index < 0 || output_segment_index >= output_total_size) { 117 continue; 118 } 119 const Index output_index = 120 output_segment_index * inner_dim_size + segment_offset; 121 KernelReductionFunctor()(output + output_index, ldg(input + input_index)); 122 } 123 } 124 125 namespace functor { 126 127 template <typename T, typename Index> 128 void SegmentSumFunctor<T, Index>::operator()( 129 OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, 130 const TensorShape& segment_ids_shape, 131 typename TTypes<Index>::ConstFlat segment_ids, const Index data_size, 132 const T* data, typename TTypes<T, 2>::Tensor output) { 133 if (output.size() == 0) { 134 return; 135 } 136 // Set 'output' to zeros. 137 CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); 138 SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 139 output.size(), output.data()); 140 if (data_size == 0 || segment_ids_shape.num_elements() == 0) { 141 return; 142 } 143 144 // Launch kernel to compute sorted segment sum. 145 // Notes: 146 // *) 'input_total_size' is the total number of elements to process. 147 // *) 'segment_ids.shape' is a prefix of data's shape. 148 // *) 'input_outer_dim_size' is the total number of segments to process. 149 const Index input_total_size = data_size; 150 const Index input_outer_dim_size = segment_ids.dimension(0); 151 const Index input_inner_dim_size = input_total_size / input_outer_dim_size; 152 153 const int OuterDimTileSize = 8; 154 155 const Index input_outer_dim_num_stripe = 156 Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize)); 157 158 const Index total_stripe_count = 159 input_inner_dim_size * input_outer_dim_num_stripe; 160 161 config = GetCudaLaunchConfig(total_stripe_count, d); 162 SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize> 163 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 164 input_outer_dim_size, input_inner_dim_size, output_rows, 165 segment_ids.data(), data, output.data(), total_stripe_count); 166 } 167 168 template <typename T, typename Index, typename InitialValueF, 169 typename ReductionF> 170 struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> { 171 void operator()(OpKernelContext* ctx, const Index num_segments, 172 const TensorShape& segment_ids_shape, 173 typename TTypes<Index>::ConstFlat segment_ids, 174 const Index data_size, const T* data, 175 typename TTypes<T, 2>::Tensor output) { 176 if (output.size() == 0) { 177 return; 178 } 179 // Set 'output' to initial value. 180 GPUDevice d = ctx->template eigen_device<GPUDevice>(); 181 CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); 182 SetToValue<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 183 output.size(), output.data(), InitialValueF()()); 184 if (data_size == 0 || segment_ids_shape.num_elements() == 0) { 185 return; 186 } 187 // Launch kernel to compute unsorted segment reduction. 188 // Notes: 189 // *) 'data_size' is the total number of elements to process. 190 // *) 'segment_ids.shape' is a prefix of data's shape. 191 // *) 'input_outer_dim_size' is the total number of segments to process. 192 const Index input_outer_dim_size = segment_ids.dimension(0); 193 const Index input_inner_dim_size = data_size / input_outer_dim_size; 194 config = GetCudaLaunchConfig(data_size, d); 195 196 UnsortedSegmentCustomKernel<T, Index, ReductionF> 197 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 198 input_outer_dim_size, input_inner_dim_size, num_segments, 199 segment_ids.data(), data, output.data()); 200 } 201 }; 202 203 #define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \ 204 template struct SegmentSumFunctor<T, Index> 205 206 #define DEFINE_SORTED_GPU_SPECS(T) \ 207 DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \ 208 DEFINE_SORTED_GPU_SPECS_INDEX(T, int64); 209 210 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); 211 212 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \ 213 template struct UnsortedSegmentFunctor< \ 214 GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>; \ 215 template struct UnsortedSegmentFunctor< \ 216 GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>; \ 217 template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \ 218 functor::ProdOpGpu<T>>; 219 220 // sum is the only op that supports all input types currently 221 #define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \ 222 template struct UnsortedSegmentFunctor< \ 223 GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>; 224 225 #define DEFINE_REAL_GPU_SPECS(T) \ 226 DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \ 227 DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64); 228 229 #define DEFINE_SUM_GPU_SPECS(T) \ 230 DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \ 231 DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64); 232 233 TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS); 234 TF_CALL_int32(DEFINE_REAL_GPU_SPECS); 235 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS); 236 TF_CALL_int32(DEFINE_SUM_GPU_SPECS); 237 TF_CALL_complex64(DEFINE_SUM_GPU_SPECS); 238 TF_CALL_complex128(DEFINE_SUM_GPU_SPECS); 239 240 #undef DEFINE_SORTED_GPU_SPECS_INDEX 241 #undef DEFINE_SORTED_GPU_SPECS 242 #undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX 243 #undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX 244 #undef DEFINE_REAL_GPU_SPECS 245 #undef DEFINE_SUM_GPU_SPECS 246 247 } // namespace functor 248 } // namespace tensorflow 249 250 #endif // GOOGLE_CUDA 251