Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/record_yielder.h"
     17 
     18 #include "tensorflow/core/lib/io/record_reader.h"
     19 #include "tensorflow/core/lib/strings/str_util.h"
     20 #include "tensorflow/core/platform/env.h"
     21 
     22 namespace tensorflow {
     23 
     24 RecordYielder::RecordYielder(OpKernelConstruction* context,
     25                              const RecordYielder::Options& opts)
     26     : opts_(opts),
     27       thread_(new thread::ThreadPool(context->env(), ThreadOptions(),
     28                                      "record_yielder", 1 + opts.parallelism,
     29                                      /* low_latency_hint */ false)),
     30       epoch_(0),
     31       rnd_(opts.seed) {
     32   thread_->Schedule([this]() { MainLoop(); });
     33 }
     34 
     35 RecordYielder::~RecordYielder() {
     36   {
     37     mutex_lock l(mu_);
     38     stop_ = true;
     39     buf_empty_.notify_all();
     40     buf_enough_.notify_all();
     41     buf_not_full_.notify_all();
     42   }
     43   main_loop_done_.WaitForNotification();
     44   delete thread_;
     45 }
     46 
     47 Status RecordYielder::YieldOne(string* value) {
     48   mutex_lock l(mu_);
     49   while (!BufEnough() && status_.ok()) {
     50     buf_enough_.wait(l);
     51   }
     52   if (status_.ok()) {
     53     bool notify_no_longer_full = !BufNotFull();
     54     CHECK(!stop_ && !buf_.empty());
     55     *value = std::move(buf_.back());
     56     buf_.pop_back();
     57     ++num_records_yielded_in_epoch_;
     58     // Assumption is that an epoch always has something in the buffer
     59     // until it ends.  If the input pipeline was slower than the consumers
     60     // by a lot this might not be true.  Not sure how to handle.
     61     if (buf_.empty()) {
     62       buf_empty_.notify_all();
     63     }
     64     if (notify_no_longer_full) {
     65       buf_not_full_.notify_all();
     66     }
     67   }
     68   return status_;
     69 }
     70 
     71 struct RecordYielder::Shard {
     72   int index;                      // Shard index.
     73   std::vector<string> filenames;  // File names given to this shard.
     74   Notification done;              // Notified when this shard is done.
     75   Status status;                  // Shard status.
     76 };
     77 
     78 bool RecordYielder::ShouldFinish(const Status& s) {
     79   mutex_lock l(mu_);
     80   status_.Update(s);
     81   return stop_ || !status_.ok();
     82 }
     83 
     84 static Status MatchFiles(const string& patterns,
     85                          std::vector<string>* filenames) {
     86   for (const auto& file_pattern : str_util::Split(patterns, ',')) {
     87     std::vector<string> tmp_filenames;
     88     TF_RETURN_IF_ERROR(
     89         Env::Default()->GetMatchingPaths(file_pattern, &tmp_filenames));
     90     filenames->insert(filenames->end(),
     91                       std::make_move_iterator(tmp_filenames.begin()),
     92                       std::make_move_iterator(tmp_filenames.end()));
     93   }
     94   return Status::OK();
     95 }
     96 
     97 void RecordYielder::MainLoop() {
     98   while (true) {
     99     ++epoch_;
    100     num_records_yielded_in_epoch_ = 0;
    101     num_records_added_in_epoch_ = 0;
    102 
    103     // Finds all files.
    104     std::vector<string> filenames;
    105     Status s = MatchFiles(opts_.file_pattern, &filenames);
    106 
    107     if (filenames.empty()) {
    108       s = errors::NotFound("Found no files at ", opts_.file_pattern);
    109       if (ShouldFinish(s)) {
    110         buf_enough_.notify_all();
    111         break;
    112       }
    113     }
    114 
    115     if (ShouldFinish(s)) break;
    116 
    117     // Shuffles these files according to the epoch # and random seed.
    118     std::mt19937_64 shuffle_rnd(
    119         Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
    120     std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd);
    121 
    122     // Left-shift the filename list.
    123     const std::vector<string>::size_type num = filenames.size();
    124     int64 shift;
    125     if (0 <= opts_.file_shuffle_shift_ratio &&
    126         opts_.file_shuffle_shift_ratio < 1) {
    127       shift = opts_.file_shuffle_shift_ratio * num;
    128       std::rotate(filenames.begin(), filenames.begin() + shift,
    129                   filenames.end());
    130     }
    131 
    132     // Shards files and use one thread to go through each shard.
    133     const int N = opts_.parallelism;
    134     std::vector<Shard> shards(N);
    135     for (int i = 0; i < N; ++i) {
    136       Shard* shard = &shards[i];
    137       shard->index = i;
    138       for (std::vector<string>::size_type j = i; j < filenames.size(); j += N) {
    139         shard->filenames.push_back(filenames[j]);
    140       }
    141       thread_->Schedule([this, shard]() { ShardLoop(shard); });
    142     }
    143     for (int i = 0; i < N; ++i) {
    144       shards[i].done.WaitForNotification();
    145       s.Update(shards[i].status);
    146     }
    147 
    148     if (num_records_added_in_epoch_ < opts_.bufsize) {
    149       mutex_lock l(mu_);
    150       opts_.bufsize = num_records_added_in_epoch_;
    151     }
    152 
    153     if (ShouldFinish(s)) {
    154       buf_enough_.notify_all();
    155       break;
    156     }
    157 
    158     // Starts the next epoch once all buffered records are consumed.
    159     {
    160       mutex_lock l(mu_);
    161       epoch_end_ = true;
    162       if (BufEnough()) {
    163         buf_enough_.notify_all();
    164       }
    165       while (!BufEmpty()) {
    166         buf_empty_.wait(l);
    167       }
    168       epoch_end_ = false;
    169     }
    170   }
    171   main_loop_done_.Notify();
    172 }
    173 
    174 bool RecordYielder::Add(std::vector<string>* values) {
    175   mutex_lock l(mu_);
    176   while (!BufNotFull()) {
    177     buf_not_full_.wait(l);
    178   }
    179   while (BufNotFull() && !values->empty()) {
    180     // Adds values->back(). Swaps its position with another random
    181     // element.
    182     auto index = rnd_() % (buf_.size() + 1);
    183     if (index == buf_.size()) {
    184       buf_.push_back(std::move(values->back()));
    185     } else {
    186       buf_.push_back(std::move(buf_[index]));
    187       buf_[index] = std::move(values->back());
    188     }
    189     values->pop_back();
    190     num_records_added_in_epoch_++;
    191   }
    192   if (BufEnough()) {
    193     buf_enough_.notify_all();
    194   }
    195   return stop_;
    196 }
    197 
    198 void RecordYielder::ShardLoop(Shard* shard) {
    199   std::vector<string> values;
    200   const int64 kRecords = 16;
    201   for (const string& filename : shard->filenames) {
    202     std::unique_ptr<RandomAccessFile> file;
    203     if (ShouldFinish(Status::OK())) break;
    204     Status s = Env::Default()->NewRandomAccessFile(filename, &file);
    205     if (!s.ok()) {
    206       shard->status = errors::InvalidArgument("Can't open ", filename);
    207       break;
    208     }
    209     io::RecordReaderOptions options =
    210         io::RecordReaderOptions::CreateRecordReaderOptions(
    211             opts_.compression_type);
    212     io::RecordReader rdr(file.get(), options);
    213     uint64 offset = 0;
    214     string record;
    215     while (true) {
    216       Status s = rdr.ReadRecord(&offset, &record);
    217       if (s.ok()) {
    218         values.emplace_back(std::move(record));
    219         if (values.size() >= kRecords && Add(&values)) {
    220           shard->status = errors::Aborted("stopped");
    221           break;
    222         }
    223       } else if (errors::IsOutOfRange(s)) {
    224         break;
    225       } else {
    226         shard->status = s;
    227         break;
    228       }
    229     }
    230   }
    231   // Adds the remaining values of this shard to buf_.
    232   while (!values.empty()) {
    233     Add(&values);
    234   }
    235   shard->done.Notify();
    236 }
    237 
    238 }  // namespace tensorflow
    239