Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/kernels/segment_reduction_ops.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/util/cuda_device_functions.h"
     23 #include "tensorflow/core/util/cuda_kernel_helper.h"
     24 
     25 
     26 namespace tensorflow {
     27 
     28 using GPUDevice = Eigen::GpuDevice;
     29 
     30 // SortedSegmentSumFunctor kernel reduces input data just as
     31 // UnsortedSegmentSumCustomKernel does except that input data
     32 // is partitioned along the outer reduction dimension. This is
     33 // because consecutive rows (elements in a row share the same
     34 // outer dimension index) in the flattened 2D input data likely
     35 // belong to the same segment in sorted segment sum operation.
     36 // Therefore such partitioning strategy has two advantages over
     37 // the UnsortedSegmentSumFunctor kernel:
     38 // 1. Each thread reduces across multiple rows before writing
     39 // answers to the global memory, we can therefore
     40 // write reduction results to global memory less often.
     41 // 2. We may know that the current thread is the only contributor
     42 // to an output element because of the increasing nature of segment
     43 // ids. In such cases, we do not need to use atomic operations
     44 // to write results to global memory.
     45 // In the flattened view of input data (with only outer and inner
     46 // dimension), every thread processes a strip of input data of
     47 // size OuterDimTileSize x 1. This strip runs across multiple
     48 // rows of input data and all reduction elements share one inner
     49 // dimension index.
     50 template <typename T, typename Index, int OuterDimTileSize>
     51 __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
     52                                              const Index inner_dim_size,
     53                                              const Index output_outer_dim_size,
     54                                              const Index* segment_ids,
     55                                              const T* input, T* output,
     56                                              const Index total_stripe_count) {
     57   for (int stripe_index : CudaGridRangeX(total_stripe_count)) {
     58     const Index segment_offset = stripe_index % inner_dim_size;
     59     const Index input_outer_dim_index_base =
     60         stripe_index / inner_dim_size * Index(OuterDimTileSize);
     61 
     62     T sum = T(0);
     63     Index first_segment_id = segment_ids[input_outer_dim_index_base];
     64     Index last_output_segment_id = output_outer_dim_size;
     65 
     66     const Index actual_stripe_height =
     67         min(Index(OuterDimTileSize),
     68             input_outer_dim_size - input_outer_dim_index_base);
     69     for (Index j = 0; j < actual_stripe_height; j++) {
     70       Index current_output_segment_id =
     71           segment_ids[input_outer_dim_index_base + j];
     72       // Decide whether to write result to global memory.
     73       // Result is only written to global memory if we move
     74       // to another segment. Otherwise we can keep accumulating
     75       // locally.
     76       if (current_output_segment_id > last_output_segment_id) {
     77         const Index output_index =
     78             last_output_segment_id * inner_dim_size + segment_offset;
     79         // decide whether to write result to global memory using atomic
     80         // operations
     81         if (last_output_segment_id == first_segment_id) {
     82           CudaAtomicAdd(output + output_index, sum);
     83         } else {
     84           *(output + output_index) = sum;
     85         }
     86         sum = T(0);
     87       }
     88       sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
     89                  segment_offset);
     90       last_output_segment_id = current_output_segment_id;
     91     }
     92     // For the last result in a strip, always write using atomic operations
     93     // due to possible race conditions with threads computing
     94     // the following strip.
     95     const Index output_index =
     96         last_output_segment_id * inner_dim_size + segment_offset;
     97     CudaAtomicAdd(output + output_index, sum);
     98   }
     99 }
    100 
    101 // UnsortedSegmentSumKernel processes 'input_total_size' elements.
    102 // Each element is mapped from input to output by a combination of its
    103 // 'segment_ids' mapping and 'inner_dim_size'.
    104 template <typename T, typename Index, typename KernelReductionFunctor>
    105 __global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size,
    106                                             const Index inner_dim_size,
    107                                             const Index output_outer_dim_size,
    108                                             const Index* segment_ids,
    109                                             const T* input, T* output) {
    110   const Index input_total_size = input_outer_dim_size * inner_dim_size;
    111   const Index output_total_size = output_outer_dim_size * inner_dim_size;
    112   for (int input_index : CudaGridRangeX(input_total_size)) {
    113     const Index input_segment_index = input_index / inner_dim_size;
    114     const Index segment_offset = input_index % inner_dim_size;
    115     const Index output_segment_index = segment_ids[input_segment_index];
    116     if (output_segment_index < 0 || output_segment_index >= output_total_size) {
    117       continue;
    118     }
    119     const Index output_index =
    120         output_segment_index * inner_dim_size + segment_offset;
    121     KernelReductionFunctor()(output + output_index, ldg(input + input_index));
    122   }
    123 }
    124 
    125 namespace functor {
    126 
    127 template <typename T, typename Index>
    128 void SegmentSumFunctor<T, Index>::operator()(
    129     OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
    130     const TensorShape& segment_ids_shape,
    131     typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
    132     const T* data, typename TTypes<T, 2>::Tensor output) {
    133   if (output.size() == 0) {
    134     return;
    135   }
    136   // Set 'output' to zeros.
    137   CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
    138   SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    139       output.size(), output.data());
    140   if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
    141     return;
    142   }
    143 
    144   // Launch kernel to compute sorted segment sum.
    145   // Notes:
    146   // *) 'input_total_size' is the total number of elements to process.
    147   // *) 'segment_ids.shape' is a prefix of data's shape.
    148   // *) 'input_outer_dim_size' is the total number of segments to process.
    149   const Index input_total_size = data_size;
    150   const Index input_outer_dim_size = segment_ids.dimension(0);
    151   const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
    152 
    153   const int OuterDimTileSize = 8;
    154 
    155   const Index input_outer_dim_num_stripe =
    156       Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
    157 
    158   const Index total_stripe_count =
    159       input_inner_dim_size * input_outer_dim_num_stripe;
    160 
    161   config = GetCudaLaunchConfig(total_stripe_count, d);
    162   SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize>
    163       <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    164           input_outer_dim_size, input_inner_dim_size, output_rows,
    165           segment_ids.data(), data, output.data(), total_stripe_count);
    166 }
    167 
    168 template <typename T, typename Index, typename InitialValueF,
    169           typename ReductionF>
    170 struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
    171   void operator()(OpKernelContext* ctx, const Index num_segments,
    172                   const TensorShape& segment_ids_shape,
    173                   typename TTypes<Index>::ConstFlat segment_ids,
    174                   const Index data_size, const T* data,
    175                   typename TTypes<T, 2>::Tensor output) {
    176     if (output.size() == 0) {
    177       return;
    178     }
    179     // Set 'output' to initial value.
    180     GPUDevice d = ctx->template eigen_device<GPUDevice>();
    181     CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
    182     SetToValue<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    183         output.size(), output.data(), InitialValueF()());
    184     if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
    185       return;
    186     }
    187     // Launch kernel to compute unsorted segment reduction.
    188     // Notes:
    189     // *) 'data_size' is the total number of elements to process.
    190     // *) 'segment_ids.shape' is a prefix of data's shape.
    191     // *) 'input_outer_dim_size' is the total number of segments to process.
    192     const Index input_outer_dim_size = segment_ids.dimension(0);
    193     const Index input_inner_dim_size = data_size / input_outer_dim_size;
    194     config = GetCudaLaunchConfig(data_size, d);
    195 
    196     UnsortedSegmentCustomKernel<T, Index, ReductionF>
    197         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    198             input_outer_dim_size, input_inner_dim_size, num_segments,
    199             segment_ids.data(), data, output.data());
    200   }
    201 };
    202 
    203 #define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
    204   template struct SegmentSumFunctor<T, Index>
    205 
    206 #define DEFINE_SORTED_GPU_SPECS(T)         \
    207   DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
    208   DEFINE_SORTED_GPU_SPECS_INDEX(T, int64);
    209 
    210 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
    211 
    212 #define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index)                         \
    213   template struct UnsortedSegmentFunctor<                                      \
    214       GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>;          \
    215   template struct UnsortedSegmentFunctor<                                      \
    216       GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>;         \
    217   template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
    218                                          functor::ProdOpGpu<T>>;
    219 
    220 // sum is the only op that supports all input types currently
    221 #define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
    222   template struct UnsortedSegmentFunctor<             \
    223       GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>;
    224 
    225 #define DEFINE_REAL_GPU_SPECS(T)                  \
    226   DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
    227   DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int64);
    228 
    229 #define DEFINE_SUM_GPU_SPECS(T)                  \
    230   DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int32); \
    231   DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, int64);
    232 
    233 TF_CALL_GPU_NUMBER_TYPES(DEFINE_REAL_GPU_SPECS);
    234 TF_CALL_int32(DEFINE_REAL_GPU_SPECS);
    235 TF_CALL_GPU_NUMBER_TYPES(DEFINE_SUM_GPU_SPECS);
    236 TF_CALL_int32(DEFINE_SUM_GPU_SPECS);
    237 TF_CALL_complex64(DEFINE_SUM_GPU_SPECS);
    238 TF_CALL_complex128(DEFINE_SUM_GPU_SPECS);
    239 
    240 #undef DEFINE_SORTED_GPU_SPECS_INDEX
    241 #undef DEFINE_SORTED_GPU_SPECS
    242 #undef DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX
    243 #undef DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX
    244 #undef DEFINE_REAL_GPU_SPECS
    245 #undef DEFINE_SUM_GPU_SPECS
    246 
    247 }  // namespace functor
    248 }  // namespace tensorflow
    249 
    250 #endif  // GOOGLE_CUDA
    251