1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/lookup_table_op.h" 17 #define EIGEN_USE_THREADS 18 19 #include <string> 20 #include <type_traits> 21 #include <utility> 22 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/kernels/initializable_lookup_table.h" 26 #include "tensorflow/core/lib/gtl/inlined_vector.h" 27 #include "tensorflow/core/lib/hash/hash.h" 28 29 namespace tensorflow { 30 namespace lookup { 31 32 // Lookup table that wraps an unordered_map, where the key and value data type 33 // is specified. Each individual value must be a scalar. If vector values are 34 // required, use MutableHashTableOfTensors. 35 // 36 // This table is mutable and thread safe - Insert can be called at any time. 37 // 38 // Sample use case: 39 // 40 // MutableHashTableOfScalars<int64, int64> table; // int64 -> int64. 41 // // Populate the table, elements could be added in one or multiple calls. 42 // table.Insert(key_tensor, value_tensor); // Populate the table. 43 // 44 // table.Find(in_t, &out_t, default_t) 45 // 46 template <class K, class V> 47 class MutableHashTableOfScalars final : public LookupInterface { 48 public: 49 MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} 50 51 size_t size() const override { 52 mutex_lock l(mu_); 53 return table_.size(); 54 } 55 56 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, 57 const Tensor& default_value) override { 58 const V default_val = default_value.flat<V>()(0); 59 const auto key_values = key.flat<K>(); 60 auto value_values = value->flat<V>(); 61 62 mutex_lock l(mu_); 63 for (int64 i = 0; i < key_values.size(); ++i) { 64 value_values(i) = gtl::FindWithDefault( 65 table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), 66 default_val); 67 } 68 69 return Status::OK(); 70 } 71 72 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { 73 const auto key_values = keys.flat<K>(); 74 const auto value_values = values.flat<V>(); 75 76 mutex_lock l(mu_); 77 if (clear) { 78 table_.clear(); 79 } 80 for (int64 i = 0; i < key_values.size(); ++i) { 81 gtl::InsertOrUpdate(&table_, 82 SubtleMustCopyUnlessStringOrFloat(key_values(i)), 83 SubtleMustCopyUnlessStringOrFloat(value_values(i))); 84 } 85 return Status::OK(); 86 } 87 88 Status Insert(OpKernelContext* ctx, const Tensor& keys, 89 const Tensor& values) override { 90 return DoInsert(false, keys, values); 91 } 92 93 Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 94 const Tensor& values) override { 95 return DoInsert(true, keys, values); 96 } 97 98 Status ExportValues(OpKernelContext* ctx) override { 99 mutex_lock l(mu_); 100 int64 size = table_.size(); 101 102 Tensor* keys; 103 Tensor* values; 104 TF_RETURN_IF_ERROR( 105 ctx->allocate_output("keys", TensorShape({size}), &keys)); 106 TF_RETURN_IF_ERROR( 107 ctx->allocate_output("values", TensorShape({size}), &values)); 108 109 auto keys_data = keys->flat<K>(); 110 auto values_data = values->flat<V>(); 111 int64 i = 0; 112 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { 113 keys_data(i) = it->first; 114 values_data(i) = it->second; 115 } 116 return Status::OK(); 117 } 118 119 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } 120 121 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } 122 123 TensorShape key_shape() const final { return TensorShape(); } 124 125 TensorShape value_shape() const override { return TensorShape(); } 126 127 int64 MemoryUsed() const override { 128 int64 ret = 0; 129 mutex_lock l(mu_); 130 for (unsigned i = 0; i < table_.bucket_count(); ++i) { 131 size_t bucket_size = table_.bucket_size(i); 132 if (bucket_size == 0) { 133 ret++; 134 } else { 135 ret += bucket_size; 136 } 137 } 138 return sizeof(MutableHashTableOfScalars) + ret; 139 } 140 141 private: 142 // TODO(andreasst): consider using a read/write lock or a concurrent map 143 mutable mutex mu_; 144 std::unordered_map<K, V> table_ GUARDED_BY(mu_); 145 }; 146 147 // Lookup table that wraps an unordered_map. Behaves identical to 148 // MutableHashTableOfScalars except that each value must be a vector. 149 template <class K, class V> 150 class MutableHashTableOfTensors final : public LookupInterface { 151 public: 152 MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) { 153 OP_REQUIRES_OK(ctx, 154 GetNodeAttr(kernel->def(), "value_shape", &value_shape_)); 155 OP_REQUIRES( 156 ctx, TensorShapeUtils::IsVector(value_shape_), 157 errors::InvalidArgument("Default value must be a vector, got shape ", 158 value_shape_.DebugString())); 159 } 160 161 size_t size() const override { 162 mutex_lock l(mu_); 163 return table_.size(); 164 } 165 166 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, 167 const Tensor& default_value) override { 168 const auto default_flat = default_value.flat<V>(); 169 const auto key_values = key.flat<K>(); 170 auto value_values = value->flat_inner_dims<V, 2>(); 171 int64 value_dim = value_shape_.dim_size(0); 172 173 mutex_lock l(mu_); 174 for (int64 i = 0; i < key_values.size(); ++i) { 175 ValueArray* value_vec = gtl::FindOrNull( 176 table_, SubtleMustCopyUnlessStringOrFloat(key_values(i))); 177 if (value_vec != nullptr) { 178 for (int64 j = 0; j < value_dim; j++) { 179 value_values(i, j) = value_vec->at(j); 180 } 181 } else { 182 for (int64 j = 0; j < value_dim; j++) { 183 value_values(i, j) = default_flat(j); 184 } 185 } 186 } 187 188 return Status::OK(); 189 } 190 191 Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { 192 const auto key_values = keys.flat<K>(); 193 const auto value_values = values.flat_inner_dims<V, 2>(); 194 int64 value_dim = value_shape_.dim_size(0); 195 196 mutex_lock l(mu_); 197 if (clear) { 198 table_.clear(); 199 } 200 for (int64 i = 0; i < key_values.size(); ++i) { 201 ValueArray value_vec; 202 for (int64 j = 0; j < value_dim; j++) { 203 V value = value_values(i, j); 204 value_vec.push_back(value); 205 } 206 gtl::InsertOrUpdate( 207 &table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), value_vec); 208 } 209 return Status::OK(); 210 } 211 212 Status Insert(OpKernelContext* ctx, const Tensor& keys, 213 const Tensor& values) override { 214 return DoInsert(false, keys, values); 215 } 216 217 Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 218 const Tensor& values) override { 219 return DoInsert(true, keys, values); 220 } 221 222 Status ExportValues(OpKernelContext* ctx) override { 223 mutex_lock l(mu_); 224 int64 size = table_.size(); 225 int64 value_dim = value_shape_.dim_size(0); 226 227 Tensor* keys; 228 Tensor* values; 229 TF_RETURN_IF_ERROR( 230 ctx->allocate_output("keys", TensorShape({size}), &keys)); 231 TF_RETURN_IF_ERROR(ctx->allocate_output( 232 "values", TensorShape({size, value_dim}), &values)); 233 234 auto keys_data = keys->flat<K>(); 235 auto values_data = values->matrix<V>(); 236 int64 i = 0; 237 for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { 238 K key = it->first; 239 ValueArray value = it->second; 240 keys_data(i) = key; 241 for (int64 j = 0; j < value_dim; j++) { 242 values_data(i, j) = value[j]; 243 } 244 } 245 return Status::OK(); 246 } 247 248 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } 249 250 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } 251 252 TensorShape key_shape() const final { return TensorShape(); } 253 254 TensorShape value_shape() const override { return value_shape_; } 255 256 int64 MemoryUsed() const override { 257 int64 ret = 0; 258 mutex_lock l(mu_); 259 for (unsigned i = 0; i < table_.bucket_count(); ++i) { 260 size_t bucket_size = table_.bucket_size(i); 261 if (bucket_size == 0) { 262 ret++; 263 } else { 264 ret += bucket_size; 265 } 266 } 267 return sizeof(MutableHashTableOfTensors) + ret; 268 } 269 270 private: 271 TensorShape value_shape_; 272 // TODO(andreasst): consider using a read/write lock or a concurrent map 273 mutable mutex mu_; 274 typedef gtl::InlinedVector<V, 4> ValueArray; 275 std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); 276 }; 277 278 namespace { 279 280 template <typename T> 281 inline uint64 HashScalar(const T& key) { 282 return static_cast<uint64>(key); 283 } 284 285 inline uint64 HashScalar(const string& key) { return Hash64(key); } 286 287 // If the given shape is a scalar return {1} instead. Otherwise leave it alone. 288 TensorShape MaybeVectorizeShape(const TensorShape& shape) { 289 if (shape.dims() == 0) { 290 return TensorShape({1}); 291 } 292 return shape; 293 } 294 295 } // namespace 296 297 // Modeled after densehashtable in https://github.com/sparsehash/sparsehash 298 template <class K, class V> 299 class MutableDenseHashTable final : public LookupInterface { 300 public: 301 MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) { 302 OP_REQUIRES_OK( 303 ctx, GetNodeAttr(kernel->def(), "max_load_factor", &max_load_factor_)); 304 OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1, 305 errors::InvalidArgument( 306 "max_load_factor must be between 0 and 1, got: ", 307 max_load_factor_)); 308 309 OP_REQUIRES_OK(ctx, 310 GetNodeAttr(kernel->def(), "value_shape", &value_shape_)); 311 OP_REQUIRES(ctx, 312 TensorShapeUtils::IsScalar(value_shape_) || 313 TensorShapeUtils::IsVector(value_shape_), 314 errors::InvalidArgument( 315 "Empty value must be a scalar or a vector, got shape ", 316 value_shape_.DebugString())); 317 318 const Tensor* empty_key_input; 319 OP_REQUIRES_OK(ctx, ctx->input("empty_key", &empty_key_input)); 320 key_shape_ = empty_key_input->shape(); 321 OP_REQUIRES(ctx, 322 TensorShapeUtils::IsScalar(key_shape_) || 323 TensorShapeUtils::IsVector(key_shape_), 324 errors::InvalidArgument( 325 "Empty key must be a scalar or a vector, got shape ", 326 key_shape_.DebugString())); 327 empty_key_ = PersistentTensor(*empty_key_input); 328 empty_key_hash_ = HashKey( 329 empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}), 330 0); 331 332 int64 initial_num_buckets; 333 OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets", 334 &initial_num_buckets)); 335 OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets)); 336 } 337 338 size_t size() const override LOCKS_EXCLUDED(mu_) { 339 mutex_lock l(mu_); 340 return num_entries_; 341 } 342 343 Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, 344 const Tensor& default_value) override LOCKS_EXCLUDED(mu_) { 345 const int64 num_elements = key.dim_size(0); 346 const int64 key_size = key_shape_.num_elements(); 347 const int64 value_size = value_shape_.num_elements(); 348 if (key.NumElements() != num_elements * key_size) { 349 TensorShape expected_shape({num_elements}); 350 expected_shape.AppendShape(key_shape_); 351 return errors::InvalidArgument("Expected key shape ", 352 expected_shape.DebugString(), " got ", 353 key.shape().DebugString()); 354 } 355 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); 356 auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); 357 const auto default_flat = default_value.flat<V>(); 358 359 mutex_lock l(mu_); 360 const auto key_buckets_matrix = 361 key_buckets_.AccessTensor(ctx)->template matrix<K>(); 362 const auto value_buckets_matrix = 363 value_buckets_.AccessTensor(ctx)->template matrix<V>(); 364 const auto empty_key_matrix = 365 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); 366 const int64 bit_mask = num_buckets_ - 1; 367 // TODO(andreasst): parallelize using work_sharder 368 for (int64 i = 0; i < num_elements; ++i) { 369 const uint64 key_hash = HashKey(key_matrix, i); 370 if (empty_key_hash_ == key_hash && 371 IsEqualKey(empty_key_matrix, 0, key_matrix, i)) { 372 return errors::InvalidArgument( 373 "Using the empty_key as a table key is not allowed"); 374 } 375 int64 bucket_index = key_hash & bit_mask; 376 int64 num_probes = 0; 377 while (true) { 378 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { 379 for (int64 j = 0; j < value_size; ++j) { 380 // TODO(andreasst): check if we can get rid of SubtleMustCopy 381 // here and elsewhere in this file. 382 value_matrix(i, j) = SubtleMustCopyUnlessStringOrFloat( 383 value_buckets_matrix(bucket_index, j)); 384 } 385 break; 386 } 387 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) { 388 for (int64 j = 0; j < value_size; ++j) { 389 value_matrix(i, j) = 390 SubtleMustCopyUnlessStringOrFloat(default_flat(j)); 391 } 392 break; 393 } 394 ++num_probes; 395 bucket_index = 396 (bucket_index + num_probes) & bit_mask; // quadratic probing 397 if (num_probes >= num_buckets_) { 398 return errors::Internal( 399 "Internal error in MutableDenseHashTable lookup"); 400 } 401 } 402 } 403 return Status::OK(); 404 } 405 406 Status Insert(OpKernelContext* ctx, const Tensor& key, 407 const Tensor& value) override LOCKS_EXCLUDED(mu_) { 408 if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { 409 TensorShape expected_shape({key.dim_size(0)}); 410 expected_shape.AppendShape(key_shape_); 411 return errors::InvalidArgument("Expected key shape ", 412 expected_shape.DebugString(), " got ", 413 key.shape().DebugString()); 414 } 415 mutex_lock l(mu_); 416 // For simplicity we assume that all keys in the input result in inserts 417 // rather than updates. That means we may grow the table even though we 418 // don't need to. As long as the number of keys inserted in one call is 419 // small compared to the size of the map, the impact of this is minimal. 420 const int64 pending_num_entries = num_entries_ + key.dim_size(0); 421 if (pending_num_entries > num_buckets_ * max_load_factor_) { 422 int64 new_num_buckets = num_buckets_; 423 do { 424 new_num_buckets <<= 1; 425 } while (pending_num_entries > new_num_buckets * max_load_factor_); 426 TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets)); 427 } 428 return DoInsert(ctx, key, value, false); 429 } 430 431 Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 432 const Tensor& values) override LOCKS_EXCLUDED(mu_) { 433 mutex_lock l(mu_); 434 num_buckets_ = keys.dim_size(0); 435 key_buckets_ = PersistentTensor(keys); 436 value_buckets_ = PersistentTensor(values); 437 // Count the number of keys that are not the empty_key. This requires 438 // iterating through the whole table but that is OK as we only execute it 439 // during checkpoint restore. 440 num_entries_ = 0; 441 const auto empty_key_tensor = 442 empty_key_.AccessTensor(ctx)->template shaped<K, 2>( 443 {1, key_shape_.num_elements()}); 444 const auto key_buckets_tensor = 445 key_buckets_.AccessTensor(ctx)->template matrix<K>(); 446 for (int64 i = 0; i < num_buckets_; ++i) { 447 if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) { 448 ++num_entries_; 449 } 450 } 451 return Status::OK(); 452 } 453 454 Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { 455 mutex_lock l(mu_); 456 Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx); 457 Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx); 458 TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor)); 459 TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_tensor)); 460 return Status::OK(); 461 } 462 463 Status CheckKeyAndValueTensorsForImport(const Tensor& keys, 464 const Tensor& values) override { 465 TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); 466 TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); 467 468 // The storage format in key_buckets_ and value_buckets_ is always vectors, 469 // even if the inputs are scalars. This is what eventually gets exported 470 // and is expected by the import method as well. 471 TensorShape key_shape = MaybeVectorizeShape(key_shape_); 472 TensorShape value_shape = MaybeVectorizeShape(value_shape_); 473 474 // Compute the final expected shape of the value by starting with the shape 475 // of all keys, removing the dimensions particular to each key and then 476 // appending the shape of a single value. 477 TensorShape expected_value_shape = keys.shape(); 478 expected_value_shape.RemoveLastDims(key_shape.dims()); 479 expected_value_shape.AppendShape(value_shape); 480 if (values.shape() != expected_value_shape) { 481 return errors::InvalidArgument( 482 "Expected shape ", expected_value_shape.DebugString(), 483 " for value, got ", values.shape().DebugString()); 484 } 485 return Status::OK(); 486 } 487 488 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } 489 490 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } 491 492 TensorShape key_shape() const override { return key_shape_; } 493 494 TensorShape value_shape() const override { return value_shape_; } 495 496 int64 MemoryUsed() const override { 497 mutex_lock l(mu_); 498 return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + 499 value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); 500 } 501 502 private: 503 Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, 504 bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 505 const int64 num_elements = key.dim_size(0); 506 const int64 value_size = value_shape_.num_elements(); 507 const int64 key_size = key_shape_.num_elements(); 508 const auto key_matrix = key.shaped<K, 2>({num_elements, key_size}); 509 auto value_matrix = value.shaped<V, 2>({num_elements, value_size}); 510 511 auto key_buckets_matrix = 512 key_buckets_.AccessTensor(ctx)->template matrix<K>(); 513 auto value_buckets_matrix = 514 value_buckets_.AccessTensor(ctx)->template matrix<V>(); 515 const auto empty_key_tensor = 516 empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size}); 517 const int64 bit_mask = num_buckets_ - 1; 518 for (int64 i = 0; i < num_elements; ++i) { 519 const uint64 key_hash = HashKey(key_matrix, i); 520 if (empty_key_hash_ == key_hash && 521 IsEqualKey(empty_key_tensor, 0, key_matrix, i)) { 522 if (ignore_empty_key) { 523 continue; 524 } 525 return errors::InvalidArgument( 526 "Using the empty_key as a table key is not allowed"); 527 } 528 int64 bucket_index = key_hash & bit_mask; 529 int64 num_probes = 0; 530 while (true) { 531 if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) { 532 for (int64 j = 0; j < value_size; ++j) { 533 value_buckets_matrix(bucket_index, j) = 534 SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j)); 535 } 536 break; 537 } 538 if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) { 539 ++num_entries_; 540 for (int64 j = 0; j < key_size; ++j) { 541 key_buckets_matrix(bucket_index, j) = 542 SubtleMustCopyUnlessStringOrFloat(key_matrix(i, j)); 543 } 544 for (int64 j = 0; j < value_size; ++j) { 545 value_buckets_matrix(bucket_index, j) = 546 SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j)); 547 } 548 break; 549 } 550 ++num_probes; 551 bucket_index = 552 (bucket_index + num_probes) & bit_mask; // quadratic probing 553 if (num_probes >= num_buckets_) { 554 return errors::Internal( 555 "Internal error in MutableDenseHashTable insert"); 556 } 557 } 558 } 559 return Status::OK(); 560 } 561 562 Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets) 563 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 564 if (new_num_buckets < 4 || 565 ((new_num_buckets & (new_num_buckets - 1)) != 0)) { 566 return errors::InvalidArgument( 567 "Number of buckets must be at least 4 and a power of 2, got: ", 568 new_num_buckets); 569 } 570 num_buckets_ = new_num_buckets; 571 num_entries_ = 0; 572 573 const int64 key_size = key_shape_.num_elements(); 574 Tensor* key_buckets_tensor; 575 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 576 key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_, 577 &key_buckets_tensor)); 578 auto key_buckets_matrix = key_buckets_tensor->matrix<K>(); 579 const auto empty_key_flat = 580 empty_key_.AccessTensor(ctx)->template flat<K>(); 581 for (int64 i = 0; i < num_buckets_; ++i) { 582 for (int64 j = 0; j < key_size; ++j) { 583 key_buckets_matrix(i, j) = empty_key_flat(j); 584 } 585 } 586 587 const int64 value_size = value_shape_.num_elements(); 588 Tensor* value_buckets_tensor; 589 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 590 value_dtype(), TensorShape({num_buckets_, value_size}), &value_buckets_, 591 &value_buckets_tensor)); 592 auto value_buckets_matrix = value_buckets_tensor->matrix<V>(); 593 for (int64 i = 0; i < num_buckets_; ++i) { 594 for (int64 j = 0; j < value_size; ++j) { 595 // Initialize values to the default value for the type to avoid 596 // exposing uninitialized memory in ExportValues(). 597 value_buckets_matrix(i, j) = V(); 598 } 599 } 600 return Status::OK(); 601 } 602 603 Status Rebucket(OpKernelContext* ctx, int64 num_new_buckets) 604 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 605 Tensor old_key_buckets = *key_buckets_.AccessTensor(ctx); 606 Tensor old_value_buckets = *value_buckets_.AccessTensor(ctx); 607 TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets)); 608 return DoInsert(ctx, old_key_buckets, old_value_buckets, true); 609 } 610 611 uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64 index) const { 612 if (key_shape_.num_elements() == 1) { 613 return HashScalar(key(index, 0)); 614 } 615 uint64 result = 0; 616 for (int64 i = 0; i < key_shape_.num_elements(); ++i) { 617 result = Hash64Combine(result, HashScalar(key(index, i))); 618 } 619 return result; 620 } 621 622 // Use a template to allow this function to be used both with Matrix and 623 // ConstMatrix types. 624 template <typename MT2> 625 bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64 index1, MT2 tensor2, 626 int64 index2) const { 627 for (int64 i = 0; i < key_shape_.num_elements(); ++i) { 628 if (tensor1(index1, i) != tensor2(index2, i)) { 629 return false; 630 } 631 } 632 return true; 633 } 634 635 TensorShape key_shape_; 636 TensorShape value_shape_; 637 float max_load_factor_; 638 mutable mutex mu_; 639 int64 num_entries_ GUARDED_BY(mu_); 640 int64 num_buckets_ GUARDED_BY(mu_); 641 PersistentTensor key_buckets_ GUARDED_BY(mu_); 642 PersistentTensor value_buckets_ GUARDED_BY(mu_); 643 PersistentTensor empty_key_; 644 uint64 empty_key_hash_; 645 }; 646 647 } // namespace lookup 648 649 // Table lookup op. Perform the lookup operation on the given table. 650 class LookupTableFindOp : public OpKernel { 651 public: 652 explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 653 654 void Compute(OpKernelContext* ctx) override { 655 lookup::LookupInterface* table; 656 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); 657 core::ScopedUnref unref_me(table); 658 659 // Input 0 could be a STRING_REF or a RESOURCE 660 DataType expected_input_0 = 661 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 662 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), 663 table->value_dtype()}; 664 DataTypeVector expected_outputs = {table->value_dtype()}; 665 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); 666 667 const Tensor& key = ctx->input(1); 668 const Tensor& default_value = ctx->input(2); 669 OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value)); 670 671 TensorShape output_shape = key.shape(); 672 output_shape.RemoveLastDims(table->key_shape().dims()); 673 output_shape.AppendShape(table->value_shape()); 674 Tensor* out; 675 OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out)); 676 677 OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value)); 678 } 679 }; 680 681 REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU), 682 LookupTableFindOp); 683 REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU), 684 LookupTableFindOp); 685 686 // Table insert op. 687 class LookupTableInsertOp : public OpKernel { 688 public: 689 explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 690 691 void Compute(OpKernelContext* ctx) override { 692 lookup::LookupInterface* table; 693 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); 694 core::ScopedUnref unref_me(table); 695 696 DataType expected_input_0 = 697 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 698 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), 699 table->value_dtype()}; 700 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); 701 702 const Tensor& keys = ctx->input(1); 703 const Tensor& values = ctx->input(2); 704 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values)); 705 706 int64 memory_used_before = 0; 707 if (ctx->track_allocations()) { 708 memory_used_before = table->MemoryUsed(); 709 } 710 OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values)); 711 if (ctx->track_allocations()) { 712 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 713 memory_used_before); 714 } 715 } 716 }; 717 718 REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU), 719 LookupTableInsertOp); 720 REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU), 721 LookupTableInsertOp); 722 723 // Op that returns the size of the given table. 724 class LookupTableSizeOp : public OpKernel { 725 public: 726 explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 727 728 void Compute(OpKernelContext* ctx) override { 729 lookup::LookupInterface* table; 730 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); 731 core::ScopedUnref unref_me(table); 732 733 Tensor* out; 734 OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out)); 735 out->flat<int64>().setConstant(table->size()); 736 } 737 }; 738 739 REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU), 740 LookupTableSizeOp); 741 REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU), 742 LookupTableSizeOp); 743 744 // Op that outputs tensors of all keys and all values. 745 class LookupTableExportOp : public OpKernel { 746 public: 747 explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 748 749 void Compute(OpKernelContext* ctx) override { 750 lookup::LookupInterface* table; 751 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); 752 core::ScopedUnref unref_me(table); 753 754 OP_REQUIRES_OK(ctx, table->ExportValues(ctx)); 755 } 756 }; 757 758 REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU), 759 LookupTableExportOp); 760 REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU), 761 LookupTableExportOp); 762 763 // Clear the table and insert data. 764 class LookupTableImportOp : public OpKernel { 765 public: 766 explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 767 768 void Compute(OpKernelContext* ctx) override { 769 lookup::LookupInterface* table; 770 OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); 771 core::ScopedUnref unref_me(table); 772 773 DataType expected_input_0 = 774 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 775 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), 776 table->value_dtype()}; 777 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); 778 779 const Tensor& keys = ctx->input(1); 780 const Tensor& values = ctx->input(2); 781 OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values)); 782 783 int memory_used_before = 0; 784 if (ctx->track_allocations()) { 785 memory_used_before = table->MemoryUsed(); 786 } 787 OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values)); 788 if (ctx->track_allocations()) { 789 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 790 memory_used_before); 791 } 792 } 793 }; 794 795 REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU), 796 LookupTableImportOp); 797 REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU), 798 LookupTableImportOp); 799 800 // Register the HashTable op with the currently supported key and value types. 801 #define REGISTER_KERNEL(key_dtype, value_dtype) \ 802 REGISTER_KERNEL_BUILDER( \ 803 Name("HashTable") \ 804 .Device(DEVICE_CPU) \ 805 .TypeConstraint<key_dtype>("key_dtype") \ 806 .TypeConstraint<value_dtype>("value_dtype"), \ 807 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ 808 value_dtype>) \ 809 REGISTER_KERNEL_BUILDER( \ 810 Name("HashTableV2") \ 811 .Device(DEVICE_CPU) \ 812 .TypeConstraint<key_dtype>("key_dtype") \ 813 .TypeConstraint<value_dtype>("value_dtype"), \ 814 LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \ 815 value_dtype>) 816 817 REGISTER_KERNEL(string, double); 818 REGISTER_KERNEL(string, float); 819 REGISTER_KERNEL(string, int32); 820 REGISTER_KERNEL(string, int64); 821 REGISTER_KERNEL(int64, string); 822 REGISTER_KERNEL(int64, int64); 823 REGISTER_KERNEL(int64, float); 824 REGISTER_KERNEL(string, string); 825 REGISTER_KERNEL(string, bool); 826 REGISTER_KERNEL(int32, int32); 827 828 #undef REGISTER_KERNEL 829 830 // Register the MutableHashTable op. 831 #define REGISTER_KERNEL(key_dtype, value_dtype) \ 832 REGISTER_KERNEL_BUILDER( \ 833 Name("MutableHashTable") \ 834 .Device(DEVICE_CPU) \ 835 .TypeConstraint<key_dtype>("key_dtype") \ 836 .TypeConstraint<value_dtype>("value_dtype"), \ 837 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ 838 key_dtype, value_dtype>) \ 839 REGISTER_KERNEL_BUILDER( \ 840 Name("MutableHashTableV2") \ 841 .Device(DEVICE_CPU) \ 842 .TypeConstraint<key_dtype>("key_dtype") \ 843 .TypeConstraint<value_dtype>("value_dtype"), \ 844 LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \ 845 key_dtype, value_dtype>) 846 847 REGISTER_KERNEL(string, float); 848 REGISTER_KERNEL(string, int64); 849 REGISTER_KERNEL(int64, string); 850 REGISTER_KERNEL(string, bool); 851 REGISTER_KERNEL(int64, float); 852 853 #undef REGISTER_KERNEL 854 855 // Register the MutableHashTableOfTensors op. 856 #define REGISTER_KERNEL(key_dtype, value_dtype) \ 857 REGISTER_KERNEL_BUILDER( \ 858 Name("MutableHashTableOfTensors") \ 859 .Device(DEVICE_CPU) \ 860 .TypeConstraint<key_dtype>("key_dtype") \ 861 .TypeConstraint<value_dtype>("value_dtype"), \ 862 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ 863 key_dtype, value_dtype>) \ 864 REGISTER_KERNEL_BUILDER( \ 865 Name("MutableHashTableOfTensorsV2") \ 866 .Device(DEVICE_CPU) \ 867 .TypeConstraint<key_dtype>("key_dtype") \ 868 .TypeConstraint<value_dtype>("value_dtype"), \ 869 LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \ 870 key_dtype, value_dtype>) 871 872 REGISTER_KERNEL(string, float); 873 REGISTER_KERNEL(string, int64); 874 REGISTER_KERNEL(int64, string); 875 REGISTER_KERNEL(string, bool); 876 877 #undef REGISTER_KERNEL 878 879 // Register the MutableDenseHashTable op. 880 #define REGISTER_KERNEL(key_dtype, value_dtype) \ 881 REGISTER_KERNEL_BUILDER( \ 882 Name("MutableDenseHashTable") \ 883 .Device(DEVICE_CPU) \ 884 .TypeConstraint<key_dtype>("key_dtype") \ 885 .TypeConstraint<value_dtype>("value_dtype"), \ 886 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ 887 key_dtype, value_dtype>) \ 888 REGISTER_KERNEL_BUILDER( \ 889 Name("MutableDenseHashTableV2") \ 890 .Device(DEVICE_CPU) \ 891 .TypeConstraint<key_dtype>("key_dtype") \ 892 .TypeConstraint<value_dtype>("value_dtype"), \ 893 LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \ 894 key_dtype, value_dtype>) 895 896 REGISTER_KERNEL(int64, int64); 897 REGISTER_KERNEL(int64, float); 898 REGISTER_KERNEL(int64, double); 899 REGISTER_KERNEL(string, float); 900 REGISTER_KERNEL(string, bool); 901 REGISTER_KERNEL(int64, bool); 902 903 #undef REGISTER_KERNEL 904 905 } // namespace tensorflow 906