Home | History | Annotate | Download | only in data
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <deque>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/framework/partial_tensor_shape.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/kernels/data/dataset.h"
     22 #include "tensorflow/core/lib/random/philox_random.h"
     23 #include "tensorflow/core/lib/random/random.h"
     24 #include "tensorflow/core/lib/random/random_distributions.h"
     25 
     26 namespace tensorflow {
     27 
     28 namespace {
     29 
     30 const int64 kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
     31 
     32 // See documentation in ../ops/dataset_ops.cc for a high-level
     33 // description of the following op.
     34 
     35 class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
     36  public:
     37   explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx)
     38       : UnaryDatasetOpKernel(ctx) {}
     39 
     40  protected:
     41   // Abstract base dataset that implements a shuffling iterator.
     42   class ShuffleDatasetBase : public GraphDatasetBase {
     43    public:
     44     ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
     45                        int64 buffer_size, int64 count)
     46         : GraphDatasetBase(ctx),
     47           input_(input),
     48           buffer_size_(buffer_size),
     49           count_(count) {
     50       input_->Ref();
     51     }
     52 
     53     ~ShuffleDatasetBase() override { input_->Unref(); }
     54 
     55     const DataTypeVector& output_dtypes() const override {
     56       return input_->output_dtypes();
     57     }
     58 
     59     const std::vector<PartialTensorShape>& output_shapes() const override {
     60       return input_->output_shapes();
     61     }
     62 
     63    protected:
     64     class Iterator : public DatasetIterator<ShuffleDatasetBase> {
     65      public:
     66       explicit Iterator(const Params& params, int64 seed, int64 seed2)
     67           : DatasetIterator<ShuffleDatasetBase>(params),
     68             input_impl_(nullptr),
     69             seed_(seed),
     70             seed2_(seed2),
     71             epoch_(0),
     72             num_elements_(0),
     73             parent_generator_(seed, seed2),
     74             generator_(&parent_generator_) {
     75         buffer_.reset(new std::vector<Tensor>[params.dataset->buffer_size_]);
     76         slices_.emplace_back(new Slice{0, 0});
     77       }
     78 
     79       Status GetNextInternal(IteratorContext* ctx,
     80                              std::vector<Tensor>* out_tensors,
     81                              bool* end_of_sequence) override {
     82         mutex_lock l(mu_);
     83         int64 start_micros = ctx->env()->NowMicros();
     84         int64 num_log_entries = 0;
     85         bool first_call = false;
     86         if (!input_impl_ && epoch_ == 0) {
     87           first_call = true;
     88           input_impl_ = dataset()->input_->MakeIterator(prefix());
     89         }
     90         while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
     91           if (ctx->env()->NowMicros() >
     92               ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
     93             num_log_entries++;
     94             LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
     95                       << num_elements_ << " of " << dataset()->buffer_size_;
     96           }
     97           std::vector<Tensor> input_element;
     98           bool end_of_input_sequence = false;
     99           while (dataset()->count_ == -1 || epoch_ < dataset()->count_) {
    100             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
    101                                                     &end_of_input_sequence));
    102             if (!end_of_input_sequence) {
    103               first_call = false;
    104               break;
    105             }
    106             if (first_call && dataset()->count_ == -1) {
    107               // If the first call to GetNext() fails because the end
    108               // of sequence has been reached, we terminate the
    109               // iteration immediately. (Otherwise, this iterator
    110               // would loop infinitely and never produce a value.)
    111               *end_of_sequence = true;
    112               return Status::OK();
    113             }
    114             epoch_++;
    115             int64 n = slices_.back()->end;
    116             slices_.emplace_back(new Slice{n, n});
    117             input_impl_ = dataset()->input_->MakeIterator(prefix());
    118           }
    119           if (!end_of_input_sequence) {
    120             buffer_[slices_.back()->end % dataset()->buffer_size_] =
    121                 std::move(input_element);
    122             num_elements_++;
    123             slices_.back()->end++;
    124           } else {
    125             input_impl_.reset();
    126           }
    127         }
    128         if (num_log_entries > 0) {
    129           LOG(INFO) << "Shuffle buffer filled.";
    130         }
    131 
    132         if (num_elements_ > 0) {
    133           *end_of_sequence = false;
    134           // Garbage collect all empty slices.
    135           while (!slices_.empty() &&
    136                  slices_.front()->start == slices_.front()->end) {
    137             slices_.pop_front();
    138           }
    139           DCHECK(!slices_.empty());
    140           // Choose an element to produce uniformly at random from the first
    141           // slice, and then remove the element from the slice.
    142           int64 offset =
    143               Random() % (slices_.front()->end - slices_.front()->start);
    144           int64 index =
    145               (slices_.front()->start + offset) % dataset()->buffer_size_;
    146           *out_tensors = std::move(buffer_[index]);
    147           std::swap(buffer_[index],
    148                     buffer_[slices_.front()->start % dataset()->buffer_size_]);
    149           slices_.front()->start++;
    150           num_elements_--;
    151         } else {
    152           DCHECK(input_impl_ == nullptr);
    153           *end_of_sequence = true;
    154         }
    155         return Status::OK();
    156       }
    157 
    158      protected:
    159       Status SaveInternal(IteratorStateWriter* writer) override {
    160         mutex_lock l(mu_);
    161 
    162         // Save state needed to restore the random number generators.
    163         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
    164                                                num_random_samples_));
    165 
    166         // Save input iterator if it hasn't been exhausted else write
    167         // "end_of_input_sequence".
    168         if (!input_impl_) {
    169           TF_RETURN_IF_ERROR(
    170               writer->WriteScalar(full_name("end_of_input_sequence"), ""));
    171         } else {
    172           TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
    173         }
    174 
    175         // Save the epoch counter, buffer, and buffer slices.
    176         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("epoch"), epoch_));
    177         TF_RETURN_IF_ERROR(
    178             writer->WriteScalar(full_name("num_elements"), num_elements_));
    179         TF_RETURN_IF_ERROR(
    180             writer->WriteScalar(full_name("slices_size"), slices_.size()));
    181         for (size_t i = 0; i < slices_.size(); ++i) {
    182           TF_RETURN_IF_ERROR(writer->WriteScalar(
    183               full_name(strings::StrCat("slices_start_", i)),
    184               slices_[i]->start));
    185           TF_RETURN_IF_ERROR(writer->WriteScalar(
    186               full_name(strings::StrCat("slices_end_", i)), slices_[i]->end));
    187           for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) {
    188             size_t index = j % dataset()->buffer_size_;
    189             TF_RETURN_IF_ERROR(writer->WriteScalar(
    190                 full_name(strings::StrCat("buffer_", index, "_size")),
    191                 buffer_[index].size()));
    192             for (size_t k = 0; k < buffer_[index].size(); ++k) {
    193               TF_RETURN_IF_ERROR(writer->WriteTensor(
    194                   full_name(strings::StrCat("buffer_", index, "_", k)),
    195                   buffer_[index][k]));
    196             }
    197           }
    198         }
    199 
    200         return Status::OK();
    201       }
    202 
    203       Status RestoreInternal(IteratorContext* ctx,
    204                              IteratorStateReader* reader) override {
    205         mutex_lock l(mu_);
    206 
    207         // Restore the random number generators.
    208         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
    209                                               &num_random_samples_));
    210         ResetRngs();
    211 
    212         // Restore the input iterator if it wasn't already exhausted.
    213         if (!reader->Contains(full_name("end_of_input_sequence"))) {
    214           input_impl_ = dataset()->input_->MakeIterator(prefix());
    215           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
    216         } else {
    217           input_impl_.reset();
    218         }
    219 
    220         // Restore the epoch counter, buffer, and buffer slices.
    221         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("epoch"), &epoch_));
    222         TF_RETURN_IF_ERROR(
    223             reader->ReadScalar(full_name("num_elements"), &num_elements_));
    224         size_t slices_size;
    225         {
    226           int64 temp;
    227           TF_RETURN_IF_ERROR(
    228               reader->ReadScalar(full_name("slices_size"), &temp));
    229           slices_size = static_cast<size_t>(temp);
    230         }
    231         buffer_.reset(new std::vector<Tensor>[dataset()->buffer_size_]);
    232         for (size_t i = 0; i < slices_size; ++i) {
    233           int64 start;
    234           TF_RETURN_IF_ERROR(reader->ReadScalar(
    235               full_name(strings::StrCat("slices_start_", i)), &start));
    236           int64 end;
    237           TF_RETURN_IF_ERROR(reader->ReadScalar(
    238               full_name(strings::StrCat("slices_end_", i)), &end));
    239           slices_.emplace_back(new Slice{start, end});
    240           for (size_t j = start; j < end; ++j) {
    241             size_t index = j % dataset()->buffer_size_;
    242             int64 list_size;
    243             TF_RETURN_IF_ERROR(reader->ReadScalar(
    244                 full_name(strings::StrCat("buffer_", index, "_size")),
    245                 &list_size));
    246             buffer_[index] = std::vector<Tensor>(list_size);
    247             for (int k = 0; k < list_size; ++k) {
    248               TF_RETURN_IF_ERROR(reader->ReadTensor(
    249                   full_name(strings::StrCat("buffer_", index, "_", k)),
    250                   &buffer_[index][k]));
    251             }
    252           }
    253         }
    254 
    255         return Status::OK();
    256       }
    257 
    258      private:
    259       // Used to represent slices of `buffer_` that belong to different epochs.
    260       // The invariant maintained by the implementation is: `start` <= `end`.
    261       // When using `start` and `end` to index into `buffer_`, their values
    262       // should be taken modulo the size of `buffer_` as their absolute value
    263       // can be greater than the range of `buffer_`.
    264       struct Slice {
    265         Slice(int64 start, int64 end) : start(start), end(end) {}
    266 
    267         int64 start;
    268         int64 end;
    269       };
    270 
    271       random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
    272           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    273         num_random_samples_++;
    274         auto out = generator_();
    275         return out;
    276       }
    277 
    278       void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    279         // Reset the generators based on the current iterator seeds.
    280         parent_generator_ = random::PhiloxRandom(seed_, seed2_);
    281         generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
    282             &parent_generator_);
    283         generator_.Skip(num_random_samples_);
    284       }
    285 
    286       mutex mu_;
    287       std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_);
    288       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
    289       const int64 seed_ GUARDED_BY(mu_);
    290       const int64 seed2_ GUARDED_BY(mu_);
    291       int64 epoch_ GUARDED_BY(mu_);
    292       int64 num_elements_ GUARDED_BY(mu_);
    293       std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_);
    294       random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
    295       random::SingleSampleAdapter<random::PhiloxRandom> generator_
    296           GUARDED_BY(mu_);
    297       int64 num_random_samples_ GUARDED_BY(mu_) = 0;
    298     };
    299 
    300     const DatasetBase* const input_;
    301     const int64 buffer_size_;
    302     const int64 count_;
    303   };
    304 };
    305 
    306 class ShuffleDatasetOp : public ShuffleDatasetOpBase {
    307  public:
    308   explicit ShuffleDatasetOp(OpKernelConstruction* ctx)
    309       : ShuffleDatasetOpBase(ctx) {
    310     OP_REQUIRES_OK(ctx, ctx->GetAttr("reshuffle_each_iteration",
    311                                      &reshuffle_each_iteration_));
    312   }
    313 
    314   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
    315                    DatasetBase** output) override {
    316     int64 buffer_size;
    317     OP_REQUIRES_OK(
    318         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
    319     OP_REQUIRES(
    320         ctx, buffer_size > 0,
    321         errors::InvalidArgument("buffer_size must be greater than zero."));
    322 
    323     int64 seed;
    324     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
    325 
    326     int64 seed2;
    327     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
    328 
    329     // By TensorFlow convention, passing 0 for both seeds indicates
    330     // that the shuffling should be seeded non-deterministically.
    331     if (seed == 0 && seed2 == 0) {
    332       seed = random::New64();
    333       seed2 = random::New64();
    334     }
    335 
    336     int64 count = 1;
    337     if (reshuffle_each_iteration_) {
    338       *output =
    339           new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count);
    340     } else {
    341       *output =
    342           new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count);
    343     }
    344   }
    345 
    346  private:
    347   // A dataset that uses a pseduorandom sequence of seeds for the iterators
    348   // created from it. Used when `reshuffle_each_iteration` is true.
    349   class ReshufflingDataset : public ShuffleDatasetBase {
    350    public:
    351     ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input,
    352                        int64 buffer_size, int64 seed, int64 seed2, int64 count)
    353         : ShuffleDatasetBase(ctx, input, buffer_size, count),
    354           seed_(seed),
    355           seed2_(seed2),
    356           parent_generator_(seed, seed2),
    357           generator_(&parent_generator_) {}
    358 
    359     string DebugString() override {
    360       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
    361                              ", ", seed2_, ")::ReshufflingDataset");
    362     }
    363 
    364     std::unique_ptr<IteratorBase> MakeIterator(
    365         const string& prefix) const override {
    366       int64 iterator_seed;
    367       int64 iterator_seed2;
    368       {
    369         mutex_lock l(mu_);
    370         iterator_seed = generator_();
    371         iterator_seed2 = generator_();
    372       }
    373       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
    374           {this, strings::StrCat(prefix, "::Shuffle")}, iterator_seed,
    375           iterator_seed2));
    376     }
    377 
    378    private:
    379     const int64 seed_;
    380     const int64 seed2_;
    381     mutable mutex mu_;
    382     mutable random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
    383     mutable random::SingleSampleAdapter<random::PhiloxRandom> generator_
    384         GUARDED_BY(mu_);
    385   };
    386 
    387   // A dataset that uses the same fixed seed for all iterators created from it.
    388   // Used when `reshuffle_each_iteration` is false.
    389   class FixedSeedDataset : public ShuffleDatasetBase {
    390    public:
    391     FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
    392                      int64 buffer_size, int64 seed, int64 seed2, int64 count)
    393         : ShuffleDatasetBase(ctx, input, buffer_size, count),
    394           seed_(seed),
    395           seed2_(seed) {}
    396 
    397     string DebugString() override {
    398       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
    399                              ", ", seed2_, ")::FixedSeedDataset");
    400     }
    401 
    402     std::unique_ptr<IteratorBase> MakeIterator(
    403         const string& prefix) const override {
    404       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
    405           {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
    406     }
    407 
    408    protected:
    409     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
    410                               Node** output) const override {
    411       Node* input_graph_node = nullptr;
    412       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
    413       Node* buffer_size = nullptr;
    414       Node* seed = nullptr;
    415       Node* seed2 = nullptr;
    416       AttrValue reshuffle_each_iteration;
    417 
    418       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
    419       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
    420       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
    421       b->BuildAttrValue(false, &reshuffle_each_iteration);
    422       TF_RETURN_IF_ERROR(b->AddDataset(
    423           this, {input_graph_node, buffer_size, seed, seed2},  // Inputs
    424           {std::make_pair("reshuffle_each_iteration",
    425                           reshuffle_each_iteration)},  // Attrs
    426           output));
    427       return Status::OK();
    428     }
    429 
    430    private:
    431     const int64 seed_;
    432     const int64 seed2_;
    433   };
    434 
    435   bool reshuffle_each_iteration_;
    436 };
    437 
    438 class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
    439  public:
    440   explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
    441       : ShuffleDatasetOpBase(ctx) {}
    442 
    443   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
    444                    DatasetBase** output) override {
    445     int64 buffer_size;
    446     OP_REQUIRES_OK(
    447         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
    448     OP_REQUIRES(
    449         ctx, buffer_size > 0,
    450         errors::InvalidArgument("buffer_size must be greater than zero."));
    451 
    452     int64 seed;
    453     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
    454 
    455     int64 seed2;
    456     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
    457 
    458     int64 count;
    459     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
    460 
    461     // By TensorFlow convention, if both seeds are 0, then shuffling should be
    462     // seeded non-deterministically.
    463     if (seed == 0 && seed2 == 0) {
    464       seed = random::New64();
    465       seed2 = random::New64();
    466     }
    467 
    468     *output = new Dataset(ctx, input, buffer_size, seed, seed2, count);
    469   }
    470 
    471  private:
    472   class Dataset : public ShuffleDatasetBase {
    473    public:
    474     Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
    475             int64 seed, int64 seed2, int64 count)
    476         : ShuffleDatasetBase(ctx, input, buffer_size, count),
    477           seed_(seed),
    478           seed2_(seed2) {}
    479 
    480     string DebugString() override {
    481       return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ",
    482                              seed_, ", ", seed2_, ", ", count_, ")::Dataset");
    483     }
    484 
    485     std::unique_ptr<IteratorBase> MakeIterator(
    486         const string& prefix) const override {
    487       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
    488           {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
    489           seed2_));
    490     }
    491 
    492    protected:
    493     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
    494                               Node** output) const override {
    495       Node* input_graph_node = nullptr;
    496       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
    497       Node* buffer_size = nullptr;
    498       Node* seed = nullptr;
    499       Node* seed2 = nullptr;
    500       Node* count = nullptr;
    501 
    502       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
    503       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
    504       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
    505       TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
    506       TF_RETURN_IF_ERROR(b->AddDataset(
    507           this, {input_graph_node, buffer_size, seed, seed2, count},  // Inputs
    508           {},                                                         // Attrs
    509           output));
    510       return Status::OK();
    511     }
    512 
    513    private:
    514     const int64 seed_;
    515     const int64 seed2_;
    516   };
    517 };
    518 
    519 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
    520                         ShuffleDatasetOp);
    521 
    522 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
    523                         ShuffleAndRepeatDatasetOp);
    524 
    525 }  // namespace
    526 
    527 }  // namespace tensorflow
    528