Home | History | Annotate | Download | only in trees
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
     16 #include "tensorflow/core/platform/macros.h"
     17 
     18 #include <algorithm>
     19 
     20 namespace tensorflow {
     21 namespace boosted_trees {
     22 namespace trees {
     23 
     24 constexpr int kInvalidLeaf = -1;
     25 int DecisionTree::Traverse(const DecisionTreeConfig& config,
     26                            const int32 sub_root_id,
     27                            const utils::Example& example) {
     28   if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
     29     return kInvalidLeaf;
     30   }
     31 
     32   // Traverse tree starting at the provided sub-root.
     33   int32 node_id = sub_root_id;
     34   while (true) {
     35     const auto& current_node = config.nodes(node_id);
     36     switch (current_node.node_case()) {
     37       case TreeNode::kLeaf: {
     38         return node_id;
     39       }
     40       case TreeNode::kDenseFloatBinarySplit: {
     41         const auto& split = current_node.dense_float_binary_split();
     42         node_id = example.dense_float_features[split.feature_column()] <=
     43                           split.threshold()
     44                       ? split.left_id()
     45                       : split.right_id();
     46         break;
     47       }
     48       case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
     49         const auto& split =
     50             current_node.sparse_float_binary_split_default_left().split();
     51         auto sparse_feature =
     52             example.sparse_float_features[split.feature_column()];
     53         // Feature id for the split when multivalent sparse float column, or 0
     54         // by default.
     55         const int32 dimension_id = split.dimension_id();
     56 
     57         node_id = !sparse_feature[dimension_id].has_value() ||
     58                           sparse_feature[dimension_id].get_value() <=
     59                               split.threshold()
     60                       ? split.left_id()
     61                       : split.right_id();
     62         break;
     63       }
     64       case TreeNode::kSparseFloatBinarySplitDefaultRight: {
     65         const auto& split =
     66             current_node.sparse_float_binary_split_default_right().split();
     67         auto sparse_feature =
     68             example.sparse_float_features[split.feature_column()];
     69         // Feature id for the split when multivalent sparse float column, or 0
     70         // by default.
     71         const int32 dimension_id = split.dimension_id();
     72         node_id = sparse_feature[dimension_id].has_value() &&
     73                           sparse_feature[dimension_id].get_value() <=
     74                               split.threshold()
     75                       ? split.left_id()
     76                       : split.right_id();
     77         break;
     78       }
     79       case TreeNode::kCategoricalIdBinarySplit: {
     80         const auto& split = current_node.categorical_id_binary_split();
     81         const auto& features =
     82             example.sparse_int_features[split.feature_column()];
     83         node_id = features.find(split.feature_id()) != features.end()
     84                       ? split.left_id()
     85                       : split.right_id();
     86         break;
     87       }
     88       case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
     89         const auto& split =
     90             current_node.categorical_id_set_membership_binary_split();
     91         // The new node_id = left_id if a feature is found, or right_id.
     92         node_id = split.right_id();
     93         for (const int64 feature_id :
     94              example.sparse_int_features[split.feature_column()]) {
     95           if (std::binary_search(split.feature_ids().begin(),
     96                                  split.feature_ids().end(), feature_id)) {
     97             node_id = split.left_id();
     98             break;
     99           }
    100         }
    101         break;
    102       }
    103       case TreeNode::NODE_NOT_SET: {
    104         LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
    105         break;
    106       }
    107     }
    108     DCHECK_NE(node_id, 0) << "Malformed tree, cycles found to root:"
    109                           << current_node.DebugString();
    110   }
    111 }
    112 
    113 void DecisionTree::LinkChildren(const std::vector<int32>& children,
    114                                 TreeNode* parent_node) {
    115   // Decide how to link children depending on the parent node's type.
    116   auto children_it = children.begin();
    117   switch (parent_node->node_case()) {
    118     case TreeNode::kLeaf: {
    119       // Essentially no-op.
    120       QCHECK(children.empty()) << "A leaf node cannot have children.";
    121       break;
    122     }
    123     case TreeNode::kDenseFloatBinarySplit: {
    124       QCHECK(children.size() == 2)
    125           << "A binary split node must have exactly two children.";
    126       auto* split = parent_node->mutable_dense_float_binary_split();
    127       split->set_left_id(*children_it);
    128       split->set_right_id(*++children_it);
    129       break;
    130     }
    131     case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
    132       QCHECK(children.size() == 2)
    133           << "A binary split node must have exactly two children.";
    134       auto* split =
    135           parent_node->mutable_sparse_float_binary_split_default_left()
    136               ->mutable_split();
    137       split->set_left_id(*children_it);
    138       split->set_right_id(*++children_it);
    139       break;
    140     }
    141     case TreeNode::kSparseFloatBinarySplitDefaultRight: {
    142       QCHECK(children.size() == 2)
    143           << "A binary split node must have exactly two children.";
    144       auto* split =
    145           parent_node->mutable_sparse_float_binary_split_default_right()
    146               ->mutable_split();
    147       split->set_left_id(*children_it);
    148       split->set_right_id(*++children_it);
    149       break;
    150     }
    151     case TreeNode::kCategoricalIdBinarySplit: {
    152       QCHECK(children.size() == 2)
    153           << "A binary split node must have exactly two children.";
    154       auto* split = parent_node->mutable_categorical_id_binary_split();
    155       split->set_left_id(*children_it);
    156       split->set_right_id(*++children_it);
    157       break;
    158     }
    159     case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
    160       QCHECK(children.size() == 2)
    161           << "A binary split node must have exactly two children.";
    162       auto* split =
    163           parent_node->mutable_categorical_id_set_membership_binary_split();
    164       split->set_left_id(*children_it);
    165       split->set_right_id(*++children_it);
    166       break;
    167     }
    168     case TreeNode::NODE_NOT_SET: {
    169       LOG(QFATAL) << "A non-set node cannot have children.";
    170       break;
    171     }
    172   }
    173 }
    174 
    175 std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
    176   // A node's children depend on its type.
    177   switch (node.node_case()) {
    178     case TreeNode::kLeaf: {
    179       return {};
    180     }
    181     case TreeNode::kDenseFloatBinarySplit: {
    182       const auto& split = node.dense_float_binary_split();
    183       return {split.left_id(), split.right_id()};
    184     }
    185     case TreeNode::kSparseFloatBinarySplitDefaultLeft: {
    186       const auto& split = node.sparse_float_binary_split_default_left().split();
    187       return {split.left_id(), split.right_id()};
    188     }
    189     case TreeNode::kSparseFloatBinarySplitDefaultRight: {
    190       const auto& split =
    191           node.sparse_float_binary_split_default_right().split();
    192       return {split.left_id(), split.right_id()};
    193     }
    194     case TreeNode::kCategoricalIdBinarySplit: {
    195       const auto& split = node.categorical_id_binary_split();
    196       return {split.left_id(), split.right_id()};
    197     }
    198     case TreeNode::kCategoricalIdSetMembershipBinarySplit: {
    199       const auto& split = node.categorical_id_set_membership_binary_split();
    200       return {split.left_id(), split.right_id()};
    201     }
    202     case TreeNode::NODE_NOT_SET: {
    203       return {};
    204     }
    205   }
    206 }
    207 
    208 }  // namespace trees
    209 }  // namespace boosted_trees
    210 }  // namespace tensorflow
    211