1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/gather_nd_op.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/kernels/bounds_check.h" 24 #include "tensorflow/core/lib/strings/str_util.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/platform/mem.h" 27 #include "tensorflow/core/platform/types.h" 28 #include "tensorflow/core/util/util.h" 29 30 namespace tensorflow { 31 32 typedef Eigen::ThreadPoolDevice CPUDevice; 33 typedef Eigen::GpuDevice GPUDevice; 34 35 template <typename Device, typename T, typename Index> 36 class GatherNdOp : public OpKernel { 37 public: 38 explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) { 39 const DataType dt = DataTypeToEnum<T>::v(); 40 const DataType index_t = DataTypeToEnum<Index>::v(); 41 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt})); 42 } 43 44 void Compute(OpKernelContext* c) override { 45 const Tensor& params = c->input(0); 46 const Tensor& indices = c->input(1); 47 48 Tensor out; 49 OP_REQUIRES_OK( 50 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out)); 51 c->set_output(0, out); 52 } 53 }; 54 55 #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \ 56 REGISTER_KERNEL_BUILDER(Name("GatherNd") \ 57 .Device(DEVICE_##dev) \ 58 .TypeConstraint<type>("Tparams") \ 59 .TypeConstraint<index_type>("Tindices"), \ 60 GatherNdOp<dev##Device, type, index_type>) 61 62 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \ 63 REGISTER_GATHER_ND_FULL(dev, type, int32); \ 64 REGISTER_GATHER_ND_FULL(dev, type, int64) 65 66 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type) 67 68 // TODO(ebrevdo): This is a pure data-movement kernel. It shouldn't be 69 // instantiated for all different types. Instead, all the types should 70 // be coalesced. So we should only have int8, int16, int32, int64 support. 71 // And float is redirected to int32, double is redirected to int64, 72 // and complex<float> is redirected to int32 with twice the number of 73 // entries, similarly for complex<double>. 74 // 75 // Same for the GPU kernel. 76 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); 77 78 #undef REGISTER_GATHER_ND_CPU 79 80 namespace functor { 81 template <typename Device, typename T, typename Index> 82 Status DoGatherNd(OpKernelContext* c, const Tensor& params, 83 const Tensor& indices, Tensor* out) { 84 if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { 85 return errors::InvalidArgument("params must be at least a vector"); 86 } 87 if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) { 88 return errors::InvalidArgument("indices must be at least a vector"); 89 } 90 if (indices.dim_size(indices.dims() - 1) > params.dims()) { 91 return errors::InvalidArgument( 92 "index innermost dimension length must be <= params rank; saw: ", 93 indices.dim_size(indices.dims() - 1), " vs. ", params.dims()); 94 } 95 96 const TensorShape& indices_shape(indices.shape()); 97 const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); 98 99 // Check that we have enough index space 100 int64 N_big = 1; 101 for (int i = 0; i < indices_shape.dims() - 1; ++i) { 102 N_big *= indices_shape.dim_size(i); 103 } 104 if (N_big > std::numeric_limits<int>::max()) { 105 return errors::InvalidArgument( 106 "indices has too many elements for int indexing: ", N_big, " > ", 107 std::numeric_limits<int>::max()); 108 } 109 if (params.NumElements() > std::numeric_limits<Index>::max()) { 110 return errors::InvalidArgument("params.NumElements() too large for ", 111 DataTypeString(DataTypeToEnum<Index>::v()), 112 " indexing: ", params.NumElements(), " > ", 113 std::numeric_limits<Index>::max()); 114 } 115 116 // The result shape is 117 // indices.shape[:-1] + params.shape[indices.shape[-1]:] 118 Index N_result = 1; 119 for (int i = 0; i < indices_shape.dims() - 1; ++i) { 120 N_result *= indices_shape.dim_size(i); 121 } 122 123 const TensorShape& params_shape(params.shape()); 124 Index total_nd = params_shape.dims(); 125 126 TensorShape result_shape(indices_shape); 127 result_shape.RemoveLastDims(1); 128 129 int64 slice_size_big = 1; 130 for (Index i = indices_nd; i < total_nd; ++i) { 131 slice_size_big *= params_shape.dim_size(i); 132 result_shape.AddDim(params_shape.dim_size(i)); 133 } 134 135 if (slice_size_big > std::numeric_limits<Index>::max()) { 136 return errors::InvalidArgument( 137 "slice size is too large for indexing: ", slice_size_big, " > ", 138 std::numeric_limits<Index>::max()); 139 } 140 141 const Index slice_size = static_cast<Index>(slice_size_big); 142 143 TF_RETURN_IF_ERROR( 144 c->allocate_temp(DataTypeToEnum<T>::value, result_shape, out)); 145 146 if (N_result > 0) { 147 if (params_shape.num_elements() == 0) { 148 return errors::InvalidArgument( 149 "Requested more than 0 entries, but " 150 "params is empty. Params shape: ", 151 params_shape.DebugString()); 152 } 153 154 auto indices_mat = indices.flat_inner_dims<Index>(); 155 156 Index bad_i = -1; 157 158 // Request to copy slices / subtensors 159 // Make out a matrix with the slices the col size. 160 auto out_mat = out->shaped<T, 2>({N_result, slice_size}); 161 Tensor scratch; 162 TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch)); 163 auto scratch_scalar = scratch.scalar<int32>(); 164 165 switch (indices_nd) { 166 #define PARAMS_CASE(IXDIM) \ 167 case IXDIM: { \ 168 functor::GatherNdSlice<Device, T, Index, IXDIM> func; \ 169 auto params_flat = params.flat_outer_dims<T, IXDIM + 1>(); \ 170 bad_i = func(c->eigen_device<Device>(), slice_size, scratch_scalar, \ 171 params_flat, indices_mat, out_mat); \ 172 } break 173 PARAMS_CASE(0); 174 PARAMS_CASE(1); 175 PARAMS_CASE(2); 176 PARAMS_CASE(3); 177 PARAMS_CASE(4); 178 PARAMS_CASE(5); 179 PARAMS_CASE(6); 180 PARAMS_CASE(7); 181 #undef PARAMS_CASE 182 default: 183 return errors::InvalidArgument( 184 "Only indices.shape[-1] values between 1 and 7 " 185 "are currently supported. Requested rank: ", 186 indices_nd); 187 } 188 189 // bad_i will only return >= 0 on CPUs right now. 190 if (bad_i >= 0) { 191 return errors::InvalidArgument( 192 "flat indices[", bad_i, ", :] = [", 193 str_util::Join( 194 gtl::ArraySlice<Index>(&indices_mat(bad_i, 0), indices_nd), ", "), 195 "] does not index into param (shape: ", params.shape().DebugString(), 196 ")."); 197 } 198 } 199 return Status::OK(); 200 } 201 202 } // namespace functor 203 204 #if GOOGLE_CUDA 205 // Forward declarations of the functor specializations for GPU. 206 namespace functor { 207 #define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \ 208 template <> \ 209 Index GatherNdSlice<GPUDevice, T, Index, NDIM>::operator()( \ 210 const GPUDevice& d, const Index slice_size, \ 211 typename TTypes<int32>::Scalar Tscratch, \ 212 typename TTypes<T, NDIM + 1>::ConstTensor Tparams, \ 213 typename TTypes<Index>::ConstMatrix Tindices, \ 214 typename TTypes<T>::Matrix Tout); \ 215 extern template struct GatherNdSlice<GPUDevice, T, Index, NDIM>; 216 217 #define DECLARE_GPU_SPECS_INDEX(T, Index) \ 218 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 0); \ 219 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \ 220 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \ 221 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \ 222 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \ 223 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5); \ 224 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 6); \ 225 DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 7); 226 227 #define DECLARE_GPU_SPECS(T) \ 228 DECLARE_GPU_SPECS_INDEX(T, int32); \ 229 DECLARE_GPU_SPECS_INDEX(T, int64) 230 231 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 232 TF_CALL_complex64(DECLARE_GPU_SPECS); 233 TF_CALL_complex128(DECLARE_GPU_SPECS); 234 235 #undef DECLARE_GPU_SPECS 236 #undef DECLARE_GPU_SPECS_INDEX 237 } // namespace functor 238 239 // Registration of the GPU implementations. 240 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type) 241 242 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU); 243 TF_CALL_complex64(REGISTER_GATHER_ND_GPU); 244 TF_CALL_complex128(REGISTER_GATHER_ND_GPU); 245 246 #undef REGISTER_GATHER_ND_GPU 247 248 #endif // GOOGLE_CUDA 249 250 #undef REGISTER_GATHER_ND_ALL_INDICES 251 #undef REGISTER_GATHER_ND_FULL 252 253 } // namespace tensorflow 254