Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TensorFlow estimator for Linear and DNN joined training models."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.estimator import estimator
     22 from tensorflow.python.estimator.canned import dnn_linear_combined as dnn_linear_combined_lib
     23 from tensorflow.python.ops import nn
     24 
     25 
     26 class DNNLinearCombinedEstimator(estimator.Estimator):
     27   """An estimator for TensorFlow Linear and DNN joined models with custom head.
     28 
     29   Note: This estimator is also known as wide-n-deep.
     30 
     31   Example:
     32 
     33   ```python
     34   numeric_feature = numeric_column(...)
     35   categorical_column_a = categorical_column_with_hash_bucket(...)
     36   categorical_column_b = categorical_column_with_hash_bucket(...)
     37 
     38   categorical_feature_a_x_categorical_feature_b = crossed_column(...)
     39   categorical_feature_a_emb = embedding_column(
     40       categorical_column=categorical_feature_a, ...)
     41   categorical_feature_b_emb = embedding_column(
     42       categorical_column=categorical_feature_b, ...)
     43 
     44   estimator = DNNLinearCombinedEstimator(
     45       head=tf.contrib.estimator.multi_label_head(n_classes=3),
     46       # wide settings
     47       linear_feature_columns=[categorical_feature_a_x_categorical_feature_b],
     48       linear_optimizer=tf.train.FtrlOptimizer(...),
     49       # deep settings
     50       dnn_feature_columns=[
     51           categorical_feature_a_emb, categorical_feature_b_emb,
     52           numeric_feature],
     53       dnn_hidden_units=[1000, 500, 100],
     54       dnn_optimizer=tf.train.ProximalAdagradOptimizer(...))
     55 
     56   # To apply L1 and L2 regularization, you can set optimizers as follows:
     57   tf.train.ProximalAdagradOptimizer(
     58       learning_rate=0.1,
     59       l1_regularization_strength=0.001,
     60       l2_regularization_strength=0.001)
     61   # It is same for FtrlOptimizer.
     62 
     63   # Input builders
     64   def input_fn_train: # returns x, y
     65     pass
     66   estimator.train(input_fn=input_fn_train, steps=100)
     67 
     68   def input_fn_eval: # returns x, y
     69     pass
     70   metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
     71   def input_fn_predict: # returns x, None
     72     pass
     73   predictions = estimator.predict(input_fn=input_fn_predict)
     74   ```
     75 
     76   Input of `train` and `evaluate` should have following features,
     77   otherwise there will be a `KeyError`:
     78 
     79   * for each `column` in `dnn_feature_columns` + `linear_feature_columns`:
     80     - if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
     81       whose `value` is a `SparseTensor`.
     82     - if `column` is a `_WeightedCategoricalColumn`, two features: the first
     83       with `key` the id column name, the second with `key` the weight column
     84       name. Both features' `value` must be a `SparseTensor`.
     85     - if `column` is a `_DenseColumn`, a feature with `key=column.name`
     86       whose `value` is a `Tensor`.
     87 
     88   Loss is calculated by using mean squared error.
     89 
     90   @compatibility(eager)
     91   Estimators are not compatible with eager execution.
     92   @end_compatibility
     93   """
     94 
     95   def __init__(self,
     96                head,
     97                model_dir=None,
     98                linear_feature_columns=None,
     99                linear_optimizer='Ftrl',
    100                dnn_feature_columns=None,
    101                dnn_optimizer='Adagrad',
    102                dnn_hidden_units=None,
    103                dnn_activation_fn=nn.relu,
    104                dnn_dropout=None,
    105                input_layer_partitioner=None,
    106                config=None):
    107     """Initializes a DNNLinearCombinedEstimator instance.
    108 
    109     Args:
    110       head: A `_Head` instance constructed with a method such as
    111         `tf.contrib.estimator.multi_label_head`.
    112       model_dir: Directory to save model parameters, graph and etc. This can
    113         also be used to load checkpoints from the directory into a estimator
    114         to continue training a previously saved model.
    115       linear_feature_columns: An iterable containing all the feature columns
    116         used by linear part of the model. All items in the set must be
    117         instances of classes derived from `FeatureColumn`.
    118       linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
    119         the linear part of the model. Defaults to FTRL optimizer.
    120       dnn_feature_columns: An iterable containing all the feature columns used
    121         by deep part of the model. All items in the set must be instances of
    122         classes derived from `FeatureColumn`.
    123       dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
    124         the deep part of the model. Defaults to Adagrad optimizer.
    125       dnn_hidden_units: List of hidden units per layer. All layers are fully
    126         connected.
    127       dnn_activation_fn: Activation function applied to each layer. If None,
    128         will use `tf.nn.relu`.
    129       dnn_dropout: When not None, the probability we will drop out
    130         a given coordinate.
    131       input_layer_partitioner: Partitioner for input layer. Defaults to
    132         `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
    133       config: RunConfig object to configure the runtime settings.
    134 
    135     Raises:
    136       ValueError: If both linear_feature_columns and dnn_features_columns are
    137         empty at the same time.
    138     """
    139     linear_feature_columns = linear_feature_columns or []
    140     dnn_feature_columns = dnn_feature_columns or []
    141     self._feature_columns = (
    142         list(linear_feature_columns) + list(dnn_feature_columns))
    143     if not self._feature_columns:
    144       raise ValueError('Either linear_feature_columns or dnn_feature_columns '
    145                        'must be defined.')
    146 
    147     def _model_fn(features, labels, mode, config):
    148       return dnn_linear_combined_lib._dnn_linear_combined_model_fn(  # pylint: disable=protected-access
    149           features=features,
    150           labels=labels,
    151           mode=mode,
    152           head=head,
    153           linear_feature_columns=linear_feature_columns,
    154           linear_optimizer=linear_optimizer,
    155           dnn_feature_columns=dnn_feature_columns,
    156           dnn_optimizer=dnn_optimizer,
    157           dnn_hidden_units=dnn_hidden_units,
    158           dnn_activation_fn=dnn_activation_fn,
    159           dnn_dropout=dnn_dropout,
    160           input_layer_partitioner=input_layer_partitioner,
    161           config=config)
    162 
    163     super(DNNLinearCombinedEstimator, self).__init__(
    164         model_fn=_model_fn, model_dir=model_dir, config=config)
    165