Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/kernels/bounds_check.h"
     23 #include "tensorflow/core/kernels/training_op_helpers.h"
     24 #include "tensorflow/core/kernels/training_ops.h"
     25 #include "tensorflow/core/kernels/variable_ops.h"
     26 
     27 #ifdef TENSORFLOW_USE_SYCL
     28 #include "tensorflow/core/common_runtime/sycl/sycl_util.h"
     29 #endif  // TENSORFLOW_USE_SYCL
     30 
     31 namespace tensorflow {
     32 
     33 using CPUDevice = Eigen::ThreadPoolDevice;
     34 using GPUDevice = Eigen::GpuDevice;
     35 using SYCLDevice = Eigen::SyclDevice;
     36 
     37 namespace {
     38 template <class T>
     39 inline T sgn(const T x) {
     40   T zero(0);
     41   T one(1);
     42   return (x == zero ? zero : (x < zero ? -one : one));
     43 }
     44 }  // namespace
     45 
     46 namespace functor {
     47 template <typename T>
     48 struct ApplyGradientDescent<CPUDevice, T> {
     49   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
     50                   typename TTypes<T>::ConstScalar lr,
     51                   typename TTypes<T>::ConstFlat grad) {
     52     var.device(d) -= grad * lr();
     53   }
     54 };
     55 
     56 #ifdef TENSORFLOW_USE_SYCL
     57 template <typename T>
     58 struct ApplyGradientDescentSYCL {
     59   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var, T lr,
     60                   typename TTypes<T>::ConstFlat grad) {
     61     var.device(d) -= grad * lr;
     62   }
     63 };
     64 #endif
     65 
     66 template <typename T>
     67 struct ApplyAdadelta<CPUDevice, T> {
     68   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
     69                   typename TTypes<T>::Flat accum,
     70                   typename TTypes<T>::Flat accum_update,
     71                   typename TTypes<T>::ConstScalar lr,
     72                   typename TTypes<T>::ConstScalar rho,
     73                   typename TTypes<T>::ConstScalar epsilon,
     74                   typename TTypes<T>::ConstFlat grad) {
     75     accum.device(d) =
     76         accum * rho() + grad.square() * (static_cast<T>(1) - rho());
     77     const auto update =
     78         (accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad;
     79     var.device(d) -= update * lr();
     80     accum_update.device(d) =
     81         accum_update * rho() + update.square() * (static_cast<T>(1) - rho());
     82   }
     83 };
     84 
     85 template <typename T>
     86 struct ApplyProximalGradientDescent<CPUDevice, T> {
     87   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
     88                   typename TTypes<T>::ConstScalar lr,
     89                   typename TTypes<T>::ConstScalar l1,
     90                   typename TTypes<T>::ConstScalar l2,
     91                   typename TTypes<T>::ConstFlat grad) {
     92     // Note that here is Fobos update, for details please refer:
     93     // http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf
     94     // TODO(xbing): merge the logic for ProximalGradientDescent and
     95     // ProximalAdagrad.
     96     auto prox_var = var;
     97     // compute v = w - lr * grad.
     98     prox_var.device(d) -= grad * lr();
     99     if (l1() > 0) {
    100       // compute sign(v) * max(|v| - lr * l1, 0)
    101       var.device(d) =
    102           prox_var.sign() *
    103           (prox_var.abs() - var.constant(lr() * l1())).cwiseMax(T(0.0)) /
    104           (var.constant(1.0) + var.constant(l2() * lr()));
    105     } else {
    106       var.device(d) =
    107           prox_var / (var.constant(1.0) + var.constant(l2() * lr()));
    108     }
    109   }
    110 };
    111 
    112 template <typename T>
    113 struct ApplyAdagradDA<CPUDevice, T> {
    114   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    115                   typename TTypes<T>::Flat gradient_accum,
    116                   typename TTypes<T>::Flat gradient_squared_accum,
    117                   typename TTypes<T>::ConstScalar lr, int64 global_step,
    118                   typename TTypes<T>::ConstScalar l1,
    119                   typename TTypes<T>::ConstScalar l2,
    120                   typename TTypes<T>::ConstFlat grad) {
    121     // Accumulate gradient, and gradient_squared
    122     gradient_accum.device(d) += grad;
    123     gradient_squared_accum.device(d) += grad.square();
    124 
    125     // AdagradDA update:
    126     // Let g to be gradient accumulator, gg to be gradient squared accumulator,
    127     // T be the global step, lr is the learning rate, and k the initial
    128     // gradient squared accumulator value.
    129     // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
    130     if (l1() > 0) {
    131       var.device(d) =
    132           lr() * var.constant(-1.0) * gradient_accum.sign() *
    133           (gradient_accum.abs() -
    134            var.constant(static_cast<float>(global_step)) * var.constant(l1()))
    135               .cwiseMax(T(0.0)) /
    136           (var.constant(l2()) *
    137                var.constant(static_cast<float>(global_step) * lr()) +
    138            gradient_squared_accum.sqrt());
    139     } else {
    140       var.device(d) =
    141           lr() * gradient_accum * var.constant(-1.0) /
    142           (var.constant(l2()) *
    143                var.constant(static_cast<float>(global_step) * lr()) +
    144            gradient_squared_accum.sqrt());
    145     }
    146   }
    147 };
    148 
    149 template <typename T>
    150 struct ApplyAdagrad<CPUDevice, T> {
    151   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    152                   typename TTypes<T>::Flat accum,
    153                   typename TTypes<T>::ConstScalar lr,
    154                   typename TTypes<T>::ConstFlat grad) {
    155     accum.device(d) += grad.square();
    156     var.device(d) -= grad * lr() * accum.rsqrt();
    157   }
    158 };
    159 
    160 template <typename T>
    161 struct ApplyProximalAdagrad<CPUDevice, T> {
    162   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    163                   typename TTypes<T>::Flat accum,
    164                   typename TTypes<T>::ConstScalar lr,
    165                   typename TTypes<T>::ConstScalar l1,
    166                   typename TTypes<T>::ConstScalar l2,
    167                   typename TTypes<T>::ConstFlat grad) {
    168     // Fobos update per paper with Adagrad learning rate.
    169     accum.device(d) += grad.square();
    170     // Adagrad learning rate.
    171     auto learning_rate = accum.constant(lr()) * accum.rsqrt();
    172     auto prox_var = var;
    173     // compute v = w - lr * grad.
    174     prox_var.device(d) -= grad * learning_rate;
    175     if (l1() > 0) {
    176       // compute sign(v) * max(|v| - lr * l1, 0)
    177       var.device(d) = prox_var.sign() *
    178                       (prox_var.abs() - learning_rate * prox_var.constant(l1()))
    179                           .cwiseMax(T(0.0)) /
    180                       (var.constant(1.0) + var.constant(l2()) * learning_rate);
    181     } else {
    182       var.device(d) =
    183           prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate);
    184     }
    185   }
    186 };
    187 
    188 template <typename T>
    189 struct ApplyFtrlV2<CPUDevice, T> {
    190   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    191                   typename TTypes<T>::Flat accum,
    192                   typename TTypes<T>::Flat linear,
    193                   typename TTypes<T>::ConstFlat grad,
    194                   typename TTypes<T>::ConstScalar lr,
    195                   typename TTypes<T>::ConstScalar l1,
    196                   typename TTypes<T>::ConstScalar l2,
    197                   typename TTypes<T>::ConstScalar l2_shrinkage,
    198                   typename TTypes<T>::ConstScalar lr_power) {
    199     auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
    200     auto new_accum = accum + grad_with_shrinkage.square();
    201     // special case for which lr_power=-0.5.
    202     if (lr_power() == static_cast<T>(-0.5)) {
    203       linear.device(d) +=
    204           grad_with_shrinkage - (new_accum.sqrt() - accum.sqrt()) / lr() * var;
    205     } else {
    206       linear.device(d) +=
    207           grad_with_shrinkage -
    208           (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var;
    209     }
    210     auto x = (linear.constant(l1()) * linear.sign() - linear);
    211     if (lr_power() == static_cast<T>(-0.5)) {
    212       auto y = new_accum.sqrt() / new_accum.constant(lr()) +
    213                linear.constant(static_cast<T>(2) * l2());
    214       auto pre_shrink = x / y;
    215       var.device(d) = (linear.abs() > linear.constant(l1()))
    216                           .select(pre_shrink, var.constant(static_cast<T>(0)));
    217 
    218     } else {
    219       auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) +
    220                linear.constant(static_cast<T>(2) * l2());
    221       auto pre_shrink = x / y;
    222       var.device(d) = (linear.abs() > linear.constant(l1()))
    223                           .select(pre_shrink, var.constant(static_cast<T>(0)));
    224     }
    225     accum.device(d) += grad_with_shrinkage.square();
    226   }
    227 };
    228 
    229 template <typename T>
    230 struct ApplyFtrl<CPUDevice, T> {
    231   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    232                   typename TTypes<T>::Flat accum,
    233                   typename TTypes<T>::Flat linear,
    234                   typename TTypes<T>::ConstFlat grad,
    235                   typename TTypes<T>::ConstScalar lr,
    236                   typename TTypes<T>::ConstScalar l1,
    237                   typename TTypes<T>::ConstScalar l2,
    238                   typename TTypes<T>::ConstScalar lr_power) {
    239     auto new_accum = accum + grad.square();
    240     // special case for which lr_power=-0.5.
    241     if (lr_power() == static_cast<T>(-0.5)) {
    242       linear.device(d) += grad - (new_accum.sqrt() - accum.sqrt()) / lr() * var;
    243     } else {
    244       linear.device(d) +=
    245           grad -
    246           (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var;
    247     }
    248     auto x = (linear.constant(l1()) * linear.sign() - linear);
    249     if (lr_power() == static_cast<T>(-0.5)) {
    250       auto y = new_accum.sqrt() / new_accum.constant(lr()) +
    251                linear.constant(static_cast<T>(2) * l2());
    252       auto pre_shrink = x / y;
    253       var.device(d) = (linear.abs() > linear.constant(l1()))
    254                           .select(pre_shrink, var.constant(static_cast<T>(0)));
    255 
    256     } else {
    257       auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) +
    258                linear.constant(static_cast<T>(2) * l2());
    259       auto pre_shrink = x / y;
    260       var.device(d) = (linear.abs() > linear.constant(l1()))
    261                           .select(pre_shrink, var.constant(static_cast<T>(0)));
    262     }
    263     accum.device(d) += grad.square();
    264   }
    265 };
    266 
    267 template <typename T>
    268 struct ApplyMomentum<CPUDevice, T> {
    269   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    270                   typename TTypes<T>::Flat accum,
    271                   typename TTypes<T>::ConstScalar lr,
    272                   typename TTypes<T>::ConstFlat grad,
    273                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
    274     accum.device(d) = accum * momentum() + grad;
    275     if (use_nesterov) {
    276       var.device(d) -= grad * lr() + accum * momentum() * lr();
    277     } else {
    278       var.device(d) -= accum * lr();
    279     }
    280   }
    281 };
    282 
    283 template <typename Device, typename T>
    284 struct ApplyAdamNonCuda {
    285   void operator()(const Device& d, typename TTypes<T>::Flat var,
    286                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
    287                   typename TTypes<T>::ConstScalar beta1_power,
    288                   typename TTypes<T>::ConstScalar beta2_power,
    289                   typename TTypes<T>::ConstScalar lr,
    290                   typename TTypes<T>::ConstScalar beta1,
    291                   typename TTypes<T>::ConstScalar beta2,
    292                   typename TTypes<T>::ConstScalar epsilon,
    293                   typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
    294     const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) /
    295                     (T(1) - beta1_power());
    296     // beta1 == 
    297     // beta2 == 
    298     // v     == n
    299     // var   == 
    300 
    301     m.device(d) += (grad - m) * (T(1) - beta1());
    302     v.device(d) += (grad.square() - v) * (T(1) - beta2());
    303     if (use_nesterov) {
    304       var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) /
    305                        (v.sqrt() + epsilon());
    306     } else {
    307       var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
    308     }
    309   }
    310 };
    311 
    312 #ifdef TENSORFLOW_USE_SYCL
    313 template <typename T>
    314 struct ApplyAdamSYCL {
    315   void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var,
    316                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
    317                   T beta1_power, T beta2_power, T lr, T beta1, T beta2,
    318                   T epsilon, typename TTypes<T>::ConstFlat grad) {
    319     const T alpha =
    320         lr * Eigen::numext::sqrt(T(1) - beta2_power) / (T(1) - beta1_power);
    321     m.device(d) += (grad - m) * (T(1) - beta1);
    322     v.device(d) += (grad.square() - v) * (T(1) - beta2);
    323     var.device(d) -= (m * alpha) / (v.sqrt() + epsilon);
    324   }
    325 };
    326 #endif  // TENSORFLOW_USE_SYCL
    327 
    328 template <typename T>
    329 struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
    330 
    331 template <typename T>
    332 struct ApplyRMSProp<CPUDevice, T> {
    333   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    334                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
    335                   typename TTypes<T>::ConstScalar lr,
    336                   typename TTypes<T>::ConstScalar rho,
    337                   typename TTypes<T>::ConstScalar momentum,
    338                   typename TTypes<T>::ConstScalar epsilon,
    339                   typename TTypes<T>::ConstFlat grad) {
    340     ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho());
    341     mom.device(d) =
    342         mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt());
    343     var.device(d) -= mom;
    344   }
    345 };
    346 
    347 template <typename T>
    348 struct ApplyCenteredRMSProp<CPUDevice, T> {
    349   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    350                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
    351                   typename TTypes<T>::Flat mom,
    352                   typename TTypes<T>::ConstScalar lr,
    353                   typename TTypes<T>::ConstScalar rho,
    354                   typename TTypes<T>::ConstScalar momentum,
    355                   typename TTypes<T>::ConstScalar epsilon,
    356                   typename TTypes<T>::ConstFlat grad) {
    357     ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho());
    358     mg.device(d) += (grad - mg) * (static_cast<T>(1) - rho());
    359     auto denom = (ms - mg.square()) + epsilon();
    360     mom.device(d) = mom * momentum() + (grad * lr()) / denom.sqrt();
    361     var.device(d) -= mom;
    362   }
    363 };
    364 
    365 template <typename T>
    366 struct ApplyAddSign<CPUDevice, T> {
    367   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    368                   typename TTypes<T>::Flat m,
    369                   typename TTypes<T>::ConstScalar lr,
    370                   typename TTypes<T>::ConstScalar alpha,
    371                   typename TTypes<T>::ConstScalar sign_decay,
    372                   typename TTypes<T>::ConstScalar beta,
    373                   typename TTypes<T>::ConstFlat grad) {
    374     m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
    375     auto sign_gm = grad.sign() * m.sign();
    376     var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
    377   }
    378 };
    379 
    380 template <typename T>
    381 struct ApplyPowerSign<CPUDevice, T> {
    382   void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
    383                   typename TTypes<T>::Flat m,
    384                   typename TTypes<T>::ConstScalar lr,
    385                   typename TTypes<T>::ConstScalar logbase,
    386                   typename TTypes<T>::ConstScalar sign_decay,
    387                   typename TTypes<T>::ConstScalar beta,
    388                   typename TTypes<T>::ConstFlat grad) {
    389     m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
    390     auto sign_gm = grad.sign() * m.sign();
    391     auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
    392     var.device(d) -= lr() * grad_scale * grad;
    393   }
    394 };
    395 
    396 }  // namespace functor
    397 
    398 template <typename Device, typename T>
    399 class ApplyGradientDescentOp : public OpKernel {
    400  public:
    401   explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    402     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    403   }
    404 
    405   void Compute(OpKernelContext* ctx) override {
    406     auto locks =
    407         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
    408     Tensor var;
    409     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    410                             ctx, 0, use_exclusive_lock_, false, &var));
    411 
    412     OP_REQUIRES(
    413         ctx, var.IsInitialized(),
    414         errors::FailedPrecondition(
    415             "Attempting to use uninitialized variables: ", requested_input(0)));
    416     const Tensor& alpha = ctx->input(1);
    417     OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
    418                 errors::InvalidArgument("alpha is not a scalar: ",
    419                                         alpha.shape().DebugString()));
    420     const Tensor& delta = ctx->input(2);
    421     OP_REQUIRES(
    422         ctx, var.shape().IsSameSize(delta.shape()),
    423         errors::InvalidArgument("var and delta do not have the same shape",
    424                                 var.shape().DebugString(), " ",
    425                                 delta.shape().DebugString()));
    426 
    427     const Device& device = ctx->template eigen_device<Device>();
    428     functor::ApplyGradientDescent<Device, T>()(
    429         device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>());
    430 
    431     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
    432   }
    433 
    434  private:
    435   bool use_exclusive_lock_;
    436 };
    437 
    438 #ifdef TENSORFLOW_USE_SYCL
    439 template <typename T>
    440 class ApplyGradientDescentOp<SYCLDevice, T> : public OpKernel {
    441  public:
    442   explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    443     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    444   }
    445 
    446   void Compute(OpKernelContext* ctx) override {
    447     auto locks =
    448         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
    449     Tensor var;
    450     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
    451                             ctx, 0, use_exclusive_lock_, false, &var));
    452 
    453     OP_REQUIRES(
    454         ctx, var.IsInitialized(),
    455         errors::FailedPrecondition(
    456             "Attempting to use uninitialized variables: ", requested_input(0)));
    457     const Tensor& alpha_dev = ctx->input(1);
    458     OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()),
    459                 errors::InvalidArgument("alpha is not a scalar: ",
    460                                         alpha_dev.shape().DebugString()));
    461     const Tensor& delta = ctx->input(2);
    462     OP_REQUIRES(
    463         ctx, var.shape().IsSameSize(delta.shape()),
    464         errors::InvalidArgument("var and delta do not have the same shape",
    465                                 var.shape().DebugString(), " ",
    466                                 delta.shape().DebugString()));
    467 
    468     auto device = ctx->eigen_sycl_device();
    469     auto size = sizeof(T);
    470     T alpha = T(0);
    471     auto src_ptr = GetBase(&alpha_dev);
    472     device.memcpyDeviceToHost(&alpha, static_cast<const T*>(src_ptr), size);
    473 
    474     functor::ApplyGradientDescentSYCL<T>()(device, var.flat<T>(), alpha,
    475                                            delta.flat<T>());
    476 
    477     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
    478   }
    479 
    480  private:
    481   bool use_exclusive_lock_;
    482 };
    483 #endif  // TENSORFLOW_USE_SYCL
    484 
    485 #define REGISTER_KERNELS(D, T)                                                \
    486   REGISTER_KERNEL_BUILDER(                                                    \
    487       Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \
    488       ApplyGradientDescentOp<D##Device, T>);                                  \
    489   REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent")                \
    490                               .Device(DEVICE_##D)                             \
    491                               .HostMemory("var")                              \
    492                               .TypeConstraint<T>("T"),                        \
    493                           ApplyGradientDescentOp<D##Device, T>);
    494 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
    495 
    496 TF_CALL_half(REGISTER_CPU_KERNELS);
    497 TF_CALL_float(REGISTER_CPU_KERNELS);
    498 TF_CALL_double(REGISTER_CPU_KERNELS);
    499 
    500 #if GOOGLE_CUDA
    501 // Forward declarations of the functor specializations for GPU.
    502 namespace functor {
    503 #define DECLARE_GPU_SPEC(T)                             \
    504   template <>                                           \
    505   void ApplyGradientDescent<GPUDevice, T>::operator()(  \
    506       const GPUDevice& d, typename TTypes<T>::Flat var, \
    507       typename TTypes<T>::ConstScalar alpha,            \
    508       typename TTypes<T>::ConstFlat delta);             \
    509   extern template struct ApplyGradientDescent<GPUDevice, T>;
    510 DECLARE_GPU_SPEC(Eigen::half);
    511 DECLARE_GPU_SPEC(float);
    512 DECLARE_GPU_SPEC(double);
    513 #undef DECLARE_GPU_SPEC
    514 }  // namespace functor
    515 
    516 REGISTER_KERNELS(GPU, Eigen::half);
    517 REGISTER_KERNELS(GPU, float);
    518 REGISTER_KERNELS(GPU, double);
    519 #endif
    520 
    521 #ifdef TENSORFLOW_USE_SYCL
    522 #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
    523 TF_CALL_float(REGISTER_SYCL_KERNELS);
    524 TF_CALL_double(REGISTER_SYCL_KERNELS);
    525 #undef REGISTER_SYCL_KERNELS
    526 #endif  // TENSORFLOW_USE_SYCL
    527 
    528 #undef REGISTER_CPU_KERNELS
    529 #undef REGISTER_KERNELS
    530 
    531 template <typename Device, typename T>
    532 class ApplyAdadeltaOp : public OpKernel {
    533  public:
    534   explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    535     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    536   }
    537 
    538   void Compute(OpKernelContext* ctx) override {
    539     mutex* mu = GetTrainingVariableMutex(ctx, 0);
    540     if (use_exclusive_lock_ && mu != nullptr) {
    541       mutex_lock l1(*mu);
    542       // Don't try to acquire a lock on the second ref as they share the same
    543       // mutex.
    544       //
    545       // mutex_lock l2(*ctx->input_ref_mutex(1));
    546       DoValidate(ctx);
    547       if (!ctx->status().ok()) return;
    548       DoCompute(ctx);
    549     } else {
    550       DoValidate(ctx);
    551       if (!ctx->status().ok()) return;
    552       DoCompute(ctx);
    553     }
    554     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
    555   }
    556 
    557  private:
    558   bool use_exclusive_lock_;
    559 
    560   void DoValidate(OpKernelContext* ctx) {
    561     Tensor var;
    562     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    563                             ctx, 0, use_exclusive_lock_, false, &var));
    564     Tensor accum;
    565     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    566                             ctx, 1, use_exclusive_lock_, false, &accum));
    567     Tensor accum_update;
    568     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    569                             ctx, 2, use_exclusive_lock_, false, &accum_update));
    570 
    571     OP_REQUIRES(
    572         ctx, var.IsInitialized(),
    573         errors::FailedPrecondition(
    574             "Attempting to use uninitialized variables: ", requested_input(0)));
    575     OP_REQUIRES(
    576         ctx, accum.IsInitialized(),
    577         errors::FailedPrecondition(
    578             "Attempting to use uninitialized variables: ", requested_input(1)));
    579     OP_REQUIRES(
    580         ctx, accum_update.IsInitialized(),
    581         errors::FailedPrecondition(
    582             "Attempting to use uninitialized variables: ", requested_input(2)));
    583 
    584     const Tensor& lr = ctx->input(3);
    585     const Tensor& rho = ctx->input(4);
    586     const Tensor& epsilon = ctx->input(5);
    587     const Tensor& grad = ctx->input(6);
    588 
    589     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
    590                 errors::InvalidArgument("lr is not a scalar: ",
    591                                         lr.shape().DebugString()));
    592 
    593     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
    594                 errors::InvalidArgument("rho is not a scalar: ",
    595                                         rho.shape().DebugString()));
    596 
    597     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
    598                 errors::InvalidArgument("epsilon is not a scalar: ",
    599                                         epsilon.shape().DebugString()));
    600 
    601     OP_REQUIRES(
    602         ctx, var.shape().IsSameSize(accum.shape()),
    603         errors::InvalidArgument("var and accum do not have the same shape",
    604                                 var.shape().DebugString(), " ",
    605                                 accum.shape().DebugString()));
    606     OP_REQUIRES(
    607         ctx, var.shape().IsSameSize(grad.shape()),
    608         errors::InvalidArgument("var and grad do not have the same shape",
    609                                 var.shape().DebugString(), " ",
    610                                 grad.shape().DebugString()));
    611   }
    612 
    613   void DoCompute(OpKernelContext* ctx) {
    614     const Device& device = ctx->template eigen_device<Device>();
    615     Tensor var;
    616     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    617                             ctx, 0, use_exclusive_lock_, false, &var));
    618     Tensor accum;
    619     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    620                             ctx, 1, use_exclusive_lock_, false, &accum));
    621     Tensor accum_update;
    622     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    623                             ctx, 2, use_exclusive_lock_, false, &accum_update));
    624 
    625     const Tensor& lr = ctx->input(3);
    626     const Tensor& rho = ctx->input(4);
    627     const Tensor& epsilon = ctx->input(5);
    628     const Tensor& grad = ctx->input(6);
    629 
    630     functor::ApplyAdadelta<Device, T>()(
    631         device, var.flat<T>(), accum.flat<T>(), accum_update.flat<T>(),
    632         lr.scalar<T>(), rho.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
    633   }
    634 };
    635 
    636 #define REGISTER_KERNELS(D, T)                                         \
    637   REGISTER_KERNEL_BUILDER(                                             \
    638       Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \
    639       ApplyAdadeltaOp<D##Device, T>);                                  \
    640   REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdadelta")                \
    641                               .Device(DEVICE_##D)                      \
    642                               .HostMemory("var")                       \
    643                               .HostMemory("accum")                     \
    644                               .HostMemory("accum_update")              \
    645                               .TypeConstraint<T>("T"),                 \
    646                           ApplyAdadeltaOp<D##Device, T>);
    647 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
    648 
    649 TF_CALL_half(REGISTER_CPU_KERNELS);
    650 TF_CALL_float(REGISTER_CPU_KERNELS);
    651 TF_CALL_double(REGISTER_CPU_KERNELS);
    652 
    653 #if GOOGLE_CUDA
    654 // Forward declarations of the functor specializations for GPU.
    655 namespace functor {
    656 #define DECLARE_GPU_SPEC(T)                                                    \
    657   template <>                                                                  \
    658   void ApplyAdadelta<GPUDevice, T>::operator()(                                \
    659       const GPUDevice& d, typename TTypes<T>::Flat var,                        \
    660       typename TTypes<T>::Flat accum, typename TTypes<T>::Flat accum_update,   \
    661       typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
    662       typename TTypes<T>::ConstScalar epsilon,                                 \
    663       typename TTypes<T>::ConstFlat grad);                                     \
    664   extern template struct ApplyAdadelta<GPUDevice, T>;
    665 DECLARE_GPU_SPEC(Eigen::half);
    666 DECLARE_GPU_SPEC(float);
    667 DECLARE_GPU_SPEC(double);
    668 #undef DECLARE_GPU_SPEC
    669 }  // namespace functor
    670 
    671 REGISTER_KERNELS(GPU, Eigen::half);
    672 REGISTER_KERNELS(GPU, float);
    673 REGISTER_KERNELS(GPU, double);
    674 #endif
    675 #undef REGISTER_CPU_KERNELS
    676 #undef REGISTER_KERNELS
    677 
    678 // Note, this op works on cpu only.
    679 template <typename T, typename Tindex>
    680 class SparseApplyAdadeltaOp : public OpKernel {
    681  public:
    682   explicit SparseApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    683     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    684   }
    685 
    686   void Compute(OpKernelContext* ctx) override {
    687     mutex* mu = GetTrainingVariableMutex(ctx, 0);
    688     // mu_accum is actually the same mutex as mu_var since currently we use a
    689     // global mutex.
    690     //
    691     // mutex* mu_accum = ctx->input_ref_mutex(1);
    692     if (use_exclusive_lock_ && mu != nullptr) {
    693       mutex_lock ml(*mu);
    694       DoCompute(ctx);
    695     } else {
    696       DoCompute(ctx);
    697     }
    698   }
    699 
    700   void DoCompute(OpKernelContext* ctx) {
    701     Tensor var;
    702     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
    703                             ctx, 0, use_exclusive_lock_, true, &var));
    704     Tensor accum_grad;
    705     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
    706                             ctx, 1, use_exclusive_lock_, true, &accum_grad));
    707     Tensor accum_update;
    708     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
    709                             ctx, 2, use_exclusive_lock_, true, &accum_update));
    710     OP_REQUIRES(
    711         ctx, var.IsInitialized(),
    712         errors::FailedPrecondition(
    713             "Attempting to use uninitialized variables: ", requested_input(0)));
    714     OP_REQUIRES(
    715         ctx, accum_grad.IsInitialized(),
    716         errors::FailedPrecondition(
    717             "Attempting to use uninitialized variables: ", requested_input(1)));
    718     OP_REQUIRES(
    719         ctx, accum_update.IsInitialized(),
    720         errors::FailedPrecondition(
    721             "Attempting to use uninitialized variables: ", requested_input(2)));
    722     OP_REQUIRES(
    723         ctx, var.shape().IsSameSize(accum_grad.shape()),
    724         errors::InvalidArgument("var and accum_grad do not have the same shape",
    725                                 var.shape().DebugString(), " ",
    726                                 accum_grad.shape().DebugString()));
    727     OP_REQUIRES(ctx, var.shape().IsSameSize(accum_update.shape()),
    728                 errors::InvalidArgument(
    729                     "var and accum_update do not have the same shape",
    730                     var.shape().DebugString(), " ",
    731                     accum_update.shape().DebugString()));
    732     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
    733                 errors::InvalidArgument("var must be at least 1 dimensional"));
    734 
    735     const Tensor& lr = ctx->input(3);
    736     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
    737                 errors::InvalidArgument("lr is not a scalar: ",
    738                                         lr.shape().DebugString()));
    739     const Tensor& rho = ctx->input(4);
    740     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
    741                 errors::InvalidArgument("rho is not a scalar: ",
    742                                         rho.shape().DebugString()));
    743     const Tensor& epsilon = ctx->input(5);
    744     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
    745                 errors::InvalidArgument("epsilon is not a scalar: ",
    746                                         epsilon.shape().DebugString()));
    747     const Tensor& grad = ctx->input(6);
    748     const Tensor& indices = ctx->input(7);
    749     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
    750                 errors::InvalidArgument("indices must be one-dimensional"));
    751 
    752     for (int d = 1; d < var.dims(); d++) {
    753       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
    754                   errors::InvalidArgument(strings::StrCat(
    755                       "var and grad must match in dimension ", d)));
    756     }
    757     const Tindex N = indices.dim_size(0);
    758     OP_REQUIRES(
    759         ctx, grad.dim_size(0) == N,
    760         errors::InvalidArgument(
    761             "grad must be the same size as indices in the first dimension."));
    762 
    763     if (N > 0) {
    764       const Tindex first_dim_size = var.dim_size(0);
    765       // Validate all the indices are in range
    766       auto indices_vec = indices.vec<Tindex>();
    767       for (Tindex i = 0; i < N; i++) {
    768         const Tindex index = indices_vec(i);
    769         OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
    770                     errors::InvalidArgument(
    771                         strings::StrCat("Index ", index, " at offset ", i,
    772                                         " in indices is out of range")));
    773       }
    774 
    775       auto var_flat = var.flat_outer_dims<T>();
    776       auto accum_grad_flat = accum_grad.flat_outer_dims<T>();
    777       auto accum_update_flat = accum_update.flat_outer_dims<T>();
    778       auto grad_flat = grad.flat_outer_dims<T>();
    779       const T lr_scalar = lr.scalar<T>()();
    780       const T rho_scalar = rho.scalar<T>()();
    781       const T epsilon_scalar = epsilon.scalar<T>()();
    782 
    783       for (Tindex i = 0; i < N; i++) {
    784         const Tindex index = indices_vec(i);
    785         auto accum_ = accum_grad_flat.template chip<0>(index);
    786         auto accum_update_ = accum_update_flat.template chip<0>(index);
    787         auto grad_ = grad_flat.template chip<0>(i);
    788 
    789         accum_ = accum_ * accum_.constant(rho_scalar) +
    790                  grad_.square() * grad_.constant(T(1) - rho_scalar);
    791         const auto update =
    792             (accum_update_ + accum_update_.constant(epsilon_scalar)).sqrt() *
    793             (accum_ + accum_.constant(epsilon_scalar)).rsqrt() * grad_;
    794         auto v = var_flat.template chip<0>(index);
    795         v -= update * update.constant(lr_scalar);
    796         accum_update_ =
    797             accum_update_ * accum_update_.constant(rho_scalar) +
    798             update.square() * update.constant(static_cast<T>(1) - rho_scalar);
    799       }
    800     }
    801 
    802     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
    803   }
    804 
    805  private:
    806   bool use_exclusive_lock_;
    807 };
    808 
    809 #define REGISTER_KERNELS(T, Tindices)                                \
    810   REGISTER_KERNEL_BUILDER(Name("SparseApplyAdadelta")                \
    811                               .Device(DEVICE_CPU)                    \
    812                               .TypeConstraint<T>("T")                \
    813                               .TypeConstraint<Tindices>("Tindices"), \
    814                           SparseApplyAdadeltaOp<T, Tindices>);       \
    815   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta")        \
    816                               .Device(DEVICE_CPU)                    \
    817                               .TypeConstraint<T>("T")                \
    818                               .TypeConstraint<Tindices>("Tindices"), \
    819                           SparseApplyAdadeltaOp<T, Tindices>);
    820 #define REGISTER_CPU_KERNELS(T) \
    821   REGISTER_KERNELS(T, int32);   \
    822   REGISTER_KERNELS(T, int64);
    823 
    824 TF_CALL_half(REGISTER_CPU_KERNELS);
    825 TF_CALL_float(REGISTER_CPU_KERNELS);
    826 TF_CALL_double(REGISTER_CPU_KERNELS);
    827 
    828 #undef REGISTER_CPU_KERNELS
    829 #undef REGISTER_KERNELS
    830 
    831 // Note, this op works on cpu only.
    832 template <typename Device, typename T>
    833 class ApplyProximalGradientDescentOp : public OpKernel {
    834  public:
    835   explicit ApplyProximalGradientDescentOp(OpKernelConstruction* ctx)
    836       : OpKernel(ctx) {
    837     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    838   }
    839 
    840   void Compute(OpKernelContext* ctx) override {
    841     auto locks =
    842         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
    843     Tensor var;
    844     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
    845                             ctx, 0, use_exclusive_lock_, false, &var));
    846 
    847     OP_REQUIRES(
    848         ctx, var.IsInitialized(),
    849         errors::FailedPrecondition(
    850             "Attempting to use uninitialized variables: ", requested_input(0)));
    851     const Tensor& alpha = ctx->input(1);
    852     OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
    853                 errors::InvalidArgument("alpha is not a scalar: ",
    854                                         alpha.shape().DebugString()));
    855     const Tensor& l1 = ctx->input(2);
    856     OP_REQUIRES(
    857         ctx, TensorShapeUtils::IsScalar(l1.shape()),
    858         errors::InvalidArgument("l1 regularization strength is not a scalar: ",
    859                                 l1.shape().DebugString()));
    860     const Tensor& l2 = ctx->input(3);
    861     OP_REQUIRES(
    862         ctx, TensorShapeUtils::IsScalar(l2.shape()),
    863         errors::InvalidArgument("l2 regularization strength is not a scalar: ",
    864                                 l2.shape().DebugString()));
    865 
    866     const Tensor& delta = ctx->input(4);
    867     OP_REQUIRES(
    868         ctx, var.shape().IsSameSize(delta.shape()),
    869         errors::InvalidArgument("var and delta do not have the same shape",
    870                                 var.shape().DebugString(), " ",
    871                                 delta.shape().DebugString()));
    872 
    873     const Device& device = ctx->template eigen_device<Device>();
    874     functor::ApplyProximalGradientDescent<Device, T>()(
    875         device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(),
    876         l2.scalar<T>(), delta.flat<T>());
    877 
    878     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
    879   }
    880 
    881  private:
    882   bool use_exclusive_lock_;
    883 };
    884 
    885 #define REGISTER_KERNELS(D, T)                                           \
    886   REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent")           \
    887                               .Device(DEVICE_##D)                        \
    888                               .TypeConstraint<T>("T"),                   \
    889                           ApplyProximalGradientDescentOp<D##Device, T>); \
    890   REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent")   \
    891                               .HostMemory("var")                         \
    892                               .Device(DEVICE_##D)                        \
    893                               .TypeConstraint<T>("T"),                   \
    894                           ApplyProximalGradientDescentOp<D##Device, T>);
    895 
    896 REGISTER_KERNELS(CPU, float);
    897 REGISTER_KERNELS(CPU, double);
    898 #undef REGISTER_KERNELS
    899 
    900 // Note, this op works on cpu only.
    901 template <typename T, typename Tindex>
    902 class SparseApplyProximalGradientDescentOp : public OpKernel {
    903  public:
    904   explicit SparseApplyProximalGradientDescentOp(OpKernelConstruction* ctx)
    905       : OpKernel(ctx) {
    906     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
    907   }
    908 
    909   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
    910     auto locks =
    911         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0});
    912     Tensor var;
    913     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
    914                             ctx, 0, use_exclusive_lock_, true, &var));
    915     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
    916                 errors::InvalidArgument("var must be at least 1 dimensional"));
    917 
    918     const Tensor& lr = ctx->input(1);
    919     OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
    920                 errors::InvalidArgument("lr is not a scalar: ",
    921                                         lr.shape().DebugString()));
    922     const Tensor& l1 = ctx->input(2);
    923     OP_REQUIRES(
    924         ctx, TensorShapeUtils::IsScalar(l1.shape()),
    925         errors::InvalidArgument("l1 regularization strength is not a scalar: ",
    926                                 l1.shape().DebugString()));
    927     const Tensor& l2 = ctx->input(3);
    928     OP_REQUIRES(
    929         ctx, TensorShapeUtils::IsScalar(l2.shape()),
    930         errors::InvalidArgument("l2 regularization strength is not a scalar: ",
    931                                 l2.shape().DebugString()));
    932 
    933     const Tensor& grad = ctx->input(4);
    934     const Tensor& indices = ctx->input(5);
    935     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
    936                 errors::InvalidArgument("indices must be one-dimensional"));
    937 
    938     int64 inner_dim = 1;
    939     for (int d = 1; d < var.dims(); d++) {
    940       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
    941                   errors::InvalidArgument(strings::StrCat(
    942                       "var and grad must match in dimension ", d)));
    943       inner_dim *= grad.dim_size(d);
    944     }
    945     const Tindex N = indices.dim_size(0);
    946     OP_REQUIRES(
    947         ctx, grad.dim_size(0) == N,
    948         errors::InvalidArgument(
    949             "grad must be the same size as indices in the first dimension."));
    950     OP_REQUIRES(ctx, inner_dim > 0,
    951                 errors::InvalidArgument(
    952                     "Inner dimension should be greater than zero."));
    953 
    954     if (N > 0) {
    955       if (inner_dim > 1) {
    956         const Tindex first_dim_size = var.dim_size(0);
    957         auto indices_vec = indices.vec<Tindex>();
    958         auto var_flat = var.flat_outer_dims<T>();
    959         auto grad_flat = grad.flat_outer_dims<T>();
    960         T lr_scalar = lr.scalar<T>()();
    961         T l1_scalar = l1.scalar<T>()();
    962         T l2_scalar = l2.scalar<T>()();
    963 
    964         // TODO(xbing): extract the common logic for the Fobos update.
    965         for (Tindex i = 0; i < N; i++) {
    966           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
    967           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
    968                       errors::InvalidArgument(
    969                           strings::StrCat("Index ", index, " at offset ", i,
    970                                           " in indices is out of range")));
    971           auto g = grad_flat.template chip<0>(i);
    972           auto v = var_flat.template chip<0>(index);
    973           // compute learning_rate for current step.
    974           auto learning_rate = v.constant(lr_scalar);
    975           auto prox_v = v;
    976           // v = w - g * learning_rate.
    977           prox_v -= g * learning_rate;
    978           if (l1_scalar > 0) {
    979             // compute sign(v) * max(|v|, 0)
    980             v = prox_v.sign() *
    981                 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
    982                     .cwiseMax(static_cast<T>(0.0)) /
    983                 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
    984           } else {
    985             v = prox_v /
    986                 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
    987           }
    988         }
    989       } else {
    990         auto indices_vec = indices.vec<Tindex>();
    991         auto var_flat = var.flat<T>();
    992         auto grad_flat = grad.flat<T>();
    993         T lr_scalar = lr.scalar<T>()();
    994         T l1_scalar = l1.scalar<T>()();
    995         T l2_scalar = l2.scalar<T>()();
    996         const Tindex first_dim_size = var_flat.size();
    997 
    998         for (Tindex i = 0; i < N; i++) {
    999           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1000           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1001                       errors::InvalidArgument(
   1002                           strings::StrCat("Index ", index, " at offset ", i,
   1003                                           " in indices is out of range")));
   1004           const T& g = grad_flat(i);
   1005           auto learning_rate = lr_scalar;
   1006           auto prox_v = var_flat(index);
   1007           prox_v -= learning_rate * g;
   1008           if (l1_scalar > 0) {
   1009             var_flat(index) =
   1010                 sgn(prox_v) *
   1011                 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
   1012                          static_cast<T>(0.0)) /
   1013                 (1.0 + l2_scalar * learning_rate);
   1014           } else {
   1015             var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);
   1016           }
   1017         }
   1018       }
   1019     }
   1020 
   1021     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1022   }
   1023 
   1024  private:
   1025   bool use_exclusive_lock_;
   1026 };
   1027 
   1028 #define REGISTER_KERNELS(T, Tindices)                                         \
   1029   REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent")          \
   1030                               .Device(DEVICE_CPU)                             \
   1031                               .TypeConstraint<T>("T")                         \
   1032                               .TypeConstraint<Tindices>("Tindices"),          \
   1033                           SparseApplyProximalGradientDescentOp<T, Tindices>); \
   1034   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent")  \
   1035                               .Device(DEVICE_CPU)                             \
   1036                               .TypeConstraint<T>("T")                         \
   1037                               .TypeConstraint<Tindices>("Tindices"),          \
   1038                           SparseApplyProximalGradientDescentOp<T, Tindices>);
   1039 
   1040 REGISTER_KERNELS(float, int32);
   1041 REGISTER_KERNELS(float, int64);
   1042 REGISTER_KERNELS(double, int32);
   1043 REGISTER_KERNELS(double, int64);
   1044 #undef REGISTER_KERNELS
   1045 
   1046 template <typename Device, typename T>
   1047 class ApplyAdagradOp : public OpKernel {
   1048  public:
   1049   explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1050     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1051   }
   1052 
   1053   void Compute(OpKernelContext* ctx) override {
   1054     auto locks =
   1055         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   1056     Tensor var;
   1057     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1058                             ctx, 0, use_exclusive_lock_, false, &var));
   1059     Tensor accum;
   1060     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1061                             ctx, 1, use_exclusive_lock_, false, &accum));
   1062     OP_REQUIRES(
   1063         ctx, var.IsInitialized(),
   1064         errors::FailedPrecondition(
   1065             "Attempting to use uninitialized variables: ", requested_input(0)));
   1066     OP_REQUIRES(
   1067         ctx, accum.IsInitialized(),
   1068         errors::FailedPrecondition(
   1069             "Attempting to use uninitialized variables: ", requested_input(1)));
   1070     const Tensor& lr = ctx->input(2);
   1071     OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
   1072                 errors::InvalidArgument("lr is not a scalar: ",
   1073                                         lr.shape().DebugString()));
   1074     const Tensor& grad = ctx->input(3);
   1075     OP_REQUIRES(
   1076         ctx, var.shape().IsSameSize(accum.shape()),
   1077         errors::InvalidArgument("var and accum do not have the same shape",
   1078                                 var.shape().DebugString(), " ",
   1079                                 accum.shape().DebugString()));
   1080     OP_REQUIRES(
   1081         ctx, var.shape().IsSameSize(grad.shape()),
   1082         errors::InvalidArgument("var and grad do not have the same shape",
   1083                                 var.shape().DebugString(), " ",
   1084                                 grad.shape().DebugString()));
   1085 
   1086     const Device& device = ctx->template eigen_device<Device>();
   1087     functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
   1088                                        lr.scalar<T>(), grad.flat<T>());
   1089 
   1090     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1091   }
   1092 
   1093  private:
   1094   bool use_exclusive_lock_;
   1095 };
   1096 
   1097 #define REGISTER_KERNELS(D, T)                                        \
   1098   REGISTER_KERNEL_BUILDER(                                            \
   1099       Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   1100       ApplyAdagradOp<D##Device, T>);                                  \
   1101   REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagrad")                \
   1102                               .HostMemory("var")                      \
   1103                               .HostMemory("accum")                    \
   1104                               .Device(DEVICE_##D)                     \
   1105                               .TypeConstraint<T>("T"),                \
   1106                           ApplyAdagradOp<D##Device, T>);
   1107 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   1108 
   1109 TF_CALL_half(REGISTER_CPU_KERNELS);
   1110 TF_CALL_float(REGISTER_CPU_KERNELS);
   1111 TF_CALL_double(REGISTER_CPU_KERNELS);
   1112 
   1113 #if GOOGLE_CUDA
   1114 // Forward declarations of the functor specializations for GPU.
   1115 namespace functor {
   1116 #define DECLARE_GPU_SPEC(T)                                               \
   1117   template <>                                                             \
   1118   void ApplyAdagrad<GPUDevice, T>::operator()(                            \
   1119       const GPUDevice& d, typename TTypes<T>::Flat var,                   \
   1120       typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
   1121       typename TTypes<T>::ConstFlat grad);                                \
   1122   extern template struct ApplyAdagrad<GPUDevice, T>;
   1123 DECLARE_GPU_SPEC(Eigen::half);
   1124 DECLARE_GPU_SPEC(float);
   1125 DECLARE_GPU_SPEC(double);
   1126 #undef DECLARE_GPU_SPEC
   1127 }  // namespace functor
   1128 
   1129 REGISTER_KERNELS(GPU, Eigen::half);
   1130 REGISTER_KERNELS(GPU, float);
   1131 REGISTER_KERNELS(GPU, double);
   1132 #endif
   1133 #undef REGISTER_CPU_KERNELS
   1134 #undef REGISTER_KERNELS
   1135 
   1136 template <typename Device, typename T>
   1137 class ApplyProximalAdagradOp : public OpKernel {
   1138  public:
   1139   explicit ApplyProximalAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1140     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1141   }
   1142 
   1143   void Compute(OpKernelContext* ctx) override {
   1144     auto locks =
   1145         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   1146     Tensor var;
   1147     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1148                             ctx, 0, use_exclusive_lock_, false, &var));
   1149     Tensor accum;
   1150     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1151                             ctx, 1, use_exclusive_lock_, false, &accum));
   1152     OP_REQUIRES(
   1153         ctx, var.IsInitialized(),
   1154         errors::FailedPrecondition(
   1155             "Attempting to use uninitialized variables: ", requested_input(0)));
   1156     OP_REQUIRES(
   1157         ctx, accum.IsInitialized(),
   1158         errors::FailedPrecondition(
   1159             "Attempting to use uninitialized variables: ", requested_input(1)));
   1160     OP_REQUIRES(
   1161         ctx, var.shape().IsSameSize(accum.shape()),
   1162         errors::InvalidArgument("var and accum do not have the same shape",
   1163                                 var.shape().DebugString(), " ",
   1164                                 accum.shape().DebugString()));
   1165     const Tensor& lr = ctx->input(2);
   1166     OP_REQUIRES(ctx,
   1167                 TensorShapeUtils::IsScalar(lr.shape()) &&
   1168                     lr.scalar<T>()() > static_cast<T>(0),
   1169                 errors::InvalidArgument("lr is not a positive scalar: ",
   1170                                         lr.shape().DebugString()));
   1171     const Tensor& l1 = ctx->input(3);
   1172     OP_REQUIRES(ctx,
   1173                 TensorShapeUtils::IsScalar(l1.shape()) &&
   1174                     l1.scalar<T>()() >= static_cast<T>(0),
   1175                 errors::InvalidArgument("l1 regularization strength is not a "
   1176                                         "non-negative scalar: ",
   1177                                         l1.shape().DebugString()));
   1178     const Tensor& l2 = ctx->input(4);
   1179     OP_REQUIRES(ctx,
   1180                 TensorShapeUtils::IsScalar(l2.shape()) &&
   1181                     l2.scalar<T>()() >= static_cast<T>(0),
   1182                 errors::InvalidArgument("l2 regularization strength is not a "
   1183                                         "non-negative scalar: ",
   1184                                         l2.shape().DebugString()));
   1185     const Tensor& grad = ctx->input(5);
   1186     OP_REQUIRES(
   1187         ctx, var.shape().IsSameSize(grad.shape()),
   1188         errors::InvalidArgument("var and grad do not have the same shape",
   1189                                 var.shape().DebugString(), " ",
   1190                                 grad.shape().DebugString()));
   1191 
   1192     const Device& device = ctx->template eigen_device<Device>();
   1193     functor::ApplyProximalAdagrad<Device, T>()(
   1194         device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(),
   1195         l2.scalar<T>(), grad.flat<T>());
   1196 
   1197     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1198   }
   1199 
   1200  private:
   1201   bool use_exclusive_lock_;
   1202 };
   1203 
   1204 #define REGISTER_KERNELS(D, T)                                                \
   1205   REGISTER_KERNEL_BUILDER(                                                    \
   1206       Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   1207       ApplyProximalAdagradOp<D##Device, T>);                                  \
   1208   REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad")                \
   1209                               .Device(DEVICE_##D)                             \
   1210                               .HostMemory("var")                              \
   1211                               .HostMemory("accum")                            \
   1212                               .TypeConstraint<T>("T"),                        \
   1213                           ApplyProximalAdagradOp<D##Device, T>);
   1214 
   1215 REGISTER_KERNELS(CPU, float);
   1216 REGISTER_KERNELS(CPU, double);
   1217 #undef REGISTER_KERNELS
   1218 
   1219 namespace {
   1220 
   1221 template <typename T>
   1222 inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1,
   1223                      const T& l2, const T& lr_power) {
   1224   T quadratic;
   1225   if (lr_power == static_cast<T>(-0.5)) {
   1226     quadratic = Eigen::numext::sqrt(accum) / lr + static_cast<T>(2) * l2;
   1227   } else {
   1228     quadratic =
   1229         Eigen::numext::pow(accum, -lr_power) / lr + static_cast<T>(2) * l2;
   1230   }
   1231   auto l1_reg_adjust = std::max(std::min(linear, l1), -l1);
   1232   return (l1_reg_adjust - linear) / quadratic;
   1233 }
   1234 }  // namespace
   1235 
   1236 // Note, this op works on cpu only.
   1237 template <typename T, typename Tindex>
   1238 class SparseApplyAdagradOp : public OpKernel {
   1239  public:
   1240   explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1241     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1242   }
   1243 
   1244   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   1245     auto locks =
   1246         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   1247     Tensor var;
   1248     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1249                             ctx, 0, use_exclusive_lock_, true, &var));
   1250     Tensor accum;
   1251     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1252                             ctx, 1, use_exclusive_lock_, true, &accum));
   1253     OP_REQUIRES(
   1254         ctx, var.IsInitialized(),
   1255         errors::FailedPrecondition(
   1256             "Attempting to use uninitialized variables: ", requested_input(0)));
   1257     OP_REQUIRES(
   1258         ctx, accum.IsInitialized(),
   1259         errors::FailedPrecondition(
   1260             "Attempting to use uninitialized variables: ", requested_input(1)));
   1261     OP_REQUIRES(
   1262         ctx, var.shape().IsSameSize(accum.shape()),
   1263         errors::InvalidArgument("var and accum do not have the same shape",
   1264                                 var.shape().DebugString(), " ",
   1265                                 accum.shape().DebugString()));
   1266     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   1267                 errors::InvalidArgument("var must be at least 1 dimensional"));
   1268 
   1269     const Tensor& lr = ctx->input(2);
   1270     OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
   1271                 errors::InvalidArgument("lr is not a scalar: ",
   1272                                         lr.shape().DebugString()));
   1273     const Tensor& grad = ctx->input(3);
   1274     const Tensor& indices = ctx->input(4);
   1275     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   1276                 errors::InvalidArgument("indices must be one-dimensional"));
   1277 
   1278     int64 inner_dim = 1;
   1279     for (int d = 1; d < var.dims(); d++) {
   1280       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
   1281                   errors::InvalidArgument(strings::StrCat(
   1282                       "var and grad must match in dimension ", d)));
   1283       inner_dim *= grad.dim_size(d);
   1284     }
   1285     const Tindex N = indices.dim_size(0);
   1286     OP_REQUIRES(
   1287         ctx, grad.dim_size(0) == N,
   1288         errors::InvalidArgument(
   1289             "grad must be the same size as indices in the first dimension."));
   1290 
   1291     OP_REQUIRES(ctx, inner_dim > 0,
   1292                 errors::InvalidArgument(
   1293                     "Inner dimension should be greater than zero."));
   1294 
   1295     if (N > 0) {
   1296       if (inner_dim > 1) {
   1297         const Tindex first_dim_size = var.dim_size(0);
   1298         auto indices_vec = indices.vec<Tindex>();
   1299         auto var_flat = var.flat_outer_dims<T>();
   1300         auto accum_flat = accum.flat_outer_dims<T>();
   1301         auto grad_flat = grad.flat_outer_dims<T>();
   1302         T lr_scalar = lr.scalar<T>()();
   1303 
   1304         // Note(yonghui): It might be worth multi-threading square() and
   1305         // rsqrt().
   1306         for (Tindex i = 0; i < N; i++) {
   1307           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1308           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1309                       errors::InvalidArgument(
   1310                           strings::StrCat("Index ", index, " at offset ", i,
   1311                                           " in indices is out of range")));
   1312           auto a = accum_flat.template chip<0>(index);
   1313           auto g = grad_flat.template chip<0>(i);
   1314           auto v = var_flat.template chip<0>(index);
   1315           a += g.square();
   1316           v -= g.constant(lr_scalar) * g * a.rsqrt();
   1317         }
   1318       } else {
   1319         auto indices_vec = indices.vec<Tindex>();
   1320         auto var_flat = var.flat<T>();
   1321         auto accum_flat = accum.flat<T>();
   1322         auto grad_flat = grad.flat<T>();
   1323         T lr_scalar = lr.scalar<T>()();
   1324         const Tindex first_dim_size = accum_flat.size();
   1325 
   1326         for (Tindex i = 0; i < N; i++) {
   1327           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1328           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1329                       errors::InvalidArgument(
   1330                           strings::StrCat("Index ", index, " at offset ", i,
   1331                                           " in indices is out of range")));
   1332           T& a = accum_flat(index);
   1333           const T& g = grad_flat(i);
   1334           a += g * g;
   1335           var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
   1336         }
   1337       }
   1338     }
   1339 
   1340     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1341   }
   1342 
   1343  private:
   1344   bool use_exclusive_lock_;
   1345 };
   1346 
   1347 #define REGISTER_KERNELS(T, Tindices)                                \
   1348   REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad")                 \
   1349                               .Device(DEVICE_CPU)                    \
   1350                               .TypeConstraint<T>("T")                \
   1351                               .TypeConstraint<Tindices>("Tindices"), \
   1352                           SparseApplyAdagradOp<T, Tindices>);        \
   1353   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad")         \
   1354                               .Device(DEVICE_CPU)                    \
   1355                               .TypeConstraint<T>("T")                \
   1356                               .TypeConstraint<Tindices>("Tindices"), \
   1357                           SparseApplyAdagradOp<T, Tindices>);
   1358 #define REGISTER_CPU_KERNELS(T) \
   1359   REGISTER_KERNELS(T, int32);   \
   1360   REGISTER_KERNELS(T, int64);
   1361 
   1362 TF_CALL_half(REGISTER_CPU_KERNELS);
   1363 TF_CALL_float(REGISTER_CPU_KERNELS);
   1364 TF_CALL_double(REGISTER_CPU_KERNELS);
   1365 
   1366 #undef REGISTER_CPU_KERNELS
   1367 #undef REGISTER_KERNELS
   1368 
   1369 // Note, this op works on cpu only.
   1370 template <typename T, typename Tindex>
   1371 class SparseApplyProximalAdagradOp : public OpKernel {
   1372  public:
   1373   explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx)
   1374       : OpKernel(ctx) {
   1375     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1376   }
   1377 
   1378   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   1379     auto locks =
   1380         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   1381     Tensor var;
   1382     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1383                             ctx, 0, use_exclusive_lock_, true, &var));
   1384     Tensor accum;
   1385     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1386                             ctx, 1, use_exclusive_lock_, true, &accum));
   1387     OP_REQUIRES(
   1388         ctx, var.IsInitialized(),
   1389         errors::FailedPrecondition(
   1390             "Attempting to use uninitialized variables: ", requested_input(0)));
   1391     OP_REQUIRES(
   1392         ctx, accum.IsInitialized(),
   1393         errors::FailedPrecondition(
   1394             "Attempting to use uninitialized variables: ", requested_input(1)));
   1395     OP_REQUIRES(
   1396         ctx, var.shape().IsSameSize(accum.shape()),
   1397         errors::InvalidArgument("var and accum do not have the same shape",
   1398                                 var.shape().DebugString(), " ",
   1399                                 accum.shape().DebugString()));
   1400     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   1401                 errors::InvalidArgument("var must be at least 1 dimensional"));
   1402 
   1403     const Tensor& lr = ctx->input(2);
   1404     OP_REQUIRES(ctx,
   1405                 TensorShapeUtils::IsScalar(lr.shape()) &&
   1406                     lr.scalar<T>()() > static_cast<T>(0),
   1407                 errors::InvalidArgument("lr is not a positive scalar: ",
   1408                                         lr.shape().DebugString()));
   1409     const Tensor& l1 = ctx->input(3);
   1410     OP_REQUIRES(ctx,
   1411                 TensorShapeUtils::IsScalar(l1.shape()) &&
   1412                     l1.scalar<T>()() >= static_cast<T>(0),
   1413                 errors::InvalidArgument("l1 regularization strength is not a "
   1414                                         "non-negative scalar: ",
   1415                                         l1.shape().DebugString()));
   1416     const Tensor& l2 = ctx->input(4);
   1417     OP_REQUIRES(ctx,
   1418                 TensorShapeUtils::IsScalar(l2.shape()) &&
   1419                     l2.scalar<T>()() >= static_cast<T>(0),
   1420                 errors::InvalidArgument("l2 regularization strength is not a "
   1421                                         "non-negative scalar: ",
   1422                                         l2.shape().DebugString()));
   1423 
   1424     const Tensor& grad = ctx->input(5);
   1425     const Tensor& indices = ctx->input(6);
   1426     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   1427                 errors::InvalidArgument("indices must be one-dimensional"));
   1428 
   1429     int64 inner_dim = 1;
   1430     for (int d = 1; d < var.dims(); d++) {
   1431       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
   1432                   errors::InvalidArgument(strings::StrCat(
   1433                       "var and grad must match in dimension ", d)));
   1434       inner_dim *= grad.dim_size(d);
   1435     }
   1436     const Tindex N = indices.dim_size(0);
   1437     OP_REQUIRES(
   1438         ctx, grad.dim_size(0) == N,
   1439         errors::InvalidArgument(
   1440             "grad must be the same size as indices in the first dimension."));
   1441 
   1442     OP_REQUIRES(ctx, inner_dim > 0,
   1443                 errors::InvalidArgument(
   1444                     "Inner dimension should be greater than zero."));
   1445 
   1446     if (N > 0) {
   1447       if (inner_dim > 1) {
   1448         const Tindex first_dim_size = var.dim_size(0);
   1449         auto indices_vec = indices.vec<Tindex>();
   1450         auto var_flat = var.flat_outer_dims<T>();
   1451         auto accum_flat = accum.flat_outer_dims<T>();
   1452         auto grad_flat = grad.flat_outer_dims<T>();
   1453         T lr_scalar = lr.scalar<T>()();
   1454         T l1_scalar = l1.scalar<T>()();
   1455         T l2_scalar = l2.scalar<T>()();
   1456 
   1457         for (Tindex i = 0; i < N; i++) {
   1458           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1459           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1460                       errors::InvalidArgument(
   1461                           strings::StrCat("Index ", index, " at offset ", i,
   1462                                           " in indices is out of range")));
   1463           auto a = accum_flat.template chip<0>(index);
   1464           auto g = grad_flat.template chip<0>(i);
   1465           auto v = var_flat.template chip<0>(index);
   1466           a += g.square();
   1467           // compute learning_rate for current step.
   1468           auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
   1469           auto prox_v = v;
   1470           // v = w - g * learning_rate.
   1471           prox_v -= g * learning_rate;
   1472           if (l1_scalar > 0) {
   1473             // compute sign(v) * max(|v|, 0)
   1474             v = prox_v.sign() *
   1475                 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
   1476                     .cwiseMax(static_cast<T>(0.0)) /
   1477                 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
   1478           } else {
   1479             v = prox_v /
   1480                 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
   1481           }
   1482         }
   1483       } else {
   1484         auto indices_vec = indices.vec<Tindex>();
   1485         auto var_flat = var.flat<T>();
   1486         auto accum_flat = accum.flat<T>();
   1487         auto grad_flat = grad.flat<T>();
   1488         T lr_scalar = lr.scalar<T>()();
   1489         T l1_scalar = l1.scalar<T>()();
   1490         T l2_scalar = l2.scalar<T>()();
   1491         const Tindex first_dim_size = accum_flat.size();
   1492 
   1493         for (Tindex i = 0; i < N; i++) {
   1494           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1495           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1496                       errors::InvalidArgument(
   1497                           strings::StrCat("Index ", index, " at offset ", i,
   1498                                           " in indices is out of range")));
   1499           T& a = accum_flat(index);
   1500           const T& g = grad_flat(i);
   1501           a += g * g;
   1502           auto learning_rate = lr_scalar / std::sqrt(a);
   1503           auto prox_v = var_flat(index);
   1504           prox_v -= learning_rate * g;
   1505           if (l1_scalar > 0) {
   1506             var_flat(index) =
   1507                 sgn(prox_v) *
   1508                 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
   1509                          static_cast<T>(0.0)) /
   1510                 (1.0 + l2_scalar * learning_rate);
   1511           } else {
   1512             var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);
   1513           }
   1514         }
   1515       }
   1516     }
   1517 
   1518     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1519   }
   1520 
   1521  private:
   1522   bool use_exclusive_lock_;
   1523 };
   1524 
   1525 #define REGISTER_KERNELS(T, Tindices)                                 \
   1526   REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad")          \
   1527                               .Device(DEVICE_CPU)                     \
   1528                               .TypeConstraint<T>("T")                 \
   1529                               .TypeConstraint<Tindices>("Tindices"),  \
   1530                           SparseApplyProximalAdagradOp<T, Tindices>); \
   1531   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad")  \
   1532                               .Device(DEVICE_CPU)                     \
   1533                               .TypeConstraint<T>("T")                 \
   1534                               .TypeConstraint<Tindices>("Tindices"),  \
   1535                           SparseApplyProximalAdagradOp<T, Tindices>);
   1536 
   1537 REGISTER_KERNELS(float, int32);
   1538 REGISTER_KERNELS(float, int64);
   1539 REGISTER_KERNELS(double, int32);
   1540 REGISTER_KERNELS(double, int64);
   1541 #undef REGISTER_KERNELS
   1542 
   1543 template <typename Device, typename T>
   1544 class ApplyAdagradDAOp : public OpKernel {
   1545  public:
   1546   explicit ApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1547     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1548   }
   1549 
   1550   void Compute(OpKernelContext* ctx) override {
   1551     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   1552                                                       {0, 1, 2});
   1553     Tensor var;
   1554     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1555                             ctx, 0, use_exclusive_lock_, false, &var));
   1556     Tensor gradient_accum;
   1557     OP_REQUIRES_OK(
   1558         ctx, GetInputTensorFromVariable<Device, T>(ctx, 1, use_exclusive_lock_,
   1559                                                    false, &gradient_accum));
   1560     Tensor gradient_squared_accum;
   1561     OP_REQUIRES_OK(
   1562         ctx, GetInputTensorFromVariable<Device, T>(
   1563                  ctx, 2, use_exclusive_lock_, false, &gradient_squared_accum));
   1564     OP_REQUIRES(
   1565         ctx, var.IsInitialized(),
   1566         errors::FailedPrecondition(
   1567             "Attempting to use uninitialized variables: ", requested_input(0)));
   1568     OP_REQUIRES(
   1569         ctx, gradient_accum.IsInitialized(),
   1570         errors::FailedPrecondition(
   1571             "Attempting to use uninitialized variables: ", requested_input(1)));
   1572     OP_REQUIRES(
   1573         ctx, gradient_squared_accum.IsInitialized(),
   1574         errors::FailedPrecondition(
   1575             "Attempting to use uninitialized variables: ", requested_input(2)));
   1576     OP_REQUIRES(
   1577         ctx, var.shape().IsSameSize(gradient_accum.shape()),
   1578         errors::InvalidArgument("var and accum do not have the same shape",
   1579                                 var.shape().DebugString(), " ",
   1580                                 gradient_accum.shape().DebugString()));
   1581     OP_REQUIRES(
   1582         ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
   1583         errors::InvalidArgument("var and accum do not have the same shape",
   1584                                 var.shape().DebugString(), " ",
   1585                                 gradient_squared_accum.shape().DebugString()));
   1586 
   1587     const Tensor& grad = ctx->input(3);
   1588     OP_REQUIRES(
   1589         ctx, var.shape().IsSameSize(grad.shape()),
   1590         errors::InvalidArgument("var and grad do not have the same shape",
   1591                                 var.shape().DebugString(), " ",
   1592                                 grad.shape().DebugString()));
   1593 
   1594     const Tensor& lr = ctx->input(4);
   1595     OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
   1596                 errors::InvalidArgument("lr is not a scalar: ",
   1597                                         lr.shape().DebugString()));
   1598     const Tensor& l1 = ctx->input(5);
   1599     OP_REQUIRES(
   1600         ctx, TensorShapeUtils::IsScalar(l1.shape()),
   1601         errors::InvalidArgument("l1 regularization strength is not a scalar: ",
   1602                                 l1.shape().DebugString()));
   1603     const Tensor& l2 = ctx->input(6);
   1604     OP_REQUIRES(
   1605         ctx, TensorShapeUtils::IsScalar(l2.shape()),
   1606         errors::InvalidArgument("l2 regularization strength is not a scalar: ",
   1607                                 l2.shape().DebugString()));
   1608     const Tensor& global_step = ctx->input(7);
   1609     OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
   1610                 errors::InvalidArgument("global_step is not a scalar: ",
   1611                                         global_step.shape().DebugString()));
   1612 
   1613     const Device& device = ctx->template eigen_device<Device>();
   1614     functor::ApplyAdagradDA<Device, T>()(
   1615         device, var.flat<T>(), gradient_accum.flat<T>(),
   1616         gradient_squared_accum.flat<T>(), lr.scalar<T>(),
   1617         global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(),
   1618         grad.flat<T>());
   1619 
   1620     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1621   }
   1622 
   1623  private:
   1624   bool use_exclusive_lock_;
   1625 };
   1626 
   1627 #define REGISTER_KERNELS(D, T)                                            \
   1628   REGISTER_KERNEL_BUILDER(                                                \
   1629       Name("ApplyAdagradDA").Device(DEVICE_##D).TypeConstraint<T>("T"),   \
   1630       ApplyAdagradDAOp<D##Device, T>);                                    \
   1631   REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradDA")                  \
   1632                               .Device(DEVICE_##D)                         \
   1633                               .HostMemory("var")                          \
   1634                               .HostMemory("gradient_accumulator")         \
   1635                               .HostMemory("gradient_squared_accumulator") \
   1636                               .TypeConstraint<T>("T"),                    \
   1637                           ApplyAdagradDAOp<D##Device, T>);
   1638 
   1639 REGISTER_KERNELS(CPU, float);
   1640 REGISTER_KERNELS(CPU, double);
   1641 #undef REGISTER_KERNELS
   1642 
   1643 // Note, this op works on cpu only.
   1644 template <typename T, typename Tindex>
   1645 class SparseApplyAdagradDAOp : public OpKernel {
   1646  public:
   1647   explicit SparseApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1648     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1649   }
   1650 
   1651   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   1652     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   1653                                                       {0, 1, 2});
   1654     Tensor var;
   1655     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1656                             ctx, 0, use_exclusive_lock_, true, &var));
   1657     Tensor gradient_accum;
   1658     OP_REQUIRES_OK(ctx,
   1659                    GetInputTensorFromVariable<CPUDevice, T>(
   1660                        ctx, 1, use_exclusive_lock_, true, &gradient_accum));
   1661     Tensor gradient_squared_accum;
   1662     OP_REQUIRES_OK(
   1663         ctx, GetInputTensorFromVariable<CPUDevice, T>(
   1664                  ctx, 2, use_exclusive_lock_, true, &gradient_squared_accum));
   1665     OP_REQUIRES(
   1666         ctx, var.IsInitialized(),
   1667         errors::FailedPrecondition(
   1668             "Attempting to use uninitialized variables: ", requested_input(0)));
   1669     OP_REQUIRES(
   1670         ctx, gradient_accum.IsInitialized(),
   1671         errors::FailedPrecondition(
   1672             "Attempting to use uninitialized variables: ", requested_input(1)));
   1673     OP_REQUIRES(
   1674         ctx, gradient_squared_accum.IsInitialized(),
   1675         errors::FailedPrecondition(
   1676             "Attempting to use uninitialized variables: ", requested_input(2)));
   1677     OP_REQUIRES(
   1678         ctx, var.shape().IsSameSize(gradient_accum.shape()),
   1679         errors::InvalidArgument("var and accum do not have the same shape",
   1680                                 var.shape().DebugString(), " ",
   1681                                 gradient_accum.shape().DebugString()));
   1682     OP_REQUIRES(
   1683         ctx, var.shape().IsSameSize(gradient_squared_accum.shape()),
   1684         errors::InvalidArgument("var and accum do not have the same shape",
   1685                                 var.shape().DebugString(), " ",
   1686                                 gradient_squared_accum.shape().DebugString()));
   1687 
   1688     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   1689                 errors::InvalidArgument("var must be at least 1 dimensional"));
   1690 
   1691     const Tensor& grad = ctx->input(3);
   1692     const Tensor& indices = ctx->input(4);
   1693     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   1694                 errors::InvalidArgument("indices must be one-dimensional"));
   1695 
   1696     const Tensor& lr = ctx->input(5);
   1697     OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
   1698                 errors::InvalidArgument("lr is not a scalar: ",
   1699                                         lr.shape().DebugString()));
   1700 
   1701     const Tensor& l1 = ctx->input(6);
   1702     OP_REQUIRES(
   1703         ctx, TensorShapeUtils::IsScalar(l1.shape()),
   1704         errors::InvalidArgument("l1 regularization strength is not a scalar: ",
   1705                                 l1.shape().DebugString()));
   1706 
   1707     const Tensor& l2 = ctx->input(7);
   1708     OP_REQUIRES(
   1709         ctx, TensorShapeUtils::IsScalar(l2.shape()),
   1710         errors::InvalidArgument("l2 regularization strength is not a scalar: ",
   1711                                 l2.shape().DebugString()));
   1712 
   1713     const Tensor& global_step = ctx->input(8);
   1714     OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
   1715                 errors::InvalidArgument("global_step is not a scalar: ",
   1716                                         global_step.shape().DebugString()));
   1717 
   1718     int64 inner_dim = 1;
   1719     for (int d = 1; d < var.dims(); d++) {
   1720       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
   1721                   errors::InvalidArgument(strings::StrCat(
   1722                       "var and grad must match in dimension ", d)));
   1723       inner_dim *= grad.dim_size(d);
   1724     }
   1725     const Tindex N = indices.dim_size(0);
   1726     OP_REQUIRES(
   1727         ctx, grad.dim_size(0) == N,
   1728         errors::InvalidArgument(
   1729             "grad must be the same size as indices in the first dimension."));
   1730 
   1731     OP_REQUIRES(ctx, inner_dim > 0,
   1732                 errors::InvalidArgument(
   1733                     "Inner dimension should be greater than zero."));
   1734 
   1735     // AdagradDA update:
   1736     // Let g to be gradient accumulator, gg to be gradient squared accumulator,
   1737     // T be the global step, lr is the learning rate, and k the initial
   1738     // gradient squared accumulator value.
   1739     // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
   1740     if (N > 0) {
   1741       if (inner_dim > 1) {
   1742         const Tindex first_dim_size = var.dim_size(0);
   1743         auto indices_vec = indices.vec<Tindex>();
   1744         auto var_flat = var.flat_outer_dims<T>();
   1745         auto gradient_accum_flat = gradient_accum.flat_outer_dims<T>();
   1746         auto gradient_squared_accum_flat =
   1747             gradient_squared_accum.flat_outer_dims<T>();
   1748         auto grad_flat = grad.flat_outer_dims<T>();
   1749         T lr_scalar = lr.scalar<T>()();
   1750         T global_step_scalar = global_step.scalar<int64>()();
   1751         T l1_scalar = l1.scalar<T>()();
   1752         T l2_scalar = l2.scalar<T>()();
   1753         const double gs_lr = global_step_scalar * lr_scalar;
   1754 
   1755         for (Tindex i = 0; i < N; i++) {
   1756           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1757           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1758                       errors::InvalidArgument(
   1759                           strings::StrCat("Index ", index, " at offset ", i,
   1760                                           " in indices is out of range")));
   1761           auto ga = gradient_accum_flat.template chip<0>(index);
   1762           auto da = gradient_squared_accum_flat.template chip<0>(index);
   1763           auto g = grad_flat.template chip<0>(i);
   1764           auto v = var_flat.template chip<0>(index);
   1765           ga += g;
   1766           da += g.square();
   1767           if (l1_scalar > 0) {
   1768             v = ga.constant(-1.0) * ga.sign() *
   1769                 ((ga.abs() / ga.constant(global_step_scalar)) -
   1770                  ga.constant(l1_scalar))
   1771                     .cwiseMax(static_cast<T>(0.0)) /
   1772                 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
   1773           } else {
   1774             v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) /
   1775                 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr));
   1776           }
   1777         }
   1778       } else {
   1779         auto indices_vec = indices.vec<Tindex>();
   1780         auto var_flat = var.flat<T>();
   1781         auto gradient_accum_flat = gradient_accum.flat<T>();
   1782         auto gradient_squared_accum_flat = gradient_squared_accum.flat<T>();
   1783         auto grad_flat = grad.flat<T>();
   1784         const double lr_scalar = lr.scalar<T>()();
   1785         const int64 global_step_scalar = global_step.scalar<int64>()();
   1786         const double l1_scalar = l1.scalar<T>()();
   1787         const double l2_scalar = l2.scalar<T>()();
   1788         const Tindex first_dim_size = var_flat.size();
   1789         const double gs_l1 = global_step_scalar * l1_scalar;
   1790         const double gs_l2_lr = global_step_scalar * l2_scalar * lr_scalar;
   1791 
   1792         for (Tindex i = 0; i < N; i++) {
   1793           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   1794           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   1795                       errors::InvalidArgument(
   1796                           strings::StrCat("Index ", index, " at offset ", i,
   1797                                           " in indices is out of range")));
   1798           T& ga = gradient_accum_flat(index);
   1799           T& da = gradient_squared_accum_flat(index);
   1800           const double g = grad_flat(i);
   1801           ga += g;
   1802           da += g * g;
   1803           if (l1_scalar > 0) {
   1804             var_flat(index) = sgn(-ga) * lr_scalar *
   1805                               std::max((std::abs(ga) - gs_l1), 0.0) /
   1806                               (gs_l2_lr + std::sqrt(da));
   1807           } else {
   1808             var_flat(index) = (-ga * lr_scalar) / (gs_l2_lr + std::sqrt(da));
   1809           }
   1810         }
   1811       }
   1812     }
   1813 
   1814     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1815   }
   1816 
   1817  private:
   1818   bool use_exclusive_lock_;
   1819 };
   1820 
   1821 #define REGISTER_KERNELS(T, Tindices)                                     \
   1822   REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradDA")                    \
   1823                               .Device(DEVICE_CPU)                         \
   1824                               .TypeConstraint<T>("T")                     \
   1825                               .TypeConstraint<Tindices>("Tindices"),      \
   1826                           SparseApplyAdagradDAOp<T, Tindices>);           \
   1827   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA")            \
   1828                               .Device(DEVICE_CPU)                         \
   1829                               .HostMemory("var")                          \
   1830                               .HostMemory("gradient_accumulator")         \
   1831                               .HostMemory("gradient_squared_accumulator") \
   1832                               .TypeConstraint<T>("T")                     \
   1833                               .TypeConstraint<Tindices>("Tindices"),      \
   1834                           SparseApplyAdagradDAOp<T, Tindices>);
   1835 
   1836 REGISTER_KERNELS(float, int32);
   1837 REGISTER_KERNELS(float, int64);
   1838 REGISTER_KERNELS(double, int32);
   1839 REGISTER_KERNELS(double, int64);
   1840 #undef REGISTER_KERNELS
   1841 
   1842 template <typename Device, typename T, bool has_l2_shrinkage>
   1843 class ApplyFtrlOp : public OpKernel {
   1844  public:
   1845   explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1846     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1847   }
   1848 
   1849   void Compute(OpKernelContext* ctx) override {
   1850     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   1851                                                       {0, 1, 2});
   1852 
   1853     Tensor var;
   1854     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1855                             ctx, 0, use_exclusive_lock_, false, &var));
   1856     Tensor accum;
   1857     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1858                             ctx, 1, use_exclusive_lock_, false, &accum));
   1859     Tensor linear;
   1860     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   1861                             ctx, 2, use_exclusive_lock_, false, &linear));
   1862     OP_REQUIRES(
   1863         ctx, var.IsInitialized(),
   1864         errors::FailedPrecondition(
   1865             "Attempting to use uninitialized variables: ", requested_input(0)));
   1866     OP_REQUIRES(
   1867         ctx, accum.IsInitialized(),
   1868         errors::FailedPrecondition(
   1869             "Attempting to use uninitialized variables: ", requested_input(1)));
   1870     OP_REQUIRES(
   1871         ctx, linear.IsInitialized(),
   1872         errors::FailedPrecondition(
   1873             "Attempting to use uninitialized variables: ", requested_input(2)));
   1874 
   1875     const Tensor& grad = ctx->input(3);
   1876     OP_REQUIRES(
   1877         ctx, var.shape().IsSameSize(accum.shape()),
   1878         errors::InvalidArgument("var and accum do not have the same shape",
   1879                                 var.shape().DebugString(), " ",
   1880                                 accum.shape().DebugString()));
   1881     OP_REQUIRES(
   1882         ctx, var.shape().IsSameSize(linear.shape()),
   1883         errors::InvalidArgument("var and linear do not have the same shape",
   1884                                 var.shape().DebugString(), " ",
   1885                                 linear.shape().DebugString()));
   1886     OP_REQUIRES(
   1887         ctx, var.shape().IsSameSize(grad.shape()),
   1888         errors::InvalidArgument("var and grad do not have the same shape",
   1889                                 var.shape().DebugString(), " ",
   1890                                 grad.shape().DebugString()));
   1891 
   1892     const Tensor& lr = ctx->input(4);
   1893     OP_REQUIRES(ctx,
   1894                 TensorShapeUtils::IsScalar(lr.shape()) &&
   1895                     lr.scalar<T>()() > static_cast<T>(0),
   1896                 errors::InvalidArgument("lr is not a positive scalar: ",
   1897                                         lr.shape().DebugString()));
   1898     const Tensor& l1 = ctx->input(5);
   1899     OP_REQUIRES(ctx,
   1900                 TensorShapeUtils::IsScalar(l1.shape()) &&
   1901                     l1.scalar<T>()() >= static_cast<T>(0),
   1902                 errors::InvalidArgument("l1 regularization strength is not a "
   1903                                         "non-negative scalar: ",
   1904                                         l1.shape().DebugString()));
   1905     const Tensor& l2 = ctx->input(6);
   1906     OP_REQUIRES(ctx,
   1907                 TensorShapeUtils::IsScalar(l2.shape()) &&
   1908                     l2.scalar<T>()() >= static_cast<T>(0),
   1909                 errors::InvalidArgument("l2 regularization strength is not a "
   1910                                         "non-negative scalar: ",
   1911                                         l2.shape().DebugString()));
   1912     const int lr_power_index = has_l2_shrinkage ? 8 : 7;
   1913     const Tensor& lr_power = ctx->input(lr_power_index);
   1914     OP_REQUIRES(ctx,
   1915                 TensorShapeUtils::IsScalar(lr_power.shape()) &&
   1916                     lr_power.scalar<T>()() <= static_cast<T>(0),
   1917                 errors::InvalidArgument("lr_power is not a"
   1918                                         " non-positive scalar: ",
   1919                                         lr_power.shape().DebugString()));
   1920 
   1921     const Device& device = ctx->template eigen_device<Device>();
   1922     if (has_l2_shrinkage) {
   1923       const Tensor& l2_shrinkage = ctx->input(7);
   1924       OP_REQUIRES(
   1925           ctx,
   1926           TensorShapeUtils::IsScalar(l2_shrinkage.shape()) &&
   1927               l2_shrinkage.scalar<T>()() >= static_cast<T>(0),
   1928           errors::InvalidArgument("l2 shrinkage regularization strength "
   1929                                   "is not a non-negative scalar: ",
   1930                                   l2_shrinkage.shape().DebugString()));
   1931       functor::ApplyFtrlV2<Device, T>()(
   1932           device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(),
   1933           grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(),
   1934           l2_shrinkage.scalar<T>(), lr_power.scalar<T>());
   1935     } else {
   1936       functor::ApplyFtrl<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
   1937                                       linear.flat<T>(), grad.flat<T>(),
   1938                                       lr.scalar<T>(), l1.scalar<T>(),
   1939                                       l2.scalar<T>(), lr_power.scalar<T>());
   1940     }
   1941 
   1942     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   1943   }
   1944 
   1945  private:
   1946   bool use_exclusive_lock_;
   1947 };
   1948 
   1949 #define REGISTER_KERNELS(D, T)                                     \
   1950   REGISTER_KERNEL_BUILDER(                                         \
   1951       Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   1952       ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>);      \
   1953   REGISTER_KERNEL_BUILDER(                                         \
   1954       Name("ResourceApplyFtrl")                                    \
   1955           .HostMemory("var")                                       \
   1956           .HostMemory("accum")                                     \
   1957           .HostMemory("linear")                                    \
   1958           .Device(DEVICE_##D)                                      \
   1959           .TypeConstraint<T>("T"),                                 \
   1960       ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>);
   1961 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   1962 
   1963 TF_CALL_half(REGISTER_CPU_KERNELS);
   1964 TF_CALL_float(REGISTER_CPU_KERNELS);
   1965 TF_CALL_double(REGISTER_CPU_KERNELS);
   1966 
   1967 #undef REGISTER_CPU_KERNELS
   1968 #undef REGISTER_KERNELS
   1969 
   1970 #define REGISTER_KERNELS(D, T)                                       \
   1971   REGISTER_KERNEL_BUILDER(                                           \
   1972       Name("ApplyFtrlV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   1973       ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>);         \
   1974   REGISTER_KERNEL_BUILDER(                                           \
   1975       Name("ResourceApplyFtrlV2")                                    \
   1976           .HostMemory("var")                                         \
   1977           .HostMemory("accum")                                       \
   1978           .HostMemory("linear")                                      \
   1979           .Device(DEVICE_##D)                                        \
   1980           .TypeConstraint<T>("T"),                                   \
   1981       ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>);
   1982 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   1983 
   1984 TF_CALL_half(REGISTER_CPU_KERNELS);
   1985 TF_CALL_float(REGISTER_CPU_KERNELS);
   1986 TF_CALL_double(REGISTER_CPU_KERNELS);
   1987 
   1988 #undef REGISTER_CPU_KERNELS
   1989 #undef REGISTER_KERNELS
   1990 
   1991 // Note, this op works on cpu only.
   1992 template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage>
   1993 class SparseApplyFtrlOp : public OpKernel {
   1994  public:
   1995   explicit SparseApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   1996     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   1997   }
   1998 
   1999   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   2000     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2001                                                       {0, 1, 2});
   2002     Tensor var;
   2003     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2004                             ctx, 0, use_exclusive_lock_, true, &var));
   2005     Tensor accum;
   2006     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2007                             ctx, 1, use_exclusive_lock_, true, &accum));
   2008     Tensor linear;
   2009     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2010                             ctx, 2, use_exclusive_lock_, true, &linear));
   2011     OP_REQUIRES(
   2012         ctx, var.IsInitialized(),
   2013         errors::FailedPrecondition(
   2014             "Attempting to use uninitialized variables: ", requested_input(0)));
   2015     OP_REQUIRES(
   2016         ctx, accum.IsInitialized(),
   2017         errors::FailedPrecondition(
   2018             "Attempting to use uninitialized variables: ", requested_input(1)));
   2019     OP_REQUIRES(
   2020         ctx, linear.IsInitialized(),
   2021         errors::FailedPrecondition(
   2022             "Attempting to use uninitialized variables: ", requested_input(2)));
   2023     OP_REQUIRES(
   2024         ctx, var.shape().IsSameSize(accum.shape()),
   2025         errors::InvalidArgument("var and accum do not have the same shape",
   2026                                 var.shape().DebugString(), " ",
   2027                                 accum.shape().DebugString()));
   2028     OP_REQUIRES(
   2029         ctx, var.shape().IsSameSize(linear.shape()),
   2030         errors::InvalidArgument("var and linear do not have the same shape",
   2031                                 var.shape().DebugString(), " ",
   2032                                 linear.shape().DebugString()));
   2033     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   2034                 errors::InvalidArgument("var must be at least 1 dimensional"));
   2035 
   2036     const Tensor& grad = ctx->input(3);
   2037     const Tensor& indices = ctx->input(4);
   2038     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   2039                 errors::InvalidArgument("indices must be one-dimensional"));
   2040 
   2041     const Tensor& lr = ctx->input(5);
   2042     OP_REQUIRES(ctx,
   2043                 TensorShapeUtils::IsScalar(lr.shape()) &&
   2044                     lr.scalar<T>()() > static_cast<T>(0),
   2045                 errors::InvalidArgument("lr is not a positive scalar: ",
   2046                                         lr.shape().DebugString()));
   2047 
   2048     const Tensor& l1 = ctx->input(6);
   2049     OP_REQUIRES(ctx,
   2050                 TensorShapeUtils::IsScalar(l1.shape()) &&
   2051                     l1.scalar<T>()() >= static_cast<T>(0),
   2052                 errors::InvalidArgument("l1 regularization strength is not a "
   2053                                         "non-negative scalar: ",
   2054                                         l1.shape().DebugString()));
   2055     const Tensor& l2 = ctx->input(7);
   2056     OP_REQUIRES(ctx,
   2057                 TensorShapeUtils::IsScalar(l2.shape()) &&
   2058                     l2.scalar<T>()() >= static_cast<T>(0),
   2059                 errors::InvalidArgument("l2 regularization strength is not a "
   2060                                         "non-negative scalar: ",
   2061                                         l2.shape().DebugString()));
   2062     const int lr_power_index = has_l2_shrinkage ? 9 : 8;
   2063     const Tensor& lr_power = ctx->input(lr_power_index);
   2064     OP_REQUIRES(ctx,
   2065                 TensorShapeUtils::IsScalar(lr_power.shape()) &&
   2066                     lr_power.scalar<T>()() <= static_cast<T>(0),
   2067                 errors::InvalidArgument("lr_power is not a "
   2068                                         "non-positive scalar: ",
   2069                                         lr_power.shape().DebugString()));
   2070     int64 inner_dim = 1;
   2071     for (int d = 1; d < var.dims(); d++) {
   2072       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
   2073                   errors::InvalidArgument(strings::StrCat(
   2074                       "var and grad must match in dimension ", d)));
   2075       inner_dim *= grad.dim_size(d);
   2076     }
   2077     const Tindex N = indices.dim_size(0);
   2078     OP_REQUIRES(
   2079         ctx, grad.dim_size(0) == N,
   2080         errors::InvalidArgument(
   2081             "grad must be the same size as indices in the first dimension."));
   2082 
   2083     OP_REQUIRES(ctx, inner_dim > 0,
   2084                 errors::InvalidArgument(
   2085                     "Inner dimension should be greater than zero."));
   2086 
   2087     const Tensor* l2_shrinkage;
   2088     if (has_l2_shrinkage) {
   2089       l2_shrinkage = &ctx->input(8);
   2090       OP_REQUIRES(
   2091           ctx,
   2092           TensorShapeUtils::IsScalar(l2_shrinkage->shape()) &&
   2093               l2_shrinkage->scalar<T>()() >= static_cast<T>(0),
   2094           errors::InvalidArgument("l2 shrinkage regularization strength "
   2095                                   "is not a non-negative scalar: ",
   2096                                   l2_shrinkage->shape().DebugString()));
   2097     }
   2098 
   2099     if (N > 0) {
   2100       if (inner_dim > 1) {
   2101         const Tindex first_dim_size = var.dim_size(0);
   2102         auto indices_vec = indices.vec<Tindex>();
   2103         auto var_flat = var.flat_outer_dims<T>();
   2104         auto accum_flat = accum.flat_outer_dims<T>();
   2105         auto linear_flat = linear.flat_outer_dims<T>();
   2106         auto grad_flat = grad.flat_outer_dims<T>();
   2107         T lr_scalar = lr.scalar<T>()();
   2108         T l1_scalar = l1.scalar<T>()();
   2109         T l2_scalar = l2.scalar<T>()();
   2110         T l2_shrinkage_scalar;
   2111         if (has_l2_shrinkage) {
   2112           l2_shrinkage_scalar = l2_shrinkage->scalar<T>()();
   2113         }
   2114         T lr_power_scalar = lr_power.scalar<T>()();
   2115 
   2116         for (Tindex i = 0; i < N; i++) {
   2117           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   2118           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   2119                       errors::InvalidArgument(
   2120                           strings::StrCat("Index ", index, " at offset ", i,
   2121                                           " in indices is out of range")));
   2122           auto accum = accum_flat.template chip<0>(index);
   2123           auto linear = linear_flat.template chip<0>(index);
   2124           auto grad = grad_flat.template chip<0>(i);
   2125           auto var = var_flat.template chip<0>(index);
   2126 
   2127 // Use a macro to implement the computation here due to the templating of the
   2128 // eigen tensor library.
   2129 #define COMPUTE_FTRL(grad_to_use)                                              \
   2130   auto new_accum = accum + grad_to_use.square();                               \
   2131   if (lr_power_scalar == static_cast<T>(-0.5)) {                               \
   2132     linear +=                                                                  \
   2133         grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var;     \
   2134   } else {                                                                     \
   2135     linear += grad_to_use - (new_accum.pow(-lr_power_scalar) -                 \
   2136                              accum.pow(-lr_power_scalar)) /                    \
   2137                                 lr_scalar * var;                               \
   2138   }                                                                            \
   2139   auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar);        \
   2140   auto x = l1_reg_adjust - linear;                                             \
   2141   if (lr_power_scalar == static_cast<T>(-0.5)) {                               \
   2142     auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) +                \
   2143              linear.constant(static_cast<T>(2) * l2_scalar);                   \
   2144     var = x / y;                                                               \
   2145   } else {                                                                     \
   2146     auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \
   2147              linear.constant(static_cast<T>(2) * l2_scalar);                   \
   2148     var = x / y;                                                               \
   2149   }                                                                            \
   2150   accum += grad_to_use.square();
   2151 
   2152           if (has_l2_shrinkage) {
   2153             auto grad_with_shrinkage =
   2154                 grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
   2155             COMPUTE_FTRL(grad_with_shrinkage);
   2156           } else {
   2157             COMPUTE_FTRL(grad);
   2158           }
   2159         }
   2160 #undef COMPUTE_FTRL
   2161       } else {
   2162         T lr_scalar = lr.scalar<T>()();
   2163         T l1_scalar = l1.scalar<T>()();
   2164         T l2_scalar = l2.scalar<T>()();
   2165         T lr_power_scalar = lr_power.scalar<T>()();
   2166         T l2_shrinkage_scalar;
   2167         if (has_l2_shrinkage) {
   2168           l2_shrinkage_scalar = l2_shrinkage->scalar<T>()();
   2169         }
   2170 
   2171         auto indices_vec = indices.vec<Tindex>();
   2172         auto var_flat = var.flat<T>();
   2173         auto accum_flat = accum.flat<T>();
   2174         auto linear_flat = linear.flat<T>();
   2175         auto grad_flat = grad.flat<T>();
   2176         const Tindex first_dim_size = accum_flat.size();
   2177 
   2178         for (Tindex i = 0; i < N; i++) {
   2179           const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   2180           OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   2181                       errors::InvalidArgument(
   2182                           strings::StrCat("Index ", index, " at offset ", i,
   2183                                           " in indices is out of range")));
   2184           T& a = accum_flat(index);
   2185           T& l = linear_flat(index);
   2186           T& v = var_flat(index);
   2187           T g;
   2188           if (has_l2_shrinkage) {
   2189             g = grad_flat(i) +
   2190                 (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i));
   2191           } else {
   2192             g = grad_flat(i);
   2193           }
   2194 
   2195           T updated_a = a + g * g;
   2196           using Eigen::numext::pow;
   2197           T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
   2198           sigma /= lr_scalar;
   2199           T updated_l = l + g - sigma * v;
   2200           v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar,
   2201                           lr_power_scalar);
   2202           a = updated_a;
   2203           l = updated_l;
   2204         }
   2205       }
   2206     }
   2207 
   2208     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2209   }
   2210 
   2211  private:
   2212   bool use_exclusive_lock_;
   2213 };
   2214 
   2215 #define REGISTER_KERNELS(T, Tindices)                                         \
   2216   REGISTER_KERNEL_BUILDER(                                                    \
   2217       Name("SparseApplyFtrl")                                                 \
   2218           .Device(DEVICE_CPU)                                                 \
   2219           .TypeConstraint<T>("T")                                             \
   2220           .TypeConstraint<Tindices>("Tindices"),                              \
   2221       SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>); \
   2222   REGISTER_KERNEL_BUILDER(                                                    \
   2223       Name("ResourceSparseApplyFtrl")                                         \
   2224           .Device(DEVICE_CPU)                                                 \
   2225           .TypeConstraint<T>("T")                                             \
   2226           .TypeConstraint<Tindices>("Tindices"),                              \
   2227       SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>);
   2228 #define REGISTER_CPU_KERNELS(T) \
   2229   REGISTER_KERNELS(T, int32);   \
   2230   REGISTER_KERNELS(T, int64);
   2231 
   2232 TF_CALL_half(REGISTER_CPU_KERNELS);
   2233 TF_CALL_float(REGISTER_CPU_KERNELS);
   2234 TF_CALL_double(REGISTER_CPU_KERNELS);
   2235 
   2236 #undef REGISTER_CPU_KERNELS
   2237 #undef REGISTER_KERNELS
   2238 
   2239 #define REGISTER_KERNELS(T, Tindices)                                        \
   2240   REGISTER_KERNEL_BUILDER(                                                   \
   2241       Name("SparseApplyFtrlV2")                                              \
   2242           .Device(DEVICE_CPU)                                                \
   2243           .TypeConstraint<T>("T")                                            \
   2244           .TypeConstraint<Tindices>("Tindices"),                             \
   2245       SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>); \
   2246   REGISTER_KERNEL_BUILDER(                                                   \
   2247       Name("ResourceSparseApplyFtrlV2")                                      \
   2248           .Device(DEVICE_CPU)                                                \
   2249           .TypeConstraint<T>("T")                                            \
   2250           .TypeConstraint<Tindices>("Tindices"),                             \
   2251       SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>);
   2252 #define REGISTER_CPU_KERNELS(T) \
   2253   REGISTER_KERNELS(T, int32);   \
   2254   REGISTER_KERNELS(T, int64);
   2255 
   2256 TF_CALL_half(REGISTER_CPU_KERNELS);
   2257 TF_CALL_float(REGISTER_CPU_KERNELS);
   2258 TF_CALL_double(REGISTER_CPU_KERNELS);
   2259 
   2260 #undef REGISTER_CPU_KERNELS
   2261 #undef REGISTER_KERNELS
   2262 
   2263 template <typename Device, typename T>
   2264 class ApplyMomentumOp : public OpKernel {
   2265  public:
   2266   explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2267     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2268     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
   2269   }
   2270 
   2271   void Compute(OpKernelContext* ctx) override {
   2272     auto locks =
   2273         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   2274 
   2275     Tensor var;
   2276     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2277                             ctx, 0, use_exclusive_lock_, false, &var));
   2278     Tensor accum;
   2279     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2280                             ctx, 1, use_exclusive_lock_, false, &accum));
   2281     OP_REQUIRES(
   2282         ctx, var.IsInitialized(),
   2283         errors::FailedPrecondition(
   2284             "Attempting to use uninitialized variables: ", requested_input(0)));
   2285     OP_REQUIRES(
   2286         ctx, accum.IsInitialized(),
   2287         errors::FailedPrecondition(
   2288             "Attempting to use uninitialized variables: ", requested_input(1)));
   2289     const Tensor& lr = ctx->input(2);
   2290     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   2291                 errors::InvalidArgument("lr is not a scalar: ",
   2292                                         lr.shape().DebugString()));
   2293     const Tensor& grad = ctx->input(3);
   2294     OP_REQUIRES(
   2295         ctx, var.shape().IsSameSize(accum.shape()),
   2296         errors::InvalidArgument("var and accum do not have the same shape",
   2297                                 var.shape().DebugString(), " ",
   2298                                 accum.shape().DebugString()));
   2299     OP_REQUIRES(
   2300         ctx, var.shape().IsSameSize(grad.shape()),
   2301         errors::InvalidArgument("var and grad do not have the same shape",
   2302                                 var.shape().DebugString(), " ",
   2303                                 grad.shape().DebugString()));
   2304 
   2305     const Tensor& momentum = ctx->input(4);
   2306     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   2307                 errors::InvalidArgument("momentum is not a scalar: ",
   2308                                         momentum.shape().DebugString()));
   2309 
   2310     const Device& device = ctx->template eigen_device<Device>();
   2311     functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
   2312                                         lr.scalar<T>(), grad.flat<T>(),
   2313                                         momentum.scalar<T>(), use_nesterov_);
   2314     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2315   }
   2316 
   2317  private:
   2318   bool use_exclusive_lock_;
   2319   bool use_nesterov_;
   2320 };
   2321 
   2322 #define REGISTER_KERNELS(D, T)                                         \
   2323   REGISTER_KERNEL_BUILDER(                                             \
   2324       Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   2325       ApplyMomentumOp<D##Device, T>);                                  \
   2326   REGISTER_KERNEL_BUILDER(Name("ResourceApplyMomentum")                \
   2327                               .Device(DEVICE_##D)                      \
   2328                               .HostMemory("var")                       \
   2329                               .HostMemory("accum")                     \
   2330                               .TypeConstraint<T>("T"),                 \
   2331                           ApplyMomentumOp<D##Device, T>);
   2332 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   2333 
   2334 TF_CALL_half(REGISTER_CPU_KERNELS);
   2335 TF_CALL_float(REGISTER_CPU_KERNELS);
   2336 TF_CALL_double(REGISTER_CPU_KERNELS);
   2337 
   2338 #if GOOGLE_CUDA
   2339 // Forward declarations of the functor specializations for GPU.
   2340 namespace functor {
   2341 #define DECLARE_GPU_SPEC(T)                                               \
   2342   template <>                                                             \
   2343   void ApplyMomentum<GPUDevice, T>::operator()(                           \
   2344       const GPUDevice& d, typename TTypes<T>::Flat var,                   \
   2345       typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
   2346       typename TTypes<T>::ConstFlat grad,                                 \
   2347       typename TTypes<T>::ConstScalar momentum, bool use_nesterov);       \
   2348   extern template struct ApplyMomentum<GPUDevice, T>;
   2349 DECLARE_GPU_SPEC(Eigen::half);
   2350 DECLARE_GPU_SPEC(float);
   2351 DECLARE_GPU_SPEC(double);
   2352 #undef DECLARE_GPU_SPEC
   2353 }  // namespace functor
   2354 
   2355 REGISTER_KERNELS(GPU, Eigen::half);
   2356 REGISTER_KERNELS(GPU, float);
   2357 REGISTER_KERNELS(GPU, double);
   2358 #endif
   2359 #undef REGISTER_CPU_KERNELS
   2360 #undef REGISTER_KERNELS
   2361 
   2362 // Note, this op works on cpu only.
   2363 template <typename T, typename Tindex>
   2364 class SparseApplyMomentumOp : public OpKernel {
   2365  public:
   2366   explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2367     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2368     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
   2369   }
   2370 
   2371   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   2372     auto locks =
   2373         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   2374 
   2375     Tensor var;
   2376     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   2377                             ctx, 0, use_exclusive_lock_, true, &var));
   2378     Tensor accum;
   2379     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   2380                             ctx, 1, use_exclusive_lock_, true, &accum));
   2381     OP_REQUIRES(
   2382         ctx, var.IsInitialized(),
   2383         errors::FailedPrecondition(
   2384             "Attempting to use uninitialized variables: ", requested_input(0)));
   2385     OP_REQUIRES(
   2386         ctx, accum.IsInitialized(),
   2387         errors::FailedPrecondition(
   2388             "Attempting to use uninitialized variables: ", requested_input(1)));
   2389     OP_REQUIRES(
   2390         ctx, var.shape().IsSameSize(accum.shape()),
   2391         errors::InvalidArgument("var and accum do not have the same shape",
   2392                                 var.shape().DebugString(), " ",
   2393                                 accum.shape().DebugString()));
   2394     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   2395                 errors::InvalidArgument("var must be at least 1 dimensional"));
   2396 
   2397     const Tensor& lr = ctx->input(2);
   2398     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   2399                 errors::InvalidArgument("lr is not a scalar : ",
   2400                                         lr.shape().DebugString()));
   2401     const Tensor& grad = ctx->input(3);
   2402     const Tensor& indices = ctx->input(4);
   2403     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   2404                 errors::InvalidArgument("indices must be one-dimensional"));
   2405 
   2406     for (int d = 1; d < var.dims(); d++) {
   2407       OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
   2408                   errors::InvalidArgument(strings::StrCat(
   2409                       "var and grad must match in dimension ", d)));
   2410     }
   2411     const Tindex N = indices.dim_size(0);
   2412     OP_REQUIRES(
   2413         ctx, grad.dim_size(0) == N,
   2414         errors::InvalidArgument(
   2415             "grad must be the same size as indices in the first dimension."));
   2416 
   2417     const Tensor& momentum = ctx->input(5);
   2418     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   2419                 errors::InvalidArgument("momentum is not a scalar: ",
   2420                                         momentum.shape().DebugString()));
   2421 
   2422     if (N > 0) {
   2423       const Tindex first_dim_size = var.dim_size(0);
   2424       auto indices_vec = indices.vec<Tindex>();
   2425       auto var_flat = var.flat_outer_dims<T>();
   2426       auto accum_flat = accum.flat_outer_dims<T>();
   2427       auto grad_flat = grad.flat_outer_dims<T>();
   2428       T lr_scalar = lr.scalar<T>()();
   2429       T momentum_scalar = momentum.scalar<T>()();
   2430 
   2431       for (Tindex i = 0; i < N; i++) {
   2432         const Tindex index = internal::SubtleMustCopy(indices_vec(i));
   2433         OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size),
   2434                     errors::InvalidArgument(
   2435                         strings::StrCat("Index ", index, " at offset ", i,
   2436                                         " in indices is out of range")));
   2437         auto a = accum_flat.template chip<0>(index);
   2438         auto g = grad_flat.template chip<0>(i);
   2439         auto v = var_flat.template chip<0>(index);
   2440         a = a * a.constant(momentum_scalar) + g;
   2441         if (use_nesterov_) {
   2442           v -= g.constant(lr_scalar) * g +
   2443                a.constant(lr_scalar) * a.constant(momentum_scalar) * a;
   2444         } else {
   2445           v -= a.constant(lr_scalar) * a;
   2446         }
   2447       }
   2448     }
   2449 
   2450     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2451   }
   2452 
   2453  private:
   2454   bool use_exclusive_lock_;
   2455   bool use_nesterov_;
   2456 };
   2457 
   2458 #define REGISTER_KERNELS(T, Tindices)                                \
   2459   REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum")                \
   2460                               .Device(DEVICE_CPU)                    \
   2461                               .TypeConstraint<T>("T")                \
   2462                               .TypeConstraint<Tindices>("Tindices"), \
   2463                           SparseApplyMomentumOp<T, Tindices>);       \
   2464   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum")        \
   2465                               .Device(DEVICE_CPU)                    \
   2466                               .TypeConstraint<T>("T")                \
   2467                               .TypeConstraint<Tindices>("Tindices"), \
   2468                           SparseApplyMomentumOp<T, Tindices>);
   2469 #define REGISTER_CPU_KERNELS(T) \
   2470   REGISTER_KERNELS(T, int32);   \
   2471   REGISTER_KERNELS(T, int64);
   2472 
   2473 TF_CALL_half(REGISTER_CPU_KERNELS);
   2474 TF_CALL_float(REGISTER_CPU_KERNELS);
   2475 TF_CALL_double(REGISTER_CPU_KERNELS);
   2476 
   2477 #undef REGISTER_CPU_KERNELS
   2478 #undef REGISTER_KERNELS
   2479 
   2480 template <typename Device, typename T>
   2481 class ApplyAdamOp : public OpKernel {
   2482  public:
   2483   explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2484     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2485     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
   2486   }
   2487 
   2488   void Compute(OpKernelContext* ctx) override {
   2489     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2490                                                       {0, 1, 2});
   2491 
   2492     Tensor var;
   2493     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2494                             ctx, 0, use_exclusive_lock_, false, &var));
   2495     Tensor m;
   2496     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2497                             ctx, 1, use_exclusive_lock_, false, &m));
   2498     Tensor v;
   2499     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2500                             ctx, 2, use_exclusive_lock_, false, &v));
   2501     OP_REQUIRES(
   2502         ctx, var.IsInitialized(),
   2503         errors::FailedPrecondition(
   2504             "Attempting to use uninitialized variables: ", requested_input(0)));
   2505     OP_REQUIRES(
   2506         ctx, m.IsInitialized(),
   2507         errors::FailedPrecondition(
   2508             "Attempting to use uninitialized variables: ", requested_input(1)));
   2509     OP_REQUIRES(
   2510         ctx, v.IsInitialized(),
   2511         errors::FailedPrecondition(
   2512             "Attempting to use uninitialized variables: ", requested_input(2)));
   2513 
   2514     const Tensor& beta1_power = ctx->input(3);
   2515     const Tensor& beta2_power = ctx->input(4);
   2516     const Tensor& lr = ctx->input(5);
   2517     const Tensor& beta1 = ctx->input(6);
   2518     const Tensor& beta2 = ctx->input(7);
   2519     const Tensor& epsilon = ctx->input(8);
   2520 
   2521     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
   2522                 errors::InvalidArgument("beta1_power is not a scalar: ",
   2523                                         beta1_power.shape().DebugString()));
   2524     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
   2525                 errors::InvalidArgument("beta2_power is not a scalar: ",
   2526                                         beta2_power.shape().DebugString()));
   2527     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   2528                 errors::InvalidArgument("lr is not a scalar : ",
   2529                                         lr.shape().DebugString()));
   2530     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
   2531                 errors::InvalidArgument("beta1 is not a scalar: ",
   2532                                         beta1.shape().DebugString()));
   2533     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
   2534                 errors::InvalidArgument("beta2 is not a scalar: ",
   2535                                         beta2.shape().DebugString()));
   2536     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
   2537                 errors::InvalidArgument("epsilon is not a scalar: ",
   2538                                         epsilon.shape().DebugString()));
   2539 
   2540     const Tensor& grad = ctx->input(9);
   2541     OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
   2542                 errors::InvalidArgument("var and m do not have the same shape",
   2543                                         var.shape().DebugString(), " ",
   2544                                         m.shape().DebugString()));
   2545     OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
   2546                 errors::InvalidArgument("var and v do not have the same shape",
   2547                                         var.shape().DebugString(), " ",
   2548                                         v.shape().DebugString()));
   2549     OP_REQUIRES(
   2550         ctx, var.shape().IsSameSize(grad.shape()),
   2551         errors::InvalidArgument("var and grad do not have the same shape",
   2552                                 var.shape().DebugString(), " ",
   2553                                 grad.shape().DebugString()));
   2554 
   2555     const Device& device = ctx->template eigen_device<Device>();
   2556     functor::ApplyAdam<Device, T>()(
   2557         device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
   2558         beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(),
   2559         beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
   2560         grad.flat<T>(), use_nesterov_);
   2561 
   2562     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2563   }
   2564 
   2565  private:
   2566   bool use_exclusive_lock_;
   2567   bool use_nesterov_;
   2568 };
   2569 
   2570 #ifdef TENSORFLOW_USE_SYCL
   2571 template <typename T>
   2572 class ApplyAdamOp<SYCLDevice, T> : public OpKernel {
   2573  public:
   2574   explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2575     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2576   }
   2577 
   2578   void Compute(OpKernelContext* ctx) override {
   2579     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2580                                                       {0, 1, 2});
   2581 
   2582     Tensor var;
   2583     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
   2584                             ctx, 0, use_exclusive_lock_, false, &var));
   2585     Tensor m;
   2586     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
   2587                             ctx, 1, use_exclusive_lock_, false, &m));
   2588     Tensor v;
   2589     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>(
   2590                             ctx, 2, use_exclusive_lock_, false, &v));
   2591     OP_REQUIRES(
   2592         ctx, var.IsInitialized(),
   2593         errors::FailedPrecondition(
   2594             "Attempting to use uninitialized variables: ", requested_input(0)));
   2595     OP_REQUIRES(
   2596         ctx, m.IsInitialized(),
   2597         errors::FailedPrecondition(
   2598             "Attempting to use uninitialized variables: ", requested_input(1)));
   2599     OP_REQUIRES(
   2600         ctx, v.IsInitialized(),
   2601         errors::FailedPrecondition(
   2602             "Attempting to use uninitialized variables: ", requested_input(2)));
   2603 
   2604     const Tensor& beta1_power_dev = ctx->input(3);
   2605     const Tensor& beta2_power_dev = ctx->input(4);
   2606     const Tensor& lr_dev = ctx->input(5);
   2607     const Tensor& beta1_dev = ctx->input(6);
   2608     const Tensor& beta2_dev = ctx->input(7);
   2609     const Tensor& epsilon_dev = ctx->input(8);
   2610 
   2611     T beta1_power = 0;
   2612     T beta2_power = 0;
   2613     T lr = 0;
   2614     T beta1 = 0;
   2615     T beta2 = 0;
   2616     T epsilon = 0;
   2617 
   2618     auto device = ctx->eigen_sycl_device();
   2619     auto size = sizeof(T);
   2620     auto src_ptr = GetBase(&beta1_power_dev);
   2621     device.memcpyDeviceToHost(&beta1_power, static_cast<const T*>(src_ptr),
   2622                               size);
   2623 
   2624     src_ptr = GetBase(&beta2_power_dev);
   2625     device.memcpyDeviceToHost(&beta2_power, static_cast<const T*>(src_ptr),
   2626                               size);
   2627 
   2628     src_ptr = GetBase(&lr_dev);
   2629     device.memcpyDeviceToHost(&lr, static_cast<const T*>(src_ptr), size);
   2630 
   2631     src_ptr = GetBase(&beta1_dev);
   2632     device.memcpyDeviceToHost(&beta1, static_cast<const T*>(src_ptr), size);
   2633 
   2634     src_ptr = GetBase(&beta2_dev);
   2635     device.memcpyDeviceToHost(&beta2, static_cast<const T*>(src_ptr), size);
   2636 
   2637     src_ptr = GetBase(&epsilon_dev);
   2638     device.memcpyDeviceToHost(&epsilon, static_cast<const T*>(src_ptr), size);
   2639 
   2640     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_dev.shape()),
   2641                 errors::InvalidArgument("beta1_power is not a scalar: ",
   2642                                         beta1_power_dev.shape().DebugString()));
   2643     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_dev.shape()),
   2644                 errors::InvalidArgument("beta2_power is not a scalar: ",
   2645                                         beta2_power_dev.shape().DebugString()));
   2646     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_dev.shape()),
   2647                 errors::InvalidArgument("lr is not a scalar : ",
   2648                                         lr_dev.shape().DebugString()));
   2649     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_dev.shape()),
   2650                 errors::InvalidArgument("beta1 is not a scalar: ",
   2651                                         beta1_dev.shape().DebugString()));
   2652     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_dev.shape()),
   2653                 errors::InvalidArgument("beta2 is not a scalar: ",
   2654                                         beta2_dev.shape().DebugString()));
   2655     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_dev.shape()),
   2656                 errors::InvalidArgument("epsilon is not a scalar: ",
   2657                                         epsilon_dev.shape().DebugString()));
   2658 
   2659     const Tensor& grad = ctx->input(9);
   2660 
   2661     OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
   2662                 errors::InvalidArgument("var and m do not have the same shape",
   2663                                         var.shape().DebugString(), " ",
   2664                                         m.shape().DebugString()));
   2665     OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
   2666                 errors::InvalidArgument("var and v do not have the same shape",
   2667                                         var.shape().DebugString(), " ",
   2668                                         v.shape().DebugString()));
   2669     OP_REQUIRES(
   2670         ctx, var.shape().IsSameSize(grad.shape()),
   2671         errors::InvalidArgument("var and grad do not have the same shape",
   2672                                 var.shape().DebugString(), " ",
   2673                                 grad.shape().DebugString()));
   2674 
   2675     functor::ApplyAdamSYCL<T>()(device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
   2676                                 beta1_power, beta2_power, lr, beta1, beta2,
   2677                                 epsilon, grad.flat<T>());
   2678 
   2679     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2680   }
   2681 
   2682  private:
   2683   bool use_exclusive_lock_;
   2684 };
   2685 #endif  // TENSORFLOW_USE_SYCL
   2686 
   2687 #define REGISTER_KERNELS(D, T)                                     \
   2688   REGISTER_KERNEL_BUILDER(                                         \
   2689       Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   2690       ApplyAdamOp<D##Device, T>);                                  \
   2691   REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdam")                \
   2692                               .HostMemory("var")                   \
   2693                               .HostMemory("m")                     \
   2694                               .HostMemory("v")                     \
   2695                               .Device(DEVICE_##D)                  \
   2696                               .TypeConstraint<T>("T"),             \
   2697                           ApplyAdamOp<D##Device, T>);
   2698 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   2699 
   2700 TF_CALL_half(REGISTER_CPU_KERNELS);
   2701 TF_CALL_float(REGISTER_CPU_KERNELS);
   2702 TF_CALL_double(REGISTER_CPU_KERNELS);
   2703 
   2704 #ifdef TENSORFLOW_USE_SYCL
   2705 #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T);
   2706 
   2707 TF_CALL_float(REGISTER_SYCL_KERNELS);
   2708 TF_CALL_double(REGISTER_SYCL_KERNELS);
   2709 #endif
   2710 
   2711 #if GOOGLE_CUDA
   2712 // Forward declarations of the functor specializations for GPU.
   2713 namespace functor {
   2714 #define DECLARE_GPU_SPEC(T)                                   \
   2715   template <>                                                 \
   2716   void ApplyAdam<GPUDevice, T>::operator()(                   \
   2717       const GPUDevice& d, typename TTypes<T>::Flat var,       \
   2718       typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
   2719       typename TTypes<T>::ConstScalar beta1_power,            \
   2720       typename TTypes<T>::ConstScalar beta2_power,            \
   2721       typename TTypes<T>::ConstScalar lr,                     \
   2722       typename TTypes<T>::ConstScalar beta1,                  \
   2723       typename TTypes<T>::ConstScalar beta2,                  \
   2724       typename TTypes<T>::ConstScalar epsilon,                \
   2725       typename TTypes<T>::ConstFlat grad, bool use_nesterov); \
   2726   extern template struct ApplyAdam<GPUDevice, T>;
   2727 DECLARE_GPU_SPEC(Eigen::half);
   2728 DECLARE_GPU_SPEC(float);
   2729 DECLARE_GPU_SPEC(double);
   2730 #undef DECLARE_GPU_SPEC
   2731 }  // namespace functor
   2732 
   2733 REGISTER_KERNELS(GPU, Eigen::half);
   2734 REGISTER_KERNELS(GPU, float);
   2735 REGISTER_KERNELS(GPU, double);
   2736 #endif
   2737 #undef REGISTER_CPU_KERNELS
   2738 #undef REGISTER_KERNELS
   2739 
   2740 template <typename Device, typename T>
   2741 class ApplyRMSPropOp : public OpKernel {
   2742  public:
   2743   explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2744     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2745   }
   2746 
   2747   void Compute(OpKernelContext* ctx) override {
   2748     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2749                                                       {0, 1, 2});
   2750 
   2751     Tensor var;
   2752     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2753                             ctx, 0, use_exclusive_lock_, false, &var));
   2754     Tensor ms;
   2755     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2756                             ctx, 1, use_exclusive_lock_, false, &ms));
   2757     Tensor mom;
   2758     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2759                             ctx, 2, use_exclusive_lock_, false, &mom));
   2760 
   2761     OP_REQUIRES(
   2762         ctx, var.IsInitialized(),
   2763         errors::FailedPrecondition(
   2764             "Attempting to use uninitialized variables: ", requested_input(0)));
   2765     OP_REQUIRES(
   2766         ctx, ms.IsInitialized(),
   2767         errors::FailedPrecondition(
   2768             "Attempting to use uninitialized variables: ", requested_input(1)));
   2769     OP_REQUIRES(
   2770         ctx, mom.IsInitialized(),
   2771         errors::FailedPrecondition(
   2772             "Attempting to use uninitialized variables: ", requested_input(2)));
   2773 
   2774     const Tensor& lr = ctx->input(3);
   2775     const Tensor& rho = ctx->input(4);
   2776     const Tensor& momentum = ctx->input(5);
   2777     const Tensor& epsilon = ctx->input(6);
   2778     const Tensor& grad = ctx->input(7);
   2779 
   2780     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   2781                 errors::InvalidArgument("lr is not a scalar : ",
   2782                                         lr.shape().DebugString()));
   2783     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
   2784                 errors::InvalidArgument("rho is not a scalar: ",
   2785                                         rho.shape().DebugString()));
   2786     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   2787                 errors::InvalidArgument("momentum is not a scalar: ",
   2788                                         momentum.shape().DebugString()));
   2789     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
   2790                 errors::InvalidArgument("epsilon is not a scalar: ",
   2791                                         epsilon.shape().DebugString()));
   2792 
   2793     OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
   2794                 errors::InvalidArgument("var and ms do not have the same shape",
   2795                                         var.shape().DebugString(), " ",
   2796                                         ms.shape().DebugString()));
   2797 
   2798     OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
   2799                 errors::InvalidArgument(
   2800                     "var and mom do not have the same shape",
   2801                     var.shape().DebugString(), " ", mom.shape().DebugString()));
   2802 
   2803     OP_REQUIRES(
   2804         ctx, var.shape().IsSameSize(grad.shape()),
   2805         errors::InvalidArgument("var and grad do not have the same shape",
   2806                                 var.shape().DebugString(), " ",
   2807                                 grad.shape().DebugString()));
   2808 
   2809     const Device& device = ctx->template eigen_device<Device>();
   2810     functor::ApplyRMSProp<Device, T>()(device, var.flat<T>(), ms.flat<T>(),
   2811                                        mom.flat<T>(), lr.scalar<T>(),
   2812                                        rho.scalar<T>(), momentum.scalar<T>(),
   2813                                        epsilon.scalar<T>(), grad.flat<T>());
   2814 
   2815     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2816   }
   2817 
   2818  private:
   2819   bool use_exclusive_lock_;
   2820 };
   2821 
   2822 template <typename Device, typename T>
   2823 class ApplyCenteredRMSPropOp : public OpKernel {
   2824  public:
   2825   explicit ApplyCenteredRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2826     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2827   }
   2828 
   2829   void Compute(OpKernelContext* ctx) override {
   2830     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2831                                                       {0, 1, 2, 3});
   2832 
   2833     Tensor var;
   2834     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2835                             ctx, 0, use_exclusive_lock_, false, &var));
   2836     Tensor mg;
   2837     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2838                             ctx, 1, use_exclusive_lock_, false, &mg));
   2839     Tensor ms;
   2840     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2841                             ctx, 2, use_exclusive_lock_, false, &ms));
   2842     Tensor mom;
   2843     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   2844                             ctx, 3, use_exclusive_lock_, false, &mom));
   2845 
   2846     OP_REQUIRES(
   2847         ctx, var.IsInitialized(),
   2848         errors::FailedPrecondition(
   2849             "Attempting to use uninitialized variables: ", requested_input(0)));
   2850     OP_REQUIRES(
   2851         ctx, mg.IsInitialized(),
   2852         errors::FailedPrecondition(
   2853             "Attempting to use uninitialized variables: ", requested_input(1)));
   2854     OP_REQUIRES(
   2855         ctx, ms.IsInitialized(),
   2856         errors::FailedPrecondition(
   2857             "Attempting to use uninitialized variables: ", requested_input(2)));
   2858     OP_REQUIRES(
   2859         ctx, mom.IsInitialized(),
   2860         errors::FailedPrecondition(
   2861             "Attempting to use uninitialized variables: ", requested_input(3)));
   2862 
   2863     const Tensor& lr = ctx->input(4);
   2864     const Tensor& rho = ctx->input(5);
   2865     const Tensor& momentum = ctx->input(6);
   2866     const Tensor& epsilon = ctx->input(7);
   2867     const Tensor& grad = ctx->input(8);
   2868 
   2869     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   2870                 errors::InvalidArgument("lr is not a scalar : ",
   2871                                         lr.shape().DebugString()));
   2872     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
   2873                 errors::InvalidArgument("rho is not a scalar: ",
   2874                                         rho.shape().DebugString()));
   2875     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   2876                 errors::InvalidArgument("momentum is not a scalar: ",
   2877                                         momentum.shape().DebugString()));
   2878     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
   2879                 errors::InvalidArgument("epsilon is not a scalar: ",
   2880                                         epsilon.shape().DebugString()));
   2881 
   2882     OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()),
   2883                 errors::InvalidArgument("var and mg do not have the same shape",
   2884                                         var.shape().DebugString(), " ",
   2885                                         ms.shape().DebugString()));
   2886 
   2887     OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
   2888                 errors::InvalidArgument("var and ms do not have the same shape",
   2889                                         var.shape().DebugString(), " ",
   2890                                         ms.shape().DebugString()));
   2891 
   2892     OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
   2893                 errors::InvalidArgument(
   2894                     "var and mom do not have the same shape",
   2895                     var.shape().DebugString(), " ", mom.shape().DebugString()));
   2896 
   2897     OP_REQUIRES(
   2898         ctx, var.shape().IsSameSize(grad.shape()),
   2899         errors::InvalidArgument("var and grad do not have the same shape",
   2900                                 var.shape().DebugString(), " ",
   2901                                 grad.shape().DebugString()));
   2902 
   2903     const Device& device = ctx->template eigen_device<Device>();
   2904     functor::ApplyCenteredRMSProp<Device, T>()(
   2905         device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(),
   2906         lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(),
   2907         epsilon.scalar<T>(), grad.flat<T>());
   2908     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   2909   }
   2910 
   2911  private:
   2912   bool use_exclusive_lock_;
   2913 };
   2914 
   2915 #define REGISTER_KERNELS(D, T)                                                \
   2916   REGISTER_KERNEL_BUILDER(                                                    \
   2917       Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"),         \
   2918       ApplyRMSPropOp<D##Device, T>);                                          \
   2919   REGISTER_KERNEL_BUILDER(                                                    \
   2920       Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   2921       ApplyCenteredRMSPropOp<D##Device, T>);                                  \
   2922   REGISTER_KERNEL_BUILDER(Name("ResourceApplyRMSProp")                        \
   2923                               .Device(DEVICE_##D)                             \
   2924                               .HostMemory("var")                              \
   2925                               .HostMemory("ms")                               \
   2926                               .HostMemory("mom")                              \
   2927                               .TypeConstraint<T>("T"),                        \
   2928                           ApplyRMSPropOp<D##Device, T>);                      \
   2929   REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp")                \
   2930                               .Device(DEVICE_##D)                             \
   2931                               .HostMemory("var")                              \
   2932                               .HostMemory("mg")                               \
   2933                               .HostMemory("ms")                               \
   2934                               .HostMemory("mom")                              \
   2935                               .TypeConstraint<T>("T"),                        \
   2936                           ApplyCenteredRMSPropOp<D##Device, T>);
   2937 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   2938 
   2939 TF_CALL_half(REGISTER_CPU_KERNELS);
   2940 TF_CALL_float(REGISTER_CPU_KERNELS);
   2941 TF_CALL_double(REGISTER_CPU_KERNELS);
   2942 
   2943 #if GOOGLE_CUDA
   2944 // Forward declarations of the functor specializations for GPU.
   2945 namespace functor {
   2946 #define DECLARE_GPU_SPEC(T)                                                    \
   2947   template <>                                                                  \
   2948   void ApplyRMSProp<GPUDevice, T>::operator()(                                 \
   2949       const GPUDevice& d, typename TTypes<T>::Flat var,                        \
   2950       typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,               \
   2951       typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
   2952       typename TTypes<T>::ConstScalar momentum,                                \
   2953       typename TTypes<T>::ConstScalar epsilon,                                 \
   2954       typename TTypes<T>::ConstFlat grad);                                     \
   2955   extern template struct ApplyRMSProp<GPUDevice, T>;                           \
   2956   template <>                                                                  \
   2957   void ApplyCenteredRMSProp<GPUDevice, T>::operator()(                         \
   2958       const GPUDevice& d, typename TTypes<T>::Flat var,                        \
   2959       typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,                \
   2960       typename TTypes<T>::Flat mom, typename TTypes<T>::ConstScalar lr,        \
   2961       typename TTypes<T>::ConstScalar rho,                                     \
   2962       typename TTypes<T>::ConstScalar momentum,                                \
   2963       typename TTypes<T>::ConstScalar epsilon,                                 \
   2964       typename TTypes<T>::ConstFlat grad);                                     \
   2965   extern template struct ApplyCenteredRMSProp<GPUDevice, T>;
   2966 DECLARE_GPU_SPEC(Eigen::half);
   2967 DECLARE_GPU_SPEC(float);
   2968 DECLARE_GPU_SPEC(double);
   2969 #undef DECLARE_GPU_SPEC
   2970 }  // namespace functor
   2971 
   2972 REGISTER_KERNELS(GPU, Eigen::half);
   2973 REGISTER_KERNELS(GPU, float);
   2974 REGISTER_KERNELS(GPU, double);
   2975 #endif
   2976 #undef REGISTER_CPU_KERNELS
   2977 #undef REGISTER_KERNELS
   2978 
   2979 // Note, this op works on cpu only.
   2980 template <typename T, typename Tindex>
   2981 class SparseApplyRMSPropOp : public OpKernel {
   2982  public:
   2983   explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   2984     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   2985   }
   2986 
   2987   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   2988     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   2989                                                       {0, 1, 2});
   2990 
   2991     Tensor var;
   2992     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   2993                             ctx, 0, use_exclusive_lock_, true, &var));
   2994     Tensor ms;
   2995     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   2996                             ctx, 1, use_exclusive_lock_, true, &ms));
   2997     Tensor mom;
   2998     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   2999                             ctx, 2, use_exclusive_lock_, true, &mom));
   3000 
   3001     OP_REQUIRES(
   3002         ctx, var.IsInitialized(),
   3003         errors::FailedPrecondition(
   3004             "Attempting to use uninitialized variables: ", requested_input(0)));
   3005     OP_REQUIRES(
   3006         ctx, ms.IsInitialized(),
   3007         errors::FailedPrecondition(
   3008             "Attempting to use uninitialized variables: ", requested_input(1)));
   3009     OP_REQUIRES(
   3010         ctx, mom.IsInitialized(),
   3011         errors::FailedPrecondition(
   3012             "Attempting to use uninitialized variables: ", requested_input(2)));
   3013 
   3014     const Tensor& lr = ctx->input(3);
   3015     const Tensor& rho = ctx->input(4);
   3016     const Tensor& momentum = ctx->input(5);
   3017     const Tensor& epsilon = ctx->input(6);
   3018     const Tensor& grad = ctx->input(7);
   3019     const Tensor& indices = ctx->input(8);
   3020 
   3021     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   3022                 errors::InvalidArgument("lr is not a scalar: ",
   3023                                         lr.shape().DebugString()));
   3024     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
   3025                 errors::InvalidArgument("rho is not a scalar: ",
   3026                                         rho.shape().DebugString()));
   3027     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   3028                 errors::InvalidArgument("momentum is not a scalar: ",
   3029                                         momentum.shape().DebugString()));
   3030     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
   3031                 errors::InvalidArgument("epsilon is not a scalar: ",
   3032                                         epsilon.shape().DebugString()));
   3033 
   3034     OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
   3035                 errors::InvalidArgument("var and ms do not have the same shape",
   3036                                         var.shape().DebugString(), " ",
   3037                                         ms.shape().DebugString()));
   3038 
   3039     OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
   3040                 errors::InvalidArgument(
   3041                     "var and mom do not have the same shape",
   3042                     var.shape().DebugString(), " ", mom.shape().DebugString()));
   3043 
   3044     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   3045                 errors::InvalidArgument("var must be at least 1 dimensional"));
   3046 
   3047     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   3048                 errors::InvalidArgument("indices must be one-dimensional"));
   3049 
   3050     for (int d = 1; d < var.dims(); d++) {
   3051       OP_REQUIRES(
   3052           ctx, var.dim_size(d) == grad.dim_size(d),
   3053           errors::InvalidArgument("var and grad must match in dimension ", d));
   3054     }
   3055     const Tindex N = indices.dim_size(0);
   3056     OP_REQUIRES(
   3057         ctx, grad.dim_size(0) == N,
   3058         errors::InvalidArgument(
   3059             "grad must be the same size as indices in the first dimension."));
   3060 
   3061     if (N > 0) {
   3062       const Tindex first_dim_size = var.dim_size(0);
   3063       // Validate all the indices are in range
   3064       auto indices_vec = indices.vec<Tindex>();
   3065       for (Tindex i = 0; i < N; i++) {
   3066         const Tindex index = indices_vec(i);
   3067         OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
   3068                     errors::InvalidArgument(
   3069                         strings::StrCat("Index ", index, " at offset ", i,
   3070                                         " in indices is out of range")));
   3071       }
   3072 
   3073       auto var_flat = var.flat_outer_dims<T>();
   3074       auto ms_flat = ms.flat_outer_dims<T>();
   3075       auto mom_flat = mom.flat_outer_dims<T>();
   3076       auto grad_flat = grad.flat_outer_dims<T>();
   3077       const T lr_scalar = lr.scalar<T>()();
   3078       const T rho_scalar = rho.scalar<T>()();
   3079       const T epsilon_scalar = epsilon.scalar<T>()();
   3080       const T momentum_scalar = momentum.scalar<T>()();
   3081 
   3082       for (Tindex i = 0; i < N; i++) {
   3083         const Tindex index = indices_vec(i);
   3084 
   3085         auto ms_ = ms_flat.template chip<0>(index);
   3086         auto mom_ = mom_flat.template chip<0>(index);
   3087         auto grad_ = grad_flat.template chip<0>(i);
   3088 
   3089         ms_ = ms_ * ms_.constant(rho_scalar) +
   3090               grad_.square() * grad_.constant(T(1) - rho_scalar);
   3091         mom_ = mom_ * mom_.constant(momentum_scalar) +
   3092                (ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
   3093                    ms_.constant(lr_scalar) * grad_;
   3094 
   3095         auto v = var_flat.template chip<0>(index);
   3096         v -= mom_;
   3097       }
   3098     }
   3099 
   3100     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   3101   }
   3102 
   3103  private:
   3104   bool use_exclusive_lock_;
   3105 };
   3106 
   3107 // Note, this op works on cpu only.
   3108 template <typename T, typename Tindex>
   3109 class SparseApplyCenteredRMSPropOp : public OpKernel {
   3110  public:
   3111   explicit SparseApplyCenteredRMSPropOp(OpKernelConstruction* ctx)
   3112       : OpKernel(ctx) {
   3113     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   3114   }
   3115 
   3116   void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
   3117     auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
   3118                                                       {0, 1, 2, 3});
   3119 
   3120     Tensor var;
   3121     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   3122                             ctx, 0, use_exclusive_lock_, true, &var));
   3123     Tensor mg;
   3124     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   3125                             ctx, 1, use_exclusive_lock_, true, &mg));
   3126     Tensor ms;
   3127     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   3128                             ctx, 2, use_exclusive_lock_, true, &ms));
   3129     Tensor mom;
   3130     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>(
   3131                             ctx, 3, use_exclusive_lock_, true, &mom));
   3132 
   3133     OP_REQUIRES(
   3134         ctx, var.IsInitialized(),
   3135         errors::FailedPrecondition(
   3136             "Attempting to use uninitialized variables: ", requested_input(0)));
   3137     OP_REQUIRES(
   3138         ctx, ms.IsInitialized(),
   3139         errors::FailedPrecondition(
   3140             "Attempting to use uninitialized variables: ", requested_input(2)));
   3141     OP_REQUIRES(
   3142         ctx, mom.IsInitialized(),
   3143         errors::FailedPrecondition(
   3144             "Attempting to use uninitialized variables: ", requested_input(3)));
   3145 
   3146     const Tensor& lr = ctx->input(4);
   3147     const Tensor& rho = ctx->input(5);
   3148     const Tensor& momentum = ctx->input(6);
   3149     const Tensor& epsilon = ctx->input(7);
   3150     const Tensor& grad = ctx->input(8);
   3151     const Tensor& indices = ctx->input(9);
   3152 
   3153     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   3154                 errors::InvalidArgument("lr is not a scalar: ",
   3155                                         lr.shape().DebugString()));
   3156     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
   3157                 errors::InvalidArgument("rho is not a scalar: ",
   3158                                         rho.shape().DebugString()));
   3159     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
   3160                 errors::InvalidArgument("momentum is not a scalar: ",
   3161                                         momentum.shape().DebugString()));
   3162     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
   3163                 errors::InvalidArgument("epsilon is not a scalar: ",
   3164                                         epsilon.shape().DebugString()));
   3165 
   3166     OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()),
   3167                 errors::InvalidArgument("var and mg do not have the same shape",
   3168                                         var.shape().DebugString(), " ",
   3169                                         mg.shape().DebugString()));
   3170 
   3171     OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
   3172                 errors::InvalidArgument("var and ms do not have the same shape",
   3173                                         var.shape().DebugString(), " ",
   3174                                         ms.shape().DebugString()));
   3175 
   3176     OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
   3177                 errors::InvalidArgument(
   3178                     "var and mom do not have the same shape",
   3179                     var.shape().DebugString(), " ", mom.shape().DebugString()));
   3180 
   3181     OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
   3182                 errors::InvalidArgument("var must be at least 1 dimensional"));
   3183 
   3184     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
   3185                 errors::InvalidArgument("indices must be one-dimensional"));
   3186 
   3187     for (int d = 1; d < var.dims(); d++) {
   3188       OP_REQUIRES(
   3189           ctx, var.dim_size(d) == grad.dim_size(d),
   3190           errors::InvalidArgument("var and grad must match in dimension ", d));
   3191     }
   3192     const Tindex N = indices.dim_size(0);
   3193     OP_REQUIRES(
   3194         ctx, grad.dim_size(0) == N,
   3195         errors::InvalidArgument(
   3196             "grad must be the same size as indices in the first dimension."));
   3197 
   3198     if (N > 0) {
   3199       const Tindex first_dim_size = var.dim_size(0);
   3200       // Validate all the indices are in range
   3201       auto indices_vec = indices.vec<Tindex>();
   3202       for (Tindex i = 0; i < N; i++) {
   3203         const Tindex index = indices_vec(i);
   3204         OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
   3205                     errors::InvalidArgument(
   3206                         strings::StrCat("Index ", index, " at offset ", i,
   3207                                         " in indices is out of range")));
   3208       }
   3209 
   3210       auto var_flat = var.flat_outer_dims<T>();
   3211       auto ms_flat = ms.flat_outer_dims<T>();
   3212       auto mg_flat = mg.flat_outer_dims<T>();
   3213       auto mom_flat = mom.flat_outer_dims<T>();
   3214       auto grad_flat = grad.flat_outer_dims<T>();
   3215       const T lr_scalar = lr.scalar<T>()();
   3216       const T rho_scalar = rho.scalar<T>()();
   3217       const T epsilon_scalar = epsilon.scalar<T>()();
   3218       const T momentum_scalar = momentum.scalar<T>()();
   3219 
   3220       for (Tindex i = 0; i < N; i++) {
   3221         const Tindex index = indices_vec(i);
   3222 
   3223         auto ms_ = ms_flat.template chip<0>(index);
   3224         auto mom_ = mom_flat.template chip<0>(index);
   3225         auto grad_ = grad_flat.template chip<0>(i);
   3226 
   3227         ms_ = ms_ * ms_.constant(rho_scalar) +
   3228               grad_.square() * grad_.constant(T(1) - rho_scalar);
   3229 
   3230         auto mg_ = mg_flat.template chip<0>(index);
   3231         mg_ = mg_ * mg_.constant(rho_scalar) +
   3232               grad_ * grad_.constant(T(1) - rho_scalar);
   3233         auto denom_ = ms_ + ms_.constant(epsilon_scalar) - mg_.square();
   3234         mom_ = mom_ * mom_.constant(momentum_scalar) +
   3235                denom_.rsqrt() * ms_.constant(lr_scalar) * grad_;
   3236         auto v = var_flat.template chip<0>(index);
   3237         v -= mom_;
   3238       }
   3239     }
   3240 
   3241     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   3242   }
   3243 
   3244  private:
   3245   bool use_exclusive_lock_;
   3246 };
   3247 
   3248 #define REGISTER_KERNELS(T, Tindices)                                 \
   3249   REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp")                  \
   3250                               .Device(DEVICE_CPU)                     \
   3251                               .TypeConstraint<T>("T")                 \
   3252                               .TypeConstraint<Tindices>("Tindices"),  \
   3253                           SparseApplyRMSPropOp<T, Tindices>);         \
   3254   REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp")          \
   3255                               .Device(DEVICE_CPU)                     \
   3256                               .TypeConstraint<T>("T")                 \
   3257                               .TypeConstraint<Tindices>("Tindices"),  \
   3258                           SparseApplyCenteredRMSPropOp<T, Tindices>); \
   3259   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp")          \
   3260                               .Device(DEVICE_CPU)                     \
   3261                               .TypeConstraint<T>("T")                 \
   3262                               .TypeConstraint<Tindices>("Tindices"),  \
   3263                           SparseApplyRMSPropOp<T, Tindices>);         \
   3264   REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp")  \
   3265                               .Device(DEVICE_CPU)                     \
   3266                               .TypeConstraint<T>("T")                 \
   3267                               .TypeConstraint<Tindices>("Tindices"),  \
   3268                           SparseApplyCenteredRMSPropOp<T, Tindices>);
   3269 
   3270 REGISTER_KERNELS(Eigen::half, int32);
   3271 REGISTER_KERNELS(Eigen::half, int64);
   3272 REGISTER_KERNELS(float, int32);
   3273 REGISTER_KERNELS(float, int64);
   3274 REGISTER_KERNELS(double, int32);
   3275 REGISTER_KERNELS(double, int64);
   3276 
   3277 #undef REGISTER_KERNELS
   3278 
   3279 template <typename Device, typename T>
   3280 class ApplyAddSignOp : public OpKernel {
   3281  public:
   3282   explicit ApplyAddSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   3283     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   3284   }
   3285 
   3286   void Compute(OpKernelContext* ctx) override {
   3287     auto locks =
   3288         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   3289 
   3290     Tensor var;
   3291     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   3292                             ctx, 0, use_exclusive_lock_, false, &var));
   3293     Tensor m;
   3294     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   3295                             ctx, 1, use_exclusive_lock_, false, &m));
   3296     OP_REQUIRES(
   3297         ctx, var.IsInitialized(),
   3298         errors::FailedPrecondition(
   3299             "Attempting to use uninitialized variables: ", requested_input(0)));
   3300     OP_REQUIRES(
   3301         ctx, m.IsInitialized(),
   3302         errors::FailedPrecondition(
   3303             "Attempting to use uninitialized variables: ", requested_input(1)));
   3304     const Tensor& lr = ctx->input(2);
   3305     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   3306                 errors::InvalidArgument("lr is not a scalar: ",
   3307                                         lr.shape().DebugString()));
   3308     const Tensor& alpha = ctx->input(3);
   3309     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
   3310                 errors::InvalidArgument("alpha is not a scalar: ",
   3311                                         alpha.shape().DebugString()));
   3312     const Tensor& sign_decay = ctx->input(4);
   3313     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
   3314                 errors::InvalidArgument("sign_decay is not a scalar: ",
   3315                                         sign_decay.shape().DebugString()));
   3316     const Tensor& beta = ctx->input(5);
   3317     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()),
   3318                 errors::InvalidArgument("beta is not a scalar: ",
   3319                                         beta.shape().DebugString()));
   3320     const Tensor& grad = ctx->input(6);
   3321     OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
   3322                 errors::InvalidArgument("var and m do not have the same shape",
   3323                                         var.shape().DebugString(), " ",
   3324                                         m.shape().DebugString()));
   3325     OP_REQUIRES(
   3326         ctx, var.shape().IsSameSize(grad.shape()),
   3327         errors::InvalidArgument("var and grad do not have the same shape",
   3328                                 var.shape().DebugString(), " ",
   3329                                 grad.shape().DebugString()));
   3330 
   3331     const Device& device = ctx->template eigen_device<Device>();
   3332     functor::ApplyAddSign<Device, T>()(
   3333         device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), alpha.scalar<T>(),
   3334         sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>());
   3335     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   3336   }
   3337 
   3338  private:
   3339   bool use_exclusive_lock_;
   3340 };
   3341 
   3342 #define REGISTER_KERNELS(D, T)                                        \
   3343   REGISTER_KERNEL_BUILDER(                                            \
   3344       Name("ApplyAddSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   3345       ApplyAddSignOp<D##Device, T>);                                  \
   3346   REGISTER_KERNEL_BUILDER(Name("ResourceApplyAddSign")                \
   3347                               .Device(DEVICE_##D)                     \
   3348                               .HostMemory("var")                      \
   3349                               .HostMemory("m")                        \
   3350                               .TypeConstraint<T>("T"),                \
   3351                           ApplyAddSignOp<D##Device, T>);
   3352 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   3353 
   3354 TF_CALL_half(REGISTER_CPU_KERNELS);
   3355 TF_CALL_float(REGISTER_CPU_KERNELS);
   3356 TF_CALL_double(REGISTER_CPU_KERNELS);
   3357 
   3358 #if GOOGLE_CUDA
   3359 // Forward declarations of the functor specializations for GPU.
   3360 namespace functor {
   3361 #define DECLARE_GPU_SPEC(T)                                           \
   3362   template <>                                                         \
   3363   void ApplyAddSign<GPUDevice, T>::operator()(                        \
   3364       const GPUDevice& d, typename TTypes<T>::Flat var,               \
   3365       typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \
   3366       typename TTypes<T>::ConstScalar alpha,                          \
   3367       typename TTypes<T>::ConstScalar sign_decay,                     \
   3368       typename TTypes<T>::ConstScalar beta,                           \
   3369       typename TTypes<T>::ConstFlat grad);                            \
   3370   extern template struct ApplyAddSign<GPUDevice, T>;
   3371 DECLARE_GPU_SPEC(Eigen::half);
   3372 DECLARE_GPU_SPEC(float);
   3373 DECLARE_GPU_SPEC(double);
   3374 #undef DECLARE_GPU_SPEC
   3375 }  // namespace functor
   3376 
   3377 REGISTER_KERNELS(GPU, Eigen::half);
   3378 REGISTER_KERNELS(GPU, float);
   3379 REGISTER_KERNELS(GPU, double);
   3380 #endif
   3381 #undef REGISTER_CPU_KERNELS
   3382 #undef REGISTER_KERNELS
   3383 
   3384 template <typename Device, typename T>
   3385 class ApplyPowerSignOp : public OpKernel {
   3386  public:
   3387   explicit ApplyPowerSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
   3388     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
   3389   }
   3390 
   3391   void Compute(OpKernelContext* ctx) override {
   3392     auto locks =
   3393         MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1});
   3394 
   3395     Tensor var;
   3396     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   3397                             ctx, 0, use_exclusive_lock_, false, &var));
   3398     Tensor m;
   3399     OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
   3400                             ctx, 1, use_exclusive_lock_, false, &m));
   3401     OP_REQUIRES(
   3402         ctx, var.IsInitialized(),
   3403         errors::FailedPrecondition(
   3404             "Attempting to use uninitialized variables: ", requested_input(0)));
   3405     OP_REQUIRES(
   3406         ctx, m.IsInitialized(),
   3407         errors::FailedPrecondition(
   3408             "Attempting to use uninitialized variables: ", requested_input(1)));
   3409     const Tensor& lr = ctx->input(2);
   3410     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
   3411                 errors::InvalidArgument("lr is not a scalar: ",
   3412                                         lr.shape().DebugString()));
   3413     const Tensor& logbase = ctx->input(3);
   3414     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()),
   3415                 errors::InvalidArgument("logbase is not a scalar: ",
   3416                                         logbase.shape().DebugString()));
   3417     const Tensor& sign_decay = ctx->input(4);
   3418     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()),
   3419                 errors::InvalidArgument("sign_decay is not a scalar: ",
   3420                                         sign_decay.shape().DebugString()));
   3421     const Tensor& beta = ctx->input(5);
   3422     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()),
   3423                 errors::InvalidArgument("beta is not a scalar: ",
   3424                                         beta.shape().DebugString()));
   3425     const Tensor& grad = ctx->input(6);
   3426     OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
   3427                 errors::InvalidArgument("var and m do not have the same shape",
   3428                                         var.shape().DebugString(), " ",
   3429                                         m.shape().DebugString()));
   3430     OP_REQUIRES(
   3431         ctx, var.shape().IsSameSize(grad.shape()),
   3432         errors::InvalidArgument("var and grad do not have the same shape",
   3433                                 var.shape().DebugString(), " ",
   3434                                 grad.shape().DebugString()));
   3435 
   3436     const Device& device = ctx->template eigen_device<Device>();
   3437     functor::ApplyPowerSign<Device, T>()(
   3438         device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), logbase.scalar<T>(),
   3439         sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>());
   3440     MaybeForwardRefInputToRefOutput(ctx, 0, 0);
   3441   }
   3442 
   3443  private:
   3444   bool use_exclusive_lock_;
   3445 };
   3446 
   3447 #define REGISTER_KERNELS(D, T)                                          \
   3448   REGISTER_KERNEL_BUILDER(                                              \
   3449       Name("ApplyPowerSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \
   3450       ApplyPowerSignOp<D##Device, T>);                                  \
   3451   REGISTER_KERNEL_BUILDER(Name("ResourceApplyPowerSign")                \
   3452                               .Device(DEVICE_##D)                       \
   3453                               .HostMemory("var")                        \
   3454                               .HostMemory("m")                          \
   3455                               .TypeConstraint<T>("T"),                  \
   3456                           ApplyPowerSignOp<D##Device, T>);
   3457 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
   3458 
   3459 TF_CALL_half(REGISTER_CPU_KERNELS);
   3460 TF_CALL_float(REGISTER_CPU_KERNELS);
   3461 TF_CALL_double(REGISTER_CPU_KERNELS);
   3462 
   3463 #if GOOGLE_CUDA
   3464 // Forward declarations of the functor specializations for GPU.
   3465 namespace functor {
   3466 #define DECLARE_GPU_SPEC(T)                                           \
   3467   template <>                                                         \
   3468   void ApplyPowerSign<GPUDevice, T>::operator()(                      \
   3469       const GPUDevice& d, typename TTypes<T>::Flat var,               \
   3470       typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \
   3471       typename TTypes<T>::ConstScalar logbase,                        \
   3472       typename TTypes<T>::ConstScalar sign_decay,                     \
   3473       typename TTypes<T>::ConstScalar beta,                           \
   3474       typename TTypes<T>::ConstFlat grad);                            \
   3475   extern template struct ApplyPowerSign<GPUDevice, T>;
   3476 DECLARE_GPU_SPEC(Eigen::half);
   3477 DECLARE_GPU_SPEC(float);
   3478 DECLARE_GPU_SPEC(double);
   3479 #undef DECLARE_GPU_SPEC
   3480 }  // namespace functor
   3481 
   3482 REGISTER_KERNELS(GPU, Eigen::half);
   3483 REGISTER_KERNELS(GPU, float);
   3484 REGISTER_KERNELS(GPU, double);
   3485 #endif
   3486 #undef REGISTER_CPU_KERNELS
   3487 #undef REGISTER_KERNELS
   3488 
   3489 }  // namespace tensorflow
   3490