1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Deep Neural Network estimators.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.python.estimator import estimator 24 from tensorflow.python.estimator import model_fn 25 from tensorflow.python.estimator.canned import head as head_lib 26 from tensorflow.python.estimator.canned import optimizers 27 from tensorflow.python.feature_column import feature_column as feature_column_lib 28 from tensorflow.python.layers import core as core_layers 29 from tensorflow.python.ops import init_ops 30 from tensorflow.python.ops import nn 31 from tensorflow.python.ops import partitioned_variables 32 from tensorflow.python.ops import variable_scope 33 from tensorflow.python.ops.losses import losses 34 from tensorflow.python.summary import summary 35 from tensorflow.python.training import training_util 36 from tensorflow.python.util.tf_export import tf_export 37 38 # The default learning rate of 0.05 is a historical artifact of the initial 39 # implementation, but seems a reasonable choice. 40 _LEARNING_RATE = 0.05 41 42 43 def _add_hidden_layer_summary(value, tag): 44 summary.scalar('%s/fraction_of_zero_values' % tag, nn.zero_fraction(value)) 45 summary.histogram('%s/activation' % tag, value) 46 47 48 def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, 49 dropout, input_layer_partitioner): 50 """Function builder for a dnn logit_fn. 51 52 Args: 53 units: An int indicating the dimension of the logit layer. In the 54 MultiHead case, this should be the sum of all component Heads' logit 55 dimensions. 56 hidden_units: Iterable of integer number of hidden units per layer. 57 feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. 58 activation_fn: Activation function applied to each layer. 59 dropout: When not `None`, the probability we will drop out a given 60 coordinate. 61 input_layer_partitioner: Partitioner for input layer. 62 63 Returns: 64 A logit_fn (see below). 65 66 Raises: 67 ValueError: If units is not an int. 68 """ 69 if not isinstance(units, int): 70 raise ValueError('units must be an int. Given type: {}'.format( 71 type(units))) 72 73 def dnn_logit_fn(features, mode): 74 """Deep Neural Network logit_fn. 75 76 Args: 77 features: This is the first item returned from the `input_fn` 78 passed to `train`, `evaluate`, and `predict`. This should be a 79 single `Tensor` or `dict` of same. 80 mode: Optional. Specifies if this training, evaluation or prediction. See 81 `ModeKeys`. 82 83 Returns: 84 A `Tensor` representing the logits, or a list of `Tensor`'s representing 85 multiple logits in the MultiHead case. 86 """ 87 with variable_scope.variable_scope( 88 'input_from_feature_columns', 89 values=tuple(six.itervalues(features)), 90 partitioner=input_layer_partitioner): 91 net = feature_column_lib.input_layer( 92 features=features, feature_columns=feature_columns) 93 for layer_id, num_hidden_units in enumerate(hidden_units): 94 with variable_scope.variable_scope( 95 'hiddenlayer_%d' % layer_id, values=(net,)) as hidden_layer_scope: 96 net = core_layers.dense( 97 net, 98 units=num_hidden_units, 99 activation=activation_fn, 100 kernel_initializer=init_ops.glorot_uniform_initializer(), 101 name=hidden_layer_scope) 102 if dropout is not None and mode == model_fn.ModeKeys.TRAIN: 103 net = core_layers.dropout(net, rate=dropout, training=True) 104 _add_hidden_layer_summary(net, hidden_layer_scope.name) 105 106 with variable_scope.variable_scope('logits', values=(net,)) as logits_scope: 107 logits = core_layers.dense( 108 net, 109 units=units, 110 activation=None, 111 kernel_initializer=init_ops.glorot_uniform_initializer(), 112 name=logits_scope) 113 _add_hidden_layer_summary(logits, logits_scope.name) 114 115 return logits 116 117 return dnn_logit_fn 118 119 120 def _dnn_model_fn(features, 121 labels, 122 mode, 123 head, 124 hidden_units, 125 feature_columns, 126 optimizer='Adagrad', 127 activation_fn=nn.relu, 128 dropout=None, 129 input_layer_partitioner=None, 130 config=None): 131 """Deep Neural Net model_fn. 132 133 Args: 134 features: dict of `Tensor`. 135 labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of 136 dtype `int32` or `int64` in the range `[0, n_classes)`. 137 mode: Defines whether this is training, evaluation or prediction. 138 See `ModeKeys`. 139 head: A `head_lib._Head` instance. 140 hidden_units: Iterable of integer number of hidden units per layer. 141 feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. 142 optimizer: String, `tf.Optimizer` object, or callable that creates the 143 optimizer to use for training. If not specified, will use the Adagrad 144 optimizer with a default learning rate of 0.05. 145 activation_fn: Activation function applied to each layer. 146 dropout: When not `None`, the probability we will drop out a given 147 coordinate. 148 input_layer_partitioner: Partitioner for input layer. Defaults 149 to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. 150 config: `RunConfig` object to configure the runtime settings. 151 152 Returns: 153 An `EstimatorSpec` instance. 154 155 Raises: 156 ValueError: If features has the wrong type. 157 """ 158 if not isinstance(features, dict): 159 raise ValueError('features should be a dictionary of `Tensor`s. ' 160 'Given type: {}'.format(type(features))) 161 162 optimizer = optimizers.get_optimizer_instance( 163 optimizer, learning_rate=_LEARNING_RATE) 164 num_ps_replicas = config.num_ps_replicas if config else 0 165 166 partitioner = partitioned_variables.min_max_variable_partitioner( 167 max_partitions=num_ps_replicas) 168 with variable_scope.variable_scope( 169 'dnn', 170 values=tuple(six.itervalues(features)), 171 partitioner=partitioner): 172 input_layer_partitioner = input_layer_partitioner or ( 173 partitioned_variables.min_max_variable_partitioner( 174 max_partitions=num_ps_replicas, 175 min_slice_size=64 << 20)) 176 177 logit_fn = _dnn_logit_fn_builder( 178 units=head.logits_dimension, 179 hidden_units=hidden_units, 180 feature_columns=feature_columns, 181 activation_fn=activation_fn, 182 dropout=dropout, 183 input_layer_partitioner=input_layer_partitioner) 184 logits = logit_fn(features=features, mode=mode) 185 186 def _train_op_fn(loss): 187 """Returns the op to optimize the loss.""" 188 return optimizer.minimize( 189 loss, 190 global_step=training_util.get_global_step()) 191 192 return head.create_estimator_spec( 193 features=features, 194 mode=mode, 195 labels=labels, 196 train_op_fn=_train_op_fn, 197 logits=logits) 198 199 200 @tf_export('estimator.DNNClassifier') 201 class DNNClassifier(estimator.Estimator): 202 """A classifier for TensorFlow DNN models. 203 204 Example: 205 206 ```python 207 categorical_feature_a = categorical_column_with_hash_bucket(...) 208 categorical_feature_b = categorical_column_with_hash_bucket(...) 209 210 categorical_feature_a_emb = embedding_column( 211 categorical_column=categorical_feature_a, ...) 212 categorical_feature_b_emb = embedding_column( 213 categorical_column=categorical_feature_b, ...) 214 215 estimator = DNNClassifier( 216 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 217 hidden_units=[1024, 512, 256]) 218 219 # Or estimator using the ProximalAdagradOptimizer optimizer with 220 # regularization. 221 estimator = DNNClassifier( 222 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 223 hidden_units=[1024, 512, 256], 224 optimizer=tf.train.ProximalAdagradOptimizer( 225 learning_rate=0.1, 226 l1_regularization_strength=0.001 227 )) 228 229 # Or estimator with warm-starting from a previous checkpoint. 230 estimator = DNNClassifier( 231 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 232 hidden_units=[1024, 512, 256], 233 warm_start_from="/path/to/checkpoint/dir") 234 235 # Input builders 236 def input_fn_train: # returns x, y 237 pass 238 estimator.train(input_fn=input_fn_train, steps=100) 239 240 def input_fn_eval: # returns x, y 241 pass 242 metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) 243 def input_fn_predict: # returns x, None 244 pass 245 predictions = estimator.predict(input_fn=input_fn_predict) 246 ``` 247 248 Input of `train` and `evaluate` should have following features, 249 otherwise there will be a `KeyError`: 250 251 * if `weight_column` is not `None`, a feature with 252 `key=weight_column` whose value is a `Tensor`. 253 * for each `column` in `feature_columns`: 254 - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` 255 whose `value` is a `SparseTensor`. 256 - if `column` is a `_WeightedCategoricalColumn`, two features: the first 257 with `key` the id column name, the second with `key` the weight column 258 name. Both features' `value` must be a `SparseTensor`. 259 - if `column` is a `_DenseColumn`, a feature with `key=column.name` 260 whose `value` is a `Tensor`. 261 262 Loss is calculated by using softmax cross entropy. 263 264 @compatibility(eager) 265 Estimators are not compatible with eager execution. 266 @end_compatibility 267 """ 268 269 def __init__( 270 self, 271 hidden_units, 272 feature_columns, 273 model_dir=None, 274 n_classes=2, 275 weight_column=None, 276 label_vocabulary=None, 277 optimizer='Adagrad', 278 activation_fn=nn.relu, 279 dropout=None, 280 input_layer_partitioner=None, 281 config=None, 282 warm_start_from=None, 283 loss_reduction=losses.Reduction.SUM, 284 ): 285 """Initializes a `DNNClassifier` instance. 286 287 Args: 288 hidden_units: Iterable of number hidden units per layer. All layers are 289 fully connected. Ex. `[64, 32]` means first layer has 64 nodes and 290 second one has 32. 291 feature_columns: An iterable containing all the feature columns used by 292 the model. All items in the set should be instances of classes derived 293 from `_FeatureColumn`. 294 model_dir: Directory to save model parameters, graph and etc. This can 295 also be used to load checkpoints from the directory into a estimator to 296 continue training a previously saved model. 297 n_classes: Number of label classes. Defaults to 2, namely binary 298 classification. Must be > 1. 299 weight_column: A string or a `_NumericColumn` created by 300 `tf.feature_column.numeric_column` defining feature column representing 301 weights. It is used to down weight or boost examples during training. It 302 will be multiplied by the loss of the example. If it is a string, it is 303 used as a key to fetch weight tensor from the `features`. If it is a 304 `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, 305 then weight_column.normalizer_fn is applied on it to get weight tensor. 306 label_vocabulary: A list of strings represents possible label values. If 307 given, labels must be string type and have any value in 308 `label_vocabulary`. If it is not given, that means labels are 309 already encoded as integer or float within [0, 1] for `n_classes=2` and 310 encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . 311 Also there will be errors if vocabulary is not provided and labels are 312 string. 313 optimizer: An instance of `tf.Optimizer` used to train the model. Defaults 314 to Adagrad optimizer. 315 activation_fn: Activation function applied to each layer. If `None`, will 316 use `tf.nn.relu`. 317 dropout: When not `None`, the probability we will drop out a given 318 coordinate. 319 input_layer_partitioner: Optional. Partitioner for input layer. Defaults 320 to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. 321 config: `RunConfig` object to configure the runtime settings. 322 warm_start_from: A string filepath to a checkpoint to warm-start from, or 323 a `WarmStartSettings` object to fully configure warm-starting. If the 324 string filepath is provided instead of a `WarmStartSettings`, then all 325 weights are warm-started, and it is assumed that vocabularies and Tensor 326 names are unchanged. 327 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how 328 to reduce training loss over batch. Defaults to `SUM`. 329 """ 330 if n_classes == 2: 331 head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access 332 weight_column=weight_column, 333 label_vocabulary=label_vocabulary, 334 loss_reduction=loss_reduction) 335 else: 336 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access 337 n_classes, weight_column=weight_column, 338 label_vocabulary=label_vocabulary, 339 loss_reduction=loss_reduction) 340 341 def _model_fn(features, labels, mode, config): 342 """Call the defined shared _dnn_model_fn.""" 343 return _dnn_model_fn( 344 features=features, 345 labels=labels, 346 mode=mode, 347 head=head, 348 hidden_units=hidden_units, 349 feature_columns=tuple(feature_columns or []), 350 optimizer=optimizer, 351 activation_fn=activation_fn, 352 dropout=dropout, 353 input_layer_partitioner=input_layer_partitioner, 354 config=config) 355 356 super(DNNClassifier, self).__init__( 357 model_fn=_model_fn, model_dir=model_dir, config=config, 358 warm_start_from=warm_start_from) 359 360 361 @tf_export('estimator.DNNRegressor') 362 class DNNRegressor(estimator.Estimator): 363 """A regressor for TensorFlow DNN models. 364 365 Example: 366 367 ```python 368 categorical_feature_a = categorical_column_with_hash_bucket(...) 369 categorical_feature_b = categorical_column_with_hash_bucket(...) 370 371 categorical_feature_a_emb = embedding_column( 372 categorical_column=categorical_feature_a, ...) 373 categorical_feature_b_emb = embedding_column( 374 categorical_column=categorical_feature_b, ...) 375 376 estimator = DNNRegressor( 377 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 378 hidden_units=[1024, 512, 256]) 379 380 # Or estimator using the ProximalAdagradOptimizer optimizer with 381 # regularization. 382 estimator = DNNRegressor( 383 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 384 hidden_units=[1024, 512, 256], 385 optimizer=tf.train.ProximalAdagradOptimizer( 386 learning_rate=0.1, 387 l1_regularization_strength=0.001 388 )) 389 390 # Or estimator with warm-starting from a previous checkpoint. 391 estimator = DNNRegressor( 392 feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], 393 hidden_units=[1024, 512, 256], 394 warm_start_from="/path/to/checkpoint/dir") 395 396 # Input builders 397 def input_fn_train: # returns x, y 398 pass 399 estimator.train(input_fn=input_fn_train, steps=100) 400 401 def input_fn_eval: # returns x, y 402 pass 403 metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) 404 def input_fn_predict: # returns x, None 405 pass 406 predictions = estimator.predict(input_fn=input_fn_predict) 407 ``` 408 409 Input of `train` and `evaluate` should have following features, 410 otherwise there will be a `KeyError`: 411 412 * if `weight_column` is not `None`, a feature with 413 `key=weight_column` whose value is a `Tensor`. 414 * for each `column` in `feature_columns`: 415 - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` 416 whose `value` is a `SparseTensor`. 417 - if `column` is a `_WeightedCategoricalColumn`, two features: the first 418 with `key` the id column name, the second with `key` the weight column 419 name. Both features' `value` must be a `SparseTensor`. 420 - if `column` is a `_DenseColumn`, a feature with `key=column.name` 421 whose `value` is a `Tensor`. 422 423 Loss is calculated by using mean squared error. 424 425 @compatibility(eager) 426 Estimators are not compatible with eager execution. 427 @end_compatibility 428 """ 429 430 def __init__( 431 self, 432 hidden_units, 433 feature_columns, 434 model_dir=None, 435 label_dimension=1, 436 weight_column=None, 437 optimizer='Adagrad', 438 activation_fn=nn.relu, 439 dropout=None, 440 input_layer_partitioner=None, 441 config=None, 442 warm_start_from=None, 443 loss_reduction=losses.Reduction.SUM, 444 ): 445 """Initializes a `DNNRegressor` instance. 446 447 Args: 448 hidden_units: Iterable of number hidden units per layer. All layers are 449 fully connected. Ex. `[64, 32]` means first layer has 64 nodes and 450 second one has 32. 451 feature_columns: An iterable containing all the feature columns used by 452 the model. All items in the set should be instances of classes derived 453 from `_FeatureColumn`. 454 model_dir: Directory to save model parameters, graph and etc. This can 455 also be used to load checkpoints from the directory into a estimator to 456 continue training a previously saved model. 457 label_dimension: Number of regression targets per example. This is the 458 size of the last dimension of the labels and logits `Tensor` objects 459 (typically, these have shape `[batch_size, label_dimension]`). 460 weight_column: A string or a `_NumericColumn` created by 461 `tf.feature_column.numeric_column` defining feature column representing 462 weights. It is used to down weight or boost examples during training. It 463 will be multiplied by the loss of the example. If it is a string, it is 464 used as a key to fetch weight tensor from the `features`. If it is a 465 `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, 466 then weight_column.normalizer_fn is applied on it to get weight tensor. 467 optimizer: An instance of `tf.Optimizer` used to train the model. Defaults 468 to Adagrad optimizer. 469 activation_fn: Activation function applied to each layer. If `None`, will 470 use `tf.nn.relu`. 471 dropout: When not `None`, the probability we will drop out a given 472 coordinate. 473 input_layer_partitioner: Optional. Partitioner for input layer. Defaults 474 to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. 475 config: `RunConfig` object to configure the runtime settings. 476 warm_start_from: A string filepath to a checkpoint to warm-start from, or 477 a `WarmStartSettings` object to fully configure warm-starting. If the 478 string filepath is provided instead of a `WarmStartSettings`, then all 479 weights are warm-started, and it is assumed that vocabularies and Tensor 480 names are unchanged. 481 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how 482 to reduce training loss over batch. Defaults to `SUM`. 483 """ 484 485 def _model_fn(features, labels, mode, config): 486 """Call the defined shared _dnn_model_fn.""" 487 return _dnn_model_fn( 488 features=features, 489 labels=labels, 490 mode=mode, 491 head=head_lib. # pylint: disable=protected-access 492 _regression_head_with_mean_squared_error_loss( 493 label_dimension=label_dimension, weight_column=weight_column, 494 loss_reduction=loss_reduction), 495 hidden_units=hidden_units, 496 feature_columns=tuple(feature_columns or []), 497 optimizer=optimizer, 498 activation_fn=activation_fn, 499 dropout=dropout, 500 input_layer_partitioner=input_layer_partitioner, 501 config=config) 502 503 super(DNNRegressor, self).__init__( 504 model_fn=_model_fn, model_dir=model_dir, config=config, 505 warm_start_from=warm_start_from) 506