1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 18 19 #include <stddef.h> 20 #include <functional> 21 #include <initializer_list> 22 #include <iterator> 23 #include <utility> 24 #include "tensorflow/core/lib/gtl/flatrep.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace gtl { 31 32 // FlatSet<K,...> provides a set of K. 33 // 34 // The map is implemented using an open-addressed hash table. A 35 // single array holds entire map contents and collisions are resolved 36 // by probing at a sequence of locations in the array. 37 template <typename Key, class Hash = hash<Key>, class Eq = std::equal_to<Key>> 38 class FlatSet { 39 private: 40 // Forward declare some internal types needed in public section. 41 struct Bucket; 42 43 public: 44 typedef Key key_type; 45 typedef Key value_type; 46 typedef Hash hasher; 47 typedef Eq key_equal; 48 typedef size_t size_type; 49 typedef ptrdiff_t difference_type; 50 typedef value_type* pointer; 51 typedef const value_type* const_pointer; 52 typedef value_type& reference; 53 typedef const value_type& const_reference; 54 55 FlatSet() : FlatSet(1) {} 56 57 explicit FlatSet(size_t N, const Hash& hf = Hash(), const Eq& eq = Eq()) 58 : rep_(N, hf, eq) {} 59 60 FlatSet(const FlatSet& src) : rep_(src.rep_) {} 61 62 template <typename InputIter> 63 FlatSet(InputIter first, InputIter last, size_t N = 1, 64 const Hash& hf = Hash(), const Eq& eq = Eq()) 65 : FlatSet(N, hf, eq) { 66 insert(first, last); 67 } 68 69 FlatSet(std::initializer_list<value_type> init, size_t N = 1, 70 const Hash& hf = Hash(), const Eq& eq = Eq()) 71 : FlatSet(init.begin(), init.end(), N, hf, eq) {} 72 73 FlatSet& operator=(const FlatSet& src) { 74 rep_.CopyFrom(src.rep_); 75 return *this; 76 } 77 78 ~FlatSet() {} 79 80 void swap(FlatSet& x) { rep_.swap(x.rep_); } 81 void clear_no_resize() { rep_.clear_no_resize(); } 82 void clear() { rep_.clear(); } 83 void reserve(size_t N) { rep_.Resize(std::max(N, size())); } 84 void rehash(size_t N) { rep_.Resize(std::max(N, size())); } 85 void resize(size_t N) { rep_.Resize(std::max(N, size())); } 86 size_t size() const { return rep_.size(); } 87 bool empty() const { return size() == 0; } 88 size_t bucket_count() const { return rep_.bucket_count(); } 89 hasher hash_function() const { return rep_.hash_function(); } 90 key_equal key_eq() const { return rep_.key_eq(); } 91 92 class const_iterator { 93 public: 94 typedef typename FlatSet::difference_type difference_type; 95 typedef typename FlatSet::value_type value_type; 96 typedef typename FlatSet::const_pointer pointer; 97 typedef typename FlatSet::const_reference reference; 98 typedef ::std::forward_iterator_tag iterator_category; 99 100 const_iterator() : b_(nullptr), end_(nullptr), i_(0) {} 101 102 // Make iterator pointing at first element at or after b. 103 const_iterator(Bucket* b, Bucket* end) : b_(b), end_(end), i_(0) { 104 SkipUnused(); 105 } 106 107 // Make iterator pointing exactly at ith element in b, which must exist. 108 const_iterator(Bucket* b, Bucket* end, uint32 i) 109 : b_(b), end_(end), i_(i) {} 110 111 reference operator*() const { return key(); } 112 pointer operator->() const { return &key(); } 113 bool operator==(const const_iterator& x) const { 114 return b_ == x.b_ && i_ == x.i_; 115 } 116 bool operator!=(const const_iterator& x) const { return !(*this == x); } 117 const_iterator& operator++() { 118 DCHECK(b_ != end_); 119 i_++; 120 SkipUnused(); 121 return *this; 122 } 123 const_iterator operator++(int /*indicates postfix*/) { 124 const_iterator tmp(*this); 125 ++*this; 126 return tmp; 127 } 128 129 private: 130 friend class FlatSet; 131 Bucket* b_; 132 Bucket* end_; 133 uint32 i_; 134 135 reference key() const { return b_->key(i_); } 136 void SkipUnused() { 137 while (b_ < end_) { 138 if (i_ >= Rep::kWidth) { 139 i_ = 0; 140 b_++; 141 } else if (b_->marker[i_] < 2) { 142 i_++; 143 } else { 144 break; 145 } 146 } 147 } 148 }; 149 150 typedef const_iterator iterator; 151 152 iterator begin() { return iterator(rep_.start(), rep_.limit()); } 153 iterator end() { return iterator(rep_.limit(), rep_.limit()); } 154 const_iterator begin() const { 155 return const_iterator(rep_.start(), rep_.limit()); 156 } 157 const_iterator end() const { 158 return const_iterator(rep_.limit(), rep_.limit()); 159 } 160 161 size_t count(const Key& k) const { return rep_.Find(k).found ? 1 : 0; } 162 iterator find(const Key& k) { 163 auto r = rep_.Find(k); 164 return r.found ? iterator(r.b, rep_.limit(), r.index) : end(); 165 } 166 const_iterator find(const Key& k) const { 167 auto r = rep_.Find(k); 168 return r.found ? const_iterator(r.b, rep_.limit(), r.index) : end(); 169 } 170 171 std::pair<iterator, bool> insert(const Key& k) { return Insert(k); } 172 template <typename InputIter> 173 void insert(InputIter first, InputIter last) { 174 for (; first != last; ++first) { 175 insert(*first); 176 } 177 } 178 179 template <typename... Args> 180 std::pair<iterator, bool> emplace(Args&&... args) { 181 rep_.MaybeResize(); 182 auto r = rep_.FindOrInsert(std::forward<Args>(args)...); 183 const bool inserted = !r.found; 184 return {iterator(r.b, rep_.limit(), r.index), inserted}; 185 } 186 187 size_t erase(const Key& k) { 188 auto r = rep_.Find(k); 189 if (!r.found) return 0; 190 rep_.Erase(r.b, r.index); 191 return 1; 192 } 193 iterator erase(iterator pos) { 194 rep_.Erase(pos.b_, pos.i_); 195 ++pos; 196 return pos; 197 } 198 iterator erase(iterator pos, iterator last) { 199 for (; pos != last; ++pos) { 200 rep_.Erase(pos.b_, pos.i_); 201 } 202 return pos; 203 } 204 205 std::pair<iterator, iterator> equal_range(const Key& k) { 206 auto pos = find(k); 207 if (pos == end()) { 208 return std::make_pair(pos, pos); 209 } else { 210 auto next = pos; 211 ++next; 212 return std::make_pair(pos, next); 213 } 214 } 215 std::pair<const_iterator, const_iterator> equal_range(const Key& k) const { 216 auto pos = find(k); 217 if (pos == end()) { 218 return std::make_pair(pos, pos); 219 } else { 220 auto next = pos; 221 ++next; 222 return std::make_pair(pos, next); 223 } 224 } 225 226 bool operator==(const FlatSet& x) const { 227 if (size() != x.size()) return false; 228 for (const auto& elem : x) { 229 auto i = find(elem); 230 if (i == end()) return false; 231 } 232 return true; 233 } 234 bool operator!=(const FlatSet& x) const { return !(*this == x); } 235 236 // If key exists in the table, prefetch it. This is a hint, and may 237 // have no effect. 238 void prefetch_value(const Key& key) const { rep_.Prefetch(key); } 239 240 private: 241 using Rep = internal::FlatRep<Key, Bucket, Hash, Eq>; 242 243 // Bucket stores kWidth <marker, key, value> triples. 244 // The data is organized as three parallel arrays to reduce padding. 245 struct Bucket { 246 uint8 marker[Rep::kWidth]; 247 248 // Wrap keys in union to control construction and destruction. 249 union Storage { 250 Key key[Rep::kWidth]; 251 Storage() {} 252 ~Storage() {} 253 } storage; 254 255 Key& key(uint32 i) { 256 DCHECK_GE(marker[i], 2); 257 return storage.key[i]; 258 } 259 void Destroy(uint32 i) { storage.key[i].Key::~Key(); } 260 void MoveFrom(uint32 i, Bucket* src, uint32 src_index) { 261 new (&storage.key[i]) Key(std::move(src->storage.key[src_index])); 262 } 263 void CopyFrom(uint32 i, Bucket* src, uint32 src_index) { 264 new (&storage.key[i]) Key(src->storage.key[src_index]); 265 } 266 }; 267 268 std::pair<iterator, bool> Insert(const Key& k) { 269 rep_.MaybeResize(); 270 auto r = rep_.FindOrInsert(k); 271 const bool inserted = !r.found; 272 return {iterator(r.b, rep_.limit(), r.index), inserted}; 273 } 274 275 Rep rep_; 276 }; 277 278 } // namespace gtl 279 } // namespace tensorflow 280 281 #endif // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ 282