     19 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     23 #if CUDA_VERSION >= 9000
     25 #endif  // CUDA_VERSION >= 9000
     27 #include "third_party/cub/block/block_load.cuh"
     28 #include "third_party/cub/block/block_scan.cuh"
     29 #include "third_party/cub/block/block_store.cuh"
     30 #include "third_party/cub/iterator/counting_input_iterator.cuh"
     31 #include "third_party/cub/iterator/transform_input_iterator.cuh"
     32 #include "cuda/include/cuComplex.h"
     33 #include "tensorflow/core/framework/numeric_types.h"
     34 #include "tensorflow/core/framework/register_types.h"
     35 #include "tensorflow/core/util/cuda_launch_config.h"
     36 #include "tensorflow/core/util/permutation_input_iterator.h"
     37 #include "tensorflow/core/util/permutation_output_iterator.h"
     39 #include "tensorflow/core/kernels/scan_ops.h"
     41 namespace tensorflow {
     43 typedef Eigen::GpuDevice GPUDevice;
     44 typedef Eigen::Index Index;
     46 namespace functor {
     48 // Map a contiguous range to the actual memory locations depending on which
     49 // axis the scan is taking place over and whether or not reversed.
     50 struct MapIndexToLocation {
     51   __host__ __device__ MapIndexToLocation(int dimx, int dimy, int dimz,
     52                                          bool reverse = false)
     53       : dimx_(dimx), dimy_(dimy), dimz_(dimz), reverse_(reverse) {}
     55   __host__ __device__ int operator()(int id) const {
     56     if (dimx_ == 1) {
     57       int row = id % dimy_;
     58       int col = id / dimy_;
     60       if (reverse_) return (dimy_ - row - 1) * dimz_ + col;
     62       return row * dimz_ + col;
     63     } else if (dimz_ == 1) {
     64       if (reverse_) {
     65         int row = id / dimy_;
     66         int col = id % dimy_;
     67         return row * dimy_ + (dimy_ - col - 1);
     68       }
     69       return id;
     70     } else {
     71       int col = id % dimy_;
     72       int tmp = id / dimy_;
     74       int row1 = id / (dimy_ * dimz_);
     75       int col1 = tmp % dimz_;
     77       if (reverse_)
     78         return row1 * dimy_ * dimz_ + (dimy_ - col - 1) * dimz_ + col1;
     80       return row1 * dimy_ * dimz_ + col * dimz_ + col1;
     81     }
     82   }
     84   int dimx_;
     85   int dimy_;
     86   int dimz_;
     87   bool reverse_;
     88 };
     90 template <typename T, typename Op>
     91 struct BlockPrefixCallbackOp {
     92   // Running prefix
     93   T running_total_;
     94   Op op_;
     96   __device__ BlockPrefixCallbackOp(T running_total, Op op)
     97       : running_total_(running_total), op_(op) {}
     99   // Callback operator to be entered by the first warp of threads in the block.
    100   // tid 0 is responsible for returning a value for seeding the block-wide scan.
    101   __device__ T operator()(T block_aggregate) {
    102     T old_prefix = running_total_;
    103     running_total_ = op_(old_prefix, block_aggregate);
    104     return old_prefix;
    105   }
    106 };
    108 template <typename T>
    109 struct Sum {
    110   __host__ __device__ T operator()(const T& a, const T& b) const {
    111     return a + b;
    112   }
    113 };
    115 template <typename T>
    116 struct Prod {
    117   __host__ __device__ T operator()(const T& a, const T& b) const {
    118     return a * b;
    119   }
    120 };
    122 template <typename T, typename Op>
    123 struct IsSum {
    124   constexpr static bool value =
    125       (std::is_same<Op, Sum<T>>::value ||
    126        std::is_same<Op, Eigen::internal::SumReducer<T>>::value);
    127 };
    129 template <typename T, typename Op>
    130 struct IsProd {
    131   constexpr static bool value =
    132       (std::is_same<Op, Prod<T>>::value ||
    133        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
    134 };
    136 template <typename T, typename Op>
    137 struct IdentityValue {
    138   static_assert(IsSum<T, Op>::value || IsProd<T, Op>::value,
    139                 "IdentityValue not yet defined for this type.");
    141   template <typename U = T, typename OpCopy = Op>
    142   __host__ __device__ U operator()(
    143       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
    144     return t;
    145   }
    147   template <typename U = T, typename OpCopy = Op>
    148   __host__ __device__ U operator()(
    149       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
    150     return t;
    151   }
    152 };
    154 // Each block is mapped to one sequence.  A contiguous range is mapped to the
    155 // appropriate locations in memory by the permutation iterators.  This is
    156 // ideal for 1-D and row based scans.  Column scans would be better if they
    157 // did a block load and then locally transposed.  CUB's device wide scan is not
    158 // used in the large 1D case, even though it would be more efficient, because
    159 // it is not deterministic.
    160 template <typename T, typename Op, int BlockDim = 128, int ItemsPerThread = 4>
    161 __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz,
    162                             bool exclusive, bool reverse, Op op) {
    163   typedef cub::BlockLoad<T, BlockDim, ItemsPerThread, cub::BLOCK_LOAD_TRANSPOSE>
    164       BlockLoad;
    165   typedef cub::BlockStore<T, BlockDim, ItemsPerThread,
    166                           cub::BLOCK_STORE_TRANSPOSE>
    167       BlockStore;
    168   typedef cub::BlockScan<T, BlockDim> BlockScan;
    170   // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan
    171   __shared__ union {
    172     typename BlockLoad::TempStorage load;
    173     typename BlockScan::TempStorage scan;
    174     typename BlockStore::TempStorage store;
    175   } temp_storage;
    177   int problem_length = dimy;
    179   // Initialize running total
    180   BlockPrefixCallbackOp<T, Op> prefix_op(IdentityValue<T, Op>()(), op);
    182   MapIndexToLocation map_op(dimx, dimy, dimz, reverse);
    183   int block_start = problem_length * blockIdx.x;
    184   // Have the block iterate over segments of items
    185   for (int block_offset = block_start;
    186        block_offset < block_start + problem_length;
    187        block_offset += BlockDim * ItemsPerThread) {
    188     int valid_items = min(BlockDim * ItemsPerThread,
    189                           problem_length - (block_offset % problem_length));
    191     // first construct a counting iterator that has the desired start point
    192     typedef cub::TransformInputIterator<int, MapIndexToLocation,
    193                                         cub::CountingInputIterator<int>>
    194         MapIterType;
    196     cub::CountingInputIterator<int> counting_iter(block_offset);
    198     // Next map the iterator to the actual locations in memory
    199     MapIterType map_iter(counting_iter, map_op);
    201     PermutationInputIterator<T, const T*, MapIterType> permutein_iter(in,
    202                                                                       map_iter);
    203     PermutationOutputIterator<T, T*, MapIterType> permuteout_iter(out,
    204                                                                   map_iter);
    206     // Load a segment of consecutive items that are blocked across threads
    207     T thread_data[ItemsPerThread];
    208     BlockLoad(temp_storage.load).Load(permutein_iter, thread_data, valid_items);
    209     __syncthreads();
    211     // Collectively compute the block-wide scan
    212     if (exclusive) {
    213       BlockScan(temp_storage.scan)
    214           .ExclusiveScan(thread_data, thread_data, op, prefix_op);
    215     } else {
    216       BlockScan(temp_storage.scan)
    217           .InclusiveScan(thread_data, thread_data, op, prefix_op);
    218     }
    219     __syncthreads();
    221     // Store scanned items to output segment
    222     BlockStore(temp_storage.store)
    223         .Store(permuteout_iter, thread_data, valid_items);
    224     __syncthreads();
    225   }
    226 }
    228 template <typename T, typename Op>
    229 void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    230                 typename TTypes<T, 3>::Tensor out, Op op, const bool reverse,
    231                 const bool exclusive) {
    232   const int items_per_thread = 4;
    234   int dimx = in.dimension(0);
    235   int dimy = in.dimension(1);
    236   int dimz = in.dimension(2);
    237   int num_blocks = dimx * dimz;
    239   int ideal_block_size = dimy / items_per_thread;
    241   // There seems to be a bug when the type is not float and block_size 1024.
    242   // Launch on the smallest power of 2 block size that we can.
    243   if (ideal_block_size >= 1024 && std::is_same<T, float>::value) {
    244     const int block_size = 1024;
    245     TF_CHECK_OK(
    246         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    247                          num_blocks, block_size, 0, d.stream(), in.data(),
    248                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    249   } else if (ideal_block_size >= 512) {
    250     const int block_size = 512;
    251     TF_CHECK_OK(
    252         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    253                          num_blocks, block_size, 0, d.stream(), in.data(),
    254                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    255   } else if (ideal_block_size >= 256) {
    256     const int block_size = 256;
    257     TF_CHECK_OK(
    258         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    259                          num_blocks, block_size, 0, d.stream(), in.data(),
    260                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    261   } else if (ideal_block_size >= 128) {
    262     const int block_size = 128;
    263     TF_CHECK_OK(
    264         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    265                          num_blocks, block_size, 0, d.stream(), in.data(),
    266                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    267   } else if (ideal_block_size >= 64) {
    268     const int block_size = 64;
    269     TF_CHECK_OK(
    270         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    271                          num_blocks, block_size, 0, d.stream(), in.data(),
    272                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    273   } else {
    274     const int block_size = 32;
    275     TF_CHECK_OK(
    276         CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
    277                          num_blocks, block_size, 0, d.stream(), in.data(),
    278                          out.data(), dimx, dimy, dimz, exclusive, reverse, op));
    279   }
    280 }
    282 template <typename T>
    283 struct Scan<GPUDevice, Eigen::internal::SumReducer<T>, T> {
    284   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    285                   typename TTypes<T, 3>::Tensor out,
    286                   const Eigen::internal::SumReducer<T>& reducer,
    287                   const bool reverse, const bool exclusive) {
    288     LaunchScan<T, Sum<T>>(d, in, out, Sum<T>(), reverse, exclusive);
    289   }
    290 };
    292 template <typename T>
    293 struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
    294   void operator()(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
    295                   typename TTypes<T, 3>::Tensor out,
    296                   const Eigen::internal::ProdReducer<T>& reducer,
    297                   const bool reverse, const bool exclusive) {
    298     LaunchScan<T, Prod<T>>(d, in, out, Prod<T>(), reverse, exclusive);
    299   }
    300 };
    302 }  // namespace functor
    303 }  // end namespace tensorflow
    305 #endif  // GOOGLE_CUDA