1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/state_ops.cc. 17 #define EIGEN_USE_THREADS 18 19 #if GOOGLE_CUDA 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA 22 23 #include "tensorflow/core/kernels/scatter_nd_op.h" 24 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/kernels/bounds_check.h" 30 #include "tensorflow/core/kernels/dense_update_functor.h" 31 #include "tensorflow/core/kernels/fill_functor.h" 32 #include "tensorflow/core/kernels/training_op_helpers.h" 33 #include "tensorflow/core/kernels/variable_ops.h" 34 #include "tensorflow/core/lib/strings/str_util.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/types.h" 37 #include "tensorflow/core/util/util.h" 38 39 #ifdef TENSORFLOW_USE_SYCL 40 #include "tensorflow/core/common_runtime/sycl/sycl_util.h" 41 #endif // TENSORFLOW_USE_SYCL 42 43 namespace tensorflow { 44 45 typedef Eigen::ThreadPoolDevice CPUDevice; 46 typedef Eigen::GpuDevice GPUDevice; 47 #ifdef TENSORFLOW_USE_SYCL 48 typedef Eigen::SyclDevice SYCLDevice; 49 #endif // TENSORFLOW_USE_SYCL 50 51 template <typename Device, typename T, typename Index> 52 class ScatterNdOp : public OpKernel { 53 public: 54 explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) { 55 const DataType dt = DataTypeToEnum<T>::v(); 56 const DataType index_t = DataTypeToEnum<Index>::v(); 57 OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt})); 58 } 59 60 void Compute(OpKernelContext* c) override { 61 const Tensor& indices = c->input(0); 62 const Tensor& updates = c->input(1); 63 const Tensor& shape_input = c->input(2); 64 65 OP_REQUIRES(c, shape_input.dims() == 1, 66 errors::InvalidArgument("Shape must be a vector")); 67 68 auto vec = shape_input.flat<Index>(); 69 TensorShape shape; 70 OP_REQUIRES_OK(c, 71 TensorShapeUtils::MakeShape(vec.data(), vec.size(), &shape)); 72 73 Tensor out; 74 OP_REQUIRES_OK( 75 c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>( 76 c, indices, updates, shape, &out, true /*allocate*/)); 77 c->set_output(0, out); 78 } 79 }; 80 81 template <typename Device, typename T, typename Index, 82 scatter_nd_op::UpdateOp op> 83 class ScatterNdUpdateOp : public OpKernel { 84 public: 85 explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) { 86 const DataType dt = DataTypeToEnum<T>::v(); 87 const DataType dt_ref = DataTypeToEnum<T>::ref(); 88 const DataType index_t = DataTypeToEnum<Index>::v(); 89 dtype_ = c->input_type(0); 90 if (c->input_type(0) == DT_RESOURCE) { 91 // TODO(apassos): what to validate here? 92 } else if (IsRefType(c->input_type(0))) { 93 OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref})); 94 OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_)); 95 } else { 96 OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt})); 97 use_exclusive_lock_ = false; 98 } 99 } 100 101 void Compute(OpKernelContext* c) override { 102 if (dtype_ == DT_RESOURCE) { 103 if (use_exclusive_lock_) { 104 Var* v; 105 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); 106 mutex_lock m(*v->mu()); 107 DoCompute(c); 108 } else { 109 DoCompute(c); 110 } 111 } else if (use_exclusive_lock_) { 112 // If we're here, it means the input type is a ref. 113 DCHECK(IsRefType(c->input_dtype(0))); 114 // Hold mutex while we apply updates 115 mutex_lock l(*c->input_ref_mutex(0)); 116 DoCompute(c); 117 } else { 118 DoCompute(c); 119 } 120 } 121 122 private: 123 DataType dtype_; 124 bool use_exclusive_lock_; 125 126 void DoCompute(OpKernelContext* c) { 127 const Tensor& indices = c->input(1); 128 const Tensor& updates = c->input(2); 129 Tensor params; 130 TensorShape params_shape; 131 132 if (dtype_ == DT_RESOURCE) { 133 Var* v; 134 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); 135 Tensor* t = v->tensor(); 136 if (!use_exclusive_lock_) { 137 // We're not holding the lock in the outer scope so need it here. 138 mutex_lock m(*v->mu()); 139 OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t)); 140 } else { 141 OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t)); 142 } 143 params = *t; 144 params_shape = params.shape(); 145 } else if (IsRefType(c->input_dtype(0))) { 146 params = c->mutable_input(0, use_exclusive_lock_); 147 params_shape = params.shape(); 148 c->forward_ref_input_to_ref_output(0, 0); 149 OP_REQUIRES(c, params.IsInitialized(), 150 errors::FailedPrecondition("Null ref for params")); 151 } else { 152 Tensor* params_ptr; 153 params_shape = c->input(0).shape(); 154 if (!c->forward_input_to_output_with_shape(0, 0, params_shape, 155 ¶ms_ptr)) { 156 // We weren't able to forward the input to output, so just 157 // allocate a new output tensor and copy the values over. 158 OP_REQUIRES_OK(c, c->allocate_output(0, params_shape, ¶ms_ptr)); 159 params = *params_ptr; 160 functor::DenseUpdate<Device, T, ASSIGN> copy; 161 const Tensor& input_copy = c->input(0); 162 copy(c->eigen_device<Device>(), params.flat<T>(), input_copy.flat<T>()); 163 } else { 164 params = *params_ptr; 165 } 166 } 167 168 OP_REQUIRES_OK( 169 c, functor::DoScatterNd<Device, T, Index, op>( 170 c, indices, updates, params_shape, ¶ms, false /*allocate*/)); 171 } 172 }; 173 174 #define REGISTER_SCATTER_ND_KERNEL_INDEX(type, index_type, dev, name) \ 175 REGISTER_KERNEL_BUILDER(Name(name) \ 176 .Device(DEVICE_##dev) \ 177 .TypeConstraint<type>("T") \ 178 .TypeConstraint<index_type>("Tindices") \ 179 .HostMemory("shape"), \ 180 ScatterNdOp<dev##Device, type, index_type>) 181 182 #define REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, dev, name, \ 183 op) \ 184 REGISTER_KERNEL_BUILDER( \ 185 Name(name) \ 186 .Device(DEVICE_##dev) \ 187 .TypeConstraint<type>("T") \ 188 .TypeConstraint<index_type>("Tindices"), \ 189 ScatterNdUpdateOp<dev##Device, type, index_type, op>) 190 191 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \ 192 dev, name, op) \ 193 REGISTER_KERNEL_BUILDER( \ 194 Name(name) \ 195 .Device(DEVICE_##dev) \ 196 .TypeConstraint<type>("T") \ 197 .TypeConstraint<index_type>("Tindices") \ 198 .HostMemory("ref"), \ 199 ScatterNdUpdateOp<dev##Device, type, index_type, op>) 200 201 #define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \ 202 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \ 203 REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name) 204 205 #define REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ 206 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \ 207 REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) 208 209 #define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \ 210 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \ 211 op); \ 212 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op) 213 214 #define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \ 215 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ 216 scatter_nd_op::UpdateOp::ADD); \ 217 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \ 218 scatter_nd_op::UpdateOp::ADD); \ 219 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ 220 scatter_nd_op::UpdateOp::SUB); 221 222 #define REGISTER_SCATTER_ND(type, dev) \ 223 REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); 224 225 #define REGISTER_SCATTER_ND_UPDATE(type, dev) \ 226 REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \ 227 scatter_nd_op::UpdateOp::ASSIGN); \ 228 REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \ 229 type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN); 230 231 // Registers CPU kernels. 232 #define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \ 233 REGISTER_SCATTER_ND_ADD_SUB(type, CPU); 234 235 #define REGISTER_SCATTER_ND_UPDATE_CPU(type) \ 236 REGISTER_SCATTER_ND_UPDATE(type, CPU); 237 238 #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); 239 #define REGISTER_SCATTER_ND_GPU(type) REGISTER_SCATTER_ND(type, GPU); 240 241 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); 242 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); 243 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); 244 245 // Registers GPU kernels. 246 #if GOOGLE_CUDA 247 248 #define REGISTER_SCATTER_ND_ADD_SUB_GPU(type) \ 249 REGISTER_SCATTER_ND_ADD_SUB(type, GPU); 250 251 #define REGISTER_SCATTER_ND_UPDATE_GPU(type) \ 252 REGISTER_SCATTER_ND_UPDATE(type, GPU); 253 254 #define REGISTER_SCATTER_ND_ALL_GPU(type) \ 255 REGISTER_SCATTER_ND_ADD_SUB_GPU(type); \ 256 REGISTER_SCATTER_ND_UPDATE_GPU(type); \ 257 REGISTER_SCATTER_ND_GPU(type); 258 259 // TODO(b/66916790): Support half types in ScatterNd. 260 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ALL_GPU); 261 TF_CALL_complex64(REGISTER_SCATTER_ND_ALL_GPU); 262 TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU); 263 264 #undef REGISTER_SCATTER_ND_ALL_GPU 265 266 #ifdef TENSORFLOW_USE_SYCL 267 #define REGISTER_SCATTER_ND_ADD_SUB_SYCL(type) \ 268 REGISTER_SCATTER_ND_ADD_SUB(type, SYCL); 269 270 #define REGISTER_SCATTER_ND_UPDATE_SYCL(type) \ 271 REGISTER_SCATTER_ND_UPDATE(type, SYCL); 272 273 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); 274 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); 275 #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL 276 #undef REGISTER_SCATTER_ND_UPDATE_SYCL 277 #endif // TENSORFLOW_USE_SYCL 278 279 #undef REGISTER_SCATTER_ND_ADD 280 #undef REGISTER_SCATTER_ND_ADD_SUB 281 #undef REGISTER_SCATTER_ND_ADD_SUB_CPU 282 #undef REGISTER_SCATTER_ND_ADD_SUB_GPU 283 #undef REGISTER_SCATTER_ND_UPDATE 284 #undef REGISTER_SCATTER_ND_UPDATE_CPU 285 #undef REGISTER_SCATTER_ND_UPDATE_GPU 286 #undef REGISTER_SCATTER_ND_KERNEL 287 #undef REGISTER_SCATTER_ND_KERNEL_INDEX 288 289 #endif // GOOGLE_CUDA 290 291 namespace functor { 292 // Check whether updates.shape = indices.shape[:batch_dim] + 293 // params_shape[slice_dim:] 294 Status ValidateUpdateShape(const TensorShape& params_shape, 295 const Tensor& indices, const Tensor& updates) { 296 const int64 slice_dim = 297 (indices.dims() > 1) ? indices.dim_size(indices.dims() - 1) : 1; 298 const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; 299 300 auto shape_err = [&]() { 301 return errors::InvalidArgument( 302 "Must have updates.shape = indices.shape[:batch_dim] + ", 303 "params_shape[slice_dim:], got updates.shape: ", 304 updates.shape().DebugString(), 305 ", indices.shape: ", indices.shape().DebugString(), 306 ", params_shape: ", params_shape.DebugString(), 307 ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim); 308 }; 309 310 if (updates.dims() < batch_dim) return shape_err(); 311 if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) { 312 return shape_err(); 313 } 314 if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) { 315 return shape_err(); 316 } 317 for (int d = 0; d < batch_dim; ++d) { 318 if (updates.dim_size(d) != indices.dim_size(d)) return shape_err(); 319 } 320 for (int d = 0; d < updates.dims() - batch_dim; ++d) { 321 if (updates.dim_size(d + batch_dim) != 322 params_shape.dim_size(d + slice_dim)) { 323 return shape_err(); 324 } 325 } 326 return Status::OK(); 327 } 328 329 template <typename Index> 330 Status PrepareAndValidateInputs(const TensorShape& params_shape, 331 const Tensor& indices, const Tensor& updates, 332 int64* slice_dim, Index* num_updates, 333 Index* slice_size) { 334 const TensorShape& indices_shape(indices.shape()); 335 const TensorShape& updates_shape(updates.shape()); 336 337 if (!TensorShapeUtils::IsVectorOrHigher(params_shape)) { 338 return errors::InvalidArgument("Output must be at least 1-D, ", 339 "got shape: ", params_shape.DebugString()); 340 } 341 342 if (!(params_shape.num_elements() > 0 || 343 (indices.NumElements() == 0 && updates.NumElements() == 0))) { 344 return errors::InvalidArgument( 345 "Indices and updates specified for empty output. indices shape: ", 346 indices.shape().DebugString()); 347 } 348 349 if (updates.dim_size(0) != indices.dim_size(0)) { 350 return errors::InvalidArgument( 351 "The outermost dimension of updates and indices ", 352 "must match. Got indices.shape ", indices_shape.DebugString(), 353 ", updates.shape ", updates_shape.DebugString()); 354 } 355 TF_RETURN_IF_ERROR(ValidateUpdateShape(params_shape, indices, updates)); 356 357 // Check that we have enough index space 358 const int64 N_big = indices.NumElements(); 359 if (N_big > std::numeric_limits<Index>::max()) { 360 return errors::InvalidArgument("indices has too many elements for ", 361 DataTypeString(DataTypeToEnum<Index>::v()), 362 " indexing: ", N_big, " > ", 363 std::numeric_limits<Index>::max()); 364 } 365 if (params_shape.dim_size(0) > std::numeric_limits<Index>::max()) { 366 return errors::InvalidArgument("params_shape[0] too large for ", 367 DataTypeString(DataTypeToEnum<Index>::v()), 368 " indexing: ", params_shape.dim_size(0), 369 " > ", std::numeric_limits<Index>::max()); 370 } 371 372 // Calculate the number of dimensions in indices 373 *slice_dim = (indices_shape.dims() > 1) 374 ? indices_shape.dim_size(indices_shape.dims() - 1) 375 : 1; 376 377 // Calculate the number of elements that make up each slice of our updated 378 // tensor. This allows us to work with flattened tensors and copy over whole 379 // slices at a time. 380 Index total_nd = params_shape.dims(); 381 382 int64 slice_size_big = 1; 383 for (int64 i = *slice_dim; i < total_nd; ++i) { 384 slice_size_big *= params_shape.dim_size(i); 385 } 386 387 if (slice_size_big > std::numeric_limits<Index>::max()) { 388 return errors::InvalidArgument( 389 "slice size is too large for indexing: ", slice_size_big, " > ", 390 std::numeric_limits<Index>::max()); 391 } 392 393 *slice_size = static_cast<Index>(slice_size_big); 394 395 const int64 safe_slice_dim = (*slice_dim < 1) ? 1 : *slice_dim; 396 *num_updates = indices_shape.num_elements() / safe_slice_dim; 397 398 return Status::OK(); 399 } 400 401 template <typename Device, typename Index> 402 class IndexFlattener { 403 public: 404 inline typename TTypes<Index, 2>::ConstTensor operator()( 405 OpKernelContext*, const Tensor& indices) { 406 return indices.flat_inner_dims<Index>(); 407 } 408 }; 409 410 #ifdef TENSORFLOW_USE_SYCL 411 template <typename Index> 412 class IndexFlattener<SYCLDevice, Index> { 413 public: 414 IndexFlattener() { indices_host_ = nullptr; } 415 ~IndexFlattener() { delete[] indices_host_; } 416 417 inline typename TTypes<Index, 2>::ConstTensor operator()( 418 OpKernelContext* c, const Tensor& indices) { 419 size_t num_indices = indices.NumElements(); 420 indices_host_ = new Index[num_indices]; 421 auto device = c->eigen_sycl_device(); 422 auto size = sizeof(Index) * num_indices; 423 auto src_ptr = GetBase(&indices); 424 device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr), 425 size); 426 return typename TTypes<Index, 2>::ConstTensor( 427 indices_host_, indices.shape().AsEigenDSizes<2>()); 428 } 429 430 private: 431 Index* indices_host_; 432 }; 433 #endif 434 435 template <typename Device, typename T, typename Index, 436 scatter_nd_op::UpdateOp Op> 437 Status DoScatterNd(OpKernelContext* c, const Tensor& indices, 438 const Tensor& updates, const TensorShape& shape, Tensor* out, 439 bool allocate) { 440 int64 slice_dim; 441 Index num_updates; 442 Index slice_size; 443 TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>( 444 shape, indices, updates, &slice_dim, &num_updates, &slice_size)); 445 446 IndexFlattener<Device, Index> index_flattener; 447 auto indices_flat = index_flattener(c, indices); 448 auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); 449 450 if (allocate) { 451 TF_RETURN_IF_ERROR(c->allocate_temp(DataTypeToEnum<T>::value, shape, out)); 452 } else { 453 CHECK_NOTNULL(out); 454 } 455 456 if (shape.num_elements() == 0) { 457 return Status::OK(); 458 } 459 460 if (allocate) { 461 // Brand new tensor, zero it out. 462 functor::SetZeroFunctor<Device, T> fill; 463 fill(c->eigen_device<Device>(), out->flat<T>()); 464 } 465 auto output_matrix = 466 out->shaped<T, 2>({shape.num_elements() / slice_size, slice_size}); 467 468 Index bad_i = -1; 469 470 if (shape.num_elements() > 0) { 471 switch (slice_dim) { 472 #define PARAMS_CASE(IXDIM) \ 473 case IXDIM: { \ 474 typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \ 475 for (int i = 0; i < IXDIM; ++i) { \ 476 output_shape_prefix[i] = shape.dim_size(i); \ 477 } \ 478 functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \ 479 bad_i = \ 480 functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \ 481 output_matrix, indices_flat, updates_flat, output_matrix); \ 482 } break 483 // TODO(simister): Re-enable this once binary size is under control. 484 // PARAMS_CASE(0); 485 PARAMS_CASE(1); 486 PARAMS_CASE(2); 487 PARAMS_CASE(3); 488 PARAMS_CASE(4); 489 PARAMS_CASE(5); 490 PARAMS_CASE(6); 491 PARAMS_CASE(7); 492 #undef PARAMS_CASE 493 default: 494 return errors::InvalidArgument( 495 "Only indices.shape[-1] values between 1 and 5 " 496 "are currently supported. Requested rank: ", 497 slice_dim); 498 } 499 } 500 if (bad_i >= 0) { 501 return errors::InvalidArgument( 502 "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [", 503 str_util::Join( 504 gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "), 505 "] does not index into ", shape.DebugString()); 506 } 507 return Status::OK(); 508 } 509 } // namespace functor 510 511 #ifdef GOOGLE_CUDA 512 // Forward declarations of the functor specializations for GPU. 513 namespace functor { 514 #define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ 515 template <> \ 516 Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \ 517 const GPUDevice& d, const Index slice_size, \ 518 const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \ 519 typename TTypes<T, 2>::Tensor Tparams, \ 520 typename TTypes<Index, 2>::ConstTensor Tindices, \ 521 typename TTypes<T, 2>::ConstTensor Tupdates, \ 522 typename TTypes<T, 2>::Tensor Toutput); \ 523 extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>; 524 525 #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ 526 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ 527 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 2); \ 528 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 3); \ 529 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 4); \ 530 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 5); \ 531 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 6); \ 532 DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 7); 533 534 #define DECLARE_GPU_SPECS_INDEX(T, Index) \ 535 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ 536 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::ADD); \ 537 DECLARE_GPU_SPECS_INDEX_OP(T, Index, scatter_nd_op::UpdateOp::SUB) 538 539 #define DECLARE_GPU_SPECS(T) \ 540 DECLARE_GPU_SPECS_INDEX(T, int32); \ 541 DECLARE_GPU_SPECS_INDEX(T, int64) 542 543 // TODO(b/66916790): Support half types in ScatterNd. 544 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); 545 TF_CALL_complex64(DECLARE_GPU_SPECS); 546 TF_CALL_complex128(DECLARE_GPU_SPECS); 547 548 #undef DECLARE_GPU_SPECS 549 #undef DECLARE_GPU_SPECS_INDEX 550 #undef DECLARE_GPU_SPECS_INDEX_OP 551 552 } // namespace functor 553 554 #endif // GOOGLE_CUDA 555 556 } // namespace tensorflow 557