Home | History | Annotate | Download | only in testutil
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h"
     16 
     17 #include "tensorflow/core/lib/random/philox_random.h"
     18 #include "tensorflow/core/lib/random/simple_philox.h"
     19 #include "tensorflow/core/platform/logging.h"
     20 
     21 namespace tensorflow {
     22 namespace boosted_trees {
     23 namespace testutil {
     24 
     25 using boosted_trees::trees::DenseFloatBinarySplit;
     26 using tensorflow::boosted_trees::trees::DecisionTreeConfig;
     27 using tensorflow::boosted_trees::trees::TreeNode;
     28 
     29 namespace {
     30 
     31 // Append the given nodes to tree with transfer of pointer ownership.
     32 // nodes will not be usable upon return.
     33 template <typename T>
     34 void AppendNodes(DecisionTreeConfig* tree, T* nodes) {
     35   std::reverse(nodes->pointer_begin(), nodes->pointer_end());
     36   while (!nodes->empty()) {
     37     tree->mutable_nodes()->AddAllocated(nodes->ReleaseLast());
     38   }
     39 }
     40 
     41 DenseFloatBinarySplit* GetSplit(TreeNode* node) {
     42   switch (node->node_case()) {
     43     case TreeNode::kSparseFloatBinarySplitDefaultLeft:
     44       return node->mutable_sparse_float_binary_split_default_left()
     45           ->mutable_split();
     46     case TreeNode::kSparseFloatBinarySplitDefaultRight:
     47       return node->mutable_sparse_float_binary_split_default_right()
     48           ->mutable_split();
     49     case TreeNode::kDenseFloatBinarySplit:
     50       return node->mutable_dense_float_binary_split();
     51     default:
     52       LOG(FATAL) << "Unknown node type encountered.";
     53   }
     54   return nullptr;
     55 }
     56 
     57 }  // namespace
     58 
     59 RandomTreeGen::RandomTreeGen(tensorflow::random::SimplePhilox* rng,
     60                              int dense_feature_size, int sparse_feature_size)
     61     : rng_(rng),
     62       dense_feature_size_(dense_feature_size),
     63       sparse_feature_size_(sparse_feature_size) {}
     64 
     65 namespace {
     66 void AddWeightAndMetadata(
     67     boosted_trees::trees::DecisionTreeEnsembleConfig* ret) {
     68   // Assign the weight of the tree to 1 and say that this weight was updated
     69   // only once.
     70   ret->add_tree_weights(1.0);
     71   auto* meta = ret->add_tree_metadata();
     72   meta->set_num_tree_weight_updates(1);
     73 }
     74 
     75 }  //  namespace
     76 
     77 boosted_trees::trees::DecisionTreeEnsembleConfig
     78 RandomTreeGen::GenerateEnsemble(int depth, int tree_count) {
     79   boosted_trees::trees::DecisionTreeEnsembleConfig ret;
     80   *(ret.add_trees()) = Generate(depth);
     81   AddWeightAndMetadata(&ret);
     82   for (int i = 1; i < tree_count; ++i) {
     83     *(ret.add_trees()) = Generate(ret.trees(0));
     84     AddWeightAndMetadata(&ret);
     85   }
     86   return ret;
     87 }
     88 
     89 DecisionTreeConfig RandomTreeGen::Generate(const DecisionTreeConfig& tree) {
     90   DecisionTreeConfig ret = tree;
     91   for (auto& node : *ret.mutable_nodes()) {
     92     if (node.node_case() == TreeNode::kLeaf) {
     93       node.mutable_leaf()->mutable_sparse_vector()->set_value(
     94           0, rng_->RandFloat());
     95       continue;
     96     }
     97     // Original node is a split. Re-generate it's type but retain the split node
     98     // indices.
     99     DenseFloatBinarySplit* split = GetSplit(&node);
    100     const int left_id = split->left_id();
    101     const int right_id = split->right_id();
    102     GenerateSplit(&node, left_id, right_id);
    103   }
    104   return ret;
    105 }
    106 
    107 DecisionTreeConfig RandomTreeGen::Generate(int depth) {
    108   DecisionTreeConfig ret;
    109   // Add root,
    110   TreeNode* node = ret.add_nodes();
    111   GenerateSplit(node, 1, 2);
    112   if (depth == 1) {
    113     // Add left and right leaves.
    114     TreeNode* left = ret.add_nodes();
    115     left->mutable_leaf()->mutable_sparse_vector()->add_index(0);
    116     left->mutable_leaf()->mutable_sparse_vector()->add_value(rng_->RandFloat());
    117     TreeNode* right = ret.add_nodes();
    118     right->mutable_leaf()->mutable_sparse_vector()->add_index(0);
    119     right->mutable_leaf()->mutable_sparse_vector()->add_value(
    120         rng_->RandFloat());
    121     return ret;
    122   } else {
    123     DecisionTreeConfig left_branch = Generate(depth - 1);
    124     DecisionTreeConfig right_branch = Generate(depth - 1);
    125     Combine(&ret, &left_branch, &right_branch);
    126     return ret;
    127   }
    128 }
    129 
    130 void RandomTreeGen::Combine(DecisionTreeConfig* root,
    131                             DecisionTreeConfig* left_branch,
    132                             DecisionTreeConfig* right_branch) {
    133   const int left_branch_size = left_branch->nodes_size();
    134   CHECK_EQ(1, root->nodes_size());
    135   // left_branch starts its index at 1. right_branch starts its index at
    136   // (left_branch_size + 1).
    137   auto* root_node = root->mutable_nodes(0);
    138   DenseFloatBinarySplit* root_split = GetSplit(root_node);
    139   root_split->set_left_id(1);
    140   root_split->set_right_id(left_branch_size + 1);
    141   // Shift left/right branch's indices internally so that everything is
    142   // consistent.
    143   ShiftNodeIndex(left_branch, 1);
    144   ShiftNodeIndex(right_branch, left_branch_size + 1);
    145 
    146   // Complexity O(branch node size). No proto copying though.
    147   AppendNodes(root, left_branch->mutable_nodes());
    148   AppendNodes(root, right_branch->mutable_nodes());
    149 }
    150 
    151 void RandomTreeGen::ShiftNodeIndex(DecisionTreeConfig* tree, int shift) {
    152   for (TreeNode& node : *(tree->mutable_nodes())) {
    153     DenseFloatBinarySplit* split = nullptr;
    154     switch (node.node_case()) {
    155       case TreeNode::kLeaf:
    156         break;
    157       case TreeNode::kSparseFloatBinarySplitDefaultLeft:
    158         split = node.mutable_sparse_float_binary_split_default_left()
    159                     ->mutable_split();
    160         break;
    161       case TreeNode::kSparseFloatBinarySplitDefaultRight:
    162         split = node.mutable_sparse_float_binary_split_default_right()
    163                     ->mutable_split();
    164         break;
    165       case TreeNode::kDenseFloatBinarySplit:
    166         split = node.mutable_dense_float_binary_split();
    167         break;
    168       default:
    169         LOG(FATAL) << "Unknown node type encountered.";
    170     }
    171     if (split) {
    172       split->set_left_id(shift + split->left_id());
    173       split->set_right_id(shift + split->right_id());
    174     }
    175   }
    176 }
    177 
    178 void RandomTreeGen::GenerateSplit(TreeNode* node, int left_id, int right_id) {
    179   const double denseSplitProb =
    180       sparse_feature_size_ == 0
    181           ? 1.0
    182           : static_cast<double>(dense_feature_size_) /
    183                 (dense_feature_size_ + sparse_feature_size_);
    184   // Generate the tree such that it has equal probability of going left and
    185   // right when the feature is missing.
    186   static constexpr float kLeftProb = 0.5;
    187 
    188   DenseFloatBinarySplit* split;
    189   int feature_size;
    190   if (rng_->RandFloat() < denseSplitProb) {
    191     feature_size = dense_feature_size_;
    192     split = node->mutable_dense_float_binary_split();
    193   } else {
    194     feature_size = sparse_feature_size_;
    195     if (rng_->RandFloat() < kLeftProb) {
    196       split = node->mutable_sparse_float_binary_split_default_left()
    197                   ->mutable_split();
    198     } else {
    199       split = node->mutable_sparse_float_binary_split_default_right()
    200                   ->mutable_split();
    201     }
    202   }
    203   split->set_threshold(rng_->RandFloat());
    204   split->set_feature_column(rng_->Uniform(feature_size));
    205   split->set_left_id(left_id);
    206   split->set_right_id(right_id);
    207 }
    208 
    209 }  // namespace testutil
    210 }  // namespace boosted_trees
    211 }  // namespace tensorflow
    212