Home | History | Annotate | Download | only in gtl
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
     17 #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
     18 
     19 #include <stddef.h>
     20 #include <functional>
     21 #include <initializer_list>
     22 #include <iterator>
     23 #include <utility>
     24 #include "tensorflow/core/lib/gtl/flatrep.h"
     25 #include "tensorflow/core/lib/hash/hash.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 namespace gtl {
     31 
     32 // FlatSet<K,...> provides a set of K.
     33 //
     34 // The map is implemented using an open-addressed hash table.  A
     35 // single array holds entire map contents and collisions are resolved
     36 // by probing at a sequence of locations in the array.
     37 template <typename Key, class Hash = hash<Key>, class Eq = std::equal_to<Key>>
     38 class FlatSet {
     39  private:
     40   // Forward declare some internal types needed in public section.
     41   struct Bucket;
     42 
     43  public:
     44   typedef Key key_type;
     45   typedef Key value_type;
     46   typedef Hash hasher;
     47   typedef Eq key_equal;
     48   typedef size_t size_type;
     49   typedef ptrdiff_t difference_type;
     50   typedef value_type* pointer;
     51   typedef const value_type* const_pointer;
     52   typedef value_type& reference;
     53   typedef const value_type& const_reference;
     54 
     55   FlatSet() : FlatSet(1) {}
     56 
     57   explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq())
     58       : rep_(N, hf, eq) {}
     59 
     60   FlatSet(const FlatSet& src) : rep_(src.rep_) {}
     61 
     62   template <typename InputIter>
     63   FlatSet(InputIter first, InputIter last, size_t N = 1,
     64           const Hash& hf = Hash(), const Eq& eq = Eq())
     65       : FlatSet(N, hf, eq) {
     66     insert(first, last);
     67   }
     68 
     69   FlatSet(std::initializer_list<value_type> init, size_t N = 1,
     70           const Hash& hf = Hash(), const Eq& eq = Eq())
     71       : FlatSet(init.begin(), init.end(), N, hf, eq) {}
     72 
     73   FlatSet& operator=(const FlatSet& src) {
     74     rep_.CopyFrom(src.rep_);
     75     return *this;
     76   }
     77 
     78   ~FlatSet() {}
     79 
     80   void swap(FlatSet& x) { rep_.swap(x.rep_); }
     81   void clear_no_resize() { rep_.clear_no_resize(); }
     82   void clear() { rep_.clear(); }
     83   void reserve(size_t N) { rep_.Resize(std::max(N, size())); }
     84   void rehash(size_t N) { rep_.Resize(std::max(N, size())); }
     85   void resize(size_t N) { rep_.Resize(std::max(N, size())); }
     86   size_t size() const { return rep_.size(); }
     87   bool empty() const { return size() == 0; }
     88   size_t bucket_count() const { return rep_.bucket_count(); }
     89   hasher hash_function() const { return rep_.hash_function(); }
     90   key_equal key_eq() const { return rep_.key_eq(); }
     91 
     92   class const_iterator {
     93    public:
     94     typedef typename FlatSet::difference_type difference_type;
     95     typedef typename FlatSet::value_type value_type;
     96     typedef typename FlatSet::const_pointer pointer;
     97     typedef typename FlatSet::const_reference reference;
     98     typedef ::std::forward_iterator_tag iterator_category;
     99 
    100     const_iterator() : b_(nullptr), end_(nullptr), i_(0) {}
    101 
    102     // Make iterator pointing at first element at or after b.
    103     const_iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) {
    104       SkipUnused();
    105     }
    106 
    107     // Make iterator pointing exactly at ith element in b, which must exist.
    108     const_iterator(Bucket* b, Bucket* end, uint32 i)
    109         : b_(b), end_(end), i_(i) {}
    110 
    111     reference operator*() const { return key(); }
    112     pointer operator->() const { return &key(); }
    113     bool operator==(const const_iterator& x) const {
    114       return b_ == x.b_ && i_ == x.i_;
    115     }
    116     bool operator!=(const const_iterator& x) const { return !(*this == x); }
    117     const_iterator& operator++() {
    118       DCHECK(b_ != end_);
    119       i_++;
    120       SkipUnused();
    121       return *this;
    122     }
    123     const_iterator operator++(int /*indicates postfix*/) {
    124       const_iterator tmp(*this);
    125       ++*this;
    126       return tmp;
    127     }
    128 
    129    private:
    130     friend class FlatSet;
    131     Bucket* b_;
    132     Bucket* end_;
    133     uint32 i_;
    134 
    135     reference key() const { return b_->key(i_); }
    136     void SkipUnused() {
    137       while (b_ < end_) {
    138         if (i_ >= Rep::kWidth) {
    139           i_ = 0;
    140           b_++;
    141         } else if (b_->marker[i_] < 2) {
    142           i_++;
    143         } else {
    144           break;
    145         }
    146       }
    147     }
    148   };
    149 
    150   typedef const_iterator iterator;
    151 
    152   iterator begin() { return iterator(rep_.start(), rep_.limit()); }
    153   iterator end() { return iterator(rep_.limit(), rep_.limit()); }
    154   const_iterator begin() const {
    155     return const_iterator(rep_.start(), rep_.limit());
    156   }
    157   const_iterator end() const {
    158     return const_iterator(rep_.limit(), rep_.limit());
    159   }
    160 
    161   size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; }
    162   iterator find(const Key& k) {
    163     auto r = rep_.Find(k);
    164     return r.found ? iterator(r.b, rep_.limit(), r.index) : end();
    165   }
    166   const_iterator find(const Key& k) const {
    167     auto r = rep_.Find(k);
    168     return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end();
    169   }
    170 
    171   std::pair<iterator, bool> insert(const Key& k) { return Insert(k); }
    172   template <typename InputIter>
    173   void insert(InputIter first, InputIter last) {
    174     for (; first != last; ++first) {
    175       insert(*first);
    176     }
    177   }
    178 
    179   template <typename... Args>
    180   std::pair<iterator, bool> emplace(Args&&... args) {
    181     rep_.MaybeResize();
    182     auto r = rep_.FindOrInsert(std::forward<Args>(args)...);
    183     const bool inserted = !r.found;
    184     return {iterator(r.b, rep_.limit(), r.index), inserted};
    185   }
    186 
    187   size_t erase(const Key& k) {
    188     auto r = rep_.Find(k);
    189     if (!r.found) return 0;
    190     rep_.Erase(r.b, r.index);
    191     return 1;
    192   }
    193   iterator erase(iterator pos) {
    194     rep_.Erase(pos.b_, pos.i_);
    195     ++pos;
    196     return pos;
    197   }
    198   iterator erase(iterator pos, iterator last) {
    199     for (; pos != last; ++pos) {
    200       rep_.Erase(pos.b_, pos.i_);
    201     }
    202     return pos;
    203   }
    204 
    205   std::pair<iterator, iterator> equal_range(const Key& k) {
    206     auto pos = find(k);
    207     if (pos == end()) {
    208       return std::make_pair(pos, pos);
    209     } else {
    210       auto next = pos;
    211       ++next;
    212       return std::make_pair(pos, next);
    213     }
    214   }
    215   std::pair<const_iterator, const_iterator> equal_range(const Key& k) const {
    216     auto pos = find(k);
    217     if (pos == end()) {
    218       return std::make_pair(pos, pos);
    219     } else {
    220       auto next = pos;
    221       ++next;
    222       return std::make_pair(pos, next);
    223     }
    224   }
    225 
    226   bool operator==(const FlatSet& x) const {
    227     if (size() != x.size()) return false;
    228     for (const auto& elem : x) {
    229       auto i = find(elem);
    230       if (i == end()) return false;
    231     }
    232     return true;
    233   }
    234   bool operator!=(const FlatSet& x) const { return !(*this == x); }
    235 
    236   // If key exists in the table, prefetch it.  This is a hint, and may
    237   // have no effect.
    238   void prefetch_value(const Key& key) const { rep_.Prefetch(key); }
    239 
    240  private:
    241   using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>;
    242 
    243   // Bucket stores kWidth <marker, key, value> triples.
    244   // The data is organized as three parallel arrays to reduce padding.
    245   struct Bucket {
    246     uint8 marker[Rep::kWidth];
    247 
    248     // Wrap keys in union to control construction and destruction.
    249     union Storage {
    250       Key key[Rep::kWidth];
    251       Storage() {}
    252       ~Storage() {}
    253     } storage;
    254 
    255     Key& key(uint32 i) {
    256       DCHECK_GE(marker[i], 2);
    257       return storage.key[i];
    258     }
    259     void Destroy(uint32 i) { storage.key[i].Key::~Key(); }
    260     void MoveFrom(uint32 i, Bucket* src, uint32 src_index) {
    261       new (&storage.key[i]) Key(std::move(src->storage.key[src_index]));
    262     }
    263     void CopyFrom(uint32 i, Bucket* src, uint32 src_index) {
    264       new (&storage.key[i]) Key(src->storage.key[src_index]);
    265     }
    266   };
    267 
    268   std::pair<iterator, bool> Insert(const Key& k) {
    269     rep_.MaybeResize();
    270     auto r = rep_.FindOrInsert(k);
    271     const bool inserted = !r.found;
    272     return {iterator(r.b, rep_.limit(), r.index), inserted};
    273   }
    274 
    275   Rep rep_;
    276 };
    277 
    278 }  // namespace gtl
    279 }  // namespace tensorflow
    280 
    281 #endif  // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_
    282