Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "external/cub_archive/cub/device/device_reduce.cuh"
     22 #include "external/cub_archive/cub/device/device_select.cuh"
     23 #include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
     24 #include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor_types.h"
     27 #include "tensorflow/core/kernels/bounds_check.h"
     28 #include "tensorflow/core/kernels/where_op.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/types.h"
     31 #include "tensorflow/core/util/cuda_kernel_helper.h"
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::GpuDevice GPUDevice;
     36 
     37 namespace functor {
     38 
     39 template <int NDIM, typename TIndex>
     40 __global__ void PropagateWhereIndicesKernel(
     41     const TIndex output_rows, const typename Eigen::array<TIndex, NDIM> strides,
     42     int64* output) {
     43   // TODO(ebrevdo): Use a multi-dimensional loop, increasing the
     44   // dimensions of individual indices manually, instead of relying on
     45   // a scalar loop variable and using integer division.
     46   CUDA_1D_KERNEL_LOOP(i, output_rows) {
     47     TIndex index_value = ldg(output + NDIM * i);
     48 #pragma unroll
     49     for (int c = 0; c < NDIM; ++c) {
     50       *(output + NDIM * i + c) = index_value / strides[c];
     51       index_value %= strides[c];
     52     }
     53   }
     54 }
     55 
     56 namespace {
     57 
     58 template <typename T>
     59 struct IsNonzero {
     60   EIGEN_DEVICE_FUNC IsNonzero() : zero(T(0)) {}
     61   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x) const {
     62     return (x != zero);
     63   }
     64   const T zero;
     65 };
     66 
     67 template <typename T, typename TIndex>
     68 struct CubDeviceReduceCount {
     69   cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
     70                          const T* d_in, TIndex* d_out, int num_items,
     71                          cudaStream_t stream = 0,
     72                          bool debug_synchronous = false) {
     73     IsNonzero<T> is_nonzero;
     74     cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter(
     75         d_in, is_nonzero);
     76     return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
     77                                   is_nonzero_iter, d_out, num_items, stream,
     78                                   debug_synchronous);
     79   }
     80 };
     81 
     82 template <typename TIndex>
     83 struct CubDeviceReduceCount<bool, TIndex> {
     84   cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
     85                          const bool* d_in, TIndex* d_out, int num_items,
     86                          cudaStream_t stream = 0,
     87                          bool debug_synchronous = false) {
     88     return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in,
     89                                   d_out, num_items, stream, debug_synchronous);
     90   }
     91 };
     92 
     93 template <typename T, typename TIndex, typename OutputIterator,
     94           bool IsConvertibleToBool>
     95 struct CubDeviceSelectFlaggedCounter;
     96 
     97 template <typename T, typename TIndex, typename OutputIterator>
     98 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
     99                                      false /*IsConvertibleToBool*/> {
    100   cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
    101                          const T* d_flags, OutputIterator d_out,
    102                          TIndex* d_num_selected_out, int num_items,
    103                          cudaStream_t stream = 0,
    104                          bool debug_synchronous = false) {
    105     cub::CountingInputIterator<TIndex> select_counter(0);
    106     IsNonzero<T> is_nonzero;
    107     cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter(
    108         d_flags, is_nonzero);
    109     return cub::DeviceSelect::Flagged(
    110         d_temp_storage, temp_storage_bytes, select_counter /*d_in*/,
    111         is_nonzero_iter /*d_flags*/, d_out, d_num_selected_out, num_items,
    112         stream, debug_synchronous);
    113   }
    114 };
    115 
    116 template <typename T, typename TIndex, typename OutputIterator>
    117 struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
    118                                      true /*IsConvertibleToBool*/> {
    119   cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
    120                          const T* d_flags, OutputIterator d_out,
    121                          TIndex* d_num_selected_out, int num_items,
    122                          cudaStream_t stream = 0,
    123                          bool debug_synchronous = false) {
    124     cub::CountingInputIterator<TIndex> select_counter(0);
    125     return cub::DeviceSelect::Flagged(
    126         d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, d_flags,
    127         d_out, d_num_selected_out, num_items, stream, debug_synchronous);
    128   }
    129 };
    130 
    131 }  // namespace
    132 
    133 template <typename T, typename TIndex>
    134 struct NumTrue<GPUDevice, T, TIndex> {
    135   EIGEN_ALWAYS_INLINE static Status Compute(
    136       OpKernelContext* ctx, const GPUDevice& d,
    137       typename TTypes<T>::ConstFlat input,
    138       typename TTypes<TIndex>::Scalar num_true) {
    139     const cudaStream_t& cu_stream = GetCudaStream(ctx);
    140 
    141     std::size_t temp_storage_bytes = 0;
    142     const T* input_data = input.data();
    143     TIndex* num_true_data = num_true.data();
    144 
    145     // TODO(ebrevdo): sum doesn't work; perhaps need a different
    146     // iterator?
    147     auto reducer = CubDeviceReduceCount<T, TIndex>();
    148     auto first_success = reducer(/*temp_storage*/ nullptr, temp_storage_bytes,
    149                                  /*d_in*/ input_data,
    150                                  /*d_out*/ num_true_data,
    151                                  /*num_items*/ input.size(),
    152                                  /*stream*/ cu_stream);
    153 
    154     if (first_success != cudaSuccess) {
    155       return errors::Internal(
    156           "WhereOp: Could not launch cub::DeviceReduce::Sum to calculate "
    157           "temp_storage_bytes, status: ",
    158           cudaGetErrorString(first_success));
    159     }
    160 
    161     Tensor temp_storage;
    162     TF_RETURN_IF_ERROR(ctx->allocate_temp(
    163         DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    164         &temp_storage));
    165 
    166     auto second_success = reducer(
    167         /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
    168         /*d_in*/ input_data,
    169         /*d_out*/ num_true_data,
    170         /*num_items*/ input.size(),
    171         /*stream*/ cu_stream);
    172 
    173     if (second_success != cudaSuccess) {
    174       return errors::Internal(
    175           "WhereOp: Could not launch cub::DeviceReduce::Sum to count "
    176           "number of true / nonzero indices.  temp_storage_bytes: ",
    177           temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
    178     }
    179 
    180     return Status::OK();
    181   }
    182 };
    183 
    184 #define NUMTRUE_GPU_FUNCTOR(T)                  \
    185   template struct NumTrue<GPUDevice, T, int32>; \
    186   template struct NumTrue<GPUDevice, T, int64>;
    187 
    188 // We only need to declare the NumTrue functor once, but this file is
    189 // included from where_op_gpu_impl_X.cu.cc for X=1,2,...
    190 // Only declare for X = 1.
    191 #if GPU_PROVIDED_DIM == 1
    192 
    193 TF_CALL_WHERE_GPU_TYPES(NUMTRUE_GPU_FUNCTOR);
    194 
    195 #endif  // GPU_PROVIDED_DIM == 1
    196 
    197 #undef NUMTRUE_GPU_FUNCTOR
    198 
    199 template <int NDIM>
    200 class WhereOutputIterator {
    201  public:
    202   // Required iterator traits
    203   typedef WhereOutputIterator self_type;
    204   typedef std::ptrdiff_t difference_type;
    205   typedef void value_type;
    206   typedef void pointer;
    207   typedef int64& reference;
    208 
    209 #if (THRUST_VERSION >= 100700)
    210   // Use Thrust's iterator categories so we can use these iterators in Thrust
    211   // 1.7 (or newer) methods
    212   typedef typename thrust::detail::iterator_facade_category<
    213       thrust::device_system_tag, thrust::random_access_traversal_tag,
    214       value_type,
    215       reference>::type iterator_category;  ///< The iterator category
    216 #else
    217   typedef std::random_access_iterator_tag
    218       iterator_category;  ///< The iterator category
    219 #endif  // THRUST_VERSION
    220 
    221   WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row)
    222       : ptr_(ptr), max_row_(max_row) {}
    223 
    224   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const {
    225     // If the selection mechanism finds too many true values (because
    226     // the input tensor changed between allocation of output and now),
    227     // we may accidentally try to write past the allowable memory.  If
    228     // valid is false, then we don't do this.  Instead, we'll read off
    229     // the number of items found in Flagged()'s d_num_selected_out at
    230     // the end and confirm that it matches the number of rows of output.
    231     const bool valid = FastBoundsCheck(n, max_row_);
    232     return *(ptr_ + (valid ? (NDIM * n) : 0));
    233   }
    234 
    235  private:
    236   int64* ptr_;
    237   const Eigen::DenseIndex max_row_;
    238 };
    239 
    240 template <typename TIndex, typename T, int NDIM>
    241 Eigen::array<TIndex, NDIM> CalculateStrides(
    242     typename TTypes<T, NDIM>::ConstTensor input) {
    243   const Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
    244   Eigen::array<TIndex, NDIM> strides;
    245   EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
    246                        static_cast<int>(Eigen::RowMajor)),
    247                       INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
    248   strides[NDIM - 1] = 1;
    249   for (int i = NDIM - 2; i >= 0; --i) {
    250     strides[i] = strides[i + 1] * dims[i + 1];
    251   }
    252   return strides;
    253 }
    254 
    255 template <int NDIM, typename T, typename TIndex>
    256 struct Where<GPUDevice, NDIM, T, TIndex> {
    257   EIGEN_ALWAYS_INLINE static Status Compute(
    258       OpKernelContext* ctx, const GPUDevice& d,
    259       typename TTypes<T, NDIM>::ConstTensor input,
    260       typename TTypes<int64>::Matrix output, TIndex* found_true_host) {
    261     if (output.dimension(0) == 0) {
    262       // Nothing to do.
    263       return Status::OK();
    264     }
    265 
    266     const cudaStream_t& cu_stream = GetCudaStream(ctx);
    267 
    268     std::size_t temp_storage_bytes = 0;
    269 
    270     Tensor found_true_t;
    271     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<TIndex>::v(),
    272                                           TensorShape({}), &found_true_t));
    273     TIndex* found_true_device = found_true_t.scalar<TIndex>().data();
    274 
    275     WhereOutputIterator<NDIM> output_iterator(
    276         output.data(),
    277         /* max_row */ output.dimension(0));
    278 
    279     typedef std::decay<T> DT;
    280     CubDeviceSelectFlaggedCounter<
    281         T, TIndex, decltype(output_iterator) /*OutputIterator*/,
    282         std::is_convertible<DT, bool>::value /*IsConvertibleToBool*/>
    283         counter;
    284     auto first_success = counter(/*temp_storage*/ nullptr, temp_storage_bytes,
    285                                  /*d_flags*/ input.data(),
    286                                  /*d_out*/ output_iterator,
    287                                  /*d_num_selected_out*/ found_true_device,
    288                                  /*num_items*/ input.size(),
    289                                  /*stream*/ cu_stream);
    290     if (first_success != cudaSuccess) {
    291       return errors::Internal(
    292           "WhereOp: Could not launch cub::DeviceSelect::Flagged to calculate "
    293           "temp_storage_bytes, status: ",
    294           cudaGetErrorString(first_success));
    295     }
    296 
    297     Tensor temp_storage;
    298     TF_RETURN_IF_ERROR(ctx->allocate_temp(
    299         DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
    300         &temp_storage));
    301 
    302     auto second_success = counter(
    303         /*temp_storage*/ temp_storage.flat<int8>().data(), temp_storage_bytes,
    304         /*d_flags*/ input.data(),
    305         /*d_out*/ output_iterator,
    306         /*d_num_selected_out*/ found_true_device,
    307         /*num_items*/ input.size(),
    308         /*stream*/ cu_stream);
    309 
    310     if (second_success != cudaSuccess) {
    311       return errors::Internal(
    312           "WhereOp: Could not launch cub::DeviceSelect::Flagged to copy "
    313           "indices out, status: ",
    314           cudaGetErrorString(second_success));
    315     }
    316 
    317     // TODO(ebrevdo): Find a way to synchronously copy back data from
    318     // found_true_device to *found_true_host.
    319 
    320     const Eigen::array<TIndex, NDIM> strides =
    321         CalculateStrides<TIndex, T, NDIM>(input);
    322     const TIndex output_rows = output.dimension(0);
    323     CudaLaunchConfig config = GetCudaLaunchConfig(output_rows, d);
    324     PropagateWhereIndicesKernel<NDIM, TIndex>
    325         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    326             output_rows, strides, output.data());
    327 
    328     return Status::OK();
    329   }
    330 };
    331 
    332 #define DECLARE_GPU_SPEC_INDEX(Dims, T, TIndex) \
    333   template struct Where<GPUDevice, Dims, T, TIndex>
    334 
    335 #define DECLARE_GPU_SPEC(T)                           \
    336   DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int32); \
    337   DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int64)
    338 
    339 TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
    340 
    341 #undef DECLARE_GPU_SPEC
    342 #undef DECLARE_GPU_SPEC_INDEX
    343 
    344 }  // namespace functor
    345 
    346 }  // namespace tensorflow
    347 
    348 #endif  // GOOGLE_CUDA
    349