1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 19 #include "tensorflow/core/framework/kernel_def_builder.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/types.h" 23 #include "tensorflow/core/kernels/lookup_table_init_op.h" 24 #include "tensorflow/core/kernels/lookup_table_op.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 namespace { 30 // lookup::InitializeTableFromTextFile requires a delimiter even though we use 31 // the entire line for vocabularies. 32 constexpr char kUnusedLookupDelim = '\t'; 33 } // namespace 34 35 // This Op generates a vocab remapping Tensor from an old and new vocabulary 36 // file that maps new ID's to old ID's. 37 class GenerateVocabRemappingOp : public OpKernel { 38 public: 39 explicit GenerateVocabRemappingOp(OpKernelConstruction* context) 40 : OpKernel(context) { 41 OP_REQUIRES_OK(context, 42 context->GetAttr("new_vocab_offset", &new_vocab_offset_)); 43 OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_)); 44 OP_REQUIRES_OK(context, 45 context->GetAttr("old_vocab_size", &old_vocab_size_)); 46 } 47 48 void Compute(OpKernelContext* context) override { 49 const Tensor* new_vocab_file_tensor; 50 OP_REQUIRES_OK(context, 51 context->input("new_vocab_file", &new_vocab_file_tensor)); 52 OP_REQUIRES(context, 53 TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()), 54 errors::InvalidArgument( 55 "new_vocab_file should be a single string, but got ", 56 new_vocab_file_tensor->shape().DebugString())); 57 58 // Build a new ID->token lookup table. 59 const string& new_vocab_filename = 60 new_vocab_file_tensor->scalar<string>()(); 61 OP_REQUIRES(context, !new_vocab_filename.empty(), 62 errors::InvalidArgument("new vocab filename cannot be empty.")); 63 lookup::HashTable<int64, string>* new_vocab_table = 64 new lookup::HashTable<int64, string>(context, this); 65 core::ScopedUnref unref_new(new_vocab_table); 66 // Note: we pass -1 (unknown) for vocab_size, which is supposed to be the 67 // total elements in file. This is different from num_new_vocab_, which 68 // accounts for partitioning. 69 OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile( 70 new_vocab_filename, 71 -1, // vocab_size 72 kUnusedLookupDelim, 73 -1, // key_index, use the line number. 74 -2, // value_index, use the whole line/token. 75 context->env(), new_vocab_table)); 76 OP_REQUIRES(context, 77 new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(), 78 errors::InvalidArgument("lookup table size must be larger than " 79 "last new vocab entry's line")); 80 81 const Tensor* old_vocab_file_tensor; 82 OP_REQUIRES_OK(context, 83 context->input("old_vocab_file", &old_vocab_file_tensor)); 84 OP_REQUIRES(context, 85 TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()), 86 errors::InvalidArgument( 87 "old_vocab_file should be a single string, but got ", 88 old_vocab_file_tensor->shape().DebugString())); 89 // Build a token->old ID lookup table. 90 const string& old_vocab_filename = 91 old_vocab_file_tensor->scalar<string>()(); 92 OP_REQUIRES(context, !old_vocab_filename.empty(), 93 errors::InvalidArgument("new vocab filename cannot be empty.")); 94 lookup::HashTable<string, int64>* old_vocab_table = 95 new lookup::HashTable<string, int64>(context, this); 96 core::ScopedUnref unref_old(old_vocab_table); 97 // Note: If old_vocab_size_ is -1 (unknown), we retrieve all elements in 98 // file (see TextFileLineIterator). 99 OP_REQUIRES_OK(context, 100 lookup::InitializeTableFromTextFile( 101 old_vocab_filename, old_vocab_size_, kUnusedLookupDelim, 102 -2, // key_index, use the whole line/token. 103 -1, // value_index, use the line number. 104 context->env(), old_vocab_table)); 105 106 // Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ..., 107 // new_vocab_offset + num_new_vocab_] 108 // The double look-up requires a few temporary Tensors. 109 Tensor new_ids; 110 OP_REQUIRES_OK( 111 context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}), 112 &new_ids)); 113 auto new_ids_vec = new_ids.vec<int64>(); 114 // Note that we should always be able to find tokens for all new ID's, given 115 // that the lookup table is constructed with the vocabulary file itself 116 // (see the check on offset and table size post-initialization). 117 Tensor default_token; 118 OP_REQUIRES_OK( 119 context, context->allocate_temp( 120 DT_STRING, TensorShape({num_new_vocab_}), &default_token)); 121 auto default_token_vec = default_token.vec<string>(); 122 default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */); 123 124 Tensor default_id; 125 OP_REQUIRES_OK( 126 context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}), 127 &default_id)); 128 auto default_id_vec = default_id.vec<int64>(); 129 default_id_vec.setConstant(-1 /* NOT_FOUND_ID */); 130 131 for (int i = 0; i < num_new_vocab_; ++i) { 132 new_ids_vec(i) = static_cast<int64>(i + new_vocab_offset_); 133 } 134 Tensor tokens; 135 OP_REQUIRES_OK(context, 136 context->allocate_temp( 137 DT_STRING, TensorShape({num_new_vocab_}), &tokens)); 138 Tensor* remapping; 139 OP_REQUIRES_OK(context, 140 context->allocate_output( 141 "remapping", TensorShape({num_new_vocab_}), &remapping)); 142 // In the corner case where num_new_vocab_ is 0 (we are dealing with an 143 // OOV-only partition), we should not do this lookup. 144 if (num_new_vocab_ != 0) { 145 OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens, 146 default_token)); 147 OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping, 148 default_id)); 149 } 150 // Iterate through remapping to calculate num_present. 151 const auto remapping_vec = remapping->vec<int64>(); 152 int num_present = 0; 153 for (int i = 0; i < num_new_vocab_; ++i) { 154 if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) { 155 ++num_present; 156 } 157 } 158 Tensor* num_present_t; 159 OP_REQUIRES_OK(context, 160 context->allocate_output("num_present", TensorShape({}), 161 &num_present_t)); 162 num_present_t->scalar<int>()() = num_present; 163 } 164 165 private: 166 int new_vocab_offset_; 167 int num_new_vocab_; 168 int old_vocab_size_; 169 }; 170 171 REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU), 172 GenerateVocabRemappingOp); 173 174 } // namespace tensorflow 175