Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Linear estimator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.estimator import estimator
     22 from tensorflow.python.estimator.canned import linear as linear_lib
     23 
     24 
     25 class LinearEstimator(estimator.Estimator):
     26   """An estimator for TensorFlow linear models with user-specified head.
     27 
     28   Example:
     29 
     30   ```python
     31   categorical_column_a = categorical_column_with_hash_bucket(...)
     32   categorical_column_b = categorical_column_with_hash_bucket(...)
     33 
     34   categorical_feature_a_x_categorical_feature_b = crossed_column(...)
     35 
     36   # Estimator using the default optimizer.
     37   estimator = LinearEstimator(
     38       head=tf.contrib.estimator.multi_label_head(n_classes=3),
     39       feature_columns=[categorical_column_a,
     40                        categorical_feature_a_x_categorical_feature_b])
     41 
     42   # Or estimator using the FTRL optimizer with regularization.
     43   estimator = LinearEstimator(
     44       head=tf.contrib.estimator.multi_label_head(n_classes=3),
     45       feature_columns=[categorical_column_a,
     46                        categorical_feature_a_x_categorical_feature_b])
     47       optimizer=tf.train.FtrlOptimizer(
     48           learning_rate=0.1,
     49           l1_regularization_strength=0.001
     50       ))
     51 
     52   def input_fn_train: # returns x, y (where y represents label's class index).
     53     ...
     54   estimator.train(input_fn=input_fn_train, steps=100)
     55   def input_fn_eval: # returns x, y (where y represents label's class index).
     56     ...
     57   metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
     58   def input_fn_predict: # returns x, None
     59     ...
     60   predictions = estimator.predict(input_fn=input_fn_predict)
     61   ```
     62 
     63   Input of `train` and `evaluate` should have following features,
     64   otherwise there will be a `KeyError`:
     65 
     66   * if `weight_column` is not `None`, a feature with
     67     `key=weight_column` whose value is a `Tensor`.
     68   * for each `column` in `feature_columns`:
     69     - if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
     70       whose `value` is a `SparseTensor`.
     71     - if `column` is a `_WeightedCategoricalColumn`, two features: the first
     72       with `key` the id column name, the second with `key` the weight column
     73       name. Both features' `value` must be a `SparseTensor`.
     74     - if `column` is a `_DenseColumn`, a feature with `key=column.name`
     75       whose `value` is a `Tensor`.
     76 
     77   Loss and predicted output are determined by the specified head.
     78 
     79   @compatibility(eager)
     80   Estimators are not compatible with eager execution.
     81   @end_compatibility
     82   """
     83 
     84   def __init__(self,
     85                head,
     86                feature_columns,
     87                model_dir=None,
     88                optimizer='Ftrl',
     89                config=None,
     90                partitioner=None):
     91     """Initializes a `LinearEstimator` instance.
     92 
     93     Args:
     94       head: A `_Head` instance constructed with a method such as
     95         `tf.contrib.estimator.multi_label_head`.
     96       feature_columns: An iterable containing all the feature columns used by
     97         the model. All items in the set should be instances of classes derived
     98         from `FeatureColumn`.
     99       model_dir: Directory to save model parameters, graph and etc. This can
    100         also be used to load checkpoints from the directory into a estimator
    101         to continue training a previously saved model.
    102       optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
    103         to FTRL optimizer.
    104       config: `RunConfig` object to configure the runtime settings.
    105       partitioner: Optional. Partitioner for input layer.
    106     """
    107     def _model_fn(features, labels, mode, config):
    108       return linear_lib._linear_model_fn(  # pylint: disable=protected-access
    109           features=features,
    110           labels=labels,
    111           mode=mode,
    112           head=head,
    113           feature_columns=tuple(feature_columns or []),
    114           optimizer=optimizer,
    115           partitioner=partitioner,
    116           config=config)
    117     super(LinearEstimator, self).__init__(
    118         model_fn=_model_fn, model_dir=model_dir, config=config)
    119