Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
     17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/compiler/xla/client/lib/constants.h"
     20 #include "tensorflow/compiler/xla/client/lib/math.h"
     21 #include "tensorflow/compiler/xla/client/xla_builder.h"
     22 #include "tensorflow/compiler/xla/literal.h"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 #include "tensorflow/core/framework/types.h"
     25 
     26 namespace tensorflow {
     27 namespace {
     28 
     29 class ResourceApplyGradientDescent : public XlaOpKernel {
     30  public:
     31   explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
     32       : XlaOpKernel(ctx) {}
     33   void Compile(XlaOpKernelContext* ctx) override {
     34     xla::XlaOp handle;
     35     DataType type = ctx->input_type(1);
     36     TensorShape var_shape;
     37     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
     38 
     39     TensorShape alpha_shape = ctx->InputShape(1);
     40     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
     41                 errors::InvalidArgument("alpha is not a scalar: ",
     42                                         alpha_shape.DebugString()));
     43 
     44     TensorShape delta_shape = ctx->InputShape(2);
     45     OP_REQUIRES(
     46         ctx, var_shape.IsSameSize(delta_shape),
     47         errors::InvalidArgument("var and delta do not have the same shape: ",
     48                                 var_shape.DebugString(), " vs ",
     49                                 delta_shape.DebugString()));
     50 
     51     handle = handle - ctx->Input(1) * ctx->Input(2);
     52     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
     53   }
     54 };
     55 REGISTER_XLA_OP(
     56     Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
     57     ResourceApplyGradientDescent);
     58 
     59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
     60                                          xla::XlaOp l1, xla::XlaOp l2,
     61                                          xla::XlaOp grad) {
     62   xla::XlaOp one = xla::ScalarLike(lr, 1.0);
     63   xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
     64   xla::XlaOp prox_var = var - grad * lr;
     65   xla::XlaOp l1_gt_zero = xla::Sign(prox_var) *
     66                           xla::Max(xla::Abs(prox_var) - lr * l1, zero) /
     67                           (one + lr * l2);
     68   xla::XlaOp l1_le_zero = prox_var / (one + lr * l2);
     69   return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
     70 }
     71 
     72 class ResourceApplyProximalGradientDescent : public XlaOpKernel {
     73  public:
     74   explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx)
     75       : XlaOpKernel(ctx) {
     76     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
     77   }
     78 
     79   void Compile(XlaOpKernelContext* ctx) override {
     80     xla::XlaOp var;
     81     TensorShape var_shape;
     82     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
     83 
     84     TensorShape alpha_shape = ctx->InputShape(1);
     85     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
     86                 errors::InvalidArgument("alpha is not a scalar: ",
     87                                         alpha_shape.DebugString()));
     88     TensorShape l1_shape = ctx->InputShape(2);
     89     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
     90                 errors::InvalidArgument("l1 is not a scalar: ",
     91                                         l1_shape.DebugString()));
     92     TensorShape l2_shape = ctx->InputShape(3);
     93     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
     94                 errors::InvalidArgument("l2 is not a scalar: ",
     95                                         l2_shape.DebugString()));
     96     TensorShape delta_shape = ctx->InputShape(4);
     97     OP_REQUIRES(
     98         ctx, var_shape.IsSameSize(delta_shape),
     99         errors::InvalidArgument("var and delta do not have the same shape: ",
    100                                 var_shape.DebugString(), " vs ",
    101                                 delta_shape.DebugString()));
    102     xla::XlaOp alpha = ctx->Input(1);
    103     xla::XlaOp l1 = ctx->Input(2);
    104     xla::XlaOp l2 = ctx->Input(3);
    105     xla::XlaOp delta = ctx->Input(4);
    106     var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta);
    107     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    108   }
    109 
    110  private:
    111   DataType dtype_;
    112 };
    113 REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent")
    114                     .TypeConstraint("T", kFloatTypes),
    115                 ResourceApplyProximalGradientDescent);
    116 
    117 class ResourceApplyMomentum : public XlaOpKernel {
    118  public:
    119   explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    120     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
    121   }
    122 
    123   void Compile(XlaOpKernelContext* ctx) override {
    124     DataType type = ctx->input_type(2);
    125 
    126     TensorShape var_shape, accum_shape;
    127     xla::XlaOp var, accum;
    128     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
    129     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
    130 
    131     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    132                 errors::InvalidArgument(
    133                     "var and accum do not have the same shape",
    134                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    135 
    136     TensorShape lr_shape = ctx->InputShape(2);
    137     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    138                 errors::InvalidArgument("lr is not a scalar: ",
    139                                         lr_shape.DebugString()));
    140 
    141     TensorShape grad_shape = ctx->InputShape(3);
    142     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    143                 errors::InvalidArgument(
    144                     "var and grad do not have the same shape",
    145                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    146 
    147     TensorShape momentum_shape = ctx->InputShape(4);
    148     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
    149                 errors::InvalidArgument("momentum is not a scalar: ",
    150                                         momentum_shape.DebugString()));
    151 
    152     xla::XlaOp lr = ctx->Input(2);
    153     xla::XlaOp grad = ctx->Input(3);
    154     xla::XlaOp momentum = ctx->Input(4);
    155 
    156     accum = accum * momentum + grad;
    157     if (use_nesterov_) {
    158       // See https://github.com/tensorflow/tensorflow/pull/2798 for an
    159       // explanation of the reparameterization used here.
    160       var = var - (grad * lr + accum * momentum * lr);
    161     } else {
    162       var = var - accum * lr;
    163     }
    164     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
    165     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
    166   }
    167 
    168  private:
    169   bool use_nesterov_;
    170 };
    171 REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes),
    172                 ResourceApplyMomentum);
    173 
    174 class ResourceApplyKerasMomentum : public XlaOpKernel {
    175  public:
    176   explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx)
    177       : XlaOpKernel(ctx) {
    178     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
    179   }
    180 
    181   void Compile(XlaOpKernelContext* ctx) override {
    182     DataType type = ctx->input_type(2);
    183 
    184     TensorShape var_shape, accum_shape;
    185     xla::XlaOp var, accum;
    186     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
    187     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
    188 
    189     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    190                 errors::InvalidArgument(
    191                     "var and accum do not have the same shape",
    192                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    193 
    194     TensorShape lr_shape = ctx->InputShape(2);
    195     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    196                 errors::InvalidArgument("lr is not a scalar: ",
    197                                         lr_shape.DebugString()));
    198 
    199     TensorShape grad_shape = ctx->InputShape(3);
    200     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    201                 errors::InvalidArgument(
    202                     "var and grad do not have the same shape",
    203                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    204 
    205     TensorShape momentum_shape = ctx->InputShape(4);
    206     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
    207                 errors::InvalidArgument("momentum is not a scalar: ",
    208                                         momentum_shape.DebugString()));
    209 
    210     xla::XlaOp lr = ctx->Input(2);
    211     xla::XlaOp grad = ctx->Input(3);
    212     xla::XlaOp momentum = ctx->Input(4);
    213 
    214     accum = accum * momentum - grad * lr;
    215     if (use_nesterov_) {
    216       // See https://github.com/tensorflow/tensorflow/pull/2798 for an
    217       // explanation of the reparameterization used here.
    218       var = var + accum * momentum - grad * lr;
    219     } else {
    220       var = var + accum;
    221     }
    222     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
    223     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
    224   }
    225 
    226  private:
    227   bool use_nesterov_;
    228 };
    229 REGISTER_XLA_OP(
    230     Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes),
    231     ResourceApplyKerasMomentum);
    232 
    233 class ResourceApplyAdagrad : public XlaOpKernel {
    234  public:
    235   explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    236 
    237   void Compile(XlaOpKernelContext* ctx) override {
    238     DataType type = ctx->input_type(2);
    239 
    240     TensorShape var_shape, accum_shape;
    241     xla::XlaOp var, accum;
    242     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
    243     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
    244 
    245     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    246                 errors::InvalidArgument(
    247                     "var and accum do not have the same shape",
    248                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    249 
    250     TensorShape lr_shape = ctx->InputShape(2);
    251     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    252                 errors::InvalidArgument("lr is not a scalar: ",
    253                                         lr_shape.DebugString()));
    254 
    255     TensorShape grad_shape = ctx->InputShape(3);
    256     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    257                 errors::InvalidArgument(
    258                     "var and grad do not have the same shape",
    259                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    260 
    261     xla::XlaOp lr = ctx->Input(2);
    262     xla::XlaOp grad = ctx->Input(3);
    263 
    264     accum = accum + xla::Square(grad);
    265     var = var - grad * lr * xla::Rsqrt(accum);
    266     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
    267     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
    268   }
    269 };
    270 REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
    271                 ResourceApplyAdagrad);
    272 
    273 class ResourceApplyProximalAdagrad : public XlaOpKernel {
    274  public:
    275   explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
    276       : XlaOpKernel(ctx) {
    277     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    278   }
    279 
    280   void Compile(XlaOpKernelContext* ctx) override {
    281     TensorShape var_shape, accum_shape;
    282     xla::XlaOp var, accum;
    283     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    284     OP_REQUIRES_OK(ctx,
    285                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
    286 
    287     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    288                 errors::InvalidArgument(
    289                     "var and accum do not have the same shape",
    290                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    291 
    292     TensorShape lr_shape = ctx->InputShape(2);
    293     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    294                 errors::InvalidArgument("lr is not a scalar: ",
    295                                         lr_shape.DebugString()));
    296     TensorShape l1_shape = ctx->InputShape(3);
    297     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
    298                 errors::InvalidArgument("l1 is not a scalar: ",
    299                                         l1_shape.DebugString()));
    300     TensorShape l2_shape = ctx->InputShape(4);
    301     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
    302                 errors::InvalidArgument("l2 is not a scalar: ",
    303                                         l2_shape.DebugString()));
    304     TensorShape grad_shape = ctx->InputShape(5);
    305     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    306                 errors::InvalidArgument(
    307                     "var and grad do not have the same shape: ",
    308                     var_shape.DebugString(), " vs ", grad_shape.DebugString()));
    309 
    310     xla::XlaOp lr = ctx->Input(2);
    311     xla::XlaOp l1 = ctx->Input(3);
    312     xla::XlaOp l2 = ctx->Input(4);
    313     xla::XlaOp grad = ctx->Input(5);
    314     accum = accum + xla::Square(grad);
    315     // Adagrad learning rate.
    316     xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum);
    317     var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad);
    318     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    319     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
    320   }
    321 
    322  private:
    323   DataType dtype_;
    324 };
    325 REGISTER_XLA_OP(
    326     Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes),
    327     ResourceApplyProximalAdagrad);
    328 
    329 class ResourceApplyAdagradDA : public XlaOpKernel {
    330  public:
    331   explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
    332       : XlaOpKernel(ctx) {
    333     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    334   }
    335 
    336   void Compile(XlaOpKernelContext* ctx) override {
    337     TensorShape var_shape, accum_shape, squared_accum_shape;
    338     xla::XlaOp var, accum, squared_accum;
    339     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    340     OP_REQUIRES_OK(ctx,
    341                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
    342     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
    343                                                &squared_accum));
    344     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    345                 errors::InvalidArgument(
    346                     "var and accum do not have the same shape",
    347                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    348     OP_REQUIRES(
    349         ctx, var_shape.IsSameSize(squared_accum_shape),
    350         errors::InvalidArgument(
    351             "var and squared accum do not have the same shape",
    352             var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
    353 
    354     TensorShape grad_shape = ctx->InputShape(3);
    355     TensorShape lr_shape = ctx->InputShape(4);
    356     TensorShape l1_shape = ctx->InputShape(5);
    357     TensorShape l2_shape = ctx->InputShape(6);
    358     TensorShape global_step_shape = ctx->InputShape(7);
    359 
    360     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    361                 errors::InvalidArgument(
    362                     "var and grad do not have the same shape",
    363                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    364     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    365                 errors::InvalidArgument("lr is not a scalar: ",
    366                                         lr_shape.DebugString()));
    367     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
    368                 errors::InvalidArgument("l1 is not a scalar: ",
    369                                         l1_shape.DebugString()));
    370     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
    371                 errors::InvalidArgument("l2 is not a scalar: ",
    372                                         l2_shape.DebugString()));
    373     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
    374                 errors::InvalidArgument("global step is not a scalar: ",
    375                                         global_step_shape.DebugString()));
    376 
    377     xla::XlaOp grad = ctx->Input(3);
    378     xla::XlaOp lr = ctx->Input(4);
    379     xla::XlaOp l1 = ctx->Input(5);
    380     xla::XlaOp l2 = ctx->Input(6);
    381     xla::XlaOp global_step =
    382         XlaHelpers::ConvertElementType(ctx->Input(7), dtype_);
    383 
    384     accum = accum + grad;
    385     squared_accum = squared_accum + xla::Square(grad);
    386     xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
    387     xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
    388     xla::XlaOp l1_le_zero = -lr * accum / denominator;
    389     xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
    390                             xla::Max(xla::Abs(accum) - global_step * l1, zero) /
    391                             denominator;
    392 
    393     var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
    394     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    395     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
    396     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
    397   }
    398 
    399  private:
    400   DataType dtype_;
    401 };
    402 REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
    403                 ResourceApplyAdagradDA);
    404 
    405 class ResourceApplyAdam : public XlaOpKernel {
    406  public:
    407   explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    408     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    409   }
    410 
    411   void Compile(XlaOpKernelContext* ctx) override {
    412     TensorShape var_shape, m_shape, v_shape;
    413     xla::XlaOp var, m, v;
    414     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    415     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
    416     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
    417 
    418     TensorShape beta1_power_shape = ctx->InputShape(3);
    419     TensorShape beta2_power_shape = ctx->InputShape(4);
    420     TensorShape lr_shape = ctx->InputShape(5);
    421     TensorShape beta1_shape = ctx->InputShape(6);
    422     TensorShape beta2_shape = ctx->InputShape(7);
    423     TensorShape epsilon_shape = ctx->InputShape(8);
    424     TensorShape grad_shape = ctx->InputShape(9);
    425 
    426     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
    427                 errors::InvalidArgument("beta1_power is not a scalar: ",
    428                                         beta1_power_shape.DebugString()));
    429     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape),
    430                 errors::InvalidArgument("beta2_power is not a scalar: ",
    431                                         beta2_power_shape.DebugString()));
    432     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    433                 errors::InvalidArgument("lr is not a scalar : ",
    434                                         lr_shape.DebugString()));
    435     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
    436                 errors::InvalidArgument("beta1 is not a scalar: ",
    437                                         beta1_shape.DebugString()));
    438     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
    439                 errors::InvalidArgument("beta2 is not a scalar: ",
    440                                         beta2_shape.DebugString()));
    441     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
    442                 errors::InvalidArgument("epsilon is not a scalar: ",
    443                                         epsilon_shape.DebugString()));
    444 
    445     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
    446                 errors::InvalidArgument("var and m do not have the same shape",
    447                                         var_shape.DebugString(), " ",
    448                                         m_shape.DebugString()));
    449     OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
    450                 errors::InvalidArgument("var and v do not have the same shape",
    451                                         var_shape.DebugString(), " ",
    452                                         v_shape.DebugString()));
    453     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    454                 errors::InvalidArgument(
    455                     "var and grad do not have the same shape",
    456                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    457 
    458     xla::XlaOp beta1_power = ctx->Input(3);
    459     xla::XlaOp beta2_power = ctx->Input(4);
    460     xla::XlaOp lr = ctx->Input(5);
    461     xla::XlaOp beta1 = ctx->Input(6);
    462     xla::XlaOp beta2 = ctx->Input(7);
    463     xla::XlaOp epsilon = ctx->Input(8);
    464     xla::XlaOp grad = ctx->Input(9);
    465 
    466     // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
    467     // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
    468     // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
    469     // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
    470 
    471     xla::XlaBuilder* b = ctx->builder();
    472     xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
    473 
    474     xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
    475     m = m + (grad - m) * (one - beta1);
    476     v = v + (xla::Square(grad) - v) * (one - beta2);
    477     var = var - m * alpha / (xla::Sqrt(v) + epsilon);
    478 
    479     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    480     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
    481     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
    482   }
    483 
    484  private:
    485   DataType dtype_;
    486 };
    487 REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
    488                 ResourceApplyAdam);
    489 
    490 class ResourceApplyAdaMax : public XlaOpKernel {
    491  public:
    492   explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    493     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    494   }
    495 
    496   void Compile(XlaOpKernelContext* ctx) override {
    497     TensorShape var_shape, m_shape, v_shape;
    498     xla::XlaOp var, m, v;
    499     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    500     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
    501     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
    502 
    503     TensorShape beta1_power_shape = ctx->InputShape(3);
    504     TensorShape lr_shape = ctx->InputShape(4);
    505     TensorShape beta1_shape = ctx->InputShape(5);
    506     TensorShape beta2_shape = ctx->InputShape(6);
    507     TensorShape epsilon_shape = ctx->InputShape(7);
    508     TensorShape grad_shape = ctx->InputShape(8);
    509 
    510     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
    511                 errors::InvalidArgument("beta1_power is not a scalar: ",
    512                                         beta1_power_shape.DebugString()));
    513     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    514                 errors::InvalidArgument("lr is not a scalar : ",
    515                                         lr_shape.DebugString()));
    516     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
    517                 errors::InvalidArgument("beta1 is not a scalar: ",
    518                                         beta1_shape.DebugString()));
    519     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
    520                 errors::InvalidArgument("beta2 is not a scalar: ",
    521                                         beta2_shape.DebugString()));
    522     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
    523                 errors::InvalidArgument("epsilon is not a scalar: ",
    524                                         epsilon_shape.DebugString()));
    525     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
    526                 errors::InvalidArgument("var and m do not have the same shape",
    527                                         var_shape.DebugString(), " ",
    528                                         m_shape.DebugString()));
    529     OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
    530                 errors::InvalidArgument("var and v do not have the same shape",
    531                                         var_shape.DebugString(), " ",
    532                                         v_shape.DebugString()));
    533     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    534                 errors::InvalidArgument(
    535                     "var and grad do not have the same shape",
    536                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    537 
    538     xla::XlaOp beta1_power = ctx->Input(3);
    539     xla::XlaOp lr = ctx->Input(4);
    540     xla::XlaOp beta1 = ctx->Input(5);
    541     xla::XlaOp beta2 = ctx->Input(6);
    542     xla::XlaOp epsilon = ctx->Input(7);
    543     xla::XlaOp grad = ctx->Input(8);
    544 
    545     xla::XlaOp one = xla::ScalarLike(lr, 1.0);
    546     m = beta1 * m + (one - beta1) * grad;
    547     v = xla::Max(beta2 * v, xla::Abs(grad));
    548     var = var - lr / (one - beta1_power) * (m / (v + epsilon));
    549 
    550     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    551     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
    552     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
    553   }
    554 
    555  private:
    556   DataType dtype_;
    557 };
    558 REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
    559                 ResourceApplyAdaMax);
    560 
    561 class ResourceApplyRMSProp : public XlaOpKernel {
    562  public:
    563   explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    564     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    565   }
    566 
    567   void Compile(XlaOpKernelContext* ctx) override {
    568     TensorShape var_shape, ms_shape, mom_shape, mg_shape;
    569     xla::XlaOp var, ms, mom, mg;
    570     OP_REQUIRES_OK(ctx,
    571                    ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
    572     if (centered_) {
    573       OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
    574     }
    575     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
    576     OP_REQUIRES_OK(ctx,
    577                    ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
    578 
    579     TensorShape lr_shape = ctx->InputShape("lr");
    580     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    581                 errors::InvalidArgument("lr is not a scalar: ",
    582                                         lr_shape.DebugString()));
    583     TensorShape rho_shape = ctx->InputShape("rho");
    584     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
    585                 errors::InvalidArgument("rho is not a scalar: ",
    586                                         rho_shape.DebugString()));
    587     TensorShape momentum_shape = ctx->InputShape("momentum");
    588     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
    589                 errors::InvalidArgument("momentum is not a scalar: ",
    590                                         momentum_shape.DebugString()));
    591     TensorShape epsilon_shape = ctx->InputShape("epsilon");
    592     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
    593                 errors::InvalidArgument("epsilon is not a scalar: ",
    594                                         epsilon_shape.DebugString()));
    595     TensorShape grad_shape = ctx->InputShape("grad");
    596 
    597     // var should be the same shape as mom and ms.
    598     OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
    599                 errors::InvalidArgument("var and ms do not have the same shape",
    600                                         var_shape.DebugString(), " ",
    601                                         ms_shape.DebugString()));
    602     OP_REQUIRES(ctx, var_shape.IsSameSize(mom_shape),
    603                 errors::InvalidArgument(
    604                     "var and mom do not have the same shape",
    605                     var_shape.DebugString(), " ", mom_shape.DebugString()));
    606     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    607                 errors::InvalidArgument(
    608                     "var and grad do not have the same shape",
    609                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    610 
    611     xla::XlaOp lr = ctx->Input("lr");
    612     xla::XlaOp rho = ctx->Input("rho");
    613     xla::XlaOp momentum = ctx->Input("momentum");
    614     xla::XlaOp epsilon = ctx->Input("epsilon");
    615     xla::XlaOp grad = ctx->Input("grad");
    616 
    617     // ms <- rho * ms_{t-1} + (1-rho) * grad * grad
    618     // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
    619     // var <- var - mom
    620     //
    621     // We use an alternate formulation of the ms equation:
    622     //
    623     //    ms <- ms + (grad**2 - ms) * (1 - rho)
    624     //
    625     // Which expands to:
    626     //
    627     //    ms <- ms + grad**2 - rho * grad ** 2 - ms + ms * rho
    628     //
    629     // Which simplifies to:
    630     //
    631     //    ms <- grad**2 (1 - rho) + ms * rho
    632     //
    633     // Which is the equation listed above.
    634     xla::XlaOp one = xla::ScalarLike(ms, 1.0);
    635     xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
    636     xla::XlaOp denominator;
    637     if (centered_) {
    638       mg = grad * (one - rho) + mg * rho;
    639       denominator = new_ms - xla::Square(mg) + epsilon;
    640     } else {
    641       denominator = new_ms + epsilon;
    642     }
    643     xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
    644     xla::XlaOp new_var = var - new_mom;
    645 
    646     OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
    647     if (centered_) {
    648       OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
    649     }
    650     OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
    651     OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
    652   }
    653 
    654  protected:
    655   bool centered_ = false;
    656 
    657  private:
    658   DataType dtype_;
    659 };
    660 REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
    661                 ResourceApplyRMSProp);
    662 
    663 class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
    664  public:
    665   explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
    666       : ResourceApplyRMSProp(ctx) {
    667     centered_ = true;
    668   }
    669 };
    670 REGISTER_XLA_OP(
    671     Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes),
    672     ResourceApplyCenteredRMSProp);
    673 
    674 void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
    675                  bool has_l2_shrinkage) {
    676   xla::XlaBuilder* b = ctx->builder();
    677 
    678   TensorShape var_shape, accum_shape, linear_shape;
    679   xla::XlaOp var, accum, linear;
    680   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
    681   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
    682   OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
    683 
    684   OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    685               errors::InvalidArgument(
    686                   "var and accum do not have the same shape",
    687                   var_shape.DebugString(), " ", accum_shape.DebugString()));
    688 
    689   OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape),
    690               errors::InvalidArgument(
    691                   "var and linear do not have the same shape",
    692                   var_shape.DebugString(), " ", linear_shape.DebugString()));
    693 
    694   TensorShape grad_shape = ctx->InputShape(3);
    695   TensorShape lr_shape = ctx->InputShape(4);
    696   TensorShape l1_shape = ctx->InputShape(5);
    697   TensorShape l2_shape = ctx->InputShape(6);
    698   TensorShape l2_shrinkage_shape;
    699   TensorShape lr_power_shape;
    700   if (has_l2_shrinkage) {
    701     l2_shrinkage_shape = ctx->InputShape(7);
    702     lr_power_shape = ctx->InputShape(8);
    703   } else {
    704     lr_power_shape = ctx->InputShape(7);
    705   }
    706 
    707   OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    708               errors::InvalidArgument("var and grad do not have the same shape",
    709                                       var_shape.DebugString(), " ",
    710                                       grad_shape.DebugString()));
    711 
    712   OP_REQUIRES(
    713       ctx, TensorShapeUtils::IsScalar(lr_shape),
    714       errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString()));
    715 
    716   OP_REQUIRES(
    717       ctx, TensorShapeUtils::IsScalar(l1_shape),
    718       errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString()));
    719 
    720   OP_REQUIRES(
    721       ctx, TensorShapeUtils::IsScalar(l2_shape),
    722       errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString()));
    723 
    724   if (has_l2_shrinkage) {
    725     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape),
    726                 errors::InvalidArgument("l2_shrinkage is not a scalar: ",
    727                                         l2_shrinkage_shape.DebugString()));
    728   }
    729 
    730   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape),
    731               errors::InvalidArgument("lr_power is not a scalar: ",
    732                                       lr_power_shape.DebugString()));
    733 
    734   xla::XlaOp grad = ctx->Input(3);
    735   xla::XlaOp lr = ctx->Input(4);
    736   xla::XlaOp l1 = ctx->Input(5);
    737   xla::XlaOp l2 = ctx->Input(6);
    738   xla::XlaOp l2_shrinkage;
    739   xla::XlaOp lr_power;
    740   if (has_l2_shrinkage) {
    741     l2_shrinkage = ctx->Input(7);
    742     lr_power = ctx->Input(8);
    743   } else {
    744     lr_power = ctx->Input(7);
    745   }
    746 
    747   // grad_to_use = grad + 2 * l2_shrinkage * var
    748   // new_accum = accum + grad * grad
    749   // linear += grad_to_use -
    750   //     (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var
    751   // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2
    752   // linear_clipped = clamp linear in [-l1, l1]
    753   // var = (linear_clipped - linear) / quadratic
    754   // accum = new_accum
    755 
    756   xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
    757   xla::XlaOp grad_to_use;
    758   if (has_l2_shrinkage) {
    759     grad_to_use = grad + two * l2_shrinkage * var;
    760   } else {
    761     grad_to_use = grad;
    762   }
    763 
    764   xla::XlaOp new_accum = accum + xla::Square(grad);
    765   xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
    766   xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
    767   linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
    768   xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1);
    769   xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2;
    770   var = (linear_clipped - linear) / quadratic;
    771   accum = new_accum;
    772 
    773   OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
    774   OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum));
    775   OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear));
    776 }
    777 
    778 class ResourceApplyFtrl : public XlaOpKernel {
    779  public:
    780   explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    781     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    782   }
    783 
    784   void Compile(XlaOpKernelContext* ctx) override {
    785     CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false);
    786   }
    787 
    788  private:
    789   DataType dtype_;
    790 };
    791 REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes),
    792                 ResourceApplyFtrl);
    793 
    794 class ResourceApplyFtrlV2 : public XlaOpKernel {
    795  public:
    796   explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    797     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    798   }
    799 
    800   void Compile(XlaOpKernelContext* ctx) override {
    801     CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true);
    802   }
    803 
    804  private:
    805   DataType dtype_;
    806 };
    807 REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
    808                 ResourceApplyFtrlV2);
    809 
    810 class ResourceApplyAdadelta : public XlaOpKernel {
    811  public:
    812   explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    813     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    814   }
    815 
    816   void Compile(XlaOpKernelContext* ctx) override {
    817     TensorShape var_shape, accum_shape, accum_update_shape;
    818     xla::XlaOp var, accum, accum_update;
    819     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    820     OP_REQUIRES_OK(ctx,
    821                    ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
    822     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
    823                                                &accum_update));
    824 
    825     TensorShape lr_shape = ctx->InputShape(3);
    826     TensorShape rho_shape = ctx->InputShape(4);
    827     TensorShape epsilon_shape = ctx->InputShape(5);
    828     TensorShape grad_shape = ctx->InputShape(6);
    829 
    830     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    831                 errors::InvalidArgument("lr is not a scalar: ",
    832                                         lr_shape.DebugString()));
    833 
    834     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
    835                 errors::InvalidArgument("rho is not a scalar: ",
    836                                         rho_shape.DebugString()));
    837 
    838     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
    839                 errors::InvalidArgument("epsilon is not a scalar: ",
    840                                         epsilon_shape.DebugString()));
    841 
    842     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
    843                 errors::InvalidArgument(
    844                     "var and accum do not have the same shape",
    845                     var_shape.DebugString(), " ", accum_shape.DebugString()));
    846 
    847     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    848                 errors::InvalidArgument(
    849                     "var and grad do not have the same shape",
    850                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    851 
    852     xla::XlaOp lr = ctx->Input(3);
    853     xla::XlaOp rho = ctx->Input(4);
    854     xla::XlaOp epsilon = ctx->Input(5);
    855     xla::XlaOp grad = ctx->Input(6);
    856 
    857     xla::XlaBuilder* b = ctx->builder();
    858     xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
    859 
    860     accum = rho * accum + (one - rho) * xla::Square(grad);
    861     xla::XlaOp update =
    862         xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad;
    863     accum_update = rho * accum_update + (one - rho) * xla::Square(update);
    864     var = var - update * lr;
    865     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    866     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
    867     OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
    868   }
    869 
    870  private:
    871   DataType dtype_;
    872 };
    873 REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
    874                 ResourceApplyAdadelta);
    875 
    876 class ResourceApplySignBase : public XlaOpKernel {
    877  public:
    878   explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    879     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    880   }
    881 
    882   void Compile(XlaOpKernelContext* ctx) override {
    883     TensorShape var_shape, m_shape;
    884     xla::XlaOp var, m;
    885     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
    886     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
    887     OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
    888                 errors::InvalidArgument("var and m do not have the same shape",
    889                                         var_shape.DebugString(), " ",
    890                                         m_shape.DebugString()));
    891     TensorShape grad_shape = ctx->InputShape(6);
    892     OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
    893                 errors::InvalidArgument(
    894                     "var and grad do not have the same shape",
    895                     var_shape.DebugString(), " ", grad_shape.DebugString()));
    896     CheckScalarParams(ctx);
    897 
    898     xla::XlaOp lr = ctx->Input(2);
    899     xla::XlaOp alpha = ctx->Input(3);
    900     xla::XlaOp sign_decay = ctx->Input(4);
    901     xla::XlaOp beta = ctx->Input(5);
    902     xla::XlaOp grad = ctx->Input(6);
    903 
    904     m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
    905     xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
    906 
    907     xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
    908     var = var - lr * grad_scale * grad;
    909     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
    910     OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
    911   }
    912 
    913   virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
    914     TensorShape lr_shape = ctx->InputShape(2);
    915     TensorShape sign_decay_shape = ctx->InputShape(4);
    916     TensorShape beta_shape = ctx->InputShape(5);
    917 
    918     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
    919                 errors::InvalidArgument("lr is not a scalar: ",
    920                                         lr_shape.DebugString()));
    921 
    922     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
    923                 errors::InvalidArgument("sign_decay is not a scalar: ",
    924                                         sign_decay_shape.DebugString()));
    925 
    926     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
    927                 errors::InvalidArgument("beta is not a scalar: ",
    928                                         beta_shape.DebugString()));
    929   }
    930 
    931   virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
    932                                           xla::XlaOp decay) = 0;
    933 
    934  private:
    935   DataType dtype_;
    936 };
    937 
    938 class ResourceApplyAddSign : public ResourceApplySignBase {
    939  public:
    940   explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
    941       : ResourceApplySignBase(ctx) {}
    942 
    943   void CheckScalarParams(XlaOpKernelContext* ctx) override {
    944     ResourceApplySignBase::CheckScalarParams(ctx);
    945     TensorShape alpha_shape = ctx->InputShape(3);
    946     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
    947                 errors::InvalidArgument("alpha is not a scalar: ",
    948                                         alpha_shape.DebugString()));
    949   }
    950 
    951   xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
    952     return alpha + decay;
    953   }
    954 };
    955 REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
    956                 ResourceApplyAddSign);
    957 
    958 class ResourceApplyPowerSign : public ResourceApplySignBase {
    959  public:
    960   explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
    961       : ResourceApplySignBase(ctx) {}
    962 
    963   void CheckScalarParams(XlaOpKernelContext* ctx) override {
    964     ResourceApplySignBase::CheckScalarParams(ctx);
    965     TensorShape logbase_shape = ctx->InputShape(3);
    966     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
    967                 errors::InvalidArgument("logbase is not a scalar: ",
    968                                         logbase_shape.DebugString()));
    969   }
    970 
    971   xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
    972     return xla::Exp(alpha * decay);
    973   }
    974 };
    975 REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
    976                 ResourceApplyPowerSign);
    977 
    978 }  // namespace
    979 }  // namespace tensorflow
    980