Home | History | Annotate | Download | only in quantiles
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // =============================================================================
     15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
     16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
     17 
     18 #include <cstring>
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
     22 
     23 namespace tensorflow {
     24 namespace boosted_trees {
     25 namespace quantiles {
     26 
     27 // Summary holding a sorted block of entries with upper bound guarantees
     28 // over the approximation error.
     29 template <typename ValueType, typename WeightType,
     30           typename CompareFn = std::less<ValueType>>
     31 class WeightedQuantilesSummary {
     32  public:
     33   using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
     34   using BufferEntry = typename Buffer::BufferEntry;
     35 
     36   struct SummaryEntry {
     37     SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
     38                  const WeightType& max) {
     39       value = v;
     40       weight = w;
     41       min_rank = min;
     42       max_rank = max;
     43     }
     44 
     45     SummaryEntry() {
     46       value = ValueType();
     47       weight = 0;
     48       min_rank = 0;
     49       max_rank = 0;
     50     }
     51 
     52     bool operator==(const SummaryEntry& other) const {
     53       return value == other.value && weight == other.weight &&
     54              min_rank == other.min_rank && max_rank == other.max_rank;
     55     }
     56     friend std::ostream& operator<<(std::ostream& strm,
     57                                     const SummaryEntry& entry) {
     58       return strm << "{" << entry.value << ", " << entry.weight << ", "
     59                   << entry.min_rank << ", " << entry.max_rank << "}";
     60     }
     61 
     62     // Max rank estimate for previous smaller value.
     63     WeightType PrevMaxRank() const { return max_rank - weight; }
     64 
     65     // Min rank estimate for next larger value.
     66     WeightType NextMinRank() const { return min_rank + weight; }
     67 
     68     ValueType value;
     69     WeightType weight;
     70     WeightType min_rank;
     71     WeightType max_rank;
     72   };
     73 
     74   // Re-construct summary from the specified buffer.
     75   void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
     76     entries_.clear();
     77     entries_.reserve(buffer_entries.size());
     78     WeightType cumulative_weight = 0;
     79     for (const auto& entry : buffer_entries) {
     80       WeightType current_weight = entry.weight;
     81       entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
     82                             cumulative_weight + current_weight);
     83       cumulative_weight += current_weight;
     84     }
     85   }
     86 
     87   // Re-construct summary from the specified summary entries.
     88   void BuildFromSummaryEntries(
     89       const std::vector<SummaryEntry>& summary_entries) {
     90     entries_.clear();
     91     entries_.reserve(summary_entries.size());
     92     entries_.insert(entries_.begin(), summary_entries.begin(),
     93                     summary_entries.end());
     94   }
     95 
     96   // Merges two summaries through an algorithm that's derived from MergeSort
     97   // for summary entries while guaranteeing that the max approximation error
     98   // of the final merged summary is no greater than the approximation errors
     99   // of each individual summary.
    100   // For example consider summaries where each entry is of the form
    101   // (element, weight, min rank, max rank):
    102   // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
    103   // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
    104   // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
    105   void Merge(const WeightedQuantilesSummary& other_summary) {
    106     // Make sure we have something to merge.
    107     const auto& other_entries = other_summary.entries_;
    108     if (other_entries.empty()) {
    109       return;
    110     }
    111     if (entries_.empty()) {
    112       BuildFromSummaryEntries(other_summary.entries_);
    113       return;
    114     }
    115 
    116     // Move current entries to make room for a new buffer.
    117     std::vector<SummaryEntry> base_entries(std::move(entries_));
    118     entries_.clear();
    119     entries_.reserve(base_entries.size() + other_entries.size());
    120 
    121     // Merge entries maintaining ranks. The idea is to stack values
    122     // in order which we can do in linear time as the two summaries are
    123     // already sorted. We keep track of the next lower rank from either
    124     // summary and update it as we pop elements from the summaries.
    125     // We handle the special case when the next two elements from either
    126     // summary are equal, in which case we just merge the two elements
    127     // and simultaneously update both ranks.
    128     auto it1 = base_entries.cbegin();
    129     auto it2 = other_entries.cbegin();
    130     WeightType next_min_rank1 = 0;
    131     WeightType next_min_rank2 = 0;
    132     while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
    133       if (kCompFn(it1->value, it2->value)) {  // value1 < value2
    134         // Take value1 and use the last added value2 to compute
    135         // the min rank and the current value2 to compute the max rank.
    136         entries_.emplace_back(it1->value, it1->weight,
    137                               it1->min_rank + next_min_rank2,
    138                               it1->max_rank + it2->PrevMaxRank());
    139         // Update next min rank 1.
    140         next_min_rank1 = it1->NextMinRank();
    141         ++it1;
    142       } else if (kCompFn(it2->value, it1->value)) {  // value1 > value2
    143         // Take value2 and use the last added value1 to compute
    144         // the min rank and the current value1 to compute the max rank.
    145         entries_.emplace_back(it2->value, it2->weight,
    146                               it2->min_rank + next_min_rank1,
    147                               it2->max_rank + it1->PrevMaxRank());
    148         // Update next min rank 2.
    149         next_min_rank2 = it2->NextMinRank();
    150         ++it2;
    151       } else {  // value1 == value2
    152         // Straight additive merger of the two entries into one.
    153         entries_.emplace_back(it1->value, it1->weight + it2->weight,
    154                               it1->min_rank + it2->min_rank,
    155                               it1->max_rank + it2->max_rank);
    156         // Update next min ranks for both.
    157         next_min_rank1 = it1->NextMinRank();
    158         next_min_rank2 = it2->NextMinRank();
    159         ++it1;
    160         ++it2;
    161       }
    162     }
    163 
    164     // Fill in any residual.
    165     while (it1 != base_entries.cend()) {
    166       entries_.emplace_back(it1->value, it1->weight,
    167                             it1->min_rank + next_min_rank2,
    168                             it1->max_rank + other_entries.back().max_rank);
    169       ++it1;
    170     }
    171     while (it2 != other_entries.cend()) {
    172       entries_.emplace_back(it2->value, it2->weight,
    173                             it2->min_rank + next_min_rank1,
    174                             it2->max_rank + base_entries.back().max_rank);
    175       ++it2;
    176     }
    177   }
    178 
    179   // Compresses buffer into desired size. The size specification is
    180   // considered a hint as we always keep the first and last elements and
    181   // maintain strict approximation error bounds.
    182   // The approximation error delta is taken as the max of either the requested
    183   // min error or 1 / size_hint.
    184   // After compression, the approximation error is guaranteed to increase
    185   // by no more than that error delta.
    186   // This algorithm is linear in the original size of the summary and is
    187   // designed to be cache-friendly.
    188   void Compress(int64 size_hint, double min_eps = 0) {
    189     // No-op if we're already within the size requirement.
    190     size_hint = std::max(size_hint, int64{2});
    191     if (entries_.size() <= size_hint) {
    192       return;
    193     }
    194 
    195     // First compute the max error bound delta resulting from this compression.
    196     double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
    197 
    198     // Compress elements ensuring approximation bounds and elements diversity
    199     // are both maintained.
    200     int64 add_accumulator = 0, add_step = entries_.size();
    201     auto write_it = entries_.begin() + 1, last_it = write_it;
    202     for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
    203       auto next_it = read_it + 1;
    204       while (next_it != entries_.end() && add_accumulator < add_step &&
    205              next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
    206         add_accumulator += size_hint;
    207         ++next_it;
    208       }
    209       if (read_it == next_it - 1) {
    210         ++read_it;
    211       } else {
    212         read_it = next_it - 1;
    213       }
    214       (*write_it++) = (*read_it);
    215       last_it = read_it;
    216       add_accumulator -= add_step;
    217     }
    218     // Write last element and resize.
    219     if (last_it + 1 != entries_.end()) {
    220       (*write_it++) = entries_.back();
    221     }
    222     entries_.resize(write_it - entries_.begin());
    223   }
    224 
    225   // To construct the boundaries we first run a soft compress over a copy
    226   // of the summary and retrieve the values.
    227   // The resulting boundaries are guaranteed to both contain at least
    228   // num_boundaries unique elements and maintain approximation bounds.
    229   std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
    230     std::vector<ValueType> output;
    231     if (entries_.empty()) {
    232       return output;
    233     }
    234 
    235     // Generate soft compressed summary.
    236     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
    237         compressed_summary;
    238     compressed_summary.BuildFromSummaryEntries(entries_);
    239     // Set an epsilon for compression that's at most 1.0 / num_boundaries
    240     // more than epsilon of original our summary since the compression operation
    241     // adds ~1.0/num_boundaries to final approximation error.
    242     float compression_eps = ApproximationError() + (1.0 / num_boundaries);
    243     compressed_summary.Compress(num_boundaries, compression_eps);
    244 
    245     // Return boundaries.
    246     output.reserve(compressed_summary.entries_.size());
    247     for (const auto& entry : compressed_summary.entries_) {
    248       output.push_back(entry.value);
    249     }
    250     return output;
    251   }
    252 
    253   // To construct the desired n-quantiles we repetitively query n ranks from the
    254   // original summary. The following algorithm is an efficient cache-friendly
    255   // O(n) implementation of that idea which avoids the cost of the repetitive
    256   // full rank queries O(nlogn).
    257   std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
    258     std::vector<ValueType> output;
    259     if (entries_.empty()) {
    260       return output;
    261     }
    262     num_quantiles = std::max(num_quantiles, int64{2});
    263     output.reserve(num_quantiles + 1);
    264 
    265     // Make successive rank queries to get boundaries.
    266     // We always keep the first (min) and last (max) entries.
    267     for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
    268       // This step boils down to finding the next element sub-range defined by
    269       // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
    270       WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
    271       size_t next_idx = cur_idx + 1;
    272       while (next_idx < entries_.size() &&
    273              d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
    274         ++next_idx;
    275       }
    276       cur_idx = next_idx - 1;
    277 
    278       // Determine insertion order.
    279       if (next_idx == entries_.size() ||
    280           d_2 < entries_[cur_idx].NextMinRank() +
    281                     entries_[next_idx].PrevMaxRank()) {
    282         output.push_back(entries_[cur_idx].value);
    283       } else {
    284         output.push_back(entries_[next_idx].value);
    285       }
    286     }
    287     return output;
    288   }
    289 
    290   // Calculates current approximation error which should always be <= eps.
    291   double ApproximationError() const {
    292     if (entries_.empty()) {
    293       return 0;
    294     }
    295 
    296     WeightType max_gap = 0;
    297     for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
    298       max_gap = std::max(max_gap,
    299                          std::max(it->max_rank - it->min_rank - it->weight,
    300                                   it->PrevMaxRank() - (it - 1)->NextMinRank()));
    301     }
    302     return static_cast<double>(max_gap) / TotalWeight();
    303   }
    304 
    305   ValueType MinValue() const {
    306     return !entries_.empty() ? entries_.front().value
    307                              : std::numeric_limits<ValueType>::max();
    308   }
    309   ValueType MaxValue() const {
    310     return !entries_.empty() ? entries_.back().value
    311                              : std::numeric_limits<ValueType>::max();
    312   }
    313   WeightType TotalWeight() const {
    314     return !entries_.empty() ? entries_.back().max_rank : 0;
    315   }
    316   int64 Size() const { return entries_.size(); }
    317   void Clear() { entries_.clear(); }
    318   const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
    319 
    320  private:
    321   // Comparison function.
    322   static constexpr decltype(CompareFn()) kCompFn = CompareFn();
    323 
    324   // Summary entries.
    325   std::vector<SummaryEntry> entries_;
    326 };
    327 
    328 template <typename ValueType, typename WeightType, typename CompareFn>
    329 constexpr decltype(CompareFn())
    330     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
    331 
    332 }  // namespace quantiles
    333 }  // namespace boosted_trees
    334 }  // namespace tensorflow
    335 
    336 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
    337