Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Ftrl-proximal for TensorFlow."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.framework import constant_op
     21 from tensorflow.python.framework import ops
     22 from tensorflow.python.ops import math_ops
     23 from tensorflow.python.training import optimizer
     24 from tensorflow.python.training import training_ops
     25 from tensorflow.python.util.tf_export import tf_export
     26 
     27 
     28 @tf_export("train.FtrlOptimizer")
     29 class FtrlOptimizer(optimizer.Optimizer):
     30   """Optimizer that implements the FTRL algorithm.
     31 
     32   See this [paper](
     33   https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
     34   This version has support for both online L2 (the L2 penalty given in the paper
     35   above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
     36   loss function).
     37   """
     38 
     39   def __init__(self,
     40                learning_rate,
     41                learning_rate_power=-0.5,
     42                initial_accumulator_value=0.1,
     43                l1_regularization_strength=0.0,
     44                l2_regularization_strength=0.0,
     45                use_locking=False,
     46                name="Ftrl",
     47                accum_name=None,
     48                linear_name=None,
     49                l2_shrinkage_regularization_strength=0.0):
     50     r"""Construct a new FTRL optimizer.
     51 
     52     Args:
     53       learning_rate: A float value or a constant float `Tensor`.
     54       learning_rate_power: A float value, must be less or equal to zero.
     55       initial_accumulator_value: The starting value for accumulators.
     56         Only positive values are allowed.
     57       l1_regularization_strength: A float value, must be greater than or
     58         equal to zero.
     59       l2_regularization_strength: A float value, must be greater than or
     60         equal to zero.
     61       use_locking: If `True` use locks for update operations.
     62       name: Optional name prefix for the operations created when applying
     63         gradients.  Defaults to "Ftrl".
     64       accum_name: The suffix for the variable that keeps the gradient squared
     65         accumulator.  If not present, defaults to name.
     66       linear_name: The suffix for the variable that keeps the linear gradient
     67         accumulator.  If not present, defaults to name + "_1".
     68       l2_shrinkage_regularization_strength: A float value, must be greater than
     69         or equal to zero. This differs from L2 above in that the L2 above is a
     70         stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
     71         The FTRL formulation can be written as:
     72         w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where
     73         \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss
     74         function w.r.t. the weights w.
     75         Specifically, in the absence of L1 regularization, it is equivalent to
     76         the following update rule:
     77         w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
     78                   2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
     79         where lr_t is the learning rate at t.
     80         When input is sparse shrinkage will only happen on the active weights.
     81 
     82     Raises:
     83       ValueError: If one of the arguments is invalid.
     84     """
     85     super(FtrlOptimizer, self).__init__(use_locking, name)
     86 
     87     if initial_accumulator_value <= 0.0:
     88       raise ValueError("initial_accumulator_value %f needs to be positive" %
     89                        initial_accumulator_value)
     90     if learning_rate_power > 0.0:
     91       raise ValueError("learning_rate_power %f needs to be negative or zero" %
     92                        learning_rate_power)
     93     if l1_regularization_strength < 0.0:
     94       raise ValueError(
     95           "l1_regularization_strength %f needs to be positive or zero" %
     96           l1_regularization_strength)
     97     if l2_regularization_strength < 0.0:
     98       raise ValueError(
     99           "l2_regularization_strength %f needs to be positive or zero" %
    100           l2_regularization_strength)
    101     if l2_shrinkage_regularization_strength < 0.0:
    102       raise ValueError(
    103           "l2_shrinkage_regularization_strength %f needs to be positive"
    104           " or zero" % l2_shrinkage_regularization_strength)
    105 
    106     self._learning_rate = learning_rate
    107     self._learning_rate_power = learning_rate_power
    108     self._initial_accumulator_value = initial_accumulator_value
    109     self._l1_regularization_strength = l1_regularization_strength
    110     self._l2_regularization_strength = l2_regularization_strength
    111     self._l2_shrinkage_regularization_strength = (
    112         l2_shrinkage_regularization_strength)
    113     self._learning_rate_tensor = None
    114     self._learning_rate_power_tensor = None
    115     self._l1_regularization_strength_tensor = None
    116     self._l2_regularization_strength_tensor = None
    117     self._l2_shrinkage_regularization_strength_tensor = None
    118     self._accum_name = accum_name
    119     self._linear_name = linear_name
    120 
    121   def _create_slots(self, var_list):
    122     # Create the "accum" and "linear" slots.
    123     for v in var_list:
    124       with ops.colocate_with(v):
    125         val = constant_op.constant(
    126             self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
    127         self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
    128         self._zeros_slot(v, "linear", self._linear_name or self._name)
    129 
    130   def _prepare(self):
    131     self._learning_rate_tensor = ops.convert_to_tensor(
    132         self._learning_rate, name="learning_rate")
    133     self._l1_regularization_strength_tensor = ops.convert_to_tensor(
    134         self._l1_regularization_strength, name="l1_regularization_strength")
    135     self._l2_regularization_strength_tensor = ops.convert_to_tensor(
    136         self._l2_regularization_strength, name="l2_regularization_strength")
    137     self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
    138         self._l2_shrinkage_regularization_strength,
    139         name="l2_shrinkage_regularization_strength")
    140     self._learning_rate_power_tensor = ops.convert_to_tensor(
    141         self._learning_rate_power, name="learning_rate_power")
    142 
    143   def _apply_dense(self, grad, var):
    144     accum = self.get_slot(var, "accum")
    145     linear = self.get_slot(var, "linear")
    146     if self._l2_shrinkage_regularization_strength <= 0.0:
    147       return training_ops.apply_ftrl(
    148           var,
    149           accum,
    150           linear,
    151           grad,
    152           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    153           math_ops.cast(self._l1_regularization_strength_tensor,
    154                         var.dtype.base_dtype),
    155           math_ops.cast(self._l2_regularization_strength_tensor,
    156                         var.dtype.base_dtype),
    157           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    158           use_locking=self._use_locking)
    159     else:
    160       return training_ops.apply_ftrl_v2(
    161           var,
    162           accum,
    163           linear,
    164           grad,
    165           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    166           math_ops.cast(self._l1_regularization_strength_tensor,
    167                         var.dtype.base_dtype),
    168           math_ops.cast(self._l2_regularization_strength_tensor,
    169                         var.dtype.base_dtype),
    170           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
    171                         var.dtype.base_dtype),
    172           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    173           use_locking=self._use_locking)
    174 
    175   def _resource_apply_dense(self, grad, var):
    176     accum = self.get_slot(var, "accum")
    177     linear = self.get_slot(var, "linear")
    178     if self._l2_shrinkage_regularization_strength <= 0.0:
    179       return training_ops.resource_apply_ftrl(
    180           var.handle,
    181           accum.handle,
    182           linear.handle,
    183           grad,
    184           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    185           math_ops.cast(self._l1_regularization_strength_tensor,
    186                         var.dtype.base_dtype),
    187           math_ops.cast(self._l2_regularization_strength_tensor,
    188                         var.dtype.base_dtype),
    189           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    190           use_locking=self._use_locking)
    191     else:
    192       return training_ops.resource_apply_ftrl_v2(
    193           var.handle,
    194           accum.handle,
    195           linear.handle,
    196           grad,
    197           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    198           math_ops.cast(self._l1_regularization_strength_tensor,
    199                         var.dtype.base_dtype),
    200           math_ops.cast(self._l2_regularization_strength_tensor,
    201                         var.dtype.base_dtype),
    202           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
    203                         var.dtype.base_dtype),
    204           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    205           use_locking=self._use_locking)
    206 
    207   def _apply_sparse(self, grad, var):
    208     accum = self.get_slot(var, "accum")
    209     linear = self.get_slot(var, "linear")
    210     if self._l2_shrinkage_regularization_strength <= 0.0:
    211       return training_ops.sparse_apply_ftrl(
    212           var,
    213           accum,
    214           linear,
    215           grad.values,
    216           grad.indices,
    217           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    218           math_ops.cast(self._l1_regularization_strength_tensor,
    219                         var.dtype.base_dtype),
    220           math_ops.cast(self._l2_regularization_strength_tensor,
    221                         var.dtype.base_dtype),
    222           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    223           use_locking=self._use_locking)
    224     else:
    225       return training_ops.sparse_apply_ftrl_v2(
    226           var,
    227           accum,
    228           linear,
    229           grad.values,
    230           grad.indices,
    231           math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    232           math_ops.cast(self._l1_regularization_strength_tensor,
    233                         var.dtype.base_dtype),
    234           math_ops.cast(self._l2_regularization_strength_tensor,
    235                         var.dtype.base_dtype),
    236           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
    237                         grad.dtype.base_dtype),
    238           math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
    239           use_locking=self._use_locking)
    240 
    241   def _resource_apply_sparse(self, grad, var, indices):
    242     accum = self.get_slot(var, "accum")
    243     linear = self.get_slot(var, "linear")
    244     if self._l2_shrinkage_regularization_strength <= 0.0:
    245       return training_ops.resource_sparse_apply_ftrl(
    246           var.handle,
    247           accum.handle,
    248           linear.handle,
    249           grad,
    250           indices,
    251           math_ops.cast(self._learning_rate_tensor, grad.dtype),
    252           math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
    253           math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
    254           math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
    255           use_locking=self._use_locking)
    256     else:
    257       return training_ops.resource_sparse_apply_ftrl_v2(
    258           var.handle,
    259           accum.handle,
    260           linear.handle,
    261           grad,
    262           indices,
    263           math_ops.cast(self._learning_rate_tensor, grad.dtype),
    264           math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
    265           math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype),
    266           math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
    267                         grad.dtype),
    268           math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
    269           use_locking=self._use_locking)
    270