Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include <algorithm>
     21 
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/kernels/bias_op.h"
     24 #include "tensorflow/core/kernels/bias_op_gpu.h"
     25 #include "tensorflow/core/util/cuda_kernel_helper.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::GpuDevice GPUDevice;
     30 
     31 // There are no native fp16 atomics (we simulate them using 32-bit atomics),
     32 // so fp16 sums are done in fp32 internally. (We don't have a lot of shared
     33 // memory traffic; BiasGradNCHW_SharedAtomics in particular works almost
     34 // entirely on a local variable.)
     35 template <class T>
     36 struct AccumulatorType {
     37   typedef T type;
     38 };
     39 
     40 template <>
     41 struct AccumulatorType<Eigen::half> {
     42   typedef float type;
     43 };
     44 
     45 // Definition of the GPU implementations declared in bias_op.cc.
     46 
     47 template <typename T>
     48 __global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias,
     49                                T* output, int32 bias_size) {
     50   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     51     int32 bias_offset = index % bias_size;
     52     output[index] = ldg(input + index) + ldg(bias + bias_offset);
     53   }
     54 }
     55 
     56 template <typename T>
     57 __global__ void BiasNCHWKernel(int32 nthreads, const T* input, const T* bias,
     58                                T* output, int32 bias_size, int32 image_size) {
     59   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     60     int32 index2 = index / image_size;
     61     int32 bias_offset = index2 % bias_size;
     62     output[index] = ldg(input + index) + ldg(bias + bias_offset);
     63   }
     64 }
     65 
     66 // Add "bias" to "input", broadcasting it on all dimensions but the bias
     67 // dimension.
     68 template <typename T>
     69 void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
     70                          T* output, int32 batch, int32 height, int32 width,
     71                          int32 channel, TensorFormat data_format) {
     72   const int32 bias_size = channel;
     73   const int32 image_size = height * width;
     74   const int32 total_count = batch * bias_size * image_size;
     75   if (total_count == 0) {
     76     return;
     77   }
     78   CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
     79   if (data_format == FORMAT_NHWC) {
     80     BiasNHWCKernel<T>
     81         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
     82             config.virtual_thread_count, input, bias, output, bias_size);
     83   } else {
     84     BiasNCHWKernel<T>
     85         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
     86             config.virtual_thread_count, input, bias, output, bias_size,
     87             image_size);
     88   }
     89 }
     90 
     91 // A naive implementation that is functional on all cases.
     92 template <typename T>
     93 __global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop,
     94                                    T* bias_backprop, int32 bias_size) {
     95   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     96     int32 bias_offset = index % bias_size;
     97     CudaAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
     98   }
     99 }
    100 
    101 // A naive implementation that is functional on all cases.
    102 template <typename T>
    103 __global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop,
    104                                    T* bias_backprop, int32 bias_size,
    105                                    int32 image_size) {
    106   CUDA_1D_KERNEL_LOOP(index, nthreads) {
    107     int32 index2 = index / image_size;
    108     int32 bias_offset = index2 % bias_size;
    109     CudaAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
    110   }
    111 }
    112 
    113 extern __shared__ char s_buf[];
    114 
    115 template <typename T>
    116 __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
    117                                            const T* output_backprop,
    118                                            T* bias_backprop, int32 bias_size) {
    119   typedef typename AccumulatorType<T>::type AccT;
    120   AccT* s_data = reinterpret_cast<AccT*>(s_buf);
    121   for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
    122     s_data[index] = AccT(0);
    123   }
    124   __syncthreads();
    125 
    126   for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
    127        index += blockDim.x * gridDim.x) {
    128     int32 bias_offset = index % bias_size;
    129     CudaAtomicAdd(s_data + bias_offset, AccT(ldg(output_backprop + index)));
    130   }
    131   __syncthreads();
    132 
    133   for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
    134     CudaAtomicAdd(bias_backprop + index, T(s_data[index]));
    135   }
    136 }
    137 
    138 template <typename T>
    139 __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
    140                                            T* bias_backprop, int32 batch,
    141                                            int32 bias_size, int32 image_size,
    142                                            int group_size) {
    143   // Initialize the shared memory.
    144   typedef typename AccumulatorType<T>::type AccT;
    145   const int32 kSDataSize = 32;
    146   __shared__ AccT s_data[kSDataSize];
    147   for (int32 index = threadIdx.x; index < kSDataSize; index += blockDim.x) {
    148     s_data[index] = AccT(0);
    149   }
    150   __syncthreads();
    151 
    152   // Accumulate all the values within this thread. They all have the same bias
    153   // index.
    154   int32 bias_index = blockIdx.x % bias_size;
    155   int32 group_index = blockIdx.x / bias_size;
    156   int32 total_count = batch * image_size;
    157   AccT sum(0);
    158   for (int32 index = group_index * blockDim.x + threadIdx.x;
    159        index < total_count; index += blockDim.x * group_size) {
    160     int32 image_offset = index % image_size;
    161     int32 batch = index / image_size;
    162     T val = ldg(output_backprop +
    163                 (batch * bias_size + bias_index) * image_size + image_offset);
    164     sum += AccT(val);
    165   }
    166 
    167   // Write the accumulated sum in this thread to the shared memory. Each thread
    168   // shifts their write location to avoid bank conflict.
    169   int bias_offset = threadIdx.x % 32;
    170   CudaAtomicAdd(s_data + bias_offset, sum);
    171   __syncthreads();
    172 
    173   // Accumulate the results in the shared memory into the first element.
    174   // No syncthreads is needed since this is only in the same warp.
    175   int32 thread_index = threadIdx.x;
    176   if (thread_index < 32) {
    177     AccT data = s_data[thread_index];
    178     for (int32 delta = warpSize / 2; delta > 0; delta /= 2) {
    179       data += CudaShuffleXorSync(kCudaWarpAll, data, delta);
    180     }
    181     if (thread_index == 0) {
    182       CudaAtomicAdd(bias_backprop + bias_index, T(data));
    183     }
    184   }
    185 }
    186 
    187 template <typename T>
    188 void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
    189                              T* bias_backprop, int32 batch, int32 height,
    190                              int32 width, int32 channel,
    191                              TensorFormat data_format) {
    192   const int32 bias_size = channel;
    193   const int32 image_size = height * width;
    194   const int32 total_count = batch * bias_size * image_size;
    195   if (total_count == 0) {
    196     return;
    197   }
    198   static constexpr int32 kWarpSize = 32;
    199   CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
    200 
    201   const int max_shared_memory_size = d.sharedMemPerBlock() / 2;
    202   int32 shared_memory_size = 0;
    203   if (data_format == FORMAT_NHWC) {
    204     shared_memory_size = bias_size * sizeof(typename AccumulatorType<T>::type);
    205   }
    206   // Check if we have enough shared memory.
    207   if (shared_memory_size <= max_shared_memory_size) {
    208     if (data_format == FORMAT_NHWC) {
    209       BiasGradNHWC_SharedAtomics<T>
    210           <<<config.block_count, config.thread_per_block, shared_memory_size,
    211              d.stream()>>>(total_count, output_backprop, bias_backprop,
    212                            bias_size);
    213     } else {
    214       // Round up the block count to multiple of bias_size.
    215       int group_size = (config.block_count + bias_size - 1) / bias_size;
    216       config.block_count = group_size * bias_size;
    217       if (config.thread_per_block < kWarpSize) {
    218         config.thread_per_block = kWarpSize;
    219       }
    220       BiasGradNCHW_SharedAtomics<T>
    221           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    222               output_backprop, bias_backprop, batch, bias_size, image_size,
    223               group_size);
    224     }
    225   } else {
    226     // Note that even if we don't have enough shared memory to fit the entire
    227     // output block, it is possible to process one group of elements at a time.
    228     // But for now, we simply fall back to the naive implementation.
    229     if (data_format == FORMAT_NHWC) {
    230       BiasGradNHWC_Naive<T>
    231           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    232               total_count, output_backprop, bias_backprop, bias_size);
    233     } else {
    234       BiasGradNCHW_Naive<T>
    235           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    236               total_count, output_backprop, bias_backprop, bias_size,
    237               image_size);
    238     }
    239   }
    240 }
    241 
    242 #define DEFINE_GPU_SPECS(T)   \
    243   template struct BiasGPU<T>; \
    244   template struct BiasGradGPU<T>;
    245 
    246 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
    247 
    248 }  // end namespace tensorflow
    249 
    250 #endif  // GOOGLE_CUDA
    251