Home | History | Annotate | Download | only in gtl
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
     17 #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
     18 
     19 #include <string.h>
     20 #include <utility>
     21 #include "tensorflow/core/platform/prefetch.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 namespace gtl {
     26 namespace internal {
     27 
     28 // Internal representation for FlatMap and FlatSet.
     29 //
     30 // The representation is an open-addressed hash table.  Conceptually,
     31 // the representation is a flat array of entries.  However we
     32 // structure it as an array of buckets where each bucket holds
     33 // kWidth entries along with metadata for the kWidth entries.  The
     34 // metadata marker is
     35 //
     36 //  (a) kEmpty: the entry is empty
     37 //  (b) kDeleted: the entry has been deleted
     38 //  (c) other: the entry is occupied and has low-8 bits of its hash.
     39 //      These hash bits can be used to avoid potentially expensive
     40 //      key comparisons.
     41 //
     42 // FlatMap passes in a bucket that contains keys and values, FlatSet
     43 // passes in a bucket that does not contain values.
     44 template <typename Key, typename Bucket, class Hash, class Eq>
     45 class FlatRep {
     46  public:
     47   // kWidth is the number of entries stored in a bucket.
     48   static const uint32 kBase = 3;
     49   static const uint32 kWidth = (1 << kBase);
     50 
     51   FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
     52     Init(N);
     53   }
     54   explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
     55     Init(src.size());
     56     CopyEntries(src.array_, src.end_, CopyEntry());
     57   }
     58   ~FlatRep() {
     59     clear_no_resize();
     60     delete[] array_;
     61   }
     62 
     63   // Simple accessors.
     64   size_t size() const { return not_empty_ - deleted_; }
     65   size_t bucket_count() const { return mask_ + 1; }
     66   Bucket* start() const { return array_; }
     67   Bucket* limit() const { return end_; }
     68   const Hash& hash_function() const { return hash_; }
     69   const Eq& key_eq() const { return equal_; }
     70 
     71   // Overwrite contents of *this with contents of src.
     72   void CopyFrom(const FlatRep& src) {
     73     if (this != &src) {
     74       clear_no_resize();
     75       delete[] array_;
     76       Init(src.size());
     77       CopyEntries(src.array_, src.end_, CopyEntry());
     78     }
     79   }
     80 
     81   void clear_no_resize() {
     82     for (Bucket* b = array_; b != end_; b++) {
     83       for (uint32 i = 0; i < kWidth; i++) {
     84         if (b->marker[i] >= 2) {
     85           b->Destroy(i);
     86           b->marker[i] = kEmpty;
     87         }
     88       }
     89     }
     90     not_empty_ = 0;
     91     deleted_ = 0;
     92   }
     93 
     94   void clear() {
     95     clear_no_resize();
     96     grow_ = 0;  // Consider shrinking in MaybeResize()
     97     MaybeResize();
     98   }
     99 
    100   void swap(FlatRep& x) {
    101     using std::swap;
    102     swap(array_, x.array_);
    103     swap(end_, x.end_);
    104     swap(lglen_, x.lglen_);
    105     swap(mask_, x.mask_);
    106     swap(not_empty_, x.not_empty_);
    107     swap(deleted_, x.deleted_);
    108     swap(grow_, x.grow_);
    109     swap(shrink_, x.shrink_);
    110   }
    111 
    112   struct SearchResult {
    113     bool found;
    114     Bucket* b;
    115     uint32 index;
    116   };
    117 
    118   // Hash value is partitioned as follows:
    119   // 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
    120   // 2. Next 3 bits give index inside bucket.
    121   // 3. Remaining bits give bucket number.
    122 
    123   // Find bucket/index for key k.
    124   SearchResult Find(const Key& k) const {
    125     size_t h = hash_(k);
    126     const uint32 marker = Marker(h & 0xff);
    127     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    128     uint32 num_probes = 1;            // Needed for quadratic probing
    129     while (true) {
    130       uint32 bi = index & (kWidth - 1);
    131       Bucket* b = &array_[index >> kBase];
    132       const uint32 x = b->marker[bi];
    133       if (x == marker && equal_(b->key(bi), k)) {
    134         return {true, b, bi};
    135       } else if (x == kEmpty) {
    136         return {false, nullptr, 0};
    137       }
    138       index = NextIndex(index, num_probes);
    139       num_probes++;
    140     }
    141   }
    142 
    143   // Find bucket/index for key k, creating a new one if necessary.
    144   //
    145   // KeyType is a template parameter so that k's type is deduced and it
    146   // becomes a universal reference which allows the key initialization
    147   // below to use an rvalue constructor if available.
    148   template <typename KeyType>
    149   SearchResult FindOrInsert(KeyType&& k) {
    150     size_t h = hash_(k);
    151     const uint32 marker = Marker(h & 0xff);
    152     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    153     uint32 num_probes = 1;            // Needed for quadratic probing
    154     Bucket* del = nullptr;            // First encountered deletion for kInsert
    155     uint32 di = 0;
    156     while (true) {
    157       uint32 bi = index & (kWidth - 1);
    158       Bucket* b = &array_[index >> kBase];
    159       const uint32 x = b->marker[bi];
    160       if (x == marker && equal_(b->key(bi), k)) {
    161         return {true, b, bi};
    162       } else if (!del && x == kDeleted) {
    163         // Remember deleted index to use for insertion.
    164         del = b;
    165         di = bi;
    166       } else if (x == kEmpty) {
    167         if (del) {
    168           // Store in the first deleted slot we encountered
    169           b = del;
    170           bi = di;
    171           deleted_--;  // not_empty_ does not change
    172         } else {
    173           not_empty_++;
    174         }
    175         b->marker[bi] = marker;
    176         new (&b->key(bi)) Key(std::forward<KeyType>(k));
    177         return {false, b, bi};
    178       }
    179       index = NextIndex(index, num_probes);
    180       num_probes++;
    181     }
    182   }
    183 
    184   void Erase(Bucket* b, uint32 i) {
    185     b->Destroy(i);
    186     b->marker[i] = kDeleted;
    187     deleted_++;
    188     grow_ = 0;  // Consider shrinking on next insert
    189   }
    190 
    191   void Prefetch(const Key& k) const {
    192     size_t h = hash_(k);
    193     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    194     uint32 bi = index & (kWidth - 1);
    195     Bucket* b = &array_[index >> kBase];
    196     port::prefetch<port::PREFETCH_HINT_T0>(&b->marker[bi]);
    197     port::prefetch<port::PREFETCH_HINT_T0>(&b->storage.key[bi]);
    198   }
    199 
    200   inline void MaybeResize() {
    201     if (not_empty_ < grow_) {
    202       return;  // Nothing to do
    203     }
    204     if (grow_ == 0) {
    205       // Special value set by erase to cause shrink on next insert.
    206       if (size() >= shrink_) {
    207         // Not small enough to shrink.
    208         grow_ = static_cast<size_t>(bucket_count() * 0.8);
    209         if (not_empty_ < grow_) return;
    210       }
    211     }
    212     Resize(size() + 1);
    213   }
    214 
    215   void Resize(size_t N) {
    216     Bucket* old = array_;
    217     Bucket* old_end = end_;
    218     Init(N);
    219     CopyEntries(old, old_end, MoveEntry());
    220     delete[] old;
    221   }
    222 
    223  private:
    224   enum { kEmpty = 0, kDeleted = 1 };  // Special markers for an entry.
    225 
    226   Hash hash_;         // User-supplied hasher
    227   Eq equal_;          // User-supplied comparator
    228   uint8 lglen_;       // lg(#buckets)
    229   Bucket* array_;     // array of length (1 << lglen_)
    230   Bucket* end_;       // Points just past last bucket in array_
    231   size_t mask_;       // (# of entries in table) - 1
    232   size_t not_empty_;  // Count of entries with marker != kEmpty
    233   size_t deleted_;    // Count of entries with marker == kDeleted
    234   size_t grow_;       // Grow array when not_empty_ >= grow_
    235   size_t shrink_;     // Shrink array when size() < shrink_
    236 
    237   // Avoid kEmpty and kDeleted markers when computing hash values to
    238   // store in Bucket::marker[].
    239   static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
    240 
    241   void Init(size_t N) {
    242     // Make enough room for N elements.
    243     size_t lg = 0;  // Smallest table is just one bucket.
    244     while (N >= 0.8 * ((1 << lg) * kWidth)) {
    245       lg++;
    246     }
    247     const size_t n = (1 << lg);
    248     Bucket* array = new Bucket[n];
    249     for (size_t i = 0; i < n; i++) {
    250       Bucket* b = &array[i];
    251       memset(b->marker, kEmpty, kWidth);
    252     }
    253     const size_t capacity = (1 << lg) * kWidth;
    254     lglen_ = lg;
    255     mask_ = capacity - 1;
    256     array_ = array;
    257     end_ = array + n;
    258     not_empty_ = 0;
    259     deleted_ = 0;
    260     grow_ = static_cast<size_t>(capacity * 0.8);
    261     if (lg == 0) {
    262       // Already down to one bucket; no more shrinking.
    263       shrink_ = 0;
    264     } else {
    265       shrink_ = static_cast<size_t>(grow_ * 0.4);  // Must be less than 0.5
    266     }
    267   }
    268 
    269   // Used by FreshInsert when we should copy from source.
    270   struct CopyEntry {
    271     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
    272       dst->CopyFrom(dsti, src, srci);
    273     }
    274   };
    275 
    276   // Used by FreshInsert when we should move from source.
    277   struct MoveEntry {
    278     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
    279       dst->MoveFrom(dsti, src, srci);
    280       src->Destroy(srci);
    281       src->marker[srci] = kDeleted;
    282     }
    283   };
    284 
    285   template <typename Copier>
    286   void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
    287     for (Bucket* b = start; b != end; b++) {
    288       for (uint32 i = 0; i < kWidth; i++) {
    289         if (b->marker[i] >= 2) {
    290           FreshInsert(b, i, copier);
    291         }
    292       }
    293     }
    294   }
    295 
    296   // Create an entry for the key numbered src_index in *src and return
    297   // its bucket/index.  Used for insertion into a fresh table.  We
    298   // assume that there are no deletions, and k does not already exist
    299   // in the table.
    300   template <typename Copier>
    301   void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
    302     size_t h = hash_(src->key(src_index));
    303     const uint32 marker = Marker(h & 0xff);
    304     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    305     uint32 num_probes = 1;            // Needed for quadratic probing
    306     while (true) {
    307       uint32 bi = index & (kWidth - 1);
    308       Bucket* b = &array_[index >> kBase];
    309       const uint32 x = b->marker[bi];
    310       if (x == 0) {
    311         b->marker[bi] = marker;
    312         not_empty_++;
    313         copier(b, bi, src, src_index);
    314         return;
    315       }
    316       index = NextIndex(index, num_probes);
    317       num_probes++;
    318     }
    319   }
    320 
    321   inline size_t NextIndex(size_t i, uint32 num_probes) const {
    322     // Quadratic probing.
    323     return (i + num_probes) & mask_;
    324   }
    325 };
    326 
    327 }  // namespace internal
    328 }  // namespace gtl
    329 }  // namespace tensorflow
    330 
    331 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
    332