Home | History | Annotate | Download | only in util
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
     17 #define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
     18 
     19 #if GOOGLE_CUDA
     20 
     21 #include <algorithm>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "cuda/include/cuda.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/stream_executor.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 // Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
     31 // GetCuda3DLaunchConfig:
     32 //
     33 // There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
     34 // version uses heuristics without any knowledge of the device kernel, the other
     35 // version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
     36 // launch parameters that maximize occupancy. Currently, only the maximum
     37 // occupancy version of GetCuda3DLaunchConfig is available.
     38 //
     39 // For large number of work elements, the convention is that each kernel would
     40 // iterate through its assigned range. The return value of GetCudaLaunchConfig
     41 // is struct CudaLaunchConfig, which contains all the information needed for the
     42 // kernel launch, including: virtual number of threads, the number of threads
     43 // per block and number of threads per block used inside <<< >>> of a kernel
     44 // launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
     45 // as CudaLaunchConfig. The only difference is the dimension. The macros
     46 // CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
     47 //
     48 /* Sample code:
     49 
     50 __global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
     51   CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
     52     do_your_job_here;
     53   }
     54 }
     55 
     56 __global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
     57   CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
     58     CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
     59       do_your_job_here;
     60     }
     61   }
     62 }
     63 
     64 __global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
     65   CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
     66     CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
     67       CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
     68         do_your_job_here;
     69       }
     70     }
     71   }
     72 }
     73 
     74 void MyDriverFunc(const Eigen::GpuDevice &d) {
     75   // use heuristics
     76   CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
     77   MyKernel1D <<<config.block_count,
     78                 config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
     79   Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
     80   MyKernel2D <<<config.block_count,
     81                 config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
     82   Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
     83   MyKernel3D <<<config.block_count,
     84                 config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
     85 
     86   // maximize occupancy
     87   CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
     88   MyKernel1D <<<config.block_count,
     89                 config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
     90   Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
     91                                                   MyKernel1D, 0, 0);
     92   MyKernel2D <<<config.block_count,
     93                 config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
     94   Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
     95                                                   MyKernel1D, 0, 0);
     96   MyKernel3D <<<config.block_count,
     97                 config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
     98 }
     99 
    100 // See the test for this for more example:
    101 //
    102 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
    103 
    104 */
    105 
    106 namespace tensorflow {
    107 
    108 inline int DivUp(int a, int b) { return (a + b - 1) / b; }
    109 
    110 struct CudaLaunchConfig {
    111   // Logical number of thread that works on the elements. If each logical
    112   // thread works on exactly a single element, this is the same as the working
    113   // element count.
    114   int virtual_thread_count = -1;
    115   // Number of threads per block.
    116   int thread_per_block = -1;
    117   // Number of blocks for Cuda kernel launch.
    118   int block_count = -1;
    119 };
    120 
    121 // Calculate the Cuda launch config we should use for a kernel launch.
    122 // This is assuming the kernel is quite simple and will largely be
    123 // memory-limited.
    124 // REQUIRES: work_element_count > 0.
    125 inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
    126                                             const Eigen::GpuDevice& d) {
    127   CHECK_GT(work_element_count, 0);
    128   CudaLaunchConfig config;
    129   const int virtual_thread_count = work_element_count;
    130   const int physical_thread_count = std::min(
    131       d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
    132       virtual_thread_count);
    133   const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
    134   const int block_count =
    135       std::min(DivUp(physical_thread_count, thread_per_block),
    136                d.getNumCudaMultiProcessors());
    137 
    138   config.virtual_thread_count = virtual_thread_count;
    139   config.thread_per_block = thread_per_block;
    140   config.block_count = block_count;
    141   return config;
    142 }
    143 
    144 // Calculate the Cuda launch config we should use for a kernel launch. This
    145 // variant takes the resource limits of func into account to maximize occupancy.
    146 // REQUIRES: work_element_count > 0.
    147 template <typename DeviceFunc>
    148 inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
    149                                             const Eigen::GpuDevice& d,
    150                                             DeviceFunc func,
    151                                             size_t dynamic_shared_memory_size,
    152                                             int block_size_limit) {
    153   CHECK_GT(work_element_count, 0);
    154   CudaLaunchConfig config;
    155   int block_count = 0;
    156   int thread_per_block = 0;
    157 
    158   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
    159       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
    160       block_size_limit);
    161   CHECK_EQ(err, cudaSuccess);
    162 
    163   block_count =
    164       std::min(block_count, DivUp(work_element_count, thread_per_block));
    165 
    166   config.virtual_thread_count = work_element_count;
    167   config.thread_per_block = thread_per_block;
    168   config.block_count = block_count;
    169   return config;
    170 }
    171 
    172 // Calculate the Cuda launch config we should use for a kernel launch. This
    173 // variant takes the resource limits of func into account to maximize occupancy.
    174 // The returned launch config has thread_per_block set to fixed_block_size.
    175 // REQUIRES: work_element_count > 0.
    176 template <typename DeviceFunc>
    177 inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize(
    178     int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func,
    179     size_t dynamic_shared_memory_size, int fixed_block_size) {
    180   CHECK_GT(work_element_count, 0);
    181   CudaLaunchConfig config;
    182   int block_count = 0;
    183 
    184   cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
    185       &block_count, func, fixed_block_size, dynamic_shared_memory_size);
    186   CHECK_EQ(err, cudaSuccess);
    187   block_count = std::min(block_count * d.getNumCudaMultiProcessors(),
    188                          DivUp(work_element_count, fixed_block_size));
    189 
    190   config.virtual_thread_count = work_element_count;
    191   config.thread_per_block = fixed_block_size;
    192   config.block_count = block_count;
    193   return config;
    194 }
    195 
    196 struct Cuda2DLaunchConfig {
    197   dim3 virtual_thread_count = dim3(0, 0, 0);
    198   dim3 thread_per_block = dim3(0, 0, 0);
    199   dim3 block_count = dim3(0, 0, 0);
    200 };
    201 
    202 inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
    203                                                 const Eigen::GpuDevice& d) {
    204   Cuda2DLaunchConfig config;
    205 
    206   if (xdim <= 0 || ydim <= 0) {
    207     return config;
    208   }
    209 
    210   const int kThreadsPerBlock = 256;
    211   int block_cols = std::min(xdim, kThreadsPerBlock);
    212   // ok to round down here and just do more loops in the kernel
    213   int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
    214 
    215   const int physical_thread_count =
    216       d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
    217 
    218   const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
    219 
    220   config.virtual_thread_count = dim3(xdim, ydim, 1);
    221   config.thread_per_block = dim3(block_cols, block_rows, 1);
    222 
    223   int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
    224 
    225   config.block_count = dim3(
    226       grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
    227   return config;
    228 }
    229 
    230 // Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
    231 // This variant takes the resource limits of func into account to maximize
    232 // occupancy.
    233 using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
    234 
    235 template <typename DeviceFunc>
    236 inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
    237     int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
    238     size_t dynamic_shared_memory_size, int block_size_limit) {
    239   Cuda3DLaunchConfig config;
    240 
    241   if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
    242     return config;
    243   }
    244 
    245   int dev;
    246   cudaGetDevice(&dev);
    247   cudaDeviceProp deviceProp;
    248   cudaGetDeviceProperties(&deviceProp, dev);
    249   int xthreadlimit = deviceProp.maxThreadsDim[0];
    250   int ythreadlimit = deviceProp.maxThreadsDim[1];
    251   int zthreadlimit = deviceProp.maxThreadsDim[2];
    252   int xgridlimit = deviceProp.maxGridSize[0];
    253   int ygridlimit = deviceProp.maxGridSize[1];
    254   int zgridlimit = deviceProp.maxGridSize[2];
    255 
    256   int block_count = 0;
    257   int thread_per_block = 0;
    258   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
    259       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
    260       block_size_limit);
    261   CHECK_EQ(err, cudaSuccess);
    262 
    263   int threadsx = std::min({xdim, thread_per_block, xthreadlimit});
    264   int threadsy =
    265       std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit});
    266   int threadsz =
    267       std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
    268                 zthreadlimit});
    269 
    270   int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit});
    271   int blocksy = std::min(
    272       {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit});
    273   int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)),
    274                           DivUp(zdim, threadsz), zgridlimit});
    275 
    276   config.virtual_thread_count = dim3(xdim, ydim, zdim);
    277   config.thread_per_block = dim3(threadsx, threadsy, threadsz);
    278   config.block_count = dim3(blocksx, blocksy, blocksz);
    279   return config;
    280 }
    281 
    282 template <typename DeviceFunc>
    283 inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
    284     int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
    285     size_t dynamic_shared_memory_size, int block_size_limit) {
    286   return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
    287                                dynamic_shared_memory_size, block_size_limit);
    288 }
    289 
    290 // Returns a raw reference to the current cuda stream.  Required by a
    291 // number of kernel calls (for which StreamInterface* does not work), i.e.
    292 // CUB and certain cublas primitives.
    293 inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
    294   const cudaStream_t* ptr = CHECK_NOTNULL(
    295       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
    296                                                 ->stream()
    297                                                 ->implementation()
    298                                                 ->CudaStreamMemberHack()));
    299   return *ptr;
    300 }
    301 
    302 }  // namespace tensorflow
    303 
    304 #endif  // GOOGLE_CUDA
    305 
    306 #endif  // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
    307