1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 17 #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 18 19 #include <string.h> 20 #include <utility> 21 #include "tensorflow/core/platform/prefetch.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 namespace gtl { 26 namespace internal { 27 28 // Internal representation for FlatMap and FlatSet. 29 // 30 // The representation is an open-addressed hash table. Conceptually, 31 // the representation is a flat array of entries. However we 32 // structure it as an array of buckets where each bucket holds 33 // kWidth entries along with metadata for the kWidth entries. The 34 // metadata marker is 35 // 36 // (a) kEmpty: the entry is empty 37 // (b) kDeleted: the entry has been deleted 38 // (c) other: the entry is occupied and has low-8 bits of its hash. 39 // These hash bits can be used to avoid potentially expensive 40 // key comparisons. 41 // 42 // FlatMap passes in a bucket that contains keys and values, FlatSet 43 // passes in a bucket that does not contain values. 44 template <typename Key, typename Bucket, class Hash, class Eq> 45 class FlatRep { 46 public: 47 // kWidth is the number of entries stored in a bucket. 48 static const uint32 kBase = 3; 49 static const uint32 kWidth = (1 << kBase); 50 51 FlatRep(size_t N, const Hash& hf, const Eq& eq) : hash_(hf), equal_(eq) { 52 Init(N); 53 } 54 explicit FlatRep(const FlatRep& src) : hash_(src.hash_), equal_(src.equal_) { 55 Init(src.size()); 56 CopyEntries(src.array_, src.end_, CopyEntry()); 57 } 58 ~FlatRep() { 59 clear_no_resize(); 60 delete[] array_; 61 } 62 63 // Simple accessors. 64 size_t size() const { return not_empty_ - deleted_; } 65 size_t bucket_count() const { return mask_ + 1; } 66 Bucket* start() const { return array_; } 67 Bucket* limit() const { return end_; } 68 const Hash& hash_function() const { return hash_; } 69 const Eq& key_eq() const { return equal_; } 70 71 // Overwrite contents of *this with contents of src. 72 void CopyFrom(const FlatRep& src) { 73 if (this != &src) { 74 clear_no_resize(); 75 delete[] array_; 76 Init(src.size()); 77 CopyEntries(src.array_, src.end_, CopyEntry()); 78 } 79 } 80 81 void clear_no_resize() { 82 for (Bucket* b = array_; b != end_; b++) { 83 for (uint32 i = 0; i < kWidth; i++) { 84 if (b->marker[i] >= 2) { 85 b->Destroy(i); 86 b->marker[i] = kEmpty; 87 } 88 } 89 } 90 not_empty_ = 0; 91 deleted_ = 0; 92 } 93 94 void clear() { 95 clear_no_resize(); 96 grow_ = 0; // Consider shrinking in MaybeResize() 97 MaybeResize(); 98 } 99 100 void swap(FlatRep& x) { 101 using std::swap; 102 swap(array_, x.array_); 103 swap(end_, x.end_); 104 swap(lglen_, x.lglen_); 105 swap(mask_, x.mask_); 106 swap(not_empty_, x.not_empty_); 107 swap(deleted_, x.deleted_); 108 swap(grow_, x.grow_); 109 swap(shrink_, x.shrink_); 110 } 111 112 struct SearchResult { 113 bool found; 114 Bucket* b; 115 uint32 index; 116 }; 117 118 // Hash value is partitioned as follows: 119 // 1. Bottom 8 bits are stored in bucket to help speed up comparisons. 120 // 2. Next 3 bits give index inside bucket. 121 // 3. Remaining bits give bucket number. 122 123 // Find bucket/index for key k. 124 SearchResult Find(const Key& k) const { 125 size_t h = hash_(k); 126 const uint32 marker = Marker(h & 0xff); 127 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 128 uint32 num_probes = 1; // Needed for quadratic probing 129 while (true) { 130 uint32 bi = index & (kWidth - 1); 131 Bucket* b = &array_[index >> kBase]; 132 const uint32 x = b->marker[bi]; 133 if (x == marker && equal_(b->key(bi), k)) { 134 return {true, b, bi}; 135 } else if (x == kEmpty) { 136 return {false, nullptr, 0}; 137 } 138 index = NextIndex(index, num_probes); 139 num_probes++; 140 } 141 } 142 143 // Find bucket/index for key k, creating a new one if necessary. 144 // 145 // KeyType is a template parameter so that k's type is deduced and it 146 // becomes a universal reference which allows the key initialization 147 // below to use an rvalue constructor if available. 148 template <typename KeyType> 149 SearchResult FindOrInsert(KeyType&& k) { 150 size_t h = hash_(k); 151 const uint32 marker = Marker(h & 0xff); 152 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 153 uint32 num_probes = 1; // Needed for quadratic probing 154 Bucket* del = nullptr; // First encountered deletion for kInsert 155 uint32 di = 0; 156 while (true) { 157 uint32 bi = index & (kWidth - 1); 158 Bucket* b = &array_[index >> kBase]; 159 const uint32 x = b->marker[bi]; 160 if (x == marker && equal_(b->key(bi), k)) { 161 return {true, b, bi}; 162 } else if (!del && x == kDeleted) { 163 // Remember deleted index to use for insertion. 164 del = b; 165 di = bi; 166 } else if (x == kEmpty) { 167 if (del) { 168 // Store in the first deleted slot we encountered 169 b = del; 170 bi = di; 171 deleted_--; // not_empty_ does not change 172 } else { 173 not_empty_++; 174 } 175 b->marker[bi] = marker; 176 new (&b->key(bi)) Key(std::forward<KeyType>(k)); 177 return {false, b, bi}; 178 } 179 index = NextIndex(index, num_probes); 180 num_probes++; 181 } 182 } 183 184 void Erase(Bucket* b, uint32 i) { 185 b->Destroy(i); 186 b->marker[i] = kDeleted; 187 deleted_++; 188 grow_ = 0; // Consider shrinking on next insert 189 } 190 191 void Prefetch(const Key& k) const { 192 size_t h = hash_(k); 193 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 194 uint32 bi = index & (kWidth - 1); 195 Bucket* b = &array_[index >> kBase]; 196 port::prefetch<port::PREFETCH_HINT_T0>(&b->marker[bi]); 197 port::prefetch<port::PREFETCH_HINT_T0>(&b->storage.key[bi]); 198 } 199 200 inline void MaybeResize() { 201 if (not_empty_ < grow_) { 202 return; // Nothing to do 203 } 204 if (grow_ == 0) { 205 // Special value set by erase to cause shrink on next insert. 206 if (size() >= shrink_) { 207 // Not small enough to shrink. 208 grow_ = static_cast<size_t>(bucket_count() * 0.8); 209 if (not_empty_ < grow_) return; 210 } 211 } 212 Resize(size() + 1); 213 } 214 215 void Resize(size_t N) { 216 Bucket* old = array_; 217 Bucket* old_end = end_; 218 Init(N); 219 CopyEntries(old, old_end, MoveEntry()); 220 delete[] old; 221 } 222 223 private: 224 enum { kEmpty = 0, kDeleted = 1 }; // Special markers for an entry. 225 226 Hash hash_; // User-supplied hasher 227 Eq equal_; // User-supplied comparator 228 uint8 lglen_; // lg(#buckets) 229 Bucket* array_; // array of length (1 << lglen_) 230 Bucket* end_; // Points just past last bucket in array_ 231 size_t mask_; // (# of entries in table) - 1 232 size_t not_empty_; // Count of entries with marker != kEmpty 233 size_t deleted_; // Count of entries with marker == kDeleted 234 size_t grow_; // Grow array when not_empty_ >= grow_ 235 size_t shrink_; // Shrink array when size() < shrink_ 236 237 // Avoid kEmpty and kDeleted markers when computing hash values to 238 // store in Bucket::marker[]. 239 static uint32 Marker(uint32 hb) { return hb + (hb < 2 ? 2 : 0); } 240 241 void Init(size_t N) { 242 // Make enough room for N elements. 243 size_t lg = 0; // Smallest table is just one bucket. 244 while (N >= 0.8 * ((1 << lg) * kWidth)) { 245 lg++; 246 } 247 const size_t n = (1 << lg); 248 Bucket* array = new Bucket[n]; 249 for (size_t i = 0; i < n; i++) { 250 Bucket* b = &array[i]; 251 memset(b->marker, kEmpty, kWidth); 252 } 253 const size_t capacity = (1 << lg) * kWidth; 254 lglen_ = lg; 255 mask_ = capacity - 1; 256 array_ = array; 257 end_ = array + n; 258 not_empty_ = 0; 259 deleted_ = 0; 260 grow_ = static_cast<size_t>(capacity * 0.8); 261 if (lg == 0) { 262 // Already down to one bucket; no more shrinking. 263 shrink_ = 0; 264 } else { 265 shrink_ = static_cast<size_t>(grow_ * 0.4); // Must be less than 0.5 266 } 267 } 268 269 // Used by FreshInsert when we should copy from source. 270 struct CopyEntry { 271 inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { 272 dst->CopyFrom(dsti, src, srci); 273 } 274 }; 275 276 // Used by FreshInsert when we should move from source. 277 struct MoveEntry { 278 inline void operator()(Bucket* dst, uint32 dsti, Bucket* src, uint32 srci) { 279 dst->MoveFrom(dsti, src, srci); 280 src->Destroy(srci); 281 src->marker[srci] = kDeleted; 282 } 283 }; 284 285 template <typename Copier> 286 void CopyEntries(Bucket* start, Bucket* end, Copier copier) { 287 for (Bucket* b = start; b != end; b++) { 288 for (uint32 i = 0; i < kWidth; i++) { 289 if (b->marker[i] >= 2) { 290 FreshInsert(b, i, copier); 291 } 292 } 293 } 294 } 295 296 // Create an entry for the key numbered src_index in *src and return 297 // its bucket/index. Used for insertion into a fresh table. We 298 // assume that there are no deletions, and k does not already exist 299 // in the table. 300 template <typename Copier> 301 void FreshInsert(Bucket* src, uint32 src_index, Copier copier) { 302 size_t h = hash_(src->key(src_index)); 303 const uint32 marker = Marker(h & 0xff); 304 size_t index = (h >> 8) & mask_; // Holds bucket num and index-in-bucket 305 uint32 num_probes = 1; // Needed for quadratic probing 306 while (true) { 307 uint32 bi = index & (kWidth - 1); 308 Bucket* b = &array_[index >> kBase]; 309 const uint32 x = b->marker[bi]; 310 if (x == 0) { 311 b->marker[bi] = marker; 312 not_empty_++; 313 copier(b, bi, src, src_index); 314 return; 315 } 316 index = NextIndex(index, num_probes); 317 num_probes++; 318 } 319 } 320 321 inline size_t NextIndex(size_t i, uint32 num_probes) const { 322 // Quadratic probing. 323 return (i + num_probes) & mask_; 324 } 325 }; 326 327 } // namespace internal 328 } // namespace gtl 329 } // namespace tensorflow 330 331 #endif // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ 332