Home | History | Annotate | Download | only in kernels

Lines Matching defs:learning_rate

171     auto learning_rate = accum.constant(lr()) * accum.rsqrt();
174 prox_var.device(d) -= grad * learning_rate;
178 (prox_var.abs() - learning_rate * prox_var.constant(l1()))
180 (var.constant(1.0) + var.constant(l2()) * learning_rate);
183 prox_var / (var.constant(1.0) + var.constant(l2()) * learning_rate);
973 // compute learning_rate for current step.
974 auto learning_rate = v.constant(lr_scalar);
976 // v = w - g * learning_rate.
977 prox_v -= g * learning_rate;
981 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
983 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
986 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
1005 auto learning_rate = lr_scalar;
1007 prox_v -= learning_rate * g;
1011 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
1013 (1.0 + l2_scalar * learning_rate);
1015 var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);
1467 // compute learning_rate for current step.
1468 auto learning_rate = a.constant(lr_scalar) * a.rsqrt();
1470 // v = w - g * learning_rate.
1471 prox_v -= g * learning_rate;
1475 (prox_v.abs() - learning_rate * prox_v.constant(l1_scalar))
1477 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
1480 (v.constant(1.0) + v.constant(l2_scalar) * learning_rate);
1502 auto learning_rate = lr_scalar / std::sqrt(a);
1504 prox_v -= learning_rate * g;
1508 std::max(std::abs(prox_v) - learning_rate * l1_scalar,
1510 (1.0 + l2_scalar * learning_rate);
1512 var_flat(index) = prox_v / (1.0 + l2_scalar * learning_rate);