Home | History | Annotate | Download | only in estimator_batch
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """GTFlow Estimator definition."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.boosted_trees.estimator_batch import model
     22 from tensorflow.contrib.boosted_trees.python.utils import losses
     23 from tensorflow.contrib.learn.python.learn.estimators import estimator
     24 from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
     25 from tensorflow.python.ops import math_ops
     26 
     27 
     28 class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
     29   """An estimator using gradient boosted decision trees."""
     30 
     31   def __init__(self,
     32                learner_config,
     33                examples_per_layer,
     34                n_classes=2,
     35                num_trees=None,
     36                feature_columns=None,
     37                weight_column_name=None,
     38                model_dir=None,
     39                config=None,
     40                label_keys=None,
     41                feature_engineering_fn=None,
     42                logits_modifier_function=None,
     43                center_bias=True):
     44     """Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
     45 
     46     Args:
     47       learner_config: A config for the learner.
     48       examples_per_layer: Number of examples to accumulate before growing a
     49         layer. It can also be a function that computes the number of examples
     50         based on the depth of the layer that's being built.
     51       n_classes: Number of classes in the classification.
     52       num_trees: An int, number of trees to build.
     53       feature_columns: A list of feature columns.
     54       weight_column_name: Name of the column for weights, or None if not
     55         weighted.
     56       model_dir: Directory for model exports, etc.
     57       config: `RunConfig` object to configure the runtime settings.
     58       label_keys: Optional list of strings with size `[n_classes]` defining the
     59         label vocabulary. Only supported for `n_classes` > 2.
     60       feature_engineering_fn: Feature engineering function. Takes features and
     61         labels which are the output of `input_fn` and returns features and
     62         labels which will be fed into the model.
     63       logits_modifier_function: A modifier function for the logits.
     64       center_bias: Whether a separate tree should be created for first fitting
     65         the bias.
     66 
     67     Raises:
     68       ValueError: If learner_config is not valid.
     69     """
     70     if n_classes > 2:
     71       # For multi-class classification, use our loss implementation that
     72       # supports second order derivative.
     73       def loss_fn(labels, logits, weights=None):
     74         result = losses.per_example_maxent_loss(
     75             labels=labels, logits=logits, weights=weights,
     76             num_classes=n_classes)
     77         return math_ops.reduce_mean(result[0])
     78     else:
     79       loss_fn = None
     80     head = head_lib.multi_class_head(
     81         n_classes=n_classes,
     82         weight_column_name=weight_column_name,
     83         enable_centered_bias=False,
     84         loss_fn=loss_fn)
     85     if learner_config.num_classes == 0:
     86       learner_config.num_classes = n_classes
     87     elif learner_config.num_classes != n_classes:
     88       raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
     89                        (learner_config.num_classes, n_classes))
     90     super(GradientBoostedDecisionTreeClassifier, self).__init__(
     91         model_fn=model.model_builder,
     92         params={
     93             'head': head,
     94             'feature_columns': feature_columns,
     95             'learner_config': learner_config,
     96             'num_trees': num_trees,
     97             'weight_column_name': weight_column_name,
     98             'examples_per_layer': examples_per_layer,
     99             'center_bias': center_bias,
    100             'logits_modifier_function': logits_modifier_function,
    101         },
    102         model_dir=model_dir,
    103         config=config,
    104         feature_engineering_fn=feature_engineering_fn)
    105 
    106 
    107 class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
    108   """An estimator using gradient boosted decision trees."""
    109 
    110   def __init__(self,
    111                learner_config,
    112                examples_per_layer,
    113                label_dimension=1,
    114                num_trees=None,
    115                feature_columns=None,
    116                label_name=None,
    117                weight_column_name=None,
    118                model_dir=None,
    119                config=None,
    120                feature_engineering_fn=None,
    121                logits_modifier_function=None,
    122                center_bias=True):
    123     """Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
    124 
    125     Args:
    126       learner_config: A config for the learner.
    127       examples_per_layer: Number of examples to accumulate before growing a
    128         layer. It can also be a function that computes the number of examples
    129         based on the depth of the layer that's being built.
    130       label_dimension: Number of regression labels per example. This is the size
    131         of the last dimension of the labels `Tensor` (typically, this has shape
    132         `[batch_size, label_dimension]`).
    133       num_trees: An int, number of trees to build.
    134       feature_columns: A list of feature columns.
    135       label_name: String, name of the key in label dict. Can be null if label
    136           is a tensor (single headed models).
    137       weight_column_name: Name of the column for weights, or None if not
    138         weighted.
    139       model_dir: Directory for model exports, etc.
    140       config: `RunConfig` object to configure the runtime settings.
    141       feature_engineering_fn: Feature engineering function. Takes features and
    142         labels which are the output of `input_fn` and returns features and
    143         labels which will be fed into the model.
    144       logits_modifier_function: A modifier function for the logits.
    145       center_bias: Whether a separate tree should be created for first fitting
    146         the bias.
    147     """
    148     head = head_lib.regression_head(
    149         label_name=label_name,
    150         label_dimension=label_dimension,
    151         weight_column_name=weight_column_name,
    152         enable_centered_bias=False)
    153     if label_dimension == 1:
    154       learner_config.num_classes = 2
    155     else:
    156       learner_config.num_classes = label_dimension
    157     super(GradientBoostedDecisionTreeRegressor, self).__init__(
    158         model_fn=model.model_builder,
    159         params={
    160             'head': head,
    161             'feature_columns': feature_columns,
    162             'learner_config': learner_config,
    163             'num_trees': num_trees,
    164             'weight_column_name': weight_column_name,
    165             'examples_per_layer': examples_per_layer,
    166             'logits_modifier_function': logits_modifier_function,
    167             'center_bias': center_bias,
    168         },
    169         model_dir=model_dir,
    170         config=config,
    171         feature_engineering_fn=feature_engineering_fn)
    172 
    173 
    174 class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
    175   """An estimator using gradient boosted decision trees.
    176 
    177   Useful for training with user specified `Head`.
    178   """
    179 
    180   def __init__(self,
    181                learner_config,
    182                examples_per_layer,
    183                head,
    184                num_trees=None,
    185                feature_columns=None,
    186                weight_column_name=None,
    187                model_dir=None,
    188                config=None,
    189                feature_engineering_fn=None,
    190                logits_modifier_function=None,
    191                center_bias=True):
    192     """Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
    193 
    194     Args:
    195       learner_config: A config for the learner.
    196       examples_per_layer: Number of examples to accumulate before growing a
    197         layer. It can also be a function that computes the number of examples
    198         based on the depth of the layer that's being built.
    199       head: `Head` instance.
    200       num_trees: An int, number of trees to build.
    201       feature_columns: A list of feature columns.
    202       weight_column_name: Name of the column for weights, or None if not
    203         weighted.
    204       model_dir: Directory for model exports, etc.
    205       config: `RunConfig` object to configure the runtime settings.
    206       feature_engineering_fn: Feature engineering function. Takes features and
    207         labels which are the output of `input_fn` and returns features and
    208         labels which will be fed into the model.
    209       logits_modifier_function: A modifier function for the logits.
    210       center_bias: Whether a separate tree should be created for first fitting
    211         the bias.
    212     """
    213     super(GradientBoostedDecisionTreeEstimator, self).__init__(
    214         model_fn=model.model_builder,
    215         params={
    216             'head': head,
    217             'feature_columns': feature_columns,
    218             'learner_config': learner_config,
    219             'num_trees': num_trees,
    220             'weight_column_name': weight_column_name,
    221             'examples_per_layer': examples_per_layer,
    222             'logits_modifier_function': logits_modifier_function,
    223             'center_bias': center_bias,
    224         },
    225         model_dir=model_dir,
    226         config=config,
    227         feature_engineering_fn=feature_engineering_fn)
    228