Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
     17 #define TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
     18 
     19 #include <cassert>
     20 #include <cstdint>
     21 #include <cstdlib>
     22 #include <vector>
     23 
     24 namespace tensorflow {
     25 namespace nearest_neighbor {
     26 
     27 // A simple binary heap. We use our own implementation because multiprobe for
     28 // the cross-polytope hash interacts with the heap in a way so that about half
     29 // of the insertion operations are guaranteed to be on top of the heap. We make
     30 // use of this fact in the AugmentedHeap below.
     31 
     32 // HeapBase is a base class for both the SimpleHeap and AugmentedHeap below.
     33 template <typename KeyType, typename DataType>
     34 class HeapBase {
     35  public:
     36   class Item {
     37    public:
     38     KeyType key;
     39     DataType data;
     40 
     41     Item() {}
     42     Item(const KeyType& k, const DataType& d) : key(k), data(d) {}
     43 
     44     bool operator<(const Item& i2) const { return key < i2.key; }
     45   };
     46 
     47   void ExtractMin(KeyType* key, DataType* data) {
     48     *key = v_[0].key;
     49     *data = v_[0].data;
     50     num_elements_ -= 1;
     51     v_[0] = v_[num_elements_];
     52     HeapDown(0);
     53   }
     54 
     55   bool IsEmpty() { return num_elements_ == 0; }
     56 
     57   // This method adds an element at the end of the internal array without
     58   // "heapifying" the array afterwards. This is useful for setting up a heap
     59   // where a single call to heapify at the end of the initial insertion
     60   // operations suffices.
     61   void InsertUnsorted(const KeyType& key, const DataType& data) {
     62     if (v_.size() == static_cast<size_t>(num_elements_)) {
     63       v_.push_back(Item(key, data));
     64     } else {
     65       v_[num_elements_].key = key;
     66       v_[num_elements_].data = data;
     67     }
     68     num_elements_ += 1;
     69   }
     70 
     71   void Insert(const KeyType& key, const DataType& data) {
     72     if (v_.size() == static_cast<size_t>(num_elements_)) {
     73       v_.push_back(Item(key, data));
     74     } else {
     75       v_[num_elements_].key = key;
     76       v_[num_elements_].data = data;
     77     }
     78     num_elements_ += 1;
     79     HeapUp(num_elements_ - 1);
     80   }
     81 
     82   void Heapify() {
     83     int_fast32_t rightmost = parent(num_elements_ - 1);
     84     for (int_fast32_t cur_loc = rightmost; cur_loc >= 0; --cur_loc) {
     85       HeapDown(cur_loc);
     86     }
     87   }
     88 
     89   void Reset() { num_elements_ = 0; }
     90 
     91   void Resize(size_t new_size) { v_.resize(new_size); }
     92 
     93  protected:
     94   int_fast32_t lchild(int_fast32_t x) { return 2 * x + 1; }
     95 
     96   int_fast32_t rchild(int_fast32_t x) { return 2 * x + 2; }
     97 
     98   int_fast32_t parent(int_fast32_t x) { return (x - 1) / 2; }
     99 
    100   void SwapEntries(int_fast32_t a, int_fast32_t b) {
    101     Item tmp = v_[a];
    102     v_[a] = v_[b];
    103     v_[b] = tmp;
    104   }
    105 
    106   void HeapUp(int_fast32_t cur_loc) {
    107     int_fast32_t p = parent(cur_loc);
    108     while (cur_loc > 0 && v_[p].key > v_[cur_loc].key) {
    109       SwapEntries(p, cur_loc);
    110       cur_loc = p;
    111       p = parent(cur_loc);
    112     }
    113   }
    114 
    115   void HeapDown(int_fast32_t cur_loc) {
    116     while (true) {
    117       int_fast32_t lc = lchild(cur_loc);
    118       int_fast32_t rc = rchild(cur_loc);
    119       if (lc >= num_elements_) {
    120         return;
    121       }
    122 
    123       if (v_[cur_loc].key <= v_[lc].key) {
    124         if (rc >= num_elements_ || v_[cur_loc].key <= v_[rc].key) {
    125           return;
    126         } else {
    127           SwapEntries(cur_loc, rc);
    128           cur_loc = rc;
    129         }
    130       } else {
    131         if (rc >= num_elements_ || v_[lc].key <= v_[rc].key) {
    132           SwapEntries(cur_loc, lc);
    133           cur_loc = lc;
    134         } else {
    135           SwapEntries(cur_loc, rc);
    136           cur_loc = rc;
    137         }
    138       }
    139     }
    140   }
    141 
    142   std::vector<Item> v_;
    143   int_fast32_t num_elements_ = 0;
    144 };
    145 
    146 // A "simple" binary heap.
    147 template <typename KeyType, typename DataType>
    148 class SimpleHeap : public HeapBase<KeyType, DataType> {
    149  public:
    150   void ReplaceTop(const KeyType& key, const DataType& data) {
    151     this->v_[0].key = key;
    152     this->v_[0].data = data;
    153     this->HeapDown(0);
    154   }
    155 
    156   KeyType MinKey() { return this->v_[0].key; }
    157 
    158   std::vector<typename HeapBase<KeyType, DataType>::Item>& GetData() {
    159     return this->v_;
    160   }
    161 };
    162 
    163 // An "augmented" heap that can hold an extra element that is guaranteed to
    164 // be at the top of the heap. This is useful if a significant fraction of the
    165 // insertion operations are guaranteed insertions at the top. However, the heap
    166 // only stores at most one such special top element, i.e., the heap assumes
    167 // that extract_min() is called at least once between successive calls to
    168 // insert_guaranteed_top().
    169 template <typename KeyType, typename DataType>
    170 class AugmentedHeap : public HeapBase<KeyType, DataType> {
    171  public:
    172   void ExtractMin(KeyType* key, DataType* data) {
    173     if (has_guaranteed_top_) {
    174       has_guaranteed_top_ = false;
    175       *key = guaranteed_top_.key;
    176       *data = guaranteed_top_.data;
    177     } else {
    178       *key = this->v_[0].key;
    179       *data = this->v_[0].data;
    180       this->num_elements_ -= 1;
    181       this->v_[0] = this->v_[this->num_elements_];
    182       this->HeapDown(0);
    183     }
    184   }
    185 
    186   bool IsEmpty() { return this->num_elements_ == 0 && !has_guaranteed_top_; }
    187 
    188   void InsertGuaranteedTop(const KeyType& key, const DataType& data) {
    189     assert(!has_guaranteed_top_);
    190     has_guaranteed_top_ = true;
    191     guaranteed_top_.key = key;
    192     guaranteed_top_.data = data;
    193   }
    194 
    195   void Reset() {
    196     this->num_elements_ = 0;
    197     has_guaranteed_top_ = false;
    198   }
    199 
    200  protected:
    201   typename HeapBase<KeyType, DataType>::Item guaranteed_top_;
    202   bool has_guaranteed_top_ = false;
    203 };
    204 
    205 }  // namespace nearest_neighbor
    206 }  // namespace tensorflow
    207 
    208 #endif  // TENSORFLOW_CONTRIB_NEAREST_NEIGHBOR_KERNELS_HEAP_H_
    209