Home | History | Annotate | Download | only in base
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef ART_RUNTIME_BASE_HASH_SET_H_
     18 #define ART_RUNTIME_BASE_HASH_SET_H_
     19 
     20 #include <functional>
     21 #include <memory>
     22 #include <stdint.h>
     23 #include <utility>
     24 
     25 #include "logging.h"
     26 
     27 namespace art {
     28 
     29 // Returns true if an item is empty.
     30 template <class T>
     31 class DefaultEmptyFn {
     32  public:
     33   void MakeEmpty(T& item) const {
     34     item = T();
     35   }
     36   bool IsEmpty(const T& item) const {
     37     return item == T();
     38   }
     39 };
     40 
     41 template <class T>
     42 class DefaultEmptyFn<T*> {
     43  public:
     44   void MakeEmpty(T*& item) const {
     45     item = nullptr;
     46   }
     47   bool IsEmpty(const T*& item) const {
     48     return item == nullptr;
     49   }
     50 };
     51 
     52 // Low memory version of a hash set, uses less memory than std::unordered_set since elements aren't
     53 // boxed. Uses linear probing.
     54 // EmptyFn needs to implement two functions MakeEmpty(T& item) and IsEmpty(const T& item)
     55 template <class T, class EmptyFn = DefaultEmptyFn<T>, class HashFn = std::hash<T>,
     56     class Pred = std::equal_to<T>, class Alloc = std::allocator<T>>
     57 class HashSet {
     58  public:
     59   static constexpr double kDefaultMinLoadFactor = 0.5;
     60   static constexpr double kDefaultMaxLoadFactor = 0.9;
     61   static constexpr size_t kMinBuckets = 1000;
     62 
     63   class Iterator {
     64    public:
     65     Iterator(const Iterator&) = default;
     66     Iterator(HashSet* hash_set, size_t index) : hash_set_(hash_set), index_(index) {
     67     }
     68     Iterator& operator=(const Iterator&) = default;
     69     bool operator==(const Iterator& other) const {
     70       return hash_set_ == other.hash_set_ && index_ == other.index_;
     71     }
     72     bool operator!=(const Iterator& other) const {
     73       return !(*this == other);
     74     }
     75     Iterator operator++() {  // Value after modification.
     76       index_ = NextNonEmptySlot(index_);
     77       return *this;
     78     }
     79     Iterator operator++(int) {
     80       Iterator temp = *this;
     81       index_ = NextNonEmptySlot(index_);
     82       return temp;
     83     }
     84     T& operator*() {
     85       DCHECK(!hash_set_->IsFreeSlot(GetIndex()));
     86       return hash_set_->ElementForIndex(index_);
     87     }
     88     const T& operator*() const {
     89       DCHECK(!hash_set_->IsFreeSlot(GetIndex()));
     90       return hash_set_->ElementForIndex(index_);
     91     }
     92     T* operator->() {
     93       return &**this;
     94     }
     95     const T* operator->() const {
     96       return &**this;
     97     }
     98     // TODO: Operator -- --(int)
     99 
    100    private:
    101     HashSet* hash_set_;
    102     size_t index_;
    103 
    104     size_t GetIndex() const {
    105       return index_;
    106     }
    107     size_t NextNonEmptySlot(size_t index) const {
    108       const size_t num_buckets = hash_set_->NumBuckets();
    109       DCHECK_LT(index, num_buckets);
    110       do {
    111         ++index;
    112       } while (index < num_buckets && hash_set_->IsFreeSlot(index));
    113       return index;
    114     }
    115 
    116     friend class HashSet;
    117   };
    118 
    119   void Clear() {
    120     DeallocateStorage();
    121     AllocateStorage(1);
    122     num_elements_ = 0;
    123     elements_until_expand_ = 0;
    124   }
    125   HashSet() : num_elements_(0), num_buckets_(0), data_(nullptr),
    126       min_load_factor_(kDefaultMinLoadFactor), max_load_factor_(kDefaultMaxLoadFactor) {
    127     Clear();
    128   }
    129   HashSet(const HashSet& other) : num_elements_(0), num_buckets_(0), data_(nullptr) {
    130     *this = other;
    131   }
    132   HashSet(HashSet&& other) : num_elements_(0), num_buckets_(0), data_(nullptr) {
    133     *this = std::move(other);
    134   }
    135   ~HashSet() {
    136     DeallocateStorage();
    137   }
    138   HashSet& operator=(HashSet&& other) {
    139     std::swap(data_, other.data_);
    140     std::swap(num_buckets_, other.num_buckets_);
    141     std::swap(num_elements_, other.num_elements_);
    142     std::swap(elements_until_expand_, other.elements_until_expand_);
    143     std::swap(min_load_factor_, other.min_load_factor_);
    144     std::swap(max_load_factor_, other.max_load_factor_);
    145     return *this;
    146   }
    147   HashSet& operator=(const HashSet& other) {
    148     DeallocateStorage();
    149     AllocateStorage(other.NumBuckets());
    150     for (size_t i = 0; i < num_buckets_; ++i) {
    151       ElementForIndex(i) = other.data_[i];
    152     }
    153     num_elements_ = other.num_elements_;
    154     elements_until_expand_ = other.elements_until_expand_;
    155     min_load_factor_ = other.min_load_factor_;
    156     max_load_factor_ = other.max_load_factor_;
    157     return *this;
    158   }
    159   // Lower case for c++11 for each.
    160   Iterator begin() {
    161     Iterator ret(this, 0);
    162     if (num_buckets_ != 0 && IsFreeSlot(ret.GetIndex())) {
    163       ++ret;  // Skip all the empty slots.
    164     }
    165     return ret;
    166   }
    167   // Lower case for c++11 for each.
    168   Iterator end() {
    169     return Iterator(this, NumBuckets());
    170   }
    171   bool Empty() {
    172     return begin() == end();
    173   }
    174   // Erase algorithm:
    175   // Make an empty slot where the iterator is pointing.
    176   // Scan fowards until we hit another empty slot.
    177   // If an element inbetween doesn't rehash to the range from the current empty slot to the
    178   // iterator. It must be before the empty slot, in that case we can move it to the empty slot
    179   // and set the empty slot to be the location we just moved from.
    180   // Relies on maintaining the invariant that there's no empty slots from the 'ideal' index of an
    181   // element to its actual location/index.
    182   Iterator Erase(Iterator it) {
    183     // empty_index is the index that will become empty.
    184     size_t empty_index = it.GetIndex();
    185     DCHECK(!IsFreeSlot(empty_index));
    186     size_t next_index = empty_index;
    187     bool filled = false;  // True if we filled the empty index.
    188     while (true) {
    189       next_index = NextIndex(next_index);
    190       T& next_element = ElementForIndex(next_index);
    191       // If the next element is empty, we are done. Make sure to clear the current empty index.
    192       if (emptyfn_.IsEmpty(next_element)) {
    193         emptyfn_.MakeEmpty(ElementForIndex(empty_index));
    194         break;
    195       }
    196       // Otherwise try to see if the next element can fill the current empty index.
    197       const size_t next_hash = hashfn_(next_element);
    198       // Calculate the ideal index, if it is within empty_index + 1 to next_index then there is
    199       // nothing we can do.
    200       size_t next_ideal_index = IndexForHash(next_hash);
    201       // Loop around if needed for our check.
    202       size_t unwrapped_next_index = next_index;
    203       if (unwrapped_next_index < empty_index) {
    204         unwrapped_next_index += NumBuckets();
    205       }
    206       // Loop around if needed for our check.
    207       size_t unwrapped_next_ideal_index = next_ideal_index;
    208       if (unwrapped_next_ideal_index < empty_index) {
    209         unwrapped_next_ideal_index += NumBuckets();
    210       }
    211       if (unwrapped_next_ideal_index <= empty_index ||
    212           unwrapped_next_ideal_index > unwrapped_next_index) {
    213         // If the target index isn't within our current range it must have been probed from before
    214         // the empty index.
    215         ElementForIndex(empty_index) = std::move(next_element);
    216         filled = true;  // TODO: Optimize
    217         empty_index = next_index;
    218       }
    219     }
    220     --num_elements_;
    221     // If we didn't fill the slot then we need go to the next non free slot.
    222     if (!filled) {
    223       ++it;
    224     }
    225     return it;
    226   }
    227   // Find an element, returns end() if not found.
    228   // Allows custom K types, example of when this is useful.
    229   // Set of Class* sorted by name, want to find a class with a name but can't allocate a dummy
    230   // object in the heap for performance solution.
    231   template <typename K>
    232   Iterator Find(const K& element) {
    233     return FindWithHash(element, hashfn_(element));
    234   }
    235   template <typename K>
    236   Iterator FindWithHash(const K& element, size_t hash) {
    237     DCHECK_EQ(hashfn_(element), hash);
    238     size_t index = IndexForHash(hash);
    239     while (true) {
    240       T& slot = ElementForIndex(index);
    241       if (emptyfn_.IsEmpty(slot)) {
    242         return end();
    243       }
    244       if (pred_(slot, element)) {
    245         return Iterator(this, index);
    246       }
    247       index = NextIndex(index);
    248     }
    249   }
    250   // Insert an element, allows duplicates.
    251   void Insert(const T& element) {
    252     InsertWithHash(element, hashfn_(element));
    253   }
    254   void InsertWithHash(const T& element, size_t hash) {
    255     DCHECK_EQ(hash, hashfn_(element));
    256     if (num_elements_ >= elements_until_expand_) {
    257       Expand();
    258       DCHECK_LT(num_elements_, elements_until_expand_);
    259     }
    260     const size_t index = FirstAvailableSlot(IndexForHash(hash));
    261     data_[index] = element;
    262     ++num_elements_;
    263   }
    264   size_t Size() const {
    265     return num_elements_;
    266   }
    267   void ShrinkToMaximumLoad() {
    268     Resize(Size() / max_load_factor_);
    269   }
    270   // To distance that inserted elements were probed. Used for measuring how good hash functions
    271   // are.
    272   size_t TotalProbeDistance() const {
    273     size_t total = 0;
    274     for (size_t i = 0; i < NumBuckets(); ++i) {
    275       const T& element = ElementForIndex(i);
    276       if (!emptyfn_.IsEmpty(element)) {
    277         size_t ideal_location = IndexForHash(hashfn_(element));
    278         if (ideal_location > i) {
    279           total += i + NumBuckets() - ideal_location;
    280         } else {
    281           total += i - ideal_location;
    282         }
    283       }
    284     }
    285     return total;
    286   }
    287   // Calculate the current load factor and return it.
    288   double CalculateLoadFactor() const {
    289     return static_cast<double>(Size()) / static_cast<double>(NumBuckets());
    290   }
    291   // Make sure that everything reinserts in the right spot. Returns the number of errors.
    292   size_t Verify() {
    293     size_t errors = 0;
    294     for (size_t i = 0; i < num_buckets_; ++i) {
    295       T& element = data_[i];
    296       if (!emptyfn_.IsEmpty(element)) {
    297         T temp;
    298         emptyfn_.MakeEmpty(temp);
    299         std::swap(temp, element);
    300         size_t first_slot = FirstAvailableSlot(IndexForHash(hashfn_(temp)));
    301         if (i != first_slot) {
    302           LOG(ERROR) << "Element " << i << " should be in slot " << first_slot;
    303           ++errors;
    304         }
    305         std::swap(temp, element);
    306       }
    307     }
    308     return errors;
    309   }
    310 
    311  private:
    312   T& ElementForIndex(size_t index) {
    313     DCHECK_LT(index, NumBuckets());
    314     DCHECK(data_ != nullptr);
    315     return data_[index];
    316   }
    317   const T& ElementForIndex(size_t index) const {
    318     DCHECK_LT(index, NumBuckets());
    319     DCHECK(data_ != nullptr);
    320     return data_[index];
    321   }
    322   size_t IndexForHash(size_t hash) const {
    323     return hash % num_buckets_;
    324   }
    325   size_t NextIndex(size_t index) const {
    326     if (UNLIKELY(++index >= num_buckets_)) {
    327       DCHECK_EQ(index, NumBuckets());
    328       return 0;
    329     }
    330     return index;
    331   }
    332   bool IsFreeSlot(size_t index) const {
    333     return emptyfn_.IsEmpty(ElementForIndex(index));
    334   }
    335   size_t NumBuckets() const {
    336     return num_buckets_;
    337   }
    338   // Allocate a number of buckets.
    339   void AllocateStorage(size_t num_buckets) {
    340     num_buckets_ = num_buckets;
    341     data_ = allocfn_.allocate(num_buckets_);
    342     for (size_t i = 0; i < num_buckets_; ++i) {
    343       allocfn_.construct(allocfn_.address(data_[i]));
    344       emptyfn_.MakeEmpty(data_[i]);
    345     }
    346   }
    347   void DeallocateStorage() {
    348     if (num_buckets_ != 0) {
    349       for (size_t i = 0; i < NumBuckets(); ++i) {
    350         allocfn_.destroy(allocfn_.address(data_[i]));
    351       }
    352       allocfn_.deallocate(data_, NumBuckets());
    353       data_ = nullptr;
    354       num_buckets_ = 0;
    355     }
    356   }
    357   // Expand the set based on the load factors.
    358   void Expand() {
    359     size_t min_index = static_cast<size_t>(Size() / min_load_factor_);
    360     if (min_index < kMinBuckets) {
    361       min_index = kMinBuckets;
    362     }
    363     // Resize based on the minimum load factor.
    364     Resize(min_index);
    365     // When we hit elements_until_expand_, we are at the max load factor and must expand again.
    366     elements_until_expand_ = NumBuckets() * max_load_factor_;
    367   }
    368   // Expand / shrink the table to the new specified size.
    369   void Resize(size_t new_size) {
    370     DCHECK_GE(new_size, Size());
    371     T* old_data = data_;
    372     size_t old_num_buckets = num_buckets_;
    373     // Reinsert all of the old elements.
    374     AllocateStorage(new_size);
    375     for (size_t i = 0; i < old_num_buckets; ++i) {
    376       T& element = old_data[i];
    377       if (!emptyfn_.IsEmpty(element)) {
    378         data_[FirstAvailableSlot(IndexForHash(hashfn_(element)))] = std::move(element);
    379       }
    380       allocfn_.destroy(allocfn_.address(element));
    381     }
    382     allocfn_.deallocate(old_data, old_num_buckets);
    383   }
    384   ALWAYS_INLINE size_t FirstAvailableSlot(size_t index) const {
    385     while (!emptyfn_.IsEmpty(data_[index])) {
    386       index = NextIndex(index);
    387     }
    388     return index;
    389   }
    390 
    391   Alloc allocfn_;  // Allocator function.
    392   HashFn hashfn_;  // Hashing function.
    393   EmptyFn emptyfn_;  // IsEmpty/SetEmpty function.
    394   Pred pred_;  // Equals function.
    395   size_t num_elements_;  // Number of inserted elements.
    396   size_t num_buckets_;  // Number of hash table buckets.
    397   size_t elements_until_expand_;  // Maxmimum number of elements until we expand the table.
    398   T* data_;  // Backing storage.
    399   double min_load_factor_;
    400   double max_load_factor_;
    401 
    402   friend class Iterator;
    403 };
    404 
    405 }  // namespace art
    406 
    407 #endif  // ART_RUNTIME_BASE_HASH_SET_H_
    408