Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
     17 #define TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
     18 
     19 #include "tensorflow/core/framework/lookup_interface.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/kernels/initializable_lookup_table.h"
     22 
     23 namespace tensorflow {
     24 namespace lookup {
     25 
     26 // Gets the LookupTable stored in the ctx->resource_manager() with key
     27 // passed by attribute with name input_name, returns null if the table
     28 // doesn't exist.
     29 Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
     30                       LookupInterface** table);
     31 
     32 // Gets the InitializableLookupTable stored in the
     33 // ctx->resource_manager() with key passed by attribute with name
     34 // input_name, returns null if the table doesn't exist.
     35 Status GetInitializableLookupTable(const string& input_name,
     36                                    OpKernelContext* ctx,
     37                                    InitializableLookupTable** table);
     38 
     39 // Verify that the given key_dtype and value_dtype matches the corresponding
     40 // table's data types.
     41 Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
     42                            DataType value_dtype, const string& table_name);
     43 
     44 Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
     45                                    char delimiter, int32 key_index,
     46                                    int32 value_index, Env* env,
     47                                    InitializableLookupTable* table);
     48 
     49 // Iterator to initialize tables given 'keys' and 'values' tensors.
     50 //
     51 // The two tensors are returned in the first iteration. It doesn't loop
     52 // over each element of the tensor since insertions in the lookup table can
     53 // process batches.
     54 class KeyValueTensorIterator
     55     : public InitializableLookupTable::InitTableIterator {
     56  public:
     57   // keys and values are not owned by the iterator.
     58   explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
     59       : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
     60     TensorShape key_shape = keys_->shape();
     61     if (!key_shape.IsSameSize(values_->shape())) {
     62       valid_ = false;
     63       status_ = errors::InvalidArgument(
     64           "keys and values should have the same dimension.",
     65           key_shape.DebugString(), " vs ", values_->shape().DebugString());
     66     }
     67     if (key_shape.num_elements() == 0) {
     68       valid_ = false;
     69       status_ =
     70           errors::InvalidArgument("keys and values cannot be empty tensors.");
     71     }
     72   }
     73 
     74   bool Valid() const override { return valid_; }
     75 
     76   void Next() override {
     77     valid_ = false;
     78     status_ = errors::OutOfRange("No more data.");
     79   }
     80 
     81   const Tensor& keys() const override { return *keys_; }
     82 
     83   const Tensor& values() const override { return *values_; }
     84 
     85   Status status() const override { return status_; }
     86 
     87   int64 total_size() const override {
     88     return keys_ == nullptr ? -1 : keys_->NumElements();
     89   }
     90 
     91  private:
     92   TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
     93 
     94   const Tensor* keys_;    // Doesn't own it.
     95   const Tensor* values_;  // Doesn't own it.
     96   bool valid_;            // true if the iterator points to an existing range.
     97   Status status_;
     98 };
     99 
    100 }  // namespace lookup
    101 }  // namespace tensorflow
    102 
    103 #endif  // TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
    104