     15 #include <atomic>
     16 #include <deque>
     17 #include <utility>
     19 #include "tensorflow/core/common_runtime/function.h"
     20 #include "tensorflow/core/framework/dataset.h"
     21 #include "tensorflow/core/framework/partial_tensor_shape.h"
     22 #include "tensorflow/core/framework/stats_aggregator.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/kernels/data/captured_function.h"
     25 #include "tensorflow/core/kernels/data/dataset_utils.h"
     26 #include "tensorflow/core/lib/core/threadpool.h"
     27 #include "tensorflow/core/lib/gtl/cleanup.h"
     28 #include "tensorflow/core/lib/random/random.h"
     30 namespace tensorflow {
     31 namespace data {
     32 namespace {
     34 // See documentation in ../../ops/dataset_ops.cc for a high-level
     35 // description of the following op.
     37 class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
     38  public:
     39   explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
     40       : UnaryDatasetOpKernel(ctx) {
     41     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
     42     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     43     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     44   }
     46   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     47                    DatasetBase** output) override {
     48     int64 cycle_length = 0;
     49     OP_REQUIRES_OK(ctx,
     50                    ParseScalarArgument(ctx, "cycle_length", &cycle_length));
     51     OP_REQUIRES(ctx, cycle_length > 0,
     52                 errors::InvalidArgument("`cycle_length` must be > 0"));
     54     int64 block_length = 0;
     55     OP_REQUIRES_OK(ctx,
     56                    ParseScalarArgument(ctx, "block_length", &block_length));
     57     OP_REQUIRES(ctx, block_length > 0,
     58                 errors::InvalidArgument("`block_length` must be > 0"));
     60     bool sloppy = false;
     61     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
     63     int64 buffer_output_elements = 0;
     64     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
     65                                             &buffer_output_elements));
     66     OP_REQUIRES(
     67         ctx, buffer_output_elements > 0,
     68         errors::InvalidArgument("`buffer_output_elements` must be > 0"));
     70     int64 prefetch_input_elements = 0;
     71     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
     72                                             &prefetch_input_elements));
     73     OP_REQUIRES(
     74         ctx, prefetch_input_elements >= 0,
     75         errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
     77     std::unique_ptr<CapturedFunction> captured_func;
     78     OP_REQUIRES_OK(
     79         ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
     80                                       &captured_func));
     82     *output =
     83         new Dataset(ctx, input, interleave_func_, std::move(captured_func),
     84                     cycle_length, block_length, sloppy, buffer_output_elements,
     85                     prefetch_input_elements, output_types_, output_shapes_);
     86   }
     88  private:
     89   class Dataset : public DatasetBase {
     90    public:
     91     Dataset(OpKernelContext* ctx, const DatasetBase* input,
     92             const NameAttrList& func,
     93             std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
     94             int64 block_length, bool sloppy, int64 buffer_output_elements,
     95             int64 prefetch_input_elements, const DataTypeVector& output_types,
     96             const std::vector<PartialTensorShape>& output_shapes)
     97         : DatasetBase(DatasetContext(ctx)),
     98           input_(input),
     99           interleave_func_(func),
    100           captured_func_(std::move(captured_func)),
    101           cycle_length_(cycle_length),
    102           block_length_(block_length),
    103           sloppy_(sloppy),
    104           buffer_output_elements_(buffer_output_elements),
    105           prefetch_input_elements_(prefetch_input_elements),
    106           output_types_(output_types),
    107           output_shapes_(output_shapes) {
    108       input_->Ref();
    109     }
    111     ~Dataset() override { input_->Unref(); }
    113     std::unique_ptr<IteratorBase> MakeIteratorInternal(
    114         const string& prefix) const override {
    115       return absl::make_unique<Iterator>(Iterator::Params{
    116           this, strings::StrCat(prefix, "::ParallelInterleave")});
    117     }
    119     const DataTypeVector& output_dtypes() const override {
    120       return output_types_;
    121     }
    123     const std::vector<PartialTensorShape>& output_shapes() const override {
    124       return output_shapes_;
    125     }
    127     string DebugString() const override {
    128       return "ParallelInterleaveDatasetOp::Dataset";
    129     }
    131    protected:
    132     Status AsGraphDefInternal(SerializationContext* ctx,
    133                               DatasetGraphDefBuilder* b,
    134                               Node** output) const override {
    135       TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
    136       Node* input_node;
    137       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
    138       Node* cycle_length_node;
    139       TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
    140       Node* block_length_node;
    141       TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
    142       Node* sloppy_node;
    143       TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
    144       Node* buffer_output_elements_node;
    145       TF_RETURN_IF_ERROR(
    146           b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
    147       Node* prefetch_input_elements_node;
    148       TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
    149                                       &prefetch_input_elements_node));
    150       DataTypeVector other_arguments_types;
    151       other_arguments_types.reserve(captured_func_->captured_inputs().size());
    152       std::vector<Node*> other_arguments;
    153       other_arguments.reserve(captured_func_->captured_inputs().size());
    154       for (const Tensor& t : captured_func_->captured_inputs()) {
    155         Node* node;
    156         DatasetBase* input;
    157         Status s = GetDatasetFromVariantTensor(t, &input);
    158         if (s.ok()) {
    159           TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
    160         } else {
    161           TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
    162         }
    163         other_arguments.emplace_back(node);
    164         other_arguments_types.emplace_back(t.dtype());
    165       }
    166       AttrValue f;
    167       b->BuildAttrValue(interleave_func_, &f);
    168       AttrValue other_arguments_types_attr;
    169       b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
    171       TF_RETURN_IF_ERROR(b->AddDataset(
    172           this,
    173           {{0, input_node},
    174            {2, cycle_length_node},
    175            {3, block_length_node},
    176            {4, sloppy_node},
    177            {5, buffer_output_elements_node},
    178            {6, prefetch_input_elements_node}},
    179           {{1, other_arguments}},
    180           {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
    181       return Status::OK();
    182     }
    184    private:
    185     int64 num_threads() const {
    186       return cycle_length_ + prefetch_input_elements_;
    187     }
    189     // Parallel interleave's implementation is designed around a few principles:
    190     //  1. Thread creation is relatively expensive. (Not reusing
    191     //     threads causes a number of indirect costs such as poorer tcmalloc
    192     //     performance due to thread-local caches, etc.) We allocate a fixed
    193     //     number of threads at the start and never change. This is why we've
    194     //     fused functionality that is theoretically orthogonal (i.e.
    195     //     .prefetch()) into the implementation.
    196     //  2. Drop-in replacement for standard interleave. The goal will be to
    197     //     auto-opt people into an optimized implementation without any work
    198     //     on the customer's part. We thus go through great pains to maintain
    199     //     identical iteration orders, full determinism (disabled only via a
    200     //     flag, etc.)
    201     //  3. Performance across a variety of environments and I/O envelopes.
    202     //
    203     // The actual implementation centers around a collection of worker threads
    204     // and their corresponding worker state (tracked in the `workers_` vector).
    205     // Worker threads repeatedly receive a vector of Tensors that are used as
    206     // input to the flat-map function (`captured_func_`). The output of this
    207     // function must be a dataset. The worker thread then repeatedly calls
    208     // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
    209     // that a caller will block waiting for an element to be produced.
    210     //
    211     // Pointers to these worker states are kept in 2 disjoint data structures:
    212     //  1. `interleave_indices_` is a vector containing indices of WorkerStates
    213     //     in `workers_` that we are interleaving. Worker threads backing these
    214     //     WorkerStates should be regularly producing values.
    215     //  2. `staging_indices_` is a deque containing indices of WorkerStates in
    216     //     `workers_` that we will move to `interleave_indices_` when an
    217     //     iterator in `interleave_indices_` is exhausted.
    218     //
    219     // The client calls `GetNext[Internal]()` to retrieve an output element. The
    220     // internal implementation updates the state of `interleave_indices_` and
    221     // `staging_indices_` as output iterators (run by the worker threads) are
    222     // exhausted.
    223     //
    224     // `input_impl_` is the input iterator that generates arguments for the
    225     // flat-map function (`captured_func_`). It is set to an iterator at
    226     // Iterator construction, and is fixed until we consume all input elements.
    227     // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
    228     // memory.
    229     //
    230     // A few invariants are maintained:
    231     //  1. No element in interleave_indices_ should be a -1 unless
    232     //     `staging_indices_` is empty and `input_impl_` is empty.
    233     //  2. Every `worker_` element is pointed to by at most one element of the
    234     //     union of `interleave_indices_` and `staging_indices_`.
    235     //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
    236     //     an element in `interleave_indices_` or `staging_indices_`.
    237     class Iterator : public DatasetIterator<Dataset> {
    238      public:
    239       explicit Iterator(const Params& params)
    240           : DatasetIterator<Dataset>(params),
    241             workers_(dataset()->num_threads()),
    242             worker_thread_states_(dataset()->num_threads()) {}
    244       ~Iterator() override {
    245         mutex_lock l(mu_);
    246         cancelled_ = true;
    247         // Notify all workers in case they are blocked.
    248         for (auto& worker : workers_) {
    249           worker.cond_var.notify_all();
    250         }
    251       }
    253       Status Initialize(IteratorContext* ctx) override {
    254         TF_RETURN_IF_ERROR(
    255             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
    256         return dataset()->captured_func_->Instantiate(
    257             ctx, &instantiated_captured_func_);
    258       }
    260       // It is implemented so that it matches the deterministic interleave
    261       // unless getting the next element would block and we are allowed to be
    262       // sloppy.
    263       Status GetNextInternal(IteratorContext* ctx,
    264                              std::vector<Tensor>* out_tensors,
    265                              bool* end_of_sequence) override {
    266         mutex_lock l(mu_);
    267         TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
    268         while (!cancelled_) {
    269           // Wait for an item to become available, blocking if necessary. If we
    270           // are allowed to be sloppy, we can skip over input datasets that do
    271           // not have an item readily available.
    272           bool can_produce_elements = false;
    273           bool must_wait_for_input = true;
    274           for (int64 i = 0; i < interleave_indices_.size(); ++i) {
    275             int64 index = (next_index_ + i) % interleave_indices_.size();
    276             int64 current_worker_index = interleave_indices_[index];
    277             if (current_worker_index < 0) {
    278               continue;  // Empty interleave elements.
    279             }
    280             WorkerState* current_worker = &workers_[current_worker_index];
    281             can_produce_elements |= current_worker->MayHaveElements();
    282             if (!current_worker->outputs.empty()) {
    283               // We have an element!
    284               next_index_ = index;
    285               const bool element_acquired_sloppily =
    286                   dataset()->sloppy_ && i > 1;
    287               if (!element_acquired_sloppily) {
    288                 // If the element was acquired in the regular (non-sloppy)
    289                 // order, then advance the current block and cycle pointers to
    290                 // the next element in the regular order.
    291                 block_count_++;
    292                 if (block_count_ == dataset()->block_length_) {
    293                   next_index_ = (index + 1) % interleave_indices_.size();
    294                   block_count_ = 0;
    295                 }
    296               } else {
    297                 block_count_ = 0;
    298               }
    299               *end_of_sequence = false;
    300               Status s = current_worker->outputs.front().status;
    301               current_worker->outputs.front().output.swap(*out_tensors);
    302               current_worker->outputs.pop_front();
    303               current_worker->cond_var.notify_one();
    304               return s;
    305             } else if (current_worker->is_producing && !dataset()->sloppy_) {
    306               // current_worker.outputs.empty(), and we must wait for this
    307               // iterator.
    308               if (next_index_ != index) {
    309                 // We have advanced to a new iterator; reset block counts.
    310                 next_index_ = index;
    311                 block_count_ = 0;
    312               }
    313               break;
    314             } else if (!current_worker->is_producing) {
    315               // This iterator has reached end of input.
    316               interleave_indices_[index] = -1;
    317               if (input_impl_) {
    318                 // Start prefetching a new iterator.
    319                 std::vector<Tensor> args;
    320                 bool end_of_input = false;
    321                 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
    322                 if (end_of_input) {
    323                   input_impl_.reset();
    324                 } else {
    325                   current_worker->SetInputs(s, std::move(args));
    326                   staging_indices_.emplace_back(current_worker_index);
    327                 }
    328               }
    330               if (!staging_indices_.empty()) {
    331                 // Move a worker from `staging_indices_` to
    332                 // `interleave_indices_`.
    333                 interleave_indices_[index] = staging_indices_.front();
    334                 staging_indices_.pop_front();
    336                 next_index_ = (index + 1) % interleave_indices_.size();
    337                 block_count_ = 0;
    338                 // Restart the inner [for] loop
    339                 can_produce_elements = true;
    340                 must_wait_for_input = false;
    341                 break;
    342               }
    343             }
    344           }
    346           if (!can_produce_elements && !input_impl_) {
    347             // No potential for future values.
    348             *end_of_sequence = true;
    349             return Status::OK();
    350           }
    352           if (must_wait_for_input) {
    353             // Wait for elements to become available.
    354             RecordStop(ctx);
    355             if (dataset()->sloppy_) {
    356               sloppy_cond_var_.wait(l);
    357             } else {
    358               workers_[interleave_indices_[next_index_]].cond_var.wait(l);
    359             }
    360             RecordStart(ctx);
    361           }
    362         }
    363         return errors::Cancelled(
    364             "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
    365       }
    367      protected:
    368       std::shared_ptr<model::Node> CreateNode(
    369           IteratorContext* ctx, model::Node::Args args) const override {
    370         return model::MakeAsyncInterleaveManyNode(std::move(args),
    371                                                   /*parameters=*/{});
    372       }
    374       Status SaveInternal(IteratorStateWriter* writer) override {
    375         // The order of locking is important here to avoid deadlock.
    376         mutex_lock l(mu_);
    377         mutex_lock ckpt_l(ckpt_mu_);
    378         if (input_impl_) {
    379           TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
    380         } else {
    381           TF_RETURN_IF_ERROR(
    382               writer->WriteScalar(full_name("input_exhausted"), ""));
    383         }
    384         TF_RETURN_IF_ERROR(
    385             writer->WriteScalar(full_name("next_index"), next_index_));
    386         TF_RETURN_IF_ERROR(
    387             writer->WriteScalar(full_name("block_count"), block_count_));
    388         TF_RETURN_IF_ERROR(
    389             writer->WriteScalar(full_name("workers_size"), workers_.size()));
    390         for (int i = 0; i < workers_.size(); ++i) {
    391           TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
    392         }
    393         for (int i = 0; i < worker_thread_states_.size(); ++i) {
    394           TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
    395         }
    396         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"),
    397                                                interleave_indices_.size()));
    398         for (int i = 0; i < interleave_indices_.size(); ++i) {
    399           TF_RETURN_IF_ERROR(writer->WriteScalar(
    400               full_name(strings::StrCat("interleave_indices_", i)),
    401               interleave_indices_[i]));
    402         }
    403         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"),
    404                                                staging_indices_.size()));
    405         for (int i = 0; i < staging_indices_.size(); ++i) {
    406           TF_RETURN_IF_ERROR(writer->WriteScalar(
    407               full_name(strings::StrCat("staging_indices_", i)),
    408               staging_indices_[i]));
    409         }
    410         if (!worker_threads_.empty()) {
    411           TF_RETURN_IF_ERROR(
    412               writer->WriteScalar(full_name("worker_threads_running"), ""));
    413         }
    414         return Status::OK();
    415       }
    417       Status RestoreInternal(IteratorContext* ctx,
    418                              IteratorStateReader* reader) override {
    419         // The order of locking is important here to avoid deadlock.
    420         mutex_lock l(mu_);
    421         mutex_lock ckpt_l(ckpt_mu_);
    422         if (!reader->Contains(full_name("input_exhausted"))) {
    423           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
    424         } else {
    425           input_impl_.reset();
    426         }
    427         int64 temp;
    428         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp));
    429         next_index_ = size_t(temp);
    430         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp));
    431         block_count_ = size_t(temp);
    433         // Restore WorkerStates.
    434         TF_RETURN_IF_ERROR(
    435             reader->ReadScalar(full_name("workers_size"), &temp));
    436         if (temp != dataset()->num_threads()) {
    437           return errors::Internal("Expected ", dataset()->num_threads(),
    438                                   " worker states but found ", temp, ".");
    439         }
    440         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    441           TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
    442         }
    443         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    444           TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
    445         }
    447         // Restore `interleave_indices_`.
    448         std::set<int64> all_indices;
    449         {
    450           int64 interleave_size;
    451           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"),
    452                                                 &interleave_size));
    453           interleave_indices_.reserve(interleave_size);
    454           for (int64 i = 0; i < interleave_size; ++i) {
    455             int64 temp;
    456             TF_RETURN_IF_ERROR(reader->ReadScalar(
    457                 full_name(strings::StrCat("interleave_indices_", i)), &temp));
    458             if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
    459               return errors::Internal(
    460                   "Duplicate entry for ", temp,
    461                   " found when reading interleave and staging indices.");
    462             }
    463             if (temp >= 0) {
    464               all_indices.insert(temp);
    465             }
    466             interleave_indices_.emplace_back(temp);
    467           }
    468         }
    470         // Restore `staging_indices_`.
    471         {
    472           int64 staging_size;
    473           TF_RETURN_IF_ERROR(
    474               reader->ReadScalar(full_name("staging_size"), &staging_size));
    475           for (int i = 0; i < staging_size; ++i) {
    476             int64 temp;
    477             TF_RETURN_IF_ERROR(reader->ReadScalar(
    478                 full_name(strings::StrCat("staging_indices_", i)), &temp));
    479             if (all_indices.find(temp) != all_indices.end()) {
    480               return errors::Internal(
    481                   "Duplicate entry for ", temp,
    482                   " found when reading interleave and staging indices.");
    483             }
    484             if (temp >= 0) {
    485               all_indices.insert(temp);
    486             }
    487             staging_indices_.emplace_back(temp);
    488           }
    489         }
    491         // Start Worker threads.
    492         if (reader->Contains(full_name("worker_threads_running"))) {
    493           worker_threads_.reserve(dataset()->num_threads());
    494           for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    495             std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
    496             worker_threads_.emplace_back(ctx->StartThread(
    497                 strings::StrCat("tf_data_parallel_interleave_worker_", i),
    498                 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
    499           }
    500         }
    501         return Status::OK();
    502       }
    504      private:
    505       // OutputElem contains the information from a call to GetNext by an output
    506       // iterator.
    507       struct OutputElem {
    508         // The output iterator sets `status` if getting the output element
    509         // fails.
    510         Status status;
    511         // The buffered data element.
    512         std::vector<Tensor> output;
    514         explicit OutputElem(const Status& s) : status(s) {}
    515       };
    517       // Worker threads operate on their relevant WorkerState structs.
    518       //
    519       // WorkerState's fields are all protected by mu_;
    520       struct WorkerState {
    521         // The arguments to be used to construct an output iterator.
    522         std::vector<Tensor> input;
    523         // The buffered output elements.
    524         std::deque<OutputElem> outputs;
    525         // Set to true iff the worker thread expects to append more elements to
    526         // outputs. is_producing can be false despite !outputs.empty().
    527         // Concretely, all output elements will have been consumed only when:
    528         // is_producing == false && outputs.empty();
    529         bool is_producing = false;
    530         // Condition variable used to coordinate between threads. The worker
    531         // thread waits on this condition variable when it is either (1) waiting
    532         // for the main thread to add arguments to `input`, or (2) waiting for
    533         // the main thread to consume an element of `outputs`. The main thread
    534         // waits on cond_var if it is waiting for the worker thread to produce
    535         // an element into `outputs` (this implies sloppy_==false).
    536         condition_variable cond_var;
    538         inline bool MayHaveElements() const {
    539           return is_producing || !outputs.empty();
    540         }
    542         // Sets inputs for a worker thread and notifies it to start processing.
    543         void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
    544           if (s.ok()) {
    545             DCHECK(!MayHaveElements())
    546                 << "Tried to start inputs, despite already producing!";
    547             input = std::move(input_arguments);
    548             is_producing = true;
    549             cond_var.notify_one();
    550           } else {
    551             outputs.emplace_back(s);
    552           }
    553         }
    554       };
    556       // The internal state of a worker thread that is not already captured
    557       // in its `WorkerState`.
    558       //
    559       // This is needed only for checkpointing purposes. We keep this
    560       // separate from `WorkerState` and guard its fields using a separate
    561       // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
    562       struct WorkerThreadState {
    563         // The output element that has been produced from the input iterator
    564         // and is waiting to be added to `WorkerState.outputs`.
    565         OutputElem output_elem;
    567         // Whether the input iterator returned an `end_of_sequence`.
    568         bool end_of_sequence = false;
    570         // Status returned from `MakeIteratorFromInputElement`.
    571         Status iterator_creation_status;
    573         // The arguments to be used to construct `iterator`.
    574         std::vector<Tensor> input;
    576         std::unique_ptr<IteratorBase> iterator;
    578         WorkerThreadState() : output_elem(Status::OK()) {}
    579       };
    581       Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
    582           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    583         if (worker_threads_.empty()) {
    584           worker_threads_.reserve(dataset()->num_threads());
    585           for (int64 i = 0; i < dataset()->num_threads(); ++i) {
    586             std::vector<Tensor> args;
    587             bool end_of_input = false;
    588             Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
    589             if (end_of_input) {
    590               input_impl_.reset();
    591               return Status::OK();
    592             }
    593             workers_[i].SetInputs(s, std::move(args));
    594             std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
    595             worker_threads_.push_back(ctx->StartThread(
    596                 strings::StrCat("tf_data_parallel_interleave_worker_", i),
    597                 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
    598             if (i < dataset()->cycle_length_) {
    599               interleave_indices_.push_back(i);
    600             } else {
    601               staging_indices_.push_back(i);
    602             }
    603           }
    604           DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
    605           DCHECK(staging_indices_.size() ==
    606                  dataset()->prefetch_input_elements_);
    607         }
    608         return Status::OK();
    609       }
    611       // Produces elements into the worker's output buffers.
    612       void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
    613                         const int64 thread_index) {
    614         // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
    615         //
    616         // 1. Any local state that may need to be checkpointed should be kept
    617         //    in `worker_thread_states_[thread_index]`.
    618         // 2. `WorkerThreadState` should contain state that is needed only for
    619         //    checkpointing, i.e., if we were to remove checkpointing support,
    620         //    we could keep that state as local variables in this thread.
    621         // 3. This thread should only read/write state at `thread_index`
    622         //    and should not access other thread states.
    623         // 4. When restoring from checkpoint, threads are started only after
    624         //    the restore is complete.
    625         // 5. Once restored from a checkpoint, the local state is edited only
    626         //    by this thread. 3 & 4 allow making assumptions like temporarily
    627         //    caching local state in this thread and using it outside a lock
    628         //    e.g. `make_new_iterator`.
    629         // 6. `ckpt_mu_` should be wisely used to create *consistent*
    630         //    checkpoint markers.
    632         // std::function arguments are copy-constructable, so we pass raw
    633         // pointers, and then immediately wrap them to ensure correct ownership.
    634         RecordStart(ctx.get());
    635         auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
    636           mutex_lock l(mu_);
    637           workers_[thread_index].cond_var.notify_all();
    638           RecordStop(ctx.get());
    639         });
    640         bool make_new_iterator;
    641         {
    642           tf_shared_lock l(ckpt_mu_);
    643           // Decide whether a new iterator should be built.
    644           // 1. If there is an existing iterator, we use it.
    645           // 2. If there was an error in iterator creation that could not be
    646           //    notified to the client we attempt to send that to the client
    647           //    first.
    648           make_new_iterator =
    649               worker_thread_states_[thread_index].iterator == nullptr &&
    650               worker_thread_states_[thread_index].iterator_creation_status.ok();
    651         }
    652         // Even though `make_new_iterator` has cached values from
    653         // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
    654         // it is safe to *read* `make_new_iterator`outside of a lock without
    655         // worrying about concurrent changes to values in
    656         // `worker_thread_states_[thread_index]`. See comment at the start of
    657         // this function for details.
    658         while (true) {
    659           // Whether creation of the iterator succeeded.
    660           Status iterator_creation_status;
    661           // 1. Build a new iterator or use the existing one.
    662           if (make_new_iterator) {
    663             // 1a. Get new input tensors or use the exiting ones.
    664             bool read_new_input;
    665             {
    666               tf_shared_lock l(ckpt_mu_);
    667               // worker_thread_states_[thread_index].input will be non-empty
    668               // if checkpointing happened at CHECKPOINT_MARKER_A.
    669               read_new_input =
    670                   worker_thread_states_[thread_index].input.empty();
    671             }
    673             if (read_new_input) {
    674               mutex_lock l(mu_);
    675               while (!cancelled_ && !workers_[thread_index].is_producing) {
    676                 RecordStop(ctx.get());
    677                 workers_[thread_index].cond_var.wait(l);
    678                 RecordStart(ctx.get());
    679               }
    680               if (cancelled_) return;
    681               // Copy the input tensors so that we do not need to block on `mu_`
    682               // when building the iterator.
    683               // We keep a copy of the input tensors in
    684               // `WorkerThreadState.input` till the iterator is in use. This is
    685               // used in `RestoreInternal` to re-build the iterator.
    686               // TODO(b/78046638): Explore ways to avoid tracking the input
    687               // tensors.
    688               tf_shared_lock ckpt_l(ckpt_mu_);
    689               worker_thread_states_[thread_index].input.swap(
    690                   workers_[thread_index].input);
    691               // CHECKPOINT_MARKER_A
    692               // We have the input tensors but have not built the iterator yet.
    693             }
    695             // 1b. Run the user defined function to produce a new iterator.
    696             {
    697               tf_shared_lock l(ckpt_mu_);
    698               worker_thread_states_[thread_index].iterator_creation_status =
    699                   MakeIteratorFromInputElement(
    700                       ctx.get(), worker_thread_states_[thread_index].input,
    701                       thread_index, *instantiated_captured_func_, prefix(),
    702                       &worker_thread_states_[thread_index].iterator);
    703               iterator_creation_status =
    704                   worker_thread_states_[thread_index].iterator_creation_status;
    705               if (!iterator_creation_status.ok()) {
    706                 worker_thread_states_[thread_index].input.clear();
    707               }
    708               // CHECKPOINT_MARKER_B
    709               // Either an iterator has been successfully built and placed in
    710               // `worker_thread_states_[thread_index].iterator` or it failed and
    711               // a non-OK status has been put in
    712               // `worker_thread_states_[thread_index].iterator_creation_status`.
    713             }
    714           } else {
    715             tf_shared_lock l(ckpt_mu_);
    716             iterator_creation_status =
    717                 worker_thread_states_[thread_index].iterator_creation_status;
    718             // Mark that we have used up the restored iterator.
    719             make_new_iterator = true;
    720           }
    721           // 2. Start producing elements or send error state to client if
    722           //    iterator creation failed.
    723           if (!iterator_creation_status.ok()) {
    724             mutex_lock l(mu_);
    725             // Wait for space in the prefetch queue.
    726             while (!cancelled_ && workers_[thread_index].outputs.size() ==
    727                                       dataset()->buffer_output_elements_) {
    728               RecordStop(ctx.get());
    729               workers_[thread_index].cond_var.wait(l);
    730               RecordStart(ctx.get());
    731             }
    732             if (cancelled_) return;
    733             tf_shared_lock ckpt_l(ckpt_mu_);
    734             workers_[thread_index].outputs.emplace_back(
    735                 iterator_creation_status);
    736             workers_[thread_index].is_producing = false;
    737             worker_thread_states_[thread_index].iterator_creation_status =
    738                 Status::OK();
    739             // CHECKPOINT_MARKER_C
    740             // Non-OK iterator creation status has been notified to the
    741             // client.
    742             workers_[thread_index].cond_var.notify_one();
    743           } else {
    744             bool end_of_sequence = false;
    745             while (!end_of_sequence) {
    746               // 3.a Produce an element!
    747               {
    748                 tf_shared_lock ckpt_l(ckpt_mu_);
    749                 if (worker_thread_states_[thread_index]
    750                         .output_elem.status.ok() &&
    751                     worker_thread_states_[thread_index]
    752                         .output_elem.output.empty() &&
    753                     !worker_thread_states_[thread_index].end_of_sequence) {
    754                   worker_thread_states_[thread_index].output_elem.status =
    755                       worker_thread_states_[thread_index].iterator->GetNext(
    756                           ctx.get(),
    757                           &worker_thread_states_[thread_index]
    758                                .output_elem.output,
    759                           &worker_thread_states_[thread_index].end_of_sequence);
    760                   end_of_sequence =
    761                       worker_thread_states_[thread_index].end_of_sequence;
    762                 } else {
    763                   end_of_sequence =
    764                       worker_thread_states_[thread_index].end_of_sequence;
    765                 }
    766                 // CHECKPOINT_MARKER_D
    767                 // An element has been read or an error or end_of_sequence has
    768                 // been received from the input iterator and is waiting to be
    769                 // sent to client.
    770               }
    772               // 3.b Make it available to the client.
    773               {
    774                 mutex_lock l(mu_);
    776                 // Wait for space in the prefetch queue.
    777                 while (!cancelled_ && workers_[thread_index].outputs.size() ==
    778                                           dataset()->buffer_output_elements_) {
    779                   RecordStop(ctx.get());
    780                   workers_[thread_index].cond_var.wait(l);
    781                   RecordStart(ctx.get());
    782                 }
    783                 if (cancelled_) return;
    785                 tf_shared_lock ckpt_l(ckpt_mu_);
    786                 workers_[thread_index].is_producing = !end_of_sequence;
    788                 // Output the element.
    790                 // Move the temporary state in WorkerThreadState to WorkerState
    791                 // and mark it as used.
    792                 if (end_of_sequence) {
    793                   worker_thread_states_[thread_index].iterator.reset();
    794                   worker_thread_states_[thread_index].input.clear();
    795                   worker_thread_states_[thread_index].end_of_sequence = false;
    796                 } else {
    797                   workers_[thread_index].outputs.emplace_back(
    798                       worker_thread_states_[thread_index].output_elem.status);
    799                   workers_[thread_index].outputs.back().output.swap(
    800                       worker_thread_states_[thread_index].output_elem.output);
    801                 }
    802                 worker_thread_states_[thread_index].output_elem.status =
    803                     Status::OK();
    804                 if (dataset()->sloppy_) {
    805                   sloppy_cond_var_.notify_one();
    806                 } else {
    807                   workers_[thread_index].cond_var.notify_one();
    808                 }
    809                 // CHECKPOINT_MARKER_E
    810                 // Output element or iterator status has been sent to the
    811                 // client.
    812               }
    813             }
    814           }
    815         }
    816       }
    818       Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
    819           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    820         string prefix = strings::StrCat("worker_", index);
    821         TF_RETURN_IF_ERROR(writer->WriteScalar(
    822             full_name(strings::StrCat(prefix, "_input_size")),
    823             workers_[index].input.size()));
    824         for (int i = 0; i < workers_[index].input.size(); ++i) {
    825           TF_RETURN_IF_ERROR(writer->WriteTensor(
    826               full_name(strings::StrCat(prefix, "_input_", i)),
    827               workers_[index].input[i]));
    828         }
    829         TF_RETURN_IF_ERROR(writer->WriteScalar(
    830             full_name(strings::StrCat(prefix, "_outputs_size")),
    831             workers_[index].outputs.size()));
    832         for (int i = 0; i < workers_[index].outputs.size(); ++i) {
    833           TF_RETURN_IF_ERROR(WriteOutputElemLocked(
    834               writer, workers_[index].outputs[i],
    835               full_name(strings::StrCat(prefix, "_outputs_", i))));
    836         }
    837         if (workers_[index].is_producing) {
    838           TF_RETURN_IF_ERROR(writer->WriteScalar(
    839               full_name(strings::StrCat(prefix, "_is_producing")), ""));
    840         }
    841         return Status::OK();
    842       }
    844       Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
    845                                    IteratorContext* ctx)
    846           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    847         string worker_prefix = strings::StrCat("worker_", index);
    848         // Restore inputs.
    849         int64 input_size;
    850         TF_RETURN_IF_ERROR(reader->ReadScalar(
    851             full_name(strings::StrCat(worker_prefix, "_input_size")),
    852             &input_size));
    853         workers_[index].input.reserve(input_size);
    854         for (int i = 0; i < input_size; ++i) {
    855           workers_[index].input.emplace_back();
    856           TF_RETURN_IF_ERROR(reader->ReadTensor(
    857               full_name(strings::StrCat(worker_prefix, "_input_", i)),
    858               &workers_[index].input.back()));
    859         }
    860         int64 outputs_size;
    861         TF_RETURN_IF_ERROR(reader->ReadScalar(
    862             full_name(strings::StrCat(worker_prefix, "_outputs_size")),
    863             &outputs_size));
    864         for (int i = 0; i < outputs_size; ++i) {
    865           workers_[index].outputs.emplace_back(Status::OK());
    866           TF_RETURN_IF_ERROR(ReadOutputElemLocked(
    867               reader, &workers_[index].outputs.back(),
    868               full_name(strings::StrCat(worker_prefix, "_outputs_", i))));
    869         }
    870         if (reader->Contains(
    871                 full_name(strings::StrCat(worker_prefix, "_is_producing")))) {
    872           workers_[index].is_producing = true;
    873         } else {
    874           workers_[index].is_producing = false;
    875         }
    876         return Status::OK();
    877       }
    879       Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer,
    880                                           int index)
    881           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    882         string prefix = strings::StrCat("worker_thread_", index);
    883         if (worker_thread_states_[index].iterator != nullptr) {
    884           TF_RETURN_IF_ERROR(
    885               SaveInput(writer, worker_thread_states_[index].iterator));
    886         } else {
    887           TF_RETURN_IF_ERROR(writer->WriteScalar(
    888               full_name(strings::StrCat(prefix, "_iterator_exhausted")), ""));
    889         }
    890         TF_RETURN_IF_ERROR(writer->WriteScalar(
    891             full_name(strings::StrCat(prefix, "_input_size")),
    892             worker_thread_states_[index].input.size()));
    893         for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
    894           TF_RETURN_IF_ERROR(writer->WriteTensor(
    895               full_name(strings::StrCat(prefix, "_input_", i)),
    896               worker_thread_states_[index].input[i]));
    897         }
    898         TF_RETURN_IF_ERROR(WriteStatusLocked(
    899             writer, strings::StrCat(prefix, "_iterator_creation_status"),
    900             worker_thread_states_[index].iterator_creation_status));
    901         TF_RETURN_IF_ERROR(WriteOutputElemLocked(
    902             writer, worker_thread_states_[index].output_elem,
    903             full_name(strings::StrCat(prefix, "_output"))));
    904         if (worker_thread_states_[index].end_of_sequence) {
    905           TF_RETURN_IF_ERROR(writer->WriteScalar(
    906               full_name(strings::StrCat(prefix, "_end_of_sequence")), ""));
    907         }
    908         return Status::OK();
    909       }
    911       Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
    912                                          IteratorContext* ctx)
    913           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    914         string worker_prefix = strings::StrCat("worker_thread_", index);
    915         // Restore inputs.
    916         int64 input_size;
    917         TF_RETURN_IF_ERROR(reader->ReadScalar(
    918             full_name(strings::StrCat(worker_prefix, "_input_size")),
    919             &input_size));
    920         worker_thread_states_[index].input.reserve(input_size);
    921         for (int i = 0; i < input_size; ++i) {
    922           worker_thread_states_[index].input.emplace_back();
    923           TF_RETURN_IF_ERROR(reader->ReadTensor(
    924               full_name(strings::StrCat(worker_prefix, "_input_", i)),
    925               &worker_thread_states_[index].input.back()));
    926         }
    927         // Restore iterator.
    928         if (reader->Contains(full_name(
    929                 strings::StrCat(worker_prefix, "_iterator_exhausted")))) {
    930           worker_thread_states_[index].iterator.reset();
    931         } else {
    932           std::unique_ptr<IteratorBase> iterator;
    933           Status s = MakeIteratorFromInputElement(
    934               ctx, worker_thread_states_[index].input, index,
    935               *instantiated_captured_func_, prefix(), &iterator);
    936           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
    937           worker_thread_states_[index].iterator.swap(iterator);
    938         }
    939         TF_RETURN_IF_ERROR(ReadStatusLocked(
    940             reader, strings::StrCat(worker_prefix, "_iterator_creation_status"),
    941             &worker_thread_states_[index].iterator_creation_status));
    942         TF_RETURN_IF_ERROR(ReadOutputElemLocked(
    943             reader, &worker_thread_states_[index].output_elem,
    944             full_name(strings::StrCat(worker_prefix, "_output"))));
    945         if (reader->Contains(full_name(
    946                 strings::StrCat(worker_prefix, "_end_of_sequence")))) {
    947           worker_thread_states_[index].end_of_sequence = true;
    948         } else {
    949           worker_thread_states_[index].end_of_sequence = false;
    950         }
    951         return Status::OK();
    952       }
    954       Status WriteOutputElemLocked(IteratorStateWriter* writer,
    955                                    const OutputElem& output_elem,
    956                                    const string& prefix)
    957           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    958         TF_RETURN_IF_ERROR(WriteStatusLocked(
    959             writer, strings::StrCat(prefix, "_status"), output_elem.status));
    960         TF_RETURN_IF_ERROR(
    961             writer->WriteScalar(strings::StrCat(prefix, "_output_size"),
    962                                 output_elem.output.size()));
    963         for (int i = 0; i < output_elem.output.size(); ++i) {
    964           TF_RETURN_IF_ERROR(writer->WriteTensor(
    965               strings::StrCat(prefix, "_output_", i), output_elem.output[i]));
    966         }
    967         return Status::OK();
    968       }
    970       Status ReadOutputElemLocked(IteratorStateReader* reader,
    971                                   OutputElem* output_elem, const string& prefix)
    972           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    973         TF_RETURN_IF_ERROR(ReadStatusLocked(
    974             reader, strings::StrCat(prefix, "_status"), &output_elem->status));
    975         int64 output_size;
    976         TF_RETURN_IF_ERROR(reader->ReadScalar(
    977             strings::StrCat(prefix, "_output_size"), &output_size));
    978         output_elem->output.reserve(output_size);
    979         for (int i = 0; i < output_size; ++i) {
    980           output_elem->output.emplace_back();
    981           TF_RETURN_IF_ERROR(
    982               reader->ReadTensor(strings::StrCat(prefix, "_output_", i),
    983                                  &output_elem->output.back()));
    984         }
    985         return Status::OK();
    986       }
    988       Status WriteStatusLocked(IteratorStateWriter* writer,
    989                                const string& prefix, const Status& status)
    990           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    991         TF_RETURN_IF_ERROR(
    992             writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
    993                                 static_cast<int64>(status.code())));
    994         if (!status.ok()) {
    995           TF_RETURN_IF_ERROR(
    996               writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
    997                                   status.error_message()));
    998         }
    999         return Status::OK();
   1000       }
   1002       Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
   1003                               Status* status)
   1004           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
   1005         int64 code_int;
   1006         TF_RETURN_IF_ERROR(reader->ReadScalar(
   1007             full_name(strings::StrCat(prefix, "_code")), &code_int));
   1008         error::Code code = static_cast<error::Code>(code_int);
   1010         if (code != error::Code::OK) {
   1011           string error_message;
   1012           TF_RETURN_IF_ERROR(reader->ReadScalar(
   1013               full_name(strings::StrCat(prefix, "_msg")), &error_message));
   1014           *status = Status(code, error_message);
   1015         } else {
   1016           *status = Status::OK();
   1017         }
   1018         return Status::OK();
   1019       }
   1021       // Mutex & condition variable to guard mutable iterator internals and
   1022       // coordinate among worker threads and client thread[s].
   1023       mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
   1024       // The main thread waits on this condition variable if running in sloppy
   1025       // mode and no values are available.
   1026       condition_variable sloppy_cond_var_;
   1027       // Mutex used to wait for a consistent state while checkpointing.
   1028       // Only Save and Restore require an exclusive lock on this mutex. In
   1029       // other scenarios we just acquire a shared lock so the pipeline's
   1030       // performance should not be affected in the absence of checkpointing.
   1031       // A thread must not wait on any condition variable while holding
   1032       // `ckpt_mu_` in either shared or exclusive modes.
   1033       mutex ckpt_mu_;
   1035       // The iterator producing elements which are converted to datasets by
   1036       // the dataset()->captured_func_ then interleaved together.
   1037       // input_impl_ is reset when we have exhausted its input.
   1038       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
   1040       std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
   1042       // The WorkerState structs the worker threads operate on.
   1043       // workers_ elements are in at most one of interleave_ and staging_.
   1044       std::vector<WorkerState> workers_ GUARDED_BY(mu_);
   1046       // Stores the temporary state of WorkerThreads which is not stored in
   1047       // WorkerState. This is used for checkpointing purposes only.
   1048       std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
   1050       // Indices in `workers_` of iterators to interleave.
   1051       std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
   1052       // Indices in `workers_` of prefetched iterators.
   1053       std::deque<int64> staging_indices_ GUARDED_BY(mu_);
   1055       // The index into output_elements_ for next element to produce.
   1056       size_t next_index_ GUARDED_BY(mu_) = 0;
   1057       // The number of items produced so far within the block
   1058       size_t block_count_ GUARDED_BY(mu_) = 0;
   1059       // Flag to instruct the worker threads to exit.
   1060       bool cancelled_ GUARDED_BY(mu_) = false;
   1061       // The worker threads. This must be last to ensure the
   1062       // threads have exited before any other members are deallocated.
   1063       // TODO(b/65178177): Avoid allocating additional threads.
   1064       std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
   1065     };
   1067     const DatasetBase* const input_;
   1068     const NameAttrList interleave_func_;
   1069     const std::unique_ptr<CapturedFunction> captured_func_;
   1070     const int64 cycle_length_;
   1071     const int64 block_length_;
   1072     const bool sloppy_;
   1073     const int64 buffer_output_elements_;
   1074     const int64 prefetch_input_elements_;
   1075     const DataTypeVector output_types_;
   1076     const std::vector<PartialTensorShape> output_shapes_;
   1077   };
   1079   DataTypeVector output_types_;
   1080   std::vector<PartialTensorShape> output_shapes_;
   1081   NameAttrList interleave_func_;
   1082 };
   1085     Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
   1086     ParallelInterleaveDatasetOp);
   1088 }  // namespace
   1089 }  // namespace data
   1090 }  // namespace tensorflow