Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include "tensorflow/core/kernels/lookup_table_init_op.h"
     18 
     19 #include <algorithm>
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/lookup_util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/io/inputbuffer.h"
     33 #include "tensorflow/core/lib/strings/numbers.h"
     34 #include "tensorflow/core/lib/strings/str_util.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 
     37 namespace tensorflow {
     38 
     39 // Kernel to initialize a look table given a key and value tensors.
     40 // After this operation, the table becomes read-only.
     41 class InitializeTableOp : public OpKernel {
     42  public:
     43   explicit InitializeTableOp(OpKernelConstruction* context)
     44       : OpKernel(context) {}
     45 
     46   void Compute(OpKernelContext* ctx) override {
     47     mutex_lock l(mu_);
     48     lookup::InitializableLookupTable* table;
     49     OP_REQUIRES_OK(ctx,
     50                    GetInitializableLookupTable("table_handle", ctx, &table));
     51     core::ScopedUnref unref_me(table);
     52 
     53     DataType expected_input_0 =
     54         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
     55     DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
     56                                       table->value_dtype()};
     57     DataTypeVector expected_outputs = {};
     58     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
     59 
     60     const Tensor& keys = ctx->input(1);
     61     OP_REQUIRES(
     62         ctx, TensorShapeUtils::IsVector(keys.shape()),
     63         errors::InvalidArgument("Keys must be a vector, but received shape",
     64                                 keys.shape().DebugString()));
     65 
     66     const Tensor& values = ctx->input(2);
     67     OP_REQUIRES(
     68         ctx, TensorShapeUtils::IsVector(values.shape()),
     69         errors::InvalidArgument("Values must be a vector, but received shape",
     70                                 values.shape().DebugString()));
     71 
     72     OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(),
     73                 errors::InvalidArgument(
     74                     "Keys and values must have the same size ",
     75                     keys.NumElements(), " vs ", values.NumElements()));
     76 
     77     lookup::KeyValueTensorIterator iter(&keys, &values);
     78 
     79     int memory_used_before = 0;
     80     if (ctx->track_allocations()) {
     81       memory_used_before = table->MemoryUsed();
     82     }
     83     OP_REQUIRES_OK(ctx, table->Initialize(iter));
     84     if (ctx->track_allocations()) {
     85       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
     86                                                memory_used_before);
     87     }
     88   }
     89 
     90  private:
     91   mutex mu_;
     92 };
     93 
     94 REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
     95                         InitializeTableOp);
     96 REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
     97                         InitializeTableOp);
     98 
     99 // Kernel to initialize a lookup table from a text file.
    100 //
    101 // After this operation, the table becomes read-only.
    102 class InitializeTableFromTextFileOp : public OpKernel {
    103  public:
    104   explicit InitializeTableFromTextFileOp(OpKernelConstruction* ctx)
    105       : OpKernel(ctx) {
    106     OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_size", &vocab_size_));
    107     OP_REQUIRES_OK(ctx, ctx->GetAttr("key_index", &key_index_));
    108     OP_REQUIRES_OK(ctx, ctx->GetAttr("value_index", &value_index_));
    109     string delimiter;
    110     OP_REQUIRES_OK(ctx, ctx->GetAttr("delimiter", &delimiter));
    111     OP_REQUIRES(ctx, delimiter.size() == 1,
    112                 errors::InvalidArgument("delimiter should be only 1 char"));
    113     delimiter_ = delimiter[0];
    114   }
    115 
    116   void Compute(OpKernelContext* ctx) override {
    117     mutex_lock l(mu_);
    118     lookup::InitializableLookupTable* table;
    119     OP_REQUIRES_OK(ctx,
    120                    GetInitializableLookupTable("table_handle", ctx, &table));
    121     core::ScopedUnref unref_me(table);
    122 
    123     DataType expected_input_0 =
    124         (ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
    125     DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
    126     DataTypeVector expected_outputs = {};
    127     OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
    128 
    129     const Tensor& vocab_filename_tensor = ctx->input(1);
    130     OP_REQUIRES(
    131         ctx, TensorShapeUtils::IsScalar(vocab_filename_tensor.shape()),
    132         errors::InvalidArgument("filename should be a single string, but got ",
    133                                 vocab_filename_tensor.shape().DebugString()));
    134 
    135     string vocab_filename = vocab_filename_tensor.scalar<string>()();
    136     OP_REQUIRES(ctx, !vocab_filename.empty(),
    137                 errors::InvalidArgument("filename cannot be empty."));
    138 
    139     int64 memory_used_before = 0;
    140     if (ctx->track_allocations()) {
    141       memory_used_before = table->MemoryUsed();
    142     }
    143     OP_REQUIRES_OK(ctx, lookup::InitializeTableFromTextFile(
    144                             vocab_filename, vocab_size_, delimiter_, key_index_,
    145                             value_index_, ctx->env(), table));
    146     if (ctx->track_allocations()) {
    147       ctx->record_persistent_memory_allocation(table->MemoryUsed() -
    148                                                memory_used_before);
    149     }
    150   }
    151 
    152  private:
    153   mutex mu_;
    154   int64 vocab_size_;
    155   char delimiter_;
    156   int64 key_index_;
    157   int64 value_index_;
    158 
    159   TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromTextFileOp);
    160 };
    161 
    162 REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
    163                         InitializeTableFromTextFileOp);
    164 REGISTER_KERNEL_BUILDER(
    165     Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
    166     InitializeTableFromTextFileOp);
    167 
    168 }  // namespace tensorflow
    169