Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/lookup_util.h"
     17 
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 #include "tensorflow/core/lib/io/inputbuffer.h"
     22 
     23 namespace tensorflow {
     24 namespace lookup {
     25 namespace {
     26 
     27 static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
     28 static const int kLineNumber = -1;
     29 static const int kWholeLine = -2;
     30 
     31 Status GetNumLinesInTextFile(Env* env, const string& vocab_file,
     32                              int64* num_lines) {
     33   std::unique_ptr<RandomAccessFile> file;
     34   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
     35 
     36   io::InputBuffer input_buffer(file.get(), kInputBufferSize);
     37   string line;
     38   Status s = input_buffer.ReadLine(&line);
     39   int64 next_id = 0;
     40   while (s.ok()) {
     41     next_id++;
     42     s = input_buffer.ReadLine(&line);
     43   }
     44   if (!errors::IsOutOfRange(s)) {
     45     return s;
     46   }
     47   *num_lines = next_id;
     48   return Status::OK();
     49 }
     50 
     51 // Iterator that reads a text file. Each iteration process one line, it parses
     52 // the line and populates the keys and values tensors used for initialization
     53 // with a single key and corresponding value.
     54 //
     55 // What information of the line to populate the key or values is specified by
     56 // providing key_index and value_index.
     57 class TextFileLineIterator
     58     : public InitializableLookupTable::InitTableIterator {
     59  public:
     60   TextFileLineIterator()
     61       : valid_(false),
     62         vocab_size_(-1),
     63         status_(errors::FailedPrecondition("Not initialized")) {}
     64 
     65   // Initialize iterator.
     66   //
     67   // Prepares the file 'filename' and sets the data types to return the keys and
     68   // values tensors. It requires the indices of the tokens in the line given a
     69   // delimiter to specify where to pick the data from.
     70   //
     71   // - Index -2 means the entire line as string.
     72   // - Index -1 means the line number stored in int64.
     73   // - Index >= 0 represent index (starting at zero) of the split line based on
     74   //   delimiter.
     75   Status Init(const string& filename, int64 vocab_size, char delimiter,
     76               DataType key_dtype, int64 key_index, DataType value_dtype,
     77               int64 value_index, Env* env) {
     78     if (vocab_size == -1) {
     79       TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size));
     80     }
     81     filename_ = filename;
     82     vocab_size_ = vocab_size;
     83     delimiter_ = delimiter;
     84     key_ = Tensor(key_dtype, TensorShape({}));
     85     value_ = Tensor(value_dtype, TensorShape({}));
     86     key_index_ = key_index;
     87     value_index_ = value_index;
     88 
     89     status_ = env->NewRandomAccessFile(filename_, &file_);
     90     if (!status_.ok()) return status_;
     91 
     92     input_buffer_.reset(new io::InputBuffer(file_.get(), kInputBufferSize));
     93     valid_ = true;
     94     next_id_ = 0;
     95     ignore_split_ = std::max(key_index_, value_index_) < 0;
     96     Next();
     97     return status_;
     98   }
     99 
    100   void Next() override {
    101     if (!valid_) return;
    102 
    103     string line;
    104     status_ = input_buffer_->ReadLine(&line);
    105     if (!status_.ok()) {
    106       if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) {
    107         status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
    108                                           ": expected ", vocab_size_,
    109                                           " but got ", next_id_);
    110       }
    111       valid_ = false;
    112       return;
    113     }
    114     if (next_id_ >= vocab_size_) {
    115       LOG(WARNING) << "Truncated " << filename_ << " before its end at "
    116                    << vocab_size_ << " records.";
    117       LOG(WARNING) << "next_id_  : " << next_id_;
    118       status_ = errors::OutOfRange("Finished reading ", vocab_size_,
    119                                    " of lines from ", filename_);
    120       valid_ = false;
    121       return;
    122     }
    123     if (line.empty()) {
    124       status_ = errors::InvalidArgument("Invalid content in ", filename_,
    125                                         ": empty line found at position ",
    126                                         input_buffer_->Tell(), ".");
    127       valid_ = false;
    128       return;
    129     }
    130 
    131     std::vector<string> tokens;
    132     if (!ignore_split_) {
    133       tokens = str_util::Split(line, delimiter_);
    134       if (std::max(key_index_, value_index_) >= tokens.size()) {
    135         status_ = errors::InvalidArgument(
    136             "Invalid number of columns in ", filename_, " line ", next_id_,
    137             " (", line, ") : expected ", std::max(key_index_, value_index_),
    138             " got ", tokens.size());
    139         valid_ = false;
    140         return;
    141       }
    142     }
    143     status_ = SetValue(line, tokens, key_index_, &key_);
    144     if (!status_.ok()) {
    145       valid_ = false;
    146       return;
    147     }
    148     status_ = SetValue(line, tokens, value_index_, &value_);
    149     if (!status_.ok()) {
    150       valid_ = false;
    151       return;
    152     }
    153 
    154     next_id_++;
    155   }
    156 
    157   bool Valid() const override { return valid_; }
    158 
    159   const Tensor& keys() const override { return key_; }
    160 
    161   const Tensor& values() const override { return value_; }
    162 
    163   Status status() const override { return status_; }
    164 
    165   int64 total_size() const override { return vocab_size_; }
    166 
    167  private:
    168   Tensor key_;
    169   Tensor value_;
    170   bool valid_;  // true if the iterator points to an existing range.
    171   int64 key_index_;
    172   int64 value_index_;
    173   int64 next_id_;
    174   int64 vocab_size_;
    175   string filename_;
    176   char delimiter_;
    177   Status status_;
    178   bool ignore_split_;
    179   std::unique_ptr<RandomAccessFile> file_;  // must outlive input_buffer_
    180   std::unique_ptr<io::InputBuffer> input_buffer_;
    181 
    182   // Set the corresponding value from line or tokens based on 'index' into the
    183   // tensor 't'. The value is transformed to the given data type 'dtype'.
    184   Status SetValue(const string& line, const std::vector<string>& tokens,
    185                   int64 index, Tensor* tensor) {
    186     if (index == kLineNumber) {
    187       tensor->flat<int64>()(0) = next_id_;
    188       return Status::OK();
    189     }
    190     const string& token = (index == kWholeLine) ? line : tokens[index];
    191     const DataType& dtype = tensor->dtype();
    192     switch (dtype) {
    193       case DT_INT32: {
    194         int32 value;
    195         if (!strings::safe_strto32(token.c_str(), &value)) {
    196           valid_ = false;
    197           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
    198                                          " is not a valid int32.");
    199         }
    200         tensor->flat<int32>()(0) = value;
    201       } break;
    202       case DT_INT64: {
    203         int64 value;
    204         if (!strings::safe_strto64(token.c_str(), &value)) {
    205           valid_ = false;
    206           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
    207                                          " is not a valid int64.");
    208         }
    209         tensor->flat<int64>()(0) = value;
    210       } break;
    211       case DT_FLOAT: {
    212         float value;
    213         if (!strings::safe_strtof(token.c_str(), &value)) {
    214           valid_ = false;
    215           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
    216                                          " is not a valid float.");
    217         }
    218         tensor->flat<float>()(0) = value;
    219       } break;
    220       case DT_DOUBLE: {
    221         double value;
    222         if (!strings::safe_strtod(token.c_str(), &value)) {
    223           valid_ = false;
    224           return errors::InvalidArgument("Field ", token, " in line ", next_id_,
    225                                          " is not a valid double.");
    226         }
    227         tensor->flat<double>()(0) = value;
    228       } break;
    229       case DT_STRING:
    230         tensor->flat<string>()(0) = token;
    231         break;
    232       default:
    233         valid_ = false;
    234         return errors::InvalidArgument("Data type ", dtype, " not supported.");
    235     }
    236     return Status::OK();
    237   }
    238 
    239   TF_DISALLOW_COPY_AND_ASSIGN(TextFileLineIterator);
    240 };
    241 
    242 Status GetTableHandle(const string& input_name, OpKernelContext* ctx,
    243                       string* container, string* table_handle) {
    244   {
    245     mutex* mu;
    246     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
    247     mutex_lock l(*mu);
    248     Tensor tensor;
    249     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
    250     if (tensor.NumElements() != 2) {
    251       return errors::InvalidArgument(
    252           "Lookup table handle must be scalar, but had shape: ",
    253           tensor.shape().DebugString());
    254     }
    255     auto h = tensor.flat<string>();
    256     *container = h(0);
    257     *table_handle = h(1);
    258   }
    259   return Status::OK();
    260 }
    261 
    262 }  // namespace
    263 
    264 Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
    265                       LookupInterface** table) {
    266   string container;
    267   string table_handle;
    268   DataType handle_dtype;
    269   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
    270   if (handle_dtype == DT_RESOURCE) {
    271     ResourceHandle handle;
    272     TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
    273     return LookupResource(ctx, handle, table);
    274   } else {
    275     TF_RETURN_IF_ERROR(
    276         GetTableHandle(input_name, ctx, &container, &table_handle));
    277     return ctx->resource_manager()->Lookup(container, table_handle, table);
    278   }
    279 }
    280 
    281 Status GetInitializableLookupTable(const string& input_name,
    282                                    OpKernelContext* ctx,
    283                                    InitializableLookupTable** table) {
    284   LookupInterface* lookup_table;
    285   DataType handle_dtype;
    286   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
    287   if (handle_dtype == DT_RESOURCE) {
    288     ResourceHandle handle;
    289     TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
    290     TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
    291     *table = lookup_table->GetInitializableLookupTable();
    292     if (*table == nullptr) {
    293       lookup_table->Unref();
    294       return errors::InvalidArgument("Table ", handle.container(), " ",
    295                                      handle.name(), " is not initializable");
    296     }
    297   } else {
    298     string container;
    299     string table_handle;
    300     TF_RETURN_IF_ERROR(
    301         GetTableHandle(input_name, ctx, &container, &table_handle));
    302     TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
    303                                                        &lookup_table));
    304     *table = lookup_table->GetInitializableLookupTable();
    305     if (*table == nullptr) {
    306       lookup_table->Unref();
    307       return errors::InvalidArgument("Table ", container, " ", table_handle,
    308                                      " is not initializable");
    309     }
    310   }
    311   return Status::OK();
    312 }
    313 
    314 Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
    315                            DataType value_dtype, const string& table_name) {
    316   if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) {
    317     return errors::InvalidArgument(
    318         "Conflicting key/value dtypes ", key_dtype, "->", value_dtype, " with ",
    319         table.key_dtype(), "-", table.value_dtype(), " for table ", table_name);
    320   }
    321   return Status::OK();
    322 }
    323 
    324 // Helper function to initialize an InitializableLookupTable from a text file.
    325 Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
    326                                    char delimiter, int32 key_index,
    327                                    int32 value_index, Env* env,
    328                                    InitializableLookupTable* table) {
    329   if (key_index == kLineNumber && table->key_dtype() != DT_INT64) {
    330     return errors::InvalidArgument(
    331         "Key index for line number requires table key dtype of int64, got ",
    332         table->key_dtype());
    333   }
    334   const DataType& key_dtype = table->key_dtype();
    335   const DataType& value_dtype = table->value_dtype();
    336   if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) &&
    337       key_dtype != DT_STRING) {
    338     return errors::InvalidArgument(
    339         "Key index for whole line requires string or integer table key, got ",
    340         table->key_dtype());
    341   }
    342   if (value_index == kLineNumber && value_dtype != DT_INT64) {
    343     return errors::InvalidArgument(
    344         "Value index for line number requires table value dtype of int64, got ",
    345         table->value_dtype());
    346   }
    347   if (value_index == kWholeLine && value_dtype != DT_STRING) {
    348     return errors::InvalidArgument(
    349         "Value index for whole line requires table value dtype of string, got ",
    350         table->value_dtype());
    351   }
    352 
    353   TextFileLineIterator iter;
    354   TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
    355                                key_index, value_dtype, value_index, env));
    356   // For initialization from files, ignore if the table is already
    357   // initialized. The table shared name should contain the filename to
    358   // avoid trying to initialize the same table from the same file at the same
    359   // time.
    360   Status s = table->Initialize(iter);
    361   if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
    362     LOG(INFO) << "Table trying to initialize from file " << filename
    363               << " is already initialized.";
    364     return Status::OK();
    365   }
    366   return s;
    367 }
    368 
    369 }  // namespace lookup
    370 }  // namespace tensorflow
    371