1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Our general strategy for preventing conflicts between concurrent 17 // reads and writes of resource variables is to: 18 // * For read operations, we: 19 // - acquire the variable's mutex (in "shared" mode); 20 // - make a (shallow) copy of the Tensor object, which increments 21 // the reference count on the variable's TensorBuffer; 22 // - release the variable's mutex; 23 // - use the copy of the Tensor object to do the read. 24 // * For write operations, we: 25 // - acquire the variable's mutex (in "exclusive" mode); 26 // - check the reference count of variable's TensorBuffer and 27 // if it is >1, make a deep copy of the variable's Tensor; 28 // - mutate the variable's Tensor; 29 // - and release the variable's mutex. 30 // This allows several read operations to all use the same 31 // TensorBuffer without needing to copy. When it comes time to write 32 // it will only make a copy if there is an outstanding read using the 33 // buffer. Write operations are serialized by the variable's mutex. 34 // 35 // For sparse operations (scatter, gather, sparse optimizer updates), 36 // we need to avoid copies, since there may not be enough memory for 37 // to copies of the whole tensor. To support this, we make two 38 // modifications to the above strategy: 39 // * For sparse reads (gather), we hold the variable's mutex (still in 40 // "shared" mode) for the duration of the whole read. This means 41 // that as long as you only do sparse read operations no write will 42 // see the reference count >1. 43 // * For sparse write operations where the user explicitly specifies 44 // that they want to perform the write without locks held 45 // (use_locking=false), we never copy even if the variable's 46 // reference count is >1. 47 48 #define EIGEN_USE_THREADS 49 50 #if GOOGLE_CUDA 51 #define EIGEN_USE_GPU 52 #endif 53 54 #include "tensorflow/core/framework/op_kernel.h" 55 #include "tensorflow/core/framework/register_types.h" 56 #include "tensorflow/core/framework/resource_mgr.h" 57 #include "tensorflow/core/framework/tensor_types.h" 58 #include "tensorflow/core/framework/variant_op_registry.h" 59 #include "tensorflow/core/kernels/bounds_check.h" 60 #include "tensorflow/core/kernels/dense_update_functor.h" 61 #include "tensorflow/core/kernels/gather_functor.h" 62 #include "tensorflow/core/kernels/scatter_functor.h" 63 #include "tensorflow/core/kernels/training_op_helpers.h" 64 #include "tensorflow/core/kernels/variable_ops.h" 65 #include "tensorflow/core/lib/core/errors.h" 66 #include "tensorflow/core/platform/mem.h" 67 #include "tensorflow/core/platform/mutex.h" 68 #include "tensorflow/core/platform/types.h" 69 #include "tensorflow/core/util/util.h" 70 71 namespace tensorflow { 72 73 REGISTER_RESOURCE_HANDLE_KERNEL(Var); 74 75 class ReadVariableOp : public OpKernel { 76 public: 77 explicit ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { 78 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); 79 } 80 81 void Compute(OpKernelContext* ctx) override { 82 Var* variable = nullptr; 83 ResourceHandle handle = HandleFromInput(ctx, 0); 84 const auto status = LookupResource(ctx, handle, &variable); 85 OP_REQUIRES(ctx, status.ok(), 86 errors::FailedPrecondition( 87 "Error while reading resource variable ", handle.name(), 88 " from Container: ", handle.container(), 89 ". This could mean that the variable was uninitialized. ", 90 status.ToString())); 91 92 core::ScopedUnref s(variable); 93 // We're acquiring a reference to the underlying buffer while 94 // holding a shared lock to guarantee ordering of reads and 95 // writes. 96 tf_shared_lock ml(*variable->mu()); 97 const Tensor& t = *variable->tensor(); 98 OP_REQUIRES( 99 ctx, dtype_ == t.dtype(), 100 errors::InvalidArgument( 101 "Trying to read variable with wrong dtype. Expected ", 102 DataTypeString(dtype_), " got ", DataTypeString(t.dtype()))); 103 ctx->set_output(0, t); 104 } 105 106 private: 107 DataType dtype_; 108 }; 109 110 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU), 111 ReadVariableOp); 112 113 #if GOOGLE_CUDA 114 REGISTER_KERNEL_BUILDER( 115 Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"), 116 ReadVariableOp); 117 118 #define REGISTER_GPU_KERNELS(type) \ 119 namespace functor { \ 120 template <> \ 121 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ 122 const GPUDevice& d, typename TTypes<type>::Flat lhs, \ 123 typename TTypes<type>::ConstFlat rhs); \ 124 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \ 125 } \ 126 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \ 127 .Device(DEVICE_GPU) \ 128 .HostMemory("resource") \ 129 .TypeConstraint<type>("dtype"), \ 130 ResourceHandleOp<Var>) 131 132 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); 133 TF_CALL_int64(REGISTER_GPU_KERNELS); 134 TF_CALL_variant(REGISTER_GPU_KERNELS); 135 #undef REGISTER_GPU_KERNELS 136 #endif // GOOGLE_CUDA 137 138 template <typename T> 139 class VariableShapeOp : public OpKernel { 140 public: 141 explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} 142 143 void Compute(OpKernelContext* ctx) override { 144 Var* variable = nullptr; 145 OP_REQUIRES_OK(ctx, 146 LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); 147 core::ScopedUnref s(variable); 148 variable->mu()->lock_shared(); 149 TensorShape shape = variable->tensor()->shape(); 150 variable->mu()->unlock_shared(); 151 Tensor* output; 152 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output)); 153 for (int i = 0; i < shape.dims(); ++i) { 154 output->flat<T>()(i) = shape.dim_size(i); 155 } 156 } 157 }; 158 159 REGISTER_KERNEL_BUILDER( 160 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"), 161 VariableShapeOp<int32>); 162 REGISTER_KERNEL_BUILDER( 163 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"), 164 VariableShapeOp<int64>); 165 166 #if GOOGLE_CUDA 167 168 REGISTER_KERNEL_BUILDER(Name("VariableShape") 169 .Device(DEVICE_GPU) 170 .TypeConstraint<int32>("out_type") 171 .HostMemory("output") 172 .HostMemory("input"), 173 VariableShapeOp<int32>); 174 REGISTER_KERNEL_BUILDER(Name("VariableShape") 175 .Device(DEVICE_GPU) 176 .TypeConstraint<int64>("out_type") 177 .HostMemory("output") 178 .HostMemory("input"), 179 VariableShapeOp<int64>); 180 181 #endif // GOOGLE_CUDA 182 183 class DestroyResourceOp : public OpKernel { 184 public: 185 explicit DestroyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 186 OP_REQUIRES_OK(ctx, 187 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_)); 188 } 189 190 void Compute(OpKernelContext* ctx) override { 191 const ResourceHandle& p = HandleFromInput(ctx, 0); 192 Status status = DeleteResource(ctx, p); 193 if (ignore_lookup_error_ && errors::IsNotFound(status)) { 194 return; 195 } 196 OP_REQUIRES_OK(ctx, status); 197 } 198 199 private: 200 bool ignore_lookup_error_; 201 }; 202 203 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU), 204 DestroyResourceOp); 205 REGISTER_KERNEL_BUILDER( 206 Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"), 207 DestroyResourceOp); 208 209 template <typename Device, typename T> 210 class AssignVariableOp : public OpKernel { 211 public: 212 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) { 213 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); 214 } 215 216 void Compute(OpKernelContext* context) override { 217 OP_REQUIRES(context, dtype_ == context->input(1).dtype(), 218 errors::InvalidArgument( 219 "Variable and value dtypes don't match; respectively, ", 220 dtype_, " and ", context->input(1).dtype())); 221 Var* variable = nullptr; 222 OP_REQUIRES_OK( 223 context, 224 LookupOrCreateResource<Var>( 225 context, HandleFromInput(context, 0), &variable, 226 [this, context](Var** ptr) { 227 *ptr = new Var(dtype_); 228 PersistentTensor unused; 229 Tensor* tmp; 230 AllocatorAttributes attr; 231 attr.set_gpu_compatible(true); 232 attr.set_nic_compatible(true); 233 TF_RETURN_IF_ERROR(context->allocate_persistent( 234 dtype_, context->input(1).shape(), &unused, &tmp, attr)); 235 *(*ptr)->tensor() = *tmp; 236 return Status::OK(); 237 })); 238 core::ScopedUnref s(variable); 239 240 OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, 241 errors::InvalidArgument( 242 "Trying to assign variable with wrong dtype. Expected ", 243 DataTypeString(variable->tensor()->dtype()), " got ", 244 DataTypeString(dtype_))); 245 246 const Tensor& value = context->input(1); 247 AllocatorAttributes attr; 248 attr.set_gpu_compatible(true); 249 attr.set_nic_compatible(true); 250 251 // Copying is unnecessary if we are the last user of the value 252 // tensor, we can just adopt the input tensor's buffer instead. 253 std::unique_ptr<Tensor> input_alias = 254 context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr); 255 mutex_lock ml(*variable->mu()); 256 if (input_alias) { 257 *variable->tensor() = *input_alias; 258 return; 259 } 260 261 // Need to copy, but maybe we can re-use variable's buffer? 262 if (!variable->tensor()->RefCountIsOne() || 263 !variable->tensor()->shape().IsSameSize(value.shape())) { 264 // Copy to new buffer 265 PersistentTensor unused; 266 Tensor* tmp; 267 OP_REQUIRES_OK(context, context->allocate_persistent( 268 dtype_, value.shape(), &unused, &tmp, attr)); 269 *variable->tensor() = *tmp; 270 } 271 functor::DenseUpdate<Device, T, ASSIGN> copy_functor; 272 copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(), 273 value.flat<T>()); 274 } 275 276 private: 277 DataType dtype_; 278 }; 279 280 template <typename Device> 281 Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to); 282 283 #define CPU_DENSE_COPY(T) \ 284 case DataTypeToEnum<T>::value: { \ 285 functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \ 286 copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \ 287 from.flat<T>()); \ 288 break; \ 289 } 290 291 #define INSTANTIATE_GET_VARIANT_COPY_FN(Device, TYPE_CALLER, TYPE_DENSE_COPY) \ 292 template <> \ 293 Status VariantCopyFn<Device>(OpKernelContext * context, const Tensor& from, \ 294 Tensor* to) { \ 295 PersistentTensor tmp; \ 296 Tensor* tensor; \ 297 AllocatorAttributes attr; \ 298 attr.set_gpu_compatible(true); \ 299 attr.set_nic_compatible(true); \ 300 TF_RETURN_IF_ERROR(context->allocate_persistent( \ 301 from.dtype(), from.shape(), &tmp, &tensor, attr)); \ 302 switch (from.dtype()) { \ 303 TYPE_CALLER(TYPE_DENSE_COPY); \ 304 default: \ 305 return errors::InvalidArgument( \ 306 "VariantCopyFn: Could not perform a deep copy of variant " \ 307 "element of type: ", \ 308 DataTypeString(from.dtype()), \ 309 " using device: ", context->device()->name()); \ 310 } \ 311 *to = *tensor; \ 312 return Status::OK(); \ 313 } 314 315 INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY); 316 317 #if GOOGLE_CUDA 318 #define GPU_DENSE_COPY(T) \ 319 case DataTypeToEnum<T>::value: { \ 320 functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \ 321 copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \ 322 from.flat<T>()); \ 323 break; \ 324 } 325 #define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \ 326 TF_CALL_GPU_ALL_TYPES(T); \ 327 TF_CALL_int32(T); \ 328 TF_CALL_int64(T); 329 INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES, 330 GPU_DENSE_COPY); 331 #undef TF_CALL_GPU_AND_ADDITIONAL_TYPES 332 #undef GPU_DENSE_COPY 333 #endif // GOOGLE_CUDA 334 335 #undef CPU_DENSE_COPY 336 #undef INSTANTIATE_GET_VARIANT_COPY_FN 337 338 template <typename Device> 339 class AssignVariableOp<Device, Variant> : public OpKernel { 340 public: 341 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) { 342 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); 343 OP_REQUIRES(c, dtype_ == DT_VARIANT, 344 errors::Internal("Variant kernel called with dtype: ", 345 DataTypeString(dtype_))); 346 } 347 348 void Compute(OpKernelContext* context) override { 349 const Tensor& value = context->input(1); 350 Var* variable = nullptr; 351 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( 352 context, HandleFromInput(context, 0), &variable, 353 [this, context](Var** ptr) { 354 // Created on host. 355 *ptr = new Var(DT_VARIANT); 356 return Status::OK(); 357 })); 358 core::ScopedUnref s(variable); 359 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT, 360 errors::InvalidArgument( 361 "Trying to assign variable with wrong dtype. Expected ", 362 DataTypeString(variable->tensor()->dtype()), " got ", 363 DataTypeString(DT_VARIANT))); 364 365 mutex_lock ml(*variable->mu()); 366 367 *variable->tensor() = Tensor(DT_VARIANT, value.shape()); 368 const auto elements_in = value.flat<Variant>(); 369 auto elements_out = variable->tensor()->flat<Variant>(); 370 auto copy_fn = std::bind(&VariantCopyFn<Device>, context, 371 std::placeholders::_1, std::placeholders::_2); 372 for (int64 i = 0; i < elements_in.size(); ++i) { 373 OP_REQUIRES_OK(context, VariantDeviceCopy( 374 VariantDeviceCopyDirection::DEVICE_TO_DEVICE, 375 elements_in(i), &elements_out(i), copy_fn)); 376 }; 377 } 378 379 private: 380 DataType dtype_; 381 }; 382 383 #define REGISTER_KERNELS(type) \ 384 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ 385 .Device(DEVICE_CPU) \ 386 .TypeConstraint<type>("dtype"), \ 387 AssignVariableOp<Eigen::ThreadPoolDevice, type>); 388 389 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 390 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); 391 #undef REGISTER_KERNELS 392 393 #if GOOGLE_CUDA 394 #define REGISTER_GPU_KERNELS(type) \ 395 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ 396 .Device(DEVICE_GPU) \ 397 .TypeConstraint<type>("dtype") \ 398 .HostMemory("resource"), \ 399 AssignVariableOp<GPUDevice, type>); 400 401 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); 402 TF_CALL_int64(REGISTER_GPU_KERNELS); 403 TF_CALL_variant(REGISTER_GPU_KERNELS); 404 #undef REGISTER_GPU_KERNELS 405 #endif // GOOGLE_CUDA 406 407 template <typename Device, typename T, DenseUpdateType Op> 408 class AssignUpdateVariableOp : public OpKernel { 409 public: 410 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {} 411 412 void Compute(OpKernelContext* context) override { 413 Var* variable = nullptr; 414 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 415 &variable)); 416 core::ScopedUnref s(variable); 417 418 const Tensor& value = context->input(1); 419 // TODO(apassos): We could possibly avoid the copy done by 420 // PrepareToUpdateVariable() for commutative operations like Op == 421 // ADD if value's refcount was 1. 422 mutex_lock ml(*variable->mu()); 423 Tensor* var_tensor = variable->tensor(); 424 OP_REQUIRES_OK(context, 425 PrepareToUpdateVariable<Device, T>(context, var_tensor)); 426 functor::DenseUpdate<Device, T, Op> update_functor; 427 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(), 428 value.flat<T>()); 429 } 430 }; 431 432 #define REGISTER_KERNELS(type) \ 433 REGISTER_KERNEL_BUILDER( \ 434 Name("AssignAddVariableOp") \ 435 .Device(DEVICE_CPU) \ 436 .TypeConstraint<type>("dtype"), \ 437 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \ 438 REGISTER_KERNEL_BUILDER( \ 439 Name("AssignSubVariableOp") \ 440 .Device(DEVICE_CPU) \ 441 .TypeConstraint<type>("dtype"), \ 442 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>); 443 444 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 445 #undef REGISTER_KERNELS 446 447 #if GOOGLE_CUDA 448 #define REGISTER_GPU_KERNELS(type) \ 449 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \ 450 .Device(DEVICE_GPU) \ 451 .HostMemory("resource") \ 452 .TypeConstraint<type>("dtype"), \ 453 AssignUpdateVariableOp<GPUDevice, type, ADD>); \ 454 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \ 455 .Device(DEVICE_GPU) \ 456 .HostMemory("resource") \ 457 .TypeConstraint<type>("dtype"), \ 458 AssignUpdateVariableOp<GPUDevice, type, SUB>); 459 460 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); 461 TF_CALL_int64(REGISTER_GPU_KERNELS); 462 #undef REGISTER_GPU_KERNELS 463 #endif // GOOGLE_CUDA 464 465 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU), 466 IsResourceInitialized<Var>); 467 468 #if GOOGLE_CUDA 469 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp") 470 .Device(DEVICE_GPU) 471 .HostMemory("resource") 472 .HostMemory("is_initialized"), 473 IsResourceInitialized<Var>); 474 #endif // GOOGLE_CUDA 475 476 template <typename Device, typename T, typename Index> 477 class ResourceGatherOp : public OpKernel { 478 public: 479 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {} 480 481 void Compute(OpKernelContext* c) override { 482 Var* v = nullptr; 483 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); 484 // NOTE: We hold the lock for the whole gather operation instead 485 // of increasing the reference count of v->tensor() to avoid a 486 // situation where a write to the same variable will see a 487 // reference count greater than one and make a copy of the 488 // (potentially very large) tensor buffer. 489 tf_shared_lock ml(*v->mu()); 490 const Tensor& params = *v->tensor(); 491 const Tensor& indices = c->input(1); 492 OP_REQUIRES( 493 c, TensorShapeUtils::IsVectorOrHigher(params.shape()), 494 errors::InvalidArgument("params must be at least 1 dimensional")); 495 496 // Check that we have enough index space 497 const int64 N = indices.NumElements(); 498 OP_REQUIRES( 499 c, params.dim_size(0) <= std::numeric_limits<Index>::max(), 500 errors::InvalidArgument("params.shape[0] too large for ", 501 DataTypeString(DataTypeToEnum<Index>::v()), 502 " indexing: ", params.dim_size(0), " > ", 503 std::numeric_limits<Index>::max())); 504 505 // The result shape is indices.shape + params.shape[1:]. 506 TensorShape result_shape = indices.shape(); 507 for (int i = 1; i < params.dims(); i++) { 508 result_shape.AddDim(params.dim_size(i)); 509 } 510 511 Tensor* out = nullptr; 512 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); 513 if (N > 0) { 514 const int64 gather_dim_size = params.dim_size(0); 515 int64 inner_size = 1; 516 for (int i = 1; i < params.dims(); i++) { 517 inner_size *= params.dim_size(i); 518 } 519 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size}); 520 auto indices_flat = indices.flat<Index>(); 521 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N}); 522 523 functor::GatherFunctor<Device, T, Index> functor; 524 int64 bad_i = functor(c, params_flat, indices_flat, out_flat); 525 526 OP_REQUIRES( 527 c, bad_i < 0, 528 errors::InvalidArgument( 529 "indices", SliceDebugString(indices.shape(), bad_i), " = ", 530 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")")); 531 } 532 } 533 }; 534 535 #define REGISTER_GATHER_FULL(dev, type, index_type) \ 536 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \ 537 .Device(DEVICE_##dev) \ 538 .HostMemory("resource") \ 539 .TypeConstraint<type>("dtype") \ 540 .TypeConstraint<index_type>("Tindices"), \ 541 ResourceGatherOp<dev##Device, type, index_type>) 542 543 #define REGISTER_GATHER_ALL_INDICES(dev, type) \ 544 REGISTER_GATHER_FULL(dev, type, int32); \ 545 REGISTER_GATHER_FULL(dev, type, int64) 546 547 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) 548 549 // Registration of the CPU implementations. 550 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); 551 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); 552 553 // Registers GPU kernels. 554 #if GOOGLE_CUDA 555 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) 556 557 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GATHER_GPU); 558 559 #endif // GOOGLE_CUDA 560 561 #undef REGISTER_GATHER_CPU 562 #undef REGISTER_GATHER_GPU 563 #undef REGISTER_GATHER_ALL_INDICES 564 #undef REGISTER_GATHER_FULL 565 566 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 567 class ResourceScatterUpdateOp : public OpKernel { 568 public: 569 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {} 570 571 void Compute(OpKernelContext* c) override { 572 Var* v = nullptr; 573 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); 574 core::ScopedUnref unref_v(v); 575 mutex_lock ml(*v->mu()); 576 Tensor* params = v->tensor(); 577 OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, params)); 578 const Tensor& indices = c->input(1); 579 const Tensor& updates = c->input(2); 580 581 // Check that we have enough index space 582 const int64 N_big = indices.NumElements(); 583 OP_REQUIRES( 584 c, N_big <= std::numeric_limits<Index>::max(), 585 errors::InvalidArgument("indices has too many elements for ", 586 DataTypeString(DataTypeToEnum<Index>::v()), 587 " indexing: ", N_big, " > ", 588 std::numeric_limits<Index>::max())); 589 const Index N = static_cast<Index>(indices.NumElements()); 590 OP_REQUIRES( 591 c, params->dim_size(0) <= std::numeric_limits<Index>::max(), 592 errors::InvalidArgument("params.shape[0] too large for ", 593 DataTypeString(DataTypeToEnum<Index>::v()), 594 " indexing: ", params->dim_size(0), " > ", 595 std::numeric_limits<Index>::max())); 596 597 if (N > 0) { 598 auto indices_flat = indices.flat<Index>(); 599 auto params_flat = params->flat_outer_dims<T>(); 600 auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N}); 601 602 functor::ScatterFunctor<Device, T, Index, op> functor; 603 const Index bad_i = functor(c, c->template eigen_device<Device>(), 604 params_flat, updates_flat, indices_flat); 605 OP_REQUIRES(c, bad_i < 0, 606 errors::InvalidArgument( 607 "indices", SliceDebugString(indices.shape(), bad_i), 608 " = ", indices_flat(bad_i), " is not in [0, ", 609 params->dim_size(0), ")")); 610 } 611 } 612 }; 613 614 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \ 615 REGISTER_KERNEL_BUILDER( \ 616 Name(name) \ 617 .Device(DEVICE_##dev) \ 618 .HostMemory("resource") \ 619 .TypeConstraint<type>("dtype") \ 620 .TypeConstraint<index_type>("Tindices"), \ 621 ResourceScatterUpdateOp<dev##Device, type, index_type, op>) 622 623 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \ 624 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ 625 REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); 626 627 // TODO(apassos) add the other types here. 628 #define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ 629 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \ 630 scatter_op::UpdateOp::ADD); \ 631 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \ 632 scatter_op::UpdateOp::ASSIGN); 633 634 // Registers CPU kernels. 635 #define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ 636 REGISTER_SCATTER_ARITHEMTIC(type, CPU); 637 638 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); 639 640 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", 641 scatter_op::UpdateOp::ASSIGN); 642 643 // Registers GPU kernels. 644 #if GOOGLE_CUDA 645 #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ 646 REGISTER_SCATTER_ARITHEMTIC(type, GPU); 647 648 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); 649 650 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); 651 652 #endif // GOOGLE_CUDA 653 654 #undef REGISTER_SCATTER_ARITHEMTIC 655 #undef REGISTER_SCATTER_ARITHEMTIC_CPU 656 #undef REGISTER_SCATTER_KERNEL 657 #undef REGISTER_SCATTER_KERNEL_INDEX 658 659 } // namespace tensorflow 660