Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <atomic>
     16 #include <deque>
     17 #include <utility>
     18 
     19 #include "tensorflow/core/common_runtime/function.h"
     20 #include "tensorflow/core/framework/dataset.h"
     21 #include "tensorflow/core/framework/partial_tensor_shape.h"
     22 #include "tensorflow/core/framework/stats_aggregator.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/kernels/data/captured_function.h"
     25 #include "tensorflow/core/kernels/data/dataset_utils.h"
     26 #include "tensorflow/core/lib/core/threadpool.h"
     27 #include "tensorflow/core/lib/gtl/cleanup.h"
     28 #include "tensorflow/core/lib/random/random.h"
     29 
     30 namespace tensorflow {
     31 namespace data {
     32 namespace {
     33 
     34 // See documentation in ../../ops/dataset_ops.cc for a high-level
     35 // description of the following op.
     36 
     37 class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
     38  public:
     39   explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
     40       : UnaryDatasetOpKernel(ctx) {
     41     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
     42     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     43     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     44   }
     45 
     46   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     47                    DatasetBase** output) override {
     48     int64 cycle_length = 0;
     49     OP_REQUIRES_OK(ctx,
     50                    ParseScalarArgument(ctx, "cycle_length", &cycle_length));
     51     OP_REQUIRES(ctx, cycle_length > 0,
     52                 errors::InvalidArgument("`cycle_length` must be > 0"));
     53 
     54     int64 block_length = 0;
     55     OP_REQUIRES_OK(ctx,
     56                    ParseScalarArgument(ctx, "block_length", &block_length));
     57     OP_REQUIRES(ctx, block_length > 0,
     58                 errors::InvalidArgument("`block_length` must be > 0"));
     59 
     60     bool sloppy = false;
     61     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
     62 
     63     int64 buffer_output_elements = 0;
     64     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
     65                                             &buffer_output_elements));
     66     OP_REQUIRES(
     67         ctx, buffer_output_elements > 0,
     68         errors::InvalidArgument("`buffer_output_elements` must be > 0"));
     69 
     70     int64 prefetch_input_elements = 0;
     71     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
     72                                             &prefetch_input_elements));
     73     OP_REQUIRES(
     74         ctx, prefetch_input_elements >= 0,
     75         errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
     76 
     77     std::unique_ptr<CapturedFunction> captured_func;
     78     OP_REQUIRES_OK(
     79         ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
     80                                       &captured_func));
     81 
     82     *output =
     83         new Dataset(ctx, input, interleave_func_, std::move(captured_func),
     84                     cycle_length, block_length, sloppy, buffer_output_elements,
     85                     prefetch_input_elements, output_types_, output_shapes_);
     86   }
     87 
     88  private:
     89   class Dataset : public DatasetBase {
     90    public:
     91     Dataset(OpKernelContext* ctx, const DatasetBase* input,
     92             const NameAttrList& func,
     93             std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
     94             int64 block_length, bool sloppy, int64 buffer_output_elements,
     95             int64 prefetch_input_elements, const DataTypeVector& output_types,
     96             const std::vector<PartialTensorShape>& output_shapes)
     97         : DatasetBase(DatasetContext(ctx)),
     98           input_(input),
     99           interleave_func_(func),
    100           captured_func_(std::move(captured_func)),
    101           cycle_length_(cycle_length),
    102           block_length_(block_length),
    103           sloppy_(sloppy),
    104           buffer_output_elements_(buffer_output_elements),
    105           prefetch_input_elements_(prefetch_input_elements),
    106           output_types_(output_types),
    107           output_shapes_(output_shapes) {
    108       input_->Ref();
    109     }
    110 
    111     ~Dataset() override { input_->Unref(); }
    112 
    113     std::unique_ptr<IteratorBase> MakeIteratorInternal(
    114         const string& prefix) const override {
    115       return absl::make_unique<Iterator>(Iterator::Params{
    116           this, strings::StrCat(prefix, "::ParallelInterleave")});
    117     }
    118 
    119     const DataTypeVector& output_dtypes() const override {
    120       return output_types_;
    121     }
    122 
    123     const std::vector<PartialTensorShape>& output_shapes() const override {
    124       return output_shapes_;
    125     }
    126 
    127     string DebugString() const override {
    128       return "ParallelInterleaveDatasetOp::Dataset";
    129     }
    130 
    131    protected:
    132     Status AsGraphDefInternal(SerializationContext* ctx,
    133                               DatasetGraphDefBuilder* b,
    134                               Node** output) const override {
    135       TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
    136       Node* input_node;
    137       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
    138       Node* cycle_length_node;
    139       TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
    140       Node* block_length_node;
    141       TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
    142       Node* sloppy_node;
    143       TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node));
    144       Node* buffer_output_elements_node;
    145       TF_RETURN_IF_ERROR(
    146           b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
    147       Node* prefetch_input_elements_node;
    148       TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
    149                                       &prefetch_input_elements_node));
    150       DataTypeVector other_arguments_types;
    151       other_arguments_types.reserve(captured_func_->captured_inputs().size());
    152       std::vector<Node*> other_arguments;
    153       other_arguments.reserve(captured_func_->captured_inputs().size());
    154       for (const Tensor& t : captured_func_->captured_inputs()) {
    155         Node* node;
    156         DatasetBase* input;
    157         Status s = GetDatasetFromVariantTensor(t, &input);
    158         if (s.ok()) {
    159           TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
    160         } else {
    161           TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
    162         }
    163         other_arguments.emplace_back(node);
    164         other_arguments_types.emplace_back(t.dtype());
    165       }
    166       AttrValue f;
    167       b->BuildAttrValue(interleave_func_, &f);
    168       AttrValue other_arguments_types_attr;
    169       b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
    170 
    171       TF_RETURN_IF_ERROR(b->AddDataset(
    172           this,
    173           {{0, input_node},
    174            {2, cycle_length_node},
    175            {3, block_length_node},
    176            {4, sloppy_node},
    177            {5, buffer_output_elements_node},
    178            {6, prefetch_input_elements_node}},
    179           {{1, other_arguments}},
    180           {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
    181       return Status::OK();
    182     }
    183 
    184    private:
    185     int64 num_threads() const {
    186       return cycle_length_ + prefetch_input_elements_;
    187     }
    188 
    189     // Parallel interleave's implementation is designed around a few principles:
    190     //  1. Thread creation is relatively expensive. (Not reusing
    191     //     threads causes a number of indirect costs such as poorer tcmalloc
    192     //     performance due to thread-local caches, etc.) We allocate a fixed
    193     //     number of threads at the start and never change. This is why we've
    194     //     fused functionality that is theoretically orthogonal (i.e.
    195     //     .prefetch()) into the implementation.
    196     //  2. Drop-in replacement for standard interleave. The goal will be to
    197     //     auto-opt people into an optimized implementation without any work
    198     //     on the customer's part. We thus go through great pains to maintain
    199     //     identical iteration orders, full determinism (disabled only via a
    200     //     flag, etc.)
    201     //  3. Performance across a variety of environments and I/O envelopes.
    202     //
    203     // The actual implementation centers around a collection of worker threads
    204     // and their corresponding worker state (tracked in the `workers_` vector).
    205     // Worker threads repeatedly receive a vector of Tensors that are used as
    206     // input to the flat-map function (`captured_func_`). The output of this
    207     // function must be a dataset. The worker thread then repeatedly calls
    208     // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
    209     // that a caller will block waiting for an element to be produced.
    210     //
    211     // Pointers to these worker states are kept in 2 disjoint data structures:
    212     //  1. `interleave_indices_` is a vector containing indices of WorkerStates
    213     //     in `workers_` that we are interleaving. Worker threads backing these
    214     //     WorkerStates should be regularly producing values.
    215     //  2. `staging_indices_` is a deque containing indices of WorkerStates in
    216     //     `workers_` that we will move to `interleave_indices_` when an
    217     //     iterator in `interleave_indices_` is exhausted.
    218     //
    219     // The client calls `GetNext[Internal]()` to retrieve an output element. The
    220     // internal implementation updates the state of `interleave_indices_` and
    221     // `staging_indices_` as output iterators (run by the worker threads) are
    222     // exhausted.
    223     //
    224     // `input_impl_` is the input iterator that generates arguments for the
    225     // flat-map function (`captured_func_`). It is set to an iterator at
    226     // Iterator construction, and is fixed until we consume all input elements.
    227     // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
    228     // memory.
    229     //
    230     // A few invariants are maintained:
    231     //  1. No element in interleave_indices_ should be a -1 unless
    232     //     `staging_indices_` is empty and `input_impl_` is empty.
    233     //  2. Every `worker_` element is pointed to by at most one element of the
    234     //     union of `interleave_indices_` and `staging_indices_`.
    235     //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
    236     //     an element in `interleave_indices_` or `staging_indices_`.
    237     class Iterator : public DatasetIterator<Dataset> {
    238      public:
    239       explicit Iterator(const Params& params)
    240           : DatasetIterator<Dataset>(params),
    241             workers_(dataset()->num_threads()),
    242             worker_thread_states_(dataset()->num_threads()) {}
    243 
    244       ~Iterator() override {
    245         mutex_lock l(mu_);
    246         cancelled_ = true;
    247         // Notify all workers in case they are blocked.
    248         for (auto& worker : workers_) {
    249           worker.cond_var.notify_all();
    250         }
    251       }
    252 
    253       Status Initialize(IteratorContext* ctx) override {
    254         TF_RETURN_IF_ERROR(
    255             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
    256         return dataset()->captured_func_->Instantiate(
    257             ctx, &instantiated_captured_func_);
    258       }
    259 
    260       // It is implemented so that it matches the deterministic interleave
    261       // unless getting the next element would block and we are allowed to be
    262       // sloppy.
    263       Status GetNextInternal(IteratorContext* ctx,
    264                              std::vector<Tensor>* out_tensors,
    265                              bool* end_of_sequence) override {
    266         mutex_lock l(mu_);
    267         TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
    268         while (!cancelled_) {
    269           // Wait for an item to become available, blocking if necessary. If we
    270           // are allowed to be sloppy, we can skip over input datasets that do
    271           // not have an item readily available.
    272           bool can_produce_elements = false;
    273           bool must_wait_for_input = true;
    274           for (int64 i = 0; i < interleave_indices_.size(); ++i) {
    275             int64 index = (next_index_ + i) % interleave_indices_.size();
    276             int64 current_worker_index = interleave_indices_[index];
    277             if (current_worker_index < 0) {
    278               continue;  // Empty interleave elements.
    279             }
    280             WorkerState* current_worker = &workers_[current_worker_index];
    281             can_produce_elements |= current_worker->MayHaveElements();
    282             if (!current_worker->outputs.empty()) {
    283               // We have an element!
    284               next_index_ = index;
    285               const bool element_acquired_sloppily =
    286                   dataset()->sloppy_ && i > 1;
    287               if (!element_acquired_sloppily) {
    288                 // If the element was acquired in the regular (non-sloppy)
    289                 // order, then advance the current block and cycle pointers to
    290                 // the next element in the regular order.
    291                 block_count_++;
    292                 if (block_count_ == dataset()->block_length_) {
    293                   next_index_ = (index + 1) % interleave_indices_.size();
    294                   block_count_ = 0;
    295                 }
    296               } else {
    297                 block_count_ = 0;
    298               }
    299               *end_of_sequence = false;
    300               Status s = current_worker->outputs.front().status;
    301               current_worker->outputs.front().output.swap(*out_tensors);
    302               current_worker->outputs.pop_front();
    303               current_worker->cond_var.notify_one();
    304               return s;
    305             } else if (current_worker->is_producing && !dataset()->sloppy_) {
    306               // current_worker.outputs.empty(), and we must wait for this
    307               // iterator.
    308               if (next_index_ != index) {
    309                 // We have advanced to a new iterator; reset block counts.
    310                 next_index_ = index;
    311                 block_count_ = 0;
    312               }
    313               break;
    314             } else if (!current_worker->is_producing) {
    315               // This iterator has reached end of input.
    316               interleave_indices_[index] = -1;
    317               if (input_impl_) {
    318                 // Start prefetching a new iterator.
    319                 std::vector<Tensor> args;
    320                 bool end_of_input = false;
    321                 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
    322                 if (end_of_input) {
    323                   input_impl_.reset();
    324                 } else {
    325                   current_worker->SetInputs(s, std::move(args));
    326                   staging_indices_.emplace_back(current_worker_index);
    327                 }
    328               }
    329 
    330               if (!staging_indices_.empty()) {
    331                 // Move a worker from `staging_indices_` to
    332                 // `interleave_indices_`.
    333                 interleave_indices_[index] = staging_indices_.front();
    334                 staging_indices_.pop_front();
    335 
    336                 next_index_ = (index + 1) % interleave_indices_.size();
    337                 block_count_ = 0;
    338                 // Restart the inner [for] loop
    339                 can_produce_elements = true;
    340                 must_wait_for_input = false;
    341                 break;
    342               }
    343             }
    344           }
    345 
    346           if (!can_produce_elements && !input_impl_) {
    347             // No potential for future values.
    348             *end_of_sequence = true;
    349             return Status::OK();
    350           }
    351 
    352           if (must_wait_for_input) {
    353             // Wait for elements to become available.
    354             RecordStop(ctx);
    355             if (dataset()->sloppy_) {
    356               sloppy_cond_var_.wait(l);
    357             } else {
    358               workers_[interleave_indices_[next_index_]].cond_var.wait(l);
    359             }
    360             RecordStart(ctx);
    361           }
    362         }
    363         return errors::Cancelled(
    364             "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
    365       }
    366 
    367      protected:
    368       std::shared_ptr<model::Node> CreateNode(
    369           IteratorContext* ctx, model::Node::Args args) const override {
    370         return model::MakeAsyncInterleaveManyNode(std::move(args),
    371                                                   /*parameters=*/{});
    372       }
    373 
    374       Status SaveInternal(IteratorStateWriter* writer) override {
    375         // The order of locking is important here to avoid deadlock.
    376         mutex_lock l(mu_);
    377         mutex_lock ckpt_l(ckpt_mu_);
    378         if (input_impl_) {
    379           TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
    380         } else {
    381           TF_RETURN_IF_ERROR(
    382               writer->WriteScalar(full_name("input_exhausted"), ""));
    383         }
    384         TF_RETURN_IF_ERROR(
    385             writer->WriteScalar(full_name("next_index"), next_index_));
    386         TF_RETURN_IF_ERROR(
    387             writer->WriteScalar(full_name("block_count"), block_count_));
    388         TF_RETURN_IF_ERROR(
    389             writer->WriteScalar(full_name("workers_size"), workers_.size()));
    390         for (int i = 0; i < workers_.size(); ++i) {
    391           TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
    392         }
    393         for (int i = 0; i < worker_thread_states_.size(); ++i) {
    394           TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
    395         }
    396         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"),
    397                                                interleave_indices_.size()));
    398         for (int i = 0; i < interleave_indices_.size(); ++i) {
    399           TF_RETURN_IF_ERROR(writer->WriteScalar(
    400               full_name(strings::StrCat("interleave_indices_", i)),
    401               interleave_indices_[i]));
    402         }
    403         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"),
    404                                                staging_indices_.size()));
    405         for (int i = 0; i < staging_indices_.size(); ++i) {
    406           TF_RETURN_IF_ERROR(writer->WriteScalar(
    407               full_name(strings::StrCat("staging_indices_", i)),
    408               staging_indices_[i]));
    409         }
    410         if (!worker_threads_.empty()) {
    411           TF_RETURN_IF_ERROR(
    412               writer->WriteScalar(full_name("worker_threads_running"), ""));
    413         }
    414         return Status::OK();
    415       }
    416 
    417       Status RestoreInternal(IteratorContext* ctx,
    418                              IteratorStateReader* reader) override {
    419         // The order of locking is important here to avoid deadlock.
    420         mutex_lock l(mu_);
    421         mutex_lock ckpt_l(ckpt_mu_);
    422         if (!reader->Contains(full_name("input_exhausted"))) {
    423           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
    424         } else {
    425           input_impl_.reset();
    426         }
    427         int64 temp;
    428         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp));
    429         next_index_ = size_t(temp);
    430         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp));
    431         block_count_ = size_t(temp);
    432 
    433         // Restore WorkerStates.
    434         TF_RETURN_IF_ERROR(
    435             reader->ReadScalar(full_name("workers_size"), &temp));
    436         if (temp != dataset()->num_threads()) {
    437           return errors::Internal("Expected ", dataset()->num_threads(),
    438                                   " worker states but found ", temp, ".");
    439         }
    440         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    441           TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
    442         }
    443         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    444           TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx));
    445         }
    446 
    447         // Restore `interleave_indices_`.
    448         std::set<int64> all_indices;
    449         {
    450           int64 interleave_size;
    451           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"),
    452                                                 &interleave_size));
    453           interleave_indices_.reserve(interleave_size);
    454           for (int64 i = 0; i < interleave_size; ++i) {
    455             int64 temp;
    456             TF_RETURN_IF_ERROR(reader->ReadScalar(
    457                 full_name(strings::StrCat("interleave_indices_", i)), &temp));
    458             if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
    459               return errors::Internal(
    460                   "Duplicate entry for ", temp,
    461                   " found when reading interleave and staging indices.");
    462             }
    463             if (temp >= 0) {
    464               all_indices.insert(temp);
    465             }
    466             interleave_indices_.emplace_back(temp);
    467           }
    468         }
    469 
    470         // Restore `staging_indices_`.
    471         {
    472           int64 staging_size;
    473           TF_RETURN_IF_ERROR(
    474               reader->ReadScalar(full_name("staging_size"), &staging_size));
    475           for (int i = 0; i < staging_size; ++i) {
    476             int64 temp;
    477             TF_RETURN_IF_ERROR(reader->ReadScalar(
    478                 full_name(strings::StrCat("staging_indices_", i)), &temp));
    479             if (all_indices.find(temp) != all_indices.end()) {
    480               return errors::Internal(
    481                   "Duplicate entry for ", temp,
    482                   " found when reading interleave and staging indices.");
    483             }
    484             if (temp >= 0) {
    485               all_indices.insert(temp);
    486             }
    487             staging_indices_.emplace_back(temp);
    488           }
    489         }
    490 
    491         // Start Worker threads.
    492         if (reader->Contains(full_name("worker_threads_running"))) {
    493           worker_threads_.reserve(dataset()->num_threads());
    494           for (size_t i = 0; i < dataset()->num_threads(); ++i) {
    495             std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
    496             worker_threads_.emplace_back(ctx->StartThread(
    497                 strings::StrCat("tf_data_parallel_interleave_worker_", i),
    498                 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
    499           }
    500         }
    501         return Status::OK();
    502       }
    503 
    504      private:
    505       // OutputElem contains the information from a call to GetNext by an output
    506       // iterator.
    507       struct OutputElem {
    508         // The output iterator sets `status` if getting the output element
    509         // fails.
    510         Status status;
    511         // The buffered data element.
    512         std::vector<Tensor> output;
    513 
    514         explicit OutputElem(const Status& s) : status(s) {}
    515       };
    516 
    517       // Worker threads operate on their relevant WorkerState structs.
    518       //
    519       // WorkerState's fields are all protected by mu_;
    520       struct WorkerState {
    521         // The arguments to be used to construct an output iterator.
    522         std::vector<Tensor> input;
    523         // The buffered output elements.
    524         std::deque<OutputElem> outputs;
    525         // Set to true iff the worker thread expects to append more elements to
    526         // outputs. is_producing can be false despite !outputs.empty().
    527         // Concretely, all output elements will have been consumed only when:
    528         // is_producing == false && outputs.empty();
    529         bool is_producing = false;
    530         // Condition variable used to coordinate between threads. The worker
    531         // thread waits on this condition variable when it is either (1) waiting
    532         // for the main thread to add arguments to `input`, or (2) waiting for
    533         // the main thread to consume an element of `outputs`. The main thread
    534         // waits on cond_var if it is waiting for the worker thread to produce
    535         // an element into `outputs` (this implies sloppy_==false).
    536         condition_variable cond_var;
    537 
    538         inline bool MayHaveElements() const {
    539           return is_producing || !outputs.empty();
    540         }
    541 
    542         // Sets inputs for a worker thread and notifies it to start processing.
    543         void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
    544           if (s.ok()) {
    545             DCHECK(!MayHaveElements())
    546                 << "Tried to start inputs, despite already producing!";
    547             input = std::move(input_arguments);
    548             is_producing = true;
    549             cond_var.notify_one();
    550           } else {
    551             outputs.emplace_back(s);
    552           }
    553         }
    554       };
    555 
    556       // The internal state of a worker thread that is not already captured
    557       // in its `WorkerState`.
    558       //
    559       // This is needed only for checkpointing purposes. We keep this
    560       // separate from `WorkerState` and guard its fields using a separate
    561       // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
    562       struct WorkerThreadState {
    563         // The output element that has been produced from the input iterator
    564         // and is waiting to be added to `WorkerState.outputs`.
    565         OutputElem output_elem;
    566 
    567         // Whether the input iterator returned an `end_of_sequence`.
    568         bool end_of_sequence = false;
    569 
    570         // Status returned from `MakeIteratorFromInputElement`.
    571         Status iterator_creation_status;
    572 
    573         // The arguments to be used to construct `iterator`.
    574         std::vector<Tensor> input;
    575 
    576         std::unique_ptr<IteratorBase> iterator;
    577 
    578         WorkerThreadState() : output_elem(Status::OK()) {}
    579       };
    580 
    581       Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
    582           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    583         if (worker_threads_.empty()) {
    584           worker_threads_.reserve(dataset()->num_threads());
    585           for (int64 i = 0; i < dataset()->num_threads(); ++i) {
    586             std::vector<Tensor> args;
    587             bool end_of_input = false;
    588             Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
    589             if (end_of_input) {
    590               input_impl_.reset();
    591               return Status::OK();
    592             }
    593             workers_[i].SetInputs(s, std::move(args));
    594             std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
    595             worker_threads_.push_back(ctx->StartThread(
    596                 strings::StrCat("tf_data_parallel_interleave_worker_", i),
    597                 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
    598             if (i < dataset()->cycle_length_) {
    599               interleave_indices_.push_back(i);
    600             } else {
    601               staging_indices_.push_back(i);
    602             }
    603           }
    604           DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
    605           DCHECK(staging_indices_.size() ==
    606                  dataset()->prefetch_input_elements_);
    607         }
    608         return Status::OK();
    609       }
    610 
    611       // Produces elements into the worker's output buffers.
    612       void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
    613                         const int64 thread_index) {
    614         // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
    615         //
    616         // 1. Any local state that may need to be checkpointed should be kept
    617         //    in `worker_thread_states_[thread_index]`.
    618         // 2. `WorkerThreadState` should contain state that is needed only for
    619         //    checkpointing, i.e., if we were to remove checkpointing support,
    620         //    we could keep that state as local variables in this thread.
    621         // 3. This thread should only read/write state at `thread_index`
    622         //    and should not access other thread states.
    623         // 4. When restoring from checkpoint, threads are started only after
    624         //    the restore is complete.
    625         // 5. Once restored from a checkpoint, the local state is edited only
    626         //    by this thread. 3 & 4 allow making assumptions like temporarily
    627         //    caching local state in this thread and using it outside a lock
    628         //    e.g. `make_new_iterator`.
    629         // 6. `ckpt_mu_` should be wisely used to create *consistent*
    630         //    checkpoint markers.
    631 
    632         // std::function arguments are copy-constructable, so we pass raw
    633         // pointers, and then immediately wrap them to ensure correct ownership.
    634         RecordStart(ctx.get());
    635         auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
    636           mutex_lock l(mu_);
    637           workers_[thread_index].cond_var.notify_all();
    638           RecordStop(ctx.get());
    639         });
    640         bool make_new_iterator;
    641         {
    642           tf_shared_lock l(ckpt_mu_);
    643           // Decide whether a new iterator should be built.
    644           // 1. If there is an existing iterator, we use it.
    645           // 2. If there was an error in iterator creation that could not be
    646           //    notified to the client we attempt to send that to the client
    647           //    first.
    648           make_new_iterator =
    649               worker_thread_states_[thread_index].iterator == nullptr &&
    650               worker_thread_states_[thread_index].iterator_creation_status.ok();
    651         }
    652         // Even though `make_new_iterator` has cached values from
    653         // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
    654         // it is safe to *read* `make_new_iterator`outside of a lock without
    655         // worrying about concurrent changes to values in
    656         // `worker_thread_states_[thread_index]`. See comment at the start of
    657         // this function for details.
    658         while (true) {
    659           // Whether creation of the iterator succeeded.
    660           Status iterator_creation_status;
    661           // 1. Build a new iterator or use the existing one.
    662           if (make_new_iterator) {
    663             // 1a. Get new input tensors or use the exiting ones.
    664             bool read_new_input;
    665             {
    666               tf_shared_lock l(ckpt_mu_);
    667               // worker_thread_states_[thread_index].input will be non-empty
    668               // if checkpointing happened at CHECKPOINT_MARKER_A.
    669               read_new_input =
    670                   worker_thread_states_[thread_index].input.empty();
    671             }
    672 
    673             if (read_new_input) {
    674               mutex_lock l(mu_);
    675               while (!cancelled_ && !workers_[thread_index].is_producing) {
    676                 RecordStop(ctx.get());
    677                 workers_[thread_index].cond_var.wait(l);
    678                 RecordStart(ctx.get());
    679               }
    680               if (cancelled_) return;
    681               // Copy the input tensors so that we do not need to block on `mu_`
    682               // when building the iterator.
    683               // We keep a copy of the input tensors in
    684               // `WorkerThreadState.input` till the iterator is in use. This is
    685               // used in `RestoreInternal` to re-build the iterator.
    686               // TODO(b/78046638): Explore ways to avoid tracking the input
    687               // tensors.
    688               tf_shared_lock ckpt_l(ckpt_mu_);
    689               worker_thread_states_[thread_index].input.swap(
    690                   workers_[thread_index].input);
    691               // CHECKPOINT_MARKER_A
    692               // We have the input tensors but have not built the iterator yet.
    693             }
    694 
    695             // 1b. Run the user defined function to produce a new iterator.
    696             {
    697               tf_shared_lock l(ckpt_mu_);
    698               worker_thread_states_[thread_index].iterator_creation_status =
    699                   MakeIteratorFromInputElement(
    700                       ctx.get(), worker_thread_states_[thread_index].input,
    701                       thread_index, *instantiated_captured_func_, prefix(),
    702                       &worker_thread_states_[thread_index].iterator);
    703               iterator_creation_status =
    704                   worker_thread_states_[thread_index].iterator_creation_status;
    705               if (!iterator_creation_status.ok()) {
    706                 worker_thread_states_[thread_index].input.clear();
    707               }
    708               // CHECKPOINT_MARKER_B
    709               // Either an iterator has been successfully built and placed in
    710               // `worker_thread_states_[thread_index].iterator` or it failed and
    711               // a non-OK status has been put in
    712               // `worker_thread_states_[thread_index].iterator_creation_status`.
    713             }
    714           } else {
    715             tf_shared_lock l(ckpt_mu_);
    716             iterator_creation_status =
    717                 worker_thread_states_[thread_index].iterator_creation_status;
    718             // Mark that we have used up the restored iterator.
    719             make_new_iterator = true;
    720           }
    721           // 2. Start producing elements or send error state to client if
    722           //    iterator creation failed.
    723           if (!iterator_creation_status.ok()) {
    724             mutex_lock l(mu_);
    725             // Wait for space in the prefetch queue.
    726             while (!cancelled_ && workers_[thread_index].outputs.size() ==
    727                                       dataset()->buffer_output_elements_) {
    728               RecordStop(ctx.get());
    729               workers_[thread_index].cond_var.wait(l);
    730               RecordStart(ctx.get());
    731             }
    732             if (cancelled_) return;
    733             tf_shared_lock ckpt_l(ckpt_mu_);
    734             workers_[thread_index].outputs.emplace_back(
    735                 iterator_creation_status);
    736             workers_[thread_index].is_producing = false;
    737             worker_thread_states_[thread_index].iterator_creation_status =
    738                 Status::OK();
    739             // CHECKPOINT_MARKER_C
    740             // Non-OK iterator creation status has been notified to the
    741             // client.
    742             workers_[thread_index].cond_var.notify_one();
    743           } else {
    744             bool end_of_sequence = false;
    745             while (!end_of_sequence) {
    746               // 3.a Produce an element!
    747               {
    748                 tf_shared_lock ckpt_l(ckpt_mu_);
    749                 if (worker_thread_states_[thread_index]
    750                         .output_elem.status.ok() &&
    751                     worker_thread_states_[thread_index]
    752                         .output_elem.output.empty() &&
    753                     !worker_thread_states_[thread_index].end_of_sequence) {
    754                   worker_thread_states_[thread_index].output_elem.status =
    755                       worker_thread_states_[thread_index].iterator->GetNext(
    756                           ctx.get(),
    757                           &worker_thread_states_[thread_index]
    758                                .output_elem.output,
    759                           &worker_thread_states_[thread_index].end_of_sequence);
    760                   end_of_sequence =
    761                       worker_thread_states_[thread_index].end_of_sequence;
    762                 } else {
    763                   end_of_sequence =
    764                       worker_thread_states_[thread_index].end_of_sequence;
    765                 }
    766                 // CHECKPOINT_MARKER_D
    767                 // An element has been read or an error or end_of_sequence has
    768                 // been received from the input iterator and is waiting to be
    769                 // sent to client.
    770               }
    771 
    772               // 3.b Make it available to the client.
    773               {
    774                 mutex_lock l(mu_);
    775 
    776                 // Wait for space in the prefetch queue.
    777                 while (!cancelled_ && workers_[thread_index].outputs.size() ==
    778                                           dataset()->buffer_output_elements_) {
    779                   RecordStop(ctx.get());
    780                   workers_[thread_index].cond_var.wait(l);
    781                   RecordStart(ctx.get());
    782                 }
    783                 if (cancelled_) return;
    784 
    785                 tf_shared_lock ckpt_l(ckpt_mu_);
    786                 workers_[thread_index].is_producing = !end_of_sequence;
    787 
    788                 // Output the element.
    789 
    790                 // Move the temporary state in WorkerThreadState to WorkerState
    791                 // and mark it as used.
    792                 if (end_of_sequence) {
    793                   worker_thread_states_[thread_index].iterator.reset();
    794                   worker_thread_states_[thread_index].input.clear();
    795                   worker_thread_states_[thread_index].end_of_sequence = false;
    796                 } else {
    797                   workers_[thread_index].outputs.emplace_back(
    798                       worker_thread_states_[thread_index].output_elem.status);
    799                   workers_[thread_index].outputs.back().output.swap(
    800                       worker_thread_states_[thread_index].output_elem.output);
    801                 }
    802                 worker_thread_states_[thread_index].output_elem.status =
    803                     Status::OK();
    804                 if (dataset()->sloppy_) {
    805                   sloppy_cond_var_.notify_one();
    806                 } else {
    807                   workers_[thread_index].cond_var.notify_one();
    808                 }
    809                 // CHECKPOINT_MARKER_E
    810                 // Output element or iterator status has been sent to the
    811                 // client.
    812               }
    813             }
    814           }
    815         }
    816       }
    817 
    818       Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
    819           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    820         string prefix = strings::StrCat("worker_", index);
    821         TF_RETURN_IF_ERROR(writer->WriteScalar(
    822             full_name(strings::StrCat(prefix, "_input_size")),
    823             workers_[index].input.size()));
    824         for (int i = 0; i < workers_[index].input.size(); ++i) {
    825           TF_RETURN_IF_ERROR(writer->WriteTensor(
    826               full_name(strings::StrCat(prefix, "_input_", i)),
    827               workers_[index].input[i]));
    828         }
    829         TF_RETURN_IF_ERROR(writer->WriteScalar(
    830             full_name(strings::StrCat(prefix, "_outputs_size")),
    831             workers_[index].outputs.size()));
    832         for (int i = 0; i < workers_[index].outputs.size(); ++i) {
    833           TF_RETURN_IF_ERROR(WriteOutputElemLocked(
    834               writer, workers_[index].outputs[i],
    835               full_name(strings::StrCat(prefix, "_outputs_", i))));
    836         }
    837         if (workers_[index].is_producing) {
    838           TF_RETURN_IF_ERROR(writer->WriteScalar(
    839               full_name(strings::StrCat(prefix, "_is_producing")), ""));
    840         }
    841         return Status::OK();
    842       }
    843 
    844       Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
    845                                    IteratorContext* ctx)
    846           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    847         string worker_prefix = strings::StrCat("worker_", index);
    848         // Restore inputs.
    849         int64 input_size;
    850         TF_RETURN_IF_ERROR(reader->ReadScalar(
    851             full_name(strings::StrCat(worker_prefix, "_input_size")),
    852             &input_size));
    853         workers_[index].input.reserve(input_size);
    854         for (int i = 0; i < input_size; ++i) {
    855           workers_[index].input.emplace_back();
    856           TF_RETURN_IF_ERROR(reader->ReadTensor(
    857               full_name(strings::StrCat(worker_prefix, "_input_", i)),
    858               &workers_[index].input.back()));
    859         }
    860         int64 outputs_size;
    861         TF_RETURN_IF_ERROR(reader->ReadScalar(
    862             full_name(strings::StrCat(worker_prefix, "_outputs_size")),
    863             &outputs_size));
    864         for (int i = 0; i < outputs_size; ++i) {
    865           workers_[index].outputs.emplace_back(Status::OK());
    866           TF_RETURN_IF_ERROR(ReadOutputElemLocked(
    867               reader, &workers_[index].outputs.back(),
    868               full_name(strings::StrCat(worker_prefix, "_outputs_", i))));
    869         }
    870         if (reader->Contains(
    871                 full_name(strings::StrCat(worker_prefix, "_is_producing")))) {
    872           workers_[index].is_producing = true;
    873         } else {
    874           workers_[index].is_producing = false;
    875         }
    876         return Status::OK();
    877       }
    878 
    879       Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer,
    880                                           int index)
    881           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    882         string prefix = strings::StrCat("worker_thread_", index);
    883         if (worker_thread_states_[index].iterator != nullptr) {
    884           TF_RETURN_IF_ERROR(
    885               SaveInput(writer, worker_thread_states_[index].iterator));
    886         } else {
    887           TF_RETURN_IF_ERROR(writer->WriteScalar(
    888               full_name(strings::StrCat(prefix, "_iterator_exhausted")), ""));
    889         }
    890         TF_RETURN_IF_ERROR(writer->WriteScalar(
    891             full_name(strings::StrCat(prefix, "_input_size")),
    892             worker_thread_states_[index].input.size()));
    893         for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
    894           TF_RETURN_IF_ERROR(writer->WriteTensor(
    895               full_name(strings::StrCat(prefix, "_input_", i)),
    896               worker_thread_states_[index].input[i]));
    897         }
    898         TF_RETURN_IF_ERROR(WriteStatusLocked(
    899             writer, strings::StrCat(prefix, "_iterator_creation_status"),
    900             worker_thread_states_[index].iterator_creation_status));
    901         TF_RETURN_IF_ERROR(WriteOutputElemLocked(
    902             writer, worker_thread_states_[index].output_elem,
    903             full_name(strings::StrCat(prefix, "_output"))));
    904         if (worker_thread_states_[index].end_of_sequence) {
    905           TF_RETURN_IF_ERROR(writer->WriteScalar(
    906               full_name(strings::StrCat(prefix, "_end_of_sequence")), ""));
    907         }
    908         return Status::OK();
    909       }
    910 
    911       Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
    912                                          IteratorContext* ctx)
    913           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    914         string worker_prefix = strings::StrCat("worker_thread_", index);
    915         // Restore inputs.
    916         int64 input_size;
    917         TF_RETURN_IF_ERROR(reader->ReadScalar(
    918             full_name(strings::StrCat(worker_prefix, "_input_size")),
    919             &input_size));
    920         worker_thread_states_[index].input.reserve(input_size);
    921         for (int i = 0; i < input_size; ++i) {
    922           worker_thread_states_[index].input.emplace_back();
    923           TF_RETURN_IF_ERROR(reader->ReadTensor(
    924               full_name(strings::StrCat(worker_prefix, "_input_", i)),
    925               &worker_thread_states_[index].input.back()));
    926         }
    927         // Restore iterator.
    928         if (reader->Contains(full_name(
    929                 strings::StrCat(worker_prefix, "_iterator_exhausted")))) {
    930           worker_thread_states_[index].iterator.reset();
    931         } else {
    932           std::unique_ptr<IteratorBase> iterator;
    933           Status s = MakeIteratorFromInputElement(
    934               ctx, worker_thread_states_[index].input, index,
    935               *instantiated_captured_func_, prefix(), &iterator);
    936           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
    937           worker_thread_states_[index].iterator.swap(iterator);
    938         }
    939         TF_RETURN_IF_ERROR(ReadStatusLocked(
    940             reader, strings::StrCat(worker_prefix, "_iterator_creation_status"),
    941             &worker_thread_states_[index].iterator_creation_status));
    942         TF_RETURN_IF_ERROR(ReadOutputElemLocked(
    943             reader, &worker_thread_states_[index].output_elem,
    944             full_name(strings::StrCat(worker_prefix, "_output"))));
    945         if (reader->Contains(full_name(
    946                 strings::StrCat(worker_prefix, "_end_of_sequence")))) {
    947           worker_thread_states_[index].end_of_sequence = true;
    948         } else {
    949           worker_thread_states_[index].end_of_sequence = false;
    950         }
    951         return Status::OK();
    952       }
    953 
    954       Status WriteOutputElemLocked(IteratorStateWriter* writer,
    955                                    const OutputElem& output_elem,
    956                                    const string& prefix)
    957           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    958         TF_RETURN_IF_ERROR(WriteStatusLocked(
    959             writer, strings::StrCat(prefix, "_status"), output_elem.status));
    960         TF_RETURN_IF_ERROR(
    961             writer->WriteScalar(strings::StrCat(prefix, "_output_size"),
    962                                 output_elem.output.size()));
    963         for (int i = 0; i < output_elem.output.size(); ++i) {
    964           TF_RETURN_IF_ERROR(writer->WriteTensor(
    965               strings::StrCat(prefix, "_output_", i), output_elem.output[i]));
    966         }
    967         return Status::OK();
    968       }
    969 
    970       Status ReadOutputElemLocked(IteratorStateReader* reader,
    971                                   OutputElem* output_elem, const string& prefix)
    972           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    973         TF_RETURN_IF_ERROR(ReadStatusLocked(
    974             reader, strings::StrCat(prefix, "_status"), &output_elem->status));
    975         int64 output_size;
    976         TF_RETURN_IF_ERROR(reader->ReadScalar(
    977             strings::StrCat(prefix, "_output_size"), &output_size));
    978         output_elem->output.reserve(output_size);
    979         for (int i = 0; i < output_size; ++i) {
    980           output_elem->output.emplace_back();
    981           TF_RETURN_IF_ERROR(
    982               reader->ReadTensor(strings::StrCat(prefix, "_output_", i),
    983                                  &output_elem->output.back()));
    984         }
    985         return Status::OK();
    986       }
    987 
    988       Status WriteStatusLocked(IteratorStateWriter* writer,
    989                                const string& prefix, const Status& status)
    990           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
    991         TF_RETURN_IF_ERROR(
    992             writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
    993                                 static_cast<int64>(status.code())));
    994         if (!status.ok()) {
    995           TF_RETURN_IF_ERROR(
    996               writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")),
    997                                   status.error_message()));
    998         }
    999         return Status::OK();
   1000       }
   1001 
   1002       Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix,
   1003                               Status* status)
   1004           EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
   1005         int64 code_int;
   1006         TF_RETURN_IF_ERROR(reader->ReadScalar(
   1007             full_name(strings::StrCat(prefix, "_code")), &code_int));
   1008         error::Code code = static_cast<error::Code>(code_int);
   1009 
   1010         if (code != error::Code::OK) {
   1011           string error_message;
   1012           TF_RETURN_IF_ERROR(reader->ReadScalar(
   1013               full_name(strings::StrCat(prefix, "_msg")), &error_message));
   1014           *status = Status(code, error_message);
   1015         } else {
   1016           *status = Status::OK();
   1017         }
   1018         return Status::OK();
   1019       }
   1020 
   1021       // Mutex & condition variable to guard mutable iterator internals and
   1022       // coordinate among worker threads and client thread[s].
   1023       mutex mu_ ACQUIRED_BEFORE(ckpt_mu_);
   1024       // The main thread waits on this condition variable if running in sloppy
   1025       // mode and no values are available.
   1026       condition_variable sloppy_cond_var_;
   1027       // Mutex used to wait for a consistent state while checkpointing.
   1028       // Only Save and Restore require an exclusive lock on this mutex. In
   1029       // other scenarios we just acquire a shared lock so the pipeline's
   1030       // performance should not be affected in the absence of checkpointing.
   1031       // A thread must not wait on any condition variable while holding
   1032       // `ckpt_mu_` in either shared or exclusive modes.
   1033       mutex ckpt_mu_;
   1034 
   1035       // The iterator producing elements which are converted to datasets by
   1036       // the dataset()->captured_func_ then interleaved together.
   1037       // input_impl_ is reset when we have exhausted its input.
   1038       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
   1039 
   1040       std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
   1041 
   1042       // The WorkerState structs the worker threads operate on.
   1043       // workers_ elements are in at most one of interleave_ and staging_.
   1044       std::vector<WorkerState> workers_ GUARDED_BY(mu_);
   1045 
   1046       // Stores the temporary state of WorkerThreads which is not stored in
   1047       // WorkerState. This is used for checkpointing purposes only.
   1048       std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_);
   1049 
   1050       // Indices in `workers_` of iterators to interleave.
   1051       std::vector<int64> interleave_indices_ GUARDED_BY(mu_);
   1052       // Indices in `workers_` of prefetched iterators.
   1053       std::deque<int64> staging_indices_ GUARDED_BY(mu_);
   1054 
   1055       // The index into output_elements_ for next element to produce.
   1056       size_t next_index_ GUARDED_BY(mu_) = 0;
   1057       // The number of items produced so far within the block
   1058       size_t block_count_ GUARDED_BY(mu_) = 0;
   1059       // Flag to instruct the worker threads to exit.
   1060       bool cancelled_ GUARDED_BY(mu_) = false;
   1061       // The worker threads. This must be last to ensure the
   1062       // threads have exited before any other members are deallocated.
   1063       // TODO(b/65178177): Avoid allocating additional threads.
   1064       std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
   1065     };
   1066 
   1067     const DatasetBase* const input_;
   1068     const NameAttrList interleave_func_;
   1069     const std::unique_ptr<CapturedFunction> captured_func_;
   1070     const int64 cycle_length_;
   1071     const int64 block_length_;
   1072     const bool sloppy_;
   1073     const int64 buffer_output_elements_;
   1074     const int64 prefetch_input_elements_;
   1075     const DataTypeVector output_types_;
   1076     const std::vector<PartialTensorShape> output_shapes_;
   1077   };
   1078 
   1079   DataTypeVector output_types_;
   1080   std::vector<PartialTensorShape> output_shapes_;
   1081   NameAttrList interleave_func_;
   1082 };
   1083 
   1084 REGISTER_KERNEL_BUILDER(
   1085     Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
   1086     ParallelInterleaveDatasetOp);
   1087 
   1088 }  // namespace
   1089 }  // namespace data
   1090 }  // namespace tensorflow
   1091