1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include <stdio.h> 21 22 #include "tensorflow/core/kernels/split_lib.h" 23 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/kernels/cuda_device_array_gpu.h" 27 #include "tensorflow/core/util/cuda_kernel_helper.h" 28 29 namespace tensorflow { 30 namespace functor { 31 32 template <typename Device, typename T> 33 void Split<Device, T>::operator()( 34 const Device& d, typename TTypes<T, 3>::Tensor output, 35 typename TTypes<T, 3>::ConstTensor input, 36 const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices, 37 const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) { 38 To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes); 39 } 40 41 template <typename Device, typename T> 42 void SplitCustom<Device, T>::operator()( 43 const Device& d, typename TTypes<T, 2>::Tensor output, 44 typename TTypes<T, 2>::ConstTensor input, 45 const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_indices, 46 const Eigen::DSizes<Eigen::DenseIndex, 2>& slice_sizes) { 47 To32Bit(output).device(d) = To32Bit(input).slice(slice_indices, slice_sizes); 48 } 49 50 #define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>; 51 52 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); 53 TF_CALL_complex64(DEFINE_GPU_KERNELS); 54 TF_CALL_complex128(DEFINE_GPU_KERNELS); 55 TF_CALL_bfloat16(DEFINE_GPU_KERNELS); 56 57 #undef DEFINE_GPU_KERNELS 58 #define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>; 59 60 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); 61 TF_CALL_complex64(DEFINE_GPU_KERNELS); 62 TF_CALL_complex128(DEFINE_GPU_KERNELS); 63 TF_CALL_bfloat16(DEFINE_GPU_KERNELS); 64 65 #undef DEFINE_GPU_KERNELS 66 67 } // namespace functor 68 69 namespace { 70 71 template <typename T> 72 __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size, 73 int32 split_dim_size, int32 suffix_dim_size, 74 CudaDeviceArrayStruct<T*> output_ptr_data) { 75 const int32 num_split = output_ptr_data.size; 76 T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); 77 78 eigen_assert(blockDim.y == 1); 79 eigen_assert(blockDim.z == 1); 80 eigen_assert(split_dim_size % num_split == 0); 81 82 int32 size = prefix_dim_size * split_dim_size * suffix_dim_size; 83 int32 piece_size = split_dim_size / num_split; 84 85 CUDA_1D_KERNEL_LOOP(offset, size) { 86 // Calculate the index into input from offset. 87 int32 i = offset / (split_dim_size * suffix_dim_size); 88 int32 j = (offset % (split_dim_size * suffix_dim_size)) / suffix_dim_size; 89 int32 k = offset % suffix_dim_size; 90 91 // Find the output buffer that should be written to. 92 T* output_ptr = output_ptrs[j / piece_size]; 93 // output_ptr is pointing to an array of size 94 // [prefix_dim_size][piece_size][suffix_dim_size]. 95 // 96 // output_ptr[i][j % piece_size][k] = input[offset]; 97 // Linearize (i, j % piece_size, k) into an offset. 98 int32 output_offset = i * piece_size * suffix_dim_size + 99 (j % piece_size) * suffix_dim_size + k; 100 *(output_ptr + output_offset) = ldg(input + offset); 101 } 102 } 103 104 } // namespace 105 106 // cannot be in anonymous namespace due to extern shared memory 107 // very similar to the concat kernel except the input/output logic 108 // is reversed 109 template <typename T, typename IntType, bool useSmem> 110 __global__ void split_v_kernel(const T* input_ptr, 111 CudaDeviceArrayStruct<IntType> output_scan, 112 IntType total_rows, IntType total_cols, 113 CudaDeviceArrayStruct<T*> output_ptr_data) { 114 T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); 115 IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); 116 117 // do upper_bound on col to find which pointer we should be using 118 IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; 119 int num_outputs = output_ptr_data.size; 120 121 // verbose declaration needed due to template 122 extern __shared__ __align__(sizeof(T)) unsigned char smem[]; 123 IntType* smem_col_scan = reinterpret_cast<IntType*>(smem); 124 125 if (useSmem) { 126 IntType lidx = threadIdx.y * blockDim.x + threadIdx.x; 127 IntType blockSize = blockDim.x * blockDim.y; 128 129 for (IntType i = lidx; i < output_scan.size; i += blockSize) { 130 smem_col_scan[i] = col_scan[i]; 131 } 132 133 __syncthreads(); 134 135 col_scan = smem_col_scan; 136 } 137 138 // do an initial binary search and then scan linearly from there 139 // works well when there are many small segments and when the 140 // segments are much longer 141 IntType segment = 142 cuda_helper::upper_bound<IntType>(col_scan, num_outputs, gidx) - 1; 143 144 IntType curr_offset = col_scan[segment]; 145 IntType curr_segment = segment; 146 for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) { 147 IntType curr_col_offset; 148 while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) { 149 curr_offset = curr_col_offset; 150 ++curr_segment; 151 } 152 153 IntType local_col = gidx - curr_offset; 154 IntType segment_width = curr_col_offset - curr_offset; 155 T* output_ptr = output_ptrs[curr_segment]; 156 157 IntType gidy = blockIdx.y * blockDim.y + threadIdx.y; 158 for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) 159 output_ptr[gidy * segment_width + local_col] = 160 input_ptr[gidy * total_cols + gidx]; 161 } 162 } 163 164 // different from the original split implementation due to 2D vs 3D 165 // dimensions. This version is likely faster due to less integer math. 166 template <typename T> 167 __global__ void SplitVOpKernel_fixed( 168 const T* input, int32 prefix_dim_size, int32 suffix_dim_size, 169 CudaDeviceArrayStruct<T*> output_ptr_data) { 170 const int32 num_split = output_ptr_data.size; 171 T** output_ptrs = GetCudaDeviceArrayOnDevice(&output_ptr_data); 172 173 eigen_assert(blockDim.y == 1); 174 eigen_assert(blockDim.z == 1); 175 176 int32 size = prefix_dim_size * suffix_dim_size; 177 int32 piece_size = suffix_dim_size / num_split; 178 179 CUDA_1D_KERNEL_LOOP(offset, size) { 180 // Calculate the index into input from offset. 181 int32 i = offset / suffix_dim_size; 182 int32 j = offset % suffix_dim_size; 183 184 // Find the output buffer that should be written to. 185 T* output_ptr = output_ptrs[j / piece_size]; 186 int32 output_offset = i * piece_size + (j % piece_size); 187 output_ptr[output_offset] = input[offset]; 188 } 189 } 190 191 template <typename T> 192 struct SplitOpGPULaunch { 193 void Run(const Eigen::GpuDevice& d, const T* input, int32 prefix_dim_size, 194 int32 split_dim_size, int32 suffix_dim_size, 195 const CudaDeviceArrayStruct<T*>& output_ptr_data) { 196 CudaLaunchConfig config = GetCudaLaunchConfig( 197 prefix_dim_size * split_dim_size * suffix_dim_size, d); 198 199 SplitOpKernel<T> 200 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 201 input, prefix_dim_size, split_dim_size, suffix_dim_size, 202 output_ptr_data); 203 } 204 }; 205 206 template <typename T, typename IntType> 207 struct SplitVOpGPULaunch { 208 void Run(const Eigen::GpuDevice& gpu_device, bool fixed_size, 209 const T* input_ptr, int total_rows, int total_cols, 210 const CudaDeviceArrayStruct<IntType>& output_scan, 211 const CudaDeviceArrayStruct<T*>& output_ptr_data) { 212 if (fixed_size) { 213 CudaLaunchConfig config = 214 GetCudaLaunchConfig(total_rows * total_cols, gpu_device); 215 216 SplitVOpKernel_fixed<T><<<config.block_count, config.thread_per_block, 0, 217 gpu_device.stream()>>>( 218 input_ptr, total_rows, total_cols, output_ptr_data); 219 } else { 220 auto config = GetCuda2DLaunchConfig(total_cols, total_rows, gpu_device); 221 IntType smem_max = gpu_device.sharedMemPerBlock(); 222 IntType smem_usage = output_scan.size * sizeof(IntType); 223 // performance crossover is less than using maximum available shared 224 // memory on most processors possibly due to decreasing occupancy 225 // 4096 inputs is a lot, most code will take the smem path 226 const int32 kMaxSmemBytesPerformance = 16384; 227 if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) 228 split_v_kernel<T, IntType, true> 229 <<<config.block_count, config.thread_per_block, smem_usage, 230 gpu_device.stream()>>>(input_ptr, output_scan, total_rows, 231 total_cols, output_ptr_data); 232 else 233 split_v_kernel<T, IntType, false> 234 <<<config.block_count, config.thread_per_block, 0, 235 gpu_device.stream()>>>(input_ptr, output_scan, total_rows, 236 total_cols, output_ptr_data); 237 } 238 } 239 }; 240 241 #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>; 242 243 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); 244 TF_CALL_complex64(REGISTER_GPU_KERNEL); 245 TF_CALL_complex128(REGISTER_GPU_KERNEL); 246 TF_CALL_bfloat16(REGISTER_GPU_KERNEL); 247 #undef REGISTER_GPU_KERNEL 248 #define REGISTER_GPU_KERNEL(T) \ 249 template struct SplitVOpGPULaunch<T, int32>; \ 250 template struct SplitVOpGPULaunch<T, int64>; 251 252 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); 253 TF_CALL_complex64(REGISTER_GPU_KERNEL); 254 TF_CALL_complex128(REGISTER_GPU_KERNEL); 255 TF_CALL_bfloat16(REGISTER_GPU_KERNEL); 256 #undef REGISTER_GPU_KERNEL 257 258 } // namespace tensorflow 259 260 #endif // GOOGLE_CUDA 261