Home | History | Annotate | Download | only in utils
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
     16 
     17 #include <iterator>
     18 #include <numeric>
     19 #include <unordered_set>
     20 
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/random/philox_random.h"
     23 #include "tensorflow/core/lib/random/simple_philox.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 using tensorflow::Status;
     27 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
     28 using tensorflow::random::PhiloxRandom;
     29 using tensorflow::random::SimplePhilox;
     30 
     31 namespace tensorflow {
     32 namespace boosted_trees {
     33 namespace utils {
     34 
     35 Status DropoutUtils::DropOutTrees(
     36     const uint64 seed, const LearningRateDropoutDrivenConfig& config,
     37     const std::unordered_set<int32>& trees_not_to_drop,
     38     const std::vector<float>& weights, std::vector<int32>* dropped_trees,
     39     std::vector<float>* original_weights) {
     40   // Verify params.
     41   if (dropped_trees == nullptr) {
     42     return errors::Internal("Dropped trees is nullptr.");
     43   }
     44   if (original_weights == nullptr) {
     45     return errors::InvalidArgument("Original weights is nullptr.");
     46   }
     47   const float dropout_probability = config.dropout_probability();
     48   if (dropout_probability < 0 || dropout_probability > 1) {
     49     return errors::InvalidArgument(
     50         "Dropout probability must be in [0,1] range");
     51   }
     52   const float probability_of_skipping_dropout =
     53       config.probability_of_skipping_dropout();
     54   if (probability_of_skipping_dropout < 0 ||
     55       probability_of_skipping_dropout > 1) {
     56     return errors::InvalidArgument(
     57         "Probability of skiping dropout must be in [0,1] range");
     58   }
     59   const auto num_trees = weights.size();
     60 
     61   dropped_trees->clear();
     62   original_weights->clear();
     63 
     64   // If dropout is no op, return.
     65   if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) {
     66     return Status::OK();
     67   }
     68 
     69   // Roll the dice for each tree.
     70   PhiloxRandom philox(seed);
     71   SimplePhilox rng(&philox);
     72 
     73   std::vector<int32> trees_to_keep;
     74 
     75   // What is the probability of skipping dropout altogether.
     76   if (probability_of_skipping_dropout != 0) {
     77     // First roll the dice - do we do dropout
     78     double roll = rng.RandDouble();
     79     if (roll < probability_of_skipping_dropout) {
     80       // don't do dropout
     81       return Status::OK();
     82     }
     83   }
     84 
     85   for (int32 i = 0; i < num_trees; ++i) {
     86     // We can't drop some of the trees: for example, bias tree in batch mode,
     87     // or current tree that is built, in the batch mode.
     88     if (trees_not_to_drop.find(i) != trees_not_to_drop.end()) {
     89       continue;
     90     }
     91     double roll = rng.RandDouble();
     92     if (roll >= dropout_probability) {
     93       trees_to_keep.push_back(i);
     94     } else {
     95       dropped_trees->push_back(i);
     96     }
     97   }
     98 
     99   // Sort the dropped trees indices.
    100   std::sort(dropped_trees->begin(), dropped_trees->end());
    101   for (const int32 dropped_tree : *dropped_trees) {
    102     original_weights->push_back(weights[dropped_tree]);
    103   }
    104 
    105   return Status::OK();
    106 }
    107 
    108 void DropoutUtils::GetTreesWeightsForAddingTrees(
    109     const std::vector<int32>& dropped_trees,
    110     const std::vector<float>& dropped_trees_original_weights,
    111     const int32 new_trees_first_index, const int32 num_trees_to_add,
    112     std::vector<float>* current_weights, std::vector<int32>* num_updates) {
    113   CHECK(num_updates->size() == current_weights->size());
    114   // combined weight of trees that were dropped out
    115 
    116   const float dropped_sum =
    117       std::accumulate(dropped_trees_original_weights.begin(),
    118                       dropped_trees_original_weights.end(), 0.0);
    119 
    120   const int num_dropped = dropped_trees.size();
    121 
    122   // Allocate additional weight for the new tree
    123   const float total_new_trees_weight = dropped_sum / (num_dropped + 1);
    124 
    125   for (int i = 0; i < num_trees_to_add; ++i) {
    126     const int32 new_tree_index = new_trees_first_index + i;
    127     if (new_tree_index < current_weights->size()) {
    128       // We have the entries in weights and updates for this tree already
    129       (*current_weights)[new_tree_index] =
    130           total_new_trees_weight / num_trees_to_add;
    131       (*num_updates)[new_tree_index]++;
    132     } else {
    133       // We need to add a new entry. This is non-batch mode.
    134       current_weights->push_back(total_new_trees_weight / num_trees_to_add);
    135       num_updates->push_back(1);
    136     }
    137   }
    138 
    139   for (int32 i = 0; i < dropped_trees.size(); ++i) {
    140     const int32 dropped = dropped_trees[i];
    141     const float original_weight = dropped_trees_original_weights[i];
    142     const float new_weight = original_weight * num_dropped / (num_dropped + 1);
    143     (*current_weights)[dropped] = new_weight;
    144     // Update the number of updates per tree.
    145     ++(*num_updates)[dropped];
    146   }
    147 }
    148 
    149 }  // namespace utils
    150 }  // namespace boosted_trees
    151 }  // namespace tensorflow
    152