1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include <algorithm> 21 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/kernels/bias_op.h" 24 #include "tensorflow/core/kernels/bias_op_gpu.h" 25 #include "tensorflow/core/util/cuda_kernel_helper.h" 26 27 namespace tensorflow { 28 29 typedef Eigen::GpuDevice GPUDevice; 30 31 // There are no native fp16 atomics (we simulate them using 32-bit atomics), 32 // so fp16 sums are done in fp32 internally. (We don't have a lot of shared 33 // memory traffic; BiasGradNCHW_SharedAtomics in particular works almost 34 // entirely on a local variable.) 35 template <class T> 36 struct AccumulatorType { 37 typedef T type; 38 }; 39 40 template <> 41 struct AccumulatorType<Eigen::half> { 42 typedef float type; 43 }; 44 45 // Definition of the GPU implementations declared in bias_op.cc. 46 47 template <typename T> 48 __global__ void BiasNHWCKernel(int32 nthreads, const T* input, const T* bias, 49 T* output, int32 bias_size) { 50 CUDA_1D_KERNEL_LOOP(index, nthreads) { 51 int32 bias_offset = index % bias_size; 52 output[index] = ldg(input + index) + ldg(bias + bias_offset); 53 } 54 } 55 56 template <typename T> 57 __global__ void BiasNCHWKernel(int32 nthreads, const T* input, const T* bias, 58 T* output, int32 bias_size, int32 image_size) { 59 CUDA_1D_KERNEL_LOOP(index, nthreads) { 60 int32 index2 = index / image_size; 61 int32 bias_offset = index2 % bias_size; 62 output[index] = ldg(input + index) + ldg(bias + bias_offset); 63 } 64 } 65 66 // Add "bias" to "input", broadcasting it on all dimensions but the bias 67 // dimension. 68 template <typename T> 69 void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias, 70 T* output, int32 batch, int32 height, int32 width, 71 int32 channel, TensorFormat data_format) { 72 const int32 bias_size = channel; 73 const int32 image_size = height * width; 74 const int32 total_count = batch * bias_size * image_size; 75 if (total_count == 0) { 76 return; 77 } 78 CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); 79 if (data_format == FORMAT_NHWC) { 80 BiasNHWCKernel<T> 81 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 82 config.virtual_thread_count, input, bias, output, bias_size); 83 } else { 84 BiasNCHWKernel<T> 85 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 86 config.virtual_thread_count, input, bias, output, bias_size, 87 image_size); 88 } 89 } 90 91 // A naive implementation that is functional on all cases. 92 template <typename T> 93 __global__ void BiasGradNHWC_Naive(int32 nthreads, const T* output_backprop, 94 T* bias_backprop, int32 bias_size) { 95 CUDA_1D_KERNEL_LOOP(index, nthreads) { 96 int32 bias_offset = index % bias_size; 97 CudaAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index)); 98 } 99 } 100 101 // A naive implementation that is functional on all cases. 102 template <typename T> 103 __global__ void BiasGradNCHW_Naive(int32 nthreads, const T* output_backprop, 104 T* bias_backprop, int32 bias_size, 105 int32 image_size) { 106 CUDA_1D_KERNEL_LOOP(index, nthreads) { 107 int32 index2 = index / image_size; 108 int32 bias_offset = index2 % bias_size; 109 CudaAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index)); 110 } 111 } 112 113 extern __shared__ char s_buf[]; 114 115 template <typename T> 116 __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads, 117 const T* output_backprop, 118 T* bias_backprop, int32 bias_size) { 119 typedef typename AccumulatorType<T>::type AccT; 120 AccT* s_data = reinterpret_cast<AccT*>(s_buf); 121 for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) { 122 s_data[index] = AccT(0); 123 } 124 __syncthreads(); 125 126 for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; 127 index += blockDim.x * gridDim.x) { 128 int32 bias_offset = index % bias_size; 129 CudaAtomicAdd(s_data + bias_offset, AccT(ldg(output_backprop + index))); 130 } 131 __syncthreads(); 132 133 for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) { 134 CudaAtomicAdd(bias_backprop + index, T(s_data[index])); 135 } 136 } 137 138 template <typename T> 139 __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, 140 T* bias_backprop, int32 batch, 141 int32 bias_size, int32 image_size, 142 int group_size) { 143 // Initialize the shared memory. 144 typedef typename AccumulatorType<T>::type AccT; 145 const int32 kSDataSize = 32; 146 __shared__ AccT s_data[kSDataSize]; 147 for (int32 index = threadIdx.x; index < kSDataSize; index += blockDim.x) { 148 s_data[index] = AccT(0); 149 } 150 __syncthreads(); 151 152 // Accumulate all the values within this thread. They all have the same bias 153 // index. 154 int32 bias_index = blockIdx.x % bias_size; 155 int32 group_index = blockIdx.x / bias_size; 156 int32 total_count = batch * image_size; 157 AccT sum(0); 158 for (int32 index = group_index * blockDim.x + threadIdx.x; 159 index < total_count; index += blockDim.x * group_size) { 160 int32 image_offset = index % image_size; 161 int32 batch = index / image_size; 162 T val = ldg(output_backprop + 163 (batch * bias_size + bias_index) * image_size + image_offset); 164 sum += AccT(val); 165 } 166 167 // Write the accumulated sum in this thread to the shared memory. Each thread 168 // shifts their write location to avoid bank conflict. 169 int bias_offset = threadIdx.x % 32; 170 CudaAtomicAdd(s_data + bias_offset, sum); 171 __syncthreads(); 172 173 // Accumulate the results in the shared memory into the first element. 174 // No syncthreads is needed since this is only in the same warp. 175 int32 thread_index = threadIdx.x; 176 if (thread_index < 32) { 177 AccT data = s_data[thread_index]; 178 for (int32 delta = warpSize / 2; delta > 0; delta /= 2) { 179 data += CudaShuffleXorSync(kCudaWarpAll, data, delta); 180 } 181 if (thread_index == 0) { 182 CudaAtomicAdd(bias_backprop + bias_index, T(data)); 183 } 184 } 185 } 186 187 template <typename T> 188 void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, 189 T* bias_backprop, int32 batch, int32 height, 190 int32 width, int32 channel, 191 TensorFormat data_format) { 192 const int32 bias_size = channel; 193 const int32 image_size = height * width; 194 const int32 total_count = batch * bias_size * image_size; 195 if (total_count == 0) { 196 return; 197 } 198 static constexpr int32 kWarpSize = 32; 199 CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); 200 201 const int max_shared_memory_size = d.sharedMemPerBlock() / 2; 202 int32 shared_memory_size = 0; 203 if (data_format == FORMAT_NHWC) { 204 shared_memory_size = bias_size * sizeof(typename AccumulatorType<T>::type); 205 } 206 // Check if we have enough shared memory. 207 if (shared_memory_size <= max_shared_memory_size) { 208 if (data_format == FORMAT_NHWC) { 209 BiasGradNHWC_SharedAtomics<T> 210 <<<config.block_count, config.thread_per_block, shared_memory_size, 211 d.stream()>>>(total_count, output_backprop, bias_backprop, 212 bias_size); 213 } else { 214 // Round up the block count to multiple of bias_size. 215 int group_size = (config.block_count + bias_size - 1) / bias_size; 216 config.block_count = group_size * bias_size; 217 if (config.thread_per_block < kWarpSize) { 218 config.thread_per_block = kWarpSize; 219 } 220 BiasGradNCHW_SharedAtomics<T> 221 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 222 output_backprop, bias_backprop, batch, bias_size, image_size, 223 group_size); 224 } 225 } else { 226 // Note that even if we don't have enough shared memory to fit the entire 227 // output block, it is possible to process one group of elements at a time. 228 // But for now, we simply fall back to the naive implementation. 229 if (data_format == FORMAT_NHWC) { 230 BiasGradNHWC_Naive<T> 231 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 232 total_count, output_backprop, bias_backprop, bias_size); 233 } else { 234 BiasGradNCHW_Naive<T> 235 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 236 total_count, output_backprop, bias_backprop, bias_size, 237 image_size); 238 } 239 } 240 } 241 242 #define DEFINE_GPU_SPECS(T) \ 243 template struct BiasGPU<T>; \ 244 template struct BiasGradGPU<T>; 245 246 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); 247 248 } // end namespace tensorflow 249 250 #endif // GOOGLE_CUDA 251