1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #define EIGEN_USE_THREADS 16 17 #include "tensorflow/core/kernels/lookup_table_init_op.h" 18 19 #include <algorithm> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/lookup_util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/io/inputbuffer.h" 33 #include "tensorflow/core/lib/strings/numbers.h" 34 #include "tensorflow/core/lib/strings/str_util.h" 35 #include "tensorflow/core/platform/macros.h" 36 37 namespace tensorflow { 38 39 // Kernel to initialize a look table given a key and value tensors. 40 // After this operation, the table becomes read-only. 41 class InitializeTableOp : public OpKernel { 42 public: 43 explicit InitializeTableOp(OpKernelConstruction* context) 44 : OpKernel(context) {} 45 46 void Compute(OpKernelContext* ctx) override { 47 mutex_lock l(mu_); 48 lookup::InitializableLookupTable* table; 49 OP_REQUIRES_OK(ctx, 50 GetInitializableLookupTable("table_handle", ctx, &table)); 51 core::ScopedUnref unref_me(table); 52 53 DataType expected_input_0 = 54 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 55 DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(), 56 table->value_dtype()}; 57 DataTypeVector expected_outputs = {}; 58 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); 59 60 const Tensor& keys = ctx->input(1); 61 OP_REQUIRES( 62 ctx, TensorShapeUtils::IsVector(keys.shape()), 63 errors::InvalidArgument("Keys must be a vector, but received shape", 64 keys.shape().DebugString())); 65 66 const Tensor& values = ctx->input(2); 67 OP_REQUIRES( 68 ctx, TensorShapeUtils::IsVector(values.shape()), 69 errors::InvalidArgument("Values must be a vector, but received shape", 70 values.shape().DebugString())); 71 72 OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(), 73 errors::InvalidArgument( 74 "Keys and values must have the same size ", 75 keys.NumElements(), " vs ", values.NumElements())); 76 77 lookup::KeyValueTensorIterator iter(&keys, &values); 78 79 int memory_used_before = 0; 80 if (ctx->track_allocations()) { 81 memory_used_before = table->MemoryUsed(); 82 } 83 OP_REQUIRES_OK(ctx, table->Initialize(iter)); 84 if (ctx->track_allocations()) { 85 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 86 memory_used_before); 87 } 88 } 89 90 private: 91 mutex mu_; 92 }; 93 94 REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU), 95 InitializeTableOp); 96 REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU), 97 InitializeTableOp); 98 99 // Kernel to initialize a lookup table from a text file. 100 // 101 // After this operation, the table becomes read-only. 102 class InitializeTableFromTextFileOp : public OpKernel { 103 public: 104 explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx) 105 : OpKernel(ctx) { 106 OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_)); 107 OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_)); 108 OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_)); 109 string delimiter; 110 OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter)); 111 OP_REQUIRES(ctx, delimiter.size() == 1, 112 errors::InvalidArgument("delimiter should be only 1 char")); 113 delimiter_ = delimiter[0]; 114 } 115 116 void Compute(OpKernelContext* ctx) override { 117 mutex_lock l(mu_); 118 lookup::InitializableLookupTable* table; 119 OP_REQUIRES_OK(ctx, 120 GetInitializableLookupTable("table_handle", ctx, &table)); 121 core::ScopedUnref unref_me(table); 122 123 DataType expected_input_0 = 124 (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF; 125 DataTypeVector expected_inputs = {expected_input_0, DT_STRING}; 126 DataTypeVector expected_outputs = {}; 127 OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs)); 128 129 const Tensor& vocab_filename_tensor = ctx->input(1); 130 OP_REQUIRES( 131 ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()), 132 errors::InvalidArgument("filename should be a single string, but got ", 133 vocab_filename_tensor.shape().DebugString())); 134 135 string vocab_filename = vocab_filename_tensor.scalar<string>()(); 136 OP_REQUIRES(ctx, !vocab_filename.empty(), 137 errors::InvalidArgument("filename cannot be empty.")); 138 139 int64 memory_used_before = 0; 140 if (ctx->track_allocations()) { 141 memory_used_before = table->MemoryUsed(); 142 } 143 OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile( 144 vocab_filename, vocab_size_, delimiter_, key_index_, 145 value_index_, ctx->env(), table)); 146 if (ctx->track_allocations()) { 147 ctx->record_persistent_memory_allocation(table->MemoryUsed() - 148 memory_used_before); 149 } 150 } 151 152 private: 153 mutex mu_; 154 int64 vocab_size_; 155 char delimiter_; 156 int64 key_index_; 157 int64 value_index_; 158 159 TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp); 160 }; 161 162 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU), 163 InitializeTableFromTextFileOp); 164 REGISTER_KERNEL_BUILDER( 165 Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU), 166 InitializeTableFromTextFileOp); 167 168 } // namespace tensorflow 169