1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Linear estimator.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.estimator import estimator 22 from tensorflow.python.estimator.canned import linear as linear_lib 23 24 25 class LinearEstimator(estimator.Estimator): 26 """An estimator for TensorFlow linear models with user-specified head. 27 28 Example: 29 30 ```python 31 categorical_column_a = categorical_column_with_hash_bucket(...) 32 categorical_column_b = categorical_column_with_hash_bucket(...) 33 34 categorical_feature_a_x_categorical_feature_b = crossed_column(...) 35 36 # Estimator using the default optimizer. 37 estimator = LinearEstimator( 38 head=tf.contrib.estimator.multi_label_head(n_classes=3), 39 feature_columns=[categorical_column_a, 40 categorical_feature_a_x_categorical_feature_b]) 41 42 # Or estimator using the FTRL optimizer with regularization. 43 estimator = LinearEstimator( 44 head=tf.contrib.estimator.multi_label_head(n_classes=3), 45 feature_columns=[categorical_column_a, 46 categorical_feature_a_x_categorical_feature_b]) 47 optimizer=tf.train.FtrlOptimizer( 48 learning_rate=0.1, 49 l1_regularization_strength=0.001 50 )) 51 52 def input_fn_train: # returns x, y (where y represents label's class index). 53 ... 54 estimator.train(input_fn=input_fn_train, steps=100) 55 def input_fn_eval: # returns x, y (where y represents label's class index). 56 ... 57 metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) 58 def input_fn_predict: # returns x, None 59 ... 60 predictions = estimator.predict(input_fn=input_fn_predict) 61 ``` 62 63 Input of `train` and `evaluate` should have following features, 64 otherwise there will be a `KeyError`: 65 66 * if `weight_column` is not `None`, a feature with 67 `key=weight_column` whose value is a `Tensor`. 68 * for each `column` in `feature_columns`: 69 - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` 70 whose `value` is a `SparseTensor`. 71 - if `column` is a `_WeightedCategoricalColumn`, two features: the first 72 with `key` the id column name, the second with `key` the weight column 73 name. Both features' `value` must be a `SparseTensor`. 74 - if `column` is a `_DenseColumn`, a feature with `key=column.name` 75 whose `value` is a `Tensor`. 76 77 Loss and predicted output are determined by the specified head. 78 79 @compatibility(eager) 80 Estimators are not compatible with eager execution. 81 @end_compatibility 82 """ 83 84 def __init__(self, 85 head, 86 feature_columns, 87 model_dir=None, 88 optimizer='Ftrl', 89 config=None, 90 partitioner=None): 91 """Initializes a `LinearEstimator` instance. 92 93 Args: 94 head: A `_Head` instance constructed with a method such as 95 `tf.contrib.estimator.multi_label_head`. 96 feature_columns: An iterable containing all the feature columns used by 97 the model. All items in the set should be instances of classes derived 98 from `FeatureColumn`. 99 model_dir: Directory to save model parameters, graph and etc. This can 100 also be used to load checkpoints from the directory into a estimator 101 to continue training a previously saved model. 102 optimizer: An instance of `tf.Optimizer` used to train the model. Defaults 103 to FTRL optimizer. 104 config: `RunConfig` object to configure the runtime settings. 105 partitioner: Optional. Partitioner for input layer. 106 """ 107 def _model_fn(features, labels, mode, config): 108 return linear_lib._linear_model_fn( # pylint: disable=protected-access 109 features=features, 110 labels=labels, 111 mode=mode, 112 head=head, 113 feature_columns=tuple(feature_columns or []), 114 optimizer=optimizer, 115 partitioner=partitioner, 116 config=config) 117 super(LinearEstimator, self).__init__( 118 model_fn=_model_fn, model_dir=model_dir, config=config) 119