Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
     17 #define TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include "tensorflow/core/framework/queue_interface.h"
     22 #include "tensorflow/core/framework/reader_interface.h"
     23 #include "tensorflow/core/lib/core/stringpiece.h"
     24 
     25 namespace tensorflow {
     26 
     27 class ReaderBaseState;
     28 
     29 // Default implementation of ReaderInterface.
     30 class ReaderBase : public ReaderInterface {
     31  public:
     32   // name: For use in error messages, should mention both the name of
     33   // the op and the node.
     34   explicit ReaderBase(const string& name);
     35 
     36   // Note that methods with names ending in "Locked" are called while
     37   // the ReaderBase's mutex is held.
     38 
     39   // Implement this function in descendants -----------------------------------
     40 
     41   // Produce the next key/value pair from the current work item.
     42   // This is called "Locked" since it is executed under a mutex
     43   // that serializes all Reader calls.
     44   // Usage:
     45   //  a) If a record was successfully produced, set *produced = true,
     46   //  and fill in *key and *value.
     47   //  b) If no more records will be produced for this work item, set
     48   //  *at_end = true.
     49   //  c) If a record was produced, but no more will be produced, you
     50   //     may either do both (a) and (b), or do (a) in this call and do (b) in
     51   //     the next call to ReadLocked().
     52   //  d) If there was an error producing (e.g. an error reading the file,
     53   //     data corruption), return a non-OK() status.  ReadLocked may be
     54   //     called again if the user reruns this part of the graph.
     55   virtual Status ReadLocked(string* key, string* value, bool* produced,
     56                             bool* at_end) = 0;
     57 
     58   // Descendants may optionally implement these -------------------------------
     59 
     60   // Produce up to num_records next key/value pairs from the current
     61   // work item, in the same manner of ReadLocked.
     62   virtual Status ReadUpToLocked(int64 num_records, std::vector<string>* keys,
     63                                 std::vector<string>* values, int64* num_read,
     64                                 bool* at_end);
     65 
     66   // Called when work starts / finishes.
     67   virtual Status OnWorkStartedLocked() { return Status::OK(); }
     68   virtual Status OnWorkFinishedLocked() { return Status::OK(); }
     69 
     70   // Called to reset the Reader to a newly constructed state.
     71   virtual Status ResetLocked();
     72 
     73   // Default implementation generates an Unimplemented error.
     74   // See the protected helper methods below.
     75   virtual Status SerializeStateLocked(string* state);
     76   virtual Status RestoreStateLocked(const string& state);
     77 
     78   // Accessors ----------------------------------------------------------------
     79 
     80   // Always true during a call to ReadLocked().
     81   bool work_in_progress() const { return work_finished_ < work_started_; }
     82 
     83   // Returns the name of the current work item (valid if
     84   // work_in_progress() returns true).  May change between calls to
     85   // ReadLocked().
     86   const string& current_work() const { return work_; }
     87 
     88   // What was passed to the constructor.
     89   const string& name() const { return name_; }
     90 
     91   // Produce the key name (from current_work and the actual key).
     92   string KeyName(const string& key) const;
     93 
     94  protected:
     95   // For descendants wishing to implement serialize & restore state.
     96 
     97   // Writes ReaderBase state to *state.
     98   void SaveBaseState(ReaderBaseState* state) const;
     99 
    100   // Restores ReaderBase state from state. Assumes state was filled
    101   // using SaveBaseState() above.
    102   Status RestoreBaseState(const ReaderBaseState& state);
    103 
    104  private:
    105   // For descendants that wish to obtain the next work item in a different way.
    106   // For implementing Read().  Dequeues the next work item from
    107   // *queue, and if successful returns "work" (a string). May block.
    108   virtual string GetNextWorkLocked(QueueInterface* queue,
    109                                    OpKernelContext* context) const;
    110 
    111   // Implementations of ReaderInterface methods.  These ensure thread-safety
    112   // and call the methods above to do the work.
    113   void Read(QueueInterface* queue, string* key, string* value,
    114             OpKernelContext* context) override;
    115 
    116   // Produces up to num_records.
    117   // In this implementation all the records come from the same work unit.
    118   int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
    119                  std::vector<string>* keys, std::vector<string>* value,
    120                  OpKernelContext* context) override;
    121 
    122   Status Reset() override;
    123   int64 NumRecordsProduced() override;
    124   int64 NumWorkUnitsCompleted() override;
    125   Status SerializeState(string* state) override;
    126   Status RestoreState(const string& state) override;
    127 
    128   mutable mutex mu_;
    129   const string name_;
    130   int64 work_started_ = 0;
    131   int64 work_finished_ = 0;
    132   int64 num_records_produced_ = 0;
    133   string work_;
    134 };
    135 
    136 }  // namespace tensorflow
    137 
    138 #endif  // TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
    139