Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
     17 #define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
     18 
     19 #include <atomic>
     20 #include <random>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/core/notification.h"
     27 #include "tensorflow/core/lib/core/threadpool.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 #include "tensorflow/core/platform/thread_annotations.h"
     30 
     31 namespace tensorflow {
     32 
     33 // RecordYielder produces value records from a set of tfrecord files
     34 // in a random order.
     35 //
     36 // It guarantees that:
     37 //   1) all records in tfrecords are yielded within every epoch;
     38 //   2) each record is yielded only once within every epoch;
     39 //   3) the order in which records are yielded is highly randomized.
     40 //   4) the peak memory usage is roughly avg record size *
     41 //      (opts.bufsize + opts.parellelism * 16).
     42 //
     43 // Usage example:
     44 //   RecordYielder::Options opts;
     45 //   opts.file_pattern = "input-*";
     46 //   opts.seed = 301;
     47 //   opts.bufsize = 1000000;    // A randomized buffer with 1M records.
     48 //   opts.parallelism = 8;      // Uses 8 tfrecord iterators to iterate
     49 //                              // through all files.
     50 //   RecordYielder yielder(opts);
     51 //   string val;
     52 //   while (true) {
     53 //     yielder.YieldOne(&val);
     54 //     // process val
     55 //   }
     56 //
     57 // RecordYielder can be accessed by multiple threads concurrently.
     58 class RecordYielder {
     59  public:
     60   struct Options {
     61     // Glob pattern for tfrecords.
     62     string file_pattern;
     63 
     64     // Random seed. It determines how data files are shuffled and how
     65     // records are shuffled.
     66     int64 seed = 0;
     67 
     68     // Each epoch, all files are first shuffled according to the
     69     // random seed and the epoch number, and then all files are
     70     // left-shifted by file_shuffle_shift_ratio * num_files slots.  If
     71     // file_shuffle_shift_ratio is not within [0, 1), the
     72     // implementation clip it to [0, 1).
     73     float file_shuffle_shift_ratio = 0;
     74 
     75     // Randomization buffer keeps these many records.
     76     uint64 bufsize = 1;
     77 
     78     // Uses these many concurrent tfrecord iterators to iterate through
     79     // tfrecords.
     80     int32 parallelism = 1;
     81 
     82     string compression_type;
     83   };
     84 
     85   explicit RecordYielder(OpKernelConstruction* context,
     86                          const RecordYielder::Options& opts);
     87   ~RecordYielder();
     88 
     89   RecordYielder(const RecordYielder&) = delete;
     90   RecordYielder& operator=(const RecordYielder&) = delete;
     91 
     92   // Yields one 'value'.
     93   Status YieldOne(string* value);
     94 
     95   // Returns the current epoch number.
     96   int64 current_epoch() const { return epoch_; }
     97 
     98  private:
     99   typedef RecordYielder ME;
    100 
    101   Options opts_;
    102 
    103   // Backgrounds threads. Owned.
    104   thread::ThreadPool* thread_;
    105 
    106   // Epoch number.
    107   std::atomic<int64> epoch_;
    108 
    109   mutex mu_;
    110 
    111   // Turned to true when this is deleted.
    112   bool stop_ GUARDED_BY(mu_) = false;
    113   Status status_ GUARDED_BY(mu_);
    114 
    115   // PRG used for randomization.
    116   std::mt19937_64 rnd_ GUARDED_BY(mu_);
    117 
    118   // Randomization buffer.
    119   std::vector<string> buf_ GUARDED_BY(mu_);
    120 
    121   // True iff we are draining an epoch.
    122   bool epoch_end_ = false;
    123 
    124   int64 num_records_added_in_epoch_ = 0;
    125   int64 num_records_yielded_in_epoch_ = 0;
    126 
    127   // Trigger when the main loop has exited.
    128   Notification main_loop_done_;
    129 
    130   // condition_variables.
    131   condition_variable buf_empty_;
    132   bool BufEmpty() const SHARED_LOCKS_REQUIRED(mu_) {
    133     return stop_ || buf_.empty();
    134   }
    135 
    136   condition_variable buf_not_full_;
    137   bool BufNotFull() const SHARED_LOCKS_REQUIRED(mu_) {
    138     return stop_ || buf_.size() < opts_.bufsize;
    139   }
    140 
    141   condition_variable buf_enough_;
    142   bool BufEnough() const SHARED_LOCKS_REQUIRED(mu_) {
    143     // NOTE: Unless we are finishing an epoch, we want to make sure
    144     // the buf_ contains enough randomized elements before yielding
    145     // any.
    146     return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) ||
    147            (!epoch_end_ &&
    148             buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2));
    149   }
    150 
    151   void MainLoop();
    152   struct Shard;
    153   void ShardLoop(Shard* shard);
    154   bool ShouldFinish(const Status& s);
    155   bool Add(std::vector<string>* values);
    156 };
    157 
    158 }  // namespace tensorflow
    159 
    160 #endif  // TENSORFLOW_KERNELS_RECORD_YIELDER_H_
    161