Home | History | Annotate | Download | only in gtl
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
     17 #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
     18 
     19 #include <string.h>
     20 #include <utility>
     21 #include "tensorflow/core/platform/prefetch.h"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 namespace tensorflow {
     25 namespace gtl {
     26 namespace internal {
     27 
     28 // Internal representation for FlatMap and FlatSet.
     29 //
     30 // The representation is an open-addressed hash table.  Conceptually,
     31 // the representation is a flat array of entries.  However we
     32 // structure it as an array of buckets where each bucket holds
     33 // kWidth entries along with metadata for the kWidth entries.  The
     34 // metadata marker is
     35 //
     36 //  (a) kEmpty: the entry is empty
     37 //  (b) kDeleted: the entry has been deleted
     38 //  (c) other: the entry is occupied and has low-8 bits of its hash.
     39 //      These hash bits can be used to avoid potentially expensive
     40 //      key comparisons.
     41 //
     42 // FlatMap passes in a bucket that contains keys and values, FlatSet
     43 // passes in a bucket that does not contain values.
     44 template <typename Key, typename Bucket, class Hash, class Eq>
     45 class FlatRep {
     46  public:
     47   // kWidth is the number of entries stored in a bucket.
     48   static const uint32 kBase = 3;
     49   static const uint32 kWidth = (1 << kBase);
     50 
     51   FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) {
     52     Init(N);
     53   }
     54   FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) {
     55     Init(src.size());
     56     CopyEntries(src.array_, src.end_, CopyEntry());
     57   }
     58 
     59   FlatRep(FlatRep&& src)
     60       // Copy rather than move src.hash_ and src.equal_.  This is necessary to
     61       // leave src in a valid state -- otherwise e.g. if hash_ is an
     62       // std::function, moving it would null it out.
     63       : hash_(src.hash_), equal_(src.equal_) {
     64     // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap
     65     // as it could be.  The fundamental problem is that we need to leave src in
     66     // a valid state, and FlatRep *always* owns a nonzero amount of memory.
     67     Init(1);
     68     swap(src);
     69   }
     70 
     71   ~FlatRep() {
     72     clear_no_resize();
     73     delete[] array_;
     74   }
     75 
     76   // Simple accessors.
     77   size_t size() const { return not_empty_ - deleted_; }
     78   size_t bucket_count() const { return mask_ + 1; }
     79   Bucket* start() const { return array_; }
     80   Bucket* limit() const { return end_; }
     81   const Hash& hash_function() const { return hash_; }
     82   const Eq& key_eq() const { return equal_; }
     83 
     84   // Overwrite contents of *this with contents of src.
     85   void CopyFrom(const FlatRep& src) {
     86     if (this != &src) {
     87       clear_no_resize();
     88       delete[] array_;
     89       Init(src.size());
     90       CopyEntries(src.array_, src.end_, CopyEntry());
     91     }
     92   }
     93 
     94   void MoveFrom(FlatRep&& src) {
     95     if (this != &src) {
     96       swap(src);
     97     }
     98   }
     99 
    100   void clear_no_resize() {
    101     for (Bucket* b = array_; b != end_; b++) {
    102       for (uint32 i = 0; i < kWidth; i++) {
    103         if (b->marker[i] >= 2) {
    104           b->Destroy(i);
    105           b->marker[i] = kEmpty;
    106         }
    107       }
    108     }
    109     not_empty_ = 0;
    110     deleted_ = 0;
    111   }
    112 
    113   void clear() {
    114     clear_no_resize();
    115     grow_ = 0;  // Consider shrinking in MaybeResize()
    116     MaybeResize();
    117   }
    118 
    119   void swap(FlatRep& x) {
    120     using std::swap;
    121     swap(array_, x.array_);
    122     swap(end_, x.end_);
    123     swap(lglen_, x.lglen_);
    124     swap(mask_, x.mask_);
    125     swap(not_empty_, x.not_empty_);
    126     swap(deleted_, x.deleted_);
    127     swap(grow_, x.grow_);
    128     swap(shrink_, x.shrink_);
    129   }
    130 
    131   struct SearchResult {
    132     bool found;
    133     Bucket* b;
    134     uint32 index;
    135   };
    136 
    137   // Hash value is partitioned as follows:
    138   // 1. Bottom 8 bits are stored in bucket to help speed up comparisons.
    139   // 2. Next 3 bits give index inside bucket.
    140   // 3. Remaining bits give bucket number.
    141 
    142   // Find bucket/index for key k.
    143   SearchResult Find(const Key& k) const {
    144     size_t h = hash_(k);
    145     const uint32 marker = Marker(h & 0xff);
    146     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    147     uint32 num_probes = 1;            // Needed for quadratic probing
    148     while (true) {
    149       uint32 bi = index & (kWidth - 1);
    150       Bucket* b = &array_[index >> kBase];
    151       const uint32 x = b->marker[bi];
    152       if (x == marker && equal_(b->key(bi), k)) {
    153         return {true, b, bi};
    154       } else if (x == kEmpty) {
    155         return {false, nullptr, 0};
    156       }
    157       index = NextIndex(index, num_probes);
    158       num_probes++;
    159     }
    160   }
    161 
    162   // Find bucket/index for key k, creating a new one if necessary.
    163   //
    164   // KeyType is a template parameter so that k's type is deduced and it
    165   // becomes a universal reference which allows the key initialization
    166   // below to use an rvalue constructor if available.
    167   template <typename KeyType>
    168   SearchResult FindOrInsert(KeyType&& k) {
    169     size_t h = hash_(k);
    170     const uint32 marker = Marker(h & 0xff);
    171     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    172     uint32 num_probes = 1;            // Needed for quadratic probing
    173     Bucket* del = nullptr;            // First encountered deletion for kInsert
    174     uint32 di = 0;
    175     while (true) {
    176       uint32 bi = index & (kWidth - 1);
    177       Bucket* b = &array_[index >> kBase];
    178       const uint32 x = b->marker[bi];
    179       if (x == marker && equal_(b->key(bi), k)) {
    180         return {true, b, bi};
    181       } else if (!del && x == kDeleted) {
    182         // Remember deleted index to use for insertion.
    183         del = b;
    184         di = bi;
    185       } else if (x == kEmpty) {
    186         if (del) {
    187           // Store in the first deleted slot we encountered
    188           b = del;
    189           bi = di;
    190           deleted_--;  // not_empty_ does not change
    191         } else {
    192           not_empty_++;
    193         }
    194         b->marker[bi] = marker;
    195         new (&b->key(bi)) Key(std::forward<KeyType>(k));
    196         return {false, b, bi};
    197       }
    198       index = NextIndex(index, num_probes);
    199       num_probes++;
    200     }
    201   }
    202 
    203   void Erase(Bucket* b, uint32 i) {
    204     b->Destroy(i);
    205     b->marker[i] = kDeleted;
    206     deleted_++;
    207     grow_ = 0;  // Consider shrinking on next insert
    208   }
    209 
    210   void Prefetch(const Key& k) const {
    211     size_t h = hash_(k);
    212     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    213     uint32 bi = index & (kWidth - 1);
    214     Bucket* b = &array_[index >> kBase];
    215     port::prefetch<port::PREFETCH_HINT_T0>(&b->marker[bi]);
    216     port::prefetch<port::PREFETCH_HINT_T0>(&b->storage.key[bi]);
    217   }
    218 
    219   inline void MaybeResize() {
    220     if (not_empty_ < grow_) {
    221       return;  // Nothing to do
    222     }
    223     if (grow_ == 0) {
    224       // Special value set by erase to cause shrink on next insert.
    225       if (size() >= shrink_) {
    226         // Not small enough to shrink.
    227         grow_ = static_cast<size_t>(bucket_count() * 0.8);
    228         if (not_empty_ < grow_) return;
    229       }
    230     }
    231     Resize(size() + 1);
    232   }
    233 
    234   void Resize(size_t N) {
    235     Bucket* old = array_;
    236     Bucket* old_end = end_;
    237     Init(N);
    238     CopyEntries(old, old_end, MoveEntry());
    239     delete[] old;
    240   }
    241 
    242  private:
    243   enum { kEmpty = 0, kDeleted = 1 };  // Special markers for an entry.
    244 
    245   Hash hash_;         // User-supplied hasher
    246   Eq equal_;          // User-supplied comparator
    247   uint8 lglen_;       // lg(#buckets)
    248   Bucket* array_;     // array of length (1 << lglen_)
    249   Bucket* end_;       // Points just past last bucket in array_
    250   size_t mask_;       // (# of entries in table) - 1
    251   size_t not_empty_;  // Count of entries with marker != kEmpty
    252   size_t deleted_;    // Count of entries with marker == kDeleted
    253   size_t grow_;       // Grow array when not_empty_ >= grow_
    254   size_t shrink_;     // Shrink array when size() < shrink_
    255 
    256   // Avoid kEmpty and kDeleted markers when computing hash values to
    257   // store in Bucket::marker[].
    258   static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); }
    259 
    260   void Init(size_t N) {
    261     // Make enough room for N elements.
    262     size_t lg = 0;  // Smallest table is just one bucket.
    263     while (N >= 0.8 * ((1 << lg) * kWidth)) {
    264       lg++;
    265     }
    266     const size_t n = (1 << lg);
    267     Bucket* array = new Bucket[n];
    268     for (size_t i = 0; i < n; i++) {
    269       Bucket* b = &array[i];
    270       memset(b->marker, kEmpty, kWidth);
    271     }
    272     const size_t capacity = (1 << lg) * kWidth;
    273     lglen_ = lg;
    274     mask_ = capacity - 1;
    275     array_ = array;
    276     end_ = array + n;
    277     not_empty_ = 0;
    278     deleted_ = 0;
    279     grow_ = static_cast<size_t>(capacity * 0.8);
    280     if (lg == 0) {
    281       // Already down to one bucket; no more shrinking.
    282       shrink_ = 0;
    283     } else {
    284       shrink_ = static_cast<size_t>(grow_ * 0.4);  // Must be less than 0.5
    285     }
    286   }
    287 
    288   // Used by FreshInsert when we should copy from source.
    289   struct CopyEntry {
    290     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
    291       dst->CopyFrom(dsti, src, srci);
    292     }
    293   };
    294 
    295   // Used by FreshInsert when we should move from source.
    296   struct MoveEntry {
    297     inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) {
    298       dst->MoveFrom(dsti, src, srci);
    299       src->Destroy(srci);
    300       src->marker[srci] = kDeleted;
    301     }
    302   };
    303 
    304   template <typename Copier>
    305   void CopyEntries(Bucket* start, Bucket* end, Copier copier) {
    306     for (Bucket* b = start; b != end; b++) {
    307       for (uint32 i = 0; i < kWidth; i++) {
    308         if (b->marker[i] >= 2) {
    309           FreshInsert(b, i, copier);
    310         }
    311       }
    312     }
    313   }
    314 
    315   // Create an entry for the key numbered src_index in *src and return
    316   // its bucket/index.  Used for insertion into a fresh table.  We
    317   // assume that there are no deletions, and k does not already exist
    318   // in the table.
    319   template <typename Copier>
    320   void FreshInsert(Bucket* src, uint32 src_index, Copier copier) {
    321     size_t h = hash_(src->key(src_index));
    322     const uint32 marker = Marker(h & 0xff);
    323     size_t index = (h >> 8) & mask_;  // Holds bucket num and index-in-bucket
    324     uint32 num_probes = 1;            // Needed for quadratic probing
    325     while (true) {
    326       uint32 bi = index & (kWidth - 1);
    327       Bucket* b = &array_[index >> kBase];
    328       const uint32 x = b->marker[bi];
    329       if (x == 0) {
    330         b->marker[bi] = marker;
    331         not_empty_++;
    332         copier(b, bi, src, src_index);
    333         return;
    334       }
    335       index = NextIndex(index, num_probes);
    336       num_probes++;
    337     }
    338   }
    339 
    340   inline size_t NextIndex(size_t i, uint32 num_probes) const {
    341     // Quadratic probing.
    342     return (i + num_probes) & mask_;
    343   }
    344 };
    345 
    346 }  // namespace internal
    347 }  // namespace gtl
    348 }  // namespace tensorflow
    349 
    350 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_
    351