Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
     17 #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
     18 
     19 #if GOOGLE_CUDA
     20 
     21 #define EIGEN_USE_GPU
     22 
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/kernels/gather_functor.h"
     25 #include "tensorflow/core/platform/types.h"
     26 #include "tensorflow/core/util/cuda_kernel_helper.h"
     27 
     28 namespace tensorflow {
     29 
     30 typedef Eigen::GpuDevice GPUDevice;
     31 
     32 template <typename T, typename Index, bool is_axis_zero>
     33 __global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
     34                                int64 gather_dim_size, int64 indices_size,
     35                                int64 slice_size, int64 out_size) {
     36   CUDA_1D_KERNEL_LOOP(i, out_size) {
     37     Index batch_i = 0;
     38     Index indices_i = 0;
     39     Index slice_i = 0;
     40     if (is_axis_zero) {
     41       indices_i = i / slice_size;
     42       slice_i = i - indices_i * slice_size;
     43     } else {
     44       Index batch_indices_i = i / slice_size;
     45       // The batch index into params to use for i.
     46       batch_i = batch_indices_i / indices_size;
     47       // The index into indices to use for i.
     48       indices_i = batch_indices_i - batch_i * indices_size;
     49       // Index into the current slice in params to use for i.
     50       slice_i = i - batch_indices_i * slice_size;
     51     }
     52 
     53     // Index into the gather axis to use for i.
     54     Index gather_i = ldg(indices + indices_i);
     55 
     56     // Check gather_i is in [0, gather_dim_size).
     57     if (!FastBoundsCheck(gather_i, gather_dim_size)) {
     58       // Set indices out of range to zero
     59       // TODO(fpmc): Log an error for transfer back to host.
     60       out[i] = T(0);
     61     } else {
     62       // params is a [batch_size, gather_dim_size, slice_size] tensor. Read
     63       // params[batch_i, gather_i, slice_i] and write it to the i'th position in
     64       // out.
     65       Index params_i =
     66           (batch_i * gather_dim_size + gather_i) * slice_size + slice_i;
     67       out[i] = ldg(params + params_i);
     68     }
     69   }
     70 }
     71 
     72 namespace functor {
     73 template <typename T, typename Index>
     74 struct GatherFunctor<GPUDevice, T, Index> {
     75   int64 operator()(OpKernelContext* ctx,
     76                    typename TTypes<T, 3>::ConstTensor params,
     77                    typename TTypes<Index>::ConstFlat indices,
     78                    typename TTypes<T, 3>::Tensor out) {
     79     const GPUDevice& d = ctx->eigen_gpu_device();
     80     const int64 out_size = out.size();
     81     if (out_size == 0) {
     82       // We need a check here since the CPU version does useful error checking
     83       // work if there are nonempty indices but empty slices, so the kernel is
     84       // executed in that case.  In the GPU case we don't know how to do error
     85       // checking, so we skip the loop entirely.
     86       return -1;
     87     }
     88     const bool is_axis_zero = params.dimension(0) == 1;
     89     const int64 gather_dim_size = params.dimension(1);
     90     const int64 indices_size = indices.size();
     91     const int64 slice_size = params.dimension(2);
     92 
     93     CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
     94     if (is_axis_zero) {
     95       // clang-format off
     96       GatherOpKernel<T, Index, true>
     97           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
     98               params.data(), indices.data(), out.data(), gather_dim_size,
     99               indices_size, slice_size, out_size);
    100       // clang-format on
    101     } else {
    102       // clang-format off
    103       GatherOpKernel<T, Index, false>
    104           <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
    105               params.data(), indices.data(), out.data(), gather_dim_size,
    106               indices_size, slice_size, out_size);
    107       // clang-format on
    108     }
    109     // TODO(fpmc): enable indices validation on GPU.
    110     // Right now checking for indicies out of bound in the kernel would
    111     // require copying code between GPU/CPU, and thus slow.
    112     return -1;
    113   }
    114 };
    115 
    116 }  // namespace functor
    117 }  // namespace tensorflow
    118 
    119 #endif  // GOOGLE_CUDA
    120 
    121 #endif  // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_
    122