Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/kernels/gather_nd_op.h"
     22 #include "tensorflow/core/platform/types.h"
     23 #include "tensorflow/core/util/cuda_kernel_helper.h"
     24 
     25 namespace tensorflow {
     26 
     27 typedef Eigen::GpuDevice GPUDevice;
     28 
     29 template <typename T, typename Index, int IXDIM>
     30 __global__ void GatherSliceOpKernel(
     31     const T* params, const Index* indices, T* out,
     32     const Eigen::array<int64, IXDIM> batch_strides,
     33     const Eigen::array<int64, IXDIM> batch_indices, const int64 indices_size,
     34     const int64 slice_size, const int64 out_size) {
     35   // TODO(ebrevdo): reduce inner loop into two loops:
     36   // one over the number of locs, and one over the offsets inside the locs.
     37   CUDA_1D_KERNEL_LOOP(i, out_size) {
     38     const Index loc = i / slice_size;
     39     const auto indices_i = indices + IXDIM * loc;
     40     bool out_of_bounds = false;
     41     Index offset = 0;
     42 #pragma unroll
     43     for (int j = 0; j < IXDIM; ++j) {
     44       const Index index_j = ldg(indices_i + j);
     45       out_of_bounds |= !FastBoundsCheck(index_j, batch_indices[j]);
     46       offset += batch_strides[j] * index_j;
     47     }
     48     // TODO(ebrevdo):
     49     // This is the only part that depends on the offset.  The part
     50     // above does not need to be executed for every index i.
     51     // Is there a way to break the outer loop into two loops?  One
     52     // that determines how many slice_size-length locs are iterated
     53     // over, and another that iterates over slice_size iterations for
     54     // the correct indices?
     55     // NOTE(eriche):
     56     // You can consider one kernel where a warp or block is assigned
     57     // to one offset.  The calculation of offset can be shared within
     58     // the warp or block and then the warp / block can cooperate to
     59     // the copy.
     60     const Index loc_offset = i - loc * slice_size;
     61     out[i] = (out_of_bounds) ? T(0) : ldg(params + offset + loc_offset);
     62   }
     63 }
     64 
     65 namespace functor {
     66 
     67 template <typename T, typename Index, int IXDIM>
     68 struct GatherNdSlice<GPUDevice, T, Index, IXDIM> {
     69   Index operator()(const GPUDevice& d, const Index unused_slice_size,
     70                    typename TTypes<int32>::Scalar Tscratch,
     71                    typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
     72                    typename TTypes<Index>::ConstMatrix Tindices,
     73                    typename TTypes<T>::Matrix Tout) {
     74     const int64 indices_size = Tindices.dimension(1);
     75     const int64 out_size = Tout.size();
     76     int64 s_size = Tout.dimension(1);
     77     Eigen::array<int64, IXDIM> batch_strides;
     78     Eigen::array<int64, IXDIM> batch_indices;
     79     if (IXDIM > 0) {
     80       batch_strides[size_t(IXDIM - 1)] = s_size;
     81       batch_indices[size_t(IXDIM - 1)] = Tparams.dimension(IXDIM - 1);
     82     }
     83     for (int i = IXDIM - 1; i > 0; --i) {
     84       batch_indices[i - 1] = Tparams.dimension(i - 1);
     85       batch_strides[i - 1] = batch_strides[i] * Tparams.dimension(i);
     86     }
     87     CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
     88 
     89     // clang-format off
     90     GatherSliceOpKernel<T, Index, IXDIM>
     91         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
     92             Tparams.data(), Tindices.data(), Tout.data(), batch_strides,
     93             batch_indices, indices_size, s_size, out_size);
     94     // clang-format on
     95 
     96     // TODO(ebrevdo): enable indices validation on GPU.
     97     // Right now checking for indices out of bound in the kernel would
     98     // require copying code between GPU/CPU, and is too slow.
     99     return -1;
    100   }
    101 };
    102 
    103 }  // namespace functor
    104 
    105 #define DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
    106   template struct functor::GatherNdSlice<GPUDevice, T, Index, NDIM>;
    107 
    108 #define DEFINE_GPU_SPECS_INDEX(T, Index)    \
    109   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \
    110   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
    111   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
    112   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
    113   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
    114   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \
    115   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \
    116   DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 7);
    117 
    118 #define DEFINE_GPU_SPECS(T)         \
    119   DEFINE_GPU_SPECS_INDEX(T, int32); \
    120   DEFINE_GPU_SPECS_INDEX(T, int64);
    121 
    122 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
    123 TF_CALL_complex64(DEFINE_GPU_SPECS);
    124 TF_CALL_complex128(DEFINE_GPU_SPECS);
    125 
    126 #undef DEFINE_GPU_SPECS
    127 #undef DEFINE_GPU_SPECS_INDEX
    128 
    129 }  // namespace tensorflow
    130 
    131 #endif  // GOOGLE_CUDA
    132