Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/gather_nd_op.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/kernels/bounds_check.h"
     24 #include "tensorflow/core/lib/strings/str_util.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 #include "tensorflow/core/platform/mem.h"
     27 #include "tensorflow/core/platform/types.h"
     28 #include "tensorflow/core/util/util.h"
     29 
     30 namespace tensorflow {
     31 
     32 typedef Eigen::ThreadPoolDevice CPUDevice;
     33 typedef Eigen::GpuDevice GPUDevice;
     34 
     35 template <typename Device, typename T, typename Index>
     36 class GatherNdOp : public OpKernel {
     37  public:
     38   explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
     39     const DataType dt = DataTypeToEnum<T>::v();
     40     const DataType index_t = DataTypeToEnum<Index>::v();
     41     OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
     42   }
     43 
     44   void Compute(OpKernelContext* c) override {
     45     const Tensor& params = c->input(0);
     46     const Tensor& indices = c->input(1);
     47 
     48     Tensor out;
     49     OP_REQUIRES_OK(
     50         c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
     51     c->set_output(0, out);
     52   }
     53 };
     54 
     55 #define REGISTER_GATHER_ND_FULL(dev, type, index_type)                 \
     56   REGISTER_KERNEL_BUILDER(Name("GatherNd")                             \
     57                               .Device(DEVICE_##dev)                    \
     58                               .TypeConstraint<type>("Tparams")         \
     59                               .TypeConstraint<index_type>("Tindices"), \
     60                           GatherNdOp<dev##Device, type, index_type>)
     61 
     62 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
     63   REGISTER_GATHER_ND_FULL(dev, type, int32);      \
     64   REGISTER_GATHER_ND_FULL(dev, type, int64)
     65 
     66 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
     67 
     68 // TODO(ebrevdo): This is a pure data-movement kernel. It shouldn't be
     69 // instantiated for all different types. Instead, all the types should
     70 // be coalesced. So we should only have int8, int16, int32, int64 support.
     71 // And float is redirected to int32, double is redirected to int64,
     72 // and complex<float> is redirected to int32 with twice the number of
     73 // entries, similarly for complex<double>.
     74 //
     75 // Same for the GPU kernel.
     76 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
     77 
     78 #undef REGISTER_GATHER_ND_CPU
     79 
     80 namespace functor {
     81 template <typename Device, typename T, typename Index>
     82 Status DoGatherNd(OpKernelContext* c, const Tensor& params,
     83                   const Tensor& indices, Tensor* out) {
     84   if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
     85     return errors::InvalidArgument("params must be at least a vector");
     86   }
     87   if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) {
     88     return errors::InvalidArgument("indices must be at least a vector");
     89   }
     90   if (indices.dim_size(indices.dims() - 1) > params.dims()) {
     91     return errors::InvalidArgument(
     92         "index innermost dimension length must be <= params rank; saw: ",
     93         indices.dim_size(indices.dims() - 1), " vs. ", params.dims());
     94   }
     95 
     96   const TensorShape& indices_shape(indices.shape());
     97   const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
     98 
     99   // Check that we have enough index space
    100   int64 N_big = 1;
    101   for (int i = 0; i < indices_shape.dims() - 1; ++i) {
    102     N_big *= indices_shape.dim_size(i);
    103   }
    104   if (N_big > std::numeric_limits<int>::max()) {
    105     return errors::InvalidArgument(
    106         "indices has too many elements for int indexing: ", N_big, " > ",
    107         std::numeric_limits<int>::max());
    108   }
    109   if (params.NumElements() > std::numeric_limits<Index>::max()) {
    110     return errors::InvalidArgument("params.NumElements() too large for ",
    111                                    DataTypeString(DataTypeToEnum<Index>::v()),
    112                                    " indexing: ", params.NumElements(), " > ",
    113                                    std::numeric_limits<Index>::max());
    114   }
    115 
    116   // The result shape is
    117   //   indices.shape[:-1] + params.shape[indices.shape[-1]:]
    118   Index N_result = 1;
    119   for (int i = 0; i < indices_shape.dims() - 1; ++i) {
    120     N_result *= indices_shape.dim_size(i);
    121   }
    122 
    123   const TensorShape& params_shape(params.shape());
    124   Index total_nd = params_shape.dims();
    125 
    126   TensorShape result_shape(indices_shape);
    127   result_shape.RemoveLastDims(1);
    128 
    129   int64 slice_size_big = 1;
    130   for (Index i = indices_nd; i < total_nd; ++i) {
    131     slice_size_big *= params_shape.dim_size(i);
    132     result_shape.AddDim(params_shape.dim_size(i));
    133   }
    134 
    135   if (slice_size_big > std::numeric_limits<Index>::max()) {
    136     return errors::InvalidArgument(
    137         "slice size is too large for indexing: ", slice_size_big, " > ",
    138         std::numeric_limits<Index>::max());
    139   }
    140 
    141   const Index slice_size = static_cast<Index>(slice_size_big);
    142 
    143   TF_RETURN_IF_ERROR(
    144       c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out));
    145 
    146   if (N_result > 0) {
    147     if (params_shape.num_elements() == 0) {
    148       return errors::InvalidArgument(
    149           "Requested more than 0 entries, but "
    150           "params is empty.  Params shape: ",
    151           params_shape.DebugString());
    152     }
    153 
    154     auto indices_mat = indices.flat_inner_dims<Index>();
    155 
    156     Index bad_i = -1;
    157 
    158     // Request to copy slices / subtensors
    159     // Make out a matrix with the slices the col size.
    160     auto out_mat = out->shaped<T, 2>({N_result, slice_size});
    161     Tensor scratch;
    162     TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch));
    163     auto scratch_scalar = scratch.scalar<int32>();
    164 
    165     switch (indices_nd) {
    166 #define PARAMS_CASE(IXDIM)                                              \
    167   case IXDIM: {                                                         \
    168     functor::GatherNdSlice<Device, T, Index, IXDIM> func;               \
    169     auto params_flat = params.flat_outer_dims<T, IXDIM + 1>();          \
    170     bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \
    171                  params_flat, indices_mat, out_mat);                    \
    172   } break
    173       PARAMS_CASE(0);
    174       PARAMS_CASE(1);
    175       PARAMS_CASE(2);
    176       PARAMS_CASE(3);
    177       PARAMS_CASE(4);
    178       PARAMS_CASE(5);
    179       PARAMS_CASE(6);
    180       PARAMS_CASE(7);
    181 #undef PARAMS_CASE
    182       default:
    183         return errors::InvalidArgument(
    184             "Only indices.shape[-1] values between 1 and 7 "
    185             "are currently supported.  Requested rank: ",
    186             indices_nd);
    187     }
    188 
    189     // bad_i will only return >= 0 on CPUs right now.
    190     if (bad_i >= 0) {
    191       return errors::InvalidArgument(
    192           "flat indices[", bad_i, ", :] = [",
    193           str_util::Join(
    194               gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "),
    195           "] does not index into param (shape: ", params.shape().DebugString(),
    196           ").");
    197     }
    198   }
    199   return Status::OK();
    200 }
    201 
    202 }  // namespace functor
    203 
    204 #if GOOGLE_CUDA
    205 // Forward declarations of the functor specializations for GPU.
    206 namespace functor {
    207 #define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM)          \
    208   template <>                                                 \
    209   Index GatherNdSlice<GPUDevice, T, Index, NDIM>::operator()( \
    210       const GPUDevice& d, const Index slice_size,             \
    211       typename TTypes<int32>::Scalar Tscratch,                \
    212       typename TTypes<T, NDIM + 1>::ConstTensor Tparams,      \
    213       typename TTypes<Index>::ConstMatrix Tindices,           \
    214       typename TTypes<T>::Matrix Tout);                       \
    215   extern template struct GatherNdSlice<GPUDevice, T, Index, NDIM>;
    216 
    217 #define DECLARE_GPU_SPECS_INDEX(T, Index)    \
    218   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \
    219   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
    220   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
    221   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
    222   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
    223   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \
    224   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \
    225   DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 7);
    226 
    227 #define DECLARE_GPU_SPECS(T)         \
    228   DECLARE_GPU_SPECS_INDEX(T, int32); \
    229   DECLARE_GPU_SPECS_INDEX(T, int64)
    230 
    231 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    232 TF_CALL_complex64(DECLARE_GPU_SPECS);
    233 TF_CALL_complex128(DECLARE_GPU_SPECS);
    234 
    235 #undef DECLARE_GPU_SPECS
    236 #undef DECLARE_GPU_SPECS_INDEX
    237 }  // namespace functor
    238 
    239 // Registration of the GPU implementations.
    240 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
    241 
    242 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
    243 TF_CALL_complex64(REGISTER_GATHER_ND_GPU);
    244 TF_CALL_complex128(REGISTER_GATHER_ND_GPU);
    245 
    246 #undef REGISTER_GATHER_ND_GPU
    247 
    248 #endif  // GOOGLE_CUDA
    249 
    250 #undef REGISTER_GATHER_ND_ALL_INDICES
    251 #undef REGISTER_GATHER_ND_FULL
    252 
    253 }  // namespace tensorflow
    254