1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" 16 17 #include <iterator> 18 #include <numeric> 19 #include <unordered_set> 20 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/random/philox_random.h" 23 #include "tensorflow/core/lib/random/simple_philox.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 using tensorflow::Status; 27 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; 28 using tensorflow::random::PhiloxRandom; 29 using tensorflow::random::SimplePhilox; 30 31 namespace tensorflow { 32 namespace boosted_trees { 33 namespace utils { 34 35 Status DropoutUtils::DropOutTrees( 36 const uint64 seed, const LearningRateDropoutDrivenConfig& config, 37 const std::unordered_set<int32>& trees_not_to_drop, 38 const std::vector<float>& weights, std::vector<int32>* dropped_trees, 39 std::vector<float>* original_weights) { 40 // Verify params. 41 if (dropped_trees == nullptr) { 42 return errors::Internal("Dropped trees is nullptr."); 43 } 44 if (original_weights == nullptr) { 45 return errors::InvalidArgument("Original weights is nullptr."); 46 } 47 const float dropout_probability = config.dropout_probability(); 48 if (dropout_probability < 0 || dropout_probability > 1) { 49 return errors::InvalidArgument( 50 "Dropout probability must be in [0,1] range"); 51 } 52 const float probability_of_skipping_dropout = 53 config.probability_of_skipping_dropout(); 54 if (probability_of_skipping_dropout < 0 || 55 probability_of_skipping_dropout > 1) { 56 return errors::InvalidArgument( 57 "Probability of skiping dropout must be in [0,1] range"); 58 } 59 const auto num_trees = weights.size(); 60 61 dropped_trees->clear(); 62 original_weights->clear(); 63 64 // If dropout is no op, return. 65 if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) { 66 return Status::OK(); 67 } 68 69 // Roll the dice for each tree. 70 PhiloxRandom philox(seed); 71 SimplePhilox rng(&philox); 72 73 std::vector<int32> trees_to_keep; 74 75 // What is the probability of skipping dropout altogether. 76 if (probability_of_skipping_dropout != 0) { 77 // First roll the dice - do we do dropout 78 double roll = rng.RandDouble(); 79 if (roll < probability_of_skipping_dropout) { 80 // don't do dropout 81 return Status::OK(); 82 } 83 } 84 85 for (int32 i = 0; i < num_trees; ++i) { 86 // We can't drop some of the trees: for example, bias tree in batch mode, 87 // or current tree that is built, in the batch mode. 88 if (trees_not_to_drop.find(i) != trees_not_to_drop.end()) { 89 continue; 90 } 91 double roll = rng.RandDouble(); 92 if (roll >= dropout_probability) { 93 trees_to_keep.push_back(i); 94 } else { 95 dropped_trees->push_back(i); 96 } 97 } 98 99 // Sort the dropped trees indices. 100 std::sort(dropped_trees->begin(), dropped_trees->end()); 101 for (const int32 dropped_tree : *dropped_trees) { 102 original_weights->push_back(weights[dropped_tree]); 103 } 104 105 return Status::OK(); 106 } 107 108 void DropoutUtils::GetTreesWeightsForAddingTrees( 109 const std::vector<int32>& dropped_trees, 110 const std::vector<float>& dropped_trees_original_weights, 111 const int32 new_trees_first_index, const int32 num_trees_to_add, 112 std::vector<float>* current_weights, std::vector<int32>* num_updates) { 113 CHECK(num_updates->size() == current_weights->size()); 114 // combined weight of trees that were dropped out 115 116 const float dropped_sum = 117 std::accumulate(dropped_trees_original_weights.begin(), 118 dropped_trees_original_weights.end(), 0.0); 119 120 const int num_dropped = dropped_trees.size(); 121 122 // Allocate additional weight for the new tree 123 const float total_new_trees_weight = dropped_sum / (num_dropped + 1); 124 125 for (int i = 0; i < num_trees_to_add; ++i) { 126 const int32 new_tree_index = new_trees_first_index + i; 127 if (new_tree_index < current_weights->size()) { 128 // We have the entries in weights and updates for this tree already 129 (*current_weights)[new_tree_index] = 130 total_new_trees_weight / num_trees_to_add; 131 (*num_updates)[new_tree_index]++; 132 } else { 133 // We need to add a new entry. This is non-batch mode. 134 current_weights->push_back(total_new_trees_weight / num_trees_to_add); 135 num_updates->push_back(1); 136 } 137 } 138 139 for (int32 i = 0; i < dropped_trees.size(); ++i) { 140 const int32 dropped = dropped_trees[i]; 141 const float original_weight = dropped_trees_original_weights[i]; 142 const float new_weight = original_weight * num_dropped / (num_dropped + 1); 143 (*current_weights)[dropped] = new_weight; 144 // Update the number of updates per tree. 145 ++(*num_updates)[dropped]; 146 } 147 } 148 149 } // namespace utils 150 } // namespace boosted_trees 151 } // namespace tensorflow 152