Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Our general strategy for preventing conflicts between concurrent
     17 // reads and writes of resource variables is to:
     18 // * For read operations, we:
     19 //   - acquire the variable's mutex (in "shared" mode);
     20 //   - make a (shallow) copy of the Tensor object, which increments
     21 //     the reference count on the variable's TensorBuffer;
     22 //   - release the variable's mutex;
     23 //   - use the copy of the Tensor object to do the read.
     24 // * For write operations, we:
     25 //   - acquire the variable's mutex (in "exclusive" mode);
     26 //   - check the reference count of variable's TensorBuffer and
     27 //     if it is >1, make a deep copy of the variable's Tensor;
     28 //   - mutate the variable's Tensor;
     29 //   - and release the variable's mutex.
     30 // This allows several read operations to all use the same
     31 // TensorBuffer without needing to copy. When it comes time to write
     32 // it will only make a copy if there is an outstanding read using the
     33 // buffer. Write operations are serialized by the variable's mutex.
     34 //
     35 // For sparse operations (scatter, gather, sparse optimizer updates),
     36 // we need to avoid copies, since there may not be enough memory for
     37 // to copies of the whole tensor. To support this, we make two
     38 // modifications to the above strategy:
     39 // * For sparse reads (gather), we hold the variable's mutex (still in
     40 //   "shared" mode) for the duration of the whole read. This means
     41 //   that as long as you only do sparse read operations no write will
     42 //   see the reference count >1.
     43 // * For sparse write operations where the user explicitly specifies
     44 //   that they want to perform the write without locks held
     45 //   (use_locking=false), we never copy even if the variable's
     46 //   reference count is >1.
     47 
     48 #define EIGEN_USE_THREADS
     49 
     50 #if GOOGLE_CUDA
     51 #define EIGEN_USE_GPU
     52 #endif
     53 
     54 #include <memory>
     55 #include <vector>
     56 
     57 #include "absl/strings/str_join.h"
     58 #include "tensorflow/core/common_runtime/device.h"
     59 #include "tensorflow/core/framework/bounds_check.h"
     60 #include "tensorflow/core/framework/op_kernel.h"
     61 #include "tensorflow/core/framework/register_types.h"
     62 #include "tensorflow/core/framework/resource_mgr.h"
     63 #include "tensorflow/core/framework/tensor_types.h"
     64 #include "tensorflow/core/framework/variant_op_registry.h"
     65 #include "tensorflow/core/kernels/dense_update_functor.h"
     66 #include "tensorflow/core/kernels/gather_functor.h"
     67 #include "tensorflow/core/kernels/resource_variable_ops.h"
     68 #include "tensorflow/core/kernels/scatter_functor.h"
     69 #include "tensorflow/core/kernels/training_op_helpers.h"
     70 #include "tensorflow/core/kernels/variable_ops.h"
     71 #include "tensorflow/core/lib/core/errors.h"
     72 #include "tensorflow/core/lib/core/refcount.h"
     73 #include "tensorflow/core/platform/mem.h"
     74 #include "tensorflow/core/platform/mutex.h"
     75 #include "tensorflow/core/platform/types.h"
     76 #include "tensorflow/core/util/util.h"
     77 
     78 namespace tensorflow {
     79 
     80 REGISTER_RESOURCE_HANDLE_KERNEL(Var);
     81 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
     82                         ResourceHandlesOp<Var>);
     83 
     84 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
     85   OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
     86 }
     87 
     88 namespace {
     89 
     90 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
     91   Tensor* output;
     92   Notification n;
     93   Status status;
     94   AllocatorAttributes attr;
     95   if (t->dtype() == DT_VARIANT) {
     96     attr.set_on_host(true);
     97   }
     98   TF_RETURN_IF_ERROR(
     99       ctx->allocate_output(output_idx, t->shape(), &output, attr));
    100   if (t->dtype() == DT_VARIANT) {
    101     output->flat<Variant>() = t->flat<Variant>();
    102   } else if (ctx->op_device_context() != nullptr) {
    103     // TODO(apassos): remove the down_cast by just returning Device* from
    104     // OpKernelContext
    105     Device* device = static_cast<Device*>(ctx->device());
    106     ctx->op_device_context()->CopyTensorInSameDevice(
    107         t, device, output, [&n, &status](const Status& s) {
    108           status = s;
    109           n.Notify();
    110         });
    111     n.WaitForNotification();
    112     return status;
    113   } else {
    114     switch (t->dtype()) {
    115 #define HANDLER(type)                       \
    116   case DataTypeToEnum<type>::value:         \
    117     output->flat<type>() = t->flat<type>(); \
    118     break;
    119       TF_CALL_ALL_TYPES(HANDLER);
    120 #undef HANDLER
    121       default:
    122         return errors::Internal("Unsupported dtype", t->dtype());
    123     }
    124   }
    125   return Status::OK();
    126 }
    127 
    128 }  // namespace
    129 
    130 void ReadVariableOp::Compute(OpKernelContext* ctx) {
    131   Var* variable = nullptr;
    132   const ResourceHandle& handle = HandleFromInput(ctx, 0);
    133   const auto status = LookupResource(ctx, handle, &variable);
    134   OP_REQUIRES(ctx, status.ok(),
    135               errors::FailedPrecondition(
    136                   "Error while reading resource variable ", handle.name(),
    137                   " from Container: ", handle.container(),
    138                   ". This could mean that the variable was uninitialized. ",
    139                   status.ToString()));
    140 
    141   core::ScopedUnref s(variable);
    142   // We're acquiring a reference to the underlying buffer while
    143   // holding a shared lock to guarantee ordering of reads and
    144   // writes.
    145   tf_shared_lock ml(*variable->mu());
    146   const Tensor* t = variable->tensor();
    147   OP_REQUIRES(ctx, dtype_ == t->dtype(),
    148               errors::InvalidArgument(
    149                   "Trying to read variable with wrong dtype. Expected ",
    150                   DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
    151   if (variable->copy_on_read_mode.load()) {
    152     OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
    153   } else {
    154     ctx->set_output(0, *t);
    155   }
    156 }
    157 
    158 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
    159   int n;
    160   OP_REQUIRES_OK(c, c->GetAttr("N", &n));
    161   OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
    162   OP_REQUIRES(c, n == dtypes_.size(),
    163               errors::InvalidArgument(
    164                   "Mismatched number of arguments to ReadVariablesOp (", n,
    165                   " vs. ", dtypes_.size(), ")"));
    166 }
    167 
    168 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
    169   std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
    170       dtypes_.size());
    171   std::vector<const ResourceHandle*> handles(dtypes_.size());
    172   for (size_t i = 0; i < dtypes_.size(); ++i) {
    173     handles[i] = &HandleFromInput(ctx, i);
    174   }
    175 
    176   OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
    177 
    178   std::vector<string> uninitialized_vars;
    179   for (int64 i = 0; i < variables.size(); i++) {
    180     if (variables[i] == nullptr) {
    181       uninitialized_vars.push_back(handles[i]->name());
    182     }
    183   }
    184 
    185   OP_REQUIRES(
    186       ctx, uninitialized_vars.empty(),
    187       errors::InvalidArgument("In ReadVariableOp the following variables were "
    188                               "found uninitialized: ",
    189                               absl::StrJoin(uninitialized_vars, ", ")));
    190 
    191   for (size_t i = 0; i < dtypes_.size(); ++i) {
    192     // We're acquiring a reference to the underlying buffer while
    193     // holding a shared lock to guarantee ordering of reads and
    194     // writes.
    195     tf_shared_lock ml(*variables[i]->mu());
    196     OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
    197                 errors::InvalidArgument(
    198                     "Trying to read variable ", handles[i]->name(),
    199                     " from Container: ", handles[i]->container(),
    200                     " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
    201                     " got ", DataTypeString(variables[i]->tensor()->dtype())));
    202     if (variables[i]->copy_on_read_mode.load()) {
    203       OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
    204     } else {
    205       const Tensor& t = *variables[i]->tensor();
    206       ctx->set_output(i, t);
    207     }
    208   }
    209 }
    210 
    211 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
    212                         ReadVariableOp);
    213 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
    214                         ReadVariablesOp);
    215 
    216 #if GOOGLE_CUDA
    217 REGISTER_KERNEL_BUILDER(
    218     Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
    219     ReadVariableOp);
    220 REGISTER_KERNEL_BUILDER(
    221     Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
    222     ReadVariablesOp);
    223 
    224 #define REGISTER_GPU_KERNELS(type)                             \
    225   namespace functor {                                          \
    226   template <>                                                  \
    227   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
    228       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
    229       typename TTypes<type>::ConstFlat rhs);                   \
    230   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
    231   }                                                            \
    232   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                  \
    233                               .Device(DEVICE_GPU)              \
    234                               .HostMemory("resource")          \
    235                               .TypeConstraint<type>("dtype"),  \
    236                           ResourceHandleOp<Var>)
    237 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
    238 TF_CALL_int64(REGISTER_GPU_KERNELS);
    239 TF_CALL_variant(REGISTER_GPU_KERNELS);
    240 #undef REGISTER_GPU_KERNELS
    241 
    242 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
    243                             .Device(DEVICE_GPU)
    244                             .HostMemory("resources")
    245                             .TypeConstraint("dtypes",
    246                                             {DT_INT64, DT_COMPLEX64,
    247                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
    248                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
    249                         ResourceHandlesOp<Var>);
    250 
    251 #endif  // GOOGLE_CUDA
    252 
    253 template <typename T>
    254 class VariableShapeOp : public OpKernel {
    255  public:
    256   explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
    257 
    258   void Compute(OpKernelContext* ctx) override {
    259     Var* variable = nullptr;
    260     OP_REQUIRES_OK(ctx,
    261                    LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
    262     core::ScopedUnref s(variable);
    263     variable->mu()->lock_shared();
    264     TensorShape shape = variable->tensor()->shape();
    265     variable->mu()->unlock_shared();
    266     Tensor* output;
    267     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
    268     for (int i = 0; i < shape.dims(); ++i) {
    269       output->flat<T>()(i) = shape.dim_size(i);
    270     }
    271   }
    272 };
    273 
    274 REGISTER_KERNEL_BUILDER(
    275     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
    276     VariableShapeOp<int32>);
    277 REGISTER_KERNEL_BUILDER(
    278     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
    279     VariableShapeOp<int64>);
    280 
    281 #if GOOGLE_CUDA
    282 
    283 REGISTER_KERNEL_BUILDER(Name("VariableShape")
    284                             .Device(DEVICE_GPU)
    285                             .TypeConstraint<int32>("out_type")
    286                             .HostMemory("output")
    287                             .HostMemory("input"),
    288                         VariableShapeOp<int32>);
    289 REGISTER_KERNEL_BUILDER(Name("VariableShape")
    290                             .Device(DEVICE_GPU)
    291                             .TypeConstraint<int64>("out_type")
    292                             .HostMemory("output")
    293                             .HostMemory("input"),
    294                         VariableShapeOp<int64>);
    295 
    296 #endif  // GOOGLE_CUDA
    297 
    298 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
    299     : OpKernel(ctx) {
    300   OP_REQUIRES_OK(ctx,
    301                  ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
    302 }
    303 
    304 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
    305   const ResourceHandle& p = HandleFromInput(ctx, 0);
    306   Status status = DeleteResource(ctx, p);
    307   if (ignore_lookup_error_ && errors::IsNotFound(status)) {
    308     return;
    309   }
    310   OP_REQUIRES_OK(ctx, status);
    311 }
    312 
    313 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
    314                         DestroyResourceOp);
    315 REGISTER_KERNEL_BUILDER(
    316     Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
    317     DestroyResourceOp);
    318 
    319 template <typename Device, typename T>
    320 class AssignVariableOp : public OpKernel {
    321  public:
    322   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
    323     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
    324     if (!c->GetAttr("_grappler_relax_allocator_constraints",
    325                     &relax_constraints_)
    326              .ok()) {
    327       relax_constraints_ = false;
    328     }
    329   }
    330 
    331   void Compute(OpKernelContext* context) override {
    332     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
    333                 errors::InvalidArgument(
    334                     "Variable and value dtypes don't match; respectively, ",
    335                     DataTypeString(dtype_), " and ",
    336                     DataTypeString(context->input(1).dtype())));
    337     Var* variable = nullptr;
    338     const Tensor& value = context->input(1);
    339     // Note: every resource-variable-manipulating op assumes copy-on-write
    340     // semantics, and creates a copy of the variable's Tensor if its refcount is
    341     // bigger than 1 when we try to modify it. This means we never need to copy
    342     // the original tensor for AssignVariableOp; even if there are other live
    343     // users of it we know none can modify it so this is always safe (even in
    344     // esoteric cases where the same tensor is used to initialize multiple
    345     // variables or the tensor is a constant this is safe, as future writes will
    346     // trigger copies).
    347     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
    348                                 context, HandleFromInput(context, 0), &variable,
    349                                 [this, &value](Var** ptr) {
    350                                   *ptr = new Var(dtype_);
    351                                   *(*ptr)->tensor() = value;
    352                                   (*ptr)->is_initialized = true;
    353                                   return Status::OK();
    354                                 }));
    355     core::ScopedUnref s(variable);
    356     mutex_lock ml(*variable->mu());
    357     OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
    358                 errors::InvalidArgument(
    359                     "Trying to assign variable with wrong dtype. Expected ",
    360                     DataTypeString(variable->tensor()->dtype()), " got ",
    361                     DataTypeString(dtype_)));
    362     if (variable->copy_on_read_mode.load()) {
    363       PersistentTensor unused;
    364       Tensor* tmp;
    365       AllocatorAttributes attr;
    366       attr.set_gpu_compatible(true);
    367       attr.set_nic_compatible(true);
    368       OP_REQUIRES_OK(context,
    369                      context->allocate_persistent(value.dtype(), value.shape(),
    370                                                   &unused, &tmp, attr));
    371       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
    372       copy_functor(context->eigen_device<Device>(), tmp->flat<T>(),
    373                    value.flat<T>());
    374       *variable->tensor() = *tmp;
    375     } else {
    376       *variable->tensor() = value;
    377     }
    378     variable->is_initialized = true;
    379   }
    380 
    381  private:
    382   DataType dtype_;
    383   bool relax_constraints_;
    384 };
    385 
    386 template <typename Device>
    387 class AssignVariableOp<Device, Variant> : public OpKernel {
    388  public:
    389   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
    390     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
    391     OP_REQUIRES(c, dtype_ == DT_VARIANT,
    392                 errors::Internal("Variant kernel called with dtype: ",
    393                                  DataTypeString(dtype_)));
    394   }
    395 
    396   void Compute(OpKernelContext* context) override {
    397     const Tensor& value = context->input(1);
    398     Var* variable = nullptr;
    399     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
    400                                 context, HandleFromInput(context, 0), &variable,
    401                                 [](Var** ptr) {
    402                                   // Created on host.
    403                                   *ptr = new Var(DT_VARIANT);
    404                                   return Status::OK();
    405                                 }));
    406     core::ScopedUnref s(variable);
    407 
    408     // For purposes of forwarding DT_VARIANT, we want the least
    409     // restrictive attr; we already know the input is on host.
    410     AllocatorAttributes attr;
    411 
    412     // Copying is unnecessary if we are the last user of the value
    413     // tensor, we can just adopt the input tensor's buffer instead.
    414     // Note that Variant objects themselves always reside on host.
    415     //
    416     // We nevertheless want to signal to the runtime that the tensor
    417     // should reside in memory of the associated device, as Variant
    418     // tensors may be marked as sitting on either CPU or GPU.  This
    419     // helps to elide one or more copies.
    420     std::unique_ptr<Tensor> input_alias = context->forward_input(
    421         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
    422         value.shape(),
    423         DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
    424         attr);
    425 
    426     mutex_lock ml(*variable->mu());
    427     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
    428                 errors::InvalidArgument(
    429                     "Trying to assign variable with wrong dtype. Expected ",
    430                     DataTypeString(variable->tensor()->dtype()), " got ",
    431                     DataTypeString(DT_VARIANT)));
    432     variable->is_initialized = true;
    433     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
    434 
    435     if (input_alias) {
    436       *variable->tensor() = *input_alias;
    437       return;
    438     }
    439 
    440     // Need to copy, but maybe we can re-use variable's buffer?
    441     if (!variable->tensor()->RefCountIsOne() ||
    442         !variable->tensor()->shape().IsSameSize(value.shape())) {
    443       PersistentTensor unused;
    444       Tensor* tmp;
    445       // Allocation of DT_VARIANT is always on host.
    446       attr.set_on_host(true);
    447       OP_REQUIRES_OK(context,
    448                      context->allocate_persistent(DT_VARIANT, value.shape(),
    449                                                   &unused, &tmp, attr));
    450       *variable->tensor() = *tmp;
    451     }
    452 
    453     const auto elements_in = value.flat<Variant>();
    454     auto elements_out = variable->tensor()->flat<Variant>();
    455     for (int64 i = 0; i < elements_in.size(); ++i) {
    456       elements_out(i) = elements_in(i);
    457     }
    458   }
    459 
    460  private:
    461   DataType dtype_;
    462 };
    463 
    464 #define REGISTER_KERNELS(type)                                \
    465   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
    466                               .Device(DEVICE_CPU)             \
    467                               .TypeConstraint<type>("dtype"), \
    468                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
    469 
    470 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    471 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
    472 #undef REGISTER_KERNELS
    473 
    474 #if GOOGLE_CUDA
    475 #define REGISTER_GPU_KERNELS(type)                           \
    476   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
    477                               .Device(DEVICE_GPU)            \
    478                               .TypeConstraint<type>("dtype") \
    479                               .HostMemory("resource"),       \
    480                           AssignVariableOp<GPUDevice, type>);
    481 
    482 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
    483 TF_CALL_int64(REGISTER_GPU_KERNELS);
    484 TF_CALL_variant(REGISTER_GPU_KERNELS);
    485 #undef REGISTER_GPU_KERNELS
    486 #endif  // GOOGLE_CUDA
    487 
    488 template <typename Device, typename T, DenseUpdateType Op>
    489 class AssignUpdateVariableOp : public OpKernel {
    490  public:
    491   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
    492 
    493   void Compute(OpKernelContext* context) override {
    494     Var* variable = nullptr;
    495     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    496                                            &variable));
    497     core::ScopedUnref s(variable);
    498 
    499     const Tensor& value = context->input(1);
    500     // TODO(apassos): We could possibly avoid the copy done by
    501     // PrepareToUpdateVariable() for commutative operations like Op ==
    502     // ADD if value's refcount was 1.
    503     mutex_lock ml(*variable->mu());
    504     Tensor* var_tensor = variable->tensor();
    505     OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
    506                 errors::InvalidArgument("Cannot update variable with shape ",
    507                                         var_tensor->shape().DebugString(),
    508                                         " using a Tensor with shape ",
    509                                         value.shape().DebugString(),
    510                                         ", shapes must be equal."));
    511     OP_REQUIRES_OK(
    512         context, PrepareToUpdateVariable<Device, T>(
    513                      context, var_tensor, variable->copy_on_read_mode.load()));
    514     functor::DenseUpdate<Device, T, Op> update_functor;
    515     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
    516                    value.flat<T>());
    517   }
    518 };
    519 
    520 #define REGISTER_KERNELS(type)                                     \
    521   REGISTER_KERNEL_BUILDER(                                         \
    522       Name("AssignAddVariableOp")                                  \
    523           .Device(DEVICE_CPU)                                      \
    524           .TypeConstraint<type>("dtype"),                          \
    525       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
    526   REGISTER_KERNEL_BUILDER(                                         \
    527       Name("AssignSubVariableOp")                                  \
    528           .Device(DEVICE_CPU)                                      \
    529           .TypeConstraint<type>("dtype"),                          \
    530       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
    531 
    532 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
    533 #undef REGISTER_KERNELS
    534 
    535 #if GOOGLE_CUDA
    536 #define REGISTER_GPU_KERNELS(type)                                       \
    537   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
    538                               .Device(DEVICE_GPU)                        \
    539                               .HostMemory("resource")                    \
    540                               .TypeConstraint<type>("dtype"),            \
    541                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
    542   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
    543                               .Device(DEVICE_GPU)                        \
    544                               .HostMemory("resource")                    \
    545                               .TypeConstraint<type>("dtype"),            \
    546                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
    547 
    548 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
    549 TF_CALL_int64(REGISTER_GPU_KERNELS);
    550 #undef REGISTER_GPU_KERNELS
    551 #endif  // GOOGLE_CUDA
    552 
    553 class VarIsInitializedOp : public OpKernel {
    554  public:
    555   explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
    556 
    557   void Compute(OpKernelContext* context) override {
    558     Tensor* output = nullptr;
    559     OP_REQUIRES_OK(context,
    560                    context->allocate_output(0, TensorShape({}), &output));
    561     auto output_tensor = output->tensor<bool, 0>();
    562     Var* variable = nullptr;
    563     Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
    564     if (!s.ok()) {
    565       output_tensor() = false;
    566       return;
    567     }
    568     core::ScopedUnref su(variable);
    569     mutex_lock ml(*variable->mu());
    570     output_tensor() = variable->is_initialized;
    571   }
    572 };
    573 
    574 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
    575                         VarIsInitializedOp);
    576 
    577 #if GOOGLE_CUDA
    578 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
    579                             .Device(DEVICE_GPU)
    580                             .HostMemory("resource")
    581                             .HostMemory("is_initialized"),
    582                         IsResourceInitialized<Var>);
    583 #endif  // GOOGLE_CUDA
    584 
    585 template <typename Device, typename T, typename Index>
    586 class ResourceGatherOp : public OpKernel {
    587  private:
    588   int32 batch_dims_ = 0;
    589 
    590   // Add the batch offset derrived from params to each batch of indices.
    591   // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
    592   // If indexing into a params dimension of size 4, then the indices will become
    593   // [0, 1, 2, 4, 5, 6]
    594   void AddBatchOffsets(Tensor* indices, const Tensor& params) {
    595     int64 batch_size = 1;  // The size of all batch dimensions.
    596     for (int idx = 0; idx < batch_dims_; ++idx) {
    597       batch_size *= params.dim_size(idx);
    598     }
    599 
    600     auto indices_flat = indices->flat<Index>();
    601     int64 const index_inner_size = indices->NumElements() / batch_size;
    602     int64 const batch_offset = params.dim_size(batch_dims_);
    603     for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
    604          ++batch_idx) {
    605       for (int64 idx = 0; idx < index_inner_size; ++idx) {
    606         indices_flat(dest_idx++) += batch_offset * batch_idx;
    607       }
    608     }
    609   }
    610 
    611  public:
    612   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
    613     OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
    614   }
    615 
    616   void Compute(OpKernelContext* c) override {
    617     Var* v = nullptr;
    618     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    619     core::ScopedUnref su(v);
    620     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
    621     // NOTE: We hold the lock for the whole gather operation instead
    622     // of increasing the reference count of v->tensor() to avoid a
    623     // situation where a write to the same variable will see a
    624     // reference count greater than one and make a copy of the
    625     // (potentially very large) tensor buffer.
    626     tf_shared_lock ml(*v->mu());
    627     const Tensor& params = *v->tensor();
    628     const Tensor& indices = c->input(1);
    629     OP_REQUIRES(
    630         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
    631         errors::InvalidArgument("params must be at least 1 dimensional"));
    632 
    633     // Check that we have enough index space
    634     const int64 N = indices.NumElements();
    635     OP_REQUIRES(
    636         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
    637         errors::InvalidArgument("params.shape[0] too large for ",
    638                                 DataTypeString(DataTypeToEnum<Index>::v()),
    639                                 " indexing: ", params.dim_size(0), " > ",
    640                                 std::numeric_limits<Index>::max()));
    641 
    642     // The result shape is params.shape[:batch_dims] +
    643     // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
    644     TensorShape result_shape;
    645     for (int i = 0; i < batch_dims_; ++i) {
    646       result_shape.AddDim(params.dim_size(i));
    647     }
    648     for (int i = batch_dims_; i < indices.dims(); ++i) {
    649       result_shape.AddDim(indices.dim_size(i));
    650     }
    651     for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
    652       result_shape.AddDim(params.dim_size(i));
    653     }
    654 
    655     Tensor* out = nullptr;
    656     Tensor tmp;
    657     if (params.dtype() == DT_VARIANT) {
    658       tmp = Tensor(DT_VARIANT, result_shape);
    659       c->set_output(0, tmp);
    660       out = &tmp;
    661     } else {
    662       OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
    663     }
    664 
    665     if (N > 0) {
    666       Tensor tmp_indices;
    667 
    668       // Points to the original or updated (if batch_dims is set) indices.
    669       const Tensor* op_indices = &indices;
    670       if (batch_dims_ > 0) {
    671         OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
    672                                            &tmp_indices));
    673         functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
    674         copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
    675                      indices.flat<Index>());
    676 
    677         AddBatchOffsets(&tmp_indices, params);
    678         op_indices = &tmp_indices;
    679       }
    680 
    681       int64 gather_dim_size = 1;
    682       for (int idx = 0; idx <= batch_dims_; ++idx) {
    683         gather_dim_size *= params.dim_size(idx);
    684       }
    685       int64 inner_size = 1;
    686       for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
    687         inner_size *= params.dim_size(i);
    688       }
    689       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
    690       const auto indices_flat = op_indices->flat<Index>();
    691       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
    692 
    693       functor::GatherFunctor<Device, T, Index> functor;
    694       int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
    695 
    696       OP_REQUIRES(
    697           c, bad_i < 0,
    698           errors::InvalidArgument(
    699               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
    700               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
    701     }
    702   }
    703 };
    704 
    705 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
    706   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
    707                               .Device(DEVICE_##dev)                    \
    708                               .HostMemory("resource")                  \
    709                               .TypeConstraint<type>("dtype")           \
    710                               .TypeConstraint<index_type>("Tindices"), \
    711                           ResourceGatherOp<dev##Device, type, index_type>)
    712 
    713 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
    714   REGISTER_GATHER_FULL(dev, type, int32);      \
    715   REGISTER_GATHER_FULL(dev, type, int64)
    716 
    717 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
    718 
    719 // Registration of the CPU implementations.
    720 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
    721 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
    722 
    723 // Registers GPU kernels.
    724 #if GOOGLE_CUDA
    725 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
    726 
    727 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
    728 
    729 // Variant objects themselves sit on CPU, even if they contain data
    730 // pointing to a device.
    731 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
    732                             .Device(DEVICE_GPU)
    733                             .HostMemory("resource")
    734                             .HostMemory("indices")
    735                             .TypeConstraint<Variant>("dtype")
    736                             .TypeConstraint<int32>("Tindices"),
    737                         ResourceGatherOp<GPUDevice, Variant, int32>)
    738 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
    739                             .Device(DEVICE_GPU)
    740                             .HostMemory("resource")
    741                             .HostMemory("indices")
    742                             .TypeConstraint<Variant>("dtype")
    743                             .TypeConstraint<int64>("Tindices"),
    744                         ResourceGatherOp<GPUDevice, Variant, int64>)
    745 
    746 #endif  // GOOGLE_CUDA
    747 
    748 #undef REGISTER_GATHER_CPU
    749 #undef REGISTER_GATHER_GPU
    750 #undef REGISTER_GATHER_ALL_INDICES
    751 #undef REGISTER_GATHER_FULL
    752 
    753 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
    754 class ResourceScatterUpdateOp : public OpKernel {
    755  public:
    756   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
    757 
    758   void Compute(OpKernelContext* c) override {
    759     Var* v = nullptr;
    760     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    761     core::ScopedUnref unref_v(v);
    762     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v));
    763     tf_shared_lock ml(*v->mu());
    764     Tensor* params = v->tensor();
    765     const Tensor& indices = c->input(1);
    766     const Tensor& updates = c->input(2);
    767 
    768     // Check that we have enough index space
    769     const int64 N_big = indices.NumElements();
    770     OP_REQUIRES(
    771         c, N_big <= std::numeric_limits<Index>::max(),
    772         errors::InvalidArgument("indices has too many elements for ",
    773                                 DataTypeString(DataTypeToEnum<Index>::v()),
    774                                 " indexing: ", N_big, " > ",
    775                                 std::numeric_limits<Index>::max()));
    776     const Index N = static_cast<Index>(N_big);
    777     OP_REQUIRES(
    778         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
    779         errors::InvalidArgument("params.shape[0] too large for ",
    780                                 DataTypeString(DataTypeToEnum<Index>::v()),
    781                                 " indexing: ", params->dim_size(0), " > ",
    782                                 std::numeric_limits<Index>::max()));
    783 
    784     if (N > 0) {
    785       auto indices_flat = indices.flat<Index>();
    786       auto params_flat = params->flat_outer_dims<T>();
    787       if (TensorShapeUtils::IsScalar(updates.shape())) {
    788         const auto update = updates.scalar<T>();
    789 
    790         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
    791         const Index bad_i = functor(c, c->template eigen_device<Device>(),
    792                                     params_flat, update, indices_flat);
    793         OP_REQUIRES(c, bad_i < 0,
    794                     errors::InvalidArgument(
    795                         "indices", SliceDebugString(indices.shape(), bad_i),
    796                         " = ", indices_flat(bad_i), " is not in [0, ",
    797                         params->dim_size(0), ")"));
    798       } else {
    799         int64 num_updates = updates.NumElements();
    800         OP_REQUIRES(c, num_updates % N == 0,
    801                     errors::InvalidArgument(
    802                         "shape of indices (", indices.shape().DebugString(),
    803                         ") is not compatible with the shape of updates (",
    804                         updates.shape().DebugString(), ")"));
    805         auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
    806 
    807         functor::ScatterFunctor<Device, T, Index, op> functor;
    808         const Index bad_i = functor(c, c->template eigen_device<Device>(),
    809                                     params_flat, updates_flat, indices_flat);
    810         OP_REQUIRES(c, bad_i < 0,
    811                     errors::InvalidArgument(
    812                         "indices", SliceDebugString(indices.shape(), bad_i),
    813                         " = ", indices_flat(bad_i), " is not in [0, ",
    814                         params->dim_size(0), ")"));
    815       }
    816     }
    817   }
    818 };
    819 
    820 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
    821   REGISTER_KERNEL_BUILDER(                                             \
    822       Name(name)                                                       \
    823           .Device(DEVICE_##dev)                                        \
    824           .HostMemory("resource")                                      \
    825           .TypeConstraint<type>("dtype")                               \
    826           .TypeConstraint<index_type>("Tindices"),                     \
    827       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
    828 
    829 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
    830   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
    831   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
    832 
    833 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
    834   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
    835                           scatter_op::UpdateOp::ADD);         \
    836   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
    837                           scatter_op::UpdateOp::SUB);         \
    838   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
    839                           scatter_op::UpdateOp::MUL);         \
    840   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
    841                           scatter_op::UpdateOp::DIV);         \
    842   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
    843                           scatter_op::UpdateOp::ASSIGN);
    844 #define REGISTER_SCATTER_MINMAX(type, dev)                 \
    845   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
    846                           scatter_op::UpdateOp::MIN);      \
    847   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
    848                           scatter_op::UpdateOp::MAX);
    849 
    850 // Registers CPU kernels.
    851 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
    852   REGISTER_SCATTER_ARITHMETIC(type, CPU);
    853 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
    854 
    855 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
    856 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
    857 
    858 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
    859                         scatter_op::UpdateOp::ASSIGN);
    860 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
    861                         scatter_op::UpdateOp::ASSIGN);
    862 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
    863                         scatter_op::UpdateOp::ASSIGN);
    864 
    865 // Registers GPU kernels.
    866 #if GOOGLE_CUDA
    867 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
    868   REGISTER_SCATTER_ARITHMETIC(type, GPU);
    869 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
    870 
    871 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
    872 
    873 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
    874 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
    875 
    876 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
    877                             .Device(DEVICE_GPU)
    878                             .HostMemory("resource")
    879                             .HostMemory("indices")
    880                             .TypeConstraint<Variant>("dtype")
    881                             .TypeConstraint<int32>("Tindices"),
    882                         ResourceScatterUpdateOp<GPUDevice, Variant, int32,
    883                                                 scatter_op::UpdateOp::ASSIGN>)
    884 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
    885                             .Device(DEVICE_GPU)
    886                             .HostMemory("resource")
    887                             .TypeConstraint<bool>("dtype")
    888                             .TypeConstraint<int32>("Tindices"),
    889                         ResourceScatterUpdateOp<GPUDevice, bool, int32,
    890                                                 scatter_op::UpdateOp::ASSIGN>)
    891 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
    892                             .Device(DEVICE_GPU)
    893                             .HostMemory("resource")
    894                             .HostMemory("indices")
    895                             .TypeConstraint<Variant>("dtype")
    896                             .TypeConstraint<int64>("Tindices"),
    897                         ResourceScatterUpdateOp<GPUDevice, Variant, int64,
    898                                                 scatter_op::UpdateOp::ASSIGN>)
    899 
    900 #endif  // GOOGLE_CUDA
    901 
    902 #undef REGISTER_SCATTER_ARITHMETIC
    903 #undef REGISTER_SCATTER_ARITHMETIC_CPU
    904 #undef REGISTER_SCATTER_MINMAX
    905 #undef REGISTER_SCATTER_MINMAX_CPU
    906 #undef REGISTER_SCATTER_KERNEL
    907 #undef REGISTER_SCATTER_KERNEL_INDEX
    908 
    909 }  // namespace tensorflow
    910