Home | History | Annotate | Download | only in estimator_batch
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Strategy to export custom proto formats."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import os
     23 
     24 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
     25 from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
     26 from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
     27 from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2
     28 from tensorflow.contrib.learn.python.learn import export_strategy
     29 from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
     30 from tensorflow.python.client import session as tf_session
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.platform import gfile
     33 from tensorflow.python.saved_model import loader as saved_model_loader
     34 from tensorflow.python.saved_model import tag_constants
     35 
     36 _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"
     37 
     38 
     39 def make_custom_export_strategy(name,
     40                                 convert_fn,
     41                                 feature_columns,
     42                                 export_input_fn):
     43   """Makes custom exporter of GTFlow tree format.
     44 
     45   Args:
     46     name: A string, for the name of the export strategy.
     47     convert_fn: A function that converts the tree proto to desired format and
     48       saves it to the desired location. Can be None to skip conversion.
     49     feature_columns: A list of feature columns.
     50     export_input_fn: A function that takes no arguments and returns an
     51       `InputFnOps`.
     52 
     53   Returns:
     54     An `ExportStrategy`.
     55   """
     56   base_strategy = saved_model_export_utils.make_export_strategy(
     57       serving_input_fn=export_input_fn)
     58   input_fn = export_input_fn()
     59   (sorted_feature_names, dense_floats, sparse_float_indices, _, _,
     60    sparse_int_indices, _, _) = gbdt_batch.extract_features(
     61        input_fn.features, feature_columns)
     62 
     63   def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
     64     """A wrapper to export to SavedModel, and convert it to other formats."""
     65     result_dir = base_strategy.export(estimator, export_dir,
     66                                       checkpoint_path,
     67                                       eval_result)
     68     with ops.Graph().as_default() as graph:
     69       with tf_session.Session(graph=graph) as sess:
     70         saved_model_loader.load(
     71             sess, [tag_constants.SERVING], result_dir)
     72         # Note: This is GTFlow internal API and might change.
     73         ensemble_model = graph.get_operation_by_name(
     74             "ensemble_model/TreeEnsembleSerialize")
     75         _, dfec_str = sess.run(ensemble_model.outputs)
     76         dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
     77         dtec.ParseFromString(dfec_str)
     78         # Export the result in the same folder as the saved model.
     79         if convert_fn:
     80           convert_fn(dtec, sorted_feature_names,
     81                      len(dense_floats),
     82                      len(sparse_float_indices),
     83                      len(sparse_int_indices), result_dir, eval_result)
     84         feature_importances = _get_feature_importances(
     85             dtec, sorted_feature_names,
     86             len(dense_floats),
     87             len(sparse_float_indices), len(sparse_int_indices))
     88         sorted_by_importance = sorted(
     89             feature_importances.items(), key=lambda x: -x[1])
     90         assets_dir = os.path.join(result_dir, "assets.extra")
     91         gfile.MakeDirs(assets_dir)
     92         with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
     93                          "w") as f:
     94           f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
     95     return result_dir
     96   return export_strategy.ExportStrategy(name, export_fn)
     97 
     98 
     99 def convert_to_universal_format(dtec, sorted_feature_names,
    100                                 num_dense, num_sparse_float,
    101                                 num_sparse_int,
    102                                 feature_name_to_proto=None):
    103   """Convert GTFlow trees to universal format."""
    104   del num_sparse_int  # unused.
    105   model_and_features = generic_tree_model_pb2.ModelAndFeatures()
    106   # TODO(jonasz): Feature descriptions should contain information about how each
    107   # feature is processed before it's fed to the model (e.g. bucketing
    108   # information). As of now, this serves as a list of features the model uses.
    109   for feature_name in sorted_feature_names:
    110     if not feature_name_to_proto:
    111       model_and_features.features[feature_name].SetInParent()
    112     else:
    113       model_and_features.features[feature_name].CopyFrom(
    114           feature_name_to_proto[feature_name])
    115   model = model_and_features.model
    116   model.ensemble.summation_combination_technique.SetInParent()
    117   for tree_idx in range(len(dtec.trees)):
    118     gtflow_tree = dtec.trees[tree_idx]
    119     tree_weight = dtec.tree_weights[tree_idx]
    120     member = model.ensemble.members.add()
    121     member.submodel_id.value = tree_idx
    122     tree = member.submodel.decision_tree
    123     for node_idx in range(len(gtflow_tree.nodes)):
    124       gtflow_node = gtflow_tree.nodes[node_idx]
    125       node = tree.nodes.add()
    126       node_type = gtflow_node.WhichOneof("node")
    127       node.node_id.value = node_idx
    128       if node_type == "leaf":
    129         leaf = gtflow_node.leaf
    130         if leaf.HasField("vector"):
    131           for weight in leaf.vector.value:
    132             new_value = node.leaf.vector.value.add()
    133             new_value.float_value = weight * tree_weight
    134         else:
    135           for index, weight in zip(
    136               leaf.sparse_vector.index, leaf.sparse_vector.value):
    137             new_value = node.leaf.sparse_vector.sparse_value[index]
    138             new_value.float_value = weight * tree_weight
    139       else:
    140         node = node.binary_node
    141         # Binary nodes here.
    142         if node_type == "dense_float_binary_split":
    143           split = gtflow_node.dense_float_binary_split
    144           feature_id = split.feature_column
    145           inequality_test = node.inequality_left_child_test
    146           inequality_test.feature_id.id.value = sorted_feature_names[feature_id]
    147           inequality_test.type = (
    148               generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
    149           inequality_test.threshold.float_value = split.threshold
    150         elif node_type == "sparse_float_binary_split_default_left":
    151           split = gtflow_node.sparse_float_binary_split_default_left.split
    152           node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT)
    153           feature_id = split.feature_column + num_dense
    154           inequality_test = node.inequality_left_child_test
    155           inequality_test.feature_id.id.value = (
    156               _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
    157               (sorted_feature_names[feature_id], split.dimension_id))
    158           model_and_features.features.pop(sorted_feature_names[feature_id])
    159           (model_and_features.features[inequality_test.feature_id.id.value]
    160            .SetInParent())
    161           inequality_test.type = (
    162               generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
    163           inequality_test.threshold.float_value = split.threshold
    164         elif node_type == "sparse_float_binary_split_default_right":
    165           split = gtflow_node.sparse_float_binary_split_default_right.split
    166           node.default_direction = (
    167               generic_tree_model_pb2.BinaryNode.RIGHT)
    168           # TODO(nponomareva): adjust this id assignement when we allow multi-
    169           # column sparse tensors.
    170           feature_id = split.feature_column + num_dense
    171           inequality_test = node.inequality_left_child_test
    172           inequality_test.feature_id.id.value = (
    173               _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
    174               (sorted_feature_names[feature_id], split.dimension_id))
    175           model_and_features.features.pop(sorted_feature_names[feature_id])
    176           (model_and_features.features[inequality_test.feature_id.id.value]
    177            .SetInParent())
    178           inequality_test.type = (
    179               generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
    180           inequality_test.threshold.float_value = split.threshold
    181         elif node_type == "categorical_id_binary_split":
    182           split = gtflow_node.categorical_id_binary_split
    183           node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT
    184           feature_id = split.feature_column + num_dense + num_sparse_float
    185           categorical_test = (
    186               generic_tree_model_extensions_pb2.MatchingValuesTest())
    187           categorical_test.feature_id.id.value = sorted_feature_names[
    188               feature_id]
    189           matching_id = categorical_test.value.add()
    190           matching_id.int64_value = split.feature_id
    191           node.custom_left_child_test.Pack(categorical_test)
    192         else:
    193           raise ValueError("Unexpected node type %s", node_type)
    194         node.left_child_id.value = split.left_id
    195         node.right_child_id.value = split.right_id
    196   return model_and_features
    197 
    198 
    199 def _get_feature_importances(dtec, feature_names, num_dense_floats,
    200                              num_sparse_float, num_sparse_int):
    201   """Export the feature importance per feature column."""
    202   del num_sparse_int    # Unused.
    203   sums = collections.defaultdict(lambda: 0)
    204   for tree_idx in range(len(dtec.trees)):
    205     tree = dtec.trees[tree_idx]
    206     for tree_node in tree.nodes:
    207       node_type = tree_node.WhichOneof("node")
    208       if node_type == "dense_float_binary_split":
    209         split = tree_node.dense_float_binary_split
    210         split_column = feature_names[split.feature_column]
    211       elif node_type == "sparse_float_binary_split_default_left":
    212         split = tree_node.sparse_float_binary_split_default_left.split
    213         split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
    214             feature_names[split.feature_column + num_dense_floats],
    215             split.dimension_id)
    216       elif node_type == "sparse_float_binary_split_default_right":
    217         split = tree_node.sparse_float_binary_split_default_right.split
    218         split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
    219             feature_names[split.feature_column + num_dense_floats],
    220             split.dimension_id)
    221       elif node_type == "categorical_id_binary_split":
    222         split = tree_node.categorical_id_binary_split
    223         split_column = feature_names[split.feature_column + num_dense_floats +
    224                                      num_sparse_float]
    225       elif node_type == "categorical_id_set_membership_binary_split":
    226         split = tree_node.categorical_id_set_membership_binary_split
    227         split_column = feature_names[split.feature_column + num_dense_floats +
    228                                      num_sparse_float]
    229       elif node_type == "leaf":
    230         assert tree_node.node_metadata.gain == 0
    231         continue
    232       else:
    233         raise ValueError("Unexpected split type %s", node_type)
    234       # Apply shrinkage factor. It is important since it is not always uniform
    235       # across different trees.
    236       sums[split_column] += (
    237           tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
    238   return dict(sums)
    239