1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 #include "tensorflow/compiler/xla/client/lib/constants.h" 20 #include "tensorflow/compiler/xla/client/lib/math.h" 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/literal.h" 23 #include "tensorflow/core/framework/kernel_def_builder.h" 24 #include "tensorflow/core/framework/types.h" 25 26 namespace tensorflow { 27 namespace { 28 29 class ResourceApplyGradientDescent : public XlaOpKernel { 30 public: 31 explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx) 32 : XlaOpKernel(ctx) {} 33 void Compile(XlaOpKernelContext* ctx) override { 34 xla::XlaOp handle; 35 DataType type = ctx->input_type(1); 36 TensorShape var_shape; 37 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); 38 39 TensorShape alpha_shape = ctx->InputShape(1); 40 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), 41 errors::InvalidArgument("alpha is not a scalar: ", 42 alpha_shape.DebugString())); 43 44 TensorShape delta_shape = ctx->InputShape(2); 45 OP_REQUIRES( 46 ctx, var_shape.IsSameSize(delta_shape), 47 errors::InvalidArgument("var and delta do not have the same shape: ", 48 var_shape.DebugString(), " vs ", 49 delta_shape.DebugString())); 50 51 handle = handle - ctx->Input(1) * ctx->Input(2); 52 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); 53 } 54 }; 55 REGISTER_XLA_OP( 56 Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), 57 ResourceApplyGradientDescent); 58 59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, 60 xla::XlaOp l1, xla::XlaOp l2, 61 xla::XlaOp grad) { 62 xla::XlaOp one = xla::ScalarLike(lr, 1.0); 63 xla::XlaOp zero = xla::ScalarLike(lr, 0.0); 64 xla::XlaOp prox_var = var - grad * lr; 65 xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * 66 xla::Max(xla::Abs(prox_var) - lr * l1, zero) / 67 (one + lr * l2); 68 xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); 69 return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); 70 } 71 72 class ResourceApplyProximalGradientDescent : public XlaOpKernel { 73 public: 74 explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx) 75 : XlaOpKernel(ctx) { 76 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 77 } 78 79 void Compile(XlaOpKernelContext* ctx) override { 80 xla::XlaOp var; 81 TensorShape var_shape; 82 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 83 84 TensorShape alpha_shape = ctx->InputShape(1); 85 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), 86 errors::InvalidArgument("alpha is not a scalar: ", 87 alpha_shape.DebugString())); 88 TensorShape l1_shape = ctx->InputShape(2); 89 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), 90 errors::InvalidArgument("l1 is not a scalar: ", 91 l1_shape.DebugString())); 92 TensorShape l2_shape = ctx->InputShape(3); 93 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), 94 errors::InvalidArgument("l2 is not a scalar: ", 95 l2_shape.DebugString())); 96 TensorShape delta_shape = ctx->InputShape(4); 97 OP_REQUIRES( 98 ctx, var_shape.IsSameSize(delta_shape), 99 errors::InvalidArgument("var and delta do not have the same shape: ", 100 var_shape.DebugString(), " vs ", 101 delta_shape.DebugString())); 102 xla::XlaOp alpha = ctx->Input(1); 103 xla::XlaOp l1 = ctx->Input(2); 104 xla::XlaOp l2 = ctx->Input(3); 105 xla::XlaOp delta = ctx->Input(4); 106 var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta); 107 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 108 } 109 110 private: 111 DataType dtype_; 112 }; 113 REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") 114 .TypeConstraint("T", kFloatTypes), 115 ResourceApplyProximalGradientDescent); 116 117 class ResourceApplyMomentum : public XlaOpKernel { 118 public: 119 explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 120 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); 121 } 122 123 void Compile(XlaOpKernelContext* ctx) override { 124 DataType type = ctx->input_type(2); 125 126 TensorShape var_shape, accum_shape; 127 xla::XlaOp var, accum; 128 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); 129 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); 130 131 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 132 errors::InvalidArgument( 133 "var and accum do not have the same shape", 134 var_shape.DebugString(), " ", accum_shape.DebugString())); 135 136 TensorShape lr_shape = ctx->InputShape(2); 137 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 138 errors::InvalidArgument("lr is not a scalar: ", 139 lr_shape.DebugString())); 140 141 TensorShape grad_shape = ctx->InputShape(3); 142 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 143 errors::InvalidArgument( 144 "var and grad do not have the same shape", 145 var_shape.DebugString(), " ", grad_shape.DebugString())); 146 147 TensorShape momentum_shape = ctx->InputShape(4); 148 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), 149 errors::InvalidArgument("momentum is not a scalar: ", 150 momentum_shape.DebugString())); 151 152 xla::XlaOp lr = ctx->Input(2); 153 xla::XlaOp grad = ctx->Input(3); 154 xla::XlaOp momentum = ctx->Input(4); 155 156 accum = accum * momentum + grad; 157 if (use_nesterov_) { 158 // See https://github.com/tensorflow/tensorflow/pull/2798 for an 159 // explanation of the reparameterization used here. 160 var = var - (grad * lr + accum * momentum * lr); 161 } else { 162 var = var - accum * lr; 163 } 164 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); 165 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); 166 } 167 168 private: 169 bool use_nesterov_; 170 }; 171 REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), 172 ResourceApplyMomentum); 173 174 class ResourceApplyKerasMomentum : public XlaOpKernel { 175 public: 176 explicit ResourceApplyKerasMomentum(OpKernelConstruction* ctx) 177 : XlaOpKernel(ctx) { 178 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); 179 } 180 181 void Compile(XlaOpKernelContext* ctx) override { 182 DataType type = ctx->input_type(2); 183 184 TensorShape var_shape, accum_shape; 185 xla::XlaOp var, accum; 186 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); 187 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); 188 189 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 190 errors::InvalidArgument( 191 "var and accum do not have the same shape", 192 var_shape.DebugString(), " ", accum_shape.DebugString())); 193 194 TensorShape lr_shape = ctx->InputShape(2); 195 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 196 errors::InvalidArgument("lr is not a scalar: ", 197 lr_shape.DebugString())); 198 199 TensorShape grad_shape = ctx->InputShape(3); 200 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 201 errors::InvalidArgument( 202 "var and grad do not have the same shape", 203 var_shape.DebugString(), " ", grad_shape.DebugString())); 204 205 TensorShape momentum_shape = ctx->InputShape(4); 206 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), 207 errors::InvalidArgument("momentum is not a scalar: ", 208 momentum_shape.DebugString())); 209 210 xla::XlaOp lr = ctx->Input(2); 211 xla::XlaOp grad = ctx->Input(3); 212 xla::XlaOp momentum = ctx->Input(4); 213 214 accum = accum * momentum - grad * lr; 215 if (use_nesterov_) { 216 // See https://github.com/tensorflow/tensorflow/pull/2798 for an 217 // explanation of the reparameterization used here. 218 var = var + accum * momentum - grad * lr; 219 } else { 220 var = var + accum; 221 } 222 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); 223 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); 224 } 225 226 private: 227 bool use_nesterov_; 228 }; 229 REGISTER_XLA_OP( 230 Name("ResourceApplyKerasMomentum").TypeConstraint("T", kFloatTypes), 231 ResourceApplyKerasMomentum); 232 233 class ResourceApplyAdagrad : public XlaOpKernel { 234 public: 235 explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 236 237 void Compile(XlaOpKernelContext* ctx) override { 238 DataType type = ctx->input_type(2); 239 240 TensorShape var_shape, accum_shape; 241 xla::XlaOp var, accum; 242 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); 243 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); 244 245 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 246 errors::InvalidArgument( 247 "var and accum do not have the same shape", 248 var_shape.DebugString(), " ", accum_shape.DebugString())); 249 250 TensorShape lr_shape = ctx->InputShape(2); 251 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 252 errors::InvalidArgument("lr is not a scalar: ", 253 lr_shape.DebugString())); 254 255 TensorShape grad_shape = ctx->InputShape(3); 256 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 257 errors::InvalidArgument( 258 "var and grad do not have the same shape", 259 var_shape.DebugString(), " ", grad_shape.DebugString())); 260 261 xla::XlaOp lr = ctx->Input(2); 262 xla::XlaOp grad = ctx->Input(3); 263 264 accum = accum + xla::Square(grad); 265 var = var - grad * lr * xla::Rsqrt(accum); 266 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); 267 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); 268 } 269 }; 270 REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), 271 ResourceApplyAdagrad); 272 273 class ResourceApplyProximalAdagrad : public XlaOpKernel { 274 public: 275 explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) 276 : XlaOpKernel(ctx) { 277 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 278 } 279 280 void Compile(XlaOpKernelContext* ctx) override { 281 TensorShape var_shape, accum_shape; 282 xla::XlaOp var, accum; 283 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 284 OP_REQUIRES_OK(ctx, 285 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); 286 287 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 288 errors::InvalidArgument( 289 "var and accum do not have the same shape", 290 var_shape.DebugString(), " ", accum_shape.DebugString())); 291 292 TensorShape lr_shape = ctx->InputShape(2); 293 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 294 errors::InvalidArgument("lr is not a scalar: ", 295 lr_shape.DebugString())); 296 TensorShape l1_shape = ctx->InputShape(3); 297 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), 298 errors::InvalidArgument("l1 is not a scalar: ", 299 l1_shape.DebugString())); 300 TensorShape l2_shape = ctx->InputShape(4); 301 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), 302 errors::InvalidArgument("l2 is not a scalar: ", 303 l2_shape.DebugString())); 304 TensorShape grad_shape = ctx->InputShape(5); 305 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 306 errors::InvalidArgument( 307 "var and grad do not have the same shape: ", 308 var_shape.DebugString(), " vs ", grad_shape.DebugString())); 309 310 xla::XlaOp lr = ctx->Input(2); 311 xla::XlaOp l1 = ctx->Input(3); 312 xla::XlaOp l2 = ctx->Input(4); 313 xla::XlaOp grad = ctx->Input(5); 314 accum = accum + xla::Square(grad); 315 // Adagrad learning rate. 316 xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum); 317 var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad); 318 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 319 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); 320 } 321 322 private: 323 DataType dtype_; 324 }; 325 REGISTER_XLA_OP( 326 Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), 327 ResourceApplyProximalAdagrad); 328 329 class ResourceApplyAdagradDA : public XlaOpKernel { 330 public: 331 explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx) 332 : XlaOpKernel(ctx) { 333 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 334 } 335 336 void Compile(XlaOpKernelContext* ctx) override { 337 TensorShape var_shape, accum_shape, squared_accum_shape; 338 xla::XlaOp var, accum, squared_accum; 339 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 340 OP_REQUIRES_OK(ctx, 341 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); 342 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape, 343 &squared_accum)); 344 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 345 errors::InvalidArgument( 346 "var and accum do not have the same shape", 347 var_shape.DebugString(), " ", accum_shape.DebugString())); 348 OP_REQUIRES( 349 ctx, var_shape.IsSameSize(squared_accum_shape), 350 errors::InvalidArgument( 351 "var and squared accum do not have the same shape", 352 var_shape.DebugString(), " ", squared_accum_shape.DebugString())); 353 354 TensorShape grad_shape = ctx->InputShape(3); 355 TensorShape lr_shape = ctx->InputShape(4); 356 TensorShape l1_shape = ctx->InputShape(5); 357 TensorShape l2_shape = ctx->InputShape(6); 358 TensorShape global_step_shape = ctx->InputShape(7); 359 360 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 361 errors::InvalidArgument( 362 "var and grad do not have the same shape", 363 var_shape.DebugString(), " ", grad_shape.DebugString())); 364 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 365 errors::InvalidArgument("lr is not a scalar: ", 366 lr_shape.DebugString())); 367 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), 368 errors::InvalidArgument("l1 is not a scalar: ", 369 l1_shape.DebugString())); 370 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), 371 errors::InvalidArgument("l2 is not a scalar: ", 372 l2_shape.DebugString())); 373 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape), 374 errors::InvalidArgument("global step is not a scalar: ", 375 global_step_shape.DebugString())); 376 377 xla::XlaOp grad = ctx->Input(3); 378 xla::XlaOp lr = ctx->Input(4); 379 xla::XlaOp l1 = ctx->Input(5); 380 xla::XlaOp l2 = ctx->Input(6); 381 xla::XlaOp global_step = 382 XlaHelpers::ConvertElementType(ctx->Input(7), dtype_); 383 384 accum = accum + grad; 385 squared_accum = squared_accum + xla::Square(grad); 386 xla::XlaOp zero = xla::ScalarLike(lr, 0.0); 387 xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum); 388 xla::XlaOp l1_le_zero = -lr * accum / denominator; 389 xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) * 390 xla::Max(xla::Abs(accum) - global_step * l1, zero) / 391 denominator; 392 393 var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); 394 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 395 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); 396 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum)); 397 } 398 399 private: 400 DataType dtype_; 401 }; 402 REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes), 403 ResourceApplyAdagradDA); 404 405 class ResourceApplyAdam : public XlaOpKernel { 406 public: 407 explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 408 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 409 } 410 411 void Compile(XlaOpKernelContext* ctx) override { 412 TensorShape var_shape, m_shape, v_shape; 413 xla::XlaOp var, m, v; 414 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 415 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); 416 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); 417 418 TensorShape beta1_power_shape = ctx->InputShape(3); 419 TensorShape beta2_power_shape = ctx->InputShape(4); 420 TensorShape lr_shape = ctx->InputShape(5); 421 TensorShape beta1_shape = ctx->InputShape(6); 422 TensorShape beta2_shape = ctx->InputShape(7); 423 TensorShape epsilon_shape = ctx->InputShape(8); 424 TensorShape grad_shape = ctx->InputShape(9); 425 426 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), 427 errors::InvalidArgument("beta1_power is not a scalar: ", 428 beta1_power_shape.DebugString())); 429 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape), 430 errors::InvalidArgument("beta2_power is not a scalar: ", 431 beta2_power_shape.DebugString())); 432 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 433 errors::InvalidArgument("lr is not a scalar : ", 434 lr_shape.DebugString())); 435 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), 436 errors::InvalidArgument("beta1 is not a scalar: ", 437 beta1_shape.DebugString())); 438 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), 439 errors::InvalidArgument("beta2 is not a scalar: ", 440 beta2_shape.DebugString())); 441 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), 442 errors::InvalidArgument("epsilon is not a scalar: ", 443 epsilon_shape.DebugString())); 444 445 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), 446 errors::InvalidArgument("var and m do not have the same shape", 447 var_shape.DebugString(), " ", 448 m_shape.DebugString())); 449 OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), 450 errors::InvalidArgument("var and v do not have the same shape", 451 var_shape.DebugString(), " ", 452 v_shape.DebugString())); 453 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 454 errors::InvalidArgument( 455 "var and grad do not have the same shape", 456 var_shape.DebugString(), " ", grad_shape.DebugString())); 457 458 xla::XlaOp beta1_power = ctx->Input(3); 459 xla::XlaOp beta2_power = ctx->Input(4); 460 xla::XlaOp lr = ctx->Input(5); 461 xla::XlaOp beta1 = ctx->Input(6); 462 xla::XlaOp beta2 = ctx->Input(7); 463 xla::XlaOp epsilon = ctx->Input(8); 464 xla::XlaOp grad = ctx->Input(9); 465 466 // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 467 // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t 468 // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t 469 // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) 470 471 xla::XlaBuilder* b = ctx->builder(); 472 xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); 473 474 xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); 475 m = m + (grad - m) * (one - beta1); 476 v = v + (xla::Square(grad) - v) * (one - beta2); 477 var = var - m * alpha / (xla::Sqrt(v) + epsilon); 478 479 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 480 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); 481 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); 482 } 483 484 private: 485 DataType dtype_; 486 }; 487 REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), 488 ResourceApplyAdam); 489 490 class ResourceApplyAdaMax : public XlaOpKernel { 491 public: 492 explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 493 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 494 } 495 496 void Compile(XlaOpKernelContext* ctx) override { 497 TensorShape var_shape, m_shape, v_shape; 498 xla::XlaOp var, m, v; 499 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 500 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); 501 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); 502 503 TensorShape beta1_power_shape = ctx->InputShape(3); 504 TensorShape lr_shape = ctx->InputShape(4); 505 TensorShape beta1_shape = ctx->InputShape(5); 506 TensorShape beta2_shape = ctx->InputShape(6); 507 TensorShape epsilon_shape = ctx->InputShape(7); 508 TensorShape grad_shape = ctx->InputShape(8); 509 510 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), 511 errors::InvalidArgument("beta1_power is not a scalar: ", 512 beta1_power_shape.DebugString())); 513 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 514 errors::InvalidArgument("lr is not a scalar : ", 515 lr_shape.DebugString())); 516 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), 517 errors::InvalidArgument("beta1 is not a scalar: ", 518 beta1_shape.DebugString())); 519 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), 520 errors::InvalidArgument("beta2 is not a scalar: ", 521 beta2_shape.DebugString())); 522 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), 523 errors::InvalidArgument("epsilon is not a scalar: ", 524 epsilon_shape.DebugString())); 525 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), 526 errors::InvalidArgument("var and m do not have the same shape", 527 var_shape.DebugString(), " ", 528 m_shape.DebugString())); 529 OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), 530 errors::InvalidArgument("var and v do not have the same shape", 531 var_shape.DebugString(), " ", 532 v_shape.DebugString())); 533 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 534 errors::InvalidArgument( 535 "var and grad do not have the same shape", 536 var_shape.DebugString(), " ", grad_shape.DebugString())); 537 538 xla::XlaOp beta1_power = ctx->Input(3); 539 xla::XlaOp lr = ctx->Input(4); 540 xla::XlaOp beta1 = ctx->Input(5); 541 xla::XlaOp beta2 = ctx->Input(6); 542 xla::XlaOp epsilon = ctx->Input(7); 543 xla::XlaOp grad = ctx->Input(8); 544 545 xla::XlaOp one = xla::ScalarLike(lr, 1.0); 546 m = beta1 * m + (one - beta1) * grad; 547 v = xla::Max(beta2 * v, xla::Abs(grad)); 548 var = var - lr / (one - beta1_power) * (m / (v + epsilon)); 549 550 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 551 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); 552 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); 553 } 554 555 private: 556 DataType dtype_; 557 }; 558 REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes), 559 ResourceApplyAdaMax); 560 561 class ResourceApplyRMSProp : public XlaOpKernel { 562 public: 563 explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 564 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 565 } 566 567 void Compile(XlaOpKernelContext* ctx) override { 568 TensorShape var_shape, ms_shape, mom_shape, mg_shape; 569 xla::XlaOp var, ms, mom, mg; 570 OP_REQUIRES_OK(ctx, 571 ctx->ReadVariableInput("var", dtype_, &var_shape, &var)); 572 if (centered_) { 573 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg)); 574 } 575 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms)); 576 OP_REQUIRES_OK(ctx, 577 ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom)); 578 579 TensorShape lr_shape = ctx->InputShape("lr"); 580 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 581 errors::InvalidArgument("lr is not a scalar: ", 582 lr_shape.DebugString())); 583 TensorShape rho_shape = ctx->InputShape("rho"); 584 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), 585 errors::InvalidArgument("rho is not a scalar: ", 586 rho_shape.DebugString())); 587 TensorShape momentum_shape = ctx->InputShape("momentum"); 588 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), 589 errors::InvalidArgument("momentum is not a scalar: ", 590 momentum_shape.DebugString())); 591 TensorShape epsilon_shape = ctx->InputShape("epsilon"); 592 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), 593 errors::InvalidArgument("epsilon is not a scalar: ", 594 epsilon_shape.DebugString())); 595 TensorShape grad_shape = ctx->InputShape("grad"); 596 597 // var should be the same shape as mom and ms. 598 OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape), 599 errors::InvalidArgument("var and ms do not have the same shape", 600 var_shape.DebugString(), " ", 601 ms_shape.DebugString())); 602 OP_REQUIRES(ctx, var_shape.IsSameSize(mom_shape), 603 errors::InvalidArgument( 604 "var and mom do not have the same shape", 605 var_shape.DebugString(), " ", mom_shape.DebugString())); 606 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 607 errors::InvalidArgument( 608 "var and grad do not have the same shape", 609 var_shape.DebugString(), " ", grad_shape.DebugString())); 610 611 xla::XlaOp lr = ctx->Input("lr"); 612 xla::XlaOp rho = ctx->Input("rho"); 613 xla::XlaOp momentum = ctx->Input("momentum"); 614 xla::XlaOp epsilon = ctx->Input("epsilon"); 615 xla::XlaOp grad = ctx->Input("grad"); 616 617 // ms <- rho * ms_{t-1} + (1-rho) * grad * grad 618 // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) 619 // var <- var - mom 620 // 621 // We use an alternate formulation of the ms equation: 622 // 623 // ms <- ms + (grad**2 - ms) * (1 - rho) 624 // 625 // Which expands to: 626 // 627 // ms <- ms + grad**2 - rho * grad ** 2 - ms + ms * rho 628 // 629 // Which simplifies to: 630 // 631 // ms <- grad**2 (1 - rho) + ms * rho 632 // 633 // Which is the equation listed above. 634 xla::XlaOp one = xla::ScalarLike(ms, 1.0); 635 xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho; 636 xla::XlaOp denominator; 637 if (centered_) { 638 mg = grad * (one - rho) + mg * rho; 639 denominator = new_ms - xla::Square(mg) + epsilon; 640 } else { 641 denominator = new_ms + epsilon; 642 } 643 xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator); 644 xla::XlaOp new_var = var - new_mom; 645 646 OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var)); 647 if (centered_) { 648 OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg)); 649 } 650 OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms)); 651 OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom)); 652 } 653 654 protected: 655 bool centered_ = false; 656 657 private: 658 DataType dtype_; 659 }; 660 REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), 661 ResourceApplyRMSProp); 662 663 class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { 664 public: 665 explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx) 666 : ResourceApplyRMSProp(ctx) { 667 centered_ = true; 668 } 669 }; 670 REGISTER_XLA_OP( 671 Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes), 672 ResourceApplyCenteredRMSProp); 673 674 void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, 675 bool has_l2_shrinkage) { 676 xla::XlaBuilder* b = ctx->builder(); 677 678 TensorShape var_shape, accum_shape, linear_shape; 679 xla::XlaOp var, accum, linear; 680 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); 681 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); 682 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); 683 684 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 685 errors::InvalidArgument( 686 "var and accum do not have the same shape", 687 var_shape.DebugString(), " ", accum_shape.DebugString())); 688 689 OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), 690 errors::InvalidArgument( 691 "var and linear do not have the same shape", 692 var_shape.DebugString(), " ", linear_shape.DebugString())); 693 694 TensorShape grad_shape = ctx->InputShape(3); 695 TensorShape lr_shape = ctx->InputShape(4); 696 TensorShape l1_shape = ctx->InputShape(5); 697 TensorShape l2_shape = ctx->InputShape(6); 698 TensorShape l2_shrinkage_shape; 699 TensorShape lr_power_shape; 700 if (has_l2_shrinkage) { 701 l2_shrinkage_shape = ctx->InputShape(7); 702 lr_power_shape = ctx->InputShape(8); 703 } else { 704 lr_power_shape = ctx->InputShape(7); 705 } 706 707 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 708 errors::InvalidArgument("var and grad do not have the same shape", 709 var_shape.DebugString(), " ", 710 grad_shape.DebugString())); 711 712 OP_REQUIRES( 713 ctx, TensorShapeUtils::IsScalar(lr_shape), 714 errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); 715 716 OP_REQUIRES( 717 ctx, TensorShapeUtils::IsScalar(l1_shape), 718 errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString())); 719 720 OP_REQUIRES( 721 ctx, TensorShapeUtils::IsScalar(l2_shape), 722 errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString())); 723 724 if (has_l2_shrinkage) { 725 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape), 726 errors::InvalidArgument("l2_shrinkage is not a scalar: ", 727 l2_shrinkage_shape.DebugString())); 728 } 729 730 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), 731 errors::InvalidArgument("lr_power is not a scalar: ", 732 lr_power_shape.DebugString())); 733 734 xla::XlaOp grad = ctx->Input(3); 735 xla::XlaOp lr = ctx->Input(4); 736 xla::XlaOp l1 = ctx->Input(5); 737 xla::XlaOp l2 = ctx->Input(6); 738 xla::XlaOp l2_shrinkage; 739 xla::XlaOp lr_power; 740 if (has_l2_shrinkage) { 741 l2_shrinkage = ctx->Input(7); 742 lr_power = ctx->Input(8); 743 } else { 744 lr_power = ctx->Input(7); 745 } 746 747 // grad_to_use = grad + 2 * l2_shrinkage * var 748 // new_accum = accum + grad * grad 749 // linear += grad_to_use - 750 // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var 751 // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 752 // linear_clipped = clamp linear in [-l1, l1] 753 // var = (linear_clipped - linear) / quadratic 754 // accum = new_accum 755 756 xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); 757 xla::XlaOp grad_to_use; 758 if (has_l2_shrinkage) { 759 grad_to_use = grad + two * l2_shrinkage * var; 760 } else { 761 grad_to_use = grad; 762 } 763 764 xla::XlaOp new_accum = accum + xla::Square(grad); 765 xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); 766 xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); 767 linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; 768 xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); 769 xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; 770 var = (linear_clipped - linear) / quadratic; 771 accum = new_accum; 772 773 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); 774 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum)); 775 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear)); 776 } 777 778 class ResourceApplyFtrl : public XlaOpKernel { 779 public: 780 explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 781 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 782 } 783 784 void Compile(XlaOpKernelContext* ctx) override { 785 CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false); 786 } 787 788 private: 789 DataType dtype_; 790 }; 791 REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes), 792 ResourceApplyFtrl); 793 794 class ResourceApplyFtrlV2 : public XlaOpKernel { 795 public: 796 explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 797 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 798 } 799 800 void Compile(XlaOpKernelContext* ctx) override { 801 CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true); 802 } 803 804 private: 805 DataType dtype_; 806 }; 807 REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), 808 ResourceApplyFtrlV2); 809 810 class ResourceApplyAdadelta : public XlaOpKernel { 811 public: 812 explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 813 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 814 } 815 816 void Compile(XlaOpKernelContext* ctx) override { 817 TensorShape var_shape, accum_shape, accum_update_shape; 818 xla::XlaOp var, accum, accum_update; 819 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 820 OP_REQUIRES_OK(ctx, 821 ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); 822 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, 823 &accum_update)); 824 825 TensorShape lr_shape = ctx->InputShape(3); 826 TensorShape rho_shape = ctx->InputShape(4); 827 TensorShape epsilon_shape = ctx->InputShape(5); 828 TensorShape grad_shape = ctx->InputShape(6); 829 830 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 831 errors::InvalidArgument("lr is not a scalar: ", 832 lr_shape.DebugString())); 833 834 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), 835 errors::InvalidArgument("rho is not a scalar: ", 836 rho_shape.DebugString())); 837 838 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), 839 errors::InvalidArgument("epsilon is not a scalar: ", 840 epsilon_shape.DebugString())); 841 842 OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), 843 errors::InvalidArgument( 844 "var and accum do not have the same shape", 845 var_shape.DebugString(), " ", accum_shape.DebugString())); 846 847 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 848 errors::InvalidArgument( 849 "var and grad do not have the same shape", 850 var_shape.DebugString(), " ", grad_shape.DebugString())); 851 852 xla::XlaOp lr = ctx->Input(3); 853 xla::XlaOp rho = ctx->Input(4); 854 xla::XlaOp epsilon = ctx->Input(5); 855 xla::XlaOp grad = ctx->Input(6); 856 857 xla::XlaBuilder* b = ctx->builder(); 858 xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); 859 860 accum = rho * accum + (one - rho) * xla::Square(grad); 861 xla::XlaOp update = 862 xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad; 863 accum_update = rho * accum_update + (one - rho) * xla::Square(update); 864 var = var - update * lr; 865 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 866 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); 867 OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); 868 } 869 870 private: 871 DataType dtype_; 872 }; 873 REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), 874 ResourceApplyAdadelta); 875 876 class ResourceApplySignBase : public XlaOpKernel { 877 public: 878 explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 879 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); 880 } 881 882 void Compile(XlaOpKernelContext* ctx) override { 883 TensorShape var_shape, m_shape; 884 xla::XlaOp var, m; 885 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); 886 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); 887 OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), 888 errors::InvalidArgument("var and m do not have the same shape", 889 var_shape.DebugString(), " ", 890 m_shape.DebugString())); 891 TensorShape grad_shape = ctx->InputShape(6); 892 OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), 893 errors::InvalidArgument( 894 "var and grad do not have the same shape", 895 var_shape.DebugString(), " ", grad_shape.DebugString())); 896 CheckScalarParams(ctx); 897 898 xla::XlaOp lr = ctx->Input(2); 899 xla::XlaOp alpha = ctx->Input(3); 900 xla::XlaOp sign_decay = ctx->Input(4); 901 xla::XlaOp beta = ctx->Input(5); 902 xla::XlaOp grad = ctx->Input(6); 903 904 m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta); 905 xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay; 906 907 xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay); 908 var = var - lr * grad_scale * grad; 909 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); 910 OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); 911 } 912 913 virtual void CheckScalarParams(XlaOpKernelContext* ctx) { 914 TensorShape lr_shape = ctx->InputShape(2); 915 TensorShape sign_decay_shape = ctx->InputShape(4); 916 TensorShape beta_shape = ctx->InputShape(5); 917 918 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), 919 errors::InvalidArgument("lr is not a scalar: ", 920 lr_shape.DebugString())); 921 922 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape), 923 errors::InvalidArgument("sign_decay is not a scalar: ", 924 sign_decay_shape.DebugString())); 925 926 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape), 927 errors::InvalidArgument("beta is not a scalar: ", 928 beta_shape.DebugString())); 929 } 930 931 virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, 932 xla::XlaOp decay) = 0; 933 934 private: 935 DataType dtype_; 936 }; 937 938 class ResourceApplyAddSign : public ResourceApplySignBase { 939 public: 940 explicit ResourceApplyAddSign(OpKernelConstruction* ctx) 941 : ResourceApplySignBase(ctx) {} 942 943 void CheckScalarParams(XlaOpKernelContext* ctx) override { 944 ResourceApplySignBase::CheckScalarParams(ctx); 945 TensorShape alpha_shape = ctx->InputShape(3); 946 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), 947 errors::InvalidArgument("alpha is not a scalar: ", 948 alpha_shape.DebugString())); 949 } 950 951 xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { 952 return alpha + decay; 953 } 954 }; 955 REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes), 956 ResourceApplyAddSign); 957 958 class ResourceApplyPowerSign : public ResourceApplySignBase { 959 public: 960 explicit ResourceApplyPowerSign(OpKernelConstruction* ctx) 961 : ResourceApplySignBase(ctx) {} 962 963 void CheckScalarParams(XlaOpKernelContext* ctx) override { 964 ResourceApplySignBase::CheckScalarParams(ctx); 965 TensorShape logbase_shape = ctx->InputShape(3); 966 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape), 967 errors::InvalidArgument("logbase is not a scalar: ", 968 logbase_shape.DebugString())); 969 } 970 971 xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { 972 return xla::Exp(alpha * decay); 973 } 974 }; 975 REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes), 976 ResourceApplyPowerSign); 977 978 } // namespace 979 } // namespace tensorflow 980