1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h" 16 17 #include "tensorflow/core/lib/random/philox_random.h" 18 #include "tensorflow/core/lib/random/simple_philox.h" 19 #include "tensorflow/core/platform/logging.h" 20 21 namespace tensorflow { 22 namespace boosted_trees { 23 namespace testutil { 24 25 using boosted_trees::trees::DenseFloatBinarySplit; 26 using tensorflow::boosted_trees::trees::DecisionTreeConfig; 27 using tensorflow::boosted_trees::trees::TreeNode; 28 29 namespace { 30 31 // Append the given nodes to tree with transfer of pointer ownership. 32 // nodes will not be usable upon return. 33 template <typename T> 34 void AppendNodes(DecisionTreeConfig* tree, T* nodes) { 35 std::reverse(nodes->pointer_begin(), nodes->pointer_end()); 36 while (!nodes->empty()) { 37 tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast()); 38 } 39 } 40 41 DenseFloatBinarySplit* GetSplit(TreeNode* node) { 42 switch (node->node_case()) { 43 case TreeNode::kSparseFloatBinarySplitDefaultLeft: 44 return node->mutable_sparse_float_binary_split_default_left() 45 ->mutable_split(); 46 case TreeNode::kSparseFloatBinarySplitDefaultRight: 47 return node->mutable_sparse_float_binary_split_default_right() 48 ->mutable_split(); 49 case TreeNode::kDenseFloatBinarySplit: 50 return node->mutable_dense_float_binary_split(); 51 default: 52 LOG(FATAL) << "Unknown node type encountered."; 53 } 54 return nullptr; 55 } 56 57 } // namespace 58 59 RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng, 60 int dense_feature_size, int sparse_feature_size) 61 : rng_(rng), 62 dense_feature_size_(dense_feature_size), 63 sparse_feature_size_(sparse_feature_size) {} 64 65 namespace { 66 void AddWeightAndMetadata( 67 boosted_trees::trees::DecisionTreeEnsembleConfig* ret) { 68 // Assign the weight of the tree to 1 and say that this weight was updated 69 // only once. 70 ret->add_tree_weights(1.0); 71 auto* meta = ret->add_tree_metadata(); 72 meta->set_num_tree_weight_updates(1); 73 } 74 75 } // namespace 76 77 boosted_trees::trees::DecisionTreeEnsembleConfig 78 RandomTreeGen::GenerateEnsemble(int depth, int tree_count) { 79 boosted_trees::trees::DecisionTreeEnsembleConfig ret; 80 *(ret.add_trees()) = Generate(depth); 81 AddWeightAndMetadata(&ret); 82 for (int i = 1; i < tree_count; ++i) { 83 *(ret.add_trees()) = Generate(ret.trees(0)); 84 AddWeightAndMetadata(&ret); 85 } 86 return ret; 87 } 88 89 DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) { 90 DecisionTreeConfig ret = tree; 91 for (auto& node : *ret.mutable_nodes()) { 92 if (node.node_case() == TreeNode::kLeaf) { 93 node.mutable_leaf()->mutable_sparse_vector()->set_value( 94 0, rng_->RandFloat()); 95 continue; 96 } 97 // Original node is a split. Re-generate it's type but retain the split node 98 // indices. 99 DenseFloatBinarySplit* split = GetSplit(&node); 100 const int left_id = split->left_id(); 101 const int right_id = split->right_id(); 102 GenerateSplit(&node, left_id, right_id); 103 } 104 return ret; 105 } 106 107 DecisionTreeConfig RandomTreeGen::Generate(int depth) { 108 DecisionTreeConfig ret; 109 // Add root, 110 TreeNode* node = ret.add_nodes(); 111 GenerateSplit(node, 1, 2); 112 if (depth == 1) { 113 // Add left and right leaves. 114 TreeNode* left = ret.add_nodes(); 115 left->mutable_leaf()->mutable_sparse_vector()->add_index(0); 116 left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat()); 117 TreeNode* right = ret.add_nodes(); 118 right->mutable_leaf()->mutable_sparse_vector()->add_index(0); 119 right->mutable_leaf()->mutable_sparse_vector()->add_value( 120 rng_->RandFloat()); 121 return ret; 122 } else { 123 DecisionTreeConfig left_branch = Generate(depth - 1); 124 DecisionTreeConfig right_branch = Generate(depth - 1); 125 Combine(&ret, &left_branch, &right_branch); 126 return ret; 127 } 128 } 129 130 void RandomTreeGen::Combine(DecisionTreeConfig* root, 131 DecisionTreeConfig* left_branch, 132 DecisionTreeConfig* right_branch) { 133 const int left_branch_size = left_branch->nodes_size(); 134 CHECK_EQ(1, root->nodes_size()); 135 // left_branch starts its index at 1. right_branch starts its index at 136 // (left_branch_size + 1). 137 auto* root_node = root->mutable_nodes(0); 138 DenseFloatBinarySplit* root_split = GetSplit(root_node); 139 root_split->set_left_id(1); 140 root_split->set_right_id(left_branch_size + 1); 141 // Shift left/right branch's indices internally so that everything is 142 // consistent. 143 ShiftNodeIndex(left_branch, 1); 144 ShiftNodeIndex(right_branch, left_branch_size + 1); 145 146 // Complexity O(branch node size). No proto copying though. 147 AppendNodes(root, left_branch->mutable_nodes()); 148 AppendNodes(root, right_branch->mutable_nodes()); 149 } 150 151 void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) { 152 for (TreeNode& node : *(tree->mutable_nodes())) { 153 DenseFloatBinarySplit* split = nullptr; 154 switch (node.node_case()) { 155 case TreeNode::kLeaf: 156 break; 157 case TreeNode::kSparseFloatBinarySplitDefaultLeft: 158 split = node.mutable_sparse_float_binary_split_default_left() 159 ->mutable_split(); 160 break; 161 case TreeNode::kSparseFloatBinarySplitDefaultRight: 162 split = node.mutable_sparse_float_binary_split_default_right() 163 ->mutable_split(); 164 break; 165 case TreeNode::kDenseFloatBinarySplit: 166 split = node.mutable_dense_float_binary_split(); 167 break; 168 default: 169 LOG(FATAL) << "Unknown node type encountered."; 170 } 171 if (split) { 172 split->set_left_id(shift + split->left_id()); 173 split->set_right_id(shift + split->right_id()); 174 } 175 } 176 } 177 178 void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) { 179 const double denseSplitProb = 180 sparse_feature_size_ == 0 181 ? 1.0 182 : static_cast<double>(dense_feature_size_) / 183 (dense_feature_size_ + sparse_feature_size_); 184 // Generate the tree such that it has equal probability of going left and 185 // right when the feature is missing. 186 static constexpr float kLeftProb = 0.5; 187 188 DenseFloatBinarySplit* split; 189 int feature_size; 190 if (rng_->RandFloat() < denseSplitProb) { 191 feature_size = dense_feature_size_; 192 split = node->mutable_dense_float_binary_split(); 193 } else { 194 feature_size = sparse_feature_size_; 195 if (rng_->RandFloat() < kLeftProb) { 196 split = node->mutable_sparse_float_binary_split_default_left() 197 ->mutable_split(); 198 } else { 199 split = node->mutable_sparse_float_binary_split_default_right() 200 ->mutable_split(); 201 } 202 } 203 split->set_threshold(rng_->RandFloat()); 204 split->set_feature_column(rng_->Uniform(feature_size)); 205 split->set_left_id(left_id); 206 split->set_right_id(right_id); 207 } 208 209 } // namespace testutil 210 } // namespace boosted_trees 211 } // namespace tensorflow 212