Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
     17 #define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
     18 
     19 #include "tensorflow/core/framework/lookup_interface.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/resource_mgr.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 #include "tensorflow/core/kernels/lookup_util.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/gtl/map_util.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/thread_annotations.h"
     31 
     32 namespace tensorflow {
     33 
     34 // Lookup table op that supports different table implementations specified by
     35 // the 'Container' template. Container must be derived from LookupInterface. The
     36 // key and value are of the templated type "key_dtype" and "value_dtype"
     37 // respectively.
     38 template <class Container, class key_dtype, class value_dtype>
     39 class LookupTableOp : public OpKernel {
     40  public:
     41   // ctx is not owned by this class.
     42   explicit LookupTableOp(OpKernelConstruction* ctx)
     43       : OpKernel(ctx), table_handle_set_(false) {
     44     OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING,
     45                                                  tensorflow::TensorShape({2}),
     46                                                  &table_handle_, nullptr));
     47     OP_REQUIRES_OK(
     48         ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_));
     49   }
     50 
     51   // ctx is not owned by this function.
     52   void Compute(OpKernelContext* ctx) override {
     53     mutex_lock l(mu_);
     54 
     55     if (!table_handle_set_) {
     56       OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
     57                                       use_node_name_sharing_));
     58     }
     59 
     60     auto creator = [ctx, this](lookup::LookupInterface** ret) {
     61       lookup::LookupInterface* container = new Container(ctx, this);
     62       if (!ctx->status().ok()) {
     63         container->Unref();
     64         return ctx->status();
     65       }
     66       if (ctx->track_allocations()) {
     67         ctx->record_persistent_memory_allocation(
     68             container->MemoryUsed() + table_handle_.AllocatedBytes());
     69       }
     70       *ret = container;
     71       return Status::OK();
     72     };
     73 
     74     lookup::LookupInterface* table = nullptr;
     75     OP_REQUIRES_OK(ctx,
     76                    cinfo_.resource_manager()
     77                        ->template LookupOrCreate<lookup::LookupInterface>(
     78                            cinfo_.container(), cinfo_.name(), &table, creator));
     79     core::ScopedUnref unref_me(table);
     80 
     81     OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
     82                             *table, DataTypeToEnum<key_dtype>::v(),
     83                             DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
     84 
     85     if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
     86       Tensor* handle;
     87       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
     88       handle->scalar<ResourceHandle>()() =
     89           MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(),
     90                                                       cinfo_.name());
     91     } else {
     92       if (!table_handle_set_) {
     93         auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
     94         h(0) = cinfo_.container();
     95         h(1) = cinfo_.name();
     96       }
     97       ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
     98     }
     99     table_handle_set_ = true;
    100   }
    101 
    102   ~LookupTableOp() override {
    103     // If the table object was not shared, delete it.
    104     if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
    105       TF_CHECK_OK(
    106           cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
    107               cinfo_.container(), cinfo_.name()));
    108     }
    109   }
    110 
    111  private:
    112   mutex mu_;
    113   PersistentTensor table_handle_ GUARDED_BY(mu_);
    114   bool table_handle_set_ GUARDED_BY(mu_);
    115   ContainerInfo cinfo_;
    116   bool use_node_name_sharing_;
    117 
    118   TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
    119 };
    120 
    121 namespace lookup {
    122 
    123 // Ensure that the compiler cannot elide a copy into a local, for
    124 // bounds checking on source tensors that might be updated asynchronously for
    125 // integral types. However non-integer variables are not allowed and therefore
    126 // the local copy is unnecessary.
    127 template <typename T>
    128 T SubtleMustCopyUnlessStringOrFloat(const T& value) {
    129   return internal::SubtleMustCopy(value);
    130 }
    131 
    132 inline const string& SubtleMustCopyUnlessStringOrFloat(const string& value) {
    133   return value;
    134 }
    135 
    136 inline const float SubtleMustCopyUnlessStringOrFloat(const float value) {
    137   return value;
    138 }
    139 
    140 inline const double SubtleMustCopyUnlessStringOrFloat(const double value) {
    141   return value;
    142 }
    143 
    144 // Lookup table that wraps an unordered_map, where the key and value data type
    145 // is specified.
    146 //
    147 // This table is recommended for any variations to key values.
    148 //
    149 // For look up, the table is required to be initialized (allocated
    150 // and populated). Once the table is marked as initialized it becomes read-only.
    151 //
    152 // Sample use case:
    153 //
    154 // HashTable<int64, int64> table;  // int64 -> int64.
    155 // table.Prepare(10); // Prepare the underlying data structure, the number of
    156 //                    // elements is required by interface, but not used.
    157 // // Populate the table, elements could be added in one or multiple calls.
    158 // table.Insert(key_tensor, value_tensor); // Populate the table.
    159 // ...
    160 // table.set_is_initialized();
    161 //
    162 // table.Find(in_t, &out_t, default_t)
    163 //
    164 template <class K, class V>
    165 class HashTable : public InitializableLookupTable {
    166  public:
    167   HashTable(OpKernelContext* ctx, OpKernel* kernel) {}
    168 
    169   size_t size() const override {
    170     // return the size of the table only if it's initialized, otherwise 0.
    171     if (!is_initialized_) {
    172       return 0;
    173     }
    174     std::atomic_thread_fence(std::memory_order_acquire);
    175     return table_ ? table_->size() : 0;
    176   }
    177 
    178   DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
    179 
    180   DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
    181 
    182  protected:
    183   Status DoPrepare(size_t unused) override {
    184     if (is_initialized_) {
    185       return errors::Aborted("HashTable already initialized.");
    186     }
    187     if (!table_) {
    188       table_ = std::unique_ptr<std::unordered_map<K, V>>(
    189           new std::unordered_map<K, V>());
    190     }
    191     return Status::OK();
    192   };
    193 
    194   Status DoInsert(const Tensor& keys, const Tensor& values) override {
    195     if (!table_) {
    196       return errors::FailedPrecondition("HashTable is not prepared.");
    197     }
    198 
    199     const auto key_values = keys.flat<K>();
    200     const auto value_values = values.flat<V>();
    201     for (int64 i = 0; i < key_values.size(); ++i) {
    202       const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
    203       const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i));
    204       const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
    205       if (previous_value != value) {
    206         return errors::FailedPrecondition(
    207             "HashTable has different value for same key. Key ", key, " has ",
    208             previous_value, " and trying to add value ", value);
    209       }
    210     }
    211     return Status::OK();
    212   }
    213 
    214   Status DoFind(const Tensor& key, Tensor* value,
    215                 const Tensor& default_value) override {
    216     const V default_val = default_value.flat<V>()(0);
    217     const auto key_values = key.flat<K>();
    218     auto value_values = value->flat<V>();
    219 
    220     for (int64 i = 0; i < key_values.size(); ++i) {
    221       value_values(i) = gtl::FindWithDefault(
    222           *table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)),
    223           default_val);
    224     }
    225     return Status::OK();
    226   }
    227 
    228   int64 MemoryUsed() const override {
    229     if (table_) {
    230       const int64 num_elements = table_->size();
    231       return num_elements * (sizeof(K) + sizeof(V));
    232     } else {
    233       return 0;
    234     }
    235   }
    236 
    237  private:
    238   std::unique_ptr<std::unordered_map<K, V>> table_;
    239 };
    240 
    241 }  // namespace lookup
    242 
    243 }  // namespace tensorflow
    244 
    245 #endif  // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
    246