Home | History | Annotate | Download | only in v4
      1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
     16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
     17 #include <unordered_map>
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
     21 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
     22 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
     23 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
     24 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
     25 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
     26 #include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
     27 #include "tensorflow/core/lib/random/philox_random.h"
     28 #include "tensorflow/core/lib/random/simple_philox.h"
     29 
     30 namespace tensorflow {
     31 namespace tensorforest {
     32 
     33 // Base class for tracking stats necessary to split a leaf.
     34 // Holds and tracks stats for every candidate split.
     35 class GrowStats {
     36  public:
     37   virtual ~GrowStats() {}
     38   // Perform any initialization.
     39   virtual void Initialize() = 0;
     40 
     41   // Add an example to any stats being collected.
     42   virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
     43                           const InputTarget* target, int example) = 0;
     44 
     45   // Fill in the best split, return false if none were valid.
     46   virtual bool BestSplit(SplitCandidate* best) const = 0;
     47 
     48   // Return true if this leaf is finished splitting.
     49   virtual bool IsFinished() const = 0;
     50 
     51   // Get the split_num BinaryNode.
     52   const decision_trees::BinaryNode& Split(int split_num) const {
     53     return splits_[split_num];
     54   }
     55 
     56   // Clear all state.
     57   virtual void Clear() {
     58     weight_sum_ = 0;
     59     splits_.clear();
     60     evaluators_.clear();
     61     ClearInternal();
     62   }
     63 
     64   virtual void ExtractFromProto(const FertileSlot& slot) = 0;
     65   virtual void PackToProto(FertileSlot* slot) const = 0;
     66 
     67   // Add split to the list of candidate splits.
     68   void AddSplit(const decision_trees::BinaryNode& split,
     69                 const std::unique_ptr<TensorDataSet>& input_data,
     70                 const InputTarget* target, int example);
     71   virtual void AdditionalInitializationExample(
     72       const std::unique_ptr<TensorDataSet>& input_data,
     73       const InputTarget* target, int example) {}
     74   void RemoveSplit(int split_num);
     75 
     76   int num_splits() const { return splits_.size(); }
     77 
     78   float weight_sum() const { return weight_sum_; }
     79 
     80   virtual bool IsInitialized() const {
     81     return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
     82   }
     83 
     84   int32 depth() const { return depth_; }
     85 
     86  protected:
     87   GrowStats(const TensorForestParams& params, int32 depth);
     88 
     89   // Function called by AddSplit for subclasses to initialize stats for a split.
     90   virtual void AddSplitStats(const InputTarget* target, int example) = 0;
     91 
     92   virtual void RemoveSplitStats(int split_num) = 0;
     93 
     94   // Function called by Clear for subclasses to clear their state.
     95   virtual void ClearInternal() = 0;
     96 
     97   std::vector<decision_trees::BinaryNode> splits_;
     98   std::vector<std::unique_ptr<DecisionNodeEvaluator>> evaluators_;
     99 
    100   float weight_sum_;
    101 
    102   const int32 depth_;
    103 
    104   const TensorForestParams& params_;
    105 
    106   // We cache these because they're used often.
    107   const int split_after_samples_;
    108   const int num_splits_to_consider_;
    109 
    110   const int32 num_outputs_;
    111 };
    112 
    113 // Don't track anything, useful for systems that want to track split
    114 // candidates but train the model in some other way.
    115 class SimpleStats : public GrowStats {
    116  public:
    117   SimpleStats(const TensorForestParams& params, int32 depth)
    118       : GrowStats(params, depth) {}
    119   void Initialize() override {}
    120 
    121   void ExtractFromProto(const FertileSlot& slot) override {}
    122   void PackToProto(FertileSlot* slot) const override {}
    123 
    124   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
    125                   const InputTarget* target, int example) override {
    126     weight_sum_ += target->GetTargetWeight(example);
    127   }
    128 
    129   bool BestSplit(SplitCandidate* best) const override { return false; }
    130 
    131   bool IsFinished() const override {
    132     return weight_sum_ >= split_after_samples_;
    133   }
    134 
    135  protected:
    136   void AddSplitStats(const InputTarget* target, int example) override {}
    137   void RemoveSplitStats(int split_num) override {}
    138   void ClearInternal() override {}
    139 };
    140 
    141 // Tracks the sum and square of one side of a split for each Gini calculation.
    142 class RunningGiniScores {
    143  public:
    144   float sum(int split) const { return sum_[split]; }
    145   float square(int split) const { return square_[split]; }
    146 
    147   void update(int split, float old_val, float weight) {
    148     sum_[split] += weight;
    149     const float new_val = old_val + weight;
    150     square_[split] = square_[split] - old_val * old_val + new_val * new_val;
    151   }
    152 
    153   void add_split() {
    154     sum_.push_back(0);
    155     square_.push_back(0);
    156   }
    157 
    158   void remove_split(int i) {
    159     sum_.erase(sum_.begin() + i);
    160     square_.erase(square_.begin() + i);
    161   }
    162 
    163  private:
    164   std::vector<float> sum_;
    165   std::vector<float> square_;
    166 };
    167 
    168 class ClassificationStats : public GrowStats {
    169  public:
    170   ClassificationStats(const TensorForestParams& params, int32 depth);
    171 
    172   bool IsFinished() const override;
    173 
    174   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
    175                   const InputTarget* target, int example) override;
    176 
    177   void AdditionalInitializationExample(
    178       const std::unique_ptr<TensorDataSet>& input_data,
    179       const InputTarget* target, int example) override;
    180 
    181   bool IsInitialized() const override {
    182     return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ &&
    183                                half_initialized_splits_.empty());
    184   }
    185 
    186   bool BestSplit(SplitCandidate* best) const override;
    187   // When best_split_index has been chosen as the best split,
    188   // InitLeafClassStats is used to initialize the LeafStat's of the two
    189   // new leaves.
    190   virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
    191                                   LeafStat* right_stats) const = 0;
    192 
    193  protected:
    194   virtual float GiniScore(int split, float* left_sum,
    195                           float* right_sum) const = 0;
    196 
    197   // is_pure should return true if at most one class label has been seen
    198   // at the node, and false if two or more have been seen.
    199   virtual bool is_pure() const = 0;
    200   virtual float left_count(int split, int class_num) const = 0;
    201   virtual float right_count(int split, int class_num) const = 0;
    202 
    203   virtual void ClassificationAddLeftExample(int split, int64 int_label,
    204                                             float weight) = 0;
    205   virtual void ClassificationAddRightExample(int split, int64 int_label,
    206                                              float weight) {
    207     // Does nothing by default, but sub-classes can override.
    208   }
    209   virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
    210 
    211   virtual void ClassificationAddSplitStats() = 0;
    212   virtual void ClassificationRemoveSplitStats(int split) = 0;
    213 
    214   void AddSplitStats(const InputTarget* target, int example) override {
    215     if (left_gini_ != nullptr) {
    216       left_gini_->add_split();
    217       right_gini_->add_split();
    218     }
    219     if (params_.initialize_average_splits()) {
    220       if (splits_[splits_.size() - 1].has_inequality_left_child_test()) {
    221         half_initialized_splits_[splits_.size() - 1] =
    222             target->GetTargetAsClassIndex(example, 0);
    223       }
    224     }
    225     ClassificationAddSplitStats();
    226   }
    227   void RemoveSplitStats(int split) override {
    228     if (left_gini_ != nullptr) {
    229       left_gini_->remove_split(split);
    230       right_gini_->remove_split(split);
    231     }
    232     ClassificationRemoveSplitStats(split);
    233   }
    234 
    235   // Virtual so we can override these to test.
    236   virtual void CheckFinishEarly();
    237   virtual void CheckFinishEarlyHoeffding();
    238   virtual void CheckFinishEarlyBootstrap();
    239 
    240   virtual void CheckPrune();
    241 
    242   // Implement SplitPruningStrategyType::SPLIT_PRUNE_HOEFFDING.
    243   void CheckPruneHoeffding();
    244 
    245   // Return the gini score, possibly being calculated from sums and squares
    246   // saved in left_gini_ and right_gini_, otherwise calculated from raw counts.
    247   float MaybeCachedGiniScore(int split, float* left_sum,
    248                              float* right_sum) const;
    249 
    250   // Initialize the sum and squares of left_gini_ and right_gini_ for given
    251   // split and value (being extracted from a proto), if left_gini_ isn't null.
    252   void MaybeInitializeRunningCount(int split, float val) {
    253     if (left_gini_ != nullptr) {
    254       left_gini_->update(split, 0, val);
    255       right_gini_->update(split, 0, val);
    256     }
    257   }
    258 
    259   int NumBootstrapSamples() const;
    260 
    261   // Populate *weights with the smoothed per-class frequencies needed to
    262   // initialize a DistributionSampler.
    263   void MakeBootstrapWeights(int index, std::vector<float>* weights);
    264 
    265   // Accessors for RunningGiniScores objects, for testing.
    266   virtual const std::unique_ptr<RunningGiniScores>& get_left_gini() const {
    267     return left_gini_;
    268   }
    269   virtual const std::unique_ptr<RunningGiniScores>& get_right_gini() const {
    270     return right_gini_;
    271   }
    272 
    273  private:
    274   // Tracks how many check_every_samples epochs we've seen go by in weight_sum.
    275   int32 finish_sample_epoch_;
    276   int32 finish_check_every_;
    277   int32 prune_sample_epoch_;
    278   int32 prune_check_every_;
    279   bool finish_early_;
    280   int32 min_split_samples_;
    281   float dominate_fraction_;
    282   float prune_fraction_;
    283 
    284   // When using SPLIT_PRUNE_HOEFFDING, we precompute and store
    285   // 0.5 * ln(1 / (1.0 - dominate_fraction_)).
    286   float half_ln_dominate_frac_;
    287 
    288   std::unique_ptr<random::PhiloxRandom> single_rand_;
    289   std::unique_ptr<random::SimplePhilox> rng_;
    290 
    291   std::unique_ptr<RunningGiniScores> left_gini_;
    292   std::unique_ptr<RunningGiniScores> right_gini_;
    293 
    294   // Stores split number -> class that was first seen.
    295   std::unordered_map<int, int32> half_initialized_splits_;
    296 };
    297 
    298 // Tracks classification stats by storing class counts densely.
    299 class DenseClassificationGrowStats : public ClassificationStats {
    300  public:
    301   DenseClassificationGrowStats(const TensorForestParams& params, int32 depth)
    302       : ClassificationStats(params, depth) {}
    303 
    304   void Initialize() override {
    305     Clear();
    306     total_counts_.resize(num_outputs_);
    307   }
    308 
    309   void ExtractFromProto(const FertileSlot& slot) override;
    310   void PackToProto(FertileSlot* slot) const override;
    311 
    312   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
    313                           LeafStat* right_stats) const override;
    314 
    315  protected:
    316   void ClassificationAddSplitStats() override {
    317     left_counts_.resize(num_outputs_ * num_splits());
    318   }
    319   void ClassificationRemoveSplitStats(int split_num) override {
    320     left_counts_.erase(left_counts_.begin() + num_outputs_ * split_num,
    321                        left_counts_.begin() + num_outputs_ * (split_num + 1));
    322   }
    323   void ClearInternal() override {
    324     total_counts_.clear();
    325     left_counts_.clear();
    326     num_outputs_seen_ = 0;
    327   }
    328 
    329   bool is_pure() const override { return num_outputs_seen_ <= 1; }
    330 
    331   void ClassificationAddLeftExample(int split, int64 int_label,
    332                                     float weight) override {
    333     mutable_left_count(split, int_label) += weight;
    334   }
    335   void ClassificationAddTotalExample(int64 int_label, float weight) override {
    336     num_outputs_seen_ += total_counts_[int_label] == 0 && weight > 0;
    337     total_counts_[int_label] += weight;
    338   }
    339 
    340   float GiniScore(int split, float* left_sum, float* right_sum) const override;
    341 
    342   float left_count(int split, int class_num) const override {
    343     return left_counts_[split * num_outputs_ + class_num];
    344   }
    345   float right_count(int split, int class_num) const override {
    346     return total_counts_[class_num] -
    347            left_counts_[split * num_outputs_ + class_num];
    348   }
    349 
    350  private:
    351   inline float& mutable_left_count(int split, int class_num) {
    352     return left_counts_[split * num_outputs_ + class_num];
    353   }
    354   // Total class counts seen at this leaf
    355   std::vector<float> total_counts_;
    356 
    357   // Also track the number of classes seen for not splitting pure leaves.
    358   int num_outputs_seen_;
    359 
    360   // Left-branch taken class counts at this leaf for each split.
    361   // This is a flat vector for memory-performance reasons.
    362   // left_counts_[i * num_outputs_ + j] has the j-th class count for split i.
    363   std::vector<float> left_counts_;
    364 };
    365 
    366 // Tracks classification stats by storing class counts sparsely.
    367 class SparseClassificationGrowStats : public ClassificationStats {
    368  public:
    369   SparseClassificationGrowStats(const TensorForestParams& params, int32 depth)
    370       : ClassificationStats(params, depth) {}
    371 
    372   void Initialize() override { Clear(); }
    373 
    374   void ExtractFromProto(const FertileSlot& slot) override;
    375   void PackToProto(FertileSlot* slot) const override;
    376 
    377   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
    378                           LeafStat* right_stats) const override;
    379 
    380  protected:
    381   void ClassificationAddSplitStats() override {
    382     left_counts_.resize(num_splits());
    383   }
    384   void ClassificationRemoveSplitStats(int split_num) override {
    385     left_counts_.erase(left_counts_.begin() + split_num,
    386                        left_counts_.begin() + (split_num + 1));
    387   }
    388   void ClearInternal() override {
    389     total_counts_.clear();
    390     left_counts_.clear();
    391   }
    392 
    393   bool is_pure() const override { return total_counts_.size() <= 1; }
    394 
    395   void ClassificationAddLeftExample(int split, int64 int_label,
    396                                     float weight) override {
    397     left_counts_[split][int_label] += weight;
    398   }
    399   void ClassificationAddTotalExample(int64 int_label, float weight) override {
    400     total_counts_[int_label] += weight;
    401   }
    402 
    403   float GiniScore(int split, float* left_sum, float* right_sum) const override;
    404 
    405   float left_count(int split, int class_num) const override {
    406     return left_counts_[split].at(class_num);
    407   }
    408   float right_count(int split, int class_num) const override {
    409     return total_counts_.at(class_num) - left_counts_[split].at(class_num);
    410   }
    411 
    412  private:
    413   // Total class counts seen at this leaf
    414   std::unordered_map<int, float> total_counts_;
    415 
    416   // Left-branch taken class counts at this leaf for each split.
    417   // left_counts_[i][j] has the j-th class count for split i.
    418   std::vector<std::unordered_map<int, float>> left_counts_;
    419 };
    420 
    421 // Accumulates weights for the most popular classes while only using a
    422 // fixed amount of space.
    423 class FixedSizeClassStats {
    424  public:
    425   // n specifies how many classes are tracked.
    426   FixedSizeClassStats(int n, int num_classes)
    427       : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
    428 
    429   // Add weight w to the class c.
    430   void accumulate(int c, float w);
    431 
    432   // Return the approximate accumulated weight for class c.  If c isn't one
    433   // of the n-most popular classes, this can be 0 even if c has accumulated
    434   // some weight.
    435   float get_weight(int c) const;
    436 
    437   // Put the sum of all weights seen into *sum, and
    438   // \sum_c get_weight(c)^2
    439   // into *square.  *sum will be exact, but *square will be approximate.
    440   void set_sum_and_square(float* sum, float* square) const;
    441 
    442   void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
    443   void PackToProto(decision_trees::SparseVector* sparse_vector) const;
    444 
    445  private:
    446   // For our typical use cases, n_ is between 10 and 100, so there's no
    447   // need to track the smallest weight with a min_heap or the like.
    448   int n_;
    449   int num_classes_;
    450 
    451   // This tracks the class of the smallest weight, but isn't set until
    452   // class_weights_.size() == n_.
    453   int smallest_weight_class_;
    454 
    455   std::unordered_map<int, float> class_weights_;
    456 };
    457 
    458 // Tracks classification stats sparsely in a fixed amount of space.
    459 class FixedSizeSparseClassificationGrowStats : public ClassificationStats {
    460  public:
    461   FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
    462                                          int32 depth)
    463       : ClassificationStats(params, depth) {}
    464 
    465   void Initialize() override { Clear(); }
    466 
    467   void ExtractFromProto(const FertileSlot& slot) override;
    468   void PackToProto(FertileSlot* slot) const override;
    469 
    470   void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
    471                           LeafStat* right_stats) const override;
    472 
    473  protected:
    474   void ClassificationAddSplitStats() override {
    475     FixedSizeClassStats stats(params_.num_classes_to_track(),
    476                               params_.num_outputs());
    477     left_counts_.resize(num_splits(), stats);
    478     right_counts_.resize(num_splits(), stats);
    479   }
    480   void ClassificationRemoveSplitStats(int split_num) override {
    481     left_counts_.erase(left_counts_.begin() + split_num,
    482                        left_counts_.begin() + (split_num + 1));
    483     right_counts_.erase(right_counts_.begin() + split_num,
    484                         right_counts_.begin() + (split_num + 1));
    485   }
    486   void ClearInternal() override {
    487     left_counts_.clear();
    488     right_counts_.clear();
    489   }
    490 
    491   bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
    492 
    493   void ClassificationAddLeftExample(int split, int64 int_label,
    494                                     float weight) override {
    495     left_counts_[split].accumulate(int_label, weight);
    496   }
    497   void ClassificationAddRightExample(int split, int64 int_label,
    498                                      float weight) override {
    499     right_counts_[split].accumulate(int_label, weight);
    500   }
    501   void ClassificationAddTotalExample(int64 int_label, float weight) override {
    502     if (is_pure()) {
    503       first_two_classes_seen_.insert(int_label);
    504     }
    505   }
    506 
    507   float GiniScore(int split, float* left_sum, float* right_sum) const override;
    508 
    509   float left_count(int split, int class_num) const override {
    510     return left_counts_[split].get_weight(class_num);
    511   }
    512 
    513   float right_count(int split, int class_num) const override {
    514     return right_counts_[split].get_weight(class_num);
    515   }
    516 
    517  private:
    518   std::vector<FixedSizeClassStats> left_counts_;
    519   std::vector<FixedSizeClassStats> right_counts_;
    520 
    521   // We keep track of the first two class labels seen, so we can tell if
    522   // the node is pure (= all of one class) or not.
    523   std::set<int> first_two_classes_seen_;
    524 };
    525 
    526 // Tracks regression stats using least-squares minimization.
    527 class LeastSquaresRegressionGrowStats : public GrowStats {
    528  public:
    529   LeastSquaresRegressionGrowStats(const TensorForestParams& params, int32 depth)
    530       : GrowStats(params, depth) {}
    531 
    532   void Initialize() override {
    533     Clear();
    534     total_sum_.resize(num_outputs_);
    535     total_sum_squares_.resize(num_outputs_);
    536   }
    537 
    538   void ExtractFromProto(const FertileSlot& slot) override;
    539   void PackToProto(FertileSlot* slot) const override;
    540 
    541   void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
    542                   const InputTarget* target, int example) override;
    543   bool BestSplit(SplitCandidate* best) const override;
    544   bool IsFinished() const override;
    545 
    546  protected:
    547   // Returns the variance of split.
    548   float SplitVariance(int split) const;
    549 
    550   void AddSplitStats(const InputTarget* target, int example) override {
    551     left_sums_.resize(num_outputs_ * num_splits());
    552     left_squares_.resize(num_outputs_ * num_splits());
    553     left_counts_.push_back(0);
    554   }
    555   void RemoveSplitStats(int split_num) override {
    556     left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num,
    557                      left_sums_.begin() + num_outputs_ * (split_num + 1));
    558     left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num,
    559                         left_squares_.begin() + num_outputs_ * (split_num + 1));
    560     left_counts_.erase(left_counts_.begin() + split_num,
    561                        left_counts_.begin() + (split_num + 1));
    562   }
    563 
    564   void ClearInternal() override {
    565     total_sum_.clear();
    566     total_sum_squares_.clear();
    567     left_sums_.clear();
    568     left_squares_.clear();
    569   }
    570 
    571  private:
    572   // Convenience methods for accessing the flat count vectors.
    573   inline const float& left_sum(int split, int output_num) const {
    574     return left_sums_[split * num_outputs_ + output_num];
    575   }
    576   inline float& left_sum(int split, int output_num) {
    577     return left_sums_[split * num_outputs_ + output_num];
    578   }
    579   inline const float& left_square(int split, int output_num) const {
    580     return left_squares_[split * num_outputs_ + output_num];
    581   }
    582   inline float& left_square(int split, int output_num) {
    583     return left_squares_[split * num_outputs_ + output_num];
    584   }
    585 
    586   // Total sums and squares seen at this leaf.
    587   // sum[i] is the sum of the i-th output.
    588   std::vector<float> total_sum_;
    589   std::vector<float> total_sum_squares_;
    590 
    591   // Per-split sums and squares, stored flat for performance.
    592   // left_sums_[i * num_outputs_ + j] has the j-th sum for split i.
    593   std::vector<float> left_sums_;
    594   std::vector<float> left_squares_;
    595 
    596   // The number of example seen at each split.
    597   std::vector<int64> left_counts_;
    598 };
    599 
    600 }  // namespace tensorforest
    601 }  // namespace tensorflow
    602 
    603 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
    604