Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "external/cub_archive/cub/device/device_reduce.cuh"
     22 #include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
     23 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
     24 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
     25 #include "external/cub_archive/cub/warp/warp_reduce.cuh"
     26 #include "cuda/include/cuComplex.h"
     27 #include "tensorflow/core/kernels/reduction_ops.h"
     28 #include "tensorflow/core/lib/core/bits.h"
     29 #include "tensorflow/core/util/cuda_kernel_helper.h"
     30 #include "tensorflow/core/util/permutation_input_iterator.h"
     31 #include "tensorflow/core/util/transform_output_iterator.h"
     32 
     33 #include <sstream>
     34 
     35 namespace tensorflow {
     36 namespace functor {
     37 
     38 typedef Eigen::GpuDevice GPUDevice;
     39 
     40 template <typename T>
     41 struct Sum {
     42   __host__ __device__ T operator()(const T& a, const T& b) const {
     43     return a + b;
     44   }
     45 };
     46 
     47 // needed to work around a compiler bug in nvcc - it doesn't seem to like
     48 // the overloaded addition op for std::complex
     49 template <>
     50 struct Sum<std::complex<float>> {
     51   __host__ __device__ std::complex<float> operator()(
     52       const std::complex<float>& a, const std::complex<float>& b) const {
     53     auto result = cuCaddf(make_cuComplex(a.real(), a.imag()),
     54                           make_cuComplex(b.real(), b.imag()));
     55     return std::complex<float>(result.x, result.y);
     56   }
     57 };
     58 
     59 template <>
     60 struct Sum<std::complex<double>> {
     61   __host__ __device__ std::complex<double> operator()(
     62       const std::complex<double>& a, const std::complex<double>& b) const {
     63     auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()),
     64                          make_cuDoubleComplex(b.real(), b.imag()));
     65     return std::complex<double>(result.x, result.y);
     66   }
     67 };
     68 
     69 template <typename T>
     70 struct Prod {
     71   __host__ __device__ T operator()(const T& a, const T& b) const {
     72     return a * b;
     73   }
     74 };
     75 
     76 // needed to work around a compiler bug in nvcc - it doesn't seem to like
     77 // the overloaded multiply op for std::complex
     78 template <>
     79 struct Prod<std::complex<float>> {
     80   __host__ __device__ std::complex<float> operator()(
     81       const std::complex<float>& a, const std::complex<float>& b) const {
     82     auto result = cuCmulf(make_cuComplex(a.real(), a.imag()),
     83                           make_cuComplex(b.real(), b.imag()));
     84     return std::complex<float>(result.x, result.y);
     85   }
     86 };
     87 
     88 template <>
     89 struct Prod<std::complex<double>> {
     90   __host__ __device__ std::complex<double> operator()(
     91       const std::complex<double>& a, const std::complex<double>& b) const {
     92     auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()),
     93                          make_cuDoubleComplex(b.real(), b.imag()));
     94     return std::complex<double>(result.x, result.y);
     95   }
     96 };
     97 
     98 template <typename T, typename outT = T>
     99 struct DividesBy {
    100   T divisor;
    101 
    102   __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {}
    103 
    104   __host__ __device__ outT operator()(const T& x) const { return x / divisor; }
    105 };
    106 
    107 // needed to work around a compiler bug in nvcc - it doesn't seem to like
    108 // the overloaded ops for std::complex
    109 template <>
    110 struct DividesBy<std::complex<float>> {
    111   cuFloatComplex divisor;
    112 
    113   __host__ __device__ explicit DividesBy(std::complex<float> divisor)
    114       : divisor(make_cuComplex(divisor.real(), divisor.imag())) {}
    115 
    116   // implements
    117   __host__ __device__ std::complex<float> operator()(
    118       const std::complex<float>& x) const {
    119     auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor);
    120     return std::complex<float>(result.x, result.y);
    121   }
    122 };
    123 
    124 template <>
    125 struct DividesBy<std::complex<double>> {
    126   cuDoubleComplex divisor;
    127 
    128   __host__ __device__ explicit DividesBy(std::complex<double> divisor)
    129       : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {}
    130 
    131   // implements
    132   __host__ __device__ std::complex<double> operator()(
    133       const std::complex<double>& x) const {
    134     auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor);
    135     return std::complex<double>(result.x, result.y);
    136   }
    137 };
    138 
    139 template <>
    140 struct DividesBy<float, Eigen::half> {
    141   float divisor;
    142 
    143   __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {}
    144 
    145   __host__ __device__ Eigen::half operator()(const float& x) const {
    146     return Eigen::half(x / divisor);
    147   }
    148 };
    149 
    150 struct HalfToFloat {
    151   __host__ __device__ float operator()(const Eigen::half& x) const {
    152     return Eigen::half_impl::half_to_float(x);
    153   }
    154 };
    155 
    156 struct FloatToHalf {
    157   __host__ __device__ Eigen::half operator()(const float& x) const {
    158     return Eigen::half_impl::float_to_half_rtne(x);
    159   }
    160 };
    161 
    162 struct And {
    163   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
    164     return a && b;
    165   }
    166 };
    167 
    168 struct Or {
    169   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
    170     return a || b;
    171   }
    172 };
    173 
    174 // each block does a grid strided loop and reduces its values locally
    175 // the case of one block is used for low latency small reductions to scalars
    176 template <typename T, typename outT, int num_threads, typename Op>
    177 __global__ void BlockReduceKernel(
    178     T in, outT out, int num_elems, Op op,
    179     typename std::iterator_traits<T>::value_type initVal) {
    180   const int bid = blockIdx.x;
    181   const int tid = threadIdx.x;
    182 
    183   const int gid = bid * blockDim.x + tid;
    184   const int stride = blockDim.x * gridDim.x;
    185 
    186   typedef typename std::iterator_traits<T>::value_type value_type;
    187 
    188   value_type sum = initVal;
    189   if (gid < num_elems) {
    190     sum = in[gid];
    191     for (int pos = gid + stride; pos < num_elems; pos += stride) {
    192       sum = op(sum, in[pos]);
    193     }
    194   }
    195 
    196   typedef cub::BlockReduce<value_type, num_threads> BlockReduce;
    197 
    198   __shared__ typename BlockReduce::TempStorage temp_storage;
    199 
    200   // only include input values in the reduction
    201   //
    202   // elements: -----------------
    203   // grid:     |====|====|====|====|====|
    204   const int num_elements_to_reduce =
    205       max(min(num_elems - bid * blockDim.x, num_threads), 0);
    206 
    207   sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce);
    208 
    209   if (tid == 0) out[bid] = sum;
    210 }
    211 
    212 // maps a warp to each row
    213 template <typename T, typename outT, typename Op>
    214 __global__ void RowReduceKernel(
    215     T in, outT out, int num_rows, int num_cols, Op op,
    216     typename std::iterator_traits<T>::value_type initVal) {
    217   typedef typename std::iterator_traits<T>::value_type value_type;
    218   const int row = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
    219   const int lane = threadIdx.x % 32;
    220 
    221   if (num_cols == 1) {
    222     int gid = threadIdx.x + blockIdx.x * blockDim.x;
    223     if (gid < num_rows) out[gid] = in[gid];
    224     return;
    225   }
    226 
    227   value_type sum = initVal;
    228   int col = lane;
    229 
    230   if (row < num_rows && col < num_cols) {
    231     sum = in[row * num_cols + col];
    232     col += 32;
    233     for (; col < num_cols; col += 32) {
    234       sum = op(sum, in[row * num_cols + col]);
    235     }
    236   }
    237 
    238   typedef cub::WarpReduce<value_type> WarpReduce;
    239 
    240   __shared__ typename WarpReduce::TempStorage temp_storage;
    241 
    242   sum = WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, 32));
    243 
    244   if (row < num_rows && lane == 0) out[row] = sum;
    245 }
    246 
    247 // Works only if there are <= 16 columns
    248 // each warps sums over multiple rows at once
    249 template <typename T, typename outT, typename Op>
    250 __global__ void ColumnReduceMax16ColumnsKernel(
    251     T in, outT out, int num_rows, int num_cols, Op op,
    252     typename std::iterator_traits<T>::value_type initVal) {
    253   typedef typename std::iterator_traits<T>::value_type value_type;
    254   int rows_per_warp = 32 / num_cols;
    255 
    256   const int lane = threadIdx.x % 32;
    257   const int lane_row = lane / num_cols;
    258 
    259   const int start_row_warp =
    260       rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y);
    261   const int start_row_lane = start_row_warp + lane_row;
    262   int row = start_row_lane;
    263   int col = lane % num_cols;
    264 
    265   value_type sum = initVal;
    266   if (row * num_cols + col < num_rows * num_cols)
    267     sum = in[row * num_cols + col];
    268 
    269   // 1D array necessary due to bug in CUDA 9 compiler.
    270   // TODO(nluehr) revert to 2D array when compiler is ready.
    271   __shared__ value_type partial_sums[32 * 33];
    272 
    273   row += rows_per_warp * gridDim.y * blockDim.y;
    274   for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
    275     int global_pos = row * num_cols + col;
    276     if (global_pos < (num_rows * num_cols))
    277       sum = op(sum, in[row * num_cols + col]);
    278   }
    279 
    280   const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
    281   // not the most efficient way to do this sum
    282   for (int i = 1; i < rows_in_this_warp; ++i) {
    283     value_type tmp =
    284         cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff);
    285     if (lane < num_cols) sum = op(sum, tmp);
    286   }
    287 
    288   if (lane < num_cols) partial_sums[lane * 33 + threadIdx.y] = sum;
    289 
    290   __syncthreads();
    291 
    292   if (threadIdx.y == 0 && threadIdx.x < num_cols) {
    293     value_type s = partial_sums[threadIdx.x * 33];
    294 
    295     if (blockDim.y > 1) {
    296       for (int row = 1; row < blockDim.y; ++row) {
    297         s = op(s, partial_sums[threadIdx.x * 33 + row]);
    298       }
    299     }
    300 
    301     out[col * gridDim.y + blockIdx.y] = s;
    302   }
    303 }
    304 
    305 // Maps each block to a column range 32 wide
    306 template <typename T, typename outT, typename Op>
    307 __global__ void ColumnReduceKernel(
    308     T in, outT out, int num_rows, int num_cols, Op op,
    309     typename std::iterator_traits<T>::value_type initVal) {
    310   typedef typename std::iterator_traits<T>::value_type value_type;
    311   int row = blockIdx.y * blockDim.y + threadIdx.y;
    312   int col = blockIdx.x * 32 + threadIdx.x;
    313 
    314   value_type sum = initVal;
    315   if (row < num_rows && col < num_cols) sum = in[row * num_cols + col];
    316 
    317   // 1D array necessary due to bug in CUDA 9 compiler.
    318   // TODO(nluehr) revert to 2D array when compiler is ready.
    319   __shared__ value_type partial_sums[32 * 33];
    320 
    321   row += gridDim.y * blockDim.y;
    322 
    323   if (col < num_cols) {
    324     for (; row < num_rows; row += gridDim.y * blockDim.y) {
    325       sum = op(sum, in[row * num_cols + col]);
    326     }
    327   }
    328 
    329   partial_sums[threadIdx.x * 33 + threadIdx.y] = sum;
    330 
    331   __syncthreads();
    332 
    333   if (threadIdx.y == 0 && col < num_cols) {
    334     value_type s = partial_sums[threadIdx.x * 33];
    335 
    336     // only include input values in the reduction
    337     // elem   block_rows
    338     //  -         =
    339     //  -         =
    340     //  #         #  block boundary
    341     //  -         =
    342     //  -         =
    343     //  #         #  block boundary
    344     //  -         =
    345     //            =
    346     const int numRowsThisBlock =
    347         min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
    348 
    349     for (int row = 1; row < numRowsThisBlock; ++row) {
    350       s = op(s, partial_sums[threadIdx.x * 33 + row]);
    351     }
    352 
    353     out[col * gridDim.y + blockIdx.y] = s;
    354   }
    355 }
    356 
    357 // does multiple warp size segmented reductions in parallel
    358 // segments cannot cross warp boundaries (mainly used for reducing the segments
    359 // that come from the Max16Columns column reduction kernel)
    360 template <typename T, typename outT, typename Op>
    361 __global__ void CleanupSegments(
    362     T partial_sums, outT out, int num_rows, int num_cols, int segment_size,
    363     Op op, typename std::iterator_traits<T>::value_type initVal) {
    364   typedef typename std::iterator_traits<T>::value_type value_type;
    365   const int tid = threadIdx.x + blockIdx.x * blockDim.x;
    366 
    367   value_type val = initVal;
    368   if (tid < segment_size * num_cols) val = partial_sums[tid];
    369 
    370   typedef cub::WarpReduce<value_type> WarpReduce;
    371 
    372   __shared__ typename WarpReduce::TempStorage temp_storage;
    373 
    374   const bool head_flag = (threadIdx.x % segment_size) == 0;
    375   value_type sum =
    376       WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op);
    377 
    378   if (head_flag && tid < segment_size * num_cols) {
    379     out[tid / segment_size] = sum;
    380   }
    381 }
    382 
    383 // assigns one thread to a column
    384 template <typename T, typename outT, typename Op>
    385 __global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes,
    386                                          int num_rows, int num_cols, Op op) {
    387   typedef typename std::iterator_traits<T>::value_type value_type;
    388   const int gid = threadIdx.x + blockIdx.x * blockDim.x;
    389   const int elems_per_plane = num_rows * num_cols;
    390 
    391   const int plane = gid / num_cols;
    392   const int col = gid % num_cols;
    393 
    394   if (plane >= num_planes) return;
    395 
    396   if (num_rows == 1) {
    397     out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col];
    398     return;
    399   }
    400 
    401   value_type sum = op(in[plane * elems_per_plane + col],
    402                       in[plane * elems_per_plane + num_cols + col]);
    403   for (int row = 2; row < num_rows; ++row) {
    404     sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]);
    405   }
    406 
    407   out[plane * num_cols + col] = sum;
    408 }
    409 
    410 struct RowOffset {
    411   __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {}
    412 
    413   __host__ __device__ int operator()(const int& x) const { return cols_ * x; }
    414 
    415   int cols_;
    416 };
    417 
    418 struct GatherOp {
    419   __host__ __device__ GatherOp(const int& extent_x, const int& extent_y,
    420                                const int& extent_z, bool kOne)
    421       : extent_x_(extent_x),
    422         extent_y_(extent_y),
    423         extent_z_(extent_z),
    424         kOne_(kOne) {
    425     if (kOne_)
    426       group_size_ = extent_y_;
    427     else
    428       group_size_ = extent_x_ * extent_z_;
    429   }
    430 
    431   __host__ __device__ int operator()(const int& ind) const {
    432     const int group = kOne_ ? ind / group_size_ : ind % group_size_;
    433     const int offset = kOne_ ? ind % group_size_ : ind / group_size_;
    434 
    435     const int x = group / extent_z_;
    436     const int z = group % extent_z_;
    437 
    438     return x * extent_y_ * extent_z_ + z + offset * extent_z_;
    439   }
    440 
    441   int extent_x_;
    442   int extent_y_;
    443   int extent_z_;
    444   bool kOne_;
    445   int group_size_;
    446 };
    447 
    448 template <typename T, typename Op, typename OUT_T, typename IN_T>
    449 void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
    450                            int in_size, Op op, T init,
    451                            const cudaStream_t& cu_stream) {
    452   // handle situations where low latency is important better than CUB
    453   if (in_size <= 4096) {
    454     const int num_blocks = 1;
    455     const int num_threads = 256;
    456     BlockReduceKernel<IN_T, OUT_T, num_threads>
    457         <<<num_blocks, num_threads, 0, cu_stream>>>(in, out, in_size, op, init);
    458     return;
    459   } else if (in_size <= 1 << 19) {
    460     const int num_threads = 256;
    461     const int num_blocks = std::min(32, Eigen::divup(in_size, num_threads));
    462     // it seems like tailoring this to the GPU
    463     // would be more effective, but all attempts
    464     // at making this a multiple of the number of
    465     // multiprocessors have lead to lower perf
    466     // in general
    467     // TODO(eriche) investigate this more
    468 
    469     Tensor temp_storage;
    470     OP_REQUIRES_OK(
    471         ctx,
    472         ctx->allocate_temp(
    473             DT_INT8, TensorShape({static_cast<int64>(num_blocks * sizeof(T))}),
    474             &temp_storage));
    475 
    476     BlockReduceKernel<IN_T, T*, num_threads>
    477         <<<num_blocks, num_threads, 0, cu_stream>>>(
    478             in, (T*)temp_storage.flat<int8_t>().data(), in_size, op, init);
    479 
    480     // take care that we only reduce blocks that had some valid elements in them
    481     // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that
    482     // requires it to be used with a full warp.  Can reduce 32 -> num_blocks
    483     // when this is fixed.
    484     CleanupSegments<<<1, 32, 0, cu_stream>>>(
    485         (T*)temp_storage.flat<int8_t>().data(), out, 1, 1, num_blocks, op,
    486         init);
    487     return;
    488   }
    489   std::size_t temp_storage_bytes = 0;
    490 
    491   Tensor temp_storage;
    492   // written as a loop because it reduces clutter
    493   // first pass allocates memory, second launches kernel(s)
    494   for (int i = 0; i < 2; ++i) {
    495     auto success = cub::DeviceReduce::Reduce(
    496         i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
    497         temp_storage_bytes, in, out, in_size, op, init, cu_stream);
    498 
    499     OP_REQUIRES(
    500         ctx, success == 0,
    501         errors::Internal("CUB reduce error", cudaGetErrorString(success)));
    502 
    503     if (i == 0)
    504       OP_REQUIRES_OK(
    505           ctx,
    506           ctx->allocate_temp(
    507               DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    508               &temp_storage));
    509   }
    510 }
    511 
    512 template <typename T, typename Op, typename OUT_T, typename IN_T>
    513 void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows,
    514                         int num_cols, Op op, T init,
    515                         const cudaStream_t& cu_stream) {
    516   if (num_cols < 1024) {
    517     const int threads_per_block = 128;
    518     const int warps_per_block = threads_per_block / 32;
    519     int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block;
    520 
    521     RowReduceKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
    522         in, out, num_rows, num_cols, op, init);
    523     return;
    524   }
    525 
    526   // setup segment offsets with counting and transform iterator
    527   RowOffset row_offset_op(num_cols);
    528   cub::CountingInputIterator<int> counting_iter(0);
    529   cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
    530       transform_iter(counting_iter, row_offset_op);
    531 
    532   std::size_t temp_storage_bytes = 0;
    533   Tensor temp_storage;
    534   for (int i = 0; i < 2; ++i) {
    535     auto success = cub::DeviceSegmentedReduce::Reduce(
    536         i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
    537         temp_storage_bytes, in, out, num_rows, transform_iter,
    538         transform_iter + 1, op, init, cu_stream);
    539 
    540     OP_REQUIRES(ctx, success == 0,
    541                 errors::Internal("CUB segmented reduce error",
    542                                  cudaGetErrorString(success)));
    543 
    544     if (i == 0)
    545       OP_REQUIRES_OK(
    546           ctx,
    547           ctx->allocate_temp(
    548               DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    549               &temp_storage));
    550   }
    551 }
    552 
    553 template <typename T, typename Op, typename OUT_T, typename IN_T>
    554 void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
    555                                      int extent_x, int extent_y, Op op, T init,
    556                                      const cudaStream_t& cu_stream) {
    557   int rows_per_warp = 32 / extent_y;
    558   dim3 block_dim(32, std::min(Eigen::divup(extent_x, rows_per_warp), 32), 1);
    559   dim3 grid_dim(1,
    560                 Eigen::divup(static_cast<unsigned int>(extent_x),
    561                              rows_per_warp * block_dim.y),
    562                 1);
    563 
    564   grid_dim.y = std::min((int)grid_dim.y, 32);
    565 
    566   if (grid_dim.y > 2 && grid_dim.y < 32) {
    567     int log2 = Log2Floor(grid_dim.y);
    568     grid_dim.y = 1 << log2;
    569   }
    570 
    571   if (grid_dim.y == 1) {
    572     ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
    573         in, out, extent_x, extent_y, op, init);
    574   } else {
    575     Tensor temp_storage;
    576     OP_REQUIRES_OK(ctx,
    577                    ctx->allocate_temp(DT_INT8,
    578                                       TensorShape({static_cast<int64>(
    579                                           sizeof(T) * extent_y * grid_dim.y)}),
    580                                       &temp_storage));
    581     ColumnReduceMax16ColumnsKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
    582         in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op,
    583         init);
    584 
    585     dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
    586     dim3 num_threads(128, 1, 1);
    587     CleanupSegments<<<new_grid_dim, num_threads, 0, cu_stream>>>(
    588         (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y,
    589         grid_dim.y, op, init);
    590   }
    591 }
    592 
    593 template <typename T, typename Op, typename OUT_T, typename IN_T>
    594 void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
    595                                        int extent_x, int extent_y, Op op,
    596                                        T init, const cudaStream_t& cu_stream) {
    597   dim3 block_dim(32, std::min(extent_x, 32), 1);
    598   dim3 grid_dim((extent_y + 31) / 32, 1, 1);
    599 
    600   if (grid_dim.x < 16) grid_dim.y = std::min((extent_x + 31) / 32, 32);
    601 
    602   if (grid_dim.y > 2 && grid_dim.y < 32) {
    603     int log2 = Log2Floor(grid_dim.y);
    604     grid_dim.y = 1 << log2;
    605   }
    606 
    607   if (grid_dim.y == 1) {
    608     ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
    609         in, out, extent_x, extent_y, op, init);
    610   } else {
    611     Tensor temp_storage;
    612     OP_REQUIRES_OK(ctx,
    613                    ctx->allocate_temp(DT_INT8,
    614                                       TensorShape({static_cast<int64>(
    615                                           sizeof(T) * extent_y * grid_dim.y)}),
    616                                       &temp_storage));
    617 
    618     ColumnReduceKernel<<<grid_dim, block_dim, 0, cu_stream>>>(
    619         in, (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op,
    620         init);
    621 
    622     dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
    623     dim3 num_threads(128, 1, 1);
    624     CleanupSegments<<<new_grid_dim, block_dim, 0, cu_stream>>>(
    625         (T*)temp_storage.flat<int8_t>().data(), out, extent_x, extent_y,
    626         grid_dim.y, op, init);
    627   }
    628 }
    629 
    630 template <typename T, typename Op, typename OUT_T, typename IN_T>
    631 void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
    632                            int extent_x, int extent_y, Op op, T init,
    633                            const cudaStream_t& cu_stream) {
    634   if (extent_y <= 16) {
    635     LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init,
    636                                     cu_stream);
    637   } else if (extent_y <= 4096) {
    638     LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op,
    639                                       init, cu_stream);
    640   } else {
    641     int threads_per_block = 128;
    642     int num_blocks = Eigen::divup(extent_y, threads_per_block);
    643 
    644     ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
    645         in, out, 1, extent_x, extent_y, op);
    646   }
    647 }
    648 
    649 template <typename T, typename Op, typename OUT_T, typename IN_T>
    650 void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
    651                         int extent_y, int extent_z, Op op, T init,
    652                         const cudaStream_t& cu_stream) {
    653   int threads_per_block = 128;
    654   int num_blocks =
    655       (extent_x * extent_z + threads_per_block - 1) / threads_per_block;
    656 
    657   // TODO(eriche): this won't be very good in the case of small x
    658   //                small z and large y.
    659   ColumnReduceSimpleKernel<<<num_blocks, threads_per_block, 0, cu_stream>>>(
    660       in, out, extent_x, extent_y, extent_z, op);
    661 }
    662 
    663 template <typename T, typename Op, typename OUT_T, typename IN_T>
    664 void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
    665                          int extent_y, int extent_z, Op op, T init,
    666                          const cudaStream_t& cu_stream) {
    667   // setup segment offsets with counting and transform iterator
    668   RowOffset row_offset_op(extent_x * extent_z);
    669   cub::CountingInputIterator<int> counting_iter(0);
    670   cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
    671       transform_iter(counting_iter, row_offset_op);
    672 
    673   GatherOp gather_op(extent_x, extent_y, extent_z, false);
    674   typedef cub::TransformInputIterator<int, GatherOp,
    675                                       cub::CountingInputIterator<int>>
    676       gatherIterType;
    677   gatherIterType gather_iter(counting_iter, gather_op);
    678 
    679   PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in,
    680                                                                  gather_iter);
    681 
    682   std::size_t temp_storage_bytes = 0;
    683   Tensor temp_storage;
    684 
    685   for (int i = 0; i < 2; ++i) {
    686     auto success = cub::DeviceSegmentedReduce::Reduce(
    687         i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
    688         temp_storage_bytes, permute_iter, out, extent_y, transform_iter,
    689         transform_iter + 1, op, init, cu_stream);
    690 
    691     OP_REQUIRES(ctx, success == 0,
    692                 errors::Internal("CUB segmented reduce error",
    693                                  cudaGetErrorString(success)));
    694 
    695     if (i == 0)
    696       OP_REQUIRES_OK(
    697           ctx,
    698           ctx->allocate_temp(
    699               DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    700               &temp_storage));
    701   }
    702 }
    703 
    704 namespace reduction_op_helper {
    705 
    706 template <typename T, typename Op>
    707 struct IsSum {
    708   constexpr static bool value =
    709       (std::is_same<Op, cub::Sum>::value ||
    710        std::is_same<Op, Eigen::internal::SumReducer<T>>::value ||
    711        std::is_same<Op, Sum<T>>::value);
    712 };
    713 
    714 template <typename T, typename Op>
    715 struct IsMax {
    716   constexpr static bool value =
    717       (std::is_same<Op, cub::Max>::value ||
    718        std::is_same<Op, Eigen::internal::MaxReducer<T>>::value);
    719 };
    720 
    721 template <typename T, typename Op>
    722 struct IsMin {
    723   constexpr static bool value =
    724       (std::is_same<Op, cub::Min>::value ||
    725        std::is_same<Op, Eigen::internal::MinReducer<T>>::value);
    726 };
    727 
    728 template <typename T, typename Op>
    729 struct IsProd {
    730   constexpr static bool value =
    731       (std::is_same<Op, Prod<T>>::value ||
    732        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
    733 };
    734 
    735 template <typename T, typename Op>
    736 struct IdentityValue {
    737   static_assert(IsSum<T, Op>::value || IsMax<T, Op>::value ||
    738                     IsMin<T, Op>::value || IsProd<T, Op>::value ||
    739                     std::is_same<Op, And>::value || std::is_same<Op, Or>::value,
    740                 "IdentityValue not yet defined for this type");
    741 
    742   template <typename U = T, typename OpCopy = Op>
    743   U operator()(
    744       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
    745     return t;
    746   }
    747 
    748   template <typename U = T, typename OpCopy = Op>
    749   U operator()(typename std::enable_if<IsMax<U, OpCopy>::value, U>::type t =
    750                    Eigen::NumTraits<U>::lowest()) {
    751     return t;
    752   }
    753 
    754   template <typename U = T, typename OpCopy = Op>
    755   U operator()(typename std::enable_if<IsMin<U, OpCopy>::value, U>::type t =
    756                    Eigen::NumTraits<U>::highest()) {
    757     return t;
    758   }
    759 
    760   template <typename U = T, typename OpCopy = Op>
    761   U operator()(
    762       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
    763     return t;
    764   }
    765 
    766   template <typename U = T, typename OpCopy = Op>
    767   U operator()(typename std::enable_if<std::is_same<OpCopy, And>::value,
    768                                        bool>::type t = true) {
    769     return t;
    770   }
    771 
    772   template <typename U = T, typename OpCopy = Op>
    773   U operator()(typename std::enable_if<std::is_same<OpCopy, Or>::value,
    774                                        bool>::type t = false) {
    775     return t;
    776   }
    777 };
    778 
    779 }  // namespace reduction_op_helper
    780 
    781 template <typename T, typename Op, typename OUT_T, typename IN_T,
    782           typename ReductionAxes>
    783 void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank,
    784                 int in_dim0, int in_dim1, int in_dim2, int out_rank,
    785                 const ReductionAxes& reduction_axes, Op op) {
    786   T init = reduction_op_helper::IdentityValue<T, Op>()();
    787   const cudaStream_t& cu_stream = GetCudaStream(ctx);
    788   if (out_rank == 0) {
    789     const int in_size = in_dim0 * in_dim1 * in_dim2;
    790     LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream);
    791   } else if (in_rank == 2 && out_rank == 1 &&
    792              reduction_axes[0] == 1) {  // row reduction
    793     LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
    794   } else if (in_rank == 2 && out_rank == 1 &&
    795              reduction_axes[0] == 0) {  // column reduction
    796     LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
    797   } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) {
    798     Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
    799                        cu_stream);
    800   } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 &&
    801              reduction_axes[1] == 2) {
    802     Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
    803                         cu_stream);
    804   } else {
    805     std::stringstream ss;
    806     ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank
    807        << " " << out_rank;
    808     if (out_rank == 1) ss << " " << reduction_axes[0];
    809     if (out_rank == 2) ss << " " << reduction_axes[1];
    810     LOG(FATAL) << ss.str();
    811   }
    812 }
    813 
    814 template <typename Reducer>
    815 struct ReduceFunctor<GPUDevice, Reducer> {
    816   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    817   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    818                      const ReductionAxes& reduction_axes,
    819                      const Reducer& reducer);
    820 };
    821 
    822 template <typename T>
    823 struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
    824   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    825   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    826                      const ReductionAxes& reduction_axes,
    827                      const Eigen::internal::SumReducer<T>& reducer) {
    828     ReduceImpl<T, Sum<T>, T*, T*, ReductionAxes>(
    829         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
    830         in.rank() >= 2 ? in.dimension(1) : 1,
    831         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    832         Sum<T>());
    833   }
    834 
    835   template <typename OUT_T>
    836   static void FillIdentity(const GPUDevice& d, OUT_T out,
    837                            const Eigen::internal::SumReducer<T>& reducer) {
    838     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    839   }
    840 };
    841 
    842 template <typename T>
    843 struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<T>> {
    844   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    845   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    846                      const ReductionAxes& reduction_axes,
    847                      const Eigen::internal::MeanReducer<T>& reducer) {
    848     int divisor = 1;
    849     if (out.rank() == 0)
    850       divisor = in.size();
    851     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
    852       divisor = in.dimension(0);
    853     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
    854       divisor = in.dimension(1);
    855     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
    856              reduction_axes[1] == 2)
    857       divisor = in.dimension(0) * in.dimension(2);
    858     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
    859       divisor = in.dimension(1);
    860 
    861     DividesBy<T> div_op(static_cast<T>(divisor));
    862     TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op);
    863     ReduceImpl<T, Sum<T>, TransformOutputIterator<T, T, DividesBy<T>>, T*,
    864                ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(),
    865                               in.dimension(0),
    866                               in.rank() >= 2 ? in.dimension(1) : 1,
    867                               in.rank() >= 3 ? in.dimension(2) : 1, out.rank(),
    868                               reduction_axes, Sum<T>());
    869   }
    870 
    871   template <typename OUT_T>
    872   static void FillIdentity(const GPUDevice& d, OUT_T out,
    873                            const Eigen::internal::MeanReducer<T>& reducer) {
    874     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    875   }
    876 };
    877 
    878 template <>
    879 struct ReduceFunctor<GPUDevice, Eigen::internal::MeanReducer<Eigen::half>> {
    880   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    881   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    882                      const ReductionAxes& reduction_axes,
    883                      const Eigen::internal::MeanReducer<Eigen::half>& reducer) {
    884     float divisor = 1.f;
    885     if (out.rank() == 0)
    886       divisor = in.size();
    887     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
    888       divisor = in.dimension(0);
    889     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
    890       divisor = in.dimension(1);
    891     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
    892              reduction_axes[1] == 2)
    893       divisor = in.dimension(0) * in.dimension(2);
    894     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
    895       divisor = in.dimension(1);
    896     DividesBy<float, Eigen::half> div_op(divisor);
    897 
    898     typedef cub::TransformInputIterator<float, HalfToFloat, Eigen::half*>
    899         inputIterType;
    900     inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat());
    901 
    902     typedef TransformOutputIterator<Eigen::half, float,
    903                                     DividesBy<float, Eigen::half>>
    904         outputIterType;
    905     outputIterType itr((Eigen::half*)out.data(), div_op);
    906 
    907     ReduceImpl<float, cub::Sum, outputIterType, inputIterType, ReductionAxes>(
    908         ctx, itr, input_itr, in.rank(), in.dimension(0),
    909         in.rank() >= 2 ? in.dimension(1) : 1,
    910         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    911         cub::Sum());
    912   }
    913 
    914   template <typename OUT_T>
    915   static void FillIdentity(
    916       const GPUDevice& d, OUT_T out,
    917       const Eigen::internal::MeanReducer<Eigen::half>& reducer) {
    918     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    919   }
    920 };
    921 
    922 template <typename T>
    923 struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> {
    924   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    925   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    926                      const ReductionAxes& reduction_axes,
    927                      const Eigen::internal::MaxReducer<T>& reducer) {
    928     ReduceImpl<T, cub::Max, T*, T*, ReductionAxes>(
    929         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
    930         in.rank() >= 2 ? in.dimension(1) : 1,
    931         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    932         cub::Max());
    933   }
    934 
    935   template <typename OUT_T>
    936   static void FillIdentity(const GPUDevice& d, OUT_T out,
    937                            const Eigen::internal::MaxReducer<T>& reducer) {
    938     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    939   }
    940 };
    941 
    942 template <typename T>
    943 struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> {
    944   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    945   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    946                      const ReductionAxes& reduction_axes,
    947                      const Eigen::internal::MinReducer<T>& reducer) {
    948     ReduceImpl<T, cub::Min, T*, T*, ReductionAxes>(
    949         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
    950         in.rank() >= 2 ? in.dimension(1) : 1,
    951         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    952         cub::Min());
    953   }
    954 
    955   template <typename OUT_T>
    956   static void FillIdentity(const GPUDevice& d, OUT_T out,
    957                            const Eigen::internal::MinReducer<T>& reducer) {
    958     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    959   }
    960 };
    961 
    962 template <typename T>
    963 struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> {
    964   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    965   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    966                      const ReductionAxes& reduction_axes,
    967                      const Eigen::internal::ProdReducer<T>& reducer) {
    968     ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>(
    969         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
    970         in.rank() >= 2 ? in.dimension(1) : 1,
    971         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    972         Prod<T>());
    973   }
    974 
    975   template <typename OUT_T>
    976   static void FillIdentity(const GPUDevice& d, OUT_T out,
    977                            const Eigen::internal::ProdReducer<T>& reducer) {
    978     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    979   }
    980 };
    981 
    982 template <>
    983 struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> {
    984   template <typename OUT_T, typename IN_T, typename ReductionAxes>
    985   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
    986                      const ReductionAxes& reduction_axes,
    987                      const Eigen::internal::AndReducer& reducer) {
    988     ReduceImpl<bool, And, bool*, bool*, ReductionAxes>(
    989         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
    990         in.rank() >= 2 ? in.dimension(1) : 1,
    991         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
    992         And());
    993   }
    994 
    995   template <typename OUT_T>
    996   static void FillIdentity(const GPUDevice& d, OUT_T out,
    997                            const Eigen::internal::AndReducer& reducer) {
    998     FillIdentityEigenImpl(d, To32Bit(out), reducer);
    999   }
   1000 };
   1001 
   1002 template <>
   1003 struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
   1004   template <typename OUT_T, typename IN_T, typename ReductionAxes>
   1005   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
   1006                      const ReductionAxes& reduction_axes,
   1007                      const Eigen::internal::OrReducer& reducer) {
   1008     ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>(
   1009         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
   1010         in.rank() >= 2 ? in.dimension(1) : 1,
   1011         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or());
   1012   }
   1013 
   1014   template <typename OUT_T>
   1015   static void FillIdentity(const GPUDevice& d, OUT_T out,
   1016                            const Eigen::internal::OrReducer& reducer) {
   1017     FillIdentityEigenImpl(d, To32Bit(out), reducer);
   1018   }
   1019 };
   1020 
   1021 }  // namespace functor
   1022 }  // namespace tensorflow
   1023 
   1024 #endif
   1025