1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This simple class finds the top n elements of an incrementally provided set 17 // of elements which you push one at a time. If the number of elements exceeds 18 // n, the lowest elements are incrementally dropped. At the end you get 19 // a vector of the top elements sorted in descending order (through Extract() or 20 // ExtractNondestructive()), or a vector of the top elements but not sorted 21 // (through ExtractUnsorted() or ExtractUnsortedNondestructive()). 22 // 23 // The value n is specified in the constructor. If there are p elements pushed 24 // altogether: 25 // The total storage requirements are O(min(n, p)) elements 26 // The running time is O(p * log(min(n, p))) comparisons 27 // If n is a constant, the total storage required is a constant and the running 28 // time is linear in p. 29 // 30 // NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) 31 // runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, 32 // discarding the lowest n elements whenever the buffer is full using a linear- 33 // time median algorithm. This may have better performance when the input 34 // sequence is partially sorted. 35 // 36 // NOTE(zhifengc): This class should be redesigned to avoid reallocating a 37 // vector for each Extract. 38 39 #ifndef TENSORFLOW_LIB_GTL_TOP_N_H_ 40 #define TENSORFLOW_LIB_GTL_TOP_N_H_ 41 42 #include <stddef.h> 43 #include <algorithm> 44 #include <functional> 45 #include <string> 46 #include <vector> 47 48 #include "tensorflow/core/platform/logging.h" 49 50 namespace tensorflow { 51 namespace gtl { 52 53 // Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, 54 // not the more commonly used "less" predicate. 55 // 56 // If you use a "less" predicate here, the TopN will pick out the bottom N 57 // elements out of the ones passed to it, and it will return them sorted in 58 // ascending order. 59 // 60 // TopN is rule-of-zero copyable and movable if its members are. 61 template <class T, class Cmp = std::greater<T> > 62 class TopN { 63 public: 64 // The TopN is in one of the three states: 65 // 66 // o UNORDERED: this is the state an instance is originally in, 67 // where the elements are completely orderless. 68 // 69 // o BOTTOM_KNOWN: in this state, we keep the invariant that there 70 // is at least one element in it, and the lowest element is at 71 // position 0. The elements in other positions remain 72 // unsorted. This state is reached if the state was originally 73 // UNORDERED and a peek_bottom() function call is invoked. 74 // 75 // o HEAP_SORTED: in this state, the array is kept as a heap and 76 // there are exactly (limit_+1) elements in the array. This 77 // state is reached when at least (limit_+1) elements are 78 // pushed in. 79 // 80 // The state transition graph is at follows: 81 // 82 // peek_bottom() (limit_+1) elements 83 // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED 84 // | ^ 85 // | (limit_+1) elements | 86 // +-----------------------------------------------------------+ 87 88 enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; 89 using UnsortedIterator = typename std::vector<T>::const_iterator; 90 91 // 'limit' is the maximum number of top results to return. 92 explicit TopN(size_t limit) : TopN(limit, Cmp()) {} 93 TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} 94 95 size_t limit() const { return limit_; } 96 97 // Number of elements currently held by this TopN object. This 98 // will be no greater than 'limit' passed to the constructor. 99 size_t size() const { return std::min(elements_.size(), limit_); } 100 101 bool empty() const { return size() == 0; } 102 103 // If you know how many elements you will push at the time you create the 104 // TopN object, you can call reserve to preallocate the memory that TopN 105 // will need to process all 'n' pushes. Calling this method is optional. 106 void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } 107 108 // Push 'v'. If the maximum number of elements was exceeded, drop the 109 // lowest element and return it in 'dropped' (if given). If the maximum is not 110 // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or 111 // nullptr, in which case it is not filled in. 112 // Requires: T is CopyAssignable, Swappable 113 void push(const T &v) { push(v, nullptr); } 114 void push(const T &v, T *dropped) { PushInternal(v, dropped); } 115 116 // Move overloads of push. 117 // Requires: T is MoveAssignable, Swappable 118 void push(T &&v) { // NOLINT(build/c++11) 119 push(std::move(v), nullptr); 120 } 121 void push(T &&v, T *dropped) { // NOLINT(build/c++11) 122 PushInternal(std::move(v), dropped); 123 } 124 125 // Peeks the bottom result without calling Extract() 126 const T &peek_bottom(); 127 128 // Extract the elements as a vector sorted in descending order. The caller 129 // assumes ownership of the vector and must delete it when done. This is a 130 // destructive operation. The only method that can be called immediately 131 // after Extract() is Reset(). 132 std::vector<T> *Extract(); 133 134 // Similar to Extract(), but makes no guarantees the elements are in sorted 135 // order. As with Extract(), the caller assumes ownership of the vector and 136 // must delete it when done. This is a destructive operation. The only 137 // method that can be called immediately after ExtractUnsorted() is Reset(). 138 std::vector<T> *ExtractUnsorted(); 139 140 // A non-destructive version of Extract(). Copy the elements in a new vector 141 // sorted in descending order and return it. The caller assumes ownership of 142 // the new vector and must delete it when done. After calling 143 // ExtractNondestructive(), the caller can continue to push() new elements. 144 std::vector<T> *ExtractNondestructive() const; 145 146 // A non-destructive version of Extract(). Copy the elements to a given 147 // vector sorted in descending order. After calling 148 // ExtractNondestructive(), the caller can continue to push() new elements. 149 // Note: 150 // 1. The given argument must to be allocated. 151 // 2. Any data contained in the vector prior to the call will be deleted 152 // from it. After the call the vector will contain only the elements 153 // from the data structure. 154 void ExtractNondestructive(std::vector<T> *output) const; 155 156 // A non-destructive version of ExtractUnsorted(). Copy the elements in a new 157 // vector and return it, with no guarantees the elements are in sorted order. 158 // The caller assumes ownership of the new vector and must delete it when 159 // done. After calling ExtractUnsortedNondestructive(), the caller can 160 // continue to push() new elements. 161 std::vector<T> *ExtractUnsortedNondestructive() const; 162 163 // A non-destructive version of ExtractUnsorted(). Copy the elements into 164 // a given vector, with no guarantees the elements are in sorted order. 165 // After calling ExtractUnsortedNondestructive(), the caller can continue 166 // to push() new elements. 167 // Note: 168 // 1. The given argument must to be allocated. 169 // 2. Any data contained in the vector prior to the call will be deleted 170 // from it. After the call the vector will contain only the elements 171 // from the data structure. 172 void ExtractUnsortedNondestructive(std::vector<T> *output) const; 173 174 // Return an iterator to the beginning (end) of the container, 175 // with no guarantees about the order of iteration. These iterators are 176 // invalidated by mutation of the data structure. 177 UnsortedIterator unsorted_begin() const { return elements_.begin(); } 178 UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } 179 180 // Accessor for comparator template argument. 181 Cmp *comparator() { return &cmp_; } 182 183 // This removes all elements. If Extract() or ExtractUnsorted() have been 184 // called, this will put it back in an empty but useable state. 185 void Reset(); 186 187 private: 188 template <typename U> 189 void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) 190 191 // elements_ can be in one of two states: 192 // elements_.size() <= limit_: elements_ is an unsorted vector of elements 193 // pushed so far. 194 // elements_.size() > limit_: The last element of elements_ is unused; 195 // the other elements of elements_ are an stl heap whose size is exactly 196 // limit_. In this case elements_.size() is exactly one greater than 197 // limit_, but don't use "elements_.size() == limit_ + 1" to check for 198 // that because you'll get a false positive if limit_ == size_t(-1). 199 std::vector<T> elements_; 200 size_t limit_; // Maximum number of elements to find 201 Cmp cmp_; // Greater-than comparison function 202 State state_ = UNORDERED; 203 }; 204 205 // ---------------------------------------------------------------------- 206 // Implementations of non-inline functions 207 208 template <class T, class Cmp> 209 template <typename U> 210 void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) 211 if (limit_ == 0) { 212 if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) 213 return; 214 } 215 if (state_ != HEAP_SORTED) { 216 elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11) 217 if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { 218 // Easy case: we just pushed the new element back 219 } else { 220 // To maintain the BOTTOM_KNOWN state, we need to make sure that 221 // the element at position 0 is always the smallest. So we put 222 // the new element at position 0 and push the original bottom 223 // element in the back. 224 // Warning: this code is subtle. 225 using std::swap; 226 swap(elements_.front(), elements_.back()); 227 } 228 if (elements_.size() == limit_ + 1) { 229 // Transition from unsorted vector to a heap. 230 std::make_heap(elements_.begin(), elements_.end(), cmp_); 231 if (dropped) *dropped = std::move(elements_.front()); 232 std::pop_heap(elements_.begin(), elements_.end(), cmp_); 233 state_ = HEAP_SORTED; 234 } 235 } else { 236 // Only insert the new element if it is greater than the least element. 237 if (cmp_(v, elements_.front())) { 238 elements_.back() = std::forward<U>(v); // NOLINT(build/c++11) 239 std::push_heap(elements_.begin(), elements_.end(), cmp_); 240 if (dropped) *dropped = std::move(elements_.front()); 241 std::pop_heap(elements_.begin(), elements_.end(), cmp_); 242 } else { 243 if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11) 244 } 245 } 246 } 247 248 template <class T, class Cmp> 249 const T &TopN<T, Cmp>::peek_bottom() { 250 CHECK(!empty()); 251 if (state_ == UNORDERED) { 252 // We need to do a linear scan to find out the bottom element 253 int min_candidate = 0; 254 for (size_t i = 1; i < elements_.size(); ++i) { 255 if (cmp_(elements_[min_candidate], elements_[i])) { 256 min_candidate = i; 257 } 258 } 259 // By swapping the element at position 0 and the minimal 260 // element, we transition to the BOTTOM_KNOWN state 261 if (min_candidate != 0) { 262 using std::swap; 263 swap(elements_[0], elements_[min_candidate]); 264 } 265 state_ = BOTTOM_KNOWN; 266 } 267 return elements_.front(); 268 } 269 270 template <class T, class Cmp> 271 std::vector<T> *TopN<T, Cmp>::Extract() { 272 auto out = new std::vector<T>; 273 out->swap(elements_); 274 if (state_ != HEAP_SORTED) { 275 std::sort(out->begin(), out->end(), cmp_); 276 } else { 277 out->pop_back(); 278 std::sort_heap(out->begin(), out->end(), cmp_); 279 } 280 return out; 281 } 282 283 template <class T, class Cmp> 284 std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() { 285 auto out = new std::vector<T>; 286 out->swap(elements_); 287 if (state_ == HEAP_SORTED) { 288 // Remove the limit_+1'th element. 289 out->pop_back(); 290 } 291 return out; 292 } 293 294 template <class T, class Cmp> 295 std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const { 296 auto out = new std::vector<T>; 297 ExtractNondestructive(out); 298 return out; 299 } 300 301 template <class T, class Cmp> 302 void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const { 303 CHECK(output); 304 *output = elements_; 305 if (state_ != HEAP_SORTED) { 306 std::sort(output->begin(), output->end(), cmp_); 307 } else { 308 output->pop_back(); 309 std::sort_heap(output->begin(), output->end(), cmp_); 310 } 311 } 312 313 template <class T, class Cmp> 314 std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const { 315 auto elements = new std::vector<T>; 316 ExtractUnsortedNondestructive(elements); 317 return elements; 318 } 319 320 template <class T, class Cmp> 321 void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const { 322 CHECK(output); 323 *output = elements_; 324 if (state_ == HEAP_SORTED) { 325 // Remove the limit_+1'th element. 326 output->pop_back(); 327 } 328 } 329 330 template <class T, class Cmp> 331 void TopN<T, Cmp>::Reset() { 332 elements_.clear(); 333 state_ = UNORDERED; 334 } 335 336 } // namespace gtl 337 } // namespace tensorflow 338 339 #endif // TENSORFLOW_LIB_GTL_TOP_N_H_ 340