Home | History | Annotate | Download | only in gtl
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This simple class finds the top n elements of an incrementally provided set
     17 // of elements which you push one at a time.  If the number of elements exceeds
     18 // n, the lowest elements are incrementally dropped.  At the end you get
     19 // a vector of the top elements sorted in descending order (through Extract() or
     20 // ExtractNondestructive()), or a vector of the top elements but not sorted
     21 // (through ExtractUnsorted() or ExtractUnsortedNondestructive()).
     22 //
     23 // The value n is specified in the constructor.  If there are p elements pushed
     24 // altogether:
     25 //   The total storage requirements are O(min(n, p)) elements
     26 //   The running time is O(p * log(min(n, p))) comparisons
     27 // If n is a constant, the total storage required is a constant and the running
     28 // time is linear in p.
     29 //
     30 // NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p)
     31 // runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements,
     32 // discarding the lowest n elements whenever the buffer is full using a linear-
     33 // time median algorithm. This may have better performance when the input
     34 // sequence is partially sorted.
     35 //
     36 // NOTE(zhifengc): This class should be redesigned to avoid reallocating a
     37 // vector for each Extract.
     38 
     39 #ifndef TENSORFLOW_LIB_GTL_TOP_N_H_
     40 #define TENSORFLOW_LIB_GTL_TOP_N_H_
     41 
     42 #include <stddef.h>
     43 #include <algorithm>
     44 #include <functional>
     45 #include <string>
     46 #include <vector>
     47 
     48 #include "tensorflow/core/platform/logging.h"
     49 
     50 namespace tensorflow {
     51 namespace gtl {
     52 
     53 // Cmp is an stl binary predicate.  Note that Cmp is the "greater" predicate,
     54 // not the more commonly used "less" predicate.
     55 //
     56 // If you use a "less" predicate here, the TopN will pick out the bottom N
     57 // elements out of the ones passed to it, and it will return them sorted in
     58 // ascending order.
     59 //
     60 // TopN is rule-of-zero copyable and movable if its members are.
     61 template <class T, class Cmp = std::greater<T> >
     62 class TopN {
     63  public:
     64   // The TopN is in one of the three states:
     65   //
     66   //  o UNORDERED: this is the state an instance is originally in,
     67   //    where the elements are completely orderless.
     68   //
     69   //  o BOTTOM_KNOWN: in this state, we keep the invariant that there
     70   //    is at least one element in it, and the lowest element is at
     71   //    position 0. The elements in other positions remain
     72   //    unsorted. This state is reached if the state was originally
     73   //    UNORDERED and a peek_bottom() function call is invoked.
     74   //
     75   //  o HEAP_SORTED: in this state, the array is kept as a heap and
     76   //    there are exactly (limit_+1) elements in the array. This
     77   //    state is reached when at least (limit_+1) elements are
     78   //    pushed in.
     79   //
     80   //  The state transition graph is at follows:
     81   //
     82   //             peek_bottom()                (limit_+1) elements
     83   //  UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
     84   //      |                                                           ^
     85   //      |                      (limit_+1) elements                  |
     86   //      +-----------------------------------------------------------+
     87 
     88   enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
     89   using UnsortedIterator = typename std::vector<T>::const_iterator;
     90 
     91   // 'limit' is the maximum number of top results to return.
     92   explicit TopN(size_t limit) : TopN(limit, Cmp()) {}
     93   TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {}
     94 
     95   size_t limit() const { return limit_; }
     96 
     97   // Number of elements currently held by this TopN object.  This
     98   // will be no greater than 'limit' passed to the constructor.
     99   size_t size() const { return std::min(elements_.size(), limit_); }
    100 
    101   bool empty() const { return size() == 0; }
    102 
    103   // If you know how many elements you will push at the time you create the
    104   // TopN object, you can call reserve to preallocate the memory that TopN
    105   // will need to process all 'n' pushes.  Calling this method is optional.
    106   void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); }
    107 
    108   // Push 'v'.  If the maximum number of elements was exceeded, drop the
    109   // lowest element and return it in 'dropped' (if given). If the maximum is not
    110   // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or
    111   // nullptr, in which case it is not filled in.
    112   // Requires: T is CopyAssignable, Swappable
    113   void push(const T &v) { push(v, nullptr); }
    114   void push(const T &v, T *dropped) { PushInternal(v, dropped); }
    115 
    116   // Move overloads of push.
    117   // Requires: T is MoveAssignable, Swappable
    118   void push(T &&v) {  // NOLINT(build/c++11)
    119     push(std::move(v), nullptr);
    120   }
    121   void push(T &&v, T *dropped) {  // NOLINT(build/c++11)
    122     PushInternal(std::move(v), dropped);
    123   }
    124 
    125   // Peeks the bottom result without calling Extract()
    126   const T &peek_bottom();
    127 
    128   // Extract the elements as a vector sorted in descending order.  The caller
    129   // assumes ownership of the vector and must delete it when done.  This is a
    130   // destructive operation.  The only method that can be called immediately
    131   // after Extract() is Reset().
    132   std::vector<T> *Extract();
    133 
    134   // Similar to Extract(), but makes no guarantees the elements are in sorted
    135   // order.  As with Extract(), the caller assumes ownership of the vector and
    136   // must delete it when done.  This is a destructive operation.  The only
    137   // method that can be called immediately after ExtractUnsorted() is Reset().
    138   std::vector<T> *ExtractUnsorted();
    139 
    140   // A non-destructive version of Extract(). Copy the elements in a new vector
    141   // sorted in descending order and return it.  The caller assumes ownership of
    142   // the new vector and must delete it when done.  After calling
    143   // ExtractNondestructive(), the caller can continue to push() new elements.
    144   std::vector<T> *ExtractNondestructive() const;
    145 
    146   // A non-destructive version of Extract(). Copy the elements to a given
    147   // vector sorted in descending order. After calling
    148   // ExtractNondestructive(), the caller can continue to push() new elements.
    149   // Note:
    150   //  1. The given argument must to be allocated.
    151   //  2. Any data contained in the vector prior to the call will be deleted
    152   //     from it. After the call the vector will contain only the elements
    153   //     from the data structure.
    154   void ExtractNondestructive(std::vector<T> *output) const;
    155 
    156   // A non-destructive version of ExtractUnsorted(). Copy the elements in a new
    157   // vector and return it, with no guarantees the elements are in sorted order.
    158   // The caller assumes ownership of the new vector and must delete it when
    159   // done.  After calling ExtractUnsortedNondestructive(), the caller can
    160   // continue to push() new elements.
    161   std::vector<T> *ExtractUnsortedNondestructive() const;
    162 
    163   // A non-destructive version of ExtractUnsorted(). Copy the elements into
    164   // a given vector, with no guarantees the elements are in sorted order.
    165   // After calling ExtractUnsortedNondestructive(), the caller can continue
    166   // to push() new elements.
    167   // Note:
    168   //  1. The given argument must to be allocated.
    169   //  2. Any data contained in the vector prior to the call will be deleted
    170   //     from it. After the call the vector will contain only the elements
    171   //     from the data structure.
    172   void ExtractUnsortedNondestructive(std::vector<T> *output) const;
    173 
    174   // Return an iterator to the beginning (end) of the container,
    175   // with no guarantees about the order of iteration. These iterators are
    176   // invalidated by mutation of the data structure.
    177   UnsortedIterator unsorted_begin() const { return elements_.begin(); }
    178   UnsortedIterator unsorted_end() const { return elements_.begin() + size(); }
    179 
    180   // Accessor for comparator template argument.
    181   Cmp *comparator() { return &cmp_; }
    182 
    183   // This removes all elements.  If Extract() or ExtractUnsorted() have been
    184   // called, this will put it back in an empty but useable state.
    185   void Reset();
    186 
    187  private:
    188   template <typename U>
    189   void PushInternal(U &&v, T *dropped);  // NOLINT(build/c++11)
    190 
    191   // elements_ can be in one of two states:
    192   //   elements_.size() <= limit_:  elements_ is an unsorted vector of elements
    193   //      pushed so far.
    194   //   elements_.size() > limit_:  The last element of elements_ is unused;
    195   //      the other elements of elements_ are an stl heap whose size is exactly
    196   //      limit_.  In this case elements_.size() is exactly one greater than
    197   //      limit_, but don't use "elements_.size() == limit_ + 1" to check for
    198   //      that because you'll get a false positive if limit_ == size_t(-1).
    199   std::vector<T> elements_;
    200   size_t limit_;  // Maximum number of elements to find
    201   Cmp cmp_;       // Greater-than comparison function
    202   State state_ = UNORDERED;
    203 };
    204 
    205 // ----------------------------------------------------------------------
    206 // Implementations of non-inline functions
    207 
    208 template <class T, class Cmp>
    209 template <typename U>
    210 void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) {  // NOLINT(build/c++11)
    211   if (limit_ == 0) {
    212     if (dropped) *dropped = std::forward<U>(v);  // NOLINT(build/c++11)
    213     return;
    214   }
    215   if (state_ != HEAP_SORTED) {
    216     elements_.push_back(std::forward<U>(v));  // NOLINT(build/c++11)
    217     if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) {
    218       // Easy case: we just pushed the new element back
    219     } else {
    220       // To maintain the BOTTOM_KNOWN state, we need to make sure that
    221       // the element at position 0 is always the smallest. So we put
    222       // the new element at position 0 and push the original bottom
    223       // element in the back.
    224       // Warning: this code is subtle.
    225       using std::swap;
    226       swap(elements_.front(), elements_.back());
    227     }
    228     if (elements_.size() == limit_ + 1) {
    229       // Transition from unsorted vector to a heap.
    230       std::make_heap(elements_.begin(), elements_.end(), cmp_);
    231       if (dropped) *dropped = std::move(elements_.front());
    232       std::pop_heap(elements_.begin(), elements_.end(), cmp_);
    233       state_ = HEAP_SORTED;
    234     }
    235   } else {
    236     // Only insert the new element if it is greater than the least element.
    237     if (cmp_(v, elements_.front())) {
    238       elements_.back() = std::forward<U>(v);  // NOLINT(build/c++11)
    239       std::push_heap(elements_.begin(), elements_.end(), cmp_);
    240       if (dropped) *dropped = std::move(elements_.front());
    241       std::pop_heap(elements_.begin(), elements_.end(), cmp_);
    242     } else {
    243       if (dropped) *dropped = std::forward<U>(v);  // NOLINT(build/c++11)
    244     }
    245   }
    246 }
    247 
    248 template <class T, class Cmp>
    249 const T &TopN<T, Cmp>::peek_bottom() {
    250   CHECK(!empty());
    251   if (state_ == UNORDERED) {
    252     // We need to do a linear scan to find out the bottom element
    253     int min_candidate = 0;
    254     for (size_t i = 1; i < elements_.size(); ++i) {
    255       if (cmp_(elements_[min_candidate], elements_[i])) {
    256         min_candidate = i;
    257       }
    258     }
    259     // By swapping the element at position 0 and the minimal
    260     // element, we transition to the BOTTOM_KNOWN state
    261     if (min_candidate != 0) {
    262       using std::swap;
    263       swap(elements_[0], elements_[min_candidate]);
    264     }
    265     state_ = BOTTOM_KNOWN;
    266   }
    267   return elements_.front();
    268 }
    269 
    270 template <class T, class Cmp>
    271 std::vector<T> *TopN<T, Cmp>::Extract() {
    272   auto out = new std::vector<T>;
    273   out->swap(elements_);
    274   if (state_ != HEAP_SORTED) {
    275     std::sort(out->begin(), out->end(), cmp_);
    276   } else {
    277     out->pop_back();
    278     std::sort_heap(out->begin(), out->end(), cmp_);
    279   }
    280   return out;
    281 }
    282 
    283 template <class T, class Cmp>
    284 std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() {
    285   auto out = new std::vector<T>;
    286   out->swap(elements_);
    287   if (state_ == HEAP_SORTED) {
    288     // Remove the limit_+1'th element.
    289     out->pop_back();
    290   }
    291   return out;
    292 }
    293 
    294 template <class T, class Cmp>
    295 std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const {
    296   auto out = new std::vector<T>;
    297   ExtractNondestructive(out);
    298   return out;
    299 }
    300 
    301 template <class T, class Cmp>
    302 void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
    303   CHECK(output);
    304   *output = elements_;
    305   if (state_ != HEAP_SORTED) {
    306     std::sort(output->begin(), output->end(), cmp_);
    307   } else {
    308     output->pop_back();
    309     std::sort_heap(output->begin(), output->end(), cmp_);
    310   }
    311 }
    312 
    313 template <class T, class Cmp>
    314 std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const {
    315   auto elements = new std::vector<T>;
    316   ExtractUnsortedNondestructive(elements);
    317   return elements;
    318 }
    319 
    320 template <class T, class Cmp>
    321 void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
    322   CHECK(output);
    323   *output = elements_;
    324   if (state_ == HEAP_SORTED) {
    325     // Remove the limit_+1'th element.
    326     output->pop_back();
    327   }
    328 }
    329 
    330 template <class T, class Cmp>
    331 void TopN<T, Cmp>::Reset() {
    332   elements_.clear();
    333   state_ = UNORDERED;
    334 }
    335 
    336 }  // namespace gtl
    337 }  // namespace tensorflow
    338 
    339 #endif  // TENSORFLOW_LIB_GTL_TOP_N_H_
    340