1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include <memory> 21 #include <vector> 22 23 #include "tensorflow/core/framework/bfloat16.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/kernels/cuda_device_array_gpu.h" 27 #include "tensorflow/core/util/cuda_kernel_helper.h" 28 29 namespace tensorflow { 30 31 typedef Eigen::GpuDevice GPUDevice; 32 33 namespace { 34 35 template <typename T, typename IntType> 36 __global__ void concat_fixed_kernel( 37 CudaDeviceArrayStruct<const T*> input_ptr_data, int split_size, 38 int total_rows, int total_cols, T* output) { 39 const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); 40 IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; 41 42 for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) { 43 IntType gidy = blockIdx.y * blockDim.y + threadIdx.y; 44 45 IntType split = gidx / split_size; 46 const T* input_ptr = input_ptrs[split]; 47 IntType col_offset = gidx % split_size; 48 #pragma unroll 49 for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) { 50 output[gidy * total_cols + gidx] = 51 input_ptr[gidy * split_size + col_offset]; 52 } 53 } 54 } 55 56 } // end namespace 57 58 // cannot be in anonymous namespace due to extern shared memory 59 template <typename T, typename IntType, bool useSmem> 60 __global__ void concat_variable_kernel( 61 CudaDeviceArrayStruct<const T*> input_ptr_data, 62 CudaDeviceArrayStruct<IntType> output_scan, IntType total_rows, 63 IntType total_cols, T* output) { 64 const T** input_ptrs = GetCudaDeviceArrayOnDevice(&input_ptr_data); 65 IntType* col_scan = GetCudaDeviceArrayOnDevice(&output_scan); 66 67 // do upper_bound on col to find which pointer we should be using 68 IntType gidx = blockIdx.x * blockDim.x + threadIdx.x; 69 IntType num_inputs = input_ptr_data.size; 70 71 // verbose declaration needed due to template 72 extern __shared__ __align__(sizeof(T)) unsigned char smem[]; 73 IntType* smem_col_scan = reinterpret_cast<IntType*>(smem); 74 75 if (useSmem) { 76 IntType lidx = threadIdx.y * blockDim.x + threadIdx.x; 77 IntType blockSize = blockDim.x * blockDim.y; 78 79 for (IntType i = lidx; i < output_scan.size; i += blockSize) { 80 smem_col_scan[i] = col_scan[i]; 81 } 82 83 __syncthreads(); 84 85 col_scan = smem_col_scan; 86 } 87 88 // do an initial binary search and then scan linearly from there 89 // works well when there are many small segments and when the 90 // segments are much longer 91 IntType segment = 92 cuda_helper::upper_bound<IntType>(col_scan, num_inputs, gidx) - 1; 93 94 IntType curr_offset = col_scan[segment]; 95 IntType curr_segment = segment; 96 for (; gidx < total_cols; gidx += blockDim.x * gridDim.x) { 97 IntType curr_col_offset; 98 while ((curr_col_offset = col_scan[curr_segment + 1]) <= gidx) { 99 curr_offset = curr_col_offset; 100 ++curr_segment; 101 } 102 103 IntType local_col = gidx - curr_offset; 104 IntType segment_width = curr_col_offset - curr_offset; 105 const T* input_ptr = input_ptrs[curr_segment]; 106 107 IntType gidy = blockIdx.y * blockDim.y + threadIdx.y; 108 for (; gidy < total_rows; gidy += blockDim.y * gridDim.y) 109 output[gidy * total_cols + gidx] = 110 input_ptr[gidy * segment_width + local_col]; 111 } 112 } 113 114 template <typename T, typename IntType> 115 void ConcatGPUSlice( 116 const Eigen::GpuDevice& gpu_device, 117 const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& 118 inputs_flat, 119 typename TTypes<T, 2>::Matrix* output) { 120 Eigen::array<IntType, 2> offset{0, 0}; 121 for (int i = 0; i < inputs_flat.size(); ++i) { 122 Eigen::array<IntType, 2> size; 123 size[0] = inputs_flat[i]->dimension(0); 124 size[1] = inputs_flat[i]->dimension(1); 125 if (std::is_same<IntType, int32>::value) { 126 To32Bit(*output).slice(offset, size).device(gpu_device) = 127 To32Bit(*inputs_flat[i]); 128 } else { 129 output->slice(offset, size).device(gpu_device) = *inputs_flat[i]; 130 } 131 132 offset[1] += size[1]; 133 } 134 } 135 136 template <typename T, typename IntType> 137 void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device, 138 const CudaDeviceArrayStruct<const T*>& input_ptrs, 139 const CudaDeviceArrayStruct<IntType>& output_scan, 140 bool fixed_size, int split_size, 141 typename TTypes<T, 2>::Matrix* output) { 142 auto config = GetCuda2DLaunchConfig(output->dimension(1), 143 output->dimension(0), gpu_device); 144 145 if (fixed_size) { 146 concat_fixed_kernel<T, IntType> 147 <<<config.block_count, config.thread_per_block, 0, 148 gpu_device.stream()>>>(input_ptrs, split_size, output->dimension(0), 149 output->dimension(1), output->data()); 150 } else { 151 IntType smem_max = gpu_device.sharedMemPerBlock(); 152 IntType smem_usage = output_scan.size * sizeof(IntType); 153 // performance crossover is less than using maximum available shared memory 154 // on most processors 155 // possibly due to decreasing occupancy 156 // 4096 inputs is a lot, most code will take the smem path 157 const int32 kMaxSmemBytesPerformance = 16384; 158 if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) 159 concat_variable_kernel<T, IntType, true> 160 <<<config.block_count, config.thread_per_block, smem_usage, 161 gpu_device.stream()>>>(input_ptrs, output_scan, 162 output->dimension(0), output->dimension(1), 163 output->data()); 164 else 165 concat_variable_kernel<T, IntType, false> 166 <<<config.block_count, config.thread_per_block, 0, 167 gpu_device.stream()>>>(input_ptrs, output_scan, 168 output->dimension(0), output->dimension(1), 169 output->data()); 170 } 171 } 172 173 #define REGISTER_GPUCONCAT32(T) \ 174 template void ConcatGPUSlice<T, int32>( \ 175 const Eigen::GpuDevice& gpu_device, \ 176 const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \ 177 inputs_flat, \ 178 typename TTypes<T, 2>::Matrix* output); 179 180 #define REGISTER_GPUCONCAT64(T) \ 181 template void ConcatGPUSlice<T, int64>( \ 182 const Eigen::GpuDevice& gpu_device, \ 183 const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \ 184 inputs_flat, \ 185 typename TTypes<T, 2>::Matrix* output); 186 187 #define REGISTER_GPU32(T) \ 188 template void ConcatGPUImpl<T, int32>( \ 189 const Eigen::GpuDevice& d, \ 190 const CudaDeviceArrayStruct<const T*>& input_ptrs, \ 191 const CudaDeviceArrayStruct<int32>& ptr_offsets, bool fixed_size, \ 192 int split_size, typename TTypes<T, 2>::Matrix* output); 193 194 #define REGISTER_GPU64(T) \ 195 template void ConcatGPUImpl<T, int64>( \ 196 const Eigen::GpuDevice& d, \ 197 const CudaDeviceArrayStruct<const T*>& input_ptrs, \ 198 const CudaDeviceArrayStruct<int64>& ptr_offsets, bool fixed_size, \ 199 int split_size, typename TTypes<T, 2>::Matrix* output); 200 201 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32); 202 TF_CALL_complex64(REGISTER_GPUCONCAT32); 203 TF_CALL_complex128(REGISTER_GPUCONCAT32); 204 TF_CALL_int64(REGISTER_GPUCONCAT32); 205 REGISTER_GPUCONCAT32(bfloat16); 206 REGISTER_GPUCONCAT32(bool); 207 208 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64); 209 TF_CALL_complex64(REGISTER_GPUCONCAT64); 210 TF_CALL_complex128(REGISTER_GPUCONCAT64); 211 TF_CALL_int64(REGISTER_GPUCONCAT64); 212 REGISTER_GPUCONCAT64(bfloat16); 213 REGISTER_GPUCONCAT64(bool); 214 215 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32); 216 TF_CALL_complex64(REGISTER_GPU32); 217 TF_CALL_complex128(REGISTER_GPU32); 218 TF_CALL_int64(REGISTER_GPU32); 219 REGISTER_GPU32(bfloat16); 220 REGISTER_GPU32(bool); 221 222 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64); 223 TF_CALL_complex64(REGISTER_GPU64); 224 TF_CALL_complex128(REGISTER_GPU64); 225 TF_CALL_int64(REGISTER_GPU64); 226 REGISTER_GPU64(bfloat16); 227 REGISTER_GPU64(bool); 228 229 #undef REGISTER_GPUCONCAT32 230 #undef REGISTER_GPUCONCAT64 231 #undef REGISTER_GPU32 232 #undef REGISTER_GPU64 233 234 } // end namespace tensorflow 235 236 #endif // GOOGLE_CUDA 237