Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include <memory>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/bfloat16.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor_types.h"
     26 #include "tensorflow/core/kernels/cuda_device_array_gpu.h"
     27 #include "tensorflow/core/util/cuda_kernel_helper.h"
     28 
     29 namespace tensorflow {
     30 
     31 typedef Eigen::GpuDevice GPUDevice;
     32 
     33 namespace {
     34 
     35 template <typename T, typename IntType>
     36 __global__ void concat_fixed_kernel(
     37     CudaDeviceArrayStruct<const T*> input_ptr_data, int split_size,
     38     int total_rows, int total_cols, T* output) {
     39   const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
     40   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
     41 
     42   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
     43     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
     44 
     45     IntType split = gidx / split_size;
     46     const T* input_ptr = input_ptrs[split];
     47     IntType col_offset = gidx % split_size;
     48 #pragma unroll
     49     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) {
     50       output[gidy * total_cols + gidx] =
     51           input_ptr[gidy * split_size + col_offset];
     52     }
     53   }
     54 }
     55 
     56 }  // end namespace
     57 
     58 // cannot be in anonymous namespace due to extern shared memory
     59 template <typename T, typename IntType, bool useSmem>
     60 __global__ void concat_variable_kernel(
     61     CudaDeviceArrayStruct<const T*> input_ptr_data,
     62     CudaDeviceArrayStruct<IntType> output_scan, IntType total_rows,
     63     IntType total_cols, T* output) {
     64   const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data);
     65   IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan);
     66 
     67   // do upper_bound on col to find which pointer we should be using
     68   IntType gidx = blockIdx.x * blockDim.x + threadIdx.x;
     69   IntType num_inputs = input_ptr_data.size;
     70 
     71   // verbose declaration needed due to template
     72   extern __shared__ __align__(sizeof(T)) unsigned char smem[];
     73   IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
     74 
     75   if (useSmem) {
     76     IntType lidx = threadIdx.y * blockDim.x + threadIdx.x;
     77     IntType blockSize = blockDim.x * blockDim.y;
     78 
     79     for (IntType i = lidx; i < output_scan.size; i += blockSize) {
     80       smem_col_scan[i] = col_scan[i];
     81     }
     82 
     83     __syncthreads();
     84 
     85     col_scan = smem_col_scan;
     86   }
     87 
     88   // do an initial binary search and then scan linearly from there
     89   // works well when there are many small segments and when the
     90   // segments are much longer
     91   IntType segment =
     92       cuda_helper::upper_bound<IntType>(col_scan, num_inputs, gidx) - 1;
     93 
     94   IntType curr_offset = col_scan[segment];
     95   IntType curr_segment = segment;
     96   for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) {
     97     IntType curr_col_offset;
     98     while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) {
     99       curr_offset = curr_col_offset;
    100       ++curr_segment;
    101     }
    102 
    103     IntType local_col = gidx - curr_offset;
    104     IntType segment_width = curr_col_offset - curr_offset;
    105     const T* input_ptr = input_ptrs[curr_segment];
    106 
    107     IntType gidy = blockIdx.y * blockDim.y + threadIdx.y;
    108     for (; gidy < total_rows; gidy += blockDim.y * gridDim.y)
    109       output[gidy * total_cols + gidx] =
    110           input_ptr[gidy * segment_width + local_col];
    111   }
    112 }
    113 
    114 template <typename T, typename IntType>
    115 void ConcatGPUSlice(
    116     const Eigen::GpuDevice& gpu_device,
    117     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
    118         inputs_flat,
    119     typename TTypes<T, 2>::Matrix* output) {
    120   Eigen::array<IntType, 2> offset{0, 0};
    121   for (int i = 0; i < inputs_flat.size(); ++i) {
    122     Eigen::array<IntType, 2> size;
    123     size[0] = inputs_flat[i]->dimension(0);
    124     size[1] = inputs_flat[i]->dimension(1);
    125     if (std::is_same<IntType, int32>::value) {
    126       To32Bit(*output).slice(offset, size).device(gpu_device) =
    127           To32Bit(*inputs_flat[i]);
    128     } else {
    129       output->slice(offset, size).device(gpu_device) = *inputs_flat[i];
    130     }
    131 
    132     offset[1] += size[1];
    133   }
    134 }
    135 
    136 template <typename T, typename IntType>
    137 void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
    138                    const CudaDeviceArrayStruct<const T*>& input_ptrs,
    139                    const CudaDeviceArrayStruct<IntType>& output_scan,
    140                    bool fixed_size, int split_size,
    141                    typename TTypes<T, 2>::Matrix* output) {
    142   auto config = GetCuda2DLaunchConfig(output->dimension(1),
    143                                       output->dimension(0), gpu_device);
    144 
    145   if (fixed_size) {
    146     concat_fixed_kernel<T, IntType>
    147         <<<config.block_count, config.thread_per_block, 0,
    148            gpu_device.stream()>>>(input_ptrs, split_size, output->dimension(0),
    149                                   output->dimension(1), output->data());
    150   } else {
    151     IntType smem_max = gpu_device.sharedMemPerBlock();
    152     IntType smem_usage = output_scan.size * sizeof(IntType);
    153     // performance crossover is less than using maximum available shared memory
    154     // on most processors
    155     // possibly due to decreasing occupancy
    156     // 4096 inputs is a lot, most code will take the smem path
    157     const int32 kMaxSmemBytesPerformance = 16384;
    158     if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
    159       concat_variable_kernel<T, IntType, true>
    160           <<<config.block_count, config.thread_per_block, smem_usage,
    161              gpu_device.stream()>>>(input_ptrs, output_scan,
    162                                     output->dimension(0), output->dimension(1),
    163                                     output->data());
    164     else
    165       concat_variable_kernel<T, IntType, false>
    166           <<<config.block_count, config.thread_per_block, 0,
    167              gpu_device.stream()>>>(input_ptrs, output_scan,
    168                                     output->dimension(0), output->dimension(1),
    169                                     output->data());
    170   }
    171 }
    172 
    173 #define REGISTER_GPUCONCAT32(T)                                               \
    174   template void ConcatGPUSlice<T, int32>(                                     \
    175       const Eigen::GpuDevice& gpu_device,                                     \
    176       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
    177           inputs_flat,                                                        \
    178       typename TTypes<T, 2>::Matrix* output);
    179 
    180 #define REGISTER_GPUCONCAT64(T)                                               \
    181   template void ConcatGPUSlice<T, int64>(                                     \
    182       const Eigen::GpuDevice& gpu_device,                                     \
    183       const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
    184           inputs_flat,                                                        \
    185       typename TTypes<T, 2>::Matrix* output);
    186 
    187 #define REGISTER_GPU32(T)                                               \
    188   template void ConcatGPUImpl<T, int32>(                                \
    189       const Eigen::GpuDevice& d,                                        \
    190       const CudaDeviceArrayStruct<const T*>& input_ptrs,                \
    191       const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \
    192       int split_size, typename TTypes<T, 2>::Matrix* output);
    193 
    194 #define REGISTER_GPU64(T)                                               \
    195   template void ConcatGPUImpl<T, int64>(                                \
    196       const Eigen::GpuDevice& d,                                        \
    197       const CudaDeviceArrayStruct<const T*>& input_ptrs,                \
    198       const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \
    199       int split_size, typename TTypes<T, 2>::Matrix* output);
    200 
    201 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
    202 TF_CALL_complex64(REGISTER_GPUCONCAT32);
    203 TF_CALL_complex128(REGISTER_GPUCONCAT32);
    204 TF_CALL_int64(REGISTER_GPUCONCAT32);
    205 REGISTER_GPUCONCAT32(bfloat16);
    206 REGISTER_GPUCONCAT32(bool);
    207 
    208 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
    209 TF_CALL_complex64(REGISTER_GPUCONCAT64);
    210 TF_CALL_complex128(REGISTER_GPUCONCAT64);
    211 TF_CALL_int64(REGISTER_GPUCONCAT64);
    212 REGISTER_GPUCONCAT64(bfloat16);
    213 REGISTER_GPUCONCAT64(bool);
    214 
    215 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
    216 TF_CALL_complex64(REGISTER_GPU32);
    217 TF_CALL_complex128(REGISTER_GPU32);
    218 TF_CALL_int64(REGISTER_GPU32);
    219 REGISTER_GPU32(bfloat16);
    220 REGISTER_GPU32(bool);
    221 
    222 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
    223 TF_CALL_complex64(REGISTER_GPU64);
    224 TF_CALL_complex128(REGISTER_GPU64);
    225 TF_CALL_int64(REGISTER_GPU64);
    226 REGISTER_GPU64(bfloat16);
    227 REGISTER_GPU64(bool);
    228 
    229 #undef REGISTER_GPUCONCAT32
    230 #undef REGISTER_GPUCONCAT64
    231 #undef REGISTER_GPU32
    232 #undef REGISTER_GPU64
    233 
    234 }  // end namespace tensorflow
    235 
    236 #endif  // GOOGLE_CUDA
    237