1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ 17 #define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ 18 19 #include "tensorflow/core/framework/lookup_interface.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/resource_mgr.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/kernels/bounds_check.h" 25 #include "tensorflow/core/kernels/lookup_util.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/map_util.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/thread_annotations.h" 31 32 namespace tensorflow { 33 34 // Lookup table op that supports different table implementations specified by 35 // the 'Container' template. Container must be derived from LookupInterface. The 36 // key and value are of the templated type "key_dtype" and "value_dtype" 37 // respectively. 38 template <class Container, class key_dtype, class value_dtype> 39 class LookupTableOp : public OpKernel { 40 public: 41 // ctx is not owned by this class. 42 explicit LookupTableOp(OpKernelConstruction* ctx) 43 : OpKernel(ctx), table_handle_set_(false) { 44 OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING, 45 tensorflow::TensorShape({2}), 46 &table_handle_, nullptr)); 47 OP_REQUIRES_OK( 48 ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); 49 } 50 51 // ctx is not owned by this function. 52 void Compute(OpKernelContext* ctx) override { 53 mutex_lock l(mu_); 54 55 if (!table_handle_set_) { 56 OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), 57 use_node_name_sharing_)); 58 } 59 60 auto creator = [ctx, this](lookup::LookupInterface** ret) { 61 lookup::LookupInterface* container = new Container(ctx, this); 62 if (!ctx->status().ok()) { 63 container->Unref(); 64 return ctx->status(); 65 } 66 if (ctx->track_allocations()) { 67 ctx->record_persistent_memory_allocation( 68 container->MemoryUsed() + table_handle_.AllocatedBytes()); 69 } 70 *ret = container; 71 return Status::OK(); 72 }; 73 74 lookup::LookupInterface* table = nullptr; 75 OP_REQUIRES_OK(ctx, 76 cinfo_.resource_manager() 77 ->template LookupOrCreate<lookup::LookupInterface>( 78 cinfo_.container(), cinfo_.name(), &table, creator)); 79 core::ScopedUnref unref_me(table); 80 81 OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( 82 *table, DataTypeToEnum<key_dtype>::v(), 83 DataTypeToEnum<value_dtype>::v(), cinfo_.name())); 84 85 if (ctx->expected_output_dtype(0) == DT_RESOURCE) { 86 Tensor* handle; 87 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); 88 handle->scalar<ResourceHandle>()() = 89 MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(), 90 cinfo_.name()); 91 } else { 92 if (!table_handle_set_) { 93 auto h = table_handle_.AccessTensor(ctx)->template flat<string>(); 94 h(0) = cinfo_.container(); 95 h(1) = cinfo_.name(); 96 } 97 ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx)); 98 } 99 table_handle_set_ = true; 100 } 101 102 ~LookupTableOp() override { 103 // If the table object was not shared, delete it. 104 if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) { 105 TF_CHECK_OK( 106 cinfo_.resource_manager()->template Delete<lookup::LookupInterface>( 107 cinfo_.container(), cinfo_.name())); 108 } 109 } 110 111 private: 112 mutex mu_; 113 PersistentTensor table_handle_ GUARDED_BY(mu_); 114 bool table_handle_set_ GUARDED_BY(mu_); 115 ContainerInfo cinfo_; 116 bool use_node_name_sharing_; 117 118 TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp); 119 }; 120 121 namespace lookup { 122 123 // Ensure that the compiler cannot elide a copy into a local, for 124 // bounds checking on source tensors that might be updated asynchronously for 125 // integral types. However non-integer variables are not allowed and therefore 126 // the local copy is unnecessary. 127 template <typename T> 128 T SubtleMustCopyUnlessStringOrFloat(const T& value) { 129 return internal::SubtleMustCopy(value); 130 } 131 132 inline const string& SubtleMustCopyUnlessStringOrFloat(const string& value) { 133 return value; 134 } 135 136 inline const float SubtleMustCopyUnlessStringOrFloat(const float value) { 137 return value; 138 } 139 140 inline const double SubtleMustCopyUnlessStringOrFloat(const double value) { 141 return value; 142 } 143 144 // Lookup table that wraps an unordered_map, where the key and value data type 145 // is specified. 146 // 147 // This table is recommended for any variations to key values. 148 // 149 // For look up, the table is required to be initialized (allocated 150 // and populated). Once the table is marked as initialized it becomes read-only. 151 // 152 // Sample use case: 153 // 154 // HashTable<int64, int64> table; // int64 -> int64. 155 // table.Prepare(10); // Prepare the underlying data structure, the number of 156 // // elements is required by interface, but not used. 157 // // Populate the table, elements could be added in one or multiple calls. 158 // table.Insert(key_tensor, value_tensor); // Populate the table. 159 // ... 160 // table.set_is_initialized(); 161 // 162 // table.Find(in_t, &out_t, default_t) 163 // 164 template <class K, class V> 165 class HashTable : public InitializableLookupTable { 166 public: 167 HashTable(OpKernelContext* ctx, OpKernel* kernel) {} 168 169 size_t size() const override { 170 // return the size of the table only if it's initialized, otherwise 0. 171 if (!is_initialized_) { 172 return 0; 173 } 174 std::atomic_thread_fence(std::memory_order_acquire); 175 return table_ ? table_->size() : 0; 176 } 177 178 DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } 179 180 DataType value_dtype() const override { return DataTypeToEnum<V>::v(); } 181 182 protected: 183 Status DoPrepare(size_t unused) override { 184 if (is_initialized_) { 185 return errors::Aborted("HashTable already initialized."); 186 } 187 if (!table_) { 188 table_ = std::unique_ptr<std::unordered_map<K, V>>( 189 new std::unordered_map<K, V>()); 190 } 191 return Status::OK(); 192 }; 193 194 Status DoInsert(const Tensor& keys, const Tensor& values) override { 195 if (!table_) { 196 return errors::FailedPrecondition("HashTable is not prepared."); 197 } 198 199 const auto key_values = keys.flat<K>(); 200 const auto value_values = values.flat<V>(); 201 for (int64 i = 0; i < key_values.size(); ++i) { 202 const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i)); 203 const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i)); 204 const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value); 205 if (previous_value != value) { 206 return errors::FailedPrecondition( 207 "HashTable has different value for same key. Key ", key, " has ", 208 previous_value, " and trying to add value ", value); 209 } 210 } 211 return Status::OK(); 212 } 213 214 Status DoFind(const Tensor& key, Tensor* value, 215 const Tensor& default_value) override { 216 const V default_val = default_value.flat<V>()(0); 217 const auto key_values = key.flat<K>(); 218 auto value_values = value->flat<V>(); 219 220 for (int64 i = 0; i < key_values.size(); ++i) { 221 value_values(i) = gtl::FindWithDefault( 222 *table_, SubtleMustCopyUnlessStringOrFloat(key_values(i)), 223 default_val); 224 } 225 return Status::OK(); 226 } 227 228 int64 MemoryUsed() const override { 229 if (table_) { 230 const int64 num_elements = table_->size(); 231 return num_elements * (sizeof(K) + sizeof(V)); 232 } else { 233 return 0; 234 } 235 } 236 237 private: 238 std::unique_ptr<std::unordered_map<K, V>> table_; 239 }; 240 241 } // namespace lookup 242 243 } // namespace tensorflow 244 245 #endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_ 246