Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
     17 #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
     18 
     19 #if GOOGLE_CUDA
     20 
     21 #define EIGEN_USE_GPU
     22 
     23 #if CUDA_VERSION >= 9000
     24 #define CUB_USE_COOPERATIVE_GROUPS
     25 #endif  // CUDA_VERSION >= 9000
     26 
     27 #include "third_party/cub/block/block_load.cuh"
     28 #include "third_party/cub/block/block_scan.cuh"
     29 #include "third_party/cub/block/block_store.cuh"
     30 #include "third_party/cub/iterator/counting_input_iterator.cuh"
     31 #include "third_party/cub/iterator/transform_input_iterator.cuh"
     32 #include "cuda/include/cuComplex.h"
     33 #include "tensorflow/core/framework/numeric_types.h"
     34 #include "tensorflow/core/framework/register_types.h"
     35 #include "tensorflow/core/util/cuda_launch_config.h"
     36 #include "tensorflow/core/util/permutation_input_iterator.h"
     37 #include "tensorflow/core/util/permutation_output_iterator.h"
     38 
     39 #include "tensorflow/core/kernels/scan_ops.h"
     40 
     41 namespace tensorflow {
     42 
     43 typedef Eigen::GpuDevice GPUDevice;
     44 typedef Eigen::Index Index;
     45 
     46 namespace functor {
     47 
     48 // Map a contiguous range to the actual memory locations depending on which
     49 // axis the scan is taking place over and whether or not reversed.
     50 struct MapIndexToLocation {
     51   __host__ __device__ MapIndexToLocation(int dimx, int dimy, int dimz,
     52                                          bool reverse = false)
     53       : dimx_(dimx), dimy_(dimy), dimz_(dimz), reverse_(reverse) {}
     54 
     55   __host__ __device__ int operator()(int id) const {
     56     if (dimx_ == 1) {
     57       int row = id % dimy_;
     58       int col = id / dimy_;
     59 
     60       if (reverse_) return (dimy_ - row - 1) * dimz_ + col;
     61 
     62       return row * dimz_ + col;
     63     } else if (dimz_ == 1) {
     64       if (reverse_) {
     65         int row = id / dimy_;
     66         int col = id % dimy_;
     67         return row * dimy_ + (dimy_ - col - 1);
     68       }
     69       return id;
     70     } else {
     71       int col = id % dimy_;
     72       int tmp = id / dimy_;
     73 
     74       int row1 = id / (dimy_ * dimz_);
     75       int col1 = tmp % dimz_;
     76 
     77       if (reverse_)
     78         return row1 * dimy_ * dimz_ + (dimy_ - col - 1) * dimz_ + col1;
     79 
     80       return row1 * dimy_ * dimz_ + col * dimz_ + col1;
     81     }
     82   }
     83 
     84   int dimx_;
     85   int dimy_;
     86   int dimz_;
     87   bool reverse_;
     88 };
     89 
     90 template <typename T, typename Op>
     91 struct BlockPrefixCallbackOp {
     92   // Running prefix
     93   T running_total_;
     94   Op op_;
     95 
     96   __device__ BlockPrefixCallbackOp(T running_total, Op op)
     97       : running_total_(running_total), op_(op) {}
     98 
     99   // Callback operator to be entered by the first warp of threads in the block.
    100   // tid 0 is responsible for returning a value for seeding the block-wide scan.
    101   __device__ T operator()(T block_aggregate) {
    102     T old_prefix = running_total_;
    103     running_total_ = op_(old_prefix, block_aggregate);
    104     return old_prefix;
    105   }
    106 };
    107 
    108 template <typename T>
    109 struct Sum {
    110   __host__ __device__ T operator()(const T& a, const T& b) const {
    111     return a + b;
    112   }
    113 };
    114 
    115 template <typename T>
    116 struct Prod {
    117   __host__ __device__ T operator()(const T& a, const T& b) const {
    118     return a * b;
    119   }
    120 };
    121 
    122 template <typename T, typename Op>
    123 struct IsSum {
    124   constexpr static bool value =
    125       (std::is_same<Op, Sum<T>>::value ||
    126        std::is_same<Op, Eigen::internal::SumReducer<T>>::value);
    127 };
    128 
    129 template <typename T, typename Op>
    130 struct IsProd {
    131   constexpr static bool value =
    132       (std::is_same<Op, Prod<T>>::value ||
    133        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
    134 };
    135 
    136 template <typename T, typename Op>
    137 struct IdentityValue {
    138   static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value,
    139                 "IdentityValue not yet defined for this type.");
    140 
    141   template <typename U = T, typename OpCopy = Op>
    142   __host__ __device__ U operator()(
    143       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
    144     return t;
    145   }
    146 
    147   template <typename U = T, typename OpCopy = Op>
    148   __host__ __device__ U operator()(
    149       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
    150     return t;
    151   }
    152 };
    153 
    154 // Each block is mapped to one sequence.  A contiguous range is mapped to the
    155 // appropriate locations in memory by the permutation iterators.  This is
    156 // ideal for 1-D and row based scans.  Column scans would be better if they
    157 // did a block load and then locally transposed.  CUB's device wide scan is not
    158 // used in the large 1D case, even though it would be more efficient, because
    159 // it is not deterministic.
    160 template <typename T, typename Op, int BlockDim = 128, int ItemsPerThread = 4>
    161 __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz,
    162                             bool exclusive, bool reverse, Op op) {
    163   typedef cub::BlockLoad<T, BlockDim, ItemsPerThread, cub::BLOCK_LOAD_TRANSPOSE>
    164       BlockLoad;
    165   typedef cub::BlockStore<T, BlockDim, ItemsPerThread,
    166                           cub::BLOCK_STORE_TRANSPOSE>
    167       BlockStore;
    168   typedef cub::BlockScan<T, BlockDim> BlockScan;
    169 
    170   // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan
    171   __shared__ union {
    172     typename BlockLoad::TempStorage load;
    173     typename BlockScan::TempStorage scan;
    174     typename BlockStore::TempStorage store;
    175   } temp_storage;
    176 
    177   int problem_length = dimy;
    178 
    179   // Initialize running total
    180   BlockPrefixCallbackOp<T, Op> prefix_op(IdentityValue<T, Op>()(), op);
    181 
    182   MapIndexToLocation map_op(dimx, dimy, dimz, reverse);
    183   int block_start = problem_length * blockIdx.x;
    184   // Have the block iterate over segments of items
    185   for (int block_offset = block_start;
    186        block_offset < block_start + problem_length;
    187        block_offset += BlockDim * ItemsPerThread) {
    188     int valid_items = min(BlockDim * ItemsPerThread,
    189                           problem_length - (block_offset % problem_length));
    190 
    191     // first construct a counting iterator that has the desired start point
    192     typedef cub::TransformInputIterator<int, MapIndexToLocation,
    193                                         cub::CountingInputIterator<int>>
    194         MapIterType;
    195 
    196     cub::CountingInputIterator<int> counting_iter(block_offset);
    197 
    198     // Next map the iterator to the actual locations in memory
    199     MapIterType map_iter(counting_iter, map_op);
    200 
    201     PermutationInputIterator<T, const T*, MapIterType> permutein_iter(in,
    202                                                                       map_iter);
    203     PermutationOutputIterator<T, T*, MapIterType> permuteout_iter(out,
    204                                                                   map_iter);
    205 
    206     // Load a segment of consecutive items that are blocked across threads
    207     T thread_data[ItemsPerThread];
    208     BlockLoad(temp_storage.load).Load(permutein_iter, thread_data, valid_items);
    209     __syncthreads();
    210 
    211     // Collectively compute the block-wide scan
    212     if (exclusive) {
    213       BlockScan(temp_storage.scan)
    214           .ExclusiveScan(thread_data, thread_data, op, prefix_op);
    215     } else {
    216       BlockScan(temp_storage.scan)
    217           .InclusiveScan(thread_data, thread_data, op, prefix_op);
    218     }
    219     __syncthreads();
    220 
    221     // Store scanned items to output segment
    222     BlockStore(temp_storage.store)
    223         .Store(permuteout_iter, thread_data, valid_items);
    224     __syncthreads();
    225   }
    226 }
    227 
    228 template <typename T, typename Op>
    229 void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    230                 typename TTypes<T, 3>::Tensor out, Op op, const bool reverse,
    231                 const bool exclusive) {
    232   const int items_per_thread = 4;
    233 
    234   int dimx = in.dimension(0);
    235   int dimy = in.dimension(1);
    236   int dimz = in.dimension(2);
    237   int num_blocks = dimx * dimz;
    238 
    239   int ideal_block_size = dimy / items_per_thread;
    240 
    241   // There seems to be a bug when the type is not float and block_size 1024.
    242   // Launch on the smallest power of 2 block size that we can.
    243   if (ideal_block_size >= 1024 && std::is_same<T, float>::value) {
    244     const int block_size = 1024;
    245     TF_CHECK_OK(
    246         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    247                          num_blocks, block_size, 0, d.stream(), in.data(),
    248                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    249   } else if (ideal_block_size >= 512) {
    250     const int block_size = 512;
    251     TF_CHECK_OK(
    252         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    253                          num_blocks, block_size, 0, d.stream(), in.data(),
    254                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    255   } else if (ideal_block_size >= 256) {
    256     const int block_size = 256;
    257     TF_CHECK_OK(
    258         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    259                          num_blocks, block_size, 0, d.stream(), in.data(),
    260                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    261   } else if (ideal_block_size >= 128) {
    262     const int block_size = 128;
    263     TF_CHECK_OK(
    264         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    265                          num_blocks, block_size, 0, d.stream(), in.data(),
    266                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    267   } else if (ideal_block_size >= 64) {
    268     const int block_size = 64;
    269     TF_CHECK_OK(
    270         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    271                          num_blocks, block_size, 0, d.stream(), in.data(),
    272                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    273   } else {
    274     const int block_size = 32;
    275     TF_CHECK_OK(
    276         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    277                          num_blocks, block_size, 0, d.stream(), in.data(),
    278                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    279   }
    280 }
    281 
    282 template <typename T>
    283 struct Scan<GPUDevice, Eigen::internal::SumReducer<T>, T> {
    284   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    285                   typename TTypes<T, 3>::Tensor out,
    286                   const Eigen::internal::SumReducer<T>& reducer,
    287                   const bool reverse, const bool exclusive) {
    288     LaunchScan<T, Sum<T>>(d, in, out, Sum<T>(), reverse, exclusive);
    289   }
    290 };
    291 
    292 template <typename T>
    293 struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
    294   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    295                   typename TTypes<T, 3>::Tensor out,
    296                   const Eigen::internal::ProdReducer<T>& reducer,
    297                   const bool reverse, const bool exclusive) {
    298     LaunchScan<T, Prod<T>>(d, in, out, Prod<T>(), reverse, exclusive);
    299   }
    300 };
    301 
    302 }  // namespace functor
    303 }  // end namespace tensorflow
    304 
    305 #endif  // GOOGLE_CUDA
    306 
    307 #endif  // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
    308