Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/lookup_table_op.h"
     17 #define EIGEN_USE_THREADS
     18 
     19 #include <string>
     20 #include <type_traits>
     21 #include <utility>
     22 
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/kernels/initializable_lookup_table.h"
     26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     27 #include "tensorflow/core/lib/hash/hash.h"
     28 
     29 namespace tensorflow {
     30 namespace lookup {
     31 
     32 // Lookup table that wraps an unordered_map, where the key and value data type
     33 // is specified. Each individual value must be a scalar. If vector values are
     34 // required, use MutableHashTableOfTensors.
     35 //
     36 // This table is mutable and thread safe - Insert can be called at any time.
     37 //
     38 // Sample use case:
     39 //
     40 // MutableHashTableOfScalars<int64, int64> table;  // int64 -> int64.
     41 // // Populate the table, elements could be added in one or multiple calls.
     42 // table.Insert(key_tensor, value_tensor); // Populate the table.
     43 //
     44 // table.Find(in_t, &out_t, default_t)
     45 //
     46 template <class K, class V>
     47 class MutableHashTableOfScalars final : public LookupInterface {
     48  public:
     49   MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
     50 
     51   size_t size() const override {
     52     mutex_lock l(mu_);
     53     return table_.size();
     54   }
     55 
     56   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
     57               const Tensor& default_value) override {
     58     const V default_val = default_value.flat<V>()(0);
     59     const auto key_values = key.flat<K>();
     60     auto value_values = value->flat<V>();
     61 
     62     mutex_lock l(mu_);
     63     for (int64 i = 0; i < key_values.size(); ++i) {
     64       value_values(i) = gtl::FindWithDefault(
     65           table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
     66           default_val);
     67     }
     68 
     69     return Status::OK();
     70   }
     71 
     72   Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
     73     const auto key_values = keys.flat<K>();
     74     const auto value_values = values.flat<V>();
     75 
     76     mutex_lock l(mu_);
     77     if (clear) {
     78       table_.clear();
     79     }
     80     for (int64 i = 0; i < key_values.size(); ++i) {
     81       gtl::InsertOrUpdate(&table_,
     82                           SubtleMustCopyUnlessStringOrFloat(key_values(i)),
     83                           SubtleMustCopyUnlessStringOrFloat(value_values(i)));
     84     }
     85     return Status::OK();
     86   }
     87 
     88   Status Insert(OpKernelContext* ctx, const Tensor& keys,
     89                 const Tensor& values) override {
     90     return DoInsert(false, keys, values);
     91   }
     92 
     93   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
     94                       const Tensor& values) override {
     95     return DoInsert(true, keys, values);
     96   }
     97 
     98   Status ExportValues(OpKernelContext* ctx) override {
     99     mutex_lock l(mu_);
    100     int64 size = table_.size();
    101 
    102     Tensor* keys;
    103     Tensor* values;
    104     TF_RETURN_IF_ERROR(
    105         ctx->allocate_output("keys", TensorShape({size}), &keys));
    106     TF_RETURN_IF_ERROR(
    107         ctx->allocate_output("values", TensorShape({size}), &values));
    108 
    109     auto keys_data = keys->flat<K>();
    110     auto values_data = values->flat<V>();
    111     int64 i = 0;
    112     for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
    113       keys_data(i) = it->first;
    114       values_data(i) = it->second;
    115     }
    116     return Status::OK();
    117   }
    118 
    119   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
    120 
    121   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
    122 
    123   TensorShape key_shape() const final { return TensorShape(); }
    124 
    125   TensorShape value_shape() const override { return TensorShape(); }
    126 
    127   int64 MemoryUsed() const override {
    128     int64 ret = 0;
    129     mutex_lock l(mu_);
    130     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
    131       size_t bucket_size = table_.bucket_size(i);
    132       if (bucket_size == 0) {
    133         ret++;
    134       } else {
    135         ret += bucket_size;
    136       }
    137     }
    138     return sizeof(MutableHashTableOfScalars) + ret;
    139   }
    140 
    141  private:
    142   // TODO(andreasst): consider using a read/write lock or a concurrent map
    143   mutable mutex mu_;
    144   std::unordered_map<K, V> table_ GUARDED_BY(mu_);
    145 };
    146 
    147 // Lookup table that wraps an unordered_map. Behaves identical to
    148 // MutableHashTableOfScalars except that each value must be a vector.
    149 template <class K, class V>
    150 class MutableHashTableOfTensors final : public LookupInterface {
    151  public:
    152   MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
    153     OP_REQUIRES_OK(ctx,
    154                    GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
    155     OP_REQUIRES(
    156         ctx, TensorShapeUtils::IsVector(value_shape_),
    157         errors::InvalidArgument("Default value must be a vector, got shape ",
    158                                 value_shape_.DebugString()));
    159   }
    160 
    161   size_t size() const override {
    162     mutex_lock l(mu_);
    163     return table_.size();
    164   }
    165 
    166   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
    167               const Tensor& default_value) override {
    168     const auto default_flat = default_value.flat<V>();
    169     const auto key_values = key.flat<K>();
    170     auto value_values = value->flat_inner_dims<V, 2>();
    171     int64 value_dim = value_shape_.dim_size(0);
    172 
    173     mutex_lock l(mu_);
    174     for (int64 i = 0; i < key_values.size(); ++i) {
    175       ValueArray* value_vec = gtl::FindOrNull(
    176           table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)));
    177       if (value_vec != nullptr) {
    178         for (int64 j = 0; j < value_dim; j++) {
    179           value_values(i, j) = value_vec->at(j);
    180         }
    181       } else {
    182         for (int64 j = 0; j < value_dim; j++) {
    183           value_values(i, j) = default_flat(j);
    184         }
    185       }
    186     }
    187 
    188     return Status::OK();
    189   }
    190 
    191   Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
    192     const auto key_values = keys.flat<K>();
    193     const auto value_values = values.flat_inner_dims<V, 2>();
    194     int64 value_dim = value_shape_.dim_size(0);
    195 
    196     mutex_lock l(mu_);
    197     if (clear) {
    198       table_.clear();
    199     }
    200     for (int64 i = 0; i < key_values.size(); ++i) {
    201       ValueArray value_vec;
    202       for (int64 j = 0; j < value_dim; j++) {
    203         V value = value_values(i, j);
    204         value_vec.push_back(value);
    205       }
    206       gtl::InsertOrUpdate(
    207           &table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), value_vec);
    208     }
    209     return Status::OK();
    210   }
    211 
    212   Status Insert(OpKernelContext* ctx, const Tensor& keys,
    213                 const Tensor& values) override {
    214     return DoInsert(false, keys, values);
    215   }
    216 
    217   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
    218                       const Tensor& values) override {
    219     return DoInsert(true, keys, values);
    220   }
    221 
    222   Status ExportValues(OpKernelContext* ctx) override {
    223     mutex_lock l(mu_);
    224     int64 size = table_.size();
    225     int64 value_dim = value_shape_.dim_size(0);
    226 
    227     Tensor* keys;
    228     Tensor* values;
    229     TF_RETURN_IF_ERROR(
    230         ctx->allocate_output("keys", TensorShape({size}), &keys));
    231     TF_RETURN_IF_ERROR(ctx->allocate_output(
    232         "values", TensorShape({size, value_dim}), &values));
    233 
    234     auto keys_data = keys->flat<K>();
    235     auto values_data = values->matrix<V>();
    236     int64 i = 0;
    237     for (auto it = table_.begin(); it != table_.end(); ++it, ++i) {
    238       K key = it->first;
    239       ValueArray value = it->second;
    240       keys_data(i) = key;
    241       for (int64 j = 0; j < value_dim; j++) {
    242         values_data(i, j) = value[j];
    243       }
    244     }
    245     return Status::OK();
    246   }
    247 
    248   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
    249 
    250   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
    251 
    252   TensorShape key_shape() const final { return TensorShape(); }
    253 
    254   TensorShape value_shape() const override { return value_shape_; }
    255 
    256   int64 MemoryUsed() const override {
    257     int64 ret = 0;
    258     mutex_lock l(mu_);
    259     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
    260       size_t bucket_size = table_.bucket_size(i);
    261       if (bucket_size == 0) {
    262         ret++;
    263       } else {
    264         ret += bucket_size;
    265       }
    266     }
    267     return sizeof(MutableHashTableOfTensors) + ret;
    268   }
    269 
    270  private:
    271   TensorShape value_shape_;
    272   // TODO(andreasst): consider using a read/write lock or a concurrent map
    273   mutable mutex mu_;
    274   typedef gtl::InlinedVector<V, 4> ValueArray;
    275   std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
    276 };
    277 
    278 namespace {
    279 
    280 template <typename T>
    281 inline uint64 HashScalar(const T& key) {
    282   return static_cast<uint64>(key);
    283 }
    284 
    285 inline uint64 HashScalar(const string& key) { return Hash64(key); }
    286 
    287 // If the given shape is a scalar return {1} instead. Otherwise leave it alone.
    288 TensorShape MaybeVectorizeShape(const TensorShape& shape) {
    289   if (shape.dims() == 0) {
    290     return TensorShape({1});
    291   }
    292   return shape;
    293 }
    294 
    295 }  // namespace
    296 
    297 // Modeled after densehashtable in https://github.com/sparsehash/sparsehash
    298 template <class K, class V>
    299 class MutableDenseHashTable final : public LookupInterface {
    300  public:
    301   MutableDenseHashTable(OpKernelContext* ctx, OpKernel* kernel) {
    302     OP_REQUIRES_OK(
    303         ctx, GetNodeAttr(kernel->def(), "max_load_factor", &max_load_factor_));
    304     OP_REQUIRES(ctx, max_load_factor_ > 0 && max_load_factor_ < 1,
    305                 errors::InvalidArgument(
    306                     "max_load_factor must be between 0 and 1, got: ",
    307                     max_load_factor_));
    308 
    309     OP_REQUIRES_OK(ctx,
    310                    GetNodeAttr(kernel->def(), "value_shape", &value_shape_));
    311     OP_REQUIRES(ctx,
    312                 TensorShapeUtils::IsScalar(value_shape_) ||
    313                     TensorShapeUtils::IsVector(value_shape_),
    314                 errors::InvalidArgument(
    315                     "Empty value must be a scalar or a vector, got shape ",
    316                     value_shape_.DebugString()));
    317 
    318     const Tensor* empty_key_input;
    319     OP_REQUIRES_OK(ctx, ctx->input("empty_key", &empty_key_input));
    320     key_shape_ = empty_key_input->shape();
    321     OP_REQUIRES(ctx,
    322                 TensorShapeUtils::IsScalar(key_shape_) ||
    323                     TensorShapeUtils::IsVector(key_shape_),
    324                 errors::InvalidArgument(
    325                     "Empty key must be a scalar or a vector, got shape ",
    326                     key_shape_.DebugString()));
    327     empty_key_ = PersistentTensor(*empty_key_input);
    328     empty_key_hash_ = HashKey(
    329         empty_key_input->template shaped<K, 2>({1, key_shape_.num_elements()}),
    330         0);
    331 
    332     int64 initial_num_buckets;
    333     OP_REQUIRES_OK(ctx, GetNodeAttr(kernel->def(), "initial_num_buckets",
    334                                     &initial_num_buckets));
    335     OP_REQUIRES_OK(ctx, AllocateBuckets(ctx, initial_num_buckets));
    336   }
    337 
    338   size_t size() const override LOCKS_EXCLUDED(mu_) {
    339     mutex_lock l(mu_);
    340     return num_entries_;
    341   }
    342 
    343   Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
    344               const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
    345     const int64 num_elements = key.dim_size(0);
    346     const int64 key_size = key_shape_.num_elements();
    347     const int64 value_size = value_shape_.num_elements();
    348     if (key.NumElements() != num_elements * key_size) {
    349       TensorShape expected_shape({num_elements});
    350       expected_shape.AppendShape(key_shape_);
    351       return errors::InvalidArgument("Expected key shape ",
    352                                      expected_shape.DebugString(), " got ",
    353                                      key.shape().DebugString());
    354     }
    355     const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
    356     auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
    357     const auto default_flat = default_value.flat<V>();
    358 
    359     mutex_lock l(mu_);
    360     const auto key_buckets_matrix =
    361         key_buckets_.AccessTensor(ctx)->template matrix<K>();
    362     const auto value_buckets_matrix =
    363         value_buckets_.AccessTensor(ctx)->template matrix<V>();
    364     const auto empty_key_matrix =
    365         empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
    366     const int64 bit_mask = num_buckets_ - 1;
    367     // TODO(andreasst): parallelize using work_sharder
    368     for (int64 i = 0; i < num_elements; ++i) {
    369       const uint64 key_hash = HashKey(key_matrix, i);
    370       if (empty_key_hash_ == key_hash &&
    371           IsEqualKey(empty_key_matrix, 0, key_matrix, i)) {
    372         return errors::InvalidArgument(
    373             "Using the empty_key as a table key is not allowed");
    374       }
    375       int64 bucket_index = key_hash & bit_mask;
    376       int64 num_probes = 0;
    377       while (true) {
    378         if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
    379           for (int64 j = 0; j < value_size; ++j) {
    380             // TODO(andreasst): check if we can get rid of SubtleMustCopy
    381             // here and elsewhere in this file.
    382             value_matrix(i, j) = SubtleMustCopyUnlessStringOrFloat(
    383                 value_buckets_matrix(bucket_index, j));
    384           }
    385           break;
    386         }
    387         if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_matrix, 0)) {
    388           for (int64 j = 0; j < value_size; ++j) {
    389             value_matrix(i, j) =
    390                 SubtleMustCopyUnlessStringOrFloat(default_flat(j));
    391           }
    392           break;
    393         }
    394         ++num_probes;
    395         bucket_index =
    396             (bucket_index + num_probes) & bit_mask;  // quadratic probing
    397         if (num_probes >= num_buckets_) {
    398           return errors::Internal(
    399               "Internal error in MutableDenseHashTable lookup");
    400         }
    401       }
    402     }
    403     return Status::OK();
    404   }
    405 
    406   Status Insert(OpKernelContext* ctx, const Tensor& key,
    407                 const Tensor& value) override LOCKS_EXCLUDED(mu_) {
    408     if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
    409       TensorShape expected_shape({key.dim_size(0)});
    410       expected_shape.AppendShape(key_shape_);
    411       return errors::InvalidArgument("Expected key shape ",
    412                                      expected_shape.DebugString(), " got ",
    413                                      key.shape().DebugString());
    414     }
    415     mutex_lock l(mu_);
    416     // For simplicity we assume that all keys in the input result in inserts
    417     // rather than updates. That means we may grow the table even though we
    418     // don't need to. As long as the number of keys inserted in one call is
    419     // small compared to the size of the map, the impact of this is minimal.
    420     const int64 pending_num_entries = num_entries_ + key.dim_size(0);
    421     if (pending_num_entries > num_buckets_ * max_load_factor_) {
    422       int64 new_num_buckets = num_buckets_;
    423       do {
    424         new_num_buckets <<= 1;
    425       } while (pending_num_entries > new_num_buckets * max_load_factor_);
    426       TF_RETURN_IF_ERROR(Rebucket(ctx, new_num_buckets));
    427     }
    428     return DoInsert(ctx, key, value, false);
    429   }
    430 
    431   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
    432                       const Tensor& values) override LOCKS_EXCLUDED(mu_) {
    433     mutex_lock l(mu_);
    434     num_buckets_ = keys.dim_size(0);
    435     key_buckets_ = PersistentTensor(keys);
    436     value_buckets_ = PersistentTensor(values);
    437     // Count the number of keys that are not the empty_key. This requires
    438     // iterating through the whole table but that is OK as we only execute it
    439     // during checkpoint restore.
    440     num_entries_ = 0;
    441     const auto empty_key_tensor =
    442         empty_key_.AccessTensor(ctx)->template shaped<K, 2>(
    443             {1, key_shape_.num_elements()});
    444     const auto key_buckets_tensor =
    445         key_buckets_.AccessTensor(ctx)->template matrix<K>();
    446     for (int64 i = 0; i < num_buckets_; ++i) {
    447       if (!IsEqualKey(key_buckets_tensor, i, empty_key_tensor, 0)) {
    448         ++num_entries_;
    449       }
    450     }
    451     return Status::OK();
    452   }
    453 
    454   Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
    455     mutex_lock l(mu_);
    456     Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
    457     Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
    458     TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
    459     TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_tensor));
    460     return Status::OK();
    461   }
    462 
    463   Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
    464                                           const Tensor& values) override {
    465     TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values));
    466     TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape()));
    467 
    468     // The storage format in key_buckets_ and value_buckets_ is always vectors,
    469     // even if the inputs are scalars. This is what eventually gets exported
    470     // and is expected by the import method as well.
    471     TensorShape key_shape = MaybeVectorizeShape(key_shape_);
    472     TensorShape value_shape = MaybeVectorizeShape(value_shape_);
    473 
    474     // Compute the final expected shape of the value by starting with the shape
    475     // of all keys, removing the dimensions particular to each key and then
    476     // appending the shape of a single value.
    477     TensorShape expected_value_shape = keys.shape();
    478     expected_value_shape.RemoveLastDims(key_shape.dims());
    479     expected_value_shape.AppendShape(value_shape);
    480     if (values.shape() != expected_value_shape) {
    481       return errors::InvalidArgument(
    482           "Expected shape ", expected_value_shape.DebugString(),
    483           " for value, got ", values.shape().DebugString());
    484     }
    485     return Status::OK();
    486   }
    487 
    488   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
    489 
    490   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
    491 
    492   TensorShape key_shape() const override { return key_shape_; }
    493 
    494   TensorShape value_shape() const override { return value_shape_; }
    495 
    496   int64 MemoryUsed() const override {
    497     mutex_lock l(mu_);
    498     return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
    499            value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
    500   }
    501 
    502  private:
    503   Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
    504                   bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    505     const int64 num_elements = key.dim_size(0);
    506     const int64 value_size = value_shape_.num_elements();
    507     const int64 key_size = key_shape_.num_elements();
    508     const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
    509     auto value_matrix = value.shaped<V, 2>({num_elements, value_size});
    510 
    511     auto key_buckets_matrix =
    512         key_buckets_.AccessTensor(ctx)->template matrix<K>();
    513     auto value_buckets_matrix =
    514         value_buckets_.AccessTensor(ctx)->template matrix<V>();
    515     const auto empty_key_tensor =
    516         empty_key_.AccessTensor(ctx)->template shaped<K, 2>({1, key_size});
    517     const int64 bit_mask = num_buckets_ - 1;
    518     for (int64 i = 0; i < num_elements; ++i) {
    519       const uint64 key_hash = HashKey(key_matrix, i);
    520       if (empty_key_hash_ == key_hash &&
    521           IsEqualKey(empty_key_tensor, 0, key_matrix, i)) {
    522         if (ignore_empty_key) {
    523           continue;
    524         }
    525         return errors::InvalidArgument(
    526             "Using the empty_key as a table key is not allowed");
    527       }
    528       int64 bucket_index = key_hash & bit_mask;
    529       int64 num_probes = 0;
    530       while (true) {
    531         if (IsEqualKey(key_buckets_matrix, bucket_index, key_matrix, i)) {
    532           for (int64 j = 0; j < value_size; ++j) {
    533             value_buckets_matrix(bucket_index, j) =
    534                 SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j));
    535           }
    536           break;
    537         }
    538         if (IsEqualKey(key_buckets_matrix, bucket_index, empty_key_tensor, 0)) {
    539           ++num_entries_;
    540           for (int64 j = 0; j < key_size; ++j) {
    541             key_buckets_matrix(bucket_index, j) =
    542                 SubtleMustCopyUnlessStringOrFloat(key_matrix(i, j));
    543           }
    544           for (int64 j = 0; j < value_size; ++j) {
    545             value_buckets_matrix(bucket_index, j) =
    546                 SubtleMustCopyUnlessStringOrFloat(value_matrix(i, j));
    547           }
    548           break;
    549         }
    550         ++num_probes;
    551         bucket_index =
    552             (bucket_index + num_probes) & bit_mask;  // quadratic probing
    553         if (num_probes >= num_buckets_) {
    554           return errors::Internal(
    555               "Internal error in MutableDenseHashTable insert");
    556         }
    557       }
    558     }
    559     return Status::OK();
    560   }
    561 
    562   Status AllocateBuckets(OpKernelContext* ctx, int64 new_num_buckets)
    563       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    564     if (new_num_buckets < 4 ||
    565         ((new_num_buckets & (new_num_buckets - 1)) != 0)) {
    566       return errors::InvalidArgument(
    567           "Number of buckets must be at least 4 and a power of 2, got: ",
    568           new_num_buckets);
    569     }
    570     num_buckets_ = new_num_buckets;
    571     num_entries_ = 0;
    572 
    573     const int64 key_size = key_shape_.num_elements();
    574     Tensor* key_buckets_tensor;
    575     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    576         key_dtype(), TensorShape({num_buckets_, key_size}), &key_buckets_,
    577         &key_buckets_tensor));
    578     auto key_buckets_matrix = key_buckets_tensor->matrix<K>();
    579     const auto empty_key_flat =
    580         empty_key_.AccessTensor(ctx)->template flat<K>();
    581     for (int64 i = 0; i < num_buckets_; ++i) {
    582       for (int64 j = 0; j < key_size; ++j) {
    583         key_buckets_matrix(i, j) = empty_key_flat(j);
    584       }
    585     }
    586 
    587     const int64 value_size = value_shape_.num_elements();
    588     Tensor* value_buckets_tensor;
    589     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    590         value_dtype(), TensorShape({num_buckets_, value_size}), &value_buckets_,
    591         &value_buckets_tensor));
    592     auto value_buckets_matrix = value_buckets_tensor->matrix<V>();
    593     for (int64 i = 0; i < num_buckets_; ++i) {
    594       for (int64 j = 0; j < value_size; ++j) {
    595         // Initialize values to the default value for the type to avoid
    596         // exposing uninitialized memory in ExportValues().
    597         value_buckets_matrix(i, j) = V();
    598       }
    599     }
    600     return Status::OK();
    601   }
    602 
    603   Status Rebucket(OpKernelContext* ctx, int64 num_new_buckets)
    604       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    605     Tensor old_key_buckets = *key_buckets_.AccessTensor(ctx);
    606     Tensor old_value_buckets = *value_buckets_.AccessTensor(ctx);
    607     TF_RETURN_IF_ERROR(AllocateBuckets(ctx, num_new_buckets));
    608     return DoInsert(ctx, old_key_buckets, old_value_buckets, true);
    609   }
    610 
    611   uint64 HashKey(typename TTypes<K>::ConstMatrix key, int64 index) const {
    612     if (key_shape_.num_elements() == 1) {
    613       return HashScalar(key(index, 0));
    614     }
    615     uint64 result = 0;
    616     for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
    617       result = Hash64Combine(result, HashScalar(key(index, i)));
    618     }
    619     return result;
    620   }
    621 
    622   // Use a template to allow this function to be used both with Matrix and
    623   // ConstMatrix types.
    624   template <typename MT2>
    625   bool IsEqualKey(typename TTypes<K>::Matrix tensor1, int64 index1, MT2 tensor2,
    626                   int64 index2) const {
    627     for (int64 i = 0; i < key_shape_.num_elements(); ++i) {
    628       if (tensor1(index1, i) != tensor2(index2, i)) {
    629         return false;
    630       }
    631     }
    632     return true;
    633   }
    634 
    635   TensorShape key_shape_;
    636   TensorShape value_shape_;
    637   float max_load_factor_;
    638   mutable mutex mu_;
    639   int64 num_entries_ GUARDED_BY(mu_);
    640   int64 num_buckets_ GUARDED_BY(mu_);
    641   PersistentTensor key_buckets_ GUARDED_BY(mu_);
    642   PersistentTensor value_buckets_ GUARDED_BY(mu_);
    643   PersistentTensor empty_key_;
    644   uint64 empty_key_hash_;
    645 };
    646 
    647 }  // namespace lookup
    648 
    649 // Table lookup op. Perform the lookup operation on the given table.
    650 class LookupTableFindOp : public OpKernel {
    651  public:
    652   explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    653 
    654   void Compute(OpKernelContext* ctx) override {
    655     lookup::LookupInterface* table;
    656     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
    657     core::ScopedUnref unref_me(table);
    658 
    659     // Input 0 could be a STRING_REF or a RESOURCE
    660     DataType expected_input_0 =
    661         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
    662     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
    663                                       table->value_dtype()};
    664     DataTypeVector expected_outputs = {table->value_dtype()};
    665     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
    666 
    667     const Tensor& key = ctx->input(1);
    668     const Tensor& default_value = ctx->input(2);
    669     OP_REQUIRES_OK(ctx, table->CheckFindArguments(key, default_value));
    670 
    671     TensorShape output_shape = key.shape();
    672     output_shape.RemoveLastDims(table->key_shape().dims());
    673     output_shape.AppendShape(table->value_shape());
    674     Tensor* out;
    675     OP_REQUIRES_OK(ctx, ctx->allocate_output("values", output_shape, &out));
    676 
    677     OP_REQUIRES_OK(ctx, table->Find(ctx, key, out, default_value));
    678   }
    679 };
    680 
    681 REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
    682                         LookupTableFindOp);
    683 REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
    684                         LookupTableFindOp);
    685 
    686 // Table insert op.
    687 class LookupTableInsertOp : public OpKernel {
    688  public:
    689   explicit LookupTableInsertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    690 
    691   void Compute(OpKernelContext* ctx) override {
    692     lookup::LookupInterface* table;
    693     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
    694     core::ScopedUnref unref_me(table);
    695 
    696     DataType expected_input_0 =
    697         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
    698     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
    699                                       table->value_dtype()};
    700     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
    701 
    702     const Tensor& keys = ctx->input(1);
    703     const Tensor& values = ctx->input(2);
    704     OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForInsert(keys, values));
    705 
    706     int64 memory_used_before = 0;
    707     if (ctx->track_allocations()) {
    708       memory_used_before = table->MemoryUsed();
    709     }
    710     OP_REQUIRES_OK(ctx, table->Insert(ctx, keys, values));
    711     if (ctx->track_allocations()) {
    712       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
    713                                                memory_used_before);
    714     }
    715   }
    716 };
    717 
    718 REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
    719                         LookupTableInsertOp);
    720 REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
    721                         LookupTableInsertOp);
    722 
    723 // Op that returns the size of the given table.
    724 class LookupTableSizeOp : public OpKernel {
    725  public:
    726   explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    727 
    728   void Compute(OpKernelContext* ctx) override {
    729     lookup::LookupInterface* table;
    730     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
    731     core::ScopedUnref unref_me(table);
    732 
    733     Tensor* out;
    734     OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out));
    735     out->flat<int64>().setConstant(table->size());
    736   }
    737 };
    738 
    739 REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
    740                         LookupTableSizeOp);
    741 REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
    742                         LookupTableSizeOp);
    743 
    744 // Op that outputs tensors of all keys and all values.
    745 class LookupTableExportOp : public OpKernel {
    746  public:
    747   explicit LookupTableExportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    748 
    749   void Compute(OpKernelContext* ctx) override {
    750     lookup::LookupInterface* table;
    751     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
    752     core::ScopedUnref unref_me(table);
    753 
    754     OP_REQUIRES_OK(ctx, table->ExportValues(ctx));
    755   }
    756 };
    757 
    758 REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
    759                         LookupTableExportOp);
    760 REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
    761                         LookupTableExportOp);
    762 
    763 // Clear the table and insert data.
    764 class LookupTableImportOp : public OpKernel {
    765  public:
    766   explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    767 
    768   void Compute(OpKernelContext* ctx) override {
    769     lookup::LookupInterface* table;
    770     OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
    771     core::ScopedUnref unref_me(table);
    772 
    773     DataType expected_input_0 =
    774         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
    775     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
    776                                       table->value_dtype()};
    777     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
    778 
    779     const Tensor& keys = ctx->input(1);
    780     const Tensor& values = ctx->input(2);
    781     OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensorsForImport(keys, values));
    782 
    783     int memory_used_before = 0;
    784     if (ctx->track_allocations()) {
    785       memory_used_before = table->MemoryUsed();
    786     }
    787     OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
    788     if (ctx->track_allocations()) {
    789       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
    790                                                memory_used_before);
    791     }
    792   }
    793 };
    794 
    795 REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
    796                         LookupTableImportOp);
    797 REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
    798                         LookupTableImportOp);
    799 
    800 // Register the HashTable op with the currently supported key and value types.
    801 #define REGISTER_KERNEL(key_dtype, value_dtype)                           \
    802   REGISTER_KERNEL_BUILDER(                                                \
    803       Name("HashTable")                                                   \
    804           .Device(DEVICE_CPU)                                             \
    805           .TypeConstraint<key_dtype>("key_dtype")                         \
    806           .TypeConstraint<value_dtype>("value_dtype"),                    \
    807       LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
    808                     value_dtype>)                                         \
    809   REGISTER_KERNEL_BUILDER(                                                \
    810       Name("HashTableV2")                                                 \
    811           .Device(DEVICE_CPU)                                             \
    812           .TypeConstraint<key_dtype>("key_dtype")                         \
    813           .TypeConstraint<value_dtype>("value_dtype"),                    \
    814       LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
    815                     value_dtype>)
    816 
    817 REGISTER_KERNEL(string, double);
    818 REGISTER_KERNEL(string, float);
    819 REGISTER_KERNEL(string, int32);
    820 REGISTER_KERNEL(string, int64);
    821 REGISTER_KERNEL(int64, string);
    822 REGISTER_KERNEL(int64, int64);
    823 REGISTER_KERNEL(int64, float);
    824 REGISTER_KERNEL(string, string);
    825 REGISTER_KERNEL(string, bool);
    826 REGISTER_KERNEL(int32, int32);
    827 
    828 #undef REGISTER_KERNEL
    829 
    830 // Register the MutableHashTable op.
    831 #define REGISTER_KERNEL(key_dtype, value_dtype)                                \
    832   REGISTER_KERNEL_BUILDER(                                                     \
    833       Name("MutableHashTable")                                                 \
    834           .Device(DEVICE_CPU)                                                  \
    835           .TypeConstraint<key_dtype>("key_dtype")                              \
    836           .TypeConstraint<value_dtype>("value_dtype"),                         \
    837       LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
    838                     key_dtype, value_dtype>)                                   \
    839   REGISTER_KERNEL_BUILDER(                                                     \
    840       Name("MutableHashTableV2")                                               \
    841           .Device(DEVICE_CPU)                                                  \
    842           .TypeConstraint<key_dtype>("key_dtype")                              \
    843           .TypeConstraint<value_dtype>("value_dtype"),                         \
    844       LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
    845                     key_dtype, value_dtype>)
    846 
    847 REGISTER_KERNEL(string, float);
    848 REGISTER_KERNEL(string, int64);
    849 REGISTER_KERNEL(int64, string);
    850 REGISTER_KERNEL(string, bool);
    851 REGISTER_KERNEL(int64, float);
    852 
    853 #undef REGISTER_KERNEL
    854 
    855 // Register the MutableHashTableOfTensors op.
    856 #define REGISTER_KERNEL(key_dtype, value_dtype)                                \
    857   REGISTER_KERNEL_BUILDER(                                                     \
    858       Name("MutableHashTableOfTensors")                                        \
    859           .Device(DEVICE_CPU)                                                  \
    860           .TypeConstraint<key_dtype>("key_dtype")                              \
    861           .TypeConstraint<value_dtype>("value_dtype"),                         \
    862       LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
    863                     key_dtype, value_dtype>)                                   \
    864   REGISTER_KERNEL_BUILDER(                                                     \
    865       Name("MutableHashTableOfTensorsV2")                                      \
    866           .Device(DEVICE_CPU)                                                  \
    867           .TypeConstraint<key_dtype>("key_dtype")                              \
    868           .TypeConstraint<value_dtype>("value_dtype"),                         \
    869       LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
    870                     key_dtype, value_dtype>)
    871 
    872 REGISTER_KERNEL(string, float);
    873 REGISTER_KERNEL(string, int64);
    874 REGISTER_KERNEL(int64, string);
    875 REGISTER_KERNEL(string, bool);
    876 
    877 #undef REGISTER_KERNEL
    878 
    879 // Register the MutableDenseHashTable op.
    880 #define REGISTER_KERNEL(key_dtype, value_dtype)                            \
    881   REGISTER_KERNEL_BUILDER(                                                 \
    882       Name("MutableDenseHashTable")                                        \
    883           .Device(DEVICE_CPU)                                              \
    884           .TypeConstraint<key_dtype>("key_dtype")                          \
    885           .TypeConstraint<value_dtype>("value_dtype"),                     \
    886       LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
    887                     key_dtype, value_dtype>)                               \
    888   REGISTER_KERNEL_BUILDER(                                                 \
    889       Name("MutableDenseHashTableV2")                                      \
    890           .Device(DEVICE_CPU)                                              \
    891           .TypeConstraint<key_dtype>("key_dtype")                          \
    892           .TypeConstraint<value_dtype>("value_dtype"),                     \
    893       LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
    894                     key_dtype, value_dtype>)
    895 
    896 REGISTER_KERNEL(int64, int64);
    897 REGISTER_KERNEL(int64, float);
    898 REGISTER_KERNEL(int64, double);
    899 REGISTER_KERNEL(string, float);
    900 REGISTER_KERNEL(string, bool);
    901 REGISTER_KERNEL(int64, bool);
    902 
    903 #undef REGISTER_KERNEL
    904 
    905 }  // namespace tensorflow
    906