1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Proximal stochastic dual coordinate ascent optimizer for linear models.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 22 from six.moves import range 23 24 from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework.ops import internal_convert_to_tensor 28 from tensorflow.python.framework.ops import name_scope 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import control_flow_ops 31 from tensorflow.python.ops import gen_sdca_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import nn_ops 34 from tensorflow.python.ops import state_ops 35 from tensorflow.python.ops import variables as var_ops 36 from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits 37 from tensorflow.python.summary import summary 38 39 __all__ = ['SdcaModel'] 40 41 42 # TODO(sibyl-Aix6ihai): add name_scope to appropriate methods. 43 class SdcaModel(object): 44 """Stochastic dual coordinate ascent solver for linear models. 45 46 This class currently only supports a single machine (multi-threaded) 47 implementation. We expect the weights and duals to fit in a single machine. 48 49 Loss functions supported: 50 51 * Binary logistic loss 52 * Squared loss 53 * Hinge loss 54 * Smooth hinge loss 55 56 This class defines an optimizer API to train a linear model. 57 58 ### Usage 59 60 ```python 61 # Create a solver with the desired parameters. 62 lr = tf.contrib.linear_optimizer.SdcaModel(examples, variables, options) 63 min_op = lr.minimize() 64 opt_op = lr.update_weights(min_op) 65 66 predictions = lr.predictions(examples) 67 # Primal loss + L1 loss + L2 loss. 68 regularized_loss = lr.regularized_loss(examples) 69 # Primal loss only 70 unregularized_loss = lr.unregularized_loss(examples) 71 72 examples: { 73 sparse_features: list of SparseFeatureColumn. 74 dense_features: list of dense tensors of type float32. 75 example_labels: a tensor of type float32 and shape [Num examples] 76 example_weights: a tensor of type float32 and shape [Num examples] 77 example_ids: a tensor of type string and shape [Num examples] 78 } 79 variables: { 80 sparse_features_weights: list of tensors of shape [vocab size] 81 dense_features_weights: list of tensors of shape [dense_feature_dimension] 82 } 83 options: { 84 symmetric_l1_regularization: 0.0 85 symmetric_l2_regularization: 1.0 86 loss_type: "logistic_loss" 87 num_loss_partitions: 1 (Optional, with default value of 1. Number of 88 partitions of the global loss function, 1 means single machine solver, 89 and >1 when we have more than one optimizer working concurrently.) 90 num_table_shards: 1 (Optional, with default value of 1. Number of shards 91 of the internal state table, typically set to match the number of 92 parameter servers for large data sets. 93 } 94 ``` 95 96 In the training program you will just have to run the returned Op from 97 minimize(). 98 99 ```python 100 # Execute opt_op and train for num_steps. 101 for _ in range(num_steps): 102 opt_op.run() 103 104 # You can also check for convergence by calling 105 lr.approximate_duality_gap() 106 ``` 107 """ 108 109 def __init__(self, examples, variables, options): 110 """Create a new sdca optimizer.""" 111 112 if not examples or not variables or not options: 113 raise ValueError('examples, variables and options must all be specified.') 114 115 supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss', 116 'smooth_hinge_loss') 117 if options['loss_type'] not in supported_losses: 118 raise ValueError('Unsupported loss_type: ', options['loss_type']) 119 120 self._assertSpecified([ 121 'example_labels', 'example_weights', 'example_ids', 'sparse_features', 122 'dense_features' 123 ], examples) 124 self._assertList(['sparse_features', 'dense_features'], examples) 125 126 self._assertSpecified(['sparse_features_weights', 'dense_features_weights'], 127 variables) 128 self._assertList(['sparse_features_weights', 'dense_features_weights'], 129 variables) 130 131 self._assertSpecified([ 132 'loss_type', 'symmetric_l2_regularization', 133 'symmetric_l1_regularization' 134 ], options) 135 136 for name in ['symmetric_l1_regularization', 'symmetric_l2_regularization']: 137 value = options[name] 138 if value < 0.0: 139 raise ValueError('%s should be non-negative. Found (%f)' % 140 (name, value)) 141 142 self._examples = examples 143 self._variables = variables 144 self._options = options 145 self._create_slots() 146 self._hashtable = ShardedMutableDenseHashTable( 147 key_dtype=dtypes.int64, 148 value_dtype=dtypes.float32, 149 num_shards=self._num_table_shards(), 150 default_value=[0.0, 0.0, 0.0, 0.0], 151 # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe 152 # empty_key (that will never collide with actual payloads). 153 empty_key=[0, 0]) 154 155 summary.scalar('approximate_duality_gap', self.approximate_duality_gap()) 156 summary.scalar('examples_seen', self._hashtable.size()) 157 158 def _symmetric_l1_regularization(self): 159 return self._options['symmetric_l1_regularization'] 160 161 def _symmetric_l2_regularization(self): 162 # Algorithmic requirement (for now) is to have minimal l2 of 1.0. 163 return max(self._options['symmetric_l2_regularization'], 1.0) 164 165 def _num_loss_partitions(self): 166 # Number of partitions of the global objective. 167 # TODO(andreasst): set num_loss_partitions automatically based on the number 168 # of workers 169 return self._options.get('num_loss_partitions', 1) 170 171 def _num_table_shards(self): 172 # Number of hash table shards. 173 # Return 1 if not specified or if the value is 'None' 174 # TODO(andreasst): set num_table_shards automatically based on the number 175 # of parameter servers 176 num_shards = self._options.get('num_table_shards') 177 return 1 if num_shards is None else num_shards 178 179 # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. 180 def _create_slots(self): 181 # Make internal variables which have the updates before applying L1 182 # regularization. 183 self._slots = collections.defaultdict(list) 184 for name in ['sparse_features_weights', 'dense_features_weights']: 185 for var in self._variables[name]: 186 with ops.device(var.device): 187 # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is 188 # fixed 189 self._slots['unshrinked_' + name].append( 190 var_ops.Variable( 191 array_ops.zeros_like(var.initialized_value(), dtypes.float32), 192 name=var.op.name + '_unshrinked/SDCAOptimizer')) 193 194 def _assertSpecified(self, items, check_in): 195 for x in items: 196 if check_in[x] is None: 197 raise ValueError(check_in[x] + ' must be specified.') 198 199 def _assertList(self, items, check_in): 200 for x in items: 201 if not isinstance(check_in[x], list): 202 raise ValueError(x + ' must be a list.') 203 204 def _l1_loss(self): 205 """Computes the (un-normalized) l1 loss of the model.""" 206 with name_scope('sdca/l1_loss'): 207 sums = [] 208 for name in ['sparse_features_weights', 'dense_features_weights']: 209 for weights in self._convert_n_to_tensor(self._variables[name]): 210 with ops.device(weights.device): 211 sums.append( 212 math_ops.reduce_sum( 213 math_ops.abs(math_ops.cast(weights, dtypes.float64)))) 214 # SDCA L1 regularization cost is: l1 * sum(|weights|) 215 return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) 216 217 def _l2_loss(self, l2): 218 """Computes the (un-normalized) l2 loss of the model.""" 219 with name_scope('sdca/l2_loss'): 220 sums = [] 221 for name in ['sparse_features_weights', 'dense_features_weights']: 222 for weights in self._convert_n_to_tensor(self._variables[name]): 223 with ops.device(weights.device): 224 sums.append( 225 math_ops.reduce_sum( 226 math_ops.square(math_ops.cast(weights, dtypes.float64)))) 227 # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 228 return l2 * math_ops.add_n(sums) / 2.0 229 230 def _convert_n_to_tensor(self, input_list, as_ref=False): 231 """Converts input list to a set of tensors.""" 232 return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] 233 234 def _linear_predictions(self, examples): 235 """Returns predictions of the form w*x.""" 236 with name_scope('sdca/prediction'): 237 sparse_variables = self._convert_n_to_tensor(self._variables[ 238 'sparse_features_weights']) 239 result_sparse = 0.0 240 for sfc, sv in zip(examples['sparse_features'], sparse_variables): 241 # TODO(sibyl-Aix6ihai): following does not take care of missing features. 242 result_sparse += math_ops.segment_sum( 243 math_ops.multiply( 244 array_ops.gather(sv, sfc.feature_indices), sfc.feature_values), 245 sfc.example_indices) 246 dense_features = self._convert_n_to_tensor(examples['dense_features']) 247 dense_variables = self._convert_n_to_tensor(self._variables[ 248 'dense_features_weights']) 249 250 result_dense = 0.0 251 for i in range(len(dense_variables)): 252 result_dense += math_ops.matmul(dense_features[i], 253 array_ops.expand_dims( 254 dense_variables[i], -1)) 255 256 # Reshaping to allow shape inference at graph construction time. 257 return array_ops.reshape(result_dense, [-1]) + result_sparse 258 259 def predictions(self, examples): 260 """Add operations to compute predictions by the model. 261 262 If logistic_loss is being used, predicted probabilities are returned. 263 Otherwise, (raw) linear predictions (w*x) are returned. 264 265 Args: 266 examples: Examples to compute predictions on. 267 268 Returns: 269 An Operation that computes the predictions for examples. 270 271 Raises: 272 ValueError: if examples are not well defined. 273 """ 274 self._assertSpecified( 275 ['example_weights', 'sparse_features', 'dense_features'], examples) 276 self._assertList(['sparse_features', 'dense_features'], examples) 277 278 result = self._linear_predictions(examples) 279 if self._options['loss_type'] == 'logistic_loss': 280 # Convert logits to probability for logistic loss predictions. 281 with name_scope('sdca/logistic_prediction'): 282 result = math_ops.sigmoid(result) 283 return result 284 285 def minimize(self, global_step=None, name=None): 286 """Add operations to train a linear model by minimizing the loss function. 287 288 Args: 289 global_step: Optional `Variable` to increment by one after the 290 variables have been updated. 291 name: Optional name for the returned operation. 292 293 Returns: 294 An Operation that updates the variables passed in the constructor. 295 """ 296 # Technically, the op depends on a lot more than the variables, 297 # but we'll keep the list short. 298 with name_scope(name, 'sdca/minimize'): 299 sparse_example_indices = [] 300 sparse_feature_indices = [] 301 sparse_features_values = [] 302 for sf in self._examples['sparse_features']: 303 sparse_example_indices.append(sf.example_indices) 304 sparse_feature_indices.append(sf.feature_indices) 305 # If feature values are missing, sdca assumes a value of 1.0f. 306 if sf.feature_values is not None: 307 sparse_features_values.append(sf.feature_values) 308 309 # pylint: disable=protected-access 310 example_ids_hashed = gen_sdca_ops.sdca_fprint( 311 internal_convert_to_tensor(self._examples['example_ids'])) 312 # pylint: enable=protected-access 313 example_state_data = self._hashtable.lookup(example_ids_hashed) 314 # Solver returns example_state_update, new delta sparse_feature_weights 315 # and delta dense_feature_weights. 316 317 weights_tensor = self._convert_n_to_tensor(self._slots[ 318 'unshrinked_sparse_features_weights']) 319 sparse_weights = [] 320 sparse_indices = [] 321 for w, i in zip(weights_tensor, sparse_feature_indices): 322 # Find the feature ids to lookup in the variables. 323 with ops.device(w.device): 324 sparse_indices.append( 325 math_ops.cast( 326 array_ops.unique(math_ops.cast(i, dtypes.int32))[0], 327 dtypes.int64)) 328 sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) 329 330 # pylint: disable=protected-access 331 esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( 332 sparse_example_indices, 333 sparse_feature_indices, 334 sparse_features_values, 335 self._convert_n_to_tensor(self._examples['dense_features']), 336 internal_convert_to_tensor(self._examples['example_weights']), 337 internal_convert_to_tensor(self._examples['example_labels']), 338 sparse_indices, 339 sparse_weights, 340 self._convert_n_to_tensor(self._slots[ 341 'unshrinked_dense_features_weights']), 342 example_state_data, 343 loss_type=self._options['loss_type'], 344 l1=self._options['symmetric_l1_regularization'], 345 l2=self._symmetric_l2_regularization(), 346 num_loss_partitions=self._num_loss_partitions(), 347 num_inner_iterations=1) 348 # pylint: enable=protected-access 349 350 with ops.control_dependencies([esu]): 351 update_ops = [self._hashtable.insert(example_ids_hashed, esu)] 352 # Update the weights before the proximal step. 353 for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], 354 sparse_indices, sfw): 355 update_ops.append(state_ops.scatter_add(w, i, u)) 356 for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): 357 update_ops.append(w.assign_add(u)) 358 359 if not global_step: 360 return control_flow_ops.group(*update_ops) 361 with ops.control_dependencies(update_ops): 362 return state_ops.assign_add(global_step, 1, name=name).op 363 364 def update_weights(self, train_op): 365 """Updates the model weights. 366 367 This function must be called on at least one worker after `minimize`. 368 In distributed training this call can be omitted on non-chief workers to 369 speed up training. 370 371 Args: 372 train_op: The operation returned by the `minimize` call. 373 374 Returns: 375 An Operation that updates the model weights. 376 """ 377 with ops.control_dependencies([train_op]): 378 update_ops = [] 379 # Copy over unshrinked weights to user provided variables. 380 for name in ['sparse_features_weights', 'dense_features_weights']: 381 for var, slot_var in zip(self._variables[name], 382 self._slots['unshrinked_' + name]): 383 update_ops.append(var.assign(slot_var)) 384 385 # Apply proximal step. 386 with ops.control_dependencies(update_ops): 387 update_ops = [] 388 for name in ['sparse_features_weights', 'dense_features_weights']: 389 for var in self._variables[name]: 390 with ops.device(var.device): 391 # pylint: disable=protected-access 392 update_ops.append( 393 gen_sdca_ops.sdca_shrink_l1( 394 self._convert_n_to_tensor( 395 [var], as_ref=True), 396 l1=self._symmetric_l1_regularization(), 397 l2=self._symmetric_l2_regularization())) 398 return control_flow_ops.group(*update_ops) 399 400 def approximate_duality_gap(self): 401 """Add operations to compute the approximate duality gap. 402 403 Returns: 404 An Operation that computes the approximate duality gap over all 405 examples. 406 """ 407 with name_scope('sdca/approximate_duality_gap'): 408 _, values_list = self._hashtable.export_sharded() 409 shard_sums = [] 410 for values in values_list: 411 with ops.device(values.device): 412 # For large tables to_double() below allocates a large temporary 413 # tensor that is freed once the sum operation completes. To reduce 414 # peak memory usage in cases where we have multiple large tables on a 415 # single device, we serialize these operations. 416 # Note that we need double precision to get accurate results. 417 with ops.control_dependencies(shard_sums): 418 shard_sums.append( 419 math_ops.reduce_sum(math_ops.to_double(values), 0)) 420 summed_values = math_ops.add_n(shard_sums) 421 422 primal_loss = summed_values[1] 423 dual_loss = summed_values[2] 424 example_weights = summed_values[3] 425 # Note: we return NaN if there are no weights or all weights are 0, e.g. 426 # if no examples have been processed 427 return (primal_loss + dual_loss + self._l1_loss() + 428 (2.0 * self._l2_loss(self._symmetric_l2_regularization())) 429 ) / example_weights 430 431 def unregularized_loss(self, examples): 432 """Add operations to compute the loss (without the regularization loss). 433 434 Args: 435 examples: Examples to compute unregularized loss on. 436 437 Returns: 438 An Operation that computes mean (unregularized) loss for given set of 439 examples. 440 441 Raises: 442 ValueError: if examples are not well defined. 443 """ 444 self._assertSpecified([ 445 'example_labels', 'example_weights', 'sparse_features', 'dense_features' 446 ], examples) 447 self._assertList(['sparse_features', 'dense_features'], examples) 448 with name_scope('sdca/unregularized_loss'): 449 predictions = math_ops.cast( 450 self._linear_predictions(examples), dtypes.float64) 451 labels = math_ops.cast( 452 internal_convert_to_tensor(examples['example_labels']), 453 dtypes.float64) 454 weights = math_ops.cast( 455 internal_convert_to_tensor(examples['example_weights']), 456 dtypes.float64) 457 458 if self._options['loss_type'] == 'logistic_loss': 459 return math_ops.reduce_sum(math_ops.multiply( 460 sigmoid_cross_entropy_with_logits(labels=labels, 461 logits=predictions), 462 weights)) / math_ops.reduce_sum(weights) 463 464 if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']: 465 # hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to 466 # first convert 0/1 labels into -1/1 labels. 467 all_ones = array_ops.ones_like(predictions) 468 adjusted_labels = math_ops.subtract(2 * labels, all_ones) 469 # Tensor that contains (unweighted) error (hinge loss) per 470 # example. 471 error = nn_ops.relu( 472 math_ops.subtract(all_ones, 473 math_ops.multiply(adjusted_labels, predictions))) 474 weighted_error = math_ops.multiply(error, weights) 475 return math_ops.reduce_sum(weighted_error) / math_ops.reduce_sum( 476 weights) 477 478 # squared loss 479 err = math_ops.subtract(labels, predictions) 480 481 weighted_squared_err = math_ops.multiply(math_ops.square(err), weights) 482 # SDCA squared loss function is sum(err^2) / (2*sum(weights)) 483 return (math_ops.reduce_sum(weighted_squared_err) / 484 (2.0 * math_ops.reduce_sum(weights))) 485 486 def regularized_loss(self, examples): 487 """Add operations to compute the loss with regularization loss included. 488 489 Args: 490 examples: Examples to compute loss on. 491 492 Returns: 493 An Operation that computes mean (regularized) loss for given set of 494 examples. 495 Raises: 496 ValueError: if examples are not well defined. 497 """ 498 self._assertSpecified([ 499 'example_labels', 'example_weights', 'sparse_features', 'dense_features' 500 ], examples) 501 self._assertList(['sparse_features', 'dense_features'], examples) 502 with name_scope('sdca/regularized_loss'): 503 weights = internal_convert_to_tensor(examples['example_weights']) 504 return (( 505 self._l1_loss() + 506 # Note that here we are using the raw regularization 507 # (as specified by the user) and *not* 508 # self._symmetric_l2_regularization(). 509 self._l2_loss(self._options['symmetric_l2_regularization'])) / 510 math_ops.reduce_sum(math_ops.cast(weights, dtypes.float64)) + 511 self.unregularized_loss(examples)) 512