Home | History | Annotate | Download | only in experimental
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include <atomic>
     18 #include <utility>
     19 
     20 #include "tensorflow/core/common_runtime/function.h"
     21 #include "tensorflow/core/framework/allocator.h"
     22 #include "tensorflow/core/framework/dataset.h"
     23 #include "tensorflow/core/framework/partial_tensor_shape.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/kernels/data/captured_function.h"
     26 #include "tensorflow/core/kernels/inplace_ops_functor.h"
     27 #include "tensorflow/core/lib/core/blocking_counter.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/gtl/cleanup.h"
     30 #include "tensorflow/core/lib/random/random.h"
     31 #include "tensorflow/core/lib/strings/strcat.h"
     32 #include "tensorflow/core/platform/cpu_info.h"
     33 #include "tensorflow/core/platform/numa.h"
     34 #include "tensorflow/core/platform/tracing.h"
     35 
     36 namespace tensorflow {
     37 namespace data {
     38 namespace {
     39 
     40 // kWindowSize is the fixed constant controlling the number of batch outputs
     41 // each NumaWorkerBlock may be processing at a time. This is currently a
     42 // constant and not user configurable to enable future performance optimizations
     43 // in the implementation.
     44 const int64 kWindowSize = 10;
     45 
     46 // Define a helper for more consistent logging.
     47 #define WORKER_VLOG(verbose_level)                                           \
     48   VLOG(verbose_level) << "WorkerThread (" << numa_node << ", " << thread_num \
     49                       << "): "
     50 
     51 // See documentation in ../ops/dataset_ops.cc for a high-level
     52 // description of the following op.
     53 
     54 class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
     55  public:
     56   explicit NumaMapAndBatchDatasetOp(OpKernelConstruction* ctx)
     57       : UnaryDatasetOpKernel(ctx) {
     58     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
     59     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     60     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
     61     // TODO(saeta): Implement support for preserve_cardinality logic.
     62     OP_REQUIRES_OK(
     63         ctx, ctx->GetAttr("preserve_cardinality", &preserve_cardinality_));
     64   }
     65 
     66  protected:
     67   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
     68                    DatasetBase** output) override {
     69     int64 batch_size;
     70     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
     71     OP_REQUIRES(
     72         ctx, batch_size > 0,
     73         errors::InvalidArgument("batch_size must be greater than zero."));
     74 
     75     int64 num_parallel_calls;
     76     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
     77                                             &num_parallel_calls));
     78     OP_REQUIRES(
     79         ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune,
     80         errors::InvalidArgument(
     81             "num_parallel_calls must be greater than zero."));
     82 
     83     bool drop_remainder;
     84     OP_REQUIRES_OK(ctx,
     85                    ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
     86 
     87     std::unique_ptr<CapturedFunction> captured_func;
     88     OP_REQUIRES_OK(
     89         ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
     90                                       /* use_inter_op_parallelism = */ false,
     91                                       &captured_func));
     92 
     93     *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
     94                           drop_remainder, output_types_, output_shapes_, func_,
     95                           std::move(captured_func));
     96   }
     97 
     98  private:
     99   class Dataset : public DatasetBase {
    100    public:
    101     Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
    102             int64 num_parallel_calls, bool drop_remainder,
    103             const DataTypeVector& output_types,
    104             const std::vector<PartialTensorShape>& output_shapes,
    105             const NameAttrList& func,
    106             std::unique_ptr<CapturedFunction> captured_func)
    107         : DatasetBase(DatasetContext(ctx)),
    108           input_(input),
    109           batch_size_(batch_size),
    110           num_parallel_calls_(num_parallel_calls),
    111           drop_remainder_(drop_remainder),
    112           output_types_(output_types),
    113           output_shapes_(output_shapes),
    114           func_(func),
    115           captured_func_(std::move(captured_func)) {
    116       input_->Ref();
    117     }
    118 
    119     ~Dataset() override { input_->Unref(); }
    120 
    121     std::unique_ptr<IteratorBase> MakeIteratorInternal(
    122         const string& prefix) const override {
    123       return absl::make_unique<Iterator>(
    124           Iterator::Params{this, strings::StrCat(prefix, "::NumaMapAndBatch")});
    125     }
    126 
    127     const DataTypeVector& output_dtypes() const override {
    128       return output_types_;
    129     }
    130 
    131     const std::vector<PartialTensorShape>& output_shapes() const override {
    132       return output_shapes_;
    133     }
    134 
    135     string DebugString() const override {
    136       return "NumaMapAndBatchDatasetOp::Dataset";
    137     }
    138 
    139     // TODO(b/120482302): Note that this is inaccurate until
    140     // NumaMapAndBatchMapDataset modified to preserve cardinality.
    141     int64 Cardinality() const override {
    142       int64 n = input_->Cardinality();
    143       if (n == kInfiniteCardinality || n == kUnknownCardinality) {
    144         return n;
    145       }
    146       return n / batch_size_ +
    147              (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
    148     }
    149 
    150    protected:
    151     Status AsGraphDefInternal(SerializationContext* ctx,
    152                               DatasetGraphDefBuilder* b,
    153                               Node** output) const override {
    154       TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
    155       Node* input_graph_node = nullptr;
    156       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
    157       Node* batch_size_node;
    158       TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
    159       Node* num_parallel_calls_node;
    160       TF_RETURN_IF_ERROR(
    161           b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
    162       Node* drop_remainder_node;
    163       TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
    164 
    165       DataTypeVector other_arguments_types;
    166       other_arguments_types.reserve(captured_func_->captured_inputs().size());
    167       std::vector<Node*> other_arguments;
    168       other_arguments.reserve(captured_func_->captured_inputs().size());
    169       for (const Tensor& t : captured_func_->captured_inputs()) {
    170         Node* node;
    171         DatasetBase* input;
    172         Status s = GetDatasetFromVariantTensor(t, &input);
    173         if (s.ok()) {
    174           TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node));
    175         } else {
    176           TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
    177         }
    178         other_arguments.emplace_back(node);
    179         other_arguments_types.emplace_back(t.dtype());
    180       }
    181       AttrValue f;
    182       b->BuildAttrValue(func_, &f);
    183       AttrValue other_arguments_types_attr;
    184       b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
    185 
    186       TF_RETURN_IF_ERROR(b->AddDataset(
    187           this,
    188           {std::make_pair(0, input_graph_node),
    189            std::make_pair(2, batch_size_node),
    190            std::make_pair(3, num_parallel_calls_node),
    191            std::make_pair(4, drop_remainder_node)},  // Single tensor inputs.
    192           {std::make_pair(1, other_arguments)},      // Tensor list inputs.
    193           {std::make_pair("f", f),
    194            std::make_pair("Targuments", other_arguments_types_attr)},  // Attrs
    195           output));
    196       return Status::OK();
    197     }
    198 
    199    private:
    200     class Iterator : public DatasetIterator<Dataset> {
    201      public:
    202       explicit Iterator(const Params& params)
    203           : DatasetIterator<Dataset>(params),
    204             mu_(std::make_shared<mutex>()),
    205             autotune_cond_var_(std::make_shared<condition_variable>()),
    206             num_parallel_calls_(std::make_shared<model::SharedState>(
    207                 params.dataset->num_parallel_calls_, mu_, autotune_cond_var_)) {
    208       }
    209 
    210       ~Iterator() override {
    211         mutex_lock l(*mu_);
    212         cancelled_ = true;
    213         VLOG(3) << "NumaMapAndBatchIterator::~Iterator: cancelling operations.";
    214         for (size_t i = 0; i < workers_.size(); ++i) {
    215           workers_[i]->manager.Cancel();
    216         }
    217         VLOG(3) << "NumaMapAndBatchIterator::~Iterator: waiting for threads to "
    218                    "shut down.";
    219       }
    220 
    221       Status Initialize(IteratorContext* ctx) override {
    222         mutex_lock l(*mu_);
    223         if (num_parallel_calls_->value == model::kAutoTune) {
    224           num_parallel_calls_->value = ctx->runner_threadpool_size();
    225         }
    226         TF_RETURN_IF_ERROR(
    227             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
    228         TF_RETURN_IF_ERROR(dataset()->captured_func_->Instantiate(
    229             ctx, &instantiated_captured_func_));
    230         return Status::OK();
    231       }
    232 
    233       Status GetNextInternal(IteratorContext* ctx,
    234                              std::vector<Tensor>* out_tensors,
    235                              bool* end_of_sequence) override {
    236         auto cleanup = gtl::MakeCleanup(
    237             [] { VLOG(3) << "GetNextInternal call returning."; });
    238         NumaWorkerBlock* worker = nullptr;
    239         {
    240           mutex_lock l(*mu_);
    241           VLOG(3) << "GetNextInternal call; current block: " << cur_block_;
    242           if (global_end_of_input_) {
    243             *end_of_sequence = true;
    244             return Status::OK();
    245           }
    246           TF_RETURN_IF_ERROR(EnsureBackgroundThreadsStarted(ctx));
    247           worker = workers_[cur_block_].get();
    248           cur_block_ = (cur_block_ + 1) % workers_.size();
    249         }
    250         bool global_end_of_input_local = false;
    251         Status s = worker->manager.GetBatch(ctx, dataset()->drop_remainder_,
    252                                             &global_end_of_input_local,
    253                                             out_tensors, end_of_sequence);
    254         if (global_end_of_input_local) {
    255           mutex_lock l(*mu_);
    256           global_end_of_input_ = global_end_of_input_local;
    257         }
    258         return s;
    259       }
    260 
    261      protected:
    262       std::shared_ptr<model::Node> CreateNode(
    263           IteratorContext* ctx, model::Node::Args args) const override {
    264         return model::MakeAsyncKnownRatioNode(
    265             std::move(args), dataset()->batch_size_,
    266             {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
    267                                   /*max=*/ctx->runner_threadpool_size())});
    268       }
    269 
    270       Status SaveInternal(IteratorStateWriter* writer) override {
    271         mutex_lock l(*mu_);
    272         for (size_t i = 0; i < workers_.size(); ++i) {
    273           if (!workers_[i]->manager.Quiesce()) {
    274             return errors::Cancelled(
    275                 "The iterator was deleted before it could reach a "
    276                 "checkpointable state.");
    277           }
    278         }
    279 
    280         TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
    281         TF_RETURN_IF_ERROR(
    282             writer->WriteScalar(full_name("num_workers"), workers_.size()));
    283 
    284         for (size_t i = 0; i < workers_.size(); ++i) {
    285           size_t index = (cur_block_ + i) % workers_.size();
    286           TF_RETURN_IF_ERROR(workers_[index]->manager.Save(writer, this, i));
    287         }
    288         return Status::OK();
    289       }
    290 
    291       Status RestoreInternal(IteratorContext* ctx,
    292                              IteratorStateReader* reader) override {
    293         mutex_lock l(*mu_);
    294         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
    295         int64 num_workers = -1;
    296         TF_RETURN_IF_ERROR(
    297             reader->ReadScalar(full_name("num_workers"), &num_workers));
    298         // Note: num_workers can be 0 if the iterator wasn't started when
    299         // first checkpointed.
    300         if (num_workers < 0) {
    301           return errors::DataLoss(
    302               "When restoring from checkpoint, we encountered a data "
    303               "consistency error: num_workers has an invalid value: ",
    304               num_workers);
    305         }
    306         if (port::NUMAEnabled()) {
    307           int actual_numa_domains = port::NUMANumNodes();
    308           if (actual_numa_domains != num_workers && num_workers > 0) {
    309             LOG(WARNING) << "# NUMA domains mismatch when restoring from "
    310                             "checkpoint: checkpoint has "
    311                          << num_workers
    312                          << " NUMA domains, while this host has: "
    313                          << actual_numa_domains << " NUMA domains.";
    314           }
    315         }
    316         if (num_workers > 1 && !port::NUMAEnabled()) {
    317           LOG(WARNING) << "NUMA is not enabled for this process, but restoring "
    318                           "a checkpoint that assumes "
    319                        << num_workers << " NUMA domains.";
    320         }
    321         workers_.resize(num_workers);
    322         for (size_t i = 0; i < num_workers; ++i) {
    323           workers_[i] = absl::make_unique<NumaWorkerBlock>(this);
    324           TF_RETURN_IF_ERROR(
    325               workers_[i]->manager.Restore(ctx, reader, this, i));
    326         }
    327         cur_block_ = 0;
    328         return Status::OK();
    329       }
    330 
    331      private:
    332       // NumaBlockManager manages all the state for a set of threads pinned to a
    333       // single NUMA domain.
    334       //
    335       // The methods can be divided into 3 categories based on who should call
    336       // them:
    337       //
    338       //  (1) RunnerThread: WaitForInputSpace, PushInputs, SetEndOfInput.
    339       //  (2) WorkerThread: RetrieveInput, GetBatchTensors.
    340       //      RecordBatchEntryComplete
    341       //  (3) Client threads: GetBatch, Cancel, Save, Restore.
    342       //
    343       // Internally, we manage state in a circular buffer of size `kWindowSize`.
    344       // There are 3 pointers into the circular buffer, and must maintain the
    345       // following order: (1) next_input_batch_ (corresponding to the next input
    346       // batch to be pulled from the input iterator), (2) next_input_
    347       // (corresponding to the batch the WorkerThreads should pull from for
    348       // their next inputs), and (3) next_output_ corresponding to the next
    349       // value to be consumed by the output iterator.
    350       //
    351       // Methods return errors::Cancelled if the iteration is cancelled before
    352       // completing.
    353       //
    354       // NumaBlockManager is thread safe.
    355       class NumaBlockManager {
    356        public:
    357         explicit NumaBlockManager(Iterator* itr) : itr_(itr) {}
    358 
    359         // WaitForInputSpace blocks until there is space in the circular buffer
    360         // to begin processing a new batch of elements.
    361         //
    362         // Returns true when there is space, false if the Iterator is cancelled.
    363         bool WaitForInputSpace(IteratorContext* ctx) {
    364           mutex_lock l(mu_);
    365 
    366           size_t next = (next_input_batch_ + 1) % kWindowSize;
    367           DCHECK(next < kWindowSize) << next;
    368 
    369           // Wait for space in the circular buffer.
    370           while (!cancelled_ && batches_[next].state != BatchState::kEmpty) {
    371             VLOG(3) << "Waiting for input space; next: " << next
    372                     << ", next_output_: " << next_output_
    373                     << ", next_input_batch_: " << next_input_batch_;
    374             itr_->RecordStop(ctx);
    375             runner_cond_var_.wait(l);
    376             itr_->RecordStart(ctx);
    377           }
    378           if (cancelled_) {
    379             VLOG(3) << "WaitForInputSpace cancelled.";
    380             return false;
    381           }
    382 
    383           DCHECK(batches_[next].state == BatchState::kEmpty);
    384 
    385           next_input_batch_ = next;
    386           return true;
    387         }
    388 
    389         // PushInputs sets the inputs for the next batch as retrieved from the
    390         // input iterator.
    391         void PushInputs(const Status& status,
    392                         std::vector<std::vector<Tensor>> inputs) {
    393           mutex_lock l(mu_);
    394 
    395           DCHECK(next_input_ < kWindowSize) << next_input_;
    396           DCHECK(batches_[next_input_batch_].state == BatchState::kEmpty);
    397           DCHECK(batches_[next_input_batch_].next_input_to_process == 0)
    398               << batches_[next_input_batch_].next_input_to_process;
    399           DCHECK(batches_[next_input_batch_].status.ok())
    400               << batches_[next_input_batch_].status;
    401 
    402           batches_[next_input_batch_].inputs.swap(inputs);
    403           batches_[next_input_batch_].state = BatchState::kInputsFilled;
    404           batches_[next_input_batch_].status.Update(status);
    405           if (batches_[next_input_batch_].status.ok()) {
    406             worker_cond_var_.notify_all();
    407           } else {
    408             client_cond_var_.notify_all();
    409             batches_[next_input_batch_].error_index = 0;
    410           }
    411         }
    412 
    413         // SetEndOfInput records the fact that we have reached the end of the
    414         // input iterator, and that we should return end_of_sequence = true when
    415         // we have exhaused all buffered batches.
    416         void SetEndOfInput() {
    417           mutex_lock l(mu_);
    418           reached_eof_ = true;
    419           worker_cond_var_.notify_all();
    420           client_cond_var_.notify_all();
    421         }
    422 
    423         // RetrieveInput gets the next input tuple to be mapped by a worker
    424         // thread.
    425         //
    426         // Returns true if an input was retrieved, false if the iterator has
    427         // been cancelled.
    428         bool RetrieveInput(IteratorContext* ctx, std::vector<Tensor>* input,
    429                            uint64* index, size_t* sequence_number) {
    430           mutex_lock l(mu_);
    431 
    432           // Wait for inputs to be ready.
    433           while (!cancelled_ &&
    434                  batches_[next_input_].state != BatchState::kInputsFilled) {
    435             itr_->RecordStop(ctx);
    436             worker_cond_var_.wait(l);
    437             itr_->RecordStart(ctx);
    438           }
    439 
    440           if (cancelled_) {
    441             return false;
    442           }
    443 
    444           DCHECK(batches_[next_input_].next_input_to_process <
    445                  batches_[next_input_].inputs.size())
    446               << "next_input_: " << next_input_ << ", next_input_to_process: "
    447               << batches_[next_input_].next_input_to_process
    448               << ", inputs.size(): " << batches_[next_input_].inputs.size()
    449               << ", state: " << static_cast<int32>(batches_[next_input_].state)
    450               << ", this: " << this;
    451           *index = batches_[next_input_].next_input_to_process;
    452           *sequence_number = next_input_;
    453           input->swap(batches_[next_input_]
    454                           .inputs[batches_[next_input_].next_input_to_process]);
    455           // Increment pointers.
    456           batches_[next_input_].next_input_to_process++;
    457 
    458           if (batches_[next_input_].next_input_to_process ==
    459               batches_[next_input_].inputs.size()) {
    460             batches_[next_input_].state = BatchState::kAllMapsStarted;
    461             next_input_ = (next_input_ + 1) % kWindowSize;
    462           }
    463           return true;
    464         }
    465 
    466         // GetBatchTensors returns a pointer to the output batch tensors for the
    467         // worker thread to copy into.
    468         //
    469         // allocate_output is a function taking a batch size, and a pointer to
    470         // the output tuple of Tensors to allocate them. The allocate_output
    471         // function is called at most once per output batch.
    472         std::vector<Tensor>* GetBatchTensors(
    473             size_t sequence_number,
    474             std::function<void(size_t, std::vector<Tensor>*)> allocate_output) {
    475           mutex_lock l(mu_);
    476           DCHECK(sequence_number < kWindowSize) << sequence_number;
    477           DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled ||
    478                  batches_[sequence_number].state == BatchState::kAllMapsStarted)
    479               << sequence_number;
    480 
    481           if (batches_[sequence_number].outputs.empty()) {
    482             allocate_output(batches_[sequence_number].inputs.size(),
    483                             &batches_[sequence_number].outputs);
    484           }
    485           return &batches_[sequence_number].outputs;
    486         }
    487 
    488         // RecordBatchEntryComplete records an element of the batch has finished
    489         // copying into the output tensors.
    490         void RecordBatchEntryComplete(size_t sequence_number, uint64 index,
    491                                       Status s) {
    492           mutex_lock l(mu_);
    493           DCHECK(sequence_number < kWindowSize) << sequence_number;
    494           DCHECK(batches_[sequence_number].state == BatchState::kInputsFilled ||
    495                  batches_[sequence_number].state == BatchState::kAllMapsStarted)
    496               << sequence_number;
    497 
    498           batches_[sequence_number].num_outputs_complete++;
    499           if (!s.ok() && batches_[sequence_number].error_index > index) {
    500             batches_[sequence_number].status = s;
    501             batches_[sequence_number].error_index = index;
    502           }
    503 
    504           if (batches_[sequence_number].num_outputs_complete ==
    505               batches_[sequence_number].inputs.size()) {
    506             DCHECK(batches_[sequence_number].state ==
    507                    BatchState::kAllMapsStarted);
    508             batches_[sequence_number].state = BatchState::kOutputsComplete;
    509             batches_[sequence_number].inputs.clear();  // Eagerly save memory.
    510             batches_[sequence_number].inputs.shrink_to_fit();
    511             client_cond_var_.notify_all();
    512           }
    513         }
    514 
    515         // GetBatch retrieves the next output batch tensors.
    516         Status GetBatch(IteratorContext* ctx, bool drop_remainder,
    517                         bool* global_eof, std::vector<Tensor>* out_tensor,
    518                         bool* end_of_sequence) {
    519           mutex_lock l(mu_);
    520           // Wait until one of 3 conditions occurs:
    521           //  (1) we're cancelled.
    522           //  (2) the state becomes kOutputsComplete
    523           //  (3) state is empty && reached_eof.
    524           while (!cancelled_ &&
    525                  batches_[next_output_].state != BatchState::kOutputsComplete &&
    526                  !(reached_eof_ &&
    527                    batches_[next_output_].state == BatchState::kEmpty)) {
    528             VLOG(3) << "Waiting in GetBatch.";
    529             itr_->RecordStop(ctx);
    530             client_cond_var_.wait(l);
    531             itr_->RecordStart(ctx);
    532           }
    533 
    534           if (cancelled_) {
    535             return errors::Cancelled(
    536                 "Cancelled in NumaMapAndBatch::GetNext call.");
    537           }
    538 
    539           if (reached_eof_ &&
    540               batches_[next_output_].state == BatchState::kEmpty) {
    541             VLOG(4) << "GetBatch returning end of sequence.";
    542             *end_of_sequence = true;
    543             *global_eof = true;
    544             return Status::OK();
    545           }
    546 
    547           VLOG(3) << "Returning output index: " << next_output_
    548                   << ", this: " << this;
    549 
    550           *end_of_sequence = false;
    551           Status s = batches_[next_output_].status;
    552           if (s.ok()) {
    553             out_tensor->swap(batches_[next_output_].outputs);
    554           }
    555           // Handle early termination.
    556           if (errors::IsOutOfRange(s)) {
    557             *global_eof = true;
    558             s = Status::OK();
    559             if (drop_remainder || batches_[next_output_].error_index == 0) {
    560               *end_of_sequence = true;
    561             } else {
    562               std::vector<Tensor> true_outputs;
    563               for (size_t i = 0; i < batches_[next_output_].outputs.size();
    564                    ++i) {
    565                 TensorShape component_shape(
    566                     batches_[next_output_].outputs[i].shape());
    567                 component_shape.set_dim(0, batches_[next_output_].error_index);
    568                 AllocatorAttributes attr;
    569                 attr.set_gpu_compatible(true);
    570                 true_outputs.emplace_back(
    571                     ctx->allocator(attr),
    572                     batches_[next_output_].outputs[i].dtype(), component_shape);
    573                 TF_RETURN_IF_ERROR(CopyPartialBatch(
    574                     &true_outputs.back(), batches_[next_output_].outputs[i],
    575                     batches_[next_output_].error_index));
    576               }
    577               out_tensor->swap(true_outputs);
    578             }
    579           }
    580 
    581           batches_[next_output_].Reset();
    582           next_output_ = (next_output_ + 1) % kWindowSize;
    583           runner_cond_var_.notify_all();
    584 
    585           return s;
    586         }
    587 
    588         void Cancel() {
    589           mutex_lock l(mu_);
    590           VLOG(3) << "Cancelling NUMA block.";
    591           cancelled_ = true;
    592           runner_cond_var_.notify_all();
    593           worker_cond_var_.notify_all();
    594           client_cond_var_.notify_all();
    595         }
    596 
    597         // Waits until all the worker threads have completed their work and all
    598         // internal state has reached a "safe-point" where we can safely
    599         // checkpoint.
    600         //
    601         // Returns true if completed successfully, false if cancelled while
    602         // waiting.
    603         bool Quiesce() {
    604           mutex_lock l(mu_);
    605           VLOG(3) << "Waiting until the operations have quiesced.";
    606           while (!cancelled_ && !AllMapOperationsFinished()) {
    607             client_cond_var_.wait(l);
    608           }
    609           if (cancelled_) {
    610             return false;
    611           }
    612           return true;
    613         }
    614 
    615         Status Save(IteratorStateWriter* writer, Iterator* itr, size_t index) {
    616           mutex_lock l(mu_);
    617           string prefix = itr->full_name(strings::StrCat("numa_block_", index));
    618           if (reached_eof_) {
    619             TF_RETURN_IF_ERROR(writer->WriteScalar(
    620                 strings::StrCat(prefix, "_end_of_input"), ""));
    621           }
    622           for (size_t i = 0; i < kWindowSize; ++i) {
    623             size_t index = (next_output_ + i) % kWindowSize;
    624             if (batches_[index].state == BatchState::kEmpty) {
    625               break;
    626             }
    627             string batch_prefix = strings::StrCat(prefix, "_batch_", i);
    628             TF_RETURN_IF_ERROR(writer->WriteScalar(
    629                 strings::StrCat(batch_prefix, "_code"),
    630                 static_cast<int64>(batches_[index].status.code())));
    631             if (!batches_[index].status.ok()) {
    632               TF_RETURN_IF_ERROR(
    633                   writer->WriteScalar(strings::StrCat(batch_prefix, "_msg"),
    634                                       batches_[index].status.error_message()));
    635               TF_RETURN_IF_ERROR(writer->WriteScalar(
    636                   strings::StrCat(batch_prefix, "_error_index"),
    637                   batches_[index].error_index));
    638             }
    639 
    640             TF_RETURN_IF_ERROR(writer->WriteScalar(
    641                 strings::StrCat(batch_prefix, "_output_size"),
    642                 batches_[index].outputs.size()));
    643             for (size_t j = 0; j < batches_[index].outputs.size(); ++j) {
    644               string tensor_prefix =
    645                   strings::StrCat(batch_prefix, "_output_", j);
    646               if (!batches_[index].status.ok()) {
    647                 DCHECK(batches_[index].error_index >= 0 &&
    648                        batches_[index].error_index <
    649                            itr_->dataset()->batch_size_);
    650                 // If the batch is not full, we only store the first
    651                 // `error_index` values. The rest of the batch tensor might not
    652                 // be initialized, and accessing that will raise msan errors.
    653                 TF_RETURN_IF_ERROR(writer->WriteTensor(
    654                     tensor_prefix, batches_[index].outputs[j].Slice(
    655                                        0, batches_[index].error_index)));
    656               } else {
    657                 TF_RETURN_IF_ERROR(writer->WriteTensor(
    658                     tensor_prefix, batches_[index].outputs[j]));
    659               }
    660             }
    661           }
    662           return Status::OK();
    663         }
    664 
    665         Status Restore(IteratorContext* ctx, IteratorStateReader* reader,
    666                        Iterator* itr, size_t index) {
    667           mutex_lock l(mu_);
    668           if (reached_eof_) {
    669             return errors::FailedPrecondition(
    670                 "Already reached the end of the sequence.");
    671           }
    672           string prefix = itr->full_name(strings::StrCat("numa_block_", index));
    673           reached_eof_ =
    674               reader->Contains(strings::StrCat(prefix, "_end_of_input"));
    675           for (size_t i = 0; i < kWindowSize; ++i) {
    676             string batch_prefix = strings::StrCat(prefix, "_batch_", i);
    677             if (!reader->Contains(strings::StrCat(batch_prefix, "_code"))) {
    678               break;
    679             }
    680             Batch batch;
    681             batch.state = BatchState::kOutputsComplete;
    682             int64 code_int;
    683             TF_RETURN_IF_ERROR(reader->ReadScalar(
    684                 strings::StrCat(batch_prefix, "_code"), &code_int));
    685             error::Code code = static_cast<error::Code>(code_int);
    686             if (code != error::Code::OK) {
    687               string error_message;
    688               TF_RETURN_IF_ERROR(reader->ReadScalar(
    689                   strings::StrCat(batch_prefix, "_msg"), &error_message));
    690               batch.status = Status(code, error_message);
    691               int64 error_index_int = -1;
    692               TF_RETURN_IF_ERROR(reader->ReadScalar(
    693                   strings::StrCat(batch_prefix, "_error_index"),
    694                   &error_index_int));
    695               if (error_index_int < 0 ||
    696                   error_index_int > itr->dataset()->batch_size_) {
    697                 return errors::FailedPrecondition(
    698                     "Error index out of bounds when restoring from checkpoint; "
    699                     "error index: ",
    700                     error_index_int);
    701               }
    702               batch.error_index = static_cast<size_t>(error_index_int);
    703             }
    704             int64 output_size = -1;
    705             TF_RETURN_IF_ERROR(reader->ReadScalar(
    706                 strings::StrCat(batch_prefix, "_output_size"), &output_size));
    707             batch.outputs.reserve(output_size);
    708             for (size_t j = 0; j < output_size; ++j) {
    709               string tensor_name = strings::StrCat(batch_prefix, "_output_", j);
    710               Tensor t;
    711               TF_RETURN_IF_ERROR(reader->ReadTensor(tensor_name, &t));
    712               batch.outputs.emplace_back(std::move(t));
    713             }
    714             batches_[i] = std::move(batch);
    715           }
    716           return Status::OK();
    717         }
    718 
    719        private:
    720         bool AllMapOperationsFinished() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    721           for (size_t i = 0; i < kWindowSize; ++i) {
    722             if (batches_[i].state == BatchState::kInputsFilled ||
    723                 batches_[i].state == BatchState::kAllMapsStarted) {
    724               return false;
    725             }
    726             if (batches_[i].state != BatchState::kOutputsComplete &&
    727                 !reached_eof_) {
    728               return false;
    729             }
    730           }
    731           return true;
    732         }
    733 
    734         // Batches begin in the `kEmpty` state. Once the RunnerThread has
    735         // filled the `inputs` to a `Batch`, it transitions to the
    736         // `kInputsFilled` state. At this point, the Worker threads run the map
    737         // function and copy the outputs appropriately. Once all worker threads
    738         // have started, it transitions to `kAllMapsStarted`. After the outputs
    739         // are complete, the GetNext call can consume the outputs, and return
    740         // the batch to the kEmpty state.
    741         enum class BatchState {
    742           kEmpty,
    743           kInputsFilled,
    744           kAllMapsStarted,
    745           kOutputsComplete,
    746         };
    747 
    748         // Batch captures all the state of an output batch as it progresses
    749         // through the machinery. Once the RunnerThread fills inputs, it
    750         // transitions to `kInputsFilled`. At this point, the worker threads can
    751         // work on it, incrementing outputs_complete for every element of the
    752         // input set that is copied into the output Tensors. Once all the input
    753         // tuples have been processed (i.e. num_outputs_complete ==
    754         // inputs.size()), it transitions to the `kOutputsComplete` stage, where
    755         // it is ready to be returned by a `GetBatch` call (called from
    756         // `GetNextInternal`).
    757         struct Batch {
    758           BatchState state;
    759           // Aggregates the Status of the input iterator's GetNext
    760           // calls, in addition to the Status of the map function invocations.
    761           //
    762           // In the case where multiple non-OK statuses are encountered, we
    763           // return the first one encountered.
    764           Status status;
    765           // In order to return the correct error status, we keep track of the
    766           // error_index.
    767           size_t error_index;
    768           // The batch_size input tuples (or fewer in the case of the last
    769           // batch).
    770           // TODO(saeta): Avoid re-allocating vectors all the time!
    771           std::vector<std::vector<Tensor>> inputs;
    772           std::vector<Tensor> outputs;
    773           size_t next_input_to_process;
    774           size_t num_outputs_complete;
    775 
    776           Batch() { Reset(); }
    777 
    778           // Resets the Batch state (e.g. after consuming the outputs).
    779           void Reset() {
    780             state = BatchState::kEmpty;
    781             status = Status::OK();
    782             inputs.clear();
    783             inputs.shrink_to_fit();
    784             outputs.clear();
    785             outputs.shrink_to_fit();
    786             next_input_to_process = 0;
    787             num_outputs_complete = 0;
    788             error_index = -1;
    789           }
    790         };
    791 
    792         Iterator* itr_;  // Not owned.
    793         mutex mu_;
    794         Batch batches_[kWindowSize] GUARDED_BY(mu_);
    795         size_t next_input_batch_ GUARDED_BY(mu_) = -1;
    796         size_t next_input_ GUARDED_BY(mu_) = 0;
    797         size_t next_output_ GUARDED_BY(mu_) = 0;
    798         bool cancelled_ GUARDED_BY(mu_) = false;
    799         bool reached_eof_ GUARDED_BY(mu_) = false;
    800 
    801         // The runner thread waits on this condition variable for space to be
    802         // available. When the client thread takes a value out of the circular
    803         // buffer, it notifies this condition variable that space is now
    804         // available.
    805         condition_variable runner_cond_var_ GUARDED_BY(mu_);
    806         // The worker threads wait on this condition variable for available
    807         // inputs. When the runner thread makes new inputs available, it
    808         // notifies this condition variable.
    809         condition_variable worker_cond_var_ GUARDED_BY(mu_);
    810         // The client threads wait on this condition variable for available
    811         // batched outputs. When worker threads complete a batch, they notify
    812         // this condition variable.
    813         condition_variable client_cond_var_ GUARDED_BY(mu_);
    814       };
    815       // Mark NumaBlockManager as a friend of Iterator in order to call
    816       // protected Iterator methods during checkpointing.
    817       friend NumaBlockManager;
    818 
    819       struct NumaWorkerBlock {
    820         NumaBlockManager manager;
    821         // TODO(saeta): Migrate to BackgroundWorker.
    822         std::vector<std::unique_ptr<Thread>> threads;
    823 
    824         explicit NumaWorkerBlock(Iterator* itr) : manager(itr) {}
    825       };
    826 
    827       static void CustomNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) {
    828         ptr->~NumaWorkerBlock();
    829         port::NUMAFree(ptr, sizeof(NumaWorkerBlock));
    830       }
    831       static void DefaultNumaWorkerBlockDeleter(NumaWorkerBlock* ptr) {
    832         delete ptr;
    833       }
    834 
    835       static Status CopyPartialBatch(Tensor* output, const Tensor& value,
    836                                      int64 num_elements) {
    837         switch (value.dtype()) {
    838 #define HANDLE_TYPE(type)                                         \
    839   case DataTypeToEnum<type>::value: {                             \
    840     auto output_t = output->flat_outer_dims<type>();              \
    841     auto value_t = value.flat_outer_dims<type>();                 \
    842     for (size_t i = 0; i < num_elements; i++) {                   \
    843       output_t.template chip<0>(i) = value_t.template chip<0>(i); \
    844     }                                                             \
    845     return Status::OK();                                          \
    846   }
    847           TF_CALL_DATASET_TYPES(HANDLE_TYPE);
    848 #undef HANDLE_TYPE
    849           default:
    850             return errors::InvalidArgument("Unsupported data type: ",
    851                                            DataTypeString(value.dtype()));
    852         }
    853         return Status::OK();
    854       }
    855 
    856       Status EnsureBackgroundThreadsStarted(IteratorContext* ctx)
    857           EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
    858         if (curr_num_parallel_calls_ >= num_parallel_calls_->value) {
    859           // All necessary threads have been started.
    860           curr_num_parallel_calls_ = num_parallel_calls_->value;
    861           return Status::OK();
    862         }
    863 
    864         VLOG(4) << "Starting workers";
    865         bool numa_enabled = port::NUMAEnabled();
    866 
    867         if (!numa_enabled) {
    868           LOG(INFO) << "NUMA not enabled on this host.";
    869         }
    870 
    871         int num_numa_nodes = port::NUMANumNodes();
    872         if (num_numa_nodes < 1) {
    873           return errors::Internal("The number of NUMA nodes is invalid: ",
    874                                   num_numa_nodes);
    875         }
    876 
    877         // Only resize when empty to support restoring from checkpoints.
    878         if (workers_.empty()) {
    879           VLOG(3) << "# NUMA Nodes: " << num_numa_nodes
    880                   << ", # Parallel Calls: " << num_parallel_calls_->value;
    881           workers_.resize(num_numa_nodes);
    882         } else {
    883           num_numa_nodes = workers_.size();
    884         }
    885 
    886         // Round up num_parallel_calls, with a minimum of 1.
    887         const size_t num_threads_per_block =
    888             std::max(1LL, (num_parallel_calls_->value + num_numa_nodes - 1) /
    889                               num_numa_nodes);
    890 
    891         VLOG(3) << "Starting " << num_threads_per_block * num_numa_nodes
    892                 << " worker threads, with " << num_threads_per_block
    893                 << " threads per block.";
    894 
    895         // Only allocate new_ctx if required.
    896         std::shared_ptr<IteratorContext> new_ctx;
    897 
    898         for (int i = 0; i < num_numa_nodes; ++i) {
    899           if (!workers_[i]) {
    900             if (numa_enabled) {
    901               // Allocate in appropriate NUMA domain.
    902               // 4k page align.
    903               void* ptr = port::NUMAMalloc(i, sizeof(NumaWorkerBlock), 0);
    904               if (ptr != nullptr) {
    905                 NumaWorkerBlock* block = new (ptr) NumaWorkerBlock(this);
    906                 workers_[i] =
    907                     std::unique_ptr<NumaWorkerBlock,
    908                                     std::function<void(NumaWorkerBlock*)>>(
    909                         block, CustomNumaWorkerBlockDeleter);
    910               } else {
    911                 LOG(ERROR) << "Could not NUMA-allocate worker block: " << i;
    912               }
    913             }
    914             // If the NUMA allocation fails, or NUMA is not enabled.
    915             if (!workers_[i]) {
    916               workers_[i] =
    917                   std::unique_ptr<NumaWorkerBlock,
    918                                   std::function<void(NumaWorkerBlock*)>>(
    919                       new NumaWorkerBlock(this), DefaultNumaWorkerBlockDeleter);
    920             }
    921           }
    922           // Be sure to start threads if num_parallel_calls_ has changed.
    923           for (size_t j = workers_[i]->threads.size();
    924                j < num_threads_per_block; ++j) {
    925             VLOG(3) << "Starting worker " << i << ", " << j;
    926             if (!new_ctx) {
    927               new_ctx = std::make_shared<IteratorContext>(*ctx);
    928             }
    929             workers_[i]->threads.emplace_back(ctx->StartThread(
    930                 strings::StrCat("tf_data_numa_map_and_batch_", i, "_", j),
    931                 [this, new_ctx, i, j]() { WorkerThread(new_ctx, i, j); }));
    932             VLOG(3) << "Worker " << i << ", " << j << " successfully started.";
    933           }
    934         }
    935         if (!runner_thread_) {
    936           if (!new_ctx) {
    937             new_ctx = std::make_shared<IteratorContext>(*ctx);
    938           }
    939           runner_thread_ =
    940               ctx->StartThread("tf_data_numa_map_and_batch",
    941                                [this, new_ctx] { RunnerThread(new_ctx); });
    942         }
    943         VLOG(3) << "All workers & runner thread started.";
    944         return Status::OK();
    945       }
    946 
    947       void AllocateOutput(IteratorContext* ctx, size_t batch_size,
    948                           const std::vector<Tensor>& map_fn_outputs,
    949                           std::vector<Tensor>* batch_outputs) {
    950         DCHECK(dataset()->output_dtypes().size() ==
    951                dataset()->output_shapes().size());
    952         DCHECK(map_fn_outputs.size() == dataset()->output_dtypes().size());
    953         for (size_t i = 0; i < dataset()->output_dtypes().size(); ++i) {
    954           TensorShape component_shape({static_cast<uint32>(batch_size)});
    955           component_shape.AppendShape(map_fn_outputs.at(i).shape());
    956           AllocatorAttributes attr;
    957           attr.set_gpu_compatible(true);
    958           batch_outputs->emplace_back(ctx->allocator(attr),
    959                                       map_fn_outputs.at(i).dtype(),
    960                                       component_shape);
    961         }
    962       }
    963 
    964       void RunnerThread(std::shared_ptr<IteratorContext> ctx)
    965           LOCKS_EXCLUDED(mu_) {
    966         RecordStart(ctx.get());
    967         auto cleanup = gtl::MakeCleanup([this, &ctx] {
    968           // Set end of input on all the managers in order to clean up in an
    969           // orderly fashion.
    970           VLOG(3) << "Setting End of Input on workers_[*]->manager";
    971           for (size_t i = 0; i < workers_.size(); ++i) {
    972             workers_[i]->manager.SetEndOfInput();
    973           }
    974           RecordStop(ctx.get());
    975         });
    976 
    977         const size_t num_blocks = workers_.size();
    978 
    979         while (true) {
    980           for (size_t block = 0; block < num_blocks; ++block) {
    981             VLOG(4) << "RunnerThread waiting for input space in block: "
    982                     << block;
    983             if (TF_PREDICT_FALSE(
    984                     !workers_[block]->manager.WaitForInputSpace(ctx.get()))) {
    985               VLOG(3) << "RunnerThread exiting due to cancellation.";
    986               return;
    987             }
    988             VLOG(4) << "RunnerThread has space; pulling on upstream for block "
    989                     << block;
    990 
    991             Status s;
    992             std::vector<std::vector<Tensor>> inputs;
    993             bool end_of_sequence = false;
    994             for (size_t i = 0; i < dataset()->batch_size_; ++i) {
    995               std::vector<Tensor> tuple;
    996               s.Update(
    997                   input_impl_->GetNext(ctx.get(), &tuple, &end_of_sequence));
    998               if (!s.ok()) {
    999                 break;
   1000               }
   1001               if (end_of_sequence) {
   1002                 VLOG(4) << "Runner thread encountered end of sequence.";
   1003                 if (dataset()->drop_remainder_) {
   1004                   return;
   1005                 }
   1006                 break;
   1007               }
   1008               inputs.push_back(std::move(tuple));
   1009             }
   1010 
   1011             VLOG(4) << "Moving inputs to block " << block
   1012                     << ", which has size: " << inputs.size();
   1013             if (!s.ok() || !inputs.empty()) {
   1014               workers_[block]->manager.PushInputs(s, std::move(inputs));
   1015               VLOG(4) << "Inputs moved into block " << block;
   1016             }
   1017             if (end_of_sequence) {
   1018               return;
   1019             }
   1020           }
   1021         }
   1022       }
   1023 
   1024       void WorkerThread(std::shared_ptr<IteratorContext> ctx,
   1025                         const int numa_node, const int thread_num) {
   1026         RecordStart(ctx.get());
   1027         WORKER_VLOG(3) << "started.";
   1028         auto stop_cleanup =
   1029             gtl::MakeCleanup([this, numa_node, thread_num, &ctx]() {
   1030               RecordStop(ctx.get());
   1031               WORKER_VLOG(3) << "exiting.";
   1032             });
   1033 
   1034         NumaWorkerBlock* block = workers_[numa_node].get();
   1035         port::NUMASetThreadNodeAffinity(numa_node);
   1036         const int num_numa_nodes = port::NUMANumNodes();
   1037         const int minimum_num_parallel_calls = thread_num * num_numa_nodes;
   1038 
   1039         while (true) {
   1040           // Put threads to sleep based on autotuner.
   1041           {
   1042             mutex_lock l(*mu_);
   1043             while (minimum_num_parallel_calls >= num_parallel_calls_->value &&
   1044                    !cancelled_) {
   1045               RecordStop(ctx.get());
   1046               autotune_cond_var_->wait(l);
   1047               RecordStart(ctx.get());
   1048             }
   1049             if (cancelled_) {
   1050               return;
   1051             }
   1052           }
   1053 
   1054           std::vector<Tensor> input;
   1055           uint64 index = 0;
   1056           size_t sequence_number = 0;
   1057           WORKER_VLOG(4) << "retrieving input.";
   1058           {
   1059             tracing::ScopedActivity trace(
   1060                 "NumaMapAndBatch::Iterator::Worker::RetrieveInput");
   1061             if (!block->manager.RetrieveInput(ctx.get(), &input, &index,
   1062                                               &sequence_number)) {
   1063               return;
   1064             }
   1065           }
   1066 
   1067           WORKER_VLOG(4) << "retrieved input; index: " << index
   1068                          << ", sequence_number: " << sequence_number;
   1069 
   1070           std::vector<Tensor> return_values;
   1071           Status s;
   1072           {
   1073             tracing::ScopedActivity trace(
   1074                 "NumaMapAndBatch::Iterator::Worker::FunctionExecution");
   1075             s = instantiated_captured_func_->Run(ctx.get(), std::move(input),
   1076                                                  &return_values);
   1077           }
   1078           WORKER_VLOG(4) << "ran function for index: " << index
   1079                          << ", sequence_number: " << sequence_number;
   1080 
   1081           if (s.ok()) {
   1082             std::vector<Tensor>* output = block->manager.GetBatchTensors(
   1083                 sequence_number,
   1084                 [this, ctx, &return_values](size_t batch_size,
   1085                                             std::vector<Tensor>* output) {
   1086                   AllocateOutput(ctx.get(), batch_size, return_values, output);
   1087                 });
   1088             WORKER_VLOG(4) << "copying tensors to batch output.";
   1089             {
   1090               tracing::ScopedActivity trace(
   1091                   "NumaMapAndBatch::Iterator::Worker::BatchCopy");
   1092               for (size_t i = 0; i < return_values.size() && s.ok(); ++i) {
   1093                 Tensor& tensor = return_values.at(i);
   1094                 Tensor* batch = &output->at(i);
   1095                 if (tensor.NumElements() !=
   1096                     (batch->NumElements() / batch->dim_size(0))) {
   1097                   s.Update(errors::InvalidArgument(
   1098                       "Cannot add tensor to the batch: number of elements does "
   1099                       "not match. Shapes are: [tensor]: ",
   1100                       tensor.shape().DebugString(),
   1101                       ", [batch]: ", batch->shape().DebugString()));
   1102                   break;
   1103                 }
   1104                 s.Update(batch_util::CopyElementToSlice(std::move(tensor),
   1105                                                         batch, index));
   1106               }
   1107             }
   1108           }
   1109 
   1110           block->manager.RecordBatchEntryComplete(sequence_number, index, s);
   1111           WORKER_VLOG(4) << "finished index: " << index
   1112                          << ", sequence_number: " << sequence_number;
   1113         }
   1114       }
   1115 
   1116       // mu_ protects shared internal state and is used to coordinate between
   1117       // the auto-tuner, client threads, worker threads, and the runner thread.
   1118       const std::shared_ptr<mutex> mu_;
   1119       const std::shared_ptr<condition_variable> autotune_cond_var_;
   1120       // The maximum number of parallel calls (can be auto-tuned).
   1121       const std::shared_ptr<model::SharedState> num_parallel_calls_;
   1122       std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
   1123 
   1124       // Caches the last-seen value of num_parallel_calls_->value to
   1125       // short-circuit starting workers.
   1126       int64 curr_num_parallel_calls_ GUARDED_BY(*mu_) = 0;
   1127 
   1128       std::unique_ptr<IteratorBase> input_impl_;
   1129       int64 cur_block_ GUARDED_BY(*mu_) = 0;
   1130       bool global_end_of_input_ GUARDED_BY(*mu_) = false;
   1131       bool cancelled_ GUARDED_BY(*mu_) = false;
   1132       std::vector<std::unique_ptr<NumaWorkerBlock,
   1133                                   std::function<void(NumaWorkerBlock*)>>>
   1134           workers_;  // Const after initialization.
   1135       std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
   1136     };
   1137 
   1138     const DatasetBase* const input_;
   1139     const int64 batch_size_;
   1140     const int64 num_parallel_calls_;
   1141     const bool drop_remainder_;
   1142     const DataTypeVector output_types_;
   1143     const std::vector<PartialTensorShape> output_shapes_;
   1144     const NameAttrList func_;
   1145     const std::unique_ptr<CapturedFunction> captured_func_;
   1146   };
   1147 
   1148   DataTypeVector output_types_;
   1149   std::vector<PartialTensorShape> output_shapes_;
   1150   NameAttrList func_;
   1151   bool preserve_cardinality_;
   1152 };
   1153 
   1154 REGISTER_KERNEL_BUILDER(
   1155     Name("ExperimentalNumaMapAndBatchDataset").Device(DEVICE_CPU),
   1156     NumaMapAndBatchDatasetOp);
   1157 
   1158 }  // namespace
   1159 }  // namespace data
   1160 }  // namespace tensorflow
   1161