Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Proximal stochastic dual coordinate ascent optimizer for linear models."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 
     22 from six.moves import range
     23 
     24 from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework.ops import internal_convert_to_tensor
     28 from tensorflow.python.framework.ops import name_scope
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import control_flow_ops
     31 from tensorflow.python.ops import gen_sdca_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import nn_ops
     34 from tensorflow.python.ops import state_ops
     35 from tensorflow.python.ops import variables as var_ops
     36 from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
     37 from tensorflow.python.summary import summary
     38 
     39 __all__ = ['SdcaModel']
     40 
     41 
     42 # TODO(sibyl-Aix6ihai): add name_scope to appropriate methods.
     43 class SdcaModel(object):
     44   """Stochastic dual coordinate ascent solver for linear models.
     45 
     46     This class currently only supports a single machine (multi-threaded)
     47     implementation. We expect the weights and duals to fit in a single machine.
     48 
     49     Loss functions supported:
     50 
     51      * Binary logistic loss
     52      * Squared loss
     53      * Hinge loss
     54      * Smooth hinge loss
     55 
     56     This class defines an optimizer API to train a linear model.
     57 
     58     ### Usage
     59 
     60     ```python
     61     # Create a solver with the desired parameters.
     62     lr = tf.contrib.linear_optimizer.SdcaModel(examples, variables, options)
     63     min_op = lr.minimize()
     64     opt_op = lr.update_weights(min_op)
     65 
     66     predictions = lr.predictions(examples)
     67     # Primal loss + L1 loss + L2 loss.
     68     regularized_loss = lr.regularized_loss(examples)
     69     # Primal loss only
     70     unregularized_loss = lr.unregularized_loss(examples)
     71 
     72     examples: {
     73       sparse_features: list of SparseFeatureColumn.
     74       dense_features: list of dense tensors of type float32.
     75       example_labels: a tensor of type float32 and shape [Num examples]
     76       example_weights: a tensor of type float32 and shape [Num examples]
     77       example_ids: a tensor of type string and shape [Num examples]
     78     }
     79     variables: {
     80       sparse_features_weights: list of tensors of shape [vocab size]
     81       dense_features_weights: list of tensors of shape [dense_feature_dimension]
     82     }
     83     options: {
     84       symmetric_l1_regularization: 0.0
     85       symmetric_l2_regularization: 1.0
     86       loss_type: "logistic_loss"
     87       num_loss_partitions: 1 (Optional, with default value of 1. Number of
     88       partitions of the global loss function, 1 means single machine solver,
     89       and >1 when we have more than one optimizer working concurrently.)
     90       num_table_shards: 1 (Optional, with default value of 1. Number of shards
     91       of the internal state table, typically set to match the number of
     92       parameter servers for large data sets.
     93     }
     94     ```
     95 
     96     In the training program you will just have to run the returned Op from
     97     minimize().
     98 
     99     ```python
    100     # Execute opt_op and train for num_steps.
    101     for _ in range(num_steps):
    102       opt_op.run()
    103 
    104     # You can also check for convergence by calling
    105     lr.approximate_duality_gap()
    106     ```
    107   """
    108 
    109   def __init__(self, examples, variables, options):
    110     """Create a new sdca optimizer."""
    111 
    112     if not examples or not variables or not options:
    113       raise ValueError('examples, variables and options must all be specified.')
    114 
    115     supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
    116                         'smooth_hinge_loss')
    117     if options['loss_type'] not in supported_losses:
    118       raise ValueError('Unsupported loss_type: ', options['loss_type'])
    119 
    120     self._assertSpecified([
    121         'example_labels', 'example_weights', 'example_ids', 'sparse_features',
    122         'dense_features'
    123     ], examples)
    124     self._assertList(['sparse_features', 'dense_features'], examples)
    125 
    126     self._assertSpecified(['sparse_features_weights', 'dense_features_weights'],
    127                           variables)
    128     self._assertList(['sparse_features_weights', 'dense_features_weights'],
    129                      variables)
    130 
    131     self._assertSpecified([
    132         'loss_type', 'symmetric_l2_regularization',
    133         'symmetric_l1_regularization'
    134     ], options)
    135 
    136     for name in ['symmetric_l1_regularization', 'symmetric_l2_regularization']:
    137       value = options[name]
    138       if value < 0.0:
    139         raise ValueError('%s should be non-negative. Found (%f)' %
    140                          (name, value))
    141 
    142     self._examples = examples
    143     self._variables = variables
    144     self._options = options
    145     self._create_slots()
    146     self._hashtable = ShardedMutableDenseHashTable(
    147         key_dtype=dtypes.int64,
    148         value_dtype=dtypes.float32,
    149         num_shards=self._num_table_shards(),
    150         default_value=[0.0, 0.0, 0.0, 0.0],
    151         # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe
    152         # empty_key (that will never collide with actual payloads).
    153         empty_key=[0, 0])
    154 
    155     summary.scalar('approximate_duality_gap', self.approximate_duality_gap())
    156     summary.scalar('examples_seen', self._hashtable.size())
    157 
    158   def _symmetric_l1_regularization(self):
    159     return self._options['symmetric_l1_regularization']
    160 
    161   def _symmetric_l2_regularization(self):
    162     # Algorithmic requirement (for now) is to have minimal l2 of 1.0.
    163     return max(self._options['symmetric_l2_regularization'], 1.0)
    164 
    165   def _num_loss_partitions(self):
    166     # Number of partitions of the global objective.
    167     # TODO(andreasst): set num_loss_partitions automatically based on the number
    168     # of workers
    169     return self._options.get('num_loss_partitions', 1)
    170 
    171   def _num_table_shards(self):
    172     # Number of hash table shards.
    173     # Return 1 if not specified or if the value is 'None'
    174     # TODO(andreasst): set num_table_shards automatically based on the number
    175     # of parameter servers
    176     num_shards = self._options.get('num_table_shards')
    177     return 1 if num_shards is None else num_shards
    178 
    179   # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic.
    180   def _create_slots(self):
    181     # Make internal variables which have the updates before applying L1
    182     # regularization.
    183     self._slots = collections.defaultdict(list)
    184     for name in ['sparse_features_weights', 'dense_features_weights']:
    185       for var in self._variables[name]:
    186         with ops.device(var.device):
    187           # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is
    188           # fixed
    189           self._slots['unshrinked_' + name].append(
    190               var_ops.Variable(
    191                   array_ops.zeros_like(var.initialized_value(), dtypes.float32),
    192                   name=var.op.name + '_unshrinked/SDCAOptimizer'))
    193 
    194   def _assertSpecified(self, items, check_in):
    195     for x in items:
    196       if check_in[x] is None:
    197         raise ValueError(check_in[x] + ' must be specified.')
    198 
    199   def _assertList(self, items, check_in):
    200     for x in items:
    201       if not isinstance(check_in[x], list):
    202         raise ValueError(x + ' must be a list.')
    203 
    204   def _l1_loss(self):
    205     """Computes the (un-normalized) l1 loss of the model."""
    206     with name_scope('sdca/l1_loss'):
    207       sums = []
    208       for name in ['sparse_features_weights', 'dense_features_weights']:
    209         for weights in self._convert_n_to_tensor(self._variables[name]):
    210           with ops.device(weights.device):
    211             sums.append(
    212                 math_ops.reduce_sum(
    213                     math_ops.abs(math_ops.cast(weights, dtypes.float64))))
    214       # SDCA L1 regularization cost is: l1 * sum(|weights|)
    215       return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums)
    216 
    217   def _l2_loss(self, l2):
    218     """Computes the (un-normalized) l2 loss of the model."""
    219     with name_scope('sdca/l2_loss'):
    220       sums = []
    221       for name in ['sparse_features_weights', 'dense_features_weights']:
    222         for weights in self._convert_n_to_tensor(self._variables[name]):
    223           with ops.device(weights.device):
    224             sums.append(
    225                 math_ops.reduce_sum(
    226                     math_ops.square(math_ops.cast(weights, dtypes.float64))))
    227       # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2
    228       return l2 * math_ops.add_n(sums) / 2.0
    229 
    230   def _convert_n_to_tensor(self, input_list, as_ref=False):
    231     """Converts input list to a set of tensors."""
    232     return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list]
    233 
    234   def _linear_predictions(self, examples):
    235     """Returns predictions of the form w*x."""
    236     with name_scope('sdca/prediction'):
    237       sparse_variables = self._convert_n_to_tensor(self._variables[
    238           'sparse_features_weights'])
    239       result_sparse = 0.0
    240       for sfc, sv in zip(examples['sparse_features'], sparse_variables):
    241         # TODO(sibyl-Aix6ihai): following does not take care of missing features.
    242         result_sparse += math_ops.segment_sum(
    243             math_ops.multiply(
    244                 array_ops.gather(sv, sfc.feature_indices), sfc.feature_values),
    245             sfc.example_indices)
    246       dense_features = self._convert_n_to_tensor(examples['dense_features'])
    247       dense_variables = self._convert_n_to_tensor(self._variables[
    248           'dense_features_weights'])
    249 
    250       result_dense = 0.0
    251       for i in range(len(dense_variables)):
    252         result_dense += math_ops.matmul(dense_features[i],
    253                                         array_ops.expand_dims(
    254                                             dense_variables[i], -1))
    255 
    256     # Reshaping to allow shape inference at graph construction time.
    257     return array_ops.reshape(result_dense, [-1]) + result_sparse
    258 
    259   def predictions(self, examples):
    260     """Add operations to compute predictions by the model.
    261 
    262     If logistic_loss is being used, predicted probabilities are returned.
    263     Otherwise, (raw) linear predictions (w*x) are returned.
    264 
    265     Args:
    266       examples: Examples to compute predictions on.
    267 
    268     Returns:
    269       An Operation that computes the predictions for examples.
    270 
    271     Raises:
    272       ValueError: if examples are not well defined.
    273     """
    274     self._assertSpecified(
    275         ['example_weights', 'sparse_features', 'dense_features'], examples)
    276     self._assertList(['sparse_features', 'dense_features'], examples)
    277 
    278     result = self._linear_predictions(examples)
    279     if self._options['loss_type'] == 'logistic_loss':
    280       # Convert logits to probability for logistic loss predictions.
    281       with name_scope('sdca/logistic_prediction'):
    282         result = math_ops.sigmoid(result)
    283     return result
    284 
    285   def minimize(self, global_step=None, name=None):
    286     """Add operations to train a linear model by minimizing the loss function.
    287 
    288     Args:
    289       global_step: Optional `Variable` to increment by one after the
    290         variables have been updated.
    291       name: Optional name for the returned operation.
    292 
    293     Returns:
    294       An Operation that updates the variables passed in the constructor.
    295     """
    296     # Technically, the op depends on a lot more than the variables,
    297     # but we'll keep the list short.
    298     with name_scope(name, 'sdca/minimize'):
    299       sparse_example_indices = []
    300       sparse_feature_indices = []
    301       sparse_features_values = []
    302       for sf in self._examples['sparse_features']:
    303         sparse_example_indices.append(sf.example_indices)
    304         sparse_feature_indices.append(sf.feature_indices)
    305         # If feature values are missing, sdca assumes a value of 1.0f.
    306         if sf.feature_values is not None:
    307           sparse_features_values.append(sf.feature_values)
    308 
    309       # pylint: disable=protected-access
    310       example_ids_hashed = gen_sdca_ops.sdca_fprint(
    311           internal_convert_to_tensor(self._examples['example_ids']))
    312       # pylint: enable=protected-access
    313       example_state_data = self._hashtable.lookup(example_ids_hashed)
    314       # Solver returns example_state_update, new delta sparse_feature_weights
    315       # and delta dense_feature_weights.
    316 
    317       weights_tensor = self._convert_n_to_tensor(self._slots[
    318           'unshrinked_sparse_features_weights'])
    319       sparse_weights = []
    320       sparse_indices = []
    321       for w, i in zip(weights_tensor, sparse_feature_indices):
    322         # Find the feature ids to lookup in the variables.
    323         with ops.device(w.device):
    324           sparse_indices.append(
    325               math_ops.cast(
    326                   array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
    327                   dtypes.int64))
    328           sparse_weights.append(array_ops.gather(w, sparse_indices[-1]))
    329 
    330       # pylint: disable=protected-access
    331       esu, sfw, dfw = gen_sdca_ops.sdca_optimizer(
    332           sparse_example_indices,
    333           sparse_feature_indices,
    334           sparse_features_values,
    335           self._convert_n_to_tensor(self._examples['dense_features']),
    336           internal_convert_to_tensor(self._examples['example_weights']),
    337           internal_convert_to_tensor(self._examples['example_labels']),
    338           sparse_indices,
    339           sparse_weights,
    340           self._convert_n_to_tensor(self._slots[
    341               'unshrinked_dense_features_weights']),
    342           example_state_data,
    343           loss_type=self._options['loss_type'],
    344           l1=self._options['symmetric_l1_regularization'],
    345           l2=self._symmetric_l2_regularization(),
    346           num_loss_partitions=self._num_loss_partitions(),
    347           num_inner_iterations=1)
    348       # pylint: enable=protected-access
    349 
    350       with ops.control_dependencies([esu]):
    351         update_ops = [self._hashtable.insert(example_ids_hashed, esu)]
    352         # Update the weights before the proximal step.
    353         for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'],
    354                            sparse_indices, sfw):
    355           update_ops.append(state_ops.scatter_add(w, i, u))
    356         for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw):
    357           update_ops.append(w.assign_add(u))
    358 
    359       if not global_step:
    360         return control_flow_ops.group(*update_ops)
    361       with ops.control_dependencies(update_ops):
    362         return state_ops.assign_add(global_step, 1, name=name).op
    363 
    364   def update_weights(self, train_op):
    365     """Updates the model weights.
    366 
    367     This function must be called on at least one worker after `minimize`.
    368     In distributed training this call can be omitted on non-chief workers to
    369     speed up training.
    370 
    371     Args:
    372       train_op: The operation returned by the `minimize` call.
    373 
    374     Returns:
    375       An Operation that updates the model weights.
    376     """
    377     with ops.control_dependencies([train_op]):
    378       update_ops = []
    379       # Copy over unshrinked weights to user provided variables.
    380       for name in ['sparse_features_weights', 'dense_features_weights']:
    381         for var, slot_var in zip(self._variables[name],
    382                                  self._slots['unshrinked_' + name]):
    383           update_ops.append(var.assign(slot_var))
    384 
    385     # Apply proximal step.
    386     with ops.control_dependencies(update_ops):
    387       update_ops = []
    388       for name in ['sparse_features_weights', 'dense_features_weights']:
    389         for var in self._variables[name]:
    390           with ops.device(var.device):
    391             # pylint: disable=protected-access
    392             update_ops.append(
    393                 gen_sdca_ops.sdca_shrink_l1(
    394                     self._convert_n_to_tensor(
    395                         [var], as_ref=True),
    396                     l1=self._symmetric_l1_regularization(),
    397                     l2=self._symmetric_l2_regularization()))
    398       return control_flow_ops.group(*update_ops)
    399 
    400   def approximate_duality_gap(self):
    401     """Add operations to compute the approximate duality gap.
    402 
    403     Returns:
    404       An Operation that computes the approximate duality gap over all
    405       examples.
    406     """
    407     with name_scope('sdca/approximate_duality_gap'):
    408       _, values_list = self._hashtable.export_sharded()
    409       shard_sums = []
    410       for values in values_list:
    411         with ops.device(values.device):
    412           # For large tables to_double() below allocates a large temporary
    413           # tensor that is freed once the sum operation completes. To reduce
    414           # peak memory usage in cases where we have multiple large tables on a
    415           # single device, we serialize these operations.
    416           # Note that we need double precision to get accurate results.
    417           with ops.control_dependencies(shard_sums):
    418             shard_sums.append(
    419                 math_ops.reduce_sum(math_ops.to_double(values), 0))
    420       summed_values = math_ops.add_n(shard_sums)
    421 
    422       primal_loss = summed_values[1]
    423       dual_loss = summed_values[2]
    424       example_weights = summed_values[3]
    425       # Note: we return NaN if there are no weights or all weights are 0, e.g.
    426       # if no examples have been processed
    427       return (primal_loss + dual_loss + self._l1_loss() +
    428               (2.0 * self._l2_loss(self._symmetric_l2_regularization()))
    429              ) / example_weights
    430 
    431   def unregularized_loss(self, examples):
    432     """Add operations to compute the loss (without the regularization loss).
    433 
    434     Args:
    435       examples: Examples to compute unregularized loss on.
    436 
    437     Returns:
    438       An Operation that computes mean (unregularized) loss for given set of
    439       examples.
    440 
    441     Raises:
    442       ValueError: if examples are not well defined.
    443     """
    444     self._assertSpecified([
    445         'example_labels', 'example_weights', 'sparse_features', 'dense_features'
    446     ], examples)
    447     self._assertList(['sparse_features', 'dense_features'], examples)
    448     with name_scope('sdca/unregularized_loss'):
    449       predictions = math_ops.cast(
    450           self._linear_predictions(examples), dtypes.float64)
    451       labels = math_ops.cast(
    452           internal_convert_to_tensor(examples['example_labels']),
    453           dtypes.float64)
    454       weights = math_ops.cast(
    455           internal_convert_to_tensor(examples['example_weights']),
    456           dtypes.float64)
    457 
    458       if self._options['loss_type'] == 'logistic_loss':
    459         return math_ops.reduce_sum(math_ops.multiply(
    460             sigmoid_cross_entropy_with_logits(labels=labels,
    461                                               logits=predictions),
    462             weights)) / math_ops.reduce_sum(weights)
    463 
    464       if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
    465         # hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
    466         # first convert 0/1 labels into -1/1 labels.
    467         all_ones = array_ops.ones_like(predictions)
    468         adjusted_labels = math_ops.subtract(2 * labels, all_ones)
    469         # Tensor that contains (unweighted) error (hinge loss) per
    470         # example.
    471         error = nn_ops.relu(
    472             math_ops.subtract(all_ones,
    473                               math_ops.multiply(adjusted_labels, predictions)))
    474         weighted_error = math_ops.multiply(error, weights)
    475         return math_ops.reduce_sum(weighted_error) / math_ops.reduce_sum(
    476             weights)
    477 
    478       # squared loss
    479       err = math_ops.subtract(labels, predictions)
    480 
    481       weighted_squared_err = math_ops.multiply(math_ops.square(err), weights)
    482       # SDCA squared loss function is sum(err^2) / (2*sum(weights))
    483       return (math_ops.reduce_sum(weighted_squared_err) /
    484               (2.0 * math_ops.reduce_sum(weights)))
    485 
    486   def regularized_loss(self, examples):
    487     """Add operations to compute the loss with regularization loss included.
    488 
    489     Args:
    490       examples: Examples to compute loss on.
    491 
    492     Returns:
    493       An Operation that computes mean (regularized) loss for given set of
    494       examples.
    495     Raises:
    496       ValueError: if examples are not well defined.
    497     """
    498     self._assertSpecified([
    499         'example_labels', 'example_weights', 'sparse_features', 'dense_features'
    500     ], examples)
    501     self._assertList(['sparse_features', 'dense_features'], examples)
    502     with name_scope('sdca/regularized_loss'):
    503       weights = internal_convert_to_tensor(examples['example_weights'])
    504       return ((
    505           self._l1_loss() +
    506           # Note that here we are using the raw regularization
    507           # (as specified by the user) and *not*
    508           # self._symmetric_l2_regularization().
    509           self._l2_loss(self._options['symmetric_l2_regularization'])) /
    510               math_ops.reduce_sum(math_ops.cast(weights, dtypes.float64)) +
    511               self.unregularized_loss(examples))
    512