Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/state_ops.cc.
     17 #define EIGEN_USE_THREADS
     18 
     19 #if GOOGLE_CUDA
     20 #define EIGEN_USE_GPU
     21 #endif  // GOOGLE_CUDA
     22 
     23 #include "tensorflow/core/kernels/scatter_nd_op.h"
     24 
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/kernels/bounds_check.h"
     30 #include "tensorflow/core/kernels/dense_update_functor.h"
     31 #include "tensorflow/core/kernels/fill_functor.h"
     32 #include "tensorflow/core/kernels/training_op_helpers.h"
     33 #include "tensorflow/core/kernels/variable_ops.h"
     34 #include "tensorflow/core/lib/strings/str_util.h"
     35 #include "tensorflow/core/platform/mutex.h"
     36 #include "tensorflow/core/platform/types.h"
     37 #include "tensorflow/core/util/util.h"
     38 
     39 #ifdef TENSORFLOW_USE_SYCL
     40 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
     41 #endif  // TENSORFLOW_USE_SYCL
     42 
     43 namespace tensorflow {
     44 
     45 typedef Eigen::ThreadPoolDevice CPUDevice;
     46 typedef Eigen::GpuDevice GPUDevice;
     47 #ifdef TENSORFLOW_USE_SYCL
     48 typedef Eigen::SyclDevice SYCLDevice;
     49 #endif  // TENSORFLOW_USE_SYCL
     50 
     51 template <typename Device, typename T, typename Index>
     52 class ScatterNdOp : public OpKernel {
     53  public:
     54   explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
     55     const DataType dt = DataTypeToEnum<T>::v();
     56     const DataType index_t = DataTypeToEnum<Index>::v();
     57     OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
     58   }
     59 
     60   void Compute(OpKernelContext* c) override {
     61     const Tensor& indices = c->input(0);
     62     const Tensor& updates = c->input(1);
     63     const Tensor& shape_input = c->input(2);
     64 
     65     OP_REQUIRES(c, shape_input.dims() == 1,
     66                 errors::InvalidArgument("Shape must be a vector"));
     67 
     68     auto vec = shape_input.flat<Index>();
     69     TensorShape shape;
     70     OP_REQUIRES_OK(c,
     71                    TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape));
     72 
     73     Tensor out;
     74     OP_REQUIRES_OK(
     75         c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
     76                c, indices, updates, shape, &out, true /*allocate*/));
     77     c->set_output(0, out);
     78   }
     79 };
     80 
     81 template <typename Device, typename T, typename Index,
     82           scatter_nd_op::UpdateOp op>
     83 class ScatterNdUpdateOp : public OpKernel {
     84  public:
     85   explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
     86     const DataType dt = DataTypeToEnum<T>::v();
     87     const DataType dt_ref = DataTypeToEnum<T>::ref();
     88     const DataType index_t = DataTypeToEnum<Index>::v();
     89     dtype_ = c->input_type(0);
     90     if (c->input_type(0) == DT_RESOURCE) {
     91       // TODO(apassos): what to validate here?
     92     } else if (IsRefType(c->input_type(0))) {
     93       OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
     94       OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
     95     } else {
     96       OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
     97       use_exclusive_lock_ = false;
     98     }
     99   }
    100 
    101   void Compute(OpKernelContext* c) override {
    102     if (dtype_ == DT_RESOURCE) {
    103       if (use_exclusive_lock_) {
    104         Var* v;
    105         OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    106         mutex_lock m(*v->mu());
    107         DoCompute(c);
    108       } else {
    109         DoCompute(c);
    110       }
    111     } else if (use_exclusive_lock_) {
    112       // If we're here, it means the input type is a ref.
    113       DCHECK(IsRefType(c->input_dtype(0)));
    114       // Hold mutex while we apply updates
    115       mutex_lock l(*c->input_ref_mutex(0));
    116       DoCompute(c);
    117     } else {
    118       DoCompute(c);
    119     }
    120   }
    121 
    122  private:
    123   DataType dtype_;
    124   bool use_exclusive_lock_;
    125 
    126   void DoCompute(OpKernelContext* c) {
    127     const Tensor& indices = c->input(1);
    128     const Tensor& updates = c->input(2);
    129     Tensor params;
    130     TensorShape params_shape;
    131 
    132     if (dtype_ == DT_RESOURCE) {
    133       Var* v;
    134       OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    135       Tensor* t = v->tensor();
    136       if (!use_exclusive_lock_) {
    137         // We're not holding the lock in the outer scope so need it here.
    138         mutex_lock m(*v->mu());
    139         OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
    140       } else {
    141         OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
    142       }
    143       params = *t;
    144       params_shape = params.shape();
    145     } else if (IsRefType(c->input_dtype(0))) {
    146       params = c->mutable_input(0, use_exclusive_lock_);
    147       params_shape = params.shape();
    148       c->forward_ref_input_to_ref_output(0, 0);
    149       OP_REQUIRES(c, params.IsInitialized(),
    150                   errors::FailedPrecondition("Null ref for params"));
    151     } else {
    152       Tensor* params_ptr;
    153       params_shape = c->input(0).shape();
    154       if (!c->forward_input_to_output_with_shape(0, 0, params_shape,
    155                                                  &params_ptr)) {
    156         // We weren't able to forward the input to output, so just
    157         // allocate a new output tensor and copy the values over.
    158         OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, &params_ptr));
    159         params = *params_ptr;
    160         functor::DenseUpdate<Device, T, ASSIGN> copy;
    161         const Tensor& input_copy = c->input(0);
    162         copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>());
    163       } else {
    164         params = *params_ptr;
    165       }
    166     }
    167 
    168     OP_REQUIRES_OK(
    169         c, functor::DoScatterNd<Device, T, Index, op>(
    170                c, indices, updates, params_shape, &params, false /*allocate*/));
    171   }
    172 };
    173 
    174 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \
    175   REGISTER_KERNEL_BUILDER(Name(name)                                  \
    176                               .Device(DEVICE_##dev)                   \
    177                               .TypeConstraint<type>("T")              \
    178                               .TypeConstraint<index_type>("Tindices") \
    179                               .HostMemory("shape"),                   \
    180                           ScatterNdOp<dev##Device, type, index_type>)
    181 
    182 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \
    183                                                 op)                          \
    184   REGISTER_KERNEL_BUILDER(                                                   \
    185       Name(name)                                                             \
    186           .Device(DEVICE_##dev)                                              \
    187           .TypeConstraint<type>("T")                                         \
    188           .TypeConstraint<index_type>("Tindices"),                           \
    189       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
    190 
    191 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
    192                                                          dev, name, op)    \
    193   REGISTER_KERNEL_BUILDER(                                                 \
    194       Name(name)                                                           \
    195           .Device(DEVICE_##dev)                                            \
    196           .TypeConstraint<type>("T")                                       \
    197           .TypeConstraint<index_type>("Tindices")                          \
    198           .HostMemory("ref"),                                              \
    199       ScatterNdUpdateOp<dev##Device, type, index_type, op>)
    200 
    201 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name)         \
    202   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
    203   REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
    204 
    205 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)         \
    206   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
    207   REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
    208 
    209 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op)    \
    210   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
    211                                                    op);                    \
    212   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
    213 
    214 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev)                            \
    215   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd",            \
    216                                     scatter_nd_op::UpdateOp::ADD);        \
    217   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
    218                                     scatter_nd_op::UpdateOp::ADD);        \
    219   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub",            \
    220                                     scatter_nd_op::UpdateOp::SUB);
    221 
    222 #define REGISTER_SCATTER_ND(type, dev) \
    223   REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
    224 
    225 #define REGISTER_SCATTER_ND_UPDATE(type, dev)                         \
    226   REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate",     \
    227                                     scatter_nd_op::UpdateOp::ASSIGN); \
    228   REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(                         \
    229       type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
    230 
    231 // Registers CPU kernels.
    232 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
    233   REGISTER_SCATTER_ND_ADD_SUB(type, CPU);
    234 
    235 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \
    236   REGISTER_SCATTER_ND_UPDATE(type, CPU);
    237 
    238 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
    239 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU);
    240 
    241 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
    242 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
    243 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
    244 
    245 // Registers GPU kernels.
    246 #if GOOGLE_CUDA
    247 
    248 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \
    249   REGISTER_SCATTER_ND_ADD_SUB(type, GPU);
    250 
    251 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \
    252   REGISTER_SCATTER_ND_UPDATE(type, GPU);
    253 
    254 #define REGISTER_SCATTER_ND_ALL_GPU(type) \
    255   REGISTER_SCATTER_ND_ADD_SUB_GPU(type);  \
    256   REGISTER_SCATTER_ND_UPDATE_GPU(type);   \
    257   REGISTER_SCATTER_ND_GPU(type);
    258 
    259 // TODO(b/66916790): Support half types in ScatterNd.
    260 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ALL_GPU);
    261 TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU);
    262 TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
    263 
    264 #undef REGISTER_SCATTER_ND_ALL_GPU
    265 
    266 #ifdef TENSORFLOW_USE_SYCL
    267 #define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \
    268   REGISTER_SCATTER_ND_ADD_SUB(type, SYCL);
    269 
    270 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \
    271   REGISTER_SCATTER_ND_UPDATE(type, SYCL);
    272 
    273 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
    274 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
    275 #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
    276 #undef REGISTER_SCATTER_ND_UPDATE_SYCL
    277 #endif  // TENSORFLOW_USE_SYCL
    278 
    279 #undef REGISTER_SCATTER_ND_ADD
    280 #undef REGISTER_SCATTER_ND_ADD_SUB
    281 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU
    282 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU
    283 #undef REGISTER_SCATTER_ND_UPDATE
    284 #undef REGISTER_SCATTER_ND_UPDATE_CPU
    285 #undef REGISTER_SCATTER_ND_UPDATE_GPU
    286 #undef REGISTER_SCATTER_ND_KERNEL
    287 #undef REGISTER_SCATTER_ND_KERNEL_INDEX
    288 
    289 #endif  // GOOGLE_CUDA
    290 
    291 namespace functor {
    292 // Check whether updates.shape = indices.shape[:batch_dim] +
    293 // params_shape[slice_dim:]
    294 Status ValidateUpdateShape(const TensorShape& params_shape,
    295                            const Tensor& indices, const Tensor& updates) {
    296   const int64 slice_dim =
    297       (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1;
    298   const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1;
    299 
    300   auto shape_err = [&]() {
    301     return errors::InvalidArgument(
    302         "Must have updates.shape = indices.shape[:batch_dim] + ",
    303         "params_shape[slice_dim:], got updates.shape: ",
    304         updates.shape().DebugString(),
    305         ", indices.shape: ", indices.shape().DebugString(),
    306         ", params_shape: ", params_shape.DebugString(),
    307         ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
    308   };
    309 
    310   if (updates.dims() < batch_dim) return shape_err();
    311   if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
    312     return shape_err();
    313   }
    314   if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
    315     return shape_err();
    316   }
    317   for (int d = 0; d < batch_dim; ++d) {
    318     if (updates.dim_size(d) != indices.dim_size(d)) return shape_err();
    319   }
    320   for (int d = 0; d < updates.dims() - batch_dim; ++d) {
    321     if (updates.dim_size(d + batch_dim) !=
    322         params_shape.dim_size(d + slice_dim)) {
    323       return shape_err();
    324     }
    325   }
    326   return Status::OK();
    327 }
    328 
    329 template <typename Index>
    330 Status PrepareAndValidateInputs(const TensorShape& params_shape,
    331                                 const Tensor& indices, const Tensor& updates,
    332                                 int64* slice_dim, Index* num_updates,
    333                                 Index* slice_size) {
    334   const TensorShape& indices_shape(indices.shape());
    335   const TensorShape& updates_shape(updates.shape());
    336 
    337   if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) {
    338     return errors::InvalidArgument("Output must be at least 1-D, ",
    339                                    "got shape: ", params_shape.DebugString());
    340   }
    341 
    342   if (!(params_shape.num_elements() > 0 ||
    343         (indices.NumElements() == 0 && updates.NumElements() == 0))) {
    344     return errors::InvalidArgument(
    345         "Indices and updates specified for empty output.  indices shape: ",
    346         indices.shape().DebugString());
    347   }
    348 
    349   if (updates.dim_size(0) != indices.dim_size(0)) {
    350     return errors::InvalidArgument(
    351         "The outermost dimension of updates and indices ",
    352         "must match. Got indices.shape ", indices_shape.DebugString(),
    353         ", updates.shape ", updates_shape.DebugString());
    354   }
    355   TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates));
    356 
    357   // Check that we have enough index space
    358   const int64 N_big = indices.NumElements();
    359   if (N_big > std::numeric_limits<Index>::max()) {
    360     return errors::InvalidArgument("indices has too many elements for ",
    361                                    DataTypeString(DataTypeToEnum<Index>::v()),
    362                                    " indexing: ", N_big, " > ",
    363                                    std::numeric_limits<Index>::max());
    364   }
    365   if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) {
    366     return errors::InvalidArgument("params_shape[0] too large for ",
    367                                    DataTypeString(DataTypeToEnum<Index>::v()),
    368                                    " indexing: ", params_shape.dim_size(0),
    369                                    " > ", std::numeric_limits<Index>::max());
    370   }
    371 
    372   // Calculate the number of dimensions in indices
    373   *slice_dim = (indices_shape.dims() > 1)
    374                    ? indices_shape.dim_size(indices_shape.dims() - 1)
    375                    : 1;
    376 
    377   // Calculate the number of elements that make up each slice of our updated
    378   // tensor. This allows us to work with flattened tensors and copy over whole
    379   // slices at a time.
    380   Index total_nd = params_shape.dims();
    381 
    382   int64 slice_size_big = 1;
    383   for (int64 i = *slice_dim; i < total_nd; ++i) {
    384     slice_size_big *= params_shape.dim_size(i);
    385   }
    386 
    387   if (slice_size_big > std::numeric_limits<Index>::max()) {
    388     return errors::InvalidArgument(
    389         "slice size is too large for indexing: ", slice_size_big, " > ",
    390         std::numeric_limits<Index>::max());
    391   }
    392 
    393   *slice_size = static_cast<Index>(slice_size_big);
    394 
    395   const int64 safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim;
    396   *num_updates = indices_shape.num_elements() / safe_slice_dim;
    397 
    398   return Status::OK();
    399 }
    400 
    401 template <typename Device, typename Index>
    402 class IndexFlattener {
    403  public:
    404   inline typename TTypes<Index, 2>::ConstTensor operator()(
    405       OpKernelContext*, const Tensor& indices) {
    406     return indices.flat_inner_dims<Index>();
    407   }
    408 };
    409 
    410 #ifdef TENSORFLOW_USE_SYCL
    411 template <typename Index>
    412 class IndexFlattener<SYCLDevice, Index> {
    413  public:
    414   IndexFlattener() { indices_host_ = nullptr; }
    415   ~IndexFlattener() { delete[] indices_host_; }
    416 
    417   inline typename TTypes<Index, 2>::ConstTensor operator()(
    418       OpKernelContext* c, const Tensor& indices) {
    419     size_t num_indices = indices.NumElements();
    420     indices_host_ = new Index[num_indices];
    421     auto device = c->eigen_sycl_device();
    422     auto size = sizeof(Index) * num_indices;
    423     auto src_ptr = GetBase(&indices);
    424     device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr),
    425                               size);
    426     return typename TTypes<Index, 2>::ConstTensor(
    427         indices_host_, indices.shape().AsEigenDSizes<2>());
    428   }
    429 
    430  private:
    431   Index* indices_host_;
    432 };
    433 #endif
    434 
    435 template <typename Device, typename T, typename Index,
    436           scatter_nd_op::UpdateOp Op>
    437 Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
    438                    const Tensor& updates, const TensorShape& shape, Tensor* out,
    439                    bool allocate) {
    440   int64 slice_dim;
    441   Index num_updates;
    442   Index slice_size;
    443   TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
    444       shape, indices, updates, &slice_dim, &num_updates, &slice_size));
    445 
    446   IndexFlattener<Device, Index> index_flattener;
    447   auto indices_flat = index_flattener(c, indices);
    448   auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
    449 
    450   if (allocate) {
    451     TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out));
    452   } else {
    453     CHECK_NOTNULL(out);
    454   }
    455 
    456   if (shape.num_elements() == 0) {
    457     return Status::OK();
    458   }
    459 
    460   if (allocate) {
    461     // Brand new tensor, zero it out.
    462     functor::SetZeroFunctor<Device, T> fill;
    463     fill(c->eigen_device<Device>(), out->flat<T>());
    464   }
    465   auto output_matrix =
    466       out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size});
    467 
    468   Index bad_i = -1;
    469 
    470   if (shape.num_elements() > 0) {
    471     switch (slice_dim) {
    472 #define PARAMS_CASE(IXDIM)                                                  \
    473   case IXDIM: {                                                             \
    474     typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix;    \
    475     for (int i = 0; i < IXDIM; ++i) {                                       \
    476       output_shape_prefix[i] = shape.dim_size(i);                           \
    477     }                                                                       \
    478     functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor;         \
    479     bad_i =                                                                 \
    480         functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
    481                 output_matrix, indices_flat, updates_flat, output_matrix);  \
    482   } break
    483       // TODO(simister): Re-enable this once binary size is under control.
    484       //      PARAMS_CASE(0);
    485       PARAMS_CASE(1);
    486       PARAMS_CASE(2);
    487       PARAMS_CASE(3);
    488       PARAMS_CASE(4);
    489       PARAMS_CASE(5);
    490       PARAMS_CASE(6);
    491       PARAMS_CASE(7);
    492 #undef PARAMS_CASE
    493       default:
    494         return errors::InvalidArgument(
    495             "Only indices.shape[-1] values between 1 and 5 "
    496             "are currently supported.  Requested rank: ",
    497             slice_dim);
    498     }
    499   }
    500   if (bad_i >= 0) {
    501     return errors::InvalidArgument(
    502         "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [",
    503         str_util::Join(
    504             gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
    505         "] does not index into ", shape.DebugString());
    506   }
    507   return Status::OK();
    508 }
    509 }  // namespace functor
    510 
    511 #ifdef GOOGLE_CUDA
    512 // Forward declarations of the functor specializations for GPU.
    513 namespace functor {
    514 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM)           \
    515   template <>                                                           \
    516   Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()(   \
    517       const GPUDevice& d, const Index slice_size,                       \
    518       const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
    519       typename TTypes<T, 2>::Tensor Tparams,                            \
    520       typename TTypes<Index, 2>::ConstTensor Tindices,                  \
    521       typename TTypes<T, 2>::ConstTensor Tupdates,                      \
    522       typename TTypes<T, 2>::Tensor Toutput);                           \
    523   extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;
    524 
    525 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op)     \
    526   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
    527   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \
    528   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \
    529   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \
    530   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \
    531   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \
    532   DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7);
    533 
    534 #define DECLARE_GPU_SPECS_INDEX(T, Index)                                \
    535   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
    536   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD);    \
    537   DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB)
    538 
    539 #define DECLARE_GPU_SPECS(T)         \
    540   DECLARE_GPU_SPECS_INDEX(T, int32); \
    541   DECLARE_GPU_SPECS_INDEX(T, int64)
    542 
    543 // TODO(b/66916790): Support half types in ScatterNd.
    544 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
    545 TF_CALL_complex64(DECLARE_GPU_SPECS);
    546 TF_CALL_complex128(DECLARE_GPU_SPECS);
    547 
    548 #undef DECLARE_GPU_SPECS
    549 #undef DECLARE_GPU_SPECS_INDEX
    550 #undef DECLARE_GPU_SPECS_INDEX_OP
    551 
    552 }  // namespace functor
    553 
    554 #endif  // GOOGLE_CUDA
    555 
    556 }  // namespace tensorflow
    557