Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <string>
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/kernel_def_builder.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/types.h"
     23 #include "tensorflow/core/kernels/lookup_table_init_op.h"
     24 #include "tensorflow/core/kernels/lookup_table_op.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 // lookup::InitializeTableFromTextFile requires a delimiter even though we use
     31 // the entire line for vocabularies.
     32 constexpr char kUnusedLookupDelim = '\t';
     33 }  // namespace
     34 
     35 // This Op generates a vocab remapping Tensor from an old and new vocabulary
     36 // file that maps new ID's to old ID's.
     37 class GenerateVocabRemappingOp : public OpKernel {
     38  public:
     39   explicit GenerateVocabRemappingOp(OpKernelConstruction* context)
     40       : OpKernel(context) {
     41     OP_REQUIRES_OK(context,
     42                    context->GetAttr("new_vocab_offset", &new_vocab_offset_));
     43     OP_REQUIRES_OK(context, context->GetAttr("num_new_vocab", &num_new_vocab_));
     44     OP_REQUIRES_OK(context,
     45                    context->GetAttr("old_vocab_size", &old_vocab_size_));
     46   }
     47 
     48   void Compute(OpKernelContext* context) override {
     49     const Tensor* new_vocab_file_tensor;
     50     OP_REQUIRES_OK(context,
     51                    context->input("new_vocab_file", &new_vocab_file_tensor));
     52     OP_REQUIRES(context,
     53                 TensorShapeUtils::IsScalar(new_vocab_file_tensor->shape()),
     54                 errors::InvalidArgument(
     55                     "new_vocab_file should be a single string, but got ",
     56                     new_vocab_file_tensor->shape().DebugString()));
     57 
     58     // Build a new ID->token lookup table.
     59     const string& new_vocab_filename =
     60         new_vocab_file_tensor->scalar<string>()();
     61     OP_REQUIRES(context, !new_vocab_filename.empty(),
     62                 errors::InvalidArgument("new vocab filename cannot be empty."));
     63     lookup::HashTable<int64, string>* new_vocab_table =
     64         new lookup::HashTable<int64, string>(context, this);
     65     core::ScopedUnref unref_new(new_vocab_table);
     66     // Note: we pass -1 (unknown) for vocab_size, which is supposed to be the
     67     // total elements in file.  This is different from num_new_vocab_, which
     68     // accounts for partitioning.
     69     OP_REQUIRES_OK(context, lookup::InitializeTableFromTextFile(
     70                                 new_vocab_filename,
     71                                 -1,  // vocab_size
     72                                 kUnusedLookupDelim,
     73                                 -1,  // key_index, use the line number.
     74                                 -2,  // value_index, use the whole line/token.
     75                                 context->env(), new_vocab_table));
     76     OP_REQUIRES(context,
     77                 new_vocab_offset_ + num_new_vocab_ <= new_vocab_table->size(),
     78                 errors::InvalidArgument("lookup table size must be larger than "
     79                                         "last new vocab entry's line"));
     80 
     81     const Tensor* old_vocab_file_tensor;
     82     OP_REQUIRES_OK(context,
     83                    context->input("old_vocab_file", &old_vocab_file_tensor));
     84     OP_REQUIRES(context,
     85                 TensorShapeUtils::IsScalar(old_vocab_file_tensor->shape()),
     86                 errors::InvalidArgument(
     87                     "old_vocab_file should be a single string, but got ",
     88                     old_vocab_file_tensor->shape().DebugString()));
     89     // Build a token->old ID lookup table.
     90     const string& old_vocab_filename =
     91         old_vocab_file_tensor->scalar<string>()();
     92     OP_REQUIRES(context, !old_vocab_filename.empty(),
     93                 errors::InvalidArgument("new vocab filename cannot be empty."));
     94     lookup::HashTable<string, int64>* old_vocab_table =
     95         new lookup::HashTable<string, int64>(context, this);
     96     core::ScopedUnref unref_old(old_vocab_table);
     97     // Note: If old_vocab_size_ is -1 (unknown), we retrieve all elements in
     98     // file (see TextFileLineIterator).
     99     OP_REQUIRES_OK(context,
    100                    lookup::InitializeTableFromTextFile(
    101                        old_vocab_filename, old_vocab_size_, kUnusedLookupDelim,
    102                        -2,  // key_index, use the whole line/token.
    103                        -1,  // value_index, use the line number.
    104                        context->env(), old_vocab_table));
    105 
    106     // Fill out new_ids = [new_vocab_offset, new_vocab_offset + 1, ...,
    107     //                     new_vocab_offset + num_new_vocab_]
    108     // The double look-up requires a few temporary Tensors.
    109     Tensor new_ids;
    110     OP_REQUIRES_OK(
    111         context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
    112                                         &new_ids));
    113     auto new_ids_vec = new_ids.vec<int64>();
    114     // Note that we should always be able to find tokens for all new ID's, given
    115     // that the lookup table is constructed with the vocabulary file itself
    116     // (see the check on offset and table size post-initialization).
    117     Tensor default_token;
    118     OP_REQUIRES_OK(
    119         context, context->allocate_temp(
    120                      DT_STRING, TensorShape({num_new_vocab_}), &default_token));
    121     auto default_token_vec = default_token.vec<string>();
    122     default_token_vec.setConstant("" /* NOT_FOUND_TOKEN */);
    123 
    124     Tensor default_id;
    125     OP_REQUIRES_OK(
    126         context, context->allocate_temp(DT_INT64, TensorShape({num_new_vocab_}),
    127                                         &default_id));
    128     auto default_id_vec = default_id.vec<int64>();
    129     default_id_vec.setConstant(-1 /* NOT_FOUND_ID */);
    130 
    131     for (int i = 0; i < num_new_vocab_; ++i) {
    132       new_ids_vec(i) = static_cast<int64>(i + new_vocab_offset_);
    133     }
    134     Tensor tokens;
    135     OP_REQUIRES_OK(context,
    136                    context->allocate_temp(
    137                        DT_STRING, TensorShape({num_new_vocab_}), &tokens));
    138     Tensor* remapping;
    139     OP_REQUIRES_OK(context,
    140                    context->allocate_output(
    141                        "remapping", TensorShape({num_new_vocab_}), &remapping));
    142     // In the corner case where num_new_vocab_ is 0 (we are dealing with an
    143     // OOV-only partition), we should not do this lookup.
    144     if (num_new_vocab_ != 0) {
    145       OP_REQUIRES_OK(context, new_vocab_table->Find(context, new_ids, &tokens,
    146                                                     default_token));
    147       OP_REQUIRES_OK(context, old_vocab_table->Find(context, tokens, remapping,
    148                                                     default_id));
    149     }
    150     // Iterate through remapping to calculate num_present.
    151     const auto remapping_vec = remapping->vec<int64>();
    152     int num_present = 0;
    153     for (int i = 0; i < num_new_vocab_; ++i) {
    154       if (remapping_vec(i) != -1 /* NOT_FOUND_ID */) {
    155         ++num_present;
    156       }
    157     }
    158     Tensor* num_present_t;
    159     OP_REQUIRES_OK(context,
    160                    context->allocate_output("num_present", TensorShape({}),
    161                                             &num_present_t));
    162     num_present_t->scalar<int>()() = num_present;
    163   }
    164 
    165  private:
    166   int new_vocab_offset_;
    167   int num_new_vocab_;
    168   int old_vocab_size_;
    169 };
    170 
    171 REGISTER_KERNEL_BUILDER(Name("GenerateVocabRemapping").Device(DEVICE_CPU),
    172                         GenerateVocabRemappingOp);
    173 
    174 }  // namespace tensorflow
    175