Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include <stdio.h>
     21 
     22 #include "tensorflow/core/kernels/split_lib.h"
     23 
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/kernels/cuda_device_array_gpu.h"
     27 #include "tensorflow/core/util/cuda_kernel_helper.h"
     28 
     29 namespace tensorflow {
     30 namespace functor {
     31 
     32 template <typename Device, typename T>
     33 void Split<Device, T>::operator()(
     34     const Device& d, typename TTypes<T, 3>::Tensor output,
     35     typename TTypes<T, 3>::ConstTensor input,
     36     const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
     37     const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
     38   To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes);
     39 }
     40 
     41 template <typename Device, typename T>
     42 void SplitCustom<Device, T>::operator()(
     43     const Device& d, typename TTypes<T, 2>::Tensor output,
     44     typename TTypes<T, 2>::ConstTensor input,
     45     const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_indices,
     46     const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_sizes) {
     47   To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes);
     48 }
     49 
     50 #define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>;
     51 
     52 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
     53 TF_CALL_complex64(DEFINE_GPU_KERNELS);
     54 TF_CALL_complex128(DEFINE_GPU_KERNELS);
     55 TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
     56 
     57 #undef DEFINE_GPU_KERNELS
     58 #define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>;
     59 
     60 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
     61 TF_CALL_complex64(DEFINE_GPU_KERNELS);
     62 TF_CALL_complex128(DEFINE_GPU_KERNELS);
     63 TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
     64 
     65 #undef DEFINE_GPU_KERNELS
     66 
     67 }  // namespace functor
     68 
     69 namespace {
     70 
     71 template <typename T>
     72 __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
     73                               int32 split_dim_size, int32 suffix_dim_size,
     74                               CudaDeviceArrayStruct<T*> output_ptr_data) {
     75   const int32 num_split = output_ptr_data.size;
     76   T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
     77 
     78   eigen_assert(blockDim.y == 1);
     79   eigen_assert(blockDim.z == 1);
     80   eigen_assert(split_dim_size % num_split == 0);
     81 
     82   int32 size = prefix_dim_size * split_dim_size * suffix_dim_size;
     83   int32 piece_size = split_dim_size / num_split;
     84 
     85   CUDA_1D_KERNEL_LOOP(offset, size) {
     86     // Calculate the index into input from offset.
     87     int32 i = offset / (split_dim_size * suffix_dim_size);
     88     int32 j = (offset % (split_dim_size * suffix_dim_size)) / suffix_dim_size;
     89     int32 k = offset % suffix_dim_size;
     90 
     91     // Find the output buffer that should be written to.
     92     T* output_ptr = output_ptrs[j / piece_size];
     93     // output_ptr is pointing to an array of size
     94     //  [prefix_dim_size][piece_size][suffix_dim_size].
     95     //
     96     // output_ptr[i][j % piece_size][k] = input[offset];
     97     // Linearize (i, j % piece_size, k) into an offset.
     98     int32 output_offset = i * piece_size * suffix_dim_size +
     99                           (j % piece_size) * suffix_dim_size + k;
    100     *(output_ptr + output_offset) = ldg(input + offset);
    101   }
    102 }
    103 
    104 }  // namespace
    105 
    106 // cannot be in anonymous namespace due to extern shared memory
    107 // very similar to the concat kernel except the input/output logic
    108 // is reversed
    109 template <typename T, typename IntType, bool useSmem>
    110 __global__ void split_v_kernel(const T* input_ptr,
    111                                CudaDeviceArrayStruct<IntType> output_scan,
    112                                IntType total_rows, IntType total_cols,
    113                                CudaDeviceArrayStruct<T*> output_ptr_data) {
    114   T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
    115   IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan);
    116 
    117   // do upper_bound on col to find which pointer we should be using
    118   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
    119   int num_outputs = output_ptr_data.size;
    120 
    121   // verbose declaration needed due to template
    122   extern __shared__ __align__(sizeof(T)) unsigned char smem[];
    123   IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
    124 
    125   if (useSmem) {
    126     IntType lidx = threadIdx.y * blockDim.x + threadIdx.x;
    127     IntType blockSize = blockDim.x * blockDim.y;
    128 
    129     for (IntType i = lidx; i < output_scan.size; i += blockSize) {
    130       smem_col_scan[i] = col_scan[i];
    131     }
    132 
    133     __syncthreads();
    134 
    135     col_scan = smem_col_scan;
    136   }
    137 
    138   // do an initial binary search and then scan linearly from there
    139   // works well when there are many small segments and when the
    140   // segments are much longer
    141   IntType segment =
    142       cuda_helper::upper_bound<IntType>(col_scan, num_outputs, gidx) - 1;
    143 
    144   IntType curr_offset = col_scan[segment];
    145   IntType curr_segment = segment;
    146   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
    147     IntType curr_col_offset;
    148     while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) {
    149       curr_offset = curr_col_offset;
    150       ++curr_segment;
    151     }
    152 
    153     IntType local_col = gidx - curr_offset;
    154     IntType segment_width = curr_col_offset - curr_offset;
    155     T* output_ptr = output_ptrs[curr_segment];
    156 
    157     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
    158     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y)
    159       output_ptr[gidy * segment_width + local_col] =
    160           input_ptr[gidy * total_cols + gidx];
    161   }
    162 }
    163 
    164 // different from the original split implementation due to 2D vs 3D
    165 // dimensions.  This version is likely faster due to less integer math.
    166 template <typename T>
    167 __global__ void SplitVOpKernel_fixed(
    168     const T* input, int32 prefix_dim_size, int32 suffix_dim_size,
    169     CudaDeviceArrayStruct<T*> output_ptr_data) {
    170   const int32 num_split = output_ptr_data.size;
    171   T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data);
    172 
    173   eigen_assert(blockDim.y == 1);
    174   eigen_assert(blockDim.z == 1);
    175 
    176   int32 size = prefix_dim_size * suffix_dim_size;
    177   int32 piece_size = suffix_dim_size / num_split;
    178 
    179   CUDA_1D_KERNEL_LOOP(offset, size) {
    180     // Calculate the index into input from offset.
    181     int32 i = offset / suffix_dim_size;
    182     int32 j = offset % suffix_dim_size;
    183 
    184     // Find the output buffer that should be written to.
    185     T* output_ptr = output_ptrs[j / piece_size];
    186     int32 output_offset = i * piece_size + (j % piece_size);
    187     output_ptr[output_offset] = input[offset];
    188   }
    189 }
    190 
    191 template <typename T>
    192 struct SplitOpGPULaunch {
    193   void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size,
    194            int32 split_dim_size, int32 suffix_dim_size,
    195            const CudaDeviceArrayStruct<T*>& output_ptr_data) {
    196     CudaLaunchConfig config = GetCudaLaunchConfig(
    197         prefix_dim_size * split_dim_size * suffix_dim_size, d);
    198 
    199     SplitOpKernel<T>
    200         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    201             input, prefix_dim_size, split_dim_size, suffix_dim_size,
    202             output_ptr_data);
    203   }
    204 };
    205 
    206 template <typename T, typename IntType>
    207 struct SplitVOpGPULaunch {
    208   void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size,
    209            const T* input_ptr, int total_rows, int total_cols,
    210            const CudaDeviceArrayStruct<IntType>& output_scan,
    211            const CudaDeviceArrayStruct<T*>& output_ptr_data) {
    212     if (fixed_size) {
    213       CudaLaunchConfig config =
    214           GetCudaLaunchConfig(total_rows * total_cols, gpu_device);
    215 
    216       SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0,
    217                                 gpu_device.stream()>>>(
    218           input_ptr, total_rows, total_cols, output_ptr_data);
    219     } else {
    220       auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device);
    221       IntType smem_max = gpu_device.sharedMemPerBlock();
    222       IntType smem_usage = output_scan.size * sizeof(IntType);
    223       // performance crossover is less than using maximum available shared
    224       // memory on most processors possibly due to decreasing occupancy
    225       // 4096 inputs is a lot, most code will take the smem path
    226       const int32 kMaxSmemBytesPerformance = 16384;
    227       if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
    228         split_v_kernel<T, IntType, true>
    229             <<<config.block_count, config.thread_per_block, smem_usage,
    230                gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
    231                                       total_cols, output_ptr_data);
    232       else
    233         split_v_kernel<T, IntType, false>
    234             <<<config.block_count, config.thread_per_block, 0,
    235                gpu_device.stream()>>>(input_ptr, output_scan, total_rows,
    236                                       total_cols, output_ptr_data);
    237     }
    238   }
    239 };
    240 
    241 #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;
    242 
    243 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    244 TF_CALL_complex64(REGISTER_GPU_KERNEL);
    245 TF_CALL_complex128(REGISTER_GPU_KERNEL);
    246 TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
    247 #undef REGISTER_GPU_KERNEL
    248 #define REGISTER_GPU_KERNEL(T)                 \
    249   template struct SplitVOpGPULaunch<T, int32>; \
    250   template struct SplitVOpGPULaunch<T, int64>;
    251 
    252 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    253 TF_CALL_complex64(REGISTER_GPU_KERNEL);
    254 TF_CALL_complex128(REGISTER_GPU_KERNEL);
    255 TF_CALL_bfloat16(REGISTER_GPU_KERNEL);
    256 #undef REGISTER_GPU_KERNEL
    257 
    258 }  // namespace tensorflow
    259 
    260 #endif  // GOOGLE_CUDA
    261