1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/kernels/bounds_check.h" 23 #include "tensorflow/core/kernels/training_op_helpers.h" 24 #include "tensorflow/core/kernels/training_ops.h" 25 #include "tensorflow/core/kernels/variable_ops.h" 26 27 #ifdef TENSORFLOW_USE_SYCL 28 #include "tensorflow/core/common_runtime/sycl/sycl_util.h" 29 #endif // TENSORFLOW_USE_SYCL 30 31 namespace tensorflow { 32 33 using CPUDevice = Eigen::ThreadPoolDevice; 34 using GPUDevice = Eigen::GpuDevice; 35 using SYCLDevice = Eigen::SyclDevice; 36 37 namespace { 38 template <class T> 39 inline T sgn(const T x) { 40 T zero(0); 41 T one(1); 42 return (x == zero ? zero : (x < zero ? -one : one)); 43 } 44 } // namespace 45 46 namespace functor { 47 template <typename T> 48 struct ApplyGradientDescent<CPUDevice, T> { 49 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 50 typename TTypes<T>::ConstScalar lr, 51 typename TTypes<T>::ConstFlat grad) { 52 var.device(d) -= grad * lr(); 53 } 54 }; 55 56 #ifdef TENSORFLOW_USE_SYCL 57 template <typename T> 58 struct ApplyGradientDescentSYCL { 59 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var, T lr, 60 typename TTypes<T>::ConstFlat grad) { 61 var.device(d) -= grad * lr; 62 } 63 }; 64 #endif 65 66 template <typename T> 67 struct ApplyAdadelta<CPUDevice, T> { 68 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 69 typename TTypes<T>::Flat accum, 70 typename TTypes<T>::Flat accum_update, 71 typename TTypes<T>::ConstScalar lr, 72 typename TTypes<T>::ConstScalar rho, 73 typename TTypes<T>::ConstScalar epsilon, 74 typename TTypes<T>::ConstFlat grad) { 75 accum.device(d) = 76 accum * rho() + grad.square() * (static_cast<T>(1) - rho()); 77 const auto update = 78 (accum_update + epsilon()).sqrt() * (accum + epsilon()).rsqrt() * grad; 79 var.device(d) -= update * lr(); 80 accum_update.device(d) = 81 accum_update * rho() + update.square() * (static_cast<T>(1) - rho()); 82 } 83 }; 84 85 template <typename T> 86 struct ApplyProximalGradientDescent<CPUDevice, T> { 87 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 88 typename TTypes<T>::ConstScalar lr, 89 typename TTypes<T>::ConstScalar l1, 90 typename TTypes<T>::ConstScalar l2, 91 typename TTypes<T>::ConstFlat grad) { 92 // Note that here is Fobos update, for details please refer: 93 // http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf 94 // TODO(xbing): merge the logic for ProximalGradientDescent and 95 // ProximalAdagrad. 96 auto prox_var = var; 97 // compute v = w - lr * grad. 98 prox_var.device(d) -= grad * lr(); 99 if (l1() > 0) { 100 // compute sign(v) * max(|v| - lr * l1, 0) 101 var.device(d) = 102 prox_var.sign() * 103 (prox_var.abs() - var.constant(lr() * l1())).cwiseMax(T(0.0)) / 104 (var.constant(1.0) + var.constant(l2() * lr())); 105 } else { 106 var.device(d) = 107 prox_var / (var.constant(1.0) + var.constant(l2() * lr())); 108 } 109 } 110 }; 111 112 template <typename T> 113 struct ApplyAdagradDA<CPUDevice, T> { 114 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 115 typename TTypes<T>::Flat gradient_accum, 116 typename TTypes<T>::Flat gradient_squared_accum, 117 typename TTypes<T>::ConstScalar lr, int64 global_step, 118 typename TTypes<T>::ConstScalar l1, 119 typename TTypes<T>::ConstScalar l2, 120 typename TTypes<T>::ConstFlat grad) { 121 // Accumulate gradient, and gradient_squared 122 gradient_accum.device(d) += grad; 123 gradient_squared_accum.device(d) += grad.square(); 124 125 // AdagradDA update: 126 // Let g to be gradient accumulator, gg to be gradient squared accumulator, 127 // T be the global step, lr is the learning rate, and k the initial 128 // gradient squared accumulator value. 129 // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} 130 if (l1() > 0) { 131 var.device(d) = 132 lr() * var.constant(-1.0) * gradient_accum.sign() * 133 (gradient_accum.abs() - 134 var.constant(static_cast<float>(global_step)) * var.constant(l1())) 135 .cwiseMax(T(0.0)) / 136 (var.constant(l2()) * 137 var.constant(static_cast<float>(global_step) * lr()) + 138 gradient_squared_accum.sqrt()); 139 } else { 140 var.device(d) = 141 lr() * gradient_accum * var.constant(-1.0) / 142 (var.constant(l2()) * 143 var.constant(static_cast<float>(global_step) * lr()) + 144 gradient_squared_accum.sqrt()); 145 } 146 } 147 }; 148 149 template <typename T> 150 struct ApplyAdagrad<CPUDevice, T> { 151 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 152 typename TTypes<T>::Flat accum, 153 typename TTypes<T>::ConstScalar lr, 154 typename TTypes<T>::ConstFlat grad) { 155 accum.device(d) += grad.square(); 156 var.device(d) -= grad * lr() * accum.rsqrt(); 157 } 158 }; 159 160 template <typename T> 161 struct ApplyProximalAdagrad<CPUDevice, T> { 162 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 163 typename TTypes<T>::Flat accum, 164 typename TTypes<T>::ConstScalar lr, 165 typename TTypes<T>::ConstScalar l1, 166 typename TTypes<T>::ConstScalar l2, 167 typename TTypes<T>::ConstFlat grad) { 168 // Fobos update per paper with Adagrad learning rate. 169 accum.device(d) += grad.square(); 170 // Adagrad learning rate. 171 auto learning_rate = accum.constant(lr()) * accum.rsqrt(); 172 auto prox_var = var; 173 // compute v = w - lr * grad. 174 prox_var.device(d) -= grad * learning_rate; 175 if (l1() > 0) { 176 // compute sign(v) * max(|v| - lr * l1, 0) 177 var.device(d) = prox_var.sign() * 178 (prox_var.abs() - learning_rate * prox_var.constant(l1())) 179 .cwiseMax(T(0.0)) / 180 (var.constant(1.0) + var.constant(l2()) * learning_rate); 181 } else { 182 var.device(d) = 183 prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate); 184 } 185 } 186 }; 187 188 template <typename T> 189 struct ApplyFtrlV2<CPUDevice, T> { 190 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 191 typename TTypes<T>::Flat accum, 192 typename TTypes<T>::Flat linear, 193 typename TTypes<T>::ConstFlat grad, 194 typename TTypes<T>::ConstScalar lr, 195 typename TTypes<T>::ConstScalar l1, 196 typename TTypes<T>::ConstScalar l2, 197 typename TTypes<T>::ConstScalar l2_shrinkage, 198 typename TTypes<T>::ConstScalar lr_power) { 199 auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var; 200 auto new_accum = accum + grad_with_shrinkage.square(); 201 // special case for which lr_power=-0.5. 202 if (lr_power() == static_cast<T>(-0.5)) { 203 linear.device(d) += 204 grad_with_shrinkage - (new_accum.sqrt() - accum.sqrt()) / lr() * var; 205 } else { 206 linear.device(d) += 207 grad_with_shrinkage - 208 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var; 209 } 210 auto x = (linear.constant(l1()) * linear.sign() - linear); 211 if (lr_power() == static_cast<T>(-0.5)) { 212 auto y = new_accum.sqrt() / new_accum.constant(lr()) + 213 linear.constant(static_cast<T>(2) * l2()); 214 auto pre_shrink = x / y; 215 var.device(d) = (linear.abs() > linear.constant(l1())) 216 .select(pre_shrink, var.constant(static_cast<T>(0))); 217 218 } else { 219 auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) + 220 linear.constant(static_cast<T>(2) * l2()); 221 auto pre_shrink = x / y; 222 var.device(d) = (linear.abs() > linear.constant(l1())) 223 .select(pre_shrink, var.constant(static_cast<T>(0))); 224 } 225 accum.device(d) += grad_with_shrinkage.square(); 226 } 227 }; 228 229 template <typename T> 230 struct ApplyFtrl<CPUDevice, T> { 231 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 232 typename TTypes<T>::Flat accum, 233 typename TTypes<T>::Flat linear, 234 typename TTypes<T>::ConstFlat grad, 235 typename TTypes<T>::ConstScalar lr, 236 typename TTypes<T>::ConstScalar l1, 237 typename TTypes<T>::ConstScalar l2, 238 typename TTypes<T>::ConstScalar lr_power) { 239 auto new_accum = accum + grad.square(); 240 // special case for which lr_power=-0.5. 241 if (lr_power() == static_cast<T>(-0.5)) { 242 linear.device(d) += grad - (new_accum.sqrt() - accum.sqrt()) / lr() * var; 243 } else { 244 linear.device(d) += 245 grad - 246 (new_accum.pow(-lr_power()) - accum.pow(-lr_power())) / lr() * var; 247 } 248 auto x = (linear.constant(l1()) * linear.sign() - linear); 249 if (lr_power() == static_cast<T>(-0.5)) { 250 auto y = new_accum.sqrt() / new_accum.constant(lr()) + 251 linear.constant(static_cast<T>(2) * l2()); 252 auto pre_shrink = x / y; 253 var.device(d) = (linear.abs() > linear.constant(l1())) 254 .select(pre_shrink, var.constant(static_cast<T>(0))); 255 256 } else { 257 auto y = new_accum.pow(-lr_power()) / new_accum.constant(lr()) + 258 linear.constant(static_cast<T>(2) * l2()); 259 auto pre_shrink = x / y; 260 var.device(d) = (linear.abs() > linear.constant(l1())) 261 .select(pre_shrink, var.constant(static_cast<T>(0))); 262 } 263 accum.device(d) += grad.square(); 264 } 265 }; 266 267 template <typename T> 268 struct ApplyMomentum<CPUDevice, T> { 269 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 270 typename TTypes<T>::Flat accum, 271 typename TTypes<T>::ConstScalar lr, 272 typename TTypes<T>::ConstFlat grad, 273 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) { 274 accum.device(d) = accum * momentum() + grad; 275 if (use_nesterov) { 276 var.device(d) -= grad * lr() + accum * momentum() * lr(); 277 } else { 278 var.device(d) -= accum * lr(); 279 } 280 } 281 }; 282 283 template <typename Device, typename T> 284 struct ApplyAdamNonCuda { 285 void operator()(const Device& d, typename TTypes<T>::Flat var, 286 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 287 typename TTypes<T>::ConstScalar beta1_power, 288 typename TTypes<T>::ConstScalar beta2_power, 289 typename TTypes<T>::ConstScalar lr, 290 typename TTypes<T>::ConstScalar beta1, 291 typename TTypes<T>::ConstScalar beta2, 292 typename TTypes<T>::ConstScalar epsilon, 293 typename TTypes<T>::ConstFlat grad, bool use_nesterov) { 294 const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / 295 (T(1) - beta1_power()); 296 // beta1 == 297 // beta2 == 298 // v == n 299 // var == 300 301 m.device(d) += (grad - m) * (T(1) - beta1()); 302 v.device(d) += (grad.square() - v) * (T(1) - beta2()); 303 if (use_nesterov) { 304 var.device(d) -= ((grad * (T(1) - beta1()) + beta1() * m) * alpha) / 305 (v.sqrt() + epsilon()); 306 } else { 307 var.device(d) -= (m * alpha) / (v.sqrt() + epsilon()); 308 } 309 } 310 }; 311 312 #ifdef TENSORFLOW_USE_SYCL 313 template <typename T> 314 struct ApplyAdamSYCL { 315 void operator()(const SYCLDevice& d, typename TTypes<T>::Flat var, 316 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 317 T beta1_power, T beta2_power, T lr, T beta1, T beta2, 318 T epsilon, typename TTypes<T>::ConstFlat grad) { 319 const T alpha = 320 lr * Eigen::numext::sqrt(T(1) - beta2_power) / (T(1) - beta1_power); 321 m.device(d) += (grad - m) * (T(1) - beta1); 322 v.device(d) += (grad.square() - v) * (T(1) - beta2); 323 var.device(d) -= (m * alpha) / (v.sqrt() + epsilon); 324 } 325 }; 326 #endif // TENSORFLOW_USE_SYCL 327 328 template <typename T> 329 struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {}; 330 331 template <typename T> 332 struct ApplyRMSProp<CPUDevice, T> { 333 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 334 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, 335 typename TTypes<T>::ConstScalar lr, 336 typename TTypes<T>::ConstScalar rho, 337 typename TTypes<T>::ConstScalar momentum, 338 typename TTypes<T>::ConstScalar epsilon, 339 typename TTypes<T>::ConstFlat grad) { 340 ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho()); 341 mom.device(d) = 342 mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt()); 343 var.device(d) -= mom; 344 } 345 }; 346 347 template <typename T> 348 struct ApplyCenteredRMSProp<CPUDevice, T> { 349 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 350 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, 351 typename TTypes<T>::Flat mom, 352 typename TTypes<T>::ConstScalar lr, 353 typename TTypes<T>::ConstScalar rho, 354 typename TTypes<T>::ConstScalar momentum, 355 typename TTypes<T>::ConstScalar epsilon, 356 typename TTypes<T>::ConstFlat grad) { 357 ms.device(d) += (grad.square() - ms) * (static_cast<T>(1) - rho()); 358 mg.device(d) += (grad - mg) * (static_cast<T>(1) - rho()); 359 auto denom = (ms - mg.square()) + epsilon(); 360 mom.device(d) = mom * momentum() + (grad * lr()) / denom.sqrt(); 361 var.device(d) -= mom; 362 } 363 }; 364 365 template <typename T> 366 struct ApplyAddSign<CPUDevice, T> { 367 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 368 typename TTypes<T>::Flat m, 369 typename TTypes<T>::ConstScalar lr, 370 typename TTypes<T>::ConstScalar alpha, 371 typename TTypes<T>::ConstScalar sign_decay, 372 typename TTypes<T>::ConstScalar beta, 373 typename TTypes<T>::ConstFlat grad) { 374 m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); 375 auto sign_gm = grad.sign() * m.sign(); 376 var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad; 377 } 378 }; 379 380 template <typename T> 381 struct ApplyPowerSign<CPUDevice, T> { 382 void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, 383 typename TTypes<T>::Flat m, 384 typename TTypes<T>::ConstScalar lr, 385 typename TTypes<T>::ConstScalar logbase, 386 typename TTypes<T>::ConstScalar sign_decay, 387 typename TTypes<T>::ConstScalar beta, 388 typename TTypes<T>::ConstFlat grad) { 389 m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); 390 auto sign_gm = grad.sign() * m.sign(); 391 auto grad_scale = (logbase() * sign_decay() * sign_gm).exp(); 392 var.device(d) -= lr() * grad_scale * grad; 393 } 394 }; 395 396 } // namespace functor 397 398 template <typename Device, typename T> 399 class ApplyGradientDescentOp : public OpKernel { 400 public: 401 explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 402 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 403 } 404 405 void Compute(OpKernelContext* ctx) override { 406 auto locks = 407 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); 408 Tensor var; 409 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 410 ctx, 0, use_exclusive_lock_, false, &var)); 411 412 OP_REQUIRES( 413 ctx, var.IsInitialized(), 414 errors::FailedPrecondition( 415 "Attempting to use uninitialized variables: ", requested_input(0))); 416 const Tensor& alpha = ctx->input(1); 417 OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), 418 errors::InvalidArgument("alpha is not a scalar: ", 419 alpha.shape().DebugString())); 420 const Tensor& delta = ctx->input(2); 421 OP_REQUIRES( 422 ctx, var.shape().IsSameSize(delta.shape()), 423 errors::InvalidArgument("var and delta do not have the same shape", 424 var.shape().DebugString(), " ", 425 delta.shape().DebugString())); 426 427 const Device& device = ctx->template eigen_device<Device>(); 428 functor::ApplyGradientDescent<Device, T>()( 429 device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>()); 430 431 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 432 } 433 434 private: 435 bool use_exclusive_lock_; 436 }; 437 438 #ifdef TENSORFLOW_USE_SYCL 439 template <typename T> 440 class ApplyGradientDescentOp<SYCLDevice, T> : public OpKernel { 441 public: 442 explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 443 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 444 } 445 446 void Compute(OpKernelContext* ctx) override { 447 auto locks = 448 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); 449 Tensor var; 450 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>( 451 ctx, 0, use_exclusive_lock_, false, &var)); 452 453 OP_REQUIRES( 454 ctx, var.IsInitialized(), 455 errors::FailedPrecondition( 456 "Attempting to use uninitialized variables: ", requested_input(0))); 457 const Tensor& alpha_dev = ctx->input(1); 458 OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()), 459 errors::InvalidArgument("alpha is not a scalar: ", 460 alpha_dev.shape().DebugString())); 461 const Tensor& delta = ctx->input(2); 462 OP_REQUIRES( 463 ctx, var.shape().IsSameSize(delta.shape()), 464 errors::InvalidArgument("var and delta do not have the same shape", 465 var.shape().DebugString(), " ", 466 delta.shape().DebugString())); 467 468 auto device = ctx->eigen_sycl_device(); 469 auto size = sizeof(T); 470 T alpha = T(0); 471 auto src_ptr = GetBase(&alpha_dev); 472 device.memcpyDeviceToHost(&alpha, static_cast<const T*>(src_ptr), size); 473 474 functor::ApplyGradientDescentSYCL<T>()(device, var.flat<T>(), alpha, 475 delta.flat<T>()); 476 477 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 478 } 479 480 private: 481 bool use_exclusive_lock_; 482 }; 483 #endif // TENSORFLOW_USE_SYCL 484 485 #define REGISTER_KERNELS(D, T) \ 486 REGISTER_KERNEL_BUILDER( \ 487 Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 488 ApplyGradientDescentOp<D##Device, T>); \ 489 REGISTER_KERNEL_BUILDER(Name("ResourceApplyGradientDescent") \ 490 .Device(DEVICE_##D) \ 491 .HostMemory("var") \ 492 .TypeConstraint<T>("T"), \ 493 ApplyGradientDescentOp<D##Device, T>); 494 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 495 496 TF_CALL_half(REGISTER_CPU_KERNELS); 497 TF_CALL_float(REGISTER_CPU_KERNELS); 498 TF_CALL_double(REGISTER_CPU_KERNELS); 499 500 #if GOOGLE_CUDA 501 // Forward declarations of the functor specializations for GPU. 502 namespace functor { 503 #define DECLARE_GPU_SPEC(T) \ 504 template <> \ 505 void ApplyGradientDescent<GPUDevice, T>::operator()( \ 506 const GPUDevice& d, typename TTypes<T>::Flat var, \ 507 typename TTypes<T>::ConstScalar alpha, \ 508 typename TTypes<T>::ConstFlat delta); \ 509 extern template struct ApplyGradientDescent<GPUDevice, T>; 510 DECLARE_GPU_SPEC(Eigen::half); 511 DECLARE_GPU_SPEC(float); 512 DECLARE_GPU_SPEC(double); 513 #undef DECLARE_GPU_SPEC 514 } // namespace functor 515 516 REGISTER_KERNELS(GPU, Eigen::half); 517 REGISTER_KERNELS(GPU, float); 518 REGISTER_KERNELS(GPU, double); 519 #endif 520 521 #ifdef TENSORFLOW_USE_SYCL 522 #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); 523 TF_CALL_float(REGISTER_SYCL_KERNELS); 524 TF_CALL_double(REGISTER_SYCL_KERNELS); 525 #undef REGISTER_SYCL_KERNELS 526 #endif // TENSORFLOW_USE_SYCL 527 528 #undef REGISTER_CPU_KERNELS 529 #undef REGISTER_KERNELS 530 531 template <typename Device, typename T> 532 class ApplyAdadeltaOp : public OpKernel { 533 public: 534 explicit ApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 535 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 536 } 537 538 void Compute(OpKernelContext* ctx) override { 539 mutex* mu = GetTrainingVariableMutex(ctx, 0); 540 if (use_exclusive_lock_ && mu != nullptr) { 541 mutex_lock l1(*mu); 542 // Don't try to acquire a lock on the second ref as they share the same 543 // mutex. 544 // 545 // mutex_lock l2(*ctx->input_ref_mutex(1)); 546 DoValidate(ctx); 547 if (!ctx->status().ok()) return; 548 DoCompute(ctx); 549 } else { 550 DoValidate(ctx); 551 if (!ctx->status().ok()) return; 552 DoCompute(ctx); 553 } 554 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 555 } 556 557 private: 558 bool use_exclusive_lock_; 559 560 void DoValidate(OpKernelContext* ctx) { 561 Tensor var; 562 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 563 ctx, 0, use_exclusive_lock_, false, &var)); 564 Tensor accum; 565 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 566 ctx, 1, use_exclusive_lock_, false, &accum)); 567 Tensor accum_update; 568 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 569 ctx, 2, use_exclusive_lock_, false, &accum_update)); 570 571 OP_REQUIRES( 572 ctx, var.IsInitialized(), 573 errors::FailedPrecondition( 574 "Attempting to use uninitialized variables: ", requested_input(0))); 575 OP_REQUIRES( 576 ctx, accum.IsInitialized(), 577 errors::FailedPrecondition( 578 "Attempting to use uninitialized variables: ", requested_input(1))); 579 OP_REQUIRES( 580 ctx, accum_update.IsInitialized(), 581 errors::FailedPrecondition( 582 "Attempting to use uninitialized variables: ", requested_input(2))); 583 584 const Tensor& lr = ctx->input(3); 585 const Tensor& rho = ctx->input(4); 586 const Tensor& epsilon = ctx->input(5); 587 const Tensor& grad = ctx->input(6); 588 589 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 590 errors::InvalidArgument("lr is not a scalar: ", 591 lr.shape().DebugString())); 592 593 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 594 errors::InvalidArgument("rho is not a scalar: ", 595 rho.shape().DebugString())); 596 597 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 598 errors::InvalidArgument("epsilon is not a scalar: ", 599 epsilon.shape().DebugString())); 600 601 OP_REQUIRES( 602 ctx, var.shape().IsSameSize(accum.shape()), 603 errors::InvalidArgument("var and accum do not have the same shape", 604 var.shape().DebugString(), " ", 605 accum.shape().DebugString())); 606 OP_REQUIRES( 607 ctx, var.shape().IsSameSize(grad.shape()), 608 errors::InvalidArgument("var and grad do not have the same shape", 609 var.shape().DebugString(), " ", 610 grad.shape().DebugString())); 611 } 612 613 void DoCompute(OpKernelContext* ctx) { 614 const Device& device = ctx->template eigen_device<Device>(); 615 Tensor var; 616 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 617 ctx, 0, use_exclusive_lock_, false, &var)); 618 Tensor accum; 619 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 620 ctx, 1, use_exclusive_lock_, false, &accum)); 621 Tensor accum_update; 622 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 623 ctx, 2, use_exclusive_lock_, false, &accum_update)); 624 625 const Tensor& lr = ctx->input(3); 626 const Tensor& rho = ctx->input(4); 627 const Tensor& epsilon = ctx->input(5); 628 const Tensor& grad = ctx->input(6); 629 630 functor::ApplyAdadelta<Device, T>()( 631 device, var.flat<T>(), accum.flat<T>(), accum_update.flat<T>(), 632 lr.scalar<T>(), rho.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>()); 633 } 634 }; 635 636 #define REGISTER_KERNELS(D, T) \ 637 REGISTER_KERNEL_BUILDER( \ 638 Name("ApplyAdadelta").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 639 ApplyAdadeltaOp<D##Device, T>); \ 640 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdadelta") \ 641 .Device(DEVICE_##D) \ 642 .HostMemory("var") \ 643 .HostMemory("accum") \ 644 .HostMemory("accum_update") \ 645 .TypeConstraint<T>("T"), \ 646 ApplyAdadeltaOp<D##Device, T>); 647 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 648 649 TF_CALL_half(REGISTER_CPU_KERNELS); 650 TF_CALL_float(REGISTER_CPU_KERNELS); 651 TF_CALL_double(REGISTER_CPU_KERNELS); 652 653 #if GOOGLE_CUDA 654 // Forward declarations of the functor specializations for GPU. 655 namespace functor { 656 #define DECLARE_GPU_SPEC(T) \ 657 template <> \ 658 void ApplyAdadelta<GPUDevice, T>::operator()( \ 659 const GPUDevice& d, typename TTypes<T>::Flat var, \ 660 typename TTypes<T>::Flat accum, typename TTypes<T>::Flat accum_update, \ 661 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \ 662 typename TTypes<T>::ConstScalar epsilon, \ 663 typename TTypes<T>::ConstFlat grad); \ 664 extern template struct ApplyAdadelta<GPUDevice, T>; 665 DECLARE_GPU_SPEC(Eigen::half); 666 DECLARE_GPU_SPEC(float); 667 DECLARE_GPU_SPEC(double); 668 #undef DECLARE_GPU_SPEC 669 } // namespace functor 670 671 REGISTER_KERNELS(GPU, Eigen::half); 672 REGISTER_KERNELS(GPU, float); 673 REGISTER_KERNELS(GPU, double); 674 #endif 675 #undef REGISTER_CPU_KERNELS 676 #undef REGISTER_KERNELS 677 678 // Note, this op works on cpu only. 679 template <typename T, typename Tindex> 680 class SparseApplyAdadeltaOp : public OpKernel { 681 public: 682 explicit SparseApplyAdadeltaOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 683 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 684 } 685 686 void Compute(OpKernelContext* ctx) override { 687 mutex* mu = GetTrainingVariableMutex(ctx, 0); 688 // mu_accum is actually the same mutex as mu_var since currently we use a 689 // global mutex. 690 // 691 // mutex* mu_accum = ctx->input_ref_mutex(1); 692 if (use_exclusive_lock_ && mu != nullptr) { 693 mutex_lock ml(*mu); 694 DoCompute(ctx); 695 } else { 696 DoCompute(ctx); 697 } 698 } 699 700 void DoCompute(OpKernelContext* ctx) { 701 Tensor var; 702 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 703 ctx, 0, use_exclusive_lock_, true, &var)); 704 Tensor accum_grad; 705 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 706 ctx, 1, use_exclusive_lock_, true, &accum_grad)); 707 Tensor accum_update; 708 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 709 ctx, 2, use_exclusive_lock_, true, &accum_update)); 710 OP_REQUIRES( 711 ctx, var.IsInitialized(), 712 errors::FailedPrecondition( 713 "Attempting to use uninitialized variables: ", requested_input(0))); 714 OP_REQUIRES( 715 ctx, accum_grad.IsInitialized(), 716 errors::FailedPrecondition( 717 "Attempting to use uninitialized variables: ", requested_input(1))); 718 OP_REQUIRES( 719 ctx, accum_update.IsInitialized(), 720 errors::FailedPrecondition( 721 "Attempting to use uninitialized variables: ", requested_input(2))); 722 OP_REQUIRES( 723 ctx, var.shape().IsSameSize(accum_grad.shape()), 724 errors::InvalidArgument("var and accum_grad do not have the same shape", 725 var.shape().DebugString(), " ", 726 accum_grad.shape().DebugString())); 727 OP_REQUIRES(ctx, var.shape().IsSameSize(accum_update.shape()), 728 errors::InvalidArgument( 729 "var and accum_update do not have the same shape", 730 var.shape().DebugString(), " ", 731 accum_update.shape().DebugString())); 732 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 733 errors::InvalidArgument("var must be at least 1 dimensional")); 734 735 const Tensor& lr = ctx->input(3); 736 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 737 errors::InvalidArgument("lr is not a scalar: ", 738 lr.shape().DebugString())); 739 const Tensor& rho = ctx->input(4); 740 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 741 errors::InvalidArgument("rho is not a scalar: ", 742 rho.shape().DebugString())); 743 const Tensor& epsilon = ctx->input(5); 744 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 745 errors::InvalidArgument("epsilon is not a scalar: ", 746 epsilon.shape().DebugString())); 747 const Tensor& grad = ctx->input(6); 748 const Tensor& indices = ctx->input(7); 749 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 750 errors::InvalidArgument("indices must be one-dimensional")); 751 752 for (int d = 1; d < var.dims(); d++) { 753 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 754 errors::InvalidArgument(strings::StrCat( 755 "var and grad must match in dimension ", d))); 756 } 757 const Tindex N = indices.dim_size(0); 758 OP_REQUIRES( 759 ctx, grad.dim_size(0) == N, 760 errors::InvalidArgument( 761 "grad must be the same size as indices in the first dimension.")); 762 763 if (N > 0) { 764 const Tindex first_dim_size = var.dim_size(0); 765 // Validate all the indices are in range 766 auto indices_vec = indices.vec<Tindex>(); 767 for (Tindex i = 0; i < N; i++) { 768 const Tindex index = indices_vec(i); 769 OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, 770 errors::InvalidArgument( 771 strings::StrCat("Index ", index, " at offset ", i, 772 " in indices is out of range"))); 773 } 774 775 auto var_flat = var.flat_outer_dims<T>(); 776 auto accum_grad_flat = accum_grad.flat_outer_dims<T>(); 777 auto accum_update_flat = accum_update.flat_outer_dims<T>(); 778 auto grad_flat = grad.flat_outer_dims<T>(); 779 const T lr_scalar = lr.scalar<T>()(); 780 const T rho_scalar = rho.scalar<T>()(); 781 const T epsilon_scalar = epsilon.scalar<T>()(); 782 783 for (Tindex i = 0; i < N; i++) { 784 const Tindex index = indices_vec(i); 785 auto accum_ = accum_grad_flat.template chip<0>(index); 786 auto accum_update_ = accum_update_flat.template chip<0>(index); 787 auto grad_ = grad_flat.template chip<0>(i); 788 789 accum_ = accum_ * accum_.constant(rho_scalar) + 790 grad_.square() * grad_.constant(T(1) - rho_scalar); 791 const auto update = 792 (accum_update_ + accum_update_.constant(epsilon_scalar)).sqrt() * 793 (accum_ + accum_.constant(epsilon_scalar)).rsqrt() * grad_; 794 auto v = var_flat.template chip<0>(index); 795 v -= update * update.constant(lr_scalar); 796 accum_update_ = 797 accum_update_ * accum_update_.constant(rho_scalar) + 798 update.square() * update.constant(static_cast<T>(1) - rho_scalar); 799 } 800 } 801 802 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 803 } 804 805 private: 806 bool use_exclusive_lock_; 807 }; 808 809 #define REGISTER_KERNELS(T, Tindices) \ 810 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdadelta") \ 811 .Device(DEVICE_CPU) \ 812 .TypeConstraint<T>("T") \ 813 .TypeConstraint<Tindices>("Tindices"), \ 814 SparseApplyAdadeltaOp<T, Tindices>); \ 815 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdadelta") \ 816 .Device(DEVICE_CPU) \ 817 .TypeConstraint<T>("T") \ 818 .TypeConstraint<Tindices>("Tindices"), \ 819 SparseApplyAdadeltaOp<T, Tindices>); 820 #define REGISTER_CPU_KERNELS(T) \ 821 REGISTER_KERNELS(T, int32); \ 822 REGISTER_KERNELS(T, int64); 823 824 TF_CALL_half(REGISTER_CPU_KERNELS); 825 TF_CALL_float(REGISTER_CPU_KERNELS); 826 TF_CALL_double(REGISTER_CPU_KERNELS); 827 828 #undef REGISTER_CPU_KERNELS 829 #undef REGISTER_KERNELS 830 831 // Note, this op works on cpu only. 832 template <typename Device, typename T> 833 class ApplyProximalGradientDescentOp : public OpKernel { 834 public: 835 explicit ApplyProximalGradientDescentOp(OpKernelConstruction* ctx) 836 : OpKernel(ctx) { 837 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 838 } 839 840 void Compute(OpKernelContext* ctx) override { 841 auto locks = 842 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); 843 Tensor var; 844 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 845 ctx, 0, use_exclusive_lock_, false, &var)); 846 847 OP_REQUIRES( 848 ctx, var.IsInitialized(), 849 errors::FailedPrecondition( 850 "Attempting to use uninitialized variables: ", requested_input(0))); 851 const Tensor& alpha = ctx->input(1); 852 OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), 853 errors::InvalidArgument("alpha is not a scalar: ", 854 alpha.shape().DebugString())); 855 const Tensor& l1 = ctx->input(2); 856 OP_REQUIRES( 857 ctx, TensorShapeUtils::IsScalar(l1.shape()), 858 errors::InvalidArgument("l1 regularization strength is not a scalar: ", 859 l1.shape().DebugString())); 860 const Tensor& l2 = ctx->input(3); 861 OP_REQUIRES( 862 ctx, TensorShapeUtils::IsScalar(l2.shape()), 863 errors::InvalidArgument("l2 regularization strength is not a scalar: ", 864 l2.shape().DebugString())); 865 866 const Tensor& delta = ctx->input(4); 867 OP_REQUIRES( 868 ctx, var.shape().IsSameSize(delta.shape()), 869 errors::InvalidArgument("var and delta do not have the same shape", 870 var.shape().DebugString(), " ", 871 delta.shape().DebugString())); 872 873 const Device& device = ctx->template eigen_device<Device>(); 874 functor::ApplyProximalGradientDescent<Device, T>()( 875 device, var.flat<T>(), alpha.scalar<T>(), l1.scalar<T>(), 876 l2.scalar<T>(), delta.flat<T>()); 877 878 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 879 } 880 881 private: 882 bool use_exclusive_lock_; 883 }; 884 885 #define REGISTER_KERNELS(D, T) \ 886 REGISTER_KERNEL_BUILDER(Name("ApplyProximalGradientDescent") \ 887 .Device(DEVICE_##D) \ 888 .TypeConstraint<T>("T"), \ 889 ApplyProximalGradientDescentOp<D##Device, T>); \ 890 REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalGradientDescent") \ 891 .HostMemory("var") \ 892 .Device(DEVICE_##D) \ 893 .TypeConstraint<T>("T"), \ 894 ApplyProximalGradientDescentOp<D##Device, T>); 895 896 REGISTER_KERNELS(CPU, float); 897 REGISTER_KERNELS(CPU, double); 898 #undef REGISTER_KERNELS 899 900 // Note, this op works on cpu only. 901 template <typename T, typename Tindex> 902 class SparseApplyProximalGradientDescentOp : public OpKernel { 903 public: 904 explicit SparseApplyProximalGradientDescentOp(OpKernelConstruction* ctx) 905 : OpKernel(ctx) { 906 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 907 } 908 909 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 910 auto locks = 911 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); 912 Tensor var; 913 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 914 ctx, 0, use_exclusive_lock_, true, &var)); 915 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 916 errors::InvalidArgument("var must be at least 1 dimensional")); 917 918 const Tensor& lr = ctx->input(1); 919 OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), 920 errors::InvalidArgument("lr is not a scalar: ", 921 lr.shape().DebugString())); 922 const Tensor& l1 = ctx->input(2); 923 OP_REQUIRES( 924 ctx, TensorShapeUtils::IsScalar(l1.shape()), 925 errors::InvalidArgument("l1 regularization strength is not a scalar: ", 926 l1.shape().DebugString())); 927 const Tensor& l2 = ctx->input(3); 928 OP_REQUIRES( 929 ctx, TensorShapeUtils::IsScalar(l2.shape()), 930 errors::InvalidArgument("l2 regularization strength is not a scalar: ", 931 l2.shape().DebugString())); 932 933 const Tensor& grad = ctx->input(4); 934 const Tensor& indices = ctx->input(5); 935 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 936 errors::InvalidArgument("indices must be one-dimensional")); 937 938 int64 inner_dim = 1; 939 for (int d = 1; d < var.dims(); d++) { 940 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 941 errors::InvalidArgument(strings::StrCat( 942 "var and grad must match in dimension ", d))); 943 inner_dim *= grad.dim_size(d); 944 } 945 const Tindex N = indices.dim_size(0); 946 OP_REQUIRES( 947 ctx, grad.dim_size(0) == N, 948 errors::InvalidArgument( 949 "grad must be the same size as indices in the first dimension.")); 950 OP_REQUIRES(ctx, inner_dim > 0, 951 errors::InvalidArgument( 952 "Inner dimension should be greater than zero.")); 953 954 if (N > 0) { 955 if (inner_dim > 1) { 956 const Tindex first_dim_size = var.dim_size(0); 957 auto indices_vec = indices.vec<Tindex>(); 958 auto var_flat = var.flat_outer_dims<T>(); 959 auto grad_flat = grad.flat_outer_dims<T>(); 960 T lr_scalar = lr.scalar<T>()(); 961 T l1_scalar = l1.scalar<T>()(); 962 T l2_scalar = l2.scalar<T>()(); 963 964 // TODO(xbing): extract the common logic for the Fobos update. 965 for (Tindex i = 0; i < N; i++) { 966 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 967 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 968 errors::InvalidArgument( 969 strings::StrCat("Index ", index, " at offset ", i, 970 " in indices is out of range"))); 971 auto g = grad_flat.template chip<0>(i); 972 auto v = var_flat.template chip<0>(index); 973 // compute learning_rate for current step. 974 auto learning_rate = v.constant(lr_scalar); 975 auto prox_v = v; 976 // v = w - g * learning_rate. 977 prox_v -= g * learning_rate; 978 if (l1_scalar > 0) { 979 // compute sign(v) * max(|v|, 0) 980 v = prox_v.sign() * 981 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) 982 .cwiseMax(static_cast<T>(0.0)) / 983 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); 984 } else { 985 v = prox_v / 986 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); 987 } 988 } 989 } else { 990 auto indices_vec = indices.vec<Tindex>(); 991 auto var_flat = var.flat<T>(); 992 auto grad_flat = grad.flat<T>(); 993 T lr_scalar = lr.scalar<T>()(); 994 T l1_scalar = l1.scalar<T>()(); 995 T l2_scalar = l2.scalar<T>()(); 996 const Tindex first_dim_size = var_flat.size(); 997 998 for (Tindex i = 0; i < N; i++) { 999 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1000 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1001 errors::InvalidArgument( 1002 strings::StrCat("Index ", index, " at offset ", i, 1003 " in indices is out of range"))); 1004 const T& g = grad_flat(i); 1005 auto learning_rate = lr_scalar; 1006 auto prox_v = var_flat(index); 1007 prox_v -= learning_rate * g; 1008 if (l1_scalar > 0) { 1009 var_flat(index) = 1010 sgn(prox_v) * 1011 std::max(std::abs(prox_v) - learning_rate * l1_scalar, 1012 static_cast<T>(0.0)) / 1013 (1.0 + l2_scalar * learning_rate); 1014 } else { 1015 var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate); 1016 } 1017 } 1018 } 1019 } 1020 1021 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1022 } 1023 1024 private: 1025 bool use_exclusive_lock_; 1026 }; 1027 1028 #define REGISTER_KERNELS(T, Tindices) \ 1029 REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalGradientDescent") \ 1030 .Device(DEVICE_CPU) \ 1031 .TypeConstraint<T>("T") \ 1032 .TypeConstraint<Tindices>("Tindices"), \ 1033 SparseApplyProximalGradientDescentOp<T, Tindices>); \ 1034 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalGradientDescent") \ 1035 .Device(DEVICE_CPU) \ 1036 .TypeConstraint<T>("T") \ 1037 .TypeConstraint<Tindices>("Tindices"), \ 1038 SparseApplyProximalGradientDescentOp<T, Tindices>); 1039 1040 REGISTER_KERNELS(float, int32); 1041 REGISTER_KERNELS(float, int64); 1042 REGISTER_KERNELS(double, int32); 1043 REGISTER_KERNELS(double, int64); 1044 #undef REGISTER_KERNELS 1045 1046 template <typename Device, typename T> 1047 class ApplyAdagradOp : public OpKernel { 1048 public: 1049 explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1050 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1051 } 1052 1053 void Compute(OpKernelContext* ctx) override { 1054 auto locks = 1055 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 1056 Tensor var; 1057 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1058 ctx, 0, use_exclusive_lock_, false, &var)); 1059 Tensor accum; 1060 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1061 ctx, 1, use_exclusive_lock_, false, &accum)); 1062 OP_REQUIRES( 1063 ctx, var.IsInitialized(), 1064 errors::FailedPrecondition( 1065 "Attempting to use uninitialized variables: ", requested_input(0))); 1066 OP_REQUIRES( 1067 ctx, accum.IsInitialized(), 1068 errors::FailedPrecondition( 1069 "Attempting to use uninitialized variables: ", requested_input(1))); 1070 const Tensor& lr = ctx->input(2); 1071 OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), 1072 errors::InvalidArgument("lr is not a scalar: ", 1073 lr.shape().DebugString())); 1074 const Tensor& grad = ctx->input(3); 1075 OP_REQUIRES( 1076 ctx, var.shape().IsSameSize(accum.shape()), 1077 errors::InvalidArgument("var and accum do not have the same shape", 1078 var.shape().DebugString(), " ", 1079 accum.shape().DebugString())); 1080 OP_REQUIRES( 1081 ctx, var.shape().IsSameSize(grad.shape()), 1082 errors::InvalidArgument("var and grad do not have the same shape", 1083 var.shape().DebugString(), " ", 1084 grad.shape().DebugString())); 1085 1086 const Device& device = ctx->template eigen_device<Device>(); 1087 functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(), 1088 lr.scalar<T>(), grad.flat<T>()); 1089 1090 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1091 } 1092 1093 private: 1094 bool use_exclusive_lock_; 1095 }; 1096 1097 #define REGISTER_KERNELS(D, T) \ 1098 REGISTER_KERNEL_BUILDER( \ 1099 Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 1100 ApplyAdagradOp<D##Device, T>); \ 1101 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagrad") \ 1102 .HostMemory("var") \ 1103 .HostMemory("accum") \ 1104 .Device(DEVICE_##D) \ 1105 .TypeConstraint<T>("T"), \ 1106 ApplyAdagradOp<D##Device, T>); 1107 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 1108 1109 TF_CALL_half(REGISTER_CPU_KERNELS); 1110 TF_CALL_float(REGISTER_CPU_KERNELS); 1111 TF_CALL_double(REGISTER_CPU_KERNELS); 1112 1113 #if GOOGLE_CUDA 1114 // Forward declarations of the functor specializations for GPU. 1115 namespace functor { 1116 #define DECLARE_GPU_SPEC(T) \ 1117 template <> \ 1118 void ApplyAdagrad<GPUDevice, T>::operator()( \ 1119 const GPUDevice& d, typename TTypes<T>::Flat var, \ 1120 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ 1121 typename TTypes<T>::ConstFlat grad); \ 1122 extern template struct ApplyAdagrad<GPUDevice, T>; 1123 DECLARE_GPU_SPEC(Eigen::half); 1124 DECLARE_GPU_SPEC(float); 1125 DECLARE_GPU_SPEC(double); 1126 #undef DECLARE_GPU_SPEC 1127 } // namespace functor 1128 1129 REGISTER_KERNELS(GPU, Eigen::half); 1130 REGISTER_KERNELS(GPU, float); 1131 REGISTER_KERNELS(GPU, double); 1132 #endif 1133 #undef REGISTER_CPU_KERNELS 1134 #undef REGISTER_KERNELS 1135 1136 template <typename Device, typename T> 1137 class ApplyProximalAdagradOp : public OpKernel { 1138 public: 1139 explicit ApplyProximalAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1140 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1141 } 1142 1143 void Compute(OpKernelContext* ctx) override { 1144 auto locks = 1145 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 1146 Tensor var; 1147 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1148 ctx, 0, use_exclusive_lock_, false, &var)); 1149 Tensor accum; 1150 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1151 ctx, 1, use_exclusive_lock_, false, &accum)); 1152 OP_REQUIRES( 1153 ctx, var.IsInitialized(), 1154 errors::FailedPrecondition( 1155 "Attempting to use uninitialized variables: ", requested_input(0))); 1156 OP_REQUIRES( 1157 ctx, accum.IsInitialized(), 1158 errors::FailedPrecondition( 1159 "Attempting to use uninitialized variables: ", requested_input(1))); 1160 OP_REQUIRES( 1161 ctx, var.shape().IsSameSize(accum.shape()), 1162 errors::InvalidArgument("var and accum do not have the same shape", 1163 var.shape().DebugString(), " ", 1164 accum.shape().DebugString())); 1165 const Tensor& lr = ctx->input(2); 1166 OP_REQUIRES(ctx, 1167 TensorShapeUtils::IsScalar(lr.shape()) && 1168 lr.scalar<T>()() > static_cast<T>(0), 1169 errors::InvalidArgument("lr is not a positive scalar: ", 1170 lr.shape().DebugString())); 1171 const Tensor& l1 = ctx->input(3); 1172 OP_REQUIRES(ctx, 1173 TensorShapeUtils::IsScalar(l1.shape()) && 1174 l1.scalar<T>()() >= static_cast<T>(0), 1175 errors::InvalidArgument("l1 regularization strength is not a " 1176 "non-negative scalar: ", 1177 l1.shape().DebugString())); 1178 const Tensor& l2 = ctx->input(4); 1179 OP_REQUIRES(ctx, 1180 TensorShapeUtils::IsScalar(l2.shape()) && 1181 l2.scalar<T>()() >= static_cast<T>(0), 1182 errors::InvalidArgument("l2 regularization strength is not a " 1183 "non-negative scalar: ", 1184 l2.shape().DebugString())); 1185 const Tensor& grad = ctx->input(5); 1186 OP_REQUIRES( 1187 ctx, var.shape().IsSameSize(grad.shape()), 1188 errors::InvalidArgument("var and grad do not have the same shape", 1189 var.shape().DebugString(), " ", 1190 grad.shape().DebugString())); 1191 1192 const Device& device = ctx->template eigen_device<Device>(); 1193 functor::ApplyProximalAdagrad<Device, T>()( 1194 device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), 1195 l2.scalar<T>(), grad.flat<T>()); 1196 1197 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1198 } 1199 1200 private: 1201 bool use_exclusive_lock_; 1202 }; 1203 1204 #define REGISTER_KERNELS(D, T) \ 1205 REGISTER_KERNEL_BUILDER( \ 1206 Name("ApplyProximalAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 1207 ApplyProximalAdagradOp<D##Device, T>); \ 1208 REGISTER_KERNEL_BUILDER(Name("ResourceApplyProximalAdagrad") \ 1209 .Device(DEVICE_##D) \ 1210 .HostMemory("var") \ 1211 .HostMemory("accum") \ 1212 .TypeConstraint<T>("T"), \ 1213 ApplyProximalAdagradOp<D##Device, T>); 1214 1215 REGISTER_KERNELS(CPU, float); 1216 REGISTER_KERNELS(CPU, double); 1217 #undef REGISTER_KERNELS 1218 1219 namespace { 1220 1221 template <typename T> 1222 inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, 1223 const T& l2, const T& lr_power) { 1224 T quadratic; 1225 if (lr_power == static_cast<T>(-0.5)) { 1226 quadratic = Eigen::numext::sqrt(accum) / lr + static_cast<T>(2) * l2; 1227 } else { 1228 quadratic = 1229 Eigen::numext::pow(accum, -lr_power) / lr + static_cast<T>(2) * l2; 1230 } 1231 auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); 1232 return (l1_reg_adjust - linear) / quadratic; 1233 } 1234 } // namespace 1235 1236 // Note, this op works on cpu only. 1237 template <typename T, typename Tindex> 1238 class SparseApplyAdagradOp : public OpKernel { 1239 public: 1240 explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1241 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1242 } 1243 1244 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 1245 auto locks = 1246 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 1247 Tensor var; 1248 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 1249 ctx, 0, use_exclusive_lock_, true, &var)); 1250 Tensor accum; 1251 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 1252 ctx, 1, use_exclusive_lock_, true, &accum)); 1253 OP_REQUIRES( 1254 ctx, var.IsInitialized(), 1255 errors::FailedPrecondition( 1256 "Attempting to use uninitialized variables: ", requested_input(0))); 1257 OP_REQUIRES( 1258 ctx, accum.IsInitialized(), 1259 errors::FailedPrecondition( 1260 "Attempting to use uninitialized variables: ", requested_input(1))); 1261 OP_REQUIRES( 1262 ctx, var.shape().IsSameSize(accum.shape()), 1263 errors::InvalidArgument("var and accum do not have the same shape", 1264 var.shape().DebugString(), " ", 1265 accum.shape().DebugString())); 1266 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 1267 errors::InvalidArgument("var must be at least 1 dimensional")); 1268 1269 const Tensor& lr = ctx->input(2); 1270 OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), 1271 errors::InvalidArgument("lr is not a scalar: ", 1272 lr.shape().DebugString())); 1273 const Tensor& grad = ctx->input(3); 1274 const Tensor& indices = ctx->input(4); 1275 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 1276 errors::InvalidArgument("indices must be one-dimensional")); 1277 1278 int64 inner_dim = 1; 1279 for (int d = 1; d < var.dims(); d++) { 1280 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 1281 errors::InvalidArgument(strings::StrCat( 1282 "var and grad must match in dimension ", d))); 1283 inner_dim *= grad.dim_size(d); 1284 } 1285 const Tindex N = indices.dim_size(0); 1286 OP_REQUIRES( 1287 ctx, grad.dim_size(0) == N, 1288 errors::InvalidArgument( 1289 "grad must be the same size as indices in the first dimension.")); 1290 1291 OP_REQUIRES(ctx, inner_dim > 0, 1292 errors::InvalidArgument( 1293 "Inner dimension should be greater than zero.")); 1294 1295 if (N > 0) { 1296 if (inner_dim > 1) { 1297 const Tindex first_dim_size = var.dim_size(0); 1298 auto indices_vec = indices.vec<Tindex>(); 1299 auto var_flat = var.flat_outer_dims<T>(); 1300 auto accum_flat = accum.flat_outer_dims<T>(); 1301 auto grad_flat = grad.flat_outer_dims<T>(); 1302 T lr_scalar = lr.scalar<T>()(); 1303 1304 // Note(yonghui): It might be worth multi-threading square() and 1305 // rsqrt(). 1306 for (Tindex i = 0; i < N; i++) { 1307 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1308 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1309 errors::InvalidArgument( 1310 strings::StrCat("Index ", index, " at offset ", i, 1311 " in indices is out of range"))); 1312 auto a = accum_flat.template chip<0>(index); 1313 auto g = grad_flat.template chip<0>(i); 1314 auto v = var_flat.template chip<0>(index); 1315 a += g.square(); 1316 v -= g.constant(lr_scalar) * g * a.rsqrt(); 1317 } 1318 } else { 1319 auto indices_vec = indices.vec<Tindex>(); 1320 auto var_flat = var.flat<T>(); 1321 auto accum_flat = accum.flat<T>(); 1322 auto grad_flat = grad.flat<T>(); 1323 T lr_scalar = lr.scalar<T>()(); 1324 const Tindex first_dim_size = accum_flat.size(); 1325 1326 for (Tindex i = 0; i < N; i++) { 1327 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1328 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1329 errors::InvalidArgument( 1330 strings::StrCat("Index ", index, " at offset ", i, 1331 " in indices is out of range"))); 1332 T& a = accum_flat(index); 1333 const T& g = grad_flat(i); 1334 a += g * g; 1335 var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a); 1336 } 1337 } 1338 } 1339 1340 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1341 } 1342 1343 private: 1344 bool use_exclusive_lock_; 1345 }; 1346 1347 #define REGISTER_KERNELS(T, Tindices) \ 1348 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \ 1349 .Device(DEVICE_CPU) \ 1350 .TypeConstraint<T>("T") \ 1351 .TypeConstraint<Tindices>("Tindices"), \ 1352 SparseApplyAdagradOp<T, Tindices>); \ 1353 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagrad") \ 1354 .Device(DEVICE_CPU) \ 1355 .TypeConstraint<T>("T") \ 1356 .TypeConstraint<Tindices>("Tindices"), \ 1357 SparseApplyAdagradOp<T, Tindices>); 1358 #define REGISTER_CPU_KERNELS(T) \ 1359 REGISTER_KERNELS(T, int32); \ 1360 REGISTER_KERNELS(T, int64); 1361 1362 TF_CALL_half(REGISTER_CPU_KERNELS); 1363 TF_CALL_float(REGISTER_CPU_KERNELS); 1364 TF_CALL_double(REGISTER_CPU_KERNELS); 1365 1366 #undef REGISTER_CPU_KERNELS 1367 #undef REGISTER_KERNELS 1368 1369 // Note, this op works on cpu only. 1370 template <typename T, typename Tindex> 1371 class SparseApplyProximalAdagradOp : public OpKernel { 1372 public: 1373 explicit SparseApplyProximalAdagradOp(OpKernelConstruction* ctx) 1374 : OpKernel(ctx) { 1375 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1376 } 1377 1378 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 1379 auto locks = 1380 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 1381 Tensor var; 1382 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 1383 ctx, 0, use_exclusive_lock_, true, &var)); 1384 Tensor accum; 1385 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 1386 ctx, 1, use_exclusive_lock_, true, &accum)); 1387 OP_REQUIRES( 1388 ctx, var.IsInitialized(), 1389 errors::FailedPrecondition( 1390 "Attempting to use uninitialized variables: ", requested_input(0))); 1391 OP_REQUIRES( 1392 ctx, accum.IsInitialized(), 1393 errors::FailedPrecondition( 1394 "Attempting to use uninitialized variables: ", requested_input(1))); 1395 OP_REQUIRES( 1396 ctx, var.shape().IsSameSize(accum.shape()), 1397 errors::InvalidArgument("var and accum do not have the same shape", 1398 var.shape().DebugString(), " ", 1399 accum.shape().DebugString())); 1400 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 1401 errors::InvalidArgument("var must be at least 1 dimensional")); 1402 1403 const Tensor& lr = ctx->input(2); 1404 OP_REQUIRES(ctx, 1405 TensorShapeUtils::IsScalar(lr.shape()) && 1406 lr.scalar<T>()() > static_cast<T>(0), 1407 errors::InvalidArgument("lr is not a positive scalar: ", 1408 lr.shape().DebugString())); 1409 const Tensor& l1 = ctx->input(3); 1410 OP_REQUIRES(ctx, 1411 TensorShapeUtils::IsScalar(l1.shape()) && 1412 l1.scalar<T>()() >= static_cast<T>(0), 1413 errors::InvalidArgument("l1 regularization strength is not a " 1414 "non-negative scalar: ", 1415 l1.shape().DebugString())); 1416 const Tensor& l2 = ctx->input(4); 1417 OP_REQUIRES(ctx, 1418 TensorShapeUtils::IsScalar(l2.shape()) && 1419 l2.scalar<T>()() >= static_cast<T>(0), 1420 errors::InvalidArgument("l2 regularization strength is not a " 1421 "non-negative scalar: ", 1422 l2.shape().DebugString())); 1423 1424 const Tensor& grad = ctx->input(5); 1425 const Tensor& indices = ctx->input(6); 1426 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 1427 errors::InvalidArgument("indices must be one-dimensional")); 1428 1429 int64 inner_dim = 1; 1430 for (int d = 1; d < var.dims(); d++) { 1431 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 1432 errors::InvalidArgument(strings::StrCat( 1433 "var and grad must match in dimension ", d))); 1434 inner_dim *= grad.dim_size(d); 1435 } 1436 const Tindex N = indices.dim_size(0); 1437 OP_REQUIRES( 1438 ctx, grad.dim_size(0) == N, 1439 errors::InvalidArgument( 1440 "grad must be the same size as indices in the first dimension.")); 1441 1442 OP_REQUIRES(ctx, inner_dim > 0, 1443 errors::InvalidArgument( 1444 "Inner dimension should be greater than zero.")); 1445 1446 if (N > 0) { 1447 if (inner_dim > 1) { 1448 const Tindex first_dim_size = var.dim_size(0); 1449 auto indices_vec = indices.vec<Tindex>(); 1450 auto var_flat = var.flat_outer_dims<T>(); 1451 auto accum_flat = accum.flat_outer_dims<T>(); 1452 auto grad_flat = grad.flat_outer_dims<T>(); 1453 T lr_scalar = lr.scalar<T>()(); 1454 T l1_scalar = l1.scalar<T>()(); 1455 T l2_scalar = l2.scalar<T>()(); 1456 1457 for (Tindex i = 0; i < N; i++) { 1458 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1459 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1460 errors::InvalidArgument( 1461 strings::StrCat("Index ", index, " at offset ", i, 1462 " in indices is out of range"))); 1463 auto a = accum_flat.template chip<0>(index); 1464 auto g = grad_flat.template chip<0>(i); 1465 auto v = var_flat.template chip<0>(index); 1466 a += g.square(); 1467 // compute learning_rate for current step. 1468 auto learning_rate = a.constant(lr_scalar) * a.rsqrt(); 1469 auto prox_v = v; 1470 // v = w - g * learning_rate. 1471 prox_v -= g * learning_rate; 1472 if (l1_scalar > 0) { 1473 // compute sign(v) * max(|v|, 0) 1474 v = prox_v.sign() * 1475 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar)) 1476 .cwiseMax(static_cast<T>(0.0)) / 1477 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); 1478 } else { 1479 v = prox_v / 1480 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate); 1481 } 1482 } 1483 } else { 1484 auto indices_vec = indices.vec<Tindex>(); 1485 auto var_flat = var.flat<T>(); 1486 auto accum_flat = accum.flat<T>(); 1487 auto grad_flat = grad.flat<T>(); 1488 T lr_scalar = lr.scalar<T>()(); 1489 T l1_scalar = l1.scalar<T>()(); 1490 T l2_scalar = l2.scalar<T>()(); 1491 const Tindex first_dim_size = accum_flat.size(); 1492 1493 for (Tindex i = 0; i < N; i++) { 1494 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1495 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1496 errors::InvalidArgument( 1497 strings::StrCat("Index ", index, " at offset ", i, 1498 " in indices is out of range"))); 1499 T& a = accum_flat(index); 1500 const T& g = grad_flat(i); 1501 a += g * g; 1502 auto learning_rate = lr_scalar / std::sqrt(a); 1503 auto prox_v = var_flat(index); 1504 prox_v -= learning_rate * g; 1505 if (l1_scalar > 0) { 1506 var_flat(index) = 1507 sgn(prox_v) * 1508 std::max(std::abs(prox_v) - learning_rate * l1_scalar, 1509 static_cast<T>(0.0)) / 1510 (1.0 + l2_scalar * learning_rate); 1511 } else { 1512 var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate); 1513 } 1514 } 1515 } 1516 } 1517 1518 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1519 } 1520 1521 private: 1522 bool use_exclusive_lock_; 1523 }; 1524 1525 #define REGISTER_KERNELS(T, Tindices) \ 1526 REGISTER_KERNEL_BUILDER(Name("SparseApplyProximalAdagrad") \ 1527 .Device(DEVICE_CPU) \ 1528 .TypeConstraint<T>("T") \ 1529 .TypeConstraint<Tindices>("Tindices"), \ 1530 SparseApplyProximalAdagradOp<T, Tindices>); \ 1531 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyProximalAdagrad") \ 1532 .Device(DEVICE_CPU) \ 1533 .TypeConstraint<T>("T") \ 1534 .TypeConstraint<Tindices>("Tindices"), \ 1535 SparseApplyProximalAdagradOp<T, Tindices>); 1536 1537 REGISTER_KERNELS(float, int32); 1538 REGISTER_KERNELS(float, int64); 1539 REGISTER_KERNELS(double, int32); 1540 REGISTER_KERNELS(double, int64); 1541 #undef REGISTER_KERNELS 1542 1543 template <typename Device, typename T> 1544 class ApplyAdagradDAOp : public OpKernel { 1545 public: 1546 explicit ApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1547 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1548 } 1549 1550 void Compute(OpKernelContext* ctx) override { 1551 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 1552 {0, 1, 2}); 1553 Tensor var; 1554 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1555 ctx, 0, use_exclusive_lock_, false, &var)); 1556 Tensor gradient_accum; 1557 OP_REQUIRES_OK( 1558 ctx, GetInputTensorFromVariable<Device, T>(ctx, 1, use_exclusive_lock_, 1559 false, &gradient_accum)); 1560 Tensor gradient_squared_accum; 1561 OP_REQUIRES_OK( 1562 ctx, GetInputTensorFromVariable<Device, T>( 1563 ctx, 2, use_exclusive_lock_, false, &gradient_squared_accum)); 1564 OP_REQUIRES( 1565 ctx, var.IsInitialized(), 1566 errors::FailedPrecondition( 1567 "Attempting to use uninitialized variables: ", requested_input(0))); 1568 OP_REQUIRES( 1569 ctx, gradient_accum.IsInitialized(), 1570 errors::FailedPrecondition( 1571 "Attempting to use uninitialized variables: ", requested_input(1))); 1572 OP_REQUIRES( 1573 ctx, gradient_squared_accum.IsInitialized(), 1574 errors::FailedPrecondition( 1575 "Attempting to use uninitialized variables: ", requested_input(2))); 1576 OP_REQUIRES( 1577 ctx, var.shape().IsSameSize(gradient_accum.shape()), 1578 errors::InvalidArgument("var and accum do not have the same shape", 1579 var.shape().DebugString(), " ", 1580 gradient_accum.shape().DebugString())); 1581 OP_REQUIRES( 1582 ctx, var.shape().IsSameSize(gradient_squared_accum.shape()), 1583 errors::InvalidArgument("var and accum do not have the same shape", 1584 var.shape().DebugString(), " ", 1585 gradient_squared_accum.shape().DebugString())); 1586 1587 const Tensor& grad = ctx->input(3); 1588 OP_REQUIRES( 1589 ctx, var.shape().IsSameSize(grad.shape()), 1590 errors::InvalidArgument("var and grad do not have the same shape", 1591 var.shape().DebugString(), " ", 1592 grad.shape().DebugString())); 1593 1594 const Tensor& lr = ctx->input(4); 1595 OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), 1596 errors::InvalidArgument("lr is not a scalar: ", 1597 lr.shape().DebugString())); 1598 const Tensor& l1 = ctx->input(5); 1599 OP_REQUIRES( 1600 ctx, TensorShapeUtils::IsScalar(l1.shape()), 1601 errors::InvalidArgument("l1 regularization strength is not a scalar: ", 1602 l1.shape().DebugString())); 1603 const Tensor& l2 = ctx->input(6); 1604 OP_REQUIRES( 1605 ctx, TensorShapeUtils::IsScalar(l2.shape()), 1606 errors::InvalidArgument("l2 regularization strength is not a scalar: ", 1607 l2.shape().DebugString())); 1608 const Tensor& global_step = ctx->input(7); 1609 OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()), 1610 errors::InvalidArgument("global_step is not a scalar: ", 1611 global_step.shape().DebugString())); 1612 1613 const Device& device = ctx->template eigen_device<Device>(); 1614 functor::ApplyAdagradDA<Device, T>()( 1615 device, var.flat<T>(), gradient_accum.flat<T>(), 1616 gradient_squared_accum.flat<T>(), lr.scalar<T>(), 1617 global_step.scalar<int64>()(), l1.scalar<T>(), l2.scalar<T>(), 1618 grad.flat<T>()); 1619 1620 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1621 } 1622 1623 private: 1624 bool use_exclusive_lock_; 1625 }; 1626 1627 #define REGISTER_KERNELS(D, T) \ 1628 REGISTER_KERNEL_BUILDER( \ 1629 Name("ApplyAdagradDA").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 1630 ApplyAdagradDAOp<D##Device, T>); \ 1631 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdagradDA") \ 1632 .Device(DEVICE_##D) \ 1633 .HostMemory("var") \ 1634 .HostMemory("gradient_accumulator") \ 1635 .HostMemory("gradient_squared_accumulator") \ 1636 .TypeConstraint<T>("T"), \ 1637 ApplyAdagradDAOp<D##Device, T>); 1638 1639 REGISTER_KERNELS(CPU, float); 1640 REGISTER_KERNELS(CPU, double); 1641 #undef REGISTER_KERNELS 1642 1643 // Note, this op works on cpu only. 1644 template <typename T, typename Tindex> 1645 class SparseApplyAdagradDAOp : public OpKernel { 1646 public: 1647 explicit SparseApplyAdagradDAOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1648 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1649 } 1650 1651 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 1652 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 1653 {0, 1, 2}); 1654 Tensor var; 1655 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 1656 ctx, 0, use_exclusive_lock_, true, &var)); 1657 Tensor gradient_accum; 1658 OP_REQUIRES_OK(ctx, 1659 GetInputTensorFromVariable<CPUDevice, T>( 1660 ctx, 1, use_exclusive_lock_, true, &gradient_accum)); 1661 Tensor gradient_squared_accum; 1662 OP_REQUIRES_OK( 1663 ctx, GetInputTensorFromVariable<CPUDevice, T>( 1664 ctx, 2, use_exclusive_lock_, true, &gradient_squared_accum)); 1665 OP_REQUIRES( 1666 ctx, var.IsInitialized(), 1667 errors::FailedPrecondition( 1668 "Attempting to use uninitialized variables: ", requested_input(0))); 1669 OP_REQUIRES( 1670 ctx, gradient_accum.IsInitialized(), 1671 errors::FailedPrecondition( 1672 "Attempting to use uninitialized variables: ", requested_input(1))); 1673 OP_REQUIRES( 1674 ctx, gradient_squared_accum.IsInitialized(), 1675 errors::FailedPrecondition( 1676 "Attempting to use uninitialized variables: ", requested_input(2))); 1677 OP_REQUIRES( 1678 ctx, var.shape().IsSameSize(gradient_accum.shape()), 1679 errors::InvalidArgument("var and accum do not have the same shape", 1680 var.shape().DebugString(), " ", 1681 gradient_accum.shape().DebugString())); 1682 OP_REQUIRES( 1683 ctx, var.shape().IsSameSize(gradient_squared_accum.shape()), 1684 errors::InvalidArgument("var and accum do not have the same shape", 1685 var.shape().DebugString(), " ", 1686 gradient_squared_accum.shape().DebugString())); 1687 1688 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 1689 errors::InvalidArgument("var must be at least 1 dimensional")); 1690 1691 const Tensor& grad = ctx->input(3); 1692 const Tensor& indices = ctx->input(4); 1693 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 1694 errors::InvalidArgument("indices must be one-dimensional")); 1695 1696 const Tensor& lr = ctx->input(5); 1697 OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), 1698 errors::InvalidArgument("lr is not a scalar: ", 1699 lr.shape().DebugString())); 1700 1701 const Tensor& l1 = ctx->input(6); 1702 OP_REQUIRES( 1703 ctx, TensorShapeUtils::IsScalar(l1.shape()), 1704 errors::InvalidArgument("l1 regularization strength is not a scalar: ", 1705 l1.shape().DebugString())); 1706 1707 const Tensor& l2 = ctx->input(7); 1708 OP_REQUIRES( 1709 ctx, TensorShapeUtils::IsScalar(l2.shape()), 1710 errors::InvalidArgument("l2 regularization strength is not a scalar: ", 1711 l2.shape().DebugString())); 1712 1713 const Tensor& global_step = ctx->input(8); 1714 OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()), 1715 errors::InvalidArgument("global_step is not a scalar: ", 1716 global_step.shape().DebugString())); 1717 1718 int64 inner_dim = 1; 1719 for (int d = 1; d < var.dims(); d++) { 1720 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 1721 errors::InvalidArgument(strings::StrCat( 1722 "var and grad must match in dimension ", d))); 1723 inner_dim *= grad.dim_size(d); 1724 } 1725 const Tindex N = indices.dim_size(0); 1726 OP_REQUIRES( 1727 ctx, grad.dim_size(0) == N, 1728 errors::InvalidArgument( 1729 "grad must be the same size as indices in the first dimension.")); 1730 1731 OP_REQUIRES(ctx, inner_dim > 0, 1732 errors::InvalidArgument( 1733 "Inner dimension should be greater than zero.")); 1734 1735 // AdagradDA update: 1736 // Let g to be gradient accumulator, gg to be gradient squared accumulator, 1737 // T be the global step, lr is the learning rate, and k the initial 1738 // gradient squared accumulator value. 1739 // w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} 1740 if (N > 0) { 1741 if (inner_dim > 1) { 1742 const Tindex first_dim_size = var.dim_size(0); 1743 auto indices_vec = indices.vec<Tindex>(); 1744 auto var_flat = var.flat_outer_dims<T>(); 1745 auto gradient_accum_flat = gradient_accum.flat_outer_dims<T>(); 1746 auto gradient_squared_accum_flat = 1747 gradient_squared_accum.flat_outer_dims<T>(); 1748 auto grad_flat = grad.flat_outer_dims<T>(); 1749 T lr_scalar = lr.scalar<T>()(); 1750 T global_step_scalar = global_step.scalar<int64>()(); 1751 T l1_scalar = l1.scalar<T>()(); 1752 T l2_scalar = l2.scalar<T>()(); 1753 const double gs_lr = global_step_scalar * lr_scalar; 1754 1755 for (Tindex i = 0; i < N; i++) { 1756 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1757 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1758 errors::InvalidArgument( 1759 strings::StrCat("Index ", index, " at offset ", i, 1760 " in indices is out of range"))); 1761 auto ga = gradient_accum_flat.template chip<0>(index); 1762 auto da = gradient_squared_accum_flat.template chip<0>(index); 1763 auto g = grad_flat.template chip<0>(i); 1764 auto v = var_flat.template chip<0>(index); 1765 ga += g; 1766 da += g.square(); 1767 if (l1_scalar > 0) { 1768 v = ga.constant(-1.0) * ga.sign() * 1769 ((ga.abs() / ga.constant(global_step_scalar)) - 1770 ga.constant(l1_scalar)) 1771 .cwiseMax(static_cast<T>(0.0)) / 1772 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr)); 1773 } else { 1774 v = ga.constant(-1.0) * (ga / ga.constant(global_step_scalar)) / 1775 (v.constant(l2_scalar) + da.sqrt() / v.constant(gs_lr)); 1776 } 1777 } 1778 } else { 1779 auto indices_vec = indices.vec<Tindex>(); 1780 auto var_flat = var.flat<T>(); 1781 auto gradient_accum_flat = gradient_accum.flat<T>(); 1782 auto gradient_squared_accum_flat = gradient_squared_accum.flat<T>(); 1783 auto grad_flat = grad.flat<T>(); 1784 const double lr_scalar = lr.scalar<T>()(); 1785 const int64 global_step_scalar = global_step.scalar<int64>()(); 1786 const double l1_scalar = l1.scalar<T>()(); 1787 const double l2_scalar = l2.scalar<T>()(); 1788 const Tindex first_dim_size = var_flat.size(); 1789 const double gs_l1 = global_step_scalar * l1_scalar; 1790 const double gs_l2_lr = global_step_scalar * l2_scalar * lr_scalar; 1791 1792 for (Tindex i = 0; i < N; i++) { 1793 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 1794 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 1795 errors::InvalidArgument( 1796 strings::StrCat("Index ", index, " at offset ", i, 1797 " in indices is out of range"))); 1798 T& ga = gradient_accum_flat(index); 1799 T& da = gradient_squared_accum_flat(index); 1800 const double g = grad_flat(i); 1801 ga += g; 1802 da += g * g; 1803 if (l1_scalar > 0) { 1804 var_flat(index) = sgn(-ga) * lr_scalar * 1805 std::max((std::abs(ga) - gs_l1), 0.0) / 1806 (gs_l2_lr + std::sqrt(da)); 1807 } else { 1808 var_flat(index) = (-ga * lr_scalar) / (gs_l2_lr + std::sqrt(da)); 1809 } 1810 } 1811 } 1812 } 1813 1814 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1815 } 1816 1817 private: 1818 bool use_exclusive_lock_; 1819 }; 1820 1821 #define REGISTER_KERNELS(T, Tindices) \ 1822 REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagradDA") \ 1823 .Device(DEVICE_CPU) \ 1824 .TypeConstraint<T>("T") \ 1825 .TypeConstraint<Tindices>("Tindices"), \ 1826 SparseApplyAdagradDAOp<T, Tindices>); \ 1827 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyAdagradDA") \ 1828 .Device(DEVICE_CPU) \ 1829 .HostMemory("var") \ 1830 .HostMemory("gradient_accumulator") \ 1831 .HostMemory("gradient_squared_accumulator") \ 1832 .TypeConstraint<T>("T") \ 1833 .TypeConstraint<Tindices>("Tindices"), \ 1834 SparseApplyAdagradDAOp<T, Tindices>); 1835 1836 REGISTER_KERNELS(float, int32); 1837 REGISTER_KERNELS(float, int64); 1838 REGISTER_KERNELS(double, int32); 1839 REGISTER_KERNELS(double, int64); 1840 #undef REGISTER_KERNELS 1841 1842 template <typename Device, typename T, bool has_l2_shrinkage> 1843 class ApplyFtrlOp : public OpKernel { 1844 public: 1845 explicit ApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1846 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1847 } 1848 1849 void Compute(OpKernelContext* ctx) override { 1850 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 1851 {0, 1, 2}); 1852 1853 Tensor var; 1854 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1855 ctx, 0, use_exclusive_lock_, false, &var)); 1856 Tensor accum; 1857 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1858 ctx, 1, use_exclusive_lock_, false, &accum)); 1859 Tensor linear; 1860 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 1861 ctx, 2, use_exclusive_lock_, false, &linear)); 1862 OP_REQUIRES( 1863 ctx, var.IsInitialized(), 1864 errors::FailedPrecondition( 1865 "Attempting to use uninitialized variables: ", requested_input(0))); 1866 OP_REQUIRES( 1867 ctx, accum.IsInitialized(), 1868 errors::FailedPrecondition( 1869 "Attempting to use uninitialized variables: ", requested_input(1))); 1870 OP_REQUIRES( 1871 ctx, linear.IsInitialized(), 1872 errors::FailedPrecondition( 1873 "Attempting to use uninitialized variables: ", requested_input(2))); 1874 1875 const Tensor& grad = ctx->input(3); 1876 OP_REQUIRES( 1877 ctx, var.shape().IsSameSize(accum.shape()), 1878 errors::InvalidArgument("var and accum do not have the same shape", 1879 var.shape().DebugString(), " ", 1880 accum.shape().DebugString())); 1881 OP_REQUIRES( 1882 ctx, var.shape().IsSameSize(linear.shape()), 1883 errors::InvalidArgument("var and linear do not have the same shape", 1884 var.shape().DebugString(), " ", 1885 linear.shape().DebugString())); 1886 OP_REQUIRES( 1887 ctx, var.shape().IsSameSize(grad.shape()), 1888 errors::InvalidArgument("var and grad do not have the same shape", 1889 var.shape().DebugString(), " ", 1890 grad.shape().DebugString())); 1891 1892 const Tensor& lr = ctx->input(4); 1893 OP_REQUIRES(ctx, 1894 TensorShapeUtils::IsScalar(lr.shape()) && 1895 lr.scalar<T>()() > static_cast<T>(0), 1896 errors::InvalidArgument("lr is not a positive scalar: ", 1897 lr.shape().DebugString())); 1898 const Tensor& l1 = ctx->input(5); 1899 OP_REQUIRES(ctx, 1900 TensorShapeUtils::IsScalar(l1.shape()) && 1901 l1.scalar<T>()() >= static_cast<T>(0), 1902 errors::InvalidArgument("l1 regularization strength is not a " 1903 "non-negative scalar: ", 1904 l1.shape().DebugString())); 1905 const Tensor& l2 = ctx->input(6); 1906 OP_REQUIRES(ctx, 1907 TensorShapeUtils::IsScalar(l2.shape()) && 1908 l2.scalar<T>()() >= static_cast<T>(0), 1909 errors::InvalidArgument("l2 regularization strength is not a " 1910 "non-negative scalar: ", 1911 l2.shape().DebugString())); 1912 const int lr_power_index = has_l2_shrinkage ? 8 : 7; 1913 const Tensor& lr_power = ctx->input(lr_power_index); 1914 OP_REQUIRES(ctx, 1915 TensorShapeUtils::IsScalar(lr_power.shape()) && 1916 lr_power.scalar<T>()() <= static_cast<T>(0), 1917 errors::InvalidArgument("lr_power is not a" 1918 " non-positive scalar: ", 1919 lr_power.shape().DebugString())); 1920 1921 const Device& device = ctx->template eigen_device<Device>(); 1922 if (has_l2_shrinkage) { 1923 const Tensor& l2_shrinkage = ctx->input(7); 1924 OP_REQUIRES( 1925 ctx, 1926 TensorShapeUtils::IsScalar(l2_shrinkage.shape()) && 1927 l2_shrinkage.scalar<T>()() >= static_cast<T>(0), 1928 errors::InvalidArgument("l2 shrinkage regularization strength " 1929 "is not a non-negative scalar: ", 1930 l2_shrinkage.shape().DebugString())); 1931 functor::ApplyFtrlV2<Device, T>()( 1932 device, var.flat<T>(), accum.flat<T>(), linear.flat<T>(), 1933 grad.flat<T>(), lr.scalar<T>(), l1.scalar<T>(), l2.scalar<T>(), 1934 l2_shrinkage.scalar<T>(), lr_power.scalar<T>()); 1935 } else { 1936 functor::ApplyFtrl<Device, T>()(device, var.flat<T>(), accum.flat<T>(), 1937 linear.flat<T>(), grad.flat<T>(), 1938 lr.scalar<T>(), l1.scalar<T>(), 1939 l2.scalar<T>(), lr_power.scalar<T>()); 1940 } 1941 1942 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 1943 } 1944 1945 private: 1946 bool use_exclusive_lock_; 1947 }; 1948 1949 #define REGISTER_KERNELS(D, T) \ 1950 REGISTER_KERNEL_BUILDER( \ 1951 Name("ApplyFtrl").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 1952 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>); \ 1953 REGISTER_KERNEL_BUILDER( \ 1954 Name("ResourceApplyFtrl") \ 1955 .HostMemory("var") \ 1956 .HostMemory("accum") \ 1957 .HostMemory("linear") \ 1958 .Device(DEVICE_##D) \ 1959 .TypeConstraint<T>("T"), \ 1960 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/false>); 1961 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 1962 1963 TF_CALL_half(REGISTER_CPU_KERNELS); 1964 TF_CALL_float(REGISTER_CPU_KERNELS); 1965 TF_CALL_double(REGISTER_CPU_KERNELS); 1966 1967 #undef REGISTER_CPU_KERNELS 1968 #undef REGISTER_KERNELS 1969 1970 #define REGISTER_KERNELS(D, T) \ 1971 REGISTER_KERNEL_BUILDER( \ 1972 Name("ApplyFtrlV2").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 1973 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>); \ 1974 REGISTER_KERNEL_BUILDER( \ 1975 Name("ResourceApplyFtrlV2") \ 1976 .HostMemory("var") \ 1977 .HostMemory("accum") \ 1978 .HostMemory("linear") \ 1979 .Device(DEVICE_##D) \ 1980 .TypeConstraint<T>("T"), \ 1981 ApplyFtrlOp<D##Device, T, /*has_l2_shrinkage=*/true>); 1982 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 1983 1984 TF_CALL_half(REGISTER_CPU_KERNELS); 1985 TF_CALL_float(REGISTER_CPU_KERNELS); 1986 TF_CALL_double(REGISTER_CPU_KERNELS); 1987 1988 #undef REGISTER_CPU_KERNELS 1989 #undef REGISTER_KERNELS 1990 1991 // Note, this op works on cpu only. 1992 template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage> 1993 class SparseApplyFtrlOp : public OpKernel { 1994 public: 1995 explicit SparseApplyFtrlOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 1996 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 1997 } 1998 1999 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 2000 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2001 {0, 1, 2}); 2002 Tensor var; 2003 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2004 ctx, 0, use_exclusive_lock_, true, &var)); 2005 Tensor accum; 2006 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2007 ctx, 1, use_exclusive_lock_, true, &accum)); 2008 Tensor linear; 2009 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2010 ctx, 2, use_exclusive_lock_, true, &linear)); 2011 OP_REQUIRES( 2012 ctx, var.IsInitialized(), 2013 errors::FailedPrecondition( 2014 "Attempting to use uninitialized variables: ", requested_input(0))); 2015 OP_REQUIRES( 2016 ctx, accum.IsInitialized(), 2017 errors::FailedPrecondition( 2018 "Attempting to use uninitialized variables: ", requested_input(1))); 2019 OP_REQUIRES( 2020 ctx, linear.IsInitialized(), 2021 errors::FailedPrecondition( 2022 "Attempting to use uninitialized variables: ", requested_input(2))); 2023 OP_REQUIRES( 2024 ctx, var.shape().IsSameSize(accum.shape()), 2025 errors::InvalidArgument("var and accum do not have the same shape", 2026 var.shape().DebugString(), " ", 2027 accum.shape().DebugString())); 2028 OP_REQUIRES( 2029 ctx, var.shape().IsSameSize(linear.shape()), 2030 errors::InvalidArgument("var and linear do not have the same shape", 2031 var.shape().DebugString(), " ", 2032 linear.shape().DebugString())); 2033 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 2034 errors::InvalidArgument("var must be at least 1 dimensional")); 2035 2036 const Tensor& grad = ctx->input(3); 2037 const Tensor& indices = ctx->input(4); 2038 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 2039 errors::InvalidArgument("indices must be one-dimensional")); 2040 2041 const Tensor& lr = ctx->input(5); 2042 OP_REQUIRES(ctx, 2043 TensorShapeUtils::IsScalar(lr.shape()) && 2044 lr.scalar<T>()() > static_cast<T>(0), 2045 errors::InvalidArgument("lr is not a positive scalar: ", 2046 lr.shape().DebugString())); 2047 2048 const Tensor& l1 = ctx->input(6); 2049 OP_REQUIRES(ctx, 2050 TensorShapeUtils::IsScalar(l1.shape()) && 2051 l1.scalar<T>()() >= static_cast<T>(0), 2052 errors::InvalidArgument("l1 regularization strength is not a " 2053 "non-negative scalar: ", 2054 l1.shape().DebugString())); 2055 const Tensor& l2 = ctx->input(7); 2056 OP_REQUIRES(ctx, 2057 TensorShapeUtils::IsScalar(l2.shape()) && 2058 l2.scalar<T>()() >= static_cast<T>(0), 2059 errors::InvalidArgument("l2 regularization strength is not a " 2060 "non-negative scalar: ", 2061 l2.shape().DebugString())); 2062 const int lr_power_index = has_l2_shrinkage ? 9 : 8; 2063 const Tensor& lr_power = ctx->input(lr_power_index); 2064 OP_REQUIRES(ctx, 2065 TensorShapeUtils::IsScalar(lr_power.shape()) && 2066 lr_power.scalar<T>()() <= static_cast<T>(0), 2067 errors::InvalidArgument("lr_power is not a " 2068 "non-positive scalar: ", 2069 lr_power.shape().DebugString())); 2070 int64 inner_dim = 1; 2071 for (int d = 1; d < var.dims(); d++) { 2072 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 2073 errors::InvalidArgument(strings::StrCat( 2074 "var and grad must match in dimension ", d))); 2075 inner_dim *= grad.dim_size(d); 2076 } 2077 const Tindex N = indices.dim_size(0); 2078 OP_REQUIRES( 2079 ctx, grad.dim_size(0) == N, 2080 errors::InvalidArgument( 2081 "grad must be the same size as indices in the first dimension.")); 2082 2083 OP_REQUIRES(ctx, inner_dim > 0, 2084 errors::InvalidArgument( 2085 "Inner dimension should be greater than zero.")); 2086 2087 const Tensor* l2_shrinkage; 2088 if (has_l2_shrinkage) { 2089 l2_shrinkage = &ctx->input(8); 2090 OP_REQUIRES( 2091 ctx, 2092 TensorShapeUtils::IsScalar(l2_shrinkage->shape()) && 2093 l2_shrinkage->scalar<T>()() >= static_cast<T>(0), 2094 errors::InvalidArgument("l2 shrinkage regularization strength " 2095 "is not a non-negative scalar: ", 2096 l2_shrinkage->shape().DebugString())); 2097 } 2098 2099 if (N > 0) { 2100 if (inner_dim > 1) { 2101 const Tindex first_dim_size = var.dim_size(0); 2102 auto indices_vec = indices.vec<Tindex>(); 2103 auto var_flat = var.flat_outer_dims<T>(); 2104 auto accum_flat = accum.flat_outer_dims<T>(); 2105 auto linear_flat = linear.flat_outer_dims<T>(); 2106 auto grad_flat = grad.flat_outer_dims<T>(); 2107 T lr_scalar = lr.scalar<T>()(); 2108 T l1_scalar = l1.scalar<T>()(); 2109 T l2_scalar = l2.scalar<T>()(); 2110 T l2_shrinkage_scalar; 2111 if (has_l2_shrinkage) { 2112 l2_shrinkage_scalar = l2_shrinkage->scalar<T>()(); 2113 } 2114 T lr_power_scalar = lr_power.scalar<T>()(); 2115 2116 for (Tindex i = 0; i < N; i++) { 2117 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 2118 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 2119 errors::InvalidArgument( 2120 strings::StrCat("Index ", index, " at offset ", i, 2121 " in indices is out of range"))); 2122 auto accum = accum_flat.template chip<0>(index); 2123 auto linear = linear_flat.template chip<0>(index); 2124 auto grad = grad_flat.template chip<0>(i); 2125 auto var = var_flat.template chip<0>(index); 2126 2127 // Use a macro to implement the computation here due to the templating of the 2128 // eigen tensor library. 2129 #define COMPUTE_FTRL(grad_to_use) \ 2130 auto new_accum = accum + grad_to_use.square(); \ 2131 if (lr_power_scalar == static_cast<T>(-0.5)) { \ 2132 linear += \ 2133 grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \ 2134 } else { \ 2135 linear += grad_to_use - (new_accum.pow(-lr_power_scalar) - \ 2136 accum.pow(-lr_power_scalar)) / \ 2137 lr_scalar * var; \ 2138 } \ 2139 auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \ 2140 auto x = l1_reg_adjust - linear; \ 2141 if (lr_power_scalar == static_cast<T>(-0.5)) { \ 2142 auto y = new_accum.sqrt() / new_accum.constant(lr_scalar) + \ 2143 linear.constant(static_cast<T>(2) * l2_scalar); \ 2144 var = x / y; \ 2145 } else { \ 2146 auto y = new_accum.pow(-lr_power_scalar) / new_accum.constant(lr_scalar) + \ 2147 linear.constant(static_cast<T>(2) * l2_scalar); \ 2148 var = x / y; \ 2149 } \ 2150 accum += grad_to_use.square(); 2151 2152 if (has_l2_shrinkage) { 2153 auto grad_with_shrinkage = 2154 grad + static_cast<T>(2) * l2_shrinkage_scalar * var; 2155 COMPUTE_FTRL(grad_with_shrinkage); 2156 } else { 2157 COMPUTE_FTRL(grad); 2158 } 2159 } 2160 #undef COMPUTE_FTRL 2161 } else { 2162 T lr_scalar = lr.scalar<T>()(); 2163 T l1_scalar = l1.scalar<T>()(); 2164 T l2_scalar = l2.scalar<T>()(); 2165 T lr_power_scalar = lr_power.scalar<T>()(); 2166 T l2_shrinkage_scalar; 2167 if (has_l2_shrinkage) { 2168 l2_shrinkage_scalar = l2_shrinkage->scalar<T>()(); 2169 } 2170 2171 auto indices_vec = indices.vec<Tindex>(); 2172 auto var_flat = var.flat<T>(); 2173 auto accum_flat = accum.flat<T>(); 2174 auto linear_flat = linear.flat<T>(); 2175 auto grad_flat = grad.flat<T>(); 2176 const Tindex first_dim_size = accum_flat.size(); 2177 2178 for (Tindex i = 0; i < N; i++) { 2179 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 2180 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 2181 errors::InvalidArgument( 2182 strings::StrCat("Index ", index, " at offset ", i, 2183 " in indices is out of range"))); 2184 T& a = accum_flat(index); 2185 T& l = linear_flat(index); 2186 T& v = var_flat(index); 2187 T g; 2188 if (has_l2_shrinkage) { 2189 g = grad_flat(i) + 2190 (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i)); 2191 } else { 2192 g = grad_flat(i); 2193 } 2194 2195 T updated_a = a + g * g; 2196 using Eigen::numext::pow; 2197 T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar); 2198 sigma /= lr_scalar; 2199 T updated_l = l + g - sigma * v; 2200 v = FtrlCompute(updated_a, updated_l, lr_scalar, l1_scalar, l2_scalar, 2201 lr_power_scalar); 2202 a = updated_a; 2203 l = updated_l; 2204 } 2205 } 2206 } 2207 2208 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2209 } 2210 2211 private: 2212 bool use_exclusive_lock_; 2213 }; 2214 2215 #define REGISTER_KERNELS(T, Tindices) \ 2216 REGISTER_KERNEL_BUILDER( \ 2217 Name("SparseApplyFtrl") \ 2218 .Device(DEVICE_CPU) \ 2219 .TypeConstraint<T>("T") \ 2220 .TypeConstraint<Tindices>("Tindices"), \ 2221 SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>); \ 2222 REGISTER_KERNEL_BUILDER( \ 2223 Name("ResourceSparseApplyFtrl") \ 2224 .Device(DEVICE_CPU) \ 2225 .TypeConstraint<T>("T") \ 2226 .TypeConstraint<Tindices>("Tindices"), \ 2227 SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/false>); 2228 #define REGISTER_CPU_KERNELS(T) \ 2229 REGISTER_KERNELS(T, int32); \ 2230 REGISTER_KERNELS(T, int64); 2231 2232 TF_CALL_half(REGISTER_CPU_KERNELS); 2233 TF_CALL_float(REGISTER_CPU_KERNELS); 2234 TF_CALL_double(REGISTER_CPU_KERNELS); 2235 2236 #undef REGISTER_CPU_KERNELS 2237 #undef REGISTER_KERNELS 2238 2239 #define REGISTER_KERNELS(T, Tindices) \ 2240 REGISTER_KERNEL_BUILDER( \ 2241 Name("SparseApplyFtrlV2") \ 2242 .Device(DEVICE_CPU) \ 2243 .TypeConstraint<T>("T") \ 2244 .TypeConstraint<Tindices>("Tindices"), \ 2245 SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>); \ 2246 REGISTER_KERNEL_BUILDER( \ 2247 Name("ResourceSparseApplyFtrlV2") \ 2248 .Device(DEVICE_CPU) \ 2249 .TypeConstraint<T>("T") \ 2250 .TypeConstraint<Tindices>("Tindices"), \ 2251 SparseApplyFtrlOp<CPUDevice, T, Tindices, /*has_l2_shrinkage=*/true>); 2252 #define REGISTER_CPU_KERNELS(T) \ 2253 REGISTER_KERNELS(T, int32); \ 2254 REGISTER_KERNELS(T, int64); 2255 2256 TF_CALL_half(REGISTER_CPU_KERNELS); 2257 TF_CALL_float(REGISTER_CPU_KERNELS); 2258 TF_CALL_double(REGISTER_CPU_KERNELS); 2259 2260 #undef REGISTER_CPU_KERNELS 2261 #undef REGISTER_KERNELS 2262 2263 template <typename Device, typename T> 2264 class ApplyMomentumOp : public OpKernel { 2265 public: 2266 explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2267 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2268 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); 2269 } 2270 2271 void Compute(OpKernelContext* ctx) override { 2272 auto locks = 2273 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 2274 2275 Tensor var; 2276 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2277 ctx, 0, use_exclusive_lock_, false, &var)); 2278 Tensor accum; 2279 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2280 ctx, 1, use_exclusive_lock_, false, &accum)); 2281 OP_REQUIRES( 2282 ctx, var.IsInitialized(), 2283 errors::FailedPrecondition( 2284 "Attempting to use uninitialized variables: ", requested_input(0))); 2285 OP_REQUIRES( 2286 ctx, accum.IsInitialized(), 2287 errors::FailedPrecondition( 2288 "Attempting to use uninitialized variables: ", requested_input(1))); 2289 const Tensor& lr = ctx->input(2); 2290 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 2291 errors::InvalidArgument("lr is not a scalar: ", 2292 lr.shape().DebugString())); 2293 const Tensor& grad = ctx->input(3); 2294 OP_REQUIRES( 2295 ctx, var.shape().IsSameSize(accum.shape()), 2296 errors::InvalidArgument("var and accum do not have the same shape", 2297 var.shape().DebugString(), " ", 2298 accum.shape().DebugString())); 2299 OP_REQUIRES( 2300 ctx, var.shape().IsSameSize(grad.shape()), 2301 errors::InvalidArgument("var and grad do not have the same shape", 2302 var.shape().DebugString(), " ", 2303 grad.shape().DebugString())); 2304 2305 const Tensor& momentum = ctx->input(4); 2306 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 2307 errors::InvalidArgument("momentum is not a scalar: ", 2308 momentum.shape().DebugString())); 2309 2310 const Device& device = ctx->template eigen_device<Device>(); 2311 functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(), 2312 lr.scalar<T>(), grad.flat<T>(), 2313 momentum.scalar<T>(), use_nesterov_); 2314 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2315 } 2316 2317 private: 2318 bool use_exclusive_lock_; 2319 bool use_nesterov_; 2320 }; 2321 2322 #define REGISTER_KERNELS(D, T) \ 2323 REGISTER_KERNEL_BUILDER( \ 2324 Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 2325 ApplyMomentumOp<D##Device, T>); \ 2326 REGISTER_KERNEL_BUILDER(Name("ResourceApplyMomentum") \ 2327 .Device(DEVICE_##D) \ 2328 .HostMemory("var") \ 2329 .HostMemory("accum") \ 2330 .TypeConstraint<T>("T"), \ 2331 ApplyMomentumOp<D##Device, T>); 2332 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 2333 2334 TF_CALL_half(REGISTER_CPU_KERNELS); 2335 TF_CALL_float(REGISTER_CPU_KERNELS); 2336 TF_CALL_double(REGISTER_CPU_KERNELS); 2337 2338 #if GOOGLE_CUDA 2339 // Forward declarations of the functor specializations for GPU. 2340 namespace functor { 2341 #define DECLARE_GPU_SPEC(T) \ 2342 template <> \ 2343 void ApplyMomentum<GPUDevice, T>::operator()( \ 2344 const GPUDevice& d, typename TTypes<T>::Flat var, \ 2345 typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \ 2346 typename TTypes<T>::ConstFlat grad, \ 2347 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); \ 2348 extern template struct ApplyMomentum<GPUDevice, T>; 2349 DECLARE_GPU_SPEC(Eigen::half); 2350 DECLARE_GPU_SPEC(float); 2351 DECLARE_GPU_SPEC(double); 2352 #undef DECLARE_GPU_SPEC 2353 } // namespace functor 2354 2355 REGISTER_KERNELS(GPU, Eigen::half); 2356 REGISTER_KERNELS(GPU, float); 2357 REGISTER_KERNELS(GPU, double); 2358 #endif 2359 #undef REGISTER_CPU_KERNELS 2360 #undef REGISTER_KERNELS 2361 2362 // Note, this op works on cpu only. 2363 template <typename T, typename Tindex> 2364 class SparseApplyMomentumOp : public OpKernel { 2365 public: 2366 explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2367 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2368 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); 2369 } 2370 2371 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 2372 auto locks = 2373 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 2374 2375 Tensor var; 2376 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 2377 ctx, 0, use_exclusive_lock_, true, &var)); 2378 Tensor accum; 2379 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 2380 ctx, 1, use_exclusive_lock_, true, &accum)); 2381 OP_REQUIRES( 2382 ctx, var.IsInitialized(), 2383 errors::FailedPrecondition( 2384 "Attempting to use uninitialized variables: ", requested_input(0))); 2385 OP_REQUIRES( 2386 ctx, accum.IsInitialized(), 2387 errors::FailedPrecondition( 2388 "Attempting to use uninitialized variables: ", requested_input(1))); 2389 OP_REQUIRES( 2390 ctx, var.shape().IsSameSize(accum.shape()), 2391 errors::InvalidArgument("var and accum do not have the same shape", 2392 var.shape().DebugString(), " ", 2393 accum.shape().DebugString())); 2394 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 2395 errors::InvalidArgument("var must be at least 1 dimensional")); 2396 2397 const Tensor& lr = ctx->input(2); 2398 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 2399 errors::InvalidArgument("lr is not a scalar : ", 2400 lr.shape().DebugString())); 2401 const Tensor& grad = ctx->input(3); 2402 const Tensor& indices = ctx->input(4); 2403 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 2404 errors::InvalidArgument("indices must be one-dimensional")); 2405 2406 for (int d = 1; d < var.dims(); d++) { 2407 OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d), 2408 errors::InvalidArgument(strings::StrCat( 2409 "var and grad must match in dimension ", d))); 2410 } 2411 const Tindex N = indices.dim_size(0); 2412 OP_REQUIRES( 2413 ctx, grad.dim_size(0) == N, 2414 errors::InvalidArgument( 2415 "grad must be the same size as indices in the first dimension.")); 2416 2417 const Tensor& momentum = ctx->input(5); 2418 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 2419 errors::InvalidArgument("momentum is not a scalar: ", 2420 momentum.shape().DebugString())); 2421 2422 if (N > 0) { 2423 const Tindex first_dim_size = var.dim_size(0); 2424 auto indices_vec = indices.vec<Tindex>(); 2425 auto var_flat = var.flat_outer_dims<T>(); 2426 auto accum_flat = accum.flat_outer_dims<T>(); 2427 auto grad_flat = grad.flat_outer_dims<T>(); 2428 T lr_scalar = lr.scalar<T>()(); 2429 T momentum_scalar = momentum.scalar<T>()(); 2430 2431 for (Tindex i = 0; i < N; i++) { 2432 const Tindex index = internal::SubtleMustCopy(indices_vec(i)); 2433 OP_REQUIRES(ctx, FastBoundsCheck(index, first_dim_size), 2434 errors::InvalidArgument( 2435 strings::StrCat("Index ", index, " at offset ", i, 2436 " in indices is out of range"))); 2437 auto a = accum_flat.template chip<0>(index); 2438 auto g = grad_flat.template chip<0>(i); 2439 auto v = var_flat.template chip<0>(index); 2440 a = a * a.constant(momentum_scalar) + g; 2441 if (use_nesterov_) { 2442 v -= g.constant(lr_scalar) * g + 2443 a.constant(lr_scalar) * a.constant(momentum_scalar) * a; 2444 } else { 2445 v -= a.constant(lr_scalar) * a; 2446 } 2447 } 2448 } 2449 2450 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2451 } 2452 2453 private: 2454 bool use_exclusive_lock_; 2455 bool use_nesterov_; 2456 }; 2457 2458 #define REGISTER_KERNELS(T, Tindices) \ 2459 REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \ 2460 .Device(DEVICE_CPU) \ 2461 .TypeConstraint<T>("T") \ 2462 .TypeConstraint<Tindices>("Tindices"), \ 2463 SparseApplyMomentumOp<T, Tindices>); \ 2464 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyMomentum") \ 2465 .Device(DEVICE_CPU) \ 2466 .TypeConstraint<T>("T") \ 2467 .TypeConstraint<Tindices>("Tindices"), \ 2468 SparseApplyMomentumOp<T, Tindices>); 2469 #define REGISTER_CPU_KERNELS(T) \ 2470 REGISTER_KERNELS(T, int32); \ 2471 REGISTER_KERNELS(T, int64); 2472 2473 TF_CALL_half(REGISTER_CPU_KERNELS); 2474 TF_CALL_float(REGISTER_CPU_KERNELS); 2475 TF_CALL_double(REGISTER_CPU_KERNELS); 2476 2477 #undef REGISTER_CPU_KERNELS 2478 #undef REGISTER_KERNELS 2479 2480 template <typename Device, typename T> 2481 class ApplyAdamOp : public OpKernel { 2482 public: 2483 explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2484 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2485 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_)); 2486 } 2487 2488 void Compute(OpKernelContext* ctx) override { 2489 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2490 {0, 1, 2}); 2491 2492 Tensor var; 2493 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2494 ctx, 0, use_exclusive_lock_, false, &var)); 2495 Tensor m; 2496 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2497 ctx, 1, use_exclusive_lock_, false, &m)); 2498 Tensor v; 2499 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2500 ctx, 2, use_exclusive_lock_, false, &v)); 2501 OP_REQUIRES( 2502 ctx, var.IsInitialized(), 2503 errors::FailedPrecondition( 2504 "Attempting to use uninitialized variables: ", requested_input(0))); 2505 OP_REQUIRES( 2506 ctx, m.IsInitialized(), 2507 errors::FailedPrecondition( 2508 "Attempting to use uninitialized variables: ", requested_input(1))); 2509 OP_REQUIRES( 2510 ctx, v.IsInitialized(), 2511 errors::FailedPrecondition( 2512 "Attempting to use uninitialized variables: ", requested_input(2))); 2513 2514 const Tensor& beta1_power = ctx->input(3); 2515 const Tensor& beta2_power = ctx->input(4); 2516 const Tensor& lr = ctx->input(5); 2517 const Tensor& beta1 = ctx->input(6); 2518 const Tensor& beta2 = ctx->input(7); 2519 const Tensor& epsilon = ctx->input(8); 2520 2521 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), 2522 errors::InvalidArgument("beta1_power is not a scalar: ", 2523 beta1_power.shape().DebugString())); 2524 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()), 2525 errors::InvalidArgument("beta2_power is not a scalar: ", 2526 beta2_power.shape().DebugString())); 2527 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 2528 errors::InvalidArgument("lr is not a scalar : ", 2529 lr.shape().DebugString())); 2530 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), 2531 errors::InvalidArgument("beta1 is not a scalar: ", 2532 beta1.shape().DebugString())); 2533 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), 2534 errors::InvalidArgument("beta2 is not a scalar: ", 2535 beta2.shape().DebugString())); 2536 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 2537 errors::InvalidArgument("epsilon is not a scalar: ", 2538 epsilon.shape().DebugString())); 2539 2540 const Tensor& grad = ctx->input(9); 2541 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), 2542 errors::InvalidArgument("var and m do not have the same shape", 2543 var.shape().DebugString(), " ", 2544 m.shape().DebugString())); 2545 OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), 2546 errors::InvalidArgument("var and v do not have the same shape", 2547 var.shape().DebugString(), " ", 2548 v.shape().DebugString())); 2549 OP_REQUIRES( 2550 ctx, var.shape().IsSameSize(grad.shape()), 2551 errors::InvalidArgument("var and grad do not have the same shape", 2552 var.shape().DebugString(), " ", 2553 grad.shape().DebugString())); 2554 2555 const Device& device = ctx->template eigen_device<Device>(); 2556 functor::ApplyAdam<Device, T>()( 2557 device, var.flat<T>(), m.flat<T>(), v.flat<T>(), 2558 beta1_power.scalar<T>(), beta2_power.scalar<T>(), lr.scalar<T>(), 2559 beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), 2560 grad.flat<T>(), use_nesterov_); 2561 2562 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2563 } 2564 2565 private: 2566 bool use_exclusive_lock_; 2567 bool use_nesterov_; 2568 }; 2569 2570 #ifdef TENSORFLOW_USE_SYCL 2571 template <typename T> 2572 class ApplyAdamOp<SYCLDevice, T> : public OpKernel { 2573 public: 2574 explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2575 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2576 } 2577 2578 void Compute(OpKernelContext* ctx) override { 2579 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2580 {0, 1, 2}); 2581 2582 Tensor var; 2583 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>( 2584 ctx, 0, use_exclusive_lock_, false, &var)); 2585 Tensor m; 2586 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>( 2587 ctx, 1, use_exclusive_lock_, false, &m)); 2588 Tensor v; 2589 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<SYCLDevice, T>( 2590 ctx, 2, use_exclusive_lock_, false, &v)); 2591 OP_REQUIRES( 2592 ctx, var.IsInitialized(), 2593 errors::FailedPrecondition( 2594 "Attempting to use uninitialized variables: ", requested_input(0))); 2595 OP_REQUIRES( 2596 ctx, m.IsInitialized(), 2597 errors::FailedPrecondition( 2598 "Attempting to use uninitialized variables: ", requested_input(1))); 2599 OP_REQUIRES( 2600 ctx, v.IsInitialized(), 2601 errors::FailedPrecondition( 2602 "Attempting to use uninitialized variables: ", requested_input(2))); 2603 2604 const Tensor& beta1_power_dev = ctx->input(3); 2605 const Tensor& beta2_power_dev = ctx->input(4); 2606 const Tensor& lr_dev = ctx->input(5); 2607 const Tensor& beta1_dev = ctx->input(6); 2608 const Tensor& beta2_dev = ctx->input(7); 2609 const Tensor& epsilon_dev = ctx->input(8); 2610 2611 T beta1_power = 0; 2612 T beta2_power = 0; 2613 T lr = 0; 2614 T beta1 = 0; 2615 T beta2 = 0; 2616 T epsilon = 0; 2617 2618 auto device = ctx->eigen_sycl_device(); 2619 auto size = sizeof(T); 2620 auto src_ptr = GetBase(&beta1_power_dev); 2621 device.memcpyDeviceToHost(&beta1_power, static_cast<const T*>(src_ptr), 2622 size); 2623 2624 src_ptr = GetBase(&beta2_power_dev); 2625 device.memcpyDeviceToHost(&beta2_power, static_cast<const T*>(src_ptr), 2626 size); 2627 2628 src_ptr = GetBase(&lr_dev); 2629 device.memcpyDeviceToHost(&lr, static_cast<const T*>(src_ptr), size); 2630 2631 src_ptr = GetBase(&beta1_dev); 2632 device.memcpyDeviceToHost(&beta1, static_cast<const T*>(src_ptr), size); 2633 2634 src_ptr = GetBase(&beta2_dev); 2635 device.memcpyDeviceToHost(&beta2, static_cast<const T*>(src_ptr), size); 2636 2637 src_ptr = GetBase(&epsilon_dev); 2638 device.memcpyDeviceToHost(&epsilon, static_cast<const T*>(src_ptr), size); 2639 2640 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_dev.shape()), 2641 errors::InvalidArgument("beta1_power is not a scalar: ", 2642 beta1_power_dev.shape().DebugString())); 2643 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_dev.shape()), 2644 errors::InvalidArgument("beta2_power is not a scalar: ", 2645 beta2_power_dev.shape().DebugString())); 2646 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_dev.shape()), 2647 errors::InvalidArgument("lr is not a scalar : ", 2648 lr_dev.shape().DebugString())); 2649 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_dev.shape()), 2650 errors::InvalidArgument("beta1 is not a scalar: ", 2651 beta1_dev.shape().DebugString())); 2652 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_dev.shape()), 2653 errors::InvalidArgument("beta2 is not a scalar: ", 2654 beta2_dev.shape().DebugString())); 2655 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_dev.shape()), 2656 errors::InvalidArgument("epsilon is not a scalar: ", 2657 epsilon_dev.shape().DebugString())); 2658 2659 const Tensor& grad = ctx->input(9); 2660 2661 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), 2662 errors::InvalidArgument("var and m do not have the same shape", 2663 var.shape().DebugString(), " ", 2664 m.shape().DebugString())); 2665 OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), 2666 errors::InvalidArgument("var and v do not have the same shape", 2667 var.shape().DebugString(), " ", 2668 v.shape().DebugString())); 2669 OP_REQUIRES( 2670 ctx, var.shape().IsSameSize(grad.shape()), 2671 errors::InvalidArgument("var and grad do not have the same shape", 2672 var.shape().DebugString(), " ", 2673 grad.shape().DebugString())); 2674 2675 functor::ApplyAdamSYCL<T>()(device, var.flat<T>(), m.flat<T>(), v.flat<T>(), 2676 beta1_power, beta2_power, lr, beta1, beta2, 2677 epsilon, grad.flat<T>()); 2678 2679 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2680 } 2681 2682 private: 2683 bool use_exclusive_lock_; 2684 }; 2685 #endif // TENSORFLOW_USE_SYCL 2686 2687 #define REGISTER_KERNELS(D, T) \ 2688 REGISTER_KERNEL_BUILDER( \ 2689 Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 2690 ApplyAdamOp<D##Device, T>); \ 2691 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdam") \ 2692 .HostMemory("var") \ 2693 .HostMemory("m") \ 2694 .HostMemory("v") \ 2695 .Device(DEVICE_##D) \ 2696 .TypeConstraint<T>("T"), \ 2697 ApplyAdamOp<D##Device, T>); 2698 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 2699 2700 TF_CALL_half(REGISTER_CPU_KERNELS); 2701 TF_CALL_float(REGISTER_CPU_KERNELS); 2702 TF_CALL_double(REGISTER_CPU_KERNELS); 2703 2704 #ifdef TENSORFLOW_USE_SYCL 2705 #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); 2706 2707 TF_CALL_float(REGISTER_SYCL_KERNELS); 2708 TF_CALL_double(REGISTER_SYCL_KERNELS); 2709 #endif 2710 2711 #if GOOGLE_CUDA 2712 // Forward declarations of the functor specializations for GPU. 2713 namespace functor { 2714 #define DECLARE_GPU_SPEC(T) \ 2715 template <> \ 2716 void ApplyAdam<GPUDevice, T>::operator()( \ 2717 const GPUDevice& d, typename TTypes<T>::Flat var, \ 2718 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ 2719 typename TTypes<T>::ConstScalar beta1_power, \ 2720 typename TTypes<T>::ConstScalar beta2_power, \ 2721 typename TTypes<T>::ConstScalar lr, \ 2722 typename TTypes<T>::ConstScalar beta1, \ 2723 typename TTypes<T>::ConstScalar beta2, \ 2724 typename TTypes<T>::ConstScalar epsilon, \ 2725 typename TTypes<T>::ConstFlat grad, bool use_nesterov); \ 2726 extern template struct ApplyAdam<GPUDevice, T>; 2727 DECLARE_GPU_SPEC(Eigen::half); 2728 DECLARE_GPU_SPEC(float); 2729 DECLARE_GPU_SPEC(double); 2730 #undef DECLARE_GPU_SPEC 2731 } // namespace functor 2732 2733 REGISTER_KERNELS(GPU, Eigen::half); 2734 REGISTER_KERNELS(GPU, float); 2735 REGISTER_KERNELS(GPU, double); 2736 #endif 2737 #undef REGISTER_CPU_KERNELS 2738 #undef REGISTER_KERNELS 2739 2740 template <typename Device, typename T> 2741 class ApplyRMSPropOp : public OpKernel { 2742 public: 2743 explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2744 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2745 } 2746 2747 void Compute(OpKernelContext* ctx) override { 2748 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2749 {0, 1, 2}); 2750 2751 Tensor var; 2752 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2753 ctx, 0, use_exclusive_lock_, false, &var)); 2754 Tensor ms; 2755 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2756 ctx, 1, use_exclusive_lock_, false, &ms)); 2757 Tensor mom; 2758 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2759 ctx, 2, use_exclusive_lock_, false, &mom)); 2760 2761 OP_REQUIRES( 2762 ctx, var.IsInitialized(), 2763 errors::FailedPrecondition( 2764 "Attempting to use uninitialized variables: ", requested_input(0))); 2765 OP_REQUIRES( 2766 ctx, ms.IsInitialized(), 2767 errors::FailedPrecondition( 2768 "Attempting to use uninitialized variables: ", requested_input(1))); 2769 OP_REQUIRES( 2770 ctx, mom.IsInitialized(), 2771 errors::FailedPrecondition( 2772 "Attempting to use uninitialized variables: ", requested_input(2))); 2773 2774 const Tensor& lr = ctx->input(3); 2775 const Tensor& rho = ctx->input(4); 2776 const Tensor& momentum = ctx->input(5); 2777 const Tensor& epsilon = ctx->input(6); 2778 const Tensor& grad = ctx->input(7); 2779 2780 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 2781 errors::InvalidArgument("lr is not a scalar : ", 2782 lr.shape().DebugString())); 2783 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 2784 errors::InvalidArgument("rho is not a scalar: ", 2785 rho.shape().DebugString())); 2786 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 2787 errors::InvalidArgument("momentum is not a scalar: ", 2788 momentum.shape().DebugString())); 2789 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 2790 errors::InvalidArgument("epsilon is not a scalar: ", 2791 epsilon.shape().DebugString())); 2792 2793 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), 2794 errors::InvalidArgument("var and ms do not have the same shape", 2795 var.shape().DebugString(), " ", 2796 ms.shape().DebugString())); 2797 2798 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), 2799 errors::InvalidArgument( 2800 "var and mom do not have the same shape", 2801 var.shape().DebugString(), " ", mom.shape().DebugString())); 2802 2803 OP_REQUIRES( 2804 ctx, var.shape().IsSameSize(grad.shape()), 2805 errors::InvalidArgument("var and grad do not have the same shape", 2806 var.shape().DebugString(), " ", 2807 grad.shape().DebugString())); 2808 2809 const Device& device = ctx->template eigen_device<Device>(); 2810 functor::ApplyRMSProp<Device, T>()(device, var.flat<T>(), ms.flat<T>(), 2811 mom.flat<T>(), lr.scalar<T>(), 2812 rho.scalar<T>(), momentum.scalar<T>(), 2813 epsilon.scalar<T>(), grad.flat<T>()); 2814 2815 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2816 } 2817 2818 private: 2819 bool use_exclusive_lock_; 2820 }; 2821 2822 template <typename Device, typename T> 2823 class ApplyCenteredRMSPropOp : public OpKernel { 2824 public: 2825 explicit ApplyCenteredRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2826 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2827 } 2828 2829 void Compute(OpKernelContext* ctx) override { 2830 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2831 {0, 1, 2, 3}); 2832 2833 Tensor var; 2834 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2835 ctx, 0, use_exclusive_lock_, false, &var)); 2836 Tensor mg; 2837 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2838 ctx, 1, use_exclusive_lock_, false, &mg)); 2839 Tensor ms; 2840 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2841 ctx, 2, use_exclusive_lock_, false, &ms)); 2842 Tensor mom; 2843 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 2844 ctx, 3, use_exclusive_lock_, false, &mom)); 2845 2846 OP_REQUIRES( 2847 ctx, var.IsInitialized(), 2848 errors::FailedPrecondition( 2849 "Attempting to use uninitialized variables: ", requested_input(0))); 2850 OP_REQUIRES( 2851 ctx, mg.IsInitialized(), 2852 errors::FailedPrecondition( 2853 "Attempting to use uninitialized variables: ", requested_input(1))); 2854 OP_REQUIRES( 2855 ctx, ms.IsInitialized(), 2856 errors::FailedPrecondition( 2857 "Attempting to use uninitialized variables: ", requested_input(2))); 2858 OP_REQUIRES( 2859 ctx, mom.IsInitialized(), 2860 errors::FailedPrecondition( 2861 "Attempting to use uninitialized variables: ", requested_input(3))); 2862 2863 const Tensor& lr = ctx->input(4); 2864 const Tensor& rho = ctx->input(5); 2865 const Tensor& momentum = ctx->input(6); 2866 const Tensor& epsilon = ctx->input(7); 2867 const Tensor& grad = ctx->input(8); 2868 2869 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 2870 errors::InvalidArgument("lr is not a scalar : ", 2871 lr.shape().DebugString())); 2872 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 2873 errors::InvalidArgument("rho is not a scalar: ", 2874 rho.shape().DebugString())); 2875 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 2876 errors::InvalidArgument("momentum is not a scalar: ", 2877 momentum.shape().DebugString())); 2878 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 2879 errors::InvalidArgument("epsilon is not a scalar: ", 2880 epsilon.shape().DebugString())); 2881 2882 OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()), 2883 errors::InvalidArgument("var and mg do not have the same shape", 2884 var.shape().DebugString(), " ", 2885 ms.shape().DebugString())); 2886 2887 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), 2888 errors::InvalidArgument("var and ms do not have the same shape", 2889 var.shape().DebugString(), " ", 2890 ms.shape().DebugString())); 2891 2892 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), 2893 errors::InvalidArgument( 2894 "var and mom do not have the same shape", 2895 var.shape().DebugString(), " ", mom.shape().DebugString())); 2896 2897 OP_REQUIRES( 2898 ctx, var.shape().IsSameSize(grad.shape()), 2899 errors::InvalidArgument("var and grad do not have the same shape", 2900 var.shape().DebugString(), " ", 2901 grad.shape().DebugString())); 2902 2903 const Device& device = ctx->template eigen_device<Device>(); 2904 functor::ApplyCenteredRMSProp<Device, T>()( 2905 device, var.flat<T>(), mg.flat<T>(), ms.flat<T>(), mom.flat<T>(), 2906 lr.scalar<T>(), rho.scalar<T>(), momentum.scalar<T>(), 2907 epsilon.scalar<T>(), grad.flat<T>()); 2908 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 2909 } 2910 2911 private: 2912 bool use_exclusive_lock_; 2913 }; 2914 2915 #define REGISTER_KERNELS(D, T) \ 2916 REGISTER_KERNEL_BUILDER( \ 2917 Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 2918 ApplyRMSPropOp<D##Device, T>); \ 2919 REGISTER_KERNEL_BUILDER( \ 2920 Name("ApplyCenteredRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 2921 ApplyCenteredRMSPropOp<D##Device, T>); \ 2922 REGISTER_KERNEL_BUILDER(Name("ResourceApplyRMSProp") \ 2923 .Device(DEVICE_##D) \ 2924 .HostMemory("var") \ 2925 .HostMemory("ms") \ 2926 .HostMemory("mom") \ 2927 .TypeConstraint<T>("T"), \ 2928 ApplyRMSPropOp<D##Device, T>); \ 2929 REGISTER_KERNEL_BUILDER(Name("ResourceApplyCenteredRMSProp") \ 2930 .Device(DEVICE_##D) \ 2931 .HostMemory("var") \ 2932 .HostMemory("mg") \ 2933 .HostMemory("ms") \ 2934 .HostMemory("mom") \ 2935 .TypeConstraint<T>("T"), \ 2936 ApplyCenteredRMSPropOp<D##Device, T>); 2937 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 2938 2939 TF_CALL_half(REGISTER_CPU_KERNELS); 2940 TF_CALL_float(REGISTER_CPU_KERNELS); 2941 TF_CALL_double(REGISTER_CPU_KERNELS); 2942 2943 #if GOOGLE_CUDA 2944 // Forward declarations of the functor specializations for GPU. 2945 namespace functor { 2946 #define DECLARE_GPU_SPEC(T) \ 2947 template <> \ 2948 void ApplyRMSProp<GPUDevice, T>::operator()( \ 2949 const GPUDevice& d, typename TTypes<T>::Flat var, \ 2950 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, \ 2951 typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \ 2952 typename TTypes<T>::ConstScalar momentum, \ 2953 typename TTypes<T>::ConstScalar epsilon, \ 2954 typename TTypes<T>::ConstFlat grad); \ 2955 extern template struct ApplyRMSProp<GPUDevice, T>; \ 2956 template <> \ 2957 void ApplyCenteredRMSProp<GPUDevice, T>::operator()( \ 2958 const GPUDevice& d, typename TTypes<T>::Flat var, \ 2959 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, \ 2960 typename TTypes<T>::Flat mom, typename TTypes<T>::ConstScalar lr, \ 2961 typename TTypes<T>::ConstScalar rho, \ 2962 typename TTypes<T>::ConstScalar momentum, \ 2963 typename TTypes<T>::ConstScalar epsilon, \ 2964 typename TTypes<T>::ConstFlat grad); \ 2965 extern template struct ApplyCenteredRMSProp<GPUDevice, T>; 2966 DECLARE_GPU_SPEC(Eigen::half); 2967 DECLARE_GPU_SPEC(float); 2968 DECLARE_GPU_SPEC(double); 2969 #undef DECLARE_GPU_SPEC 2970 } // namespace functor 2971 2972 REGISTER_KERNELS(GPU, Eigen::half); 2973 REGISTER_KERNELS(GPU, float); 2974 REGISTER_KERNELS(GPU, double); 2975 #endif 2976 #undef REGISTER_CPU_KERNELS 2977 #undef REGISTER_KERNELS 2978 2979 // Note, this op works on cpu only. 2980 template <typename T, typename Tindex> 2981 class SparseApplyRMSPropOp : public OpKernel { 2982 public: 2983 explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 2984 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 2985 } 2986 2987 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 2988 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 2989 {0, 1, 2}); 2990 2991 Tensor var; 2992 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 2993 ctx, 0, use_exclusive_lock_, true, &var)); 2994 Tensor ms; 2995 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 2996 ctx, 1, use_exclusive_lock_, true, &ms)); 2997 Tensor mom; 2998 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 2999 ctx, 2, use_exclusive_lock_, true, &mom)); 3000 3001 OP_REQUIRES( 3002 ctx, var.IsInitialized(), 3003 errors::FailedPrecondition( 3004 "Attempting to use uninitialized variables: ", requested_input(0))); 3005 OP_REQUIRES( 3006 ctx, ms.IsInitialized(), 3007 errors::FailedPrecondition( 3008 "Attempting to use uninitialized variables: ", requested_input(1))); 3009 OP_REQUIRES( 3010 ctx, mom.IsInitialized(), 3011 errors::FailedPrecondition( 3012 "Attempting to use uninitialized variables: ", requested_input(2))); 3013 3014 const Tensor& lr = ctx->input(3); 3015 const Tensor& rho = ctx->input(4); 3016 const Tensor& momentum = ctx->input(5); 3017 const Tensor& epsilon = ctx->input(6); 3018 const Tensor& grad = ctx->input(7); 3019 const Tensor& indices = ctx->input(8); 3020 3021 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 3022 errors::InvalidArgument("lr is not a scalar: ", 3023 lr.shape().DebugString())); 3024 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 3025 errors::InvalidArgument("rho is not a scalar: ", 3026 rho.shape().DebugString())); 3027 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 3028 errors::InvalidArgument("momentum is not a scalar: ", 3029 momentum.shape().DebugString())); 3030 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 3031 errors::InvalidArgument("epsilon is not a scalar: ", 3032 epsilon.shape().DebugString())); 3033 3034 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), 3035 errors::InvalidArgument("var and ms do not have the same shape", 3036 var.shape().DebugString(), " ", 3037 ms.shape().DebugString())); 3038 3039 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), 3040 errors::InvalidArgument( 3041 "var and mom do not have the same shape", 3042 var.shape().DebugString(), " ", mom.shape().DebugString())); 3043 3044 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 3045 errors::InvalidArgument("var must be at least 1 dimensional")); 3046 3047 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 3048 errors::InvalidArgument("indices must be one-dimensional")); 3049 3050 for (int d = 1; d < var.dims(); d++) { 3051 OP_REQUIRES( 3052 ctx, var.dim_size(d) == grad.dim_size(d), 3053 errors::InvalidArgument("var and grad must match in dimension ", d)); 3054 } 3055 const Tindex N = indices.dim_size(0); 3056 OP_REQUIRES( 3057 ctx, grad.dim_size(0) == N, 3058 errors::InvalidArgument( 3059 "grad must be the same size as indices in the first dimension.")); 3060 3061 if (N > 0) { 3062 const Tindex first_dim_size = var.dim_size(0); 3063 // Validate all the indices are in range 3064 auto indices_vec = indices.vec<Tindex>(); 3065 for (Tindex i = 0; i < N; i++) { 3066 const Tindex index = indices_vec(i); 3067 OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, 3068 errors::InvalidArgument( 3069 strings::StrCat("Index ", index, " at offset ", i, 3070 " in indices is out of range"))); 3071 } 3072 3073 auto var_flat = var.flat_outer_dims<T>(); 3074 auto ms_flat = ms.flat_outer_dims<T>(); 3075 auto mom_flat = mom.flat_outer_dims<T>(); 3076 auto grad_flat = grad.flat_outer_dims<T>(); 3077 const T lr_scalar = lr.scalar<T>()(); 3078 const T rho_scalar = rho.scalar<T>()(); 3079 const T epsilon_scalar = epsilon.scalar<T>()(); 3080 const T momentum_scalar = momentum.scalar<T>()(); 3081 3082 for (Tindex i = 0; i < N; i++) { 3083 const Tindex index = indices_vec(i); 3084 3085 auto ms_ = ms_flat.template chip<0>(index); 3086 auto mom_ = mom_flat.template chip<0>(index); 3087 auto grad_ = grad_flat.template chip<0>(i); 3088 3089 ms_ = ms_ * ms_.constant(rho_scalar) + 3090 grad_.square() * grad_.constant(T(1) - rho_scalar); 3091 mom_ = mom_ * mom_.constant(momentum_scalar) + 3092 (ms_ + ms_.constant(epsilon_scalar)).rsqrt() * 3093 ms_.constant(lr_scalar) * grad_; 3094 3095 auto v = var_flat.template chip<0>(index); 3096 v -= mom_; 3097 } 3098 } 3099 3100 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 3101 } 3102 3103 private: 3104 bool use_exclusive_lock_; 3105 }; 3106 3107 // Note, this op works on cpu only. 3108 template <typename T, typename Tindex> 3109 class SparseApplyCenteredRMSPropOp : public OpKernel { 3110 public: 3111 explicit SparseApplyCenteredRMSPropOp(OpKernelConstruction* ctx) 3112 : OpKernel(ctx) { 3113 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 3114 } 3115 3116 void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { 3117 auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, 3118 {0, 1, 2, 3}); 3119 3120 Tensor var; 3121 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 3122 ctx, 0, use_exclusive_lock_, true, &var)); 3123 Tensor mg; 3124 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 3125 ctx, 1, use_exclusive_lock_, true, &mg)); 3126 Tensor ms; 3127 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 3128 ctx, 2, use_exclusive_lock_, true, &ms)); 3129 Tensor mom; 3130 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<CPUDevice, T>( 3131 ctx, 3, use_exclusive_lock_, true, &mom)); 3132 3133 OP_REQUIRES( 3134 ctx, var.IsInitialized(), 3135 errors::FailedPrecondition( 3136 "Attempting to use uninitialized variables: ", requested_input(0))); 3137 OP_REQUIRES( 3138 ctx, ms.IsInitialized(), 3139 errors::FailedPrecondition( 3140 "Attempting to use uninitialized variables: ", requested_input(2))); 3141 OP_REQUIRES( 3142 ctx, mom.IsInitialized(), 3143 errors::FailedPrecondition( 3144 "Attempting to use uninitialized variables: ", requested_input(3))); 3145 3146 const Tensor& lr = ctx->input(4); 3147 const Tensor& rho = ctx->input(5); 3148 const Tensor& momentum = ctx->input(6); 3149 const Tensor& epsilon = ctx->input(7); 3150 const Tensor& grad = ctx->input(8); 3151 const Tensor& indices = ctx->input(9); 3152 3153 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 3154 errors::InvalidArgument("lr is not a scalar: ", 3155 lr.shape().DebugString())); 3156 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), 3157 errors::InvalidArgument("rho is not a scalar: ", 3158 rho.shape().DebugString())); 3159 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), 3160 errors::InvalidArgument("momentum is not a scalar: ", 3161 momentum.shape().DebugString())); 3162 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), 3163 errors::InvalidArgument("epsilon is not a scalar: ", 3164 epsilon.shape().DebugString())); 3165 3166 OP_REQUIRES(ctx, var.shape().IsSameSize(mg.shape()), 3167 errors::InvalidArgument("var and mg do not have the same shape", 3168 var.shape().DebugString(), " ", 3169 mg.shape().DebugString())); 3170 3171 OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), 3172 errors::InvalidArgument("var and ms do not have the same shape", 3173 var.shape().DebugString(), " ", 3174 ms.shape().DebugString())); 3175 3176 OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), 3177 errors::InvalidArgument( 3178 "var and mom do not have the same shape", 3179 var.shape().DebugString(), " ", mom.shape().DebugString())); 3180 3181 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), 3182 errors::InvalidArgument("var must be at least 1 dimensional")); 3183 3184 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), 3185 errors::InvalidArgument("indices must be one-dimensional")); 3186 3187 for (int d = 1; d < var.dims(); d++) { 3188 OP_REQUIRES( 3189 ctx, var.dim_size(d) == grad.dim_size(d), 3190 errors::InvalidArgument("var and grad must match in dimension ", d)); 3191 } 3192 const Tindex N = indices.dim_size(0); 3193 OP_REQUIRES( 3194 ctx, grad.dim_size(0) == N, 3195 errors::InvalidArgument( 3196 "grad must be the same size as indices in the first dimension.")); 3197 3198 if (N > 0) { 3199 const Tindex first_dim_size = var.dim_size(0); 3200 // Validate all the indices are in range 3201 auto indices_vec = indices.vec<Tindex>(); 3202 for (Tindex i = 0; i < N; i++) { 3203 const Tindex index = indices_vec(i); 3204 OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, 3205 errors::InvalidArgument( 3206 strings::StrCat("Index ", index, " at offset ", i, 3207 " in indices is out of range"))); 3208 } 3209 3210 auto var_flat = var.flat_outer_dims<T>(); 3211 auto ms_flat = ms.flat_outer_dims<T>(); 3212 auto mg_flat = mg.flat_outer_dims<T>(); 3213 auto mom_flat = mom.flat_outer_dims<T>(); 3214 auto grad_flat = grad.flat_outer_dims<T>(); 3215 const T lr_scalar = lr.scalar<T>()(); 3216 const T rho_scalar = rho.scalar<T>()(); 3217 const T epsilon_scalar = epsilon.scalar<T>()(); 3218 const T momentum_scalar = momentum.scalar<T>()(); 3219 3220 for (Tindex i = 0; i < N; i++) { 3221 const Tindex index = indices_vec(i); 3222 3223 auto ms_ = ms_flat.template chip<0>(index); 3224 auto mom_ = mom_flat.template chip<0>(index); 3225 auto grad_ = grad_flat.template chip<0>(i); 3226 3227 ms_ = ms_ * ms_.constant(rho_scalar) + 3228 grad_.square() * grad_.constant(T(1) - rho_scalar); 3229 3230 auto mg_ = mg_flat.template chip<0>(index); 3231 mg_ = mg_ * mg_.constant(rho_scalar) + 3232 grad_ * grad_.constant(T(1) - rho_scalar); 3233 auto denom_ = ms_ + ms_.constant(epsilon_scalar) - mg_.square(); 3234 mom_ = mom_ * mom_.constant(momentum_scalar) + 3235 denom_.rsqrt() * ms_.constant(lr_scalar) * grad_; 3236 auto v = var_flat.template chip<0>(index); 3237 v -= mom_; 3238 } 3239 } 3240 3241 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 3242 } 3243 3244 private: 3245 bool use_exclusive_lock_; 3246 }; 3247 3248 #define REGISTER_KERNELS(T, Tindices) \ 3249 REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ 3250 .Device(DEVICE_CPU) \ 3251 .TypeConstraint<T>("T") \ 3252 .TypeConstraint<Tindices>("Tindices"), \ 3253 SparseApplyRMSPropOp<T, Tindices>); \ 3254 REGISTER_KERNEL_BUILDER(Name("SparseApplyCenteredRMSProp") \ 3255 .Device(DEVICE_CPU) \ 3256 .TypeConstraint<T>("T") \ 3257 .TypeConstraint<Tindices>("Tindices"), \ 3258 SparseApplyCenteredRMSPropOp<T, Tindices>); \ 3259 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyRMSProp") \ 3260 .Device(DEVICE_CPU) \ 3261 .TypeConstraint<T>("T") \ 3262 .TypeConstraint<Tindices>("Tindices"), \ 3263 SparseApplyRMSPropOp<T, Tindices>); \ 3264 REGISTER_KERNEL_BUILDER(Name("ResourceSparseApplyCenteredRMSProp") \ 3265 .Device(DEVICE_CPU) \ 3266 .TypeConstraint<T>("T") \ 3267 .TypeConstraint<Tindices>("Tindices"), \ 3268 SparseApplyCenteredRMSPropOp<T, Tindices>); 3269 3270 REGISTER_KERNELS(Eigen::half, int32); 3271 REGISTER_KERNELS(Eigen::half, int64); 3272 REGISTER_KERNELS(float, int32); 3273 REGISTER_KERNELS(float, int64); 3274 REGISTER_KERNELS(double, int32); 3275 REGISTER_KERNELS(double, int64); 3276 3277 #undef REGISTER_KERNELS 3278 3279 template <typename Device, typename T> 3280 class ApplyAddSignOp : public OpKernel { 3281 public: 3282 explicit ApplyAddSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 3283 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 3284 } 3285 3286 void Compute(OpKernelContext* ctx) override { 3287 auto locks = 3288 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 3289 3290 Tensor var; 3291 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 3292 ctx, 0, use_exclusive_lock_, false, &var)); 3293 Tensor m; 3294 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 3295 ctx, 1, use_exclusive_lock_, false, &m)); 3296 OP_REQUIRES( 3297 ctx, var.IsInitialized(), 3298 errors::FailedPrecondition( 3299 "Attempting to use uninitialized variables: ", requested_input(0))); 3300 OP_REQUIRES( 3301 ctx, m.IsInitialized(), 3302 errors::FailedPrecondition( 3303 "Attempting to use uninitialized variables: ", requested_input(1))); 3304 const Tensor& lr = ctx->input(2); 3305 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 3306 errors::InvalidArgument("lr is not a scalar: ", 3307 lr.shape().DebugString())); 3308 const Tensor& alpha = ctx->input(3); 3309 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), 3310 errors::InvalidArgument("alpha is not a scalar: ", 3311 alpha.shape().DebugString())); 3312 const Tensor& sign_decay = ctx->input(4); 3313 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), 3314 errors::InvalidArgument("sign_decay is not a scalar: ", 3315 sign_decay.shape().DebugString())); 3316 const Tensor& beta = ctx->input(5); 3317 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()), 3318 errors::InvalidArgument("beta is not a scalar: ", 3319 beta.shape().DebugString())); 3320 const Tensor& grad = ctx->input(6); 3321 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), 3322 errors::InvalidArgument("var and m do not have the same shape", 3323 var.shape().DebugString(), " ", 3324 m.shape().DebugString())); 3325 OP_REQUIRES( 3326 ctx, var.shape().IsSameSize(grad.shape()), 3327 errors::InvalidArgument("var and grad do not have the same shape", 3328 var.shape().DebugString(), " ", 3329 grad.shape().DebugString())); 3330 3331 const Device& device = ctx->template eigen_device<Device>(); 3332 functor::ApplyAddSign<Device, T>()( 3333 device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), alpha.scalar<T>(), 3334 sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>()); 3335 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 3336 } 3337 3338 private: 3339 bool use_exclusive_lock_; 3340 }; 3341 3342 #define REGISTER_KERNELS(D, T) \ 3343 REGISTER_KERNEL_BUILDER( \ 3344 Name("ApplyAddSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 3345 ApplyAddSignOp<D##Device, T>); \ 3346 REGISTER_KERNEL_BUILDER(Name("ResourceApplyAddSign") \ 3347 .Device(DEVICE_##D) \ 3348 .HostMemory("var") \ 3349 .HostMemory("m") \ 3350 .TypeConstraint<T>("T"), \ 3351 ApplyAddSignOp<D##Device, T>); 3352 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 3353 3354 TF_CALL_half(REGISTER_CPU_KERNELS); 3355 TF_CALL_float(REGISTER_CPU_KERNELS); 3356 TF_CALL_double(REGISTER_CPU_KERNELS); 3357 3358 #if GOOGLE_CUDA 3359 // Forward declarations of the functor specializations for GPU. 3360 namespace functor { 3361 #define DECLARE_GPU_SPEC(T) \ 3362 template <> \ 3363 void ApplyAddSign<GPUDevice, T>::operator()( \ 3364 const GPUDevice& d, typename TTypes<T>::Flat var, \ 3365 typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \ 3366 typename TTypes<T>::ConstScalar alpha, \ 3367 typename TTypes<T>::ConstScalar sign_decay, \ 3368 typename TTypes<T>::ConstScalar beta, \ 3369 typename TTypes<T>::ConstFlat grad); \ 3370 extern template struct ApplyAddSign<GPUDevice, T>; 3371 DECLARE_GPU_SPEC(Eigen::half); 3372 DECLARE_GPU_SPEC(float); 3373 DECLARE_GPU_SPEC(double); 3374 #undef DECLARE_GPU_SPEC 3375 } // namespace functor 3376 3377 REGISTER_KERNELS(GPU, Eigen::half); 3378 REGISTER_KERNELS(GPU, float); 3379 REGISTER_KERNELS(GPU, double); 3380 #endif 3381 #undef REGISTER_CPU_KERNELS 3382 #undef REGISTER_KERNELS 3383 3384 template <typename Device, typename T> 3385 class ApplyPowerSignOp : public OpKernel { 3386 public: 3387 explicit ApplyPowerSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 3388 OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); 3389 } 3390 3391 void Compute(OpKernelContext* ctx) override { 3392 auto locks = 3393 MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); 3394 3395 Tensor var; 3396 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 3397 ctx, 0, use_exclusive_lock_, false, &var)); 3398 Tensor m; 3399 OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( 3400 ctx, 1, use_exclusive_lock_, false, &m)); 3401 OP_REQUIRES( 3402 ctx, var.IsInitialized(), 3403 errors::FailedPrecondition( 3404 "Attempting to use uninitialized variables: ", requested_input(0))); 3405 OP_REQUIRES( 3406 ctx, m.IsInitialized(), 3407 errors::FailedPrecondition( 3408 "Attempting to use uninitialized variables: ", requested_input(1))); 3409 const Tensor& lr = ctx->input(2); 3410 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), 3411 errors::InvalidArgument("lr is not a scalar: ", 3412 lr.shape().DebugString())); 3413 const Tensor& logbase = ctx->input(3); 3414 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()), 3415 errors::InvalidArgument("logbase is not a scalar: ", 3416 logbase.shape().DebugString())); 3417 const Tensor& sign_decay = ctx->input(4); 3418 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase.shape()), 3419 errors::InvalidArgument("sign_decay is not a scalar: ", 3420 sign_decay.shape().DebugString())); 3421 const Tensor& beta = ctx->input(5); 3422 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta.shape()), 3423 errors::InvalidArgument("beta is not a scalar: ", 3424 beta.shape().DebugString())); 3425 const Tensor& grad = ctx->input(6); 3426 OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), 3427 errors::InvalidArgument("var and m do not have the same shape", 3428 var.shape().DebugString(), " ", 3429 m.shape().DebugString())); 3430 OP_REQUIRES( 3431 ctx, var.shape().IsSameSize(grad.shape()), 3432 errors::InvalidArgument("var and grad do not have the same shape", 3433 var.shape().DebugString(), " ", 3434 grad.shape().DebugString())); 3435 3436 const Device& device = ctx->template eigen_device<Device>(); 3437 functor::ApplyPowerSign<Device, T>()( 3438 device, var.flat<T>(), m.flat<T>(), lr.scalar<T>(), logbase.scalar<T>(), 3439 sign_decay.scalar<T>(), beta.scalar<T>(), grad.flat<T>()); 3440 MaybeForwardRefInputToRefOutput(ctx, 0, 0); 3441 } 3442 3443 private: 3444 bool use_exclusive_lock_; 3445 }; 3446 3447 #define REGISTER_KERNELS(D, T) \ 3448 REGISTER_KERNEL_BUILDER( \ 3449 Name("ApplyPowerSign").Device(DEVICE_##D).TypeConstraint<T>("T"), \ 3450 ApplyPowerSignOp<D##Device, T>); \ 3451 REGISTER_KERNEL_BUILDER(Name("ResourceApplyPowerSign") \ 3452 .Device(DEVICE_##D) \ 3453 .HostMemory("var") \ 3454 .HostMemory("m") \ 3455 .TypeConstraint<T>("T"), \ 3456 ApplyPowerSignOp<D##Device, T>); 3457 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); 3458 3459 TF_CALL_half(REGISTER_CPU_KERNELS); 3460 TF_CALL_float(REGISTER_CPU_KERNELS); 3461 TF_CALL_double(REGISTER_CPU_KERNELS); 3462 3463 #if GOOGLE_CUDA 3464 // Forward declarations of the functor specializations for GPU. 3465 namespace functor { 3466 #define DECLARE_GPU_SPEC(T) \ 3467 template <> \ 3468 void ApplyPowerSign<GPUDevice, T>::operator()( \ 3469 const GPUDevice& d, typename TTypes<T>::Flat var, \ 3470 typename TTypes<T>::Flat m, typename TTypes<T>::ConstScalar lr, \ 3471 typename TTypes<T>::ConstScalar logbase, \ 3472 typename TTypes<T>::ConstScalar sign_decay, \ 3473 typename TTypes<T>::ConstScalar beta, \ 3474 typename TTypes<T>::ConstFlat grad); \ 3475 extern template struct ApplyPowerSign<GPUDevice, T>; 3476 DECLARE_GPU_SPEC(Eigen::half); 3477 DECLARE_GPU_SPEC(float); 3478 DECLARE_GPU_SPEC(double); 3479 #undef DECLARE_GPU_SPEC 3480 } // namespace functor 3481 3482 REGISTER_KERNELS(GPU, Eigen::half); 3483 REGISTER_KERNELS(GPU, float); 3484 REGISTER_KERNELS(GPU, double); 3485 #endif 3486 #undef REGISTER_CPU_KERNELS 3487 #undef REGISTER_KERNELS 3488 3489 } // namespace tensorflow 3490