Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
     16 
     17 #include <cfloat>
     18 #include <queue>
     19 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
     20 #include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
     21 #include "tensorflow/core/lib/random/distribution_sampler.h"
     22 
     23 namespace tensorflow {
     24 namespace tensorforest {
     25 
     26 // When creating evaluators for the split candidates, use these
     27 // for the left and right return values.
     28 static const int32 LEFT_INDEX = 0;
     29 static const int32 RIGHT_INDEX = 1;
     30 
     31 GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
     32     : weight_sum_(0),
     33       depth_(depth),
     34       params_(params),
     35       split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
     36       num_splits_to_consider_(
     37           ResolveParam(params.num_splits_to_consider(), depth)),
     38       num_outputs_(params.num_outputs()) {}
     39 
     40 void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
     41                          const std::unique_ptr<TensorDataSet>& input_data,
     42                          const InputTarget* target, int example) {
     43   // It's possible that the split collection calls AddSplit, but we actually
     44   // have all the splits we need and are just waiting for them to be fully
     45   // initialized.
     46   if (splits_.size() < num_splits_to_consider_) {
     47     splits_.push_back(split);
     48     evaluators_.emplace_back(
     49         CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
     50     AddSplitStats(target, example);
     51   }
     52 
     53   if (input_data != nullptr && target != nullptr &&
     54       params_.initialize_average_splits()) {
     55     AdditionalInitializationExample(input_data, target, example);
     56   }
     57 }
     58 
     59 void GrowStats::RemoveSplit(int split_num) {
     60   splits_.erase(splits_.begin() + split_num);
     61   evaluators_.erase(evaluators_.begin() + split_num);
     62   RemoveSplitStats(split_num);
     63 }
     64 
     65 // ------------------------ Classification --------------------------- //
     66 
     67 ClassificationStats::ClassificationStats(const TensorForestParams& params,
     68                                          int32 depth)
     69     : GrowStats(params, depth), finish_early_(false) {
     70   // Early splitting params.
     71   if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
     72     min_split_samples_ = split_after_samples_;
     73     finish_sample_epoch_ = 1;
     74     finish_check_every_ = split_after_samples_ * 2;
     75   } else {
     76     if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
     77       LOG(FATAL) << "dominate_fraction and min_split_samples "
     78                  << "required for early-finish strategy.";
     79     } else {
     80       min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
     81       finish_check_every_ =
     82           ResolveParam(params.finish_type().check_every_steps(), depth);
     83       finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
     84 
     85       dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
     86       if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
     87         LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
     88       }
     89     }
     90   }
     91 
     92   // Pruning params.
     93   if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
     94     prune_check_every_ =
     95         ResolveParam(params.pruning_type().prune_every_samples(), depth);
     96     prune_sample_epoch_ = 1;
     97     prune_fraction_ = 0.0;
     98     switch (params_.pruning_type().type()) {
     99       case SPLIT_PRUNE_HALF:
    100         prune_fraction_ = 0.5;
    101         break;
    102       case SPLIT_PRUNE_QUARTER:
    103         prune_fraction_ = 0.25;
    104         break;
    105       case SPLIT_PRUNE_10_PERCENT:
    106         prune_fraction_ = 0.10;
    107         break;
    108       case SPLIT_PRUNE_HOEFFDING:
    109         dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
    110         half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
    111         break;
    112       default:
    113         LOG(WARNING) << "Unknown pruning type";
    114     }
    115   } else {
    116     prune_check_every_ = split_after_samples_ * 2;
    117     prune_sample_epoch_ = 1;
    118   }
    119 
    120   if (params.use_running_stats_method()) {
    121     left_gini_.reset(new RunningGiniScores());
    122     right_gini_.reset(new RunningGiniScores());
    123   }
    124 
    125   uint64 time_seed = static_cast<uint64>(std::clock());
    126   single_rand_ = std::unique_ptr<random::PhiloxRandom>(
    127       new random::PhiloxRandom(time_seed));
    128   rng_ = std::unique_ptr<random::SimplePhilox>(
    129       new random::SimplePhilox(single_rand_.get()));
    130 }
    131 
    132 void ClassificationStats::AdditionalInitializationExample(
    133     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
    134     int example) {
    135   const int32 new_target = target->GetTargetAsClassIndex(example, 0);
    136   std::unordered_set<int> to_erase;
    137   for (auto it = half_initialized_splits_.begin();
    138        it != half_initialized_splits_.end(); ++it) {
    139     if (it->second != new_target) {
    140       auto& split = splits_[it->first];
    141       if (split.has_inequality_left_child_test()) {
    142         auto& test = split.inequality_left_child_test();
    143         auto* thresh =
    144             split.mutable_inequality_left_child_test()->mutable_threshold();
    145         if (test.has_feature_id()) {
    146           const float val =
    147               input_data->GetExampleValue(example, test.feature_id());
    148           thresh->set_float_value((thresh->float_value() + val) / 2);
    149         }
    150       }
    151       to_erase.insert(it->first);
    152     }
    153   }
    154 
    155   for (const int split_id : to_erase) {
    156     half_initialized_splits_.erase(split_id);
    157   }
    158 }
    159 
    160 bool ClassificationStats::IsFinished() const {
    161   bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
    162   return basic || finish_early_;
    163 }
    164 
    165 float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
    166                                                 float* right_sum) const {
    167   if (left_gini_ == nullptr) {
    168     return GiniScore(split, left_sum, right_sum);
    169   } else {
    170     *left_sum = left_gini_->sum(split);
    171     const float left = WeightedSmoothedGini(
    172         *left_sum, left_gini_->square(split), num_outputs_);
    173 
    174     *right_sum = right_gini_->sum(split);
    175     const float right = WeightedSmoothedGini(
    176         *right_sum, right_gini_->square(split), num_outputs_);
    177 
    178     return left + right;
    179   }
    180 }
    181 
    182 void ClassificationStats::AddExample(
    183     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
    184     int example) {
    185   const int64 int_label = target->GetTargetAsClassIndex(example, 0);
    186   const float weight = target->GetTargetWeight(example);
    187 
    188   for (int i = 0; i < num_splits(); ++i) {
    189     auto& eval = evaluators_[i];
    190     if (eval->Decide(input_data, example) == LEFT_INDEX) {
    191       if (left_gini_ != nullptr) {
    192         left_gini_->update(i, left_count(i, int_label), weight);
    193       }
    194       ClassificationAddLeftExample(i, int_label, weight);
    195     } else {
    196       if (right_gini_ != nullptr) {
    197         right_gini_->update(i, right_count(i, int_label), weight);
    198       }
    199       ClassificationAddRightExample(i, int_label, weight);
    200     }
    201   }
    202 
    203   ClassificationAddTotalExample(int_label, weight);
    204 
    205   weight_sum_ += weight;
    206 
    207   CheckFinishEarly();
    208   CheckPrune();
    209 }
    210 
    211 void ClassificationStats::CheckPrune() {
    212   if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
    213       weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
    214     return;
    215   }
    216   ++prune_sample_epoch_;
    217 
    218   if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
    219     CheckPruneHoeffding();
    220     return;
    221   }
    222 
    223   const int to_remove = num_splits() * prune_fraction_;
    224   if (to_remove <= 0) {
    225     return;
    226   }
    227 
    228   // pair ordering is first-then-second by default, no need for custom
    229   // comparison.  Use std::greater to make it a min-heap.
    230   std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
    231                       std::greater<std::pair<float, int>>>
    232       worst;
    233 
    234   // Track indices that are in the heap so we can iterate over them
    235   // by largest-first later.
    236   std::set<int> indices;
    237 
    238   for (int i = 0; i < num_splits(); ++i) {
    239     float left, right;
    240     const float split_score = MaybeCachedGiniScore(i, &left, &right);
    241     if (worst.size() < to_remove) {
    242       worst.push(std::pair<float, int>(split_score, i));
    243       indices.insert(i);
    244     } else if (worst.top().first < split_score) {
    245       indices.erase(worst.top().second);
    246       worst.pop();
    247       worst.push(std::pair<float, int>(split_score, i));
    248       indices.insert(i);
    249     }
    250   }
    251 
    252   // traverse indices from the back so that they are removed correctly.
    253   for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
    254     RemoveSplit(*it);
    255   }
    256 }
    257 
    258 void ClassificationStats::CheckPruneHoeffding() {
    259   std::vector<float> split_scores(num_splits());
    260   // Find best split score
    261   float best_split_score = FLT_MAX;
    262   for (int i = 0; i < num_splits(); ++i) {
    263     float left, right;
    264     split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
    265     if (split_scores[i] < best_split_score) {
    266       best_split_score = split_scores[i];
    267     }
    268   }
    269 
    270   // We apply the Hoeffding bound to the difference between the best split
    271   // score and the i-th split score.
    272   // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
    273   const float num_classes = params_.num_outputs();
    274   const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
    275   float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
    276   for (int i = num_splits() - 1; i >= 0; i--) {
    277     if (split_scores[i] - best_split_score > epsilon) {
    278       RemoveSplit(i);
    279     }
    280   }
    281 }
    282 
    283 void ClassificationStats::CheckFinishEarly() {
    284   if (weight_sum_ < min_split_samples_ ||
    285       weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
    286     return;
    287   }
    288   ++finish_sample_epoch_;
    289 
    290   if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
    291     CheckFinishEarlyHoeffding();
    292   } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
    293     CheckFinishEarlyBootstrap();
    294   }
    295 }
    296 
    297 void ClassificationStats::CheckFinishEarlyHoeffding() {
    298   // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
    299   float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
    300 
    301   float hoeffding_bound =
    302       range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
    303 
    304   float unused_left_sum, unused_right_sum;
    305   std::function<float(int)> score_fn =
    306       std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
    307                 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
    308 
    309   float best_score;
    310   int32 best_index;
    311   float second_best_score;
    312   int32 second_best_index;
    313   GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
    314              &second_best_score, &second_best_index);
    315 
    316   finish_early_ = (second_best_score - best_score) > hoeffding_bound;
    317 }
    318 
    319 void ClassificationStats::MakeBootstrapWeights(int index,
    320                                                std::vector<float>* weights) {
    321   int n = weight_sum_;
    322   float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
    323   for (int i = 0; i < num_outputs_; ++i) {
    324     // Use the Laplace smoothed per-class probabilities when generating the
    325     // bootstrap samples.
    326     (*weights)[i] = (left_count(index, i) + 1.0) / denom;
    327     (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
    328   }
    329 }
    330 
    331 int ClassificationStats::NumBootstrapSamples() const {
    332   float p = 1.0 - dominate_fraction_;
    333   int bootstrap_samples = 1;
    334   while (p < 1.0) {
    335     ++bootstrap_samples;
    336     p = p * 2;
    337   }
    338   return bootstrap_samples;
    339 }
    340 
    341 void ClassificationStats::CheckFinishEarlyBootstrap() {
    342   float unused_left_sum, unused_right_sum;
    343   std::function<float(int)> score_fn =
    344       std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
    345                 std::placeholders::_1, &unused_left_sum, &unused_right_sum);
    346 
    347   float best_score;
    348   int32 best_index;
    349   float second_best_score;
    350   int32 second_best_index;
    351   GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
    352              &second_best_score, &second_best_index);
    353 
    354   std::vector<float> weights1(num_outputs_ * 2);
    355   MakeBootstrapWeights(best_index, &weights1);
    356   random::DistributionSampler ds1(weights1);
    357 
    358   std::vector<float> weights2(num_outputs_ * 2);
    359   MakeBootstrapWeights(second_best_index, &weights2);
    360   random::DistributionSampler ds2(weights2);
    361 
    362   const int bootstrap_samples = NumBootstrapSamples();
    363 
    364   int worst_g1 = 0;
    365   for (int i = 0; i < bootstrap_samples; i++) {
    366     int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
    367     worst_g1 = std::max(worst_g1, g1);
    368   }
    369 
    370   int best_g2 = 99;
    371   for (int i = 0; i < bootstrap_samples; i++) {
    372     int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
    373     best_g2 = std::min(best_g2, g2);
    374   }
    375 
    376   finish_early_ = worst_g1 < best_g2;
    377 }
    378 
    379 bool ClassificationStats::BestSplit(SplitCandidate* best) const {
    380   float min_score = FLT_MAX;
    381   int best_index = -1;
    382   float best_left_sum, best_right_sum;
    383 
    384   // Calculate sums.
    385   for (int i = 0; i < num_splits(); ++i) {
    386     float left_sum, right_sum;
    387     const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
    388     // Find the lowest gini.
    389     if (left_sum > 0 && right_sum > 0 &&
    390         split_score < min_score) {  // useless check
    391       min_score = split_score;
    392       best_index = i;
    393       best_left_sum = left_sum;
    394       best_right_sum = right_sum;
    395     }
    396   }
    397 
    398   // This could happen if all the splits are useless.
    399   if (best_index < 0) {
    400     return false;
    401   }
    402 
    403   // Fill in stats to be used for leaf model.
    404   *best->mutable_split() = splits_[best_index];
    405   auto* left = best->mutable_left_stats();
    406   left->set_weight_sum(best_left_sum);
    407   auto* right = best->mutable_right_stats();
    408   right->set_weight_sum(best_right_sum);
    409   InitLeafClassStats(best_index, left, right);
    410 
    411   return true;
    412 }
    413 
    414 // ------------------------ Dense Classification --------------------------- //
    415 void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
    416   Initialize();
    417   if (!slot.has_post_init_leaf_stats()) {
    418     return;
    419   }
    420   const int32 num_classes = params_.num_outputs();
    421   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
    422   const auto& class_stats =
    423       slot.post_init_leaf_stats().classification().dense_counts();
    424 
    425   // Total counts.
    426   for (int i = 0; i < num_classes; ++i) {
    427     total_counts_[i] = class_stats.value(i).float_value();
    428     num_outputs_seen_ += total_counts_[i] != 0;
    429   }
    430 
    431   // Candidate counts and splits.
    432   int split_num = 0;
    433   for (const auto& cand : slot.candidates()) {
    434     AddSplit(cand.split(), nullptr, nullptr, -1);
    435     const auto& left_stats = cand.left_stats().classification().dense_counts();
    436     for (int i = 0; i < num_classes; ++i) {
    437       const float val = left_stats.value(i).float_value();
    438       mutable_left_count(split_num, i) = val;
    439       MaybeInitializeRunningCount(split_num, val);
    440     }
    441     ++split_num;
    442   }
    443 }
    444 
    445 void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
    446   auto* slot_stats = slot->mutable_post_init_leaf_stats();
    447   slot_stats->set_weight_sum(weight_sum_);
    448 
    449   auto* class_stats = slot->mutable_post_init_leaf_stats()
    450                           ->mutable_classification()
    451                           ->mutable_dense_counts();
    452   for (int i = 0; i < num_outputs_; ++i) {
    453     class_stats->add_value()->set_float_value(total_counts_[i]);
    454   }
    455 
    456   for (int split_num = 0; split_num < num_splits(); ++split_num) {
    457     auto* cand = slot->add_candidates();
    458     *cand->mutable_split() = splits_[split_num];
    459     auto* left_stats = cand->mutable_left_stats()
    460                            ->mutable_classification()
    461                            ->mutable_dense_counts();
    462     for (int i = 0; i < num_outputs_; ++i) {
    463       left_stats->add_value()->set_float_value(left_count(split_num, i));
    464     }
    465   }
    466 }
    467 
    468 float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
    469                                               float* right_sum) const {
    470   float left_square = 0, right_square = 0;
    471   *left_sum = 0;
    472   *right_sum = 0;
    473   for (int j = 0; j < num_outputs_; ++j) {
    474     const float left = left_count(split, j);
    475     *left_sum += left;
    476     left_square += left * left;
    477     const float right = right_count(split, j);
    478     *right_sum += right;
    479     right_square += right * right;
    480   }
    481 
    482   const float left_score =
    483       WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
    484   const float right_score =
    485       WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
    486   return left_score + right_score;
    487 }
    488 
    489 void DenseClassificationGrowStats::InitLeafClassStats(
    490     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
    491   auto* left_class_stats = left_stats->mutable_classification();
    492   auto* left_counts = left_class_stats->mutable_dense_counts();
    493   for (int i = 0; i < params_.num_outputs(); ++i) {
    494     left_counts->add_value()->set_float_value(left_count(best_split_index, i));
    495   }
    496 
    497   auto* right_class_stats = right_stats->mutable_classification();
    498   auto* right_counts = right_class_stats->mutable_dense_counts();
    499   for (int i = 0; i < params_.num_outputs(); ++i) {
    500     right_counts->add_value()->set_float_value(total_counts_[i] -
    501                                                left_count(best_split_index, i));
    502   }
    503 }
    504 
    505 // ------------------------ Sparse Classification --------------------------- //
    506 void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
    507   Initialize();
    508   if (!slot.has_post_init_leaf_stats()) {
    509     return;
    510   }
    511   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
    512   const auto& class_stats =
    513       slot.post_init_leaf_stats().classification().sparse_counts();
    514 
    515   // Total counts.
    516   for (auto const& entry : class_stats.sparse_value()) {
    517     total_counts_[entry.first] = entry.second.float_value();
    518   }
    519 
    520   // Candidate counts and splits.
    521   int split_num = 0;
    522   for (const auto& cand : slot.candidates()) {
    523     AddSplit(cand.split(), nullptr, nullptr, -1);
    524     const auto& left_stats = cand.left_stats().classification().sparse_counts();
    525     for (auto const& entry : left_stats.sparse_value()) {
    526       const float val = entry.second.float_value();
    527       left_counts_[split_num][entry.first] = val;
    528       MaybeInitializeRunningCount(split_num, val);
    529     }
    530     ++split_num;
    531   }
    532 }
    533 
    534 void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
    535   auto* slot_stats = slot->mutable_post_init_leaf_stats();
    536   slot_stats->set_weight_sum(weight_sum_);
    537 
    538   auto* class_stats = slot->mutable_post_init_leaf_stats()
    539                           ->mutable_classification()
    540                           ->mutable_sparse_counts()
    541                           ->mutable_sparse_value();
    542   for (const auto& entry : total_counts_) {
    543     decision_trees::Value val;
    544     val.set_float_value(entry.second);
    545     (*class_stats)[entry.first] = val;
    546   }
    547 
    548   for (int split_num = 0; split_num < num_splits(); ++split_num) {
    549     auto* cand = slot->add_candidates();
    550     *cand->mutable_split() = splits_[split_num];
    551     auto* left_stats = cand->mutable_left_stats()
    552                            ->mutable_classification()
    553                            ->mutable_sparse_counts()
    554                            ->mutable_sparse_value();
    555     for (const auto& entry : left_counts_[split_num]) {
    556       decision_trees::Value val;
    557       val.set_float_value(entry.second);
    558       (*left_stats)[entry.first] = val;
    559     }
    560   }
    561 }
    562 
    563 float SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
    564                                                float* right_sum) const {
    565   float left_square = 0, right_square = 0;
    566   *left_sum = 0;
    567   *right_sum = 0;
    568   for (const auto& entry : total_counts_) {
    569     const int label = entry.first;
    570     float left = 0;
    571     float right = 0;
    572     auto it = left_counts_[split].find(label);
    573     if (it == left_counts_[split].end()) {
    574       right = entry.second;
    575     } else {
    576       left = it->second;
    577       right = entry.second - it->second;
    578     }
    579     *left_sum += left;
    580     left_square += left * left;
    581     *right_sum += right;
    582     right_square += right * right;
    583   }
    584   const int32 num_classes = params_.num_outputs();
    585   const float left_score =
    586       WeightedSmoothedGini(*left_sum, left_square, num_classes);
    587   const float right_score =
    588       WeightedSmoothedGini(*right_sum, right_square, num_classes);
    589   return left_score + right_score;
    590 }
    591 
    592 void SparseClassificationGrowStats::InitLeafClassStats(
    593     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
    594   auto* left_class_stats = left_stats->mutable_classification();
    595   auto* left_counts =
    596       left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
    597   auto* right_class_stats = right_stats->mutable_classification();
    598   auto* right_counts =
    599       right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
    600 
    601   for (const auto& entry : total_counts_) {
    602     auto it = left_counts_[best_split_index].find(entry.first);
    603     if (it == left_counts_[best_split_index].end()) {
    604       (*right_counts)[entry.first].set_float_value(entry.second);
    605     } else {
    606       const float left = it->second;
    607       const float right = entry.second - it->second;
    608       (*left_counts)[entry.first].set_float_value(left);
    609       if (right > 0) {
    610         (*right_counts)[entry.first].set_float_value(right);
    611       }
    612     }
    613   }
    614 }
    615 
    616 // -------------------- FixedSizeClassStats --------------------------------- //
    617 
    618 // FixedSizeClassStats implements the "SpaceSaving" algorithm by
    619 // Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi.  See for example
    620 // https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
    621 
    622 int argmin(const std::unordered_map<int, float>& m) {
    623   int c = -1;
    624   float f = FLT_MAX;
    625   for (const auto it : m) {
    626     if (it.second < f) {
    627       f = it.second;
    628       c = it.first;
    629     }
    630   }
    631   return c;
    632 }
    633 
    634 void FixedSizeClassStats::accumulate(int c, float w) {
    635   auto it = class_weights_.find(c);
    636   if (it != class_weights_.end()) {
    637     it->second += w;
    638     if (c == smallest_weight_class_) {
    639       smallest_weight_class_ = argmin(class_weights_);
    640     }
    641     return;
    642   }
    643 
    644   if (class_weights_.size() < n_) {
    645     class_weights_.insert(it, std::pair<int, float>(c, w));
    646     if (class_weights_.size() == n_) {
    647       // Can't assume last added has the smallest weight, because the
    648       // w's might be all different.
    649       smallest_weight_class_ = argmin(class_weights_);
    650     }
    651     return;
    652   }
    653 
    654   // This is the slightly unintuitive heart of the SpaceSaving algorithm:
    655   // if the map is full and we see a new class, we find the entry with the
    656   // smallest weight and "take it over":  we add our weight to its weight,
    657   // and assign it all to the new seen class.
    658   it = class_weights_.find(smallest_weight_class_);
    659   float new_weight = it->second + w;
    660   class_weights_.erase(it);
    661   class_weights_[c] = new_weight;
    662   smallest_weight_class_ = argmin(class_weights_);
    663 }
    664 
    665 float FixedSizeClassStats::get_weight(int c) const {
    666   // Every entry in class_weights_ might be overstated by as much as the
    667   // smallest_weight.  We therefore assume that each has been overstated
    668   // by smallest_weight / 2.0, and we re-distribute that mass over all
    669   // num_classes_ classes.
    670   float smallest_weight = 0.0;
    671   auto it = class_weights_.find(smallest_weight_class_);
    672   if (it != class_weights_.end()) {
    673     smallest_weight = it->second;
    674   }
    675   float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
    676   it = class_weights_.find(c);
    677   if (it != class_weights_.end()) {
    678     w += it->second - smallest_weight / 2.0;
    679   }
    680   return w;
    681 }
    682 
    683 void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
    684   *sum = 0.0;
    685   *square = 0.0;
    686 
    687   float smallest_weight = 0.0;
    688   auto it = class_weights_.find(smallest_weight_class_);
    689   if (it != class_weights_.end()) {
    690     smallest_weight = it->second;
    691   }
    692 
    693   float w;
    694   for (const auto it : class_weights_) {
    695     *sum += it.second;
    696     w = get_weight(it.first);
    697     *square += w * w;
    698   }
    699 
    700   w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
    701   *square += (num_classes_ - n_) * w * w;
    702 }
    703 
    704 void FixedSizeClassStats::ExtractFromProto(
    705     const decision_trees::SparseVector& sparse_vector) {
    706   for (const auto& it : sparse_vector.sparse_value()) {
    707     class_weights_[it.first] = it.second.float_value();
    708   }
    709   if (class_weights_.size() == n_) {
    710     smallest_weight_class_ = argmin(class_weights_);
    711   }
    712 }
    713 
    714 void FixedSizeClassStats::PackToProto(
    715     decision_trees::SparseVector* sparse_vector) const {
    716   for (const auto it : class_weights_) {
    717     (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
    718         it.second);
    719   }
    720 }
    721 
    722 // --------------------- FixedSizeSparseClassificationGrowStats ------------- //
    723 
    724 void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
    725     const FertileSlot& slot) {
    726   Initialize();
    727   if (!slot.has_post_init_leaf_stats()) {
    728     return;
    729   }
    730   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
    731 
    732   // Candidate counts and splits.
    733   int split_num = 0;
    734   left_counts_.clear();
    735   right_counts_.clear();
    736   for (const auto& cand : slot.candidates()) {
    737     AddSplit(cand.split(), nullptr, nullptr, -1);
    738     const auto& left_stats = cand.left_stats().classification().sparse_counts();
    739     left_counts_.emplace_back(params_.num_classes_to_track(),
    740                               params_.num_outputs());
    741     left_counts_[split_num].ExtractFromProto(left_stats);
    742     const auto& right_stats =
    743         cand.right_stats().classification().sparse_counts();
    744     right_counts_.emplace_back(params_.num_classes_to_track(),
    745                                params_.num_outputs());
    746     right_counts_[split_num].ExtractFromProto(right_stats);
    747     ++split_num;
    748   }
    749 }
    750 
    751 void FixedSizeSparseClassificationGrowStats::PackToProto(
    752     FertileSlot* slot) const {
    753   auto* slot_stats = slot->mutable_post_init_leaf_stats();
    754   slot_stats->set_weight_sum(weight_sum_);
    755 
    756   for (int split_num = 0; split_num < num_splits(); ++split_num) {
    757     auto* cand = slot->add_candidates();
    758     *cand->mutable_split() = splits_[split_num];
    759     auto* left_stats = cand->mutable_left_stats()
    760                            ->mutable_classification()
    761                            ->mutable_sparse_counts();
    762     left_counts_[split_num].PackToProto(left_stats);
    763     auto* right_stats = cand->mutable_right_stats()
    764                             ->mutable_classification()
    765                             ->mutable_sparse_counts();
    766     right_counts_[split_num].PackToProto(right_stats);
    767   }
    768 }
    769 
    770 float FixedSizeSparseClassificationGrowStats::GiniScore(
    771     int split, float* left_sum, float* right_sum) const {
    772   float left_square, right_square;
    773   left_counts_[split].set_sum_and_square(left_sum, &left_square);
    774   right_counts_[split].set_sum_and_square(right_sum, &right_square);
    775   const int32 num_classes = params_.num_outputs();
    776   const float left_score =
    777       WeightedSmoothedGini(*left_sum, left_square, num_classes);
    778   const float right_score =
    779       WeightedSmoothedGini(*right_sum, right_square, num_classes);
    780   return left_score + right_score;
    781 }
    782 
    783 void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
    784     int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
    785   auto* left_class_stats = left_stats->mutable_classification();
    786   auto* left_counts = left_class_stats->mutable_sparse_counts();
    787   left_counts_[best_split_index].PackToProto(left_counts);
    788 
    789   auto* right_class_stats = right_stats->mutable_classification();
    790   auto* right_counts = right_class_stats->mutable_sparse_counts();
    791   right_counts_[best_split_index].PackToProto(right_counts);
    792 }
    793 
    794 // --------------------- Least Squares Regression --------------------------- //
    795 void LeastSquaresRegressionGrowStats::ExtractFromProto(
    796     const FertileSlot& slot) {
    797   const int32 num_outputs = params_.num_outputs();
    798   Initialize();
    799   if (!slot.has_post_init_leaf_stats()) {
    800     return;
    801   }
    802   weight_sum_ = slot.post_init_leaf_stats().weight_sum();
    803   const auto& total_sums =
    804       slot.post_init_leaf_stats().regression().mean_output();
    805   const auto& total_squares =
    806       slot.post_init_leaf_stats().regression().mean_output_squares();
    807 
    808   // Total counts.
    809   for (int i = 0; i < num_outputs; ++i) {
    810     total_sum_[i] = total_sums.value(i).float_value();
    811     total_sum_squares_[i] = total_squares.value(i).float_value();
    812   }
    813 
    814   // Candidate counts and splits.
    815   int split_num = 0;
    816   for (const auto& cand : slot.candidates()) {
    817     AddSplit(cand.split(), nullptr, nullptr, -1);
    818     const auto& sums = cand.left_stats().regression().mean_output();
    819     const auto& squares = cand.left_stats().regression().mean_output_squares();
    820     for (int i = 0; i < num_outputs; ++i) {
    821       left_sum(split_num, i) = sums.value(i).float_value();
    822       left_square(split_num, i) = squares.value(i).float_value();
    823     }
    824     left_counts_[split_num] = cand.left_stats().weight_sum();
    825     ++split_num;
    826   }
    827 }
    828 
    829 void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
    830   const int32 num_outputs = params_.num_outputs();
    831   auto* slot_stats = slot->mutable_post_init_leaf_stats();
    832   slot_stats->set_weight_sum(weight_sum_);
    833 
    834   auto* total_sums = slot->mutable_post_init_leaf_stats()
    835                          ->mutable_regression()
    836                          ->mutable_mean_output();
    837   auto* total_squares = slot->mutable_post_init_leaf_stats()
    838                             ->mutable_regression()
    839                             ->mutable_mean_output_squares();
    840 
    841   for (int i = 0; i < total_sum_.size(); ++i) {
    842     total_sums->add_value()->set_float_value(total_sum_[i]);
    843     total_squares->add_value()->set_float_value(total_sum_squares_[i]);
    844   }
    845 
    846   for (int split_num = 0; split_num < num_splits(); ++split_num) {
    847     auto* cand = slot->add_candidates();
    848     *cand->mutable_split() = splits_[split_num];
    849     auto* sums =
    850         cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
    851     auto* squares = cand->mutable_left_stats()
    852                         ->mutable_regression()
    853                         ->mutable_mean_output_squares();
    854     for (int i = 0; i < num_outputs; ++i) {
    855       sums->add_value()->set_float_value(left_sum(split_num, i));
    856       squares->add_value()->set_float_value(left_square(split_num, i));
    857     }
    858     cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
    859   }
    860 }
    861 
    862 void LeastSquaresRegressionGrowStats::AddExample(
    863     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
    864     int example) {
    865   const int32 num_outputs = params_.num_outputs();
    866   // Update splits.
    867   for (int i = 0; i < num_splits(); ++i) {
    868     auto& eval = evaluators_[i];
    869     if (eval->Decide(input_data, example) == LEFT_INDEX) {
    870       for (int j = 0; j < num_outputs; ++j) {
    871         const float output = target->GetTargetAsContinuous(example, j);
    872         left_sum(i, j) += output;
    873         left_square(i, j) += output * output;
    874       }
    875       ++left_counts_[i];
    876     }
    877   }
    878 
    879   // Update totals.
    880   for (int i = 0; i < num_outputs; ++i) {
    881     const float output = target->GetTargetAsContinuous(example, i);
    882     total_sum_[i] += output;
    883     total_sum_squares_[i] += output * output;
    884   }
    885   weight_sum_ += 1.0;
    886 }
    887 
    888 float LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
    889   float total_variance = 0;
    890   for (int i = 0; i < params_.num_outputs(); ++i) {
    891     // Left side
    892     const float le_x = left_sum(split, i) / left_counts_[split];
    893 
    894     const float le_x2 = left_square(split, i) / left_counts_[split];
    895     total_variance += le_x2 - le_x * le_x;
    896 
    897     // Right side
    898     const float re_x = (total_sum_[i] - left_sum(split, i)) /
    899                        (weight_sum_ - left_counts_[split]);
    900 
    901     const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
    902                         (weight_sum_ - left_counts_[split]);
    903     total_variance += re_x2 - re_x * re_x;
    904   }
    905   return total_variance;
    906 }
    907 
    908 bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
    909   float min_score = FLT_MAX;
    910   int best_index = -1;
    911   const int32 num_outputs = params_.num_outputs();
    912   for (int i = 0; i < num_splits(); ++i) {
    913     if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
    914       const float split_score = SplitVariance(i);
    915       if (split_score < min_score) {
    916         min_score = split_score;
    917         best_index = i;
    918       }
    919     }
    920   }
    921 
    922   // This could happen if all the splits are useless.
    923   if (best_index < 0) {
    924     return false;
    925   }
    926 
    927   // Fill in right stats to be used for leaf model.
    928   *best->mutable_split() = splits_[best_index];
    929   // Left
    930   auto* left = best->mutable_left_stats();
    931   auto* left_reg_stats = left->mutable_regression();
    932   left->set_weight_sum(left_counts_[best_index]);
    933   auto* left_output_sum = left_reg_stats->mutable_mean_output();
    934   for (int i = 0; i < num_outputs; ++i) {
    935     left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
    936   }
    937 
    938   // Right
    939   auto* right = best->mutable_right_stats();
    940   auto* right_reg_stats = right->mutable_regression();
    941   right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
    942   auto* right_output_sum = right_reg_stats->mutable_mean_output();
    943   for (int i = 0; i < num_outputs; ++i) {
    944     right_output_sum->add_value()->set_float_value(total_sum_[i] -
    945                                                    left_sum(best_index, i));
    946   }
    947   return true;
    948 }
    949 
    950 bool LeastSquaresRegressionGrowStats::IsFinished() const {
    951   return weight_sum_ >= split_after_samples_;
    952 }
    953 
    954 }  // namespace tensorforest
    955 }  // namespace tensorflow
    956