Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Our general strategy for preventing conflicts between concurrent
     17 // reads and writes of resource variables is to:
     18 // * For read operations, we:
     19 //   - acquire the variable's mutex (in "shared" mode);
     20 //   - make a (shallow) copy of the Tensor object, which increments
     21 //     the reference count on the variable's TensorBuffer;
     22 //   - release the variable's mutex;
     23 //   - use the copy of the Tensor object to do the read.
     24 // * For write operations, we:
     25 //   - acquire the variable's mutex (in "exclusive" mode);
     26 //   - check the reference count of variable's TensorBuffer and
     27 //     if it is >1, make a deep copy of the variable's Tensor;
     28 //   - mutate the variable's Tensor;
     29 //   - and release the variable's mutex.
     30 // This allows several read operations to all use the same
     31 // TensorBuffer without needing to copy. When it comes time to write
     32 // it will only make a copy if there is an outstanding read using the
     33 // buffer. Write operations are serialized by the variable's mutex.
     34 //
     35 // For sparse operations (scatter, gather, sparse optimizer updates),
     36 // we need to avoid copies, since there may not be enough memory for
     37 // to copies of the whole tensor. To support this, we make two
     38 // modifications to the above strategy:
     39 // * For sparse reads (gather), we hold the variable's mutex (still in
     40 //   "shared" mode) for the duration of the whole read. This means
     41 //   that as long as you only do sparse read operations no write will
     42 //   see the reference count >1.
     43 // * For sparse write operations where the user explicitly specifies
     44 //   that they want to perform the write without locks held
     45 //   (use_locking=false), we never copy even if the variable's
     46 //   reference count is >1.
     47 
     48 #define EIGEN_USE_THREADS
     49 
     50 #if GOOGLE_CUDA
     51 #define EIGEN_USE_GPU
     52 #endif
     53 
     54 #include "tensorflow/core/framework/op_kernel.h"
     55 #include "tensorflow/core/framework/register_types.h"
     56 #include "tensorflow/core/framework/resource_mgr.h"
     57 #include "tensorflow/core/framework/tensor_types.h"
     58 #include "tensorflow/core/framework/variant_op_registry.h"
     59 #include "tensorflow/core/kernels/bounds_check.h"
     60 #include "tensorflow/core/kernels/dense_update_functor.h"
     61 #include "tensorflow/core/kernels/gather_functor.h"
     62 #include "tensorflow/core/kernels/scatter_functor.h"
     63 #include "tensorflow/core/kernels/training_op_helpers.h"
     64 #include "tensorflow/core/kernels/variable_ops.h"
     65 #include "tensorflow/core/lib/core/errors.h"
     66 #include "tensorflow/core/platform/mem.h"
     67 #include "tensorflow/core/platform/mutex.h"
     68 #include "tensorflow/core/platform/types.h"
     69 #include "tensorflow/core/util/util.h"
     70 
     71 namespace tensorflow {
     72 
     73 REGISTER_RESOURCE_HANDLE_KERNEL(Var);
     74 
     75 class ReadVariableOp : public OpKernel {
     76  public:
     77   explicit ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
     78     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
     79   }
     80 
     81   void Compute(OpKernelContext* ctx) override {
     82     Var* variable = nullptr;
     83     ResourceHandle handle = HandleFromInput(ctx, 0);
     84     const auto status = LookupResource(ctx, handle, &variable);
     85     OP_REQUIRES(ctx, status.ok(),
     86                 errors::FailedPrecondition(
     87                     "Error while reading resource variable ", handle.name(),
     88                     " from Container: ", handle.container(),
     89                     ". This could mean that the variable was uninitialized. ",
     90                     status.ToString()));
     91 
     92     core::ScopedUnref s(variable);
     93     // We're acquiring a reference to the underlying buffer while
     94     // holding a shared lock to guarantee ordering of reads and
     95     // writes.
     96     tf_shared_lock ml(*variable->mu());
     97     const Tensor& t = *variable->tensor();
     98     OP_REQUIRES(
     99         ctx, dtype_ == t.dtype(),
    100         errors::InvalidArgument(
    101             "Trying to read variable with wrong dtype. Expected ",
    102             DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
    103     ctx->set_output(0, t);
    104   }
    105 
    106  private:
    107   DataType dtype_;
    108 };
    109 
    110 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
    111                         ReadVariableOp);
    112 
    113 #if GOOGLE_CUDA
    114 REGISTER_KERNEL_BUILDER(
    115     Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
    116     ReadVariableOp);
    117 
    118 #define REGISTER_GPU_KERNELS(type)                             \
    119   namespace functor {                                          \
    120   template <>                                                  \
    121   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
    122       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
    123       typename TTypes<type>::ConstFlat rhs);                   \
    124   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
    125   }                                                            \
    126   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                  \
    127                               .Device(DEVICE_GPU)              \
    128                               .HostMemory("resource")          \
    129                               .TypeConstraint<type>("dtype"),  \
    130                           ResourceHandleOp<Var>)
    131 
    132 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
    133 TF_CALL_int64(REGISTER_GPU_KERNELS);
    134 TF_CALL_variant(REGISTER_GPU_KERNELS);
    135 #undef REGISTER_GPU_KERNELS
    136 #endif  // GOOGLE_CUDA
    137 
    138 template <typename T>
    139 class VariableShapeOp : public OpKernel {
    140  public:
    141   explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {}
    142 
    143   void Compute(OpKernelContext* ctx) override {
    144     Var* variable = nullptr;
    145     OP_REQUIRES_OK(ctx,
    146                    LookupResource(ctx, HandleFromInput(ctx, 0), &variable));
    147     core::ScopedUnref s(variable);
    148     variable->mu()->lock_shared();
    149     TensorShape shape = variable->tensor()->shape();
    150     variable->mu()->unlock_shared();
    151     Tensor* output;
    152     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output));
    153     for (int i = 0; i < shape.dims(); ++i) {
    154       output->flat<T>()(i) = shape.dim_size(i);
    155     }
    156   }
    157 };
    158 
    159 REGISTER_KERNEL_BUILDER(
    160     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
    161     VariableShapeOp<int32>);
    162 REGISTER_KERNEL_BUILDER(
    163     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int64>("out_type"),
    164     VariableShapeOp<int64>);
    165 
    166 #if GOOGLE_CUDA
    167 
    168 REGISTER_KERNEL_BUILDER(Name("VariableShape")
    169                             .Device(DEVICE_GPU)
    170                             .TypeConstraint<int32>("out_type")
    171                             .HostMemory("output")
    172                             .HostMemory("input"),
    173                         VariableShapeOp<int32>);
    174 REGISTER_KERNEL_BUILDER(Name("VariableShape")
    175                             .Device(DEVICE_GPU)
    176                             .TypeConstraint<int64>("out_type")
    177                             .HostMemory("output")
    178                             .HostMemory("input"),
    179                         VariableShapeOp<int64>);
    180 
    181 #endif  // GOOGLE_CUDA
    182 
    183 class DestroyResourceOp : public OpKernel {
    184  public:
    185   explicit DestroyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    186     OP_REQUIRES_OK(ctx,
    187                    ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
    188   }
    189 
    190   void Compute(OpKernelContext* ctx) override {
    191     const ResourceHandle& p = HandleFromInput(ctx, 0);
    192     Status status = DeleteResource(ctx, p);
    193     if (ignore_lookup_error_ && errors::IsNotFound(status)) {
    194       return;
    195     }
    196     OP_REQUIRES_OK(ctx, status);
    197   }
    198 
    199  private:
    200   bool ignore_lookup_error_;
    201 };
    202 
    203 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
    204                         DestroyResourceOp);
    205 REGISTER_KERNEL_BUILDER(
    206     Name("DestroyResourceOp").Device(DEVICE_GPU).HostMemory("resource"),
    207     DestroyResourceOp);
    208 
    209 template <typename Device, typename T>
    210 class AssignVariableOp : public OpKernel {
    211  public:
    212   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
    213     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
    214   }
    215 
    216   void Compute(OpKernelContext* context) override {
    217     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
    218                 errors::InvalidArgument(
    219                     "Variable and value dtypes don't match; respectively, ",
    220                     dtype_, " and ", context->input(1).dtype()));
    221     Var* variable = nullptr;
    222     OP_REQUIRES_OK(
    223         context,
    224         LookupOrCreateResource<Var>(
    225             context, HandleFromInput(context, 0), &variable,
    226             [this, context](Var** ptr) {
    227               *ptr = new Var(dtype_);
    228               PersistentTensor unused;
    229               Tensor* tmp;
    230               AllocatorAttributes attr;
    231               attr.set_gpu_compatible(true);
    232               attr.set_nic_compatible(true);
    233               TF_RETURN_IF_ERROR(context->allocate_persistent(
    234                   dtype_, context->input(1).shape(), &unused, &tmp, attr));
    235               *(*ptr)->tensor() = *tmp;
    236               return Status::OK();
    237             }));
    238     core::ScopedUnref s(variable);
    239 
    240     OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
    241                 errors::InvalidArgument(
    242                     "Trying to assign variable with wrong dtype. Expected ",
    243                     DataTypeString(variable->tensor()->dtype()), " got ",
    244                     DataTypeString(dtype_)));
    245 
    246     const Tensor& value = context->input(1);
    247     AllocatorAttributes attr;
    248     attr.set_gpu_compatible(true);
    249     attr.set_nic_compatible(true);
    250 
    251     // Copying is unnecessary if we are the last user of the value
    252     // tensor, we can just adopt the input tensor's buffer instead.
    253     std::unique_ptr<Tensor> input_alias =
    254         context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr);
    255     mutex_lock ml(*variable->mu());
    256     if (input_alias) {
    257       *variable->tensor() = *input_alias;
    258       return;
    259     }
    260 
    261     // Need to copy, but maybe we can re-use variable's buffer?
    262     if (!variable->tensor()->RefCountIsOne() ||
    263         !variable->tensor()->shape().IsSameSize(value.shape())) {
    264       // Copy to new buffer
    265       PersistentTensor unused;
    266       Tensor* tmp;
    267       OP_REQUIRES_OK(context, context->allocate_persistent(
    268                                   dtype_, value.shape(), &unused, &tmp, attr));
    269       *variable->tensor() = *tmp;
    270     }
    271     functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
    272     copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(),
    273                  value.flat<T>());
    274   }
    275 
    276  private:
    277   DataType dtype_;
    278 };
    279 
    280 template <typename Device>
    281 Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
    282 
    283 #define CPU_DENSE_COPY(T)                                                \
    284   case DataTypeToEnum<T>::value: {                                       \
    285     functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_;            \
    286     copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
    287                   from.flat<T>());                                       \
    288     break;                                                               \
    289   }
    290 
    291 #define INSTANTIATE_GET_VARIANT_COPY_FN(Device, TYPE_CALLER, TYPE_DENSE_COPY) \
    292   template <>                                                                 \
    293   Status VariantCopyFn<Device>(OpKernelContext * context, const Tensor& from, \
    294                                Tensor* to) {                                  \
    295     PersistentTensor tmp;                                                     \
    296     Tensor* tensor;                                                           \
    297     AllocatorAttributes attr;                                                 \
    298     attr.set_gpu_compatible(true);                                            \
    299     attr.set_nic_compatible(true);                                            \
    300     TF_RETURN_IF_ERROR(context->allocate_persistent(                          \
    301         from.dtype(), from.shape(), &tmp, &tensor, attr));                    \
    302     switch (from.dtype()) {                                                   \
    303       TYPE_CALLER(TYPE_DENSE_COPY);                                           \
    304       default:                                                                \
    305         return errors::InvalidArgument(                                       \
    306             "VariantCopyFn: Could not perform a deep copy of variant "        \
    307             "element of type: ",                                              \
    308             DataTypeString(from.dtype()),                                     \
    309             " using device: ", context->device()->name());                    \
    310     }                                                                         \
    311     *to = *tensor;                                                            \
    312     return Status::OK();                                                      \
    313   }
    314 
    315 INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
    316 
    317 #if GOOGLE_CUDA
    318 #define GPU_DENSE_COPY(T)                                                \
    319   case DataTypeToEnum<T>::value: {                                       \
    320     functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_;            \
    321     copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
    322                   from.flat<T>());                                       \
    323     break;                                                               \
    324   }
    325 #define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
    326   TF_CALL_GPU_ALL_TYPES(T);                 \
    327   TF_CALL_int32(T);                         \
    328   TF_CALL_int64(T);
    329 INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
    330                                 GPU_DENSE_COPY);
    331 #undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
    332 #undef GPU_DENSE_COPY
    333 #endif  // GOOGLE_CUDA
    334 
    335 #undef CPU_DENSE_COPY
    336 #undef INSTANTIATE_GET_VARIANT_COPY_FN
    337 
    338 template <typename Device>
    339 class AssignVariableOp<Device, Variant> : public OpKernel {
    340  public:
    341   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
    342     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
    343     OP_REQUIRES(c, dtype_ == DT_VARIANT,
    344                 errors::Internal("Variant kernel called with dtype: ",
    345                                  DataTypeString(dtype_)));
    346   }
    347 
    348   void Compute(OpKernelContext* context) override {
    349     const Tensor& value = context->input(1);
    350     Var* variable = nullptr;
    351     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
    352                                 context, HandleFromInput(context, 0), &variable,
    353                                 [this, context](Var** ptr) {
    354                                   // Created on host.
    355                                   *ptr = new Var(DT_VARIANT);
    356                                   return Status::OK();
    357                                 }));
    358     core::ScopedUnref s(variable);
    359     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
    360                 errors::InvalidArgument(
    361                     "Trying to assign variable with wrong dtype. Expected ",
    362                     DataTypeString(variable->tensor()->dtype()), " got ",
    363                     DataTypeString(DT_VARIANT)));
    364 
    365     mutex_lock ml(*variable->mu());
    366 
    367     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
    368     const auto elements_in = value.flat<Variant>();
    369     auto elements_out = variable->tensor()->flat<Variant>();
    370     auto copy_fn = std::bind(&VariantCopyFn<Device>, context,
    371                              std::placeholders::_1, std::placeholders::_2);
    372     for (int64 i = 0; i < elements_in.size(); ++i) {
    373       OP_REQUIRES_OK(context, VariantDeviceCopy(
    374                                   VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
    375                                   elements_in(i), &elements_out(i), copy_fn));
    376     };
    377   }
    378 
    379  private:
    380   DataType dtype_;
    381 };
    382 
    383 #define REGISTER_KERNELS(type)                                \
    384   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
    385                               .Device(DEVICE_CPU)             \
    386                               .TypeConstraint<type>("dtype"), \
    387                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
    388 
    389 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
    390 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
    391 #undef REGISTER_KERNELS
    392 
    393 #if GOOGLE_CUDA
    394 #define REGISTER_GPU_KERNELS(type)                           \
    395   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
    396                               .Device(DEVICE_GPU)            \
    397                               .TypeConstraint<type>("dtype") \
    398                               .HostMemory("resource"),       \
    399                           AssignVariableOp<GPUDevice, type>);
    400 
    401 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
    402 TF_CALL_int64(REGISTER_GPU_KERNELS);
    403 TF_CALL_variant(REGISTER_GPU_KERNELS);
    404 #undef REGISTER_GPU_KERNELS
    405 #endif  // GOOGLE_CUDA
    406 
    407 template <typename Device, typename T, DenseUpdateType Op>
    408 class AssignUpdateVariableOp : public OpKernel {
    409  public:
    410   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
    411 
    412   void Compute(OpKernelContext* context) override {
    413     Var* variable = nullptr;
    414     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
    415                                            &variable));
    416     core::ScopedUnref s(variable);
    417 
    418     const Tensor& value = context->input(1);
    419     // TODO(apassos): We could possibly avoid the copy done by
    420     // PrepareToUpdateVariable() for commutative operations like Op ==
    421     // ADD if value's refcount was 1.
    422     mutex_lock ml(*variable->mu());
    423     Tensor* var_tensor = variable->tensor();
    424     OP_REQUIRES_OK(context,
    425                    PrepareToUpdateVariable<Device, T>(context, var_tensor));
    426     functor::DenseUpdate<Device, T, Op> update_functor;
    427     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
    428                    value.flat<T>());
    429   }
    430 };
    431 
    432 #define REGISTER_KERNELS(type)                                     \
    433   REGISTER_KERNEL_BUILDER(                                         \
    434       Name("AssignAddVariableOp")                                  \
    435           .Device(DEVICE_CPU)                                      \
    436           .TypeConstraint<type>("dtype"),                          \
    437       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
    438   REGISTER_KERNEL_BUILDER(                                         \
    439       Name("AssignSubVariableOp")                                  \
    440           .Device(DEVICE_CPU)                                      \
    441           .TypeConstraint<type>("dtype"),                          \
    442       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
    443 
    444 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
    445 #undef REGISTER_KERNELS
    446 
    447 #if GOOGLE_CUDA
    448 #define REGISTER_GPU_KERNELS(type)                                       \
    449   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
    450                               .Device(DEVICE_GPU)                        \
    451                               .HostMemory("resource")                    \
    452                               .TypeConstraint<type>("dtype"),            \
    453                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
    454   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
    455                               .Device(DEVICE_GPU)                        \
    456                               .HostMemory("resource")                    \
    457                               .TypeConstraint<type>("dtype"),            \
    458                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
    459 
    460 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
    461 TF_CALL_int64(REGISTER_GPU_KERNELS);
    462 #undef REGISTER_GPU_KERNELS
    463 #endif  // GOOGLE_CUDA
    464 
    465 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
    466                         IsResourceInitialized<Var>);
    467 
    468 #if GOOGLE_CUDA
    469 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
    470                             .Device(DEVICE_GPU)
    471                             .HostMemory("resource")
    472                             .HostMemory("is_initialized"),
    473                         IsResourceInitialized<Var>);
    474 #endif  // GOOGLE_CUDA
    475 
    476 template <typename Device, typename T, typename Index>
    477 class ResourceGatherOp : public OpKernel {
    478  public:
    479   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {}
    480 
    481   void Compute(OpKernelContext* c) override {
    482     Var* v = nullptr;
    483     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    484     // NOTE: We hold the lock for the whole gather operation instead
    485     // of increasing the reference count of v->tensor() to avoid a
    486     // situation where a write to the same variable will see a
    487     // reference count greater than one and make a copy of the
    488     // (potentially very large) tensor buffer.
    489     tf_shared_lock ml(*v->mu());
    490     const Tensor& params = *v->tensor();
    491     const Tensor& indices = c->input(1);
    492     OP_REQUIRES(
    493         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
    494         errors::InvalidArgument("params must be at least 1 dimensional"));
    495 
    496     // Check that we have enough index space
    497     const int64 N = indices.NumElements();
    498     OP_REQUIRES(
    499         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
    500         errors::InvalidArgument("params.shape[0] too large for ",
    501                                 DataTypeString(DataTypeToEnum<Index>::v()),
    502                                 " indexing: ", params.dim_size(0), " > ",
    503                                 std::numeric_limits<Index>::max()));
    504 
    505     // The result shape is indices.shape + params.shape[1:].
    506     TensorShape result_shape = indices.shape();
    507     for (int i = 1; i < params.dims(); i++) {
    508       result_shape.AddDim(params.dim_size(i));
    509     }
    510 
    511     Tensor* out = nullptr;
    512     OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
    513     if (N > 0) {
    514       const int64 gather_dim_size = params.dim_size(0);
    515       int64 inner_size = 1;
    516       for (int i = 1; i < params.dims(); i++) {
    517         inner_size *= params.dim_size(i);
    518       }
    519       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
    520       auto indices_flat = indices.flat<Index>();
    521       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
    522 
    523       functor::GatherFunctor<Device, T, Index> functor;
    524       int64 bad_i = functor(c, params_flat, indices_flat, out_flat);
    525 
    526       OP_REQUIRES(
    527           c, bad_i < 0,
    528           errors::InvalidArgument(
    529               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
    530               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
    531     }
    532   }
    533 };
    534 
    535 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
    536   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
    537                               .Device(DEVICE_##dev)                    \
    538                               .HostMemory("resource")                  \
    539                               .TypeConstraint<type>("dtype")           \
    540                               .TypeConstraint<index_type>("Tindices"), \
    541                           ResourceGatherOp<dev##Device, type, index_type>)
    542 
    543 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
    544   REGISTER_GATHER_FULL(dev, type, int32);      \
    545   REGISTER_GATHER_FULL(dev, type, int64)
    546 
    547 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
    548 
    549 // Registration of the CPU implementations.
    550 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
    551 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
    552 
    553 // Registers GPU kernels.
    554 #if GOOGLE_CUDA
    555 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
    556 
    557 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GATHER_GPU);
    558 
    559 #endif  // GOOGLE_CUDA
    560 
    561 #undef REGISTER_GATHER_CPU
    562 #undef REGISTER_GATHER_GPU
    563 #undef REGISTER_GATHER_ALL_INDICES
    564 #undef REGISTER_GATHER_FULL
    565 
    566 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
    567 class ResourceScatterUpdateOp : public OpKernel {
    568  public:
    569   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {}
    570 
    571   void Compute(OpKernelContext* c) override {
    572     Var* v = nullptr;
    573     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
    574     core::ScopedUnref unref_v(v);
    575     mutex_lock ml(*v->mu());
    576     Tensor* params = v->tensor();
    577     OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, params));
    578     const Tensor& indices = c->input(1);
    579     const Tensor& updates = c->input(2);
    580 
    581     // Check that we have enough index space
    582     const int64 N_big = indices.NumElements();
    583     OP_REQUIRES(
    584         c, N_big <= std::numeric_limits<Index>::max(),
    585         errors::InvalidArgument("indices has too many elements for ",
    586                                 DataTypeString(DataTypeToEnum<Index>::v()),
    587                                 " indexing: ", N_big, " > ",
    588                                 std::numeric_limits<Index>::max()));
    589     const Index N = static_cast<Index>(indices.NumElements());
    590     OP_REQUIRES(
    591         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
    592         errors::InvalidArgument("params.shape[0] too large for ",
    593                                 DataTypeString(DataTypeToEnum<Index>::v()),
    594                                 " indexing: ", params->dim_size(0), " > ",
    595                                 std::numeric_limits<Index>::max()));
    596 
    597     if (N > 0) {
    598       auto indices_flat = indices.flat<Index>();
    599       auto params_flat = params->flat_outer_dims<T>();
    600       auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
    601 
    602       functor::ScatterFunctor<Device, T, Index, op> functor;
    603       const Index bad_i = functor(c, c->template eigen_device<Device>(),
    604                                   params_flat, updates_flat, indices_flat);
    605       OP_REQUIRES(c, bad_i < 0,
    606                   errors::InvalidArgument(
    607                       "indices", SliceDebugString(indices.shape(), bad_i),
    608                       " = ", indices_flat(bad_i), " is not in [0, ",
    609                       params->dim_size(0), ")"));
    610     }
    611   }
    612 };
    613 
    614 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
    615   REGISTER_KERNEL_BUILDER(                                             \
    616       Name(name)                                                       \
    617           .Device(DEVICE_##dev)                                        \
    618           .HostMemory("resource")                                      \
    619           .TypeConstraint<type>("dtype")                               \
    620           .TypeConstraint<index_type>("Tindices"),                     \
    621       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
    622 
    623 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
    624   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
    625   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
    626 
    627 // TODO(apassos) add the other types here.
    628 #define REGISTER_SCATTER_ARITHEMTIC(type, dev)                \
    629   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
    630                           scatter_op::UpdateOp::ADD);         \
    631   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
    632                           scatter_op::UpdateOp::ASSIGN);
    633 
    634 // Registers CPU kernels.
    635 #define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
    636   REGISTER_SCATTER_ARITHEMTIC(type, CPU);
    637 
    638 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
    639 
    640 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
    641                         scatter_op::UpdateOp::ASSIGN);
    642 
    643 // Registers GPU kernels.
    644 #if GOOGLE_CUDA
    645 #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
    646   REGISTER_SCATTER_ARITHEMTIC(type, GPU);
    647 
    648 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
    649 
    650 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
    651 
    652 #endif  // GOOGLE_CUDA
    653 
    654 #undef REGISTER_SCATTER_ARITHEMTIC
    655 #undef REGISTER_SCATTER_ARITHEMTIC_CPU
    656 #undef REGISTER_SCATTER_KERNEL
    657 #undef REGISTER_SCATTER_KERNEL_INDEX
    658 
    659 }  // namespace tensorflow
    660