Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
     17 #define TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
     18 
     19 #include "tensorflow/core/framework/lookup_interface.h"
     20 #include "tensorflow/core/platform/macros.h"
     21 
     22 namespace tensorflow {
     23 namespace lookup {
     24 
     25 // Base class for lookup tables that require initialization.
     26 class InitializableLookupTable : public LookupInterface {
     27  public:
     28   class InitTableIterator;
     29 
     30   // Performs batch lookups, for every element in the key tensor, Find returns
     31   // the corresponding value into the values tensor.
     32   // If an element is not present in the table, the given default value is used.
     33   //
     34   // For tables that require initialization, `Find` is available once the table
     35   // is marked as initialized.
     36   //
     37   // Returns the following statuses:
     38   // - OK: when the find finishes successfully.
     39   // - FailedPrecondition: if the table is not initialized.
     40   // - InvalidArgument: if any of the preconditions on the lookup key or value
     41   //   fails.
     42   // - In addition, other implementations may provide another non-OK status
     43   //   specific to their failure modes.
     44   Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
     45               const Tensor& default_value) final;
     46 
     47   // Returns errors::Unimplemented.
     48   Status Insert(OpKernelContext* ctx, const Tensor& keys,
     49                 const Tensor& values) final {
     50     return errors::Unimplemented(
     51         "Insert not supported by InitializableLookupTable implementations");
     52   }
     53 
     54   Status ExportValues(OpKernelContext* context) final {
     55     return errors::Unimplemented(
     56         "ExportValues not supported by InitializableLookupTable "
     57         "implementations");
     58   }
     59 
     60   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
     61                       const Tensor& values) final {
     62     return errors::Unimplemented(
     63         "ImportValues not supported by InitializableLookupTable "
     64         "implementations");
     65   }
     66 
     67   TensorShape key_shape() const final { return TensorShape(); }
     68 
     69   TensorShape value_shape() const final { return TensorShape(); }
     70 
     71   // Returns whether the table was initialized and is ready to serve lookups.
     72   bool is_initialized() const { return is_initialized_; }
     73 
     74   // Initializes the table from the given init table iterator.
     75   //
     76   // Atomically, this operation prepares the table, populates it with the given
     77   // iterator, and mark the table as initialized.
     78   //
     79   // Returns the following statuses:
     80   // - OK: when the initialization was successful.
     81   // - InvalidArgument: if any of the preconditions on the lookup key or value
     82   //   fails.
     83   // - FailedPrecondition: if the table is already initialized and
     84   //   fail_if_initialized is set to true.
     85   // - In addition, other implementations may provide another non-OK status
     86   //   specific to their failure modes.
     87   Status Initialize(InitTableIterator& iter);
     88 
     89   // Basic iterator to initialize lookup tables.
     90   // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
     91   // the consumer may insert key-value pairs in batches.
     92   //
     93   // Then the iterator is exhausted, valid returns false and status returns
     94   // Status::OutOfRange.
     95   class InitTableIterator {
     96    public:
     97     InitTableIterator() {}
     98 
     99     virtual ~InitTableIterator() {}
    100 
    101     // Prepares the next batch of key and value tensors.
    102     virtual void Next() = 0;
    103 
    104     // Returns true if keys and values point to valid tensors.
    105     virtual bool Valid() const = 0;
    106 
    107     // Returns a tensor that contains the current batch of 'key' values.
    108     virtual const Tensor& keys() const = 0;
    109 
    110     // Returns a tensor that contains the current batch of 'value' values.
    111     virtual const Tensor& values() const = 0;
    112 
    113     // Returns an error if one has occurred, otherwise returns Status::OK.
    114     virtual Status status() const = 0;
    115 
    116     // Returns the total number of elements that the iterator will produce.
    117     virtual int64 total_size() const = 0;
    118 
    119    private:
    120     TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
    121   };
    122 
    123   InitializableLookupTable* GetInitializableLookupTable() override {
    124     return this;
    125   }
    126 
    127  protected:
    128   // Prepares and allocates the underlying data structure to store the given
    129   // number of expected elements.
    130   virtual Status DoPrepare(size_t expected_num_elements) = 0;
    131 
    132   // Populates the table in batches given keys and values as tensors into the
    133   // underlying data structure.
    134   virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
    135 
    136   // Performs the batch find operation on the underlying data structure.
    137   virtual Status DoFind(const Tensor& keys, Tensor* values,
    138                         const Tensor& default_value) = 0;
    139 
    140   mutex mu_;
    141   bool is_initialized_ = false;
    142 };
    143 
    144 }  // namespace lookup
    145 }  // namespace tensorflow
    146 
    147 #endif  // TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
    148