1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <atomic> 16 #include <deque> 17 #include <utility> 18 19 #include "tensorflow/core/common_runtime/function.h" 20 #include "tensorflow/core/framework/dataset.h" 21 #include "tensorflow/core/framework/partial_tensor_shape.h" 22 #include "tensorflow/core/framework/stats_aggregator.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/kernels/data/captured_function.h" 25 #include "tensorflow/core/kernels/data/dataset_utils.h" 26 #include "tensorflow/core/lib/core/threadpool.h" 27 #include "tensorflow/core/lib/gtl/cleanup.h" 28 #include "tensorflow/core/lib/random/random.h" 29 30 namespace tensorflow { 31 namespace data { 32 namespace { 33 34 // See documentation in ../../ops/dataset_ops.cc for a high-level 35 // description of the following op. 36 37 class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { 38 public: 39 explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) 40 : UnaryDatasetOpKernel(ctx) { 41 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); 42 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 44 } 45 46 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 47 DatasetBase** output) override { 48 int64 cycle_length = 0; 49 OP_REQUIRES_OK(ctx, 50 ParseScalarArgument(ctx, "cycle_length", &cycle_length)); 51 OP_REQUIRES(ctx, cycle_length > 0, 52 errors::InvalidArgument("`cycle_length` must be > 0")); 53 54 int64 block_length = 0; 55 OP_REQUIRES_OK(ctx, 56 ParseScalarArgument(ctx, "block_length", &block_length)); 57 OP_REQUIRES(ctx, block_length > 0, 58 errors::InvalidArgument("`block_length` must be > 0")); 59 60 bool sloppy = false; 61 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy)); 62 63 int64 buffer_output_elements = 0; 64 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements", 65 &buffer_output_elements)); 66 OP_REQUIRES( 67 ctx, buffer_output_elements > 0, 68 errors::InvalidArgument("`buffer_output_elements` must be > 0")); 69 70 int64 prefetch_input_elements = 0; 71 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements", 72 &prefetch_input_elements)); 73 OP_REQUIRES( 74 ctx, prefetch_input_elements >= 0, 75 errors::InvalidArgument("`prefetch_input_elements` must be >= 0")); 76 77 std::unique_ptr<CapturedFunction> captured_func; 78 OP_REQUIRES_OK( 79 ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", 80 &captured_func)); 81 82 *output = 83 new Dataset(ctx, input, interleave_func_, std::move(captured_func), 84 cycle_length, block_length, sloppy, buffer_output_elements, 85 prefetch_input_elements, output_types_, output_shapes_); 86 } 87 88 private: 89 class Dataset : public DatasetBase { 90 public: 91 Dataset(OpKernelContext* ctx, const DatasetBase* input, 92 const NameAttrList& func, 93 std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, 94 int64 block_length, bool sloppy, int64 buffer_output_elements, 95 int64 prefetch_input_elements, const DataTypeVector& output_types, 96 const std::vector<PartialTensorShape>& output_shapes) 97 : DatasetBase(DatasetContext(ctx)), 98 input_(input), 99 interleave_func_(func), 100 captured_func_(std::move(captured_func)), 101 cycle_length_(cycle_length), 102 block_length_(block_length), 103 sloppy_(sloppy), 104 buffer_output_elements_(buffer_output_elements), 105 prefetch_input_elements_(prefetch_input_elements), 106 output_types_(output_types), 107 output_shapes_(output_shapes) { 108 input_->Ref(); 109 } 110 111 ~Dataset() override { input_->Unref(); } 112 113 std::unique_ptr<IteratorBase> MakeIteratorInternal( 114 const string& prefix) const override { 115 return absl::make_unique<Iterator>(Iterator::Params{ 116 this, strings::StrCat(prefix, "::ParallelInterleave")}); 117 } 118 119 const DataTypeVector& output_dtypes() const override { 120 return output_types_; 121 } 122 123 const std::vector<PartialTensorShape>& output_shapes() const override { 124 return output_shapes_; 125 } 126 127 string DebugString() const override { 128 return "ParallelInterleaveDatasetOp::Dataset"; 129 } 130 131 protected: 132 Status AsGraphDefInternal(SerializationContext* ctx, 133 DatasetGraphDefBuilder* b, 134 Node** output) const override { 135 TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); 136 Node* input_node; 137 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); 138 Node* cycle_length_node; 139 TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); 140 Node* block_length_node; 141 TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); 142 Node* sloppy_node; 143 TF_RETURN_IF_ERROR(b->AddScalar(sloppy_, &sloppy_node)); 144 Node* buffer_output_elements_node; 145 TF_RETURN_IF_ERROR( 146 b->AddScalar(buffer_output_elements_, &buffer_output_elements_node)); 147 Node* prefetch_input_elements_node; 148 TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_, 149 &prefetch_input_elements_node)); 150 DataTypeVector other_arguments_types; 151 other_arguments_types.reserve(captured_func_->captured_inputs().size()); 152 std::vector<Node*> other_arguments; 153 other_arguments.reserve(captured_func_->captured_inputs().size()); 154 for (const Tensor& t : captured_func_->captured_inputs()) { 155 Node* node; 156 DatasetBase* input; 157 Status s = GetDatasetFromVariantTensor(t, &input); 158 if (s.ok()) { 159 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input, &node)); 160 } else { 161 TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); 162 } 163 other_arguments.emplace_back(node); 164 other_arguments_types.emplace_back(t.dtype()); 165 } 166 AttrValue f; 167 b->BuildAttrValue(interleave_func_, &f); 168 AttrValue other_arguments_types_attr; 169 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); 170 171 TF_RETURN_IF_ERROR(b->AddDataset( 172 this, 173 {{0, input_node}, 174 {2, cycle_length_node}, 175 {3, block_length_node}, 176 {4, sloppy_node}, 177 {5, buffer_output_elements_node}, 178 {6, prefetch_input_elements_node}}, 179 {{1, other_arguments}}, 180 {{"f", f}, {"Targuments", other_arguments_types_attr}}, output)); 181 return Status::OK(); 182 } 183 184 private: 185 int64 num_threads() const { 186 return cycle_length_ + prefetch_input_elements_; 187 } 188 189 // Parallel interleave's implementation is designed around a few principles: 190 // 1. Thread creation is relatively expensive. (Not reusing 191 // threads causes a number of indirect costs such as poorer tcmalloc 192 // performance due to thread-local caches, etc.) We allocate a fixed 193 // number of threads at the start and never change. This is why we've 194 // fused functionality that is theoretically orthogonal (i.e. 195 // .prefetch()) into the implementation. 196 // 2. Drop-in replacement for standard interleave. The goal will be to 197 // auto-opt people into an optimized implementation without any work 198 // on the customer's part. We thus go through great pains to maintain 199 // identical iteration orders, full determinism (disabled only via a 200 // flag, etc.) 201 // 3. Performance across a variety of environments and I/O envelopes. 202 // 203 // The actual implementation centers around a collection of worker threads 204 // and their corresponding worker state (tracked in the `workers_` vector). 205 // Worker threads repeatedly receive a vector of Tensors that are used as 206 // input to the flat-map function (`captured_func_`). The output of this 207 // function must be a dataset. The worker thread then repeatedly calls 208 // `GetNext()`, maintaining a buffer of elements to minimize the likelihood 209 // that a caller will block waiting for an element to be produced. 210 // 211 // Pointers to these worker states are kept in 2 disjoint data structures: 212 // 1. `interleave_indices_` is a vector containing indices of WorkerStates 213 // in `workers_` that we are interleaving. Worker threads backing these 214 // WorkerStates should be regularly producing values. 215 // 2. `staging_indices_` is a deque containing indices of WorkerStates in 216 // `workers_` that we will move to `interleave_indices_` when an 217 // iterator in `interleave_indices_` is exhausted. 218 // 219 // The client calls `GetNext[Internal]()` to retrieve an output element. The 220 // internal implementation updates the state of `interleave_indices_` and 221 // `staging_indices_` as output iterators (run by the worker threads) are 222 // exhausted. 223 // 224 // `input_impl_` is the input iterator that generates arguments for the 225 // flat-map function (`captured_func_`). It is set to an iterator at 226 // Iterator construction, and is fixed until we consume all input elements. 227 // Once it is exhausted, we reset the unique_ptr to eagerly deallocate 228 // memory. 229 // 230 // A few invariants are maintained: 231 // 1. No element in interleave_indices_ should be a -1 unless 232 // `staging_indices_` is empty and `input_impl_` is empty. 233 // 2. Every `worker_` element is pointed to by at most one element of the 234 // union of `interleave_indices_` and `staging_indices_`. 235 // 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by 236 // an element in `interleave_indices_` or `staging_indices_`. 237 class Iterator : public DatasetIterator<Dataset> { 238 public: 239 explicit Iterator(const Params& params) 240 : DatasetIterator<Dataset>(params), 241 workers_(dataset()->num_threads()), 242 worker_thread_states_(dataset()->num_threads()) {} 243 244 ~Iterator() override { 245 mutex_lock l(mu_); 246 cancelled_ = true; 247 // Notify all workers in case they are blocked. 248 for (auto& worker : workers_) { 249 worker.cond_var.notify_all(); 250 } 251 } 252 253 Status Initialize(IteratorContext* ctx) override { 254 TF_RETURN_IF_ERROR( 255 dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); 256 return dataset()->captured_func_->Instantiate( 257 ctx, &instantiated_captured_func_); 258 } 259 260 // It is implemented so that it matches the deterministic interleave 261 // unless getting the next element would block and we are allowed to be 262 // sloppy. 263 Status GetNextInternal(IteratorContext* ctx, 264 std::vector<Tensor>* out_tensors, 265 bool* end_of_sequence) override { 266 mutex_lock l(mu_); 267 TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); 268 while (!cancelled_) { 269 // Wait for an item to become available, blocking if necessary. If we 270 // are allowed to be sloppy, we can skip over input datasets that do 271 // not have an item readily available. 272 bool can_produce_elements = false; 273 bool must_wait_for_input = true; 274 for (int64 i = 0; i < interleave_indices_.size(); ++i) { 275 int64 index = (next_index_ + i) % interleave_indices_.size(); 276 int64 current_worker_index = interleave_indices_[index]; 277 if (current_worker_index < 0) { 278 continue; // Empty interleave elements. 279 } 280 WorkerState* current_worker = &workers_[current_worker_index]; 281 can_produce_elements |= current_worker->MayHaveElements(); 282 if (!current_worker->outputs.empty()) { 283 // We have an element! 284 next_index_ = index; 285 const bool element_acquired_sloppily = 286 dataset()->sloppy_ && i > 1; 287 if (!element_acquired_sloppily) { 288 // If the element was acquired in the regular (non-sloppy) 289 // order, then advance the current block and cycle pointers to 290 // the next element in the regular order. 291 block_count_++; 292 if (block_count_ == dataset()->block_length_) { 293 next_index_ = (index + 1) % interleave_indices_.size(); 294 block_count_ = 0; 295 } 296 } else { 297 block_count_ = 0; 298 } 299 *end_of_sequence = false; 300 Status s = current_worker->outputs.front().status; 301 current_worker->outputs.front().output.swap(*out_tensors); 302 current_worker->outputs.pop_front(); 303 current_worker->cond_var.notify_one(); 304 return s; 305 } else if (current_worker->is_producing && !dataset()->sloppy_) { 306 // current_worker.outputs.empty(), and we must wait for this 307 // iterator. 308 if (next_index_ != index) { 309 // We have advanced to a new iterator; reset block counts. 310 next_index_ = index; 311 block_count_ = 0; 312 } 313 break; 314 } else if (!current_worker->is_producing) { 315 // This iterator has reached end of input. 316 interleave_indices_[index] = -1; 317 if (input_impl_) { 318 // Start prefetching a new iterator. 319 std::vector<Tensor> args; 320 bool end_of_input = false; 321 Status s = input_impl_->GetNext(ctx, &args, &end_of_input); 322 if (end_of_input) { 323 input_impl_.reset(); 324 } else { 325 current_worker->SetInputs(s, std::move(args)); 326 staging_indices_.emplace_back(current_worker_index); 327 } 328 } 329 330 if (!staging_indices_.empty()) { 331 // Move a worker from `staging_indices_` to 332 // `interleave_indices_`. 333 interleave_indices_[index] = staging_indices_.front(); 334 staging_indices_.pop_front(); 335 336 next_index_ = (index + 1) % interleave_indices_.size(); 337 block_count_ = 0; 338 // Restart the inner [for] loop 339 can_produce_elements = true; 340 must_wait_for_input = false; 341 break; 342 } 343 } 344 } 345 346 if (!can_produce_elements && !input_impl_) { 347 // No potential for future values. 348 *end_of_sequence = true; 349 return Status::OK(); 350 } 351 352 if (must_wait_for_input) { 353 // Wait for elements to become available. 354 RecordStop(ctx); 355 if (dataset()->sloppy_) { 356 sloppy_cond_var_.wait(l); 357 } else { 358 workers_[interleave_indices_[next_index_]].cond_var.wait(l); 359 } 360 RecordStart(ctx); 361 } 362 } 363 return errors::Cancelled( 364 "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext"); 365 } 366 367 protected: 368 std::shared_ptr<model::Node> CreateNode( 369 IteratorContext* ctx, model::Node::Args args) const override { 370 return model::MakeAsyncInterleaveManyNode(std::move(args), 371 /*parameters=*/{}); 372 } 373 374 Status SaveInternal(IteratorStateWriter* writer) override { 375 // The order of locking is important here to avoid deadlock. 376 mutex_lock l(mu_); 377 mutex_lock ckpt_l(ckpt_mu_); 378 if (input_impl_) { 379 TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); 380 } else { 381 TF_RETURN_IF_ERROR( 382 writer->WriteScalar(full_name("input_exhausted"), "")); 383 } 384 TF_RETURN_IF_ERROR( 385 writer->WriteScalar(full_name("next_index"), next_index_)); 386 TF_RETURN_IF_ERROR( 387 writer->WriteScalar(full_name("block_count"), block_count_)); 388 TF_RETURN_IF_ERROR( 389 writer->WriteScalar(full_name("workers_size"), workers_.size())); 390 for (int i = 0; i < workers_.size(); ++i) { 391 TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i)); 392 } 393 for (int i = 0; i < worker_thread_states_.size(); ++i) { 394 TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i)); 395 } 396 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("interleave_size"), 397 interleave_indices_.size())); 398 for (int i = 0; i < interleave_indices_.size(); ++i) { 399 TF_RETURN_IF_ERROR(writer->WriteScalar( 400 full_name(strings::StrCat("interleave_indices_", i)), 401 interleave_indices_[i])); 402 } 403 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("staging_size"), 404 staging_indices_.size())); 405 for (int i = 0; i < staging_indices_.size(); ++i) { 406 TF_RETURN_IF_ERROR(writer->WriteScalar( 407 full_name(strings::StrCat("staging_indices_", i)), 408 staging_indices_[i])); 409 } 410 if (!worker_threads_.empty()) { 411 TF_RETURN_IF_ERROR( 412 writer->WriteScalar(full_name("worker_threads_running"), "")); 413 } 414 return Status::OK(); 415 } 416 417 Status RestoreInternal(IteratorContext* ctx, 418 IteratorStateReader* reader) override { 419 // The order of locking is important here to avoid deadlock. 420 mutex_lock l(mu_); 421 mutex_lock ckpt_l(ckpt_mu_); 422 if (!reader->Contains(full_name("input_exhausted"))) { 423 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); 424 } else { 425 input_impl_.reset(); 426 } 427 int64 temp; 428 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next_index"), &temp)); 429 next_index_ = size_t(temp); 430 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("block_count"), &temp)); 431 block_count_ = size_t(temp); 432 433 // Restore WorkerStates. 434 TF_RETURN_IF_ERROR( 435 reader->ReadScalar(full_name("workers_size"), &temp)); 436 if (temp != dataset()->num_threads()) { 437 return errors::Internal("Expected ", dataset()->num_threads(), 438 " worker states but found ", temp, "."); 439 } 440 for (size_t i = 0; i < dataset()->num_threads(); ++i) { 441 TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx)); 442 } 443 for (size_t i = 0; i < dataset()->num_threads(); ++i) { 444 TF_RETURN_IF_ERROR(ReadWorkerThreadStateLocked(reader, i, ctx)); 445 } 446 447 // Restore `interleave_indices_`. 448 std::set<int64> all_indices; 449 { 450 int64 interleave_size; 451 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("interleave_size"), 452 &interleave_size)); 453 interleave_indices_.reserve(interleave_size); 454 for (int64 i = 0; i < interleave_size; ++i) { 455 int64 temp; 456 TF_RETURN_IF_ERROR(reader->ReadScalar( 457 full_name(strings::StrCat("interleave_indices_", i)), &temp)); 458 if (temp >= 0 && all_indices.find(temp) != all_indices.end()) { 459 return errors::Internal( 460 "Duplicate entry for ", temp, 461 " found when reading interleave and staging indices."); 462 } 463 if (temp >= 0) { 464 all_indices.insert(temp); 465 } 466 interleave_indices_.emplace_back(temp); 467 } 468 } 469 470 // Restore `staging_indices_`. 471 { 472 int64 staging_size; 473 TF_RETURN_IF_ERROR( 474 reader->ReadScalar(full_name("staging_size"), &staging_size)); 475 for (int i = 0; i < staging_size; ++i) { 476 int64 temp; 477 TF_RETURN_IF_ERROR(reader->ReadScalar( 478 full_name(strings::StrCat("staging_indices_", i)), &temp)); 479 if (all_indices.find(temp) != all_indices.end()) { 480 return errors::Internal( 481 "Duplicate entry for ", temp, 482 " found when reading interleave and staging indices."); 483 } 484 if (temp >= 0) { 485 all_indices.insert(temp); 486 } 487 staging_indices_.emplace_back(temp); 488 } 489 } 490 491 // Start Worker threads. 492 if (reader->Contains(full_name("worker_threads_running"))) { 493 worker_threads_.reserve(dataset()->num_threads()); 494 for (size_t i = 0; i < dataset()->num_threads(); ++i) { 495 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); 496 worker_threads_.emplace_back(ctx->StartThread( 497 strings::StrCat("tf_data_parallel_interleave_worker_", i), 498 [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); 499 } 500 } 501 return Status::OK(); 502 } 503 504 private: 505 // OutputElem contains the information from a call to GetNext by an output 506 // iterator. 507 struct OutputElem { 508 // The output iterator sets `status` if getting the output element 509 // fails. 510 Status status; 511 // The buffered data element. 512 std::vector<Tensor> output; 513 514 explicit OutputElem(const Status& s) : status(s) {} 515 }; 516 517 // Worker threads operate on their relevant WorkerState structs. 518 // 519 // WorkerState's fields are all protected by mu_; 520 struct WorkerState { 521 // The arguments to be used to construct an output iterator. 522 std::vector<Tensor> input; 523 // The buffered output elements. 524 std::deque<OutputElem> outputs; 525 // Set to true iff the worker thread expects to append more elements to 526 // outputs. is_producing can be false despite !outputs.empty(). 527 // Concretely, all output elements will have been consumed only when: 528 // is_producing == false && outputs.empty(); 529 bool is_producing = false; 530 // Condition variable used to coordinate between threads. The worker 531 // thread waits on this condition variable when it is either (1) waiting 532 // for the main thread to add arguments to `input`, or (2) waiting for 533 // the main thread to consume an element of `outputs`. The main thread 534 // waits on cond_var if it is waiting for the worker thread to produce 535 // an element into `outputs` (this implies sloppy_==false). 536 condition_variable cond_var; 537 538 inline bool MayHaveElements() const { 539 return is_producing || !outputs.empty(); 540 } 541 542 // Sets inputs for a worker thread and notifies it to start processing. 543 void SetInputs(const Status& s, std::vector<Tensor> input_arguments) { 544 if (s.ok()) { 545 DCHECK(!MayHaveElements()) 546 << "Tried to start inputs, despite already producing!"; 547 input = std::move(input_arguments); 548 is_producing = true; 549 cond_var.notify_one(); 550 } else { 551 outputs.emplace_back(s); 552 } 553 } 554 }; 555 556 // The internal state of a worker thread that is not already captured 557 // in its `WorkerState`. 558 // 559 // This is needed only for checkpointing purposes. We keep this 560 // separate from `WorkerState` and guard its fields using a separate 561 // lock `ckpt_mu_` so as to not affect the performance of main pipeline. 562 struct WorkerThreadState { 563 // The output element that has been produced from the input iterator 564 // and is waiting to be added to `WorkerState.outputs`. 565 OutputElem output_elem; 566 567 // Whether the input iterator returned an `end_of_sequence`. 568 bool end_of_sequence = false; 569 570 // Status returned from `MakeIteratorFromInputElement`. 571 Status iterator_creation_status; 572 573 // The arguments to be used to construct `iterator`. 574 std::vector<Tensor> input; 575 576 std::unique_ptr<IteratorBase> iterator; 577 578 WorkerThreadState() : output_elem(Status::OK()) {} 579 }; 580 581 Status EnsureWorkerThreadsStarted(IteratorContext* ctx) 582 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 583 if (worker_threads_.empty()) { 584 worker_threads_.reserve(dataset()->num_threads()); 585 for (int64 i = 0; i < dataset()->num_threads(); ++i) { 586 std::vector<Tensor> args; 587 bool end_of_input = false; 588 Status s = input_impl_->GetNext(ctx, &args, &end_of_input); 589 if (end_of_input) { 590 input_impl_.reset(); 591 return Status::OK(); 592 } 593 workers_[i].SetInputs(s, std::move(args)); 594 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); 595 worker_threads_.push_back(ctx->StartThread( 596 strings::StrCat("tf_data_parallel_interleave_worker_", i), 597 [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); 598 if (i < dataset()->cycle_length_) { 599 interleave_indices_.push_back(i); 600 } else { 601 staging_indices_.push_back(i); 602 } 603 } 604 DCHECK(interleave_indices_.size() == dataset()->cycle_length_); 605 DCHECK(staging_indices_.size() == 606 dataset()->prefetch_input_elements_); 607 } 608 return Status::OK(); 609 } 610 611 // Produces elements into the worker's output buffers. 612 void WorkerThread(const std::shared_ptr<IteratorContext>& ctx, 613 const int64 thread_index) { 614 // Notes on checkpointing thread local state, i.e., `WorkerThreadState`: 615 // 616 // 1. Any local state that may need to be checkpointed should be kept 617 // in `worker_thread_states_[thread_index]`. 618 // 2. `WorkerThreadState` should contain state that is needed only for 619 // checkpointing, i.e., if we were to remove checkpointing support, 620 // we could keep that state as local variables in this thread. 621 // 3. This thread should only read/write state at `thread_index` 622 // and should not access other thread states. 623 // 4. When restoring from checkpoint, threads are started only after 624 // the restore is complete. 625 // 5. Once restored from a checkpoint, the local state is edited only 626 // by this thread. 3 & 4 allow making assumptions like temporarily 627 // caching local state in this thread and using it outside a lock 628 // e.g. `make_new_iterator`. 629 // 6. `ckpt_mu_` should be wisely used to create *consistent* 630 // checkpoint markers. 631 632 // std::function arguments are copy-constructable, so we pass raw 633 // pointers, and then immediately wrap them to ensure correct ownership. 634 RecordStart(ctx.get()); 635 auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] { 636 mutex_lock l(mu_); 637 workers_[thread_index].cond_var.notify_all(); 638 RecordStop(ctx.get()); 639 }); 640 bool make_new_iterator; 641 { 642 tf_shared_lock l(ckpt_mu_); 643 // Decide whether a new iterator should be built. 644 // 1. If there is an existing iterator, we use it. 645 // 2. If there was an error in iterator creation that could not be 646 // notified to the client we attempt to send that to the client 647 // first. 648 make_new_iterator = 649 worker_thread_states_[thread_index].iterator == nullptr && 650 worker_thread_states_[thread_index].iterator_creation_status.ok(); 651 } 652 // Even though `make_new_iterator` has cached values from 653 // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_, 654 // it is safe to *read* `make_new_iterator`outside of a lock without 655 // worrying about concurrent changes to values in 656 // `worker_thread_states_[thread_index]`. See comment at the start of 657 // this function for details. 658 while (true) { 659 // Whether creation of the iterator succeeded. 660 Status iterator_creation_status; 661 // 1. Build a new iterator or use the existing one. 662 if (make_new_iterator) { 663 // 1a. Get new input tensors or use the exiting ones. 664 bool read_new_input; 665 { 666 tf_shared_lock l(ckpt_mu_); 667 // worker_thread_states_[thread_index].input will be non-empty 668 // if checkpointing happened at CHECKPOINT_MARKER_A. 669 read_new_input = 670 worker_thread_states_[thread_index].input.empty(); 671 } 672 673 if (read_new_input) { 674 mutex_lock l(mu_); 675 while (!cancelled_ && !workers_[thread_index].is_producing) { 676 RecordStop(ctx.get()); 677 workers_[thread_index].cond_var.wait(l); 678 RecordStart(ctx.get()); 679 } 680 if (cancelled_) return; 681 // Copy the input tensors so that we do not need to block on `mu_` 682 // when building the iterator. 683 // We keep a copy of the input tensors in 684 // `WorkerThreadState.input` till the iterator is in use. This is 685 // used in `RestoreInternal` to re-build the iterator. 686 // TODO(b/78046638): Explore ways to avoid tracking the input 687 // tensors. 688 tf_shared_lock ckpt_l(ckpt_mu_); 689 worker_thread_states_[thread_index].input.swap( 690 workers_[thread_index].input); 691 // CHECKPOINT_MARKER_A 692 // We have the input tensors but have not built the iterator yet. 693 } 694 695 // 1b. Run the user defined function to produce a new iterator. 696 { 697 tf_shared_lock l(ckpt_mu_); 698 worker_thread_states_[thread_index].iterator_creation_status = 699 MakeIteratorFromInputElement( 700 ctx.get(), worker_thread_states_[thread_index].input, 701 thread_index, *instantiated_captured_func_, prefix(), 702 &worker_thread_states_[thread_index].iterator); 703 iterator_creation_status = 704 worker_thread_states_[thread_index].iterator_creation_status; 705 if (!iterator_creation_status.ok()) { 706 worker_thread_states_[thread_index].input.clear(); 707 } 708 // CHECKPOINT_MARKER_B 709 // Either an iterator has been successfully built and placed in 710 // `worker_thread_states_[thread_index].iterator` or it failed and 711 // a non-OK status has been put in 712 // `worker_thread_states_[thread_index].iterator_creation_status`. 713 } 714 } else { 715 tf_shared_lock l(ckpt_mu_); 716 iterator_creation_status = 717 worker_thread_states_[thread_index].iterator_creation_status; 718 // Mark that we have used up the restored iterator. 719 make_new_iterator = true; 720 } 721 // 2. Start producing elements or send error state to client if 722 // iterator creation failed. 723 if (!iterator_creation_status.ok()) { 724 mutex_lock l(mu_); 725 // Wait for space in the prefetch queue. 726 while (!cancelled_ && workers_[thread_index].outputs.size() == 727 dataset()->buffer_output_elements_) { 728 RecordStop(ctx.get()); 729 workers_[thread_index].cond_var.wait(l); 730 RecordStart(ctx.get()); 731 } 732 if (cancelled_) return; 733 tf_shared_lock ckpt_l(ckpt_mu_); 734 workers_[thread_index].outputs.emplace_back( 735 iterator_creation_status); 736 workers_[thread_index].is_producing = false; 737 worker_thread_states_[thread_index].iterator_creation_status = 738 Status::OK(); 739 // CHECKPOINT_MARKER_C 740 // Non-OK iterator creation status has been notified to the 741 // client. 742 workers_[thread_index].cond_var.notify_one(); 743 } else { 744 bool end_of_sequence = false; 745 while (!end_of_sequence) { 746 // 3.a Produce an element! 747 { 748 tf_shared_lock ckpt_l(ckpt_mu_); 749 if (worker_thread_states_[thread_index] 750 .output_elem.status.ok() && 751 worker_thread_states_[thread_index] 752 .output_elem.output.empty() && 753 !worker_thread_states_[thread_index].end_of_sequence) { 754 worker_thread_states_[thread_index].output_elem.status = 755 worker_thread_states_[thread_index].iterator->GetNext( 756 ctx.get(), 757 &worker_thread_states_[thread_index] 758 .output_elem.output, 759 &worker_thread_states_[thread_index].end_of_sequence); 760 end_of_sequence = 761 worker_thread_states_[thread_index].end_of_sequence; 762 } else { 763 end_of_sequence = 764 worker_thread_states_[thread_index].end_of_sequence; 765 } 766 // CHECKPOINT_MARKER_D 767 // An element has been read or an error or end_of_sequence has 768 // been received from the input iterator and is waiting to be 769 // sent to client. 770 } 771 772 // 3.b Make it available to the client. 773 { 774 mutex_lock l(mu_); 775 776 // Wait for space in the prefetch queue. 777 while (!cancelled_ && workers_[thread_index].outputs.size() == 778 dataset()->buffer_output_elements_) { 779 RecordStop(ctx.get()); 780 workers_[thread_index].cond_var.wait(l); 781 RecordStart(ctx.get()); 782 } 783 if (cancelled_) return; 784 785 tf_shared_lock ckpt_l(ckpt_mu_); 786 workers_[thread_index].is_producing = !end_of_sequence; 787 788 // Output the element. 789 790 // Move the temporary state in WorkerThreadState to WorkerState 791 // and mark it as used. 792 if (end_of_sequence) { 793 worker_thread_states_[thread_index].iterator.reset(); 794 worker_thread_states_[thread_index].input.clear(); 795 worker_thread_states_[thread_index].end_of_sequence = false; 796 } else { 797 workers_[thread_index].outputs.emplace_back( 798 worker_thread_states_[thread_index].output_elem.status); 799 workers_[thread_index].outputs.back().output.swap( 800 worker_thread_states_[thread_index].output_elem.output); 801 } 802 worker_thread_states_[thread_index].output_elem.status = 803 Status::OK(); 804 if (dataset()->sloppy_) { 805 sloppy_cond_var_.notify_one(); 806 } else { 807 workers_[thread_index].cond_var.notify_one(); 808 } 809 // CHECKPOINT_MARKER_E 810 // Output element or iterator status has been sent to the 811 // client. 812 } 813 } 814 } 815 } 816 } 817 818 Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index) 819 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 820 string prefix = strings::StrCat("worker_", index); 821 TF_RETURN_IF_ERROR(writer->WriteScalar( 822 full_name(strings::StrCat(prefix, "_input_size")), 823 workers_[index].input.size())); 824 for (int i = 0; i < workers_[index].input.size(); ++i) { 825 TF_RETURN_IF_ERROR(writer->WriteTensor( 826 full_name(strings::StrCat(prefix, "_input_", i)), 827 workers_[index].input[i])); 828 } 829 TF_RETURN_IF_ERROR(writer->WriteScalar( 830 full_name(strings::StrCat(prefix, "_outputs_size")), 831 workers_[index].outputs.size())); 832 for (int i = 0; i < workers_[index].outputs.size(); ++i) { 833 TF_RETURN_IF_ERROR(WriteOutputElemLocked( 834 writer, workers_[index].outputs[i], 835 full_name(strings::StrCat(prefix, "_outputs_", i)))); 836 } 837 if (workers_[index].is_producing) { 838 TF_RETURN_IF_ERROR(writer->WriteScalar( 839 full_name(strings::StrCat(prefix, "_is_producing")), "")); 840 } 841 return Status::OK(); 842 } 843 844 Status ReadWorkerStateLocked(IteratorStateReader* reader, int index, 845 IteratorContext* ctx) 846 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 847 string worker_prefix = strings::StrCat("worker_", index); 848 // Restore inputs. 849 int64 input_size; 850 TF_RETURN_IF_ERROR(reader->ReadScalar( 851 full_name(strings::StrCat(worker_prefix, "_input_size")), 852 &input_size)); 853 workers_[index].input.reserve(input_size); 854 for (int i = 0; i < input_size; ++i) { 855 workers_[index].input.emplace_back(); 856 TF_RETURN_IF_ERROR(reader->ReadTensor( 857 full_name(strings::StrCat(worker_prefix, "_input_", i)), 858 &workers_[index].input.back())); 859 } 860 int64 outputs_size; 861 TF_RETURN_IF_ERROR(reader->ReadScalar( 862 full_name(strings::StrCat(worker_prefix, "_outputs_size")), 863 &outputs_size)); 864 for (int i = 0; i < outputs_size; ++i) { 865 workers_[index].outputs.emplace_back(Status::OK()); 866 TF_RETURN_IF_ERROR(ReadOutputElemLocked( 867 reader, &workers_[index].outputs.back(), 868 full_name(strings::StrCat(worker_prefix, "_outputs_", i)))); 869 } 870 if (reader->Contains( 871 full_name(strings::StrCat(worker_prefix, "_is_producing")))) { 872 workers_[index].is_producing = true; 873 } else { 874 workers_[index].is_producing = false; 875 } 876 return Status::OK(); 877 } 878 879 Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer, 880 int index) 881 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 882 string prefix = strings::StrCat("worker_thread_", index); 883 if (worker_thread_states_[index].iterator != nullptr) { 884 TF_RETURN_IF_ERROR( 885 SaveInput(writer, worker_thread_states_[index].iterator)); 886 } else { 887 TF_RETURN_IF_ERROR(writer->WriteScalar( 888 full_name(strings::StrCat(prefix, "_iterator_exhausted")), "")); 889 } 890 TF_RETURN_IF_ERROR(writer->WriteScalar( 891 full_name(strings::StrCat(prefix, "_input_size")), 892 worker_thread_states_[index].input.size())); 893 for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) { 894 TF_RETURN_IF_ERROR(writer->WriteTensor( 895 full_name(strings::StrCat(prefix, "_input_", i)), 896 worker_thread_states_[index].input[i])); 897 } 898 TF_RETURN_IF_ERROR(WriteStatusLocked( 899 writer, strings::StrCat(prefix, "_iterator_creation_status"), 900 worker_thread_states_[index].iterator_creation_status)); 901 TF_RETURN_IF_ERROR(WriteOutputElemLocked( 902 writer, worker_thread_states_[index].output_elem, 903 full_name(strings::StrCat(prefix, "_output")))); 904 if (worker_thread_states_[index].end_of_sequence) { 905 TF_RETURN_IF_ERROR(writer->WriteScalar( 906 full_name(strings::StrCat(prefix, "_end_of_sequence")), "")); 907 } 908 return Status::OK(); 909 } 910 911 Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index, 912 IteratorContext* ctx) 913 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 914 string worker_prefix = strings::StrCat("worker_thread_", index); 915 // Restore inputs. 916 int64 input_size; 917 TF_RETURN_IF_ERROR(reader->ReadScalar( 918 full_name(strings::StrCat(worker_prefix, "_input_size")), 919 &input_size)); 920 worker_thread_states_[index].input.reserve(input_size); 921 for (int i = 0; i < input_size; ++i) { 922 worker_thread_states_[index].input.emplace_back(); 923 TF_RETURN_IF_ERROR(reader->ReadTensor( 924 full_name(strings::StrCat(worker_prefix, "_input_", i)), 925 &worker_thread_states_[index].input.back())); 926 } 927 // Restore iterator. 928 if (reader->Contains(full_name( 929 strings::StrCat(worker_prefix, "_iterator_exhausted")))) { 930 worker_thread_states_[index].iterator.reset(); 931 } else { 932 std::unique_ptr<IteratorBase> iterator; 933 Status s = MakeIteratorFromInputElement( 934 ctx, worker_thread_states_[index].input, index, 935 *instantiated_captured_func_, prefix(), &iterator); 936 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); 937 worker_thread_states_[index].iterator.swap(iterator); 938 } 939 TF_RETURN_IF_ERROR(ReadStatusLocked( 940 reader, strings::StrCat(worker_prefix, "_iterator_creation_status"), 941 &worker_thread_states_[index].iterator_creation_status)); 942 TF_RETURN_IF_ERROR(ReadOutputElemLocked( 943 reader, &worker_thread_states_[index].output_elem, 944 full_name(strings::StrCat(worker_prefix, "_output")))); 945 if (reader->Contains(full_name( 946 strings::StrCat(worker_prefix, "_end_of_sequence")))) { 947 worker_thread_states_[index].end_of_sequence = true; 948 } else { 949 worker_thread_states_[index].end_of_sequence = false; 950 } 951 return Status::OK(); 952 } 953 954 Status WriteOutputElemLocked(IteratorStateWriter* writer, 955 const OutputElem& output_elem, 956 const string& prefix) 957 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 958 TF_RETURN_IF_ERROR(WriteStatusLocked( 959 writer, strings::StrCat(prefix, "_status"), output_elem.status)); 960 TF_RETURN_IF_ERROR( 961 writer->WriteScalar(strings::StrCat(prefix, "_output_size"), 962 output_elem.output.size())); 963 for (int i = 0; i < output_elem.output.size(); ++i) { 964 TF_RETURN_IF_ERROR(writer->WriteTensor( 965 strings::StrCat(prefix, "_output_", i), output_elem.output[i])); 966 } 967 return Status::OK(); 968 } 969 970 Status ReadOutputElemLocked(IteratorStateReader* reader, 971 OutputElem* output_elem, const string& prefix) 972 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 973 TF_RETURN_IF_ERROR(ReadStatusLocked( 974 reader, strings::StrCat(prefix, "_status"), &output_elem->status)); 975 int64 output_size; 976 TF_RETURN_IF_ERROR(reader->ReadScalar( 977 strings::StrCat(prefix, "_output_size"), &output_size)); 978 output_elem->output.reserve(output_size); 979 for (int i = 0; i < output_size; ++i) { 980 output_elem->output.emplace_back(); 981 TF_RETURN_IF_ERROR( 982 reader->ReadTensor(strings::StrCat(prefix, "_output_", i), 983 &output_elem->output.back())); 984 } 985 return Status::OK(); 986 } 987 988 Status WriteStatusLocked(IteratorStateWriter* writer, 989 const string& prefix, const Status& status) 990 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 991 TF_RETURN_IF_ERROR( 992 writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), 993 static_cast<int64>(status.code()))); 994 if (!status.ok()) { 995 TF_RETURN_IF_ERROR( 996 writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")), 997 status.error_message())); 998 } 999 return Status::OK(); 1000 } 1001 1002 Status ReadStatusLocked(IteratorStateReader* reader, const string& prefix, 1003 Status* status) 1004 EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { 1005 int64 code_int; 1006 TF_RETURN_IF_ERROR(reader->ReadScalar( 1007 full_name(strings::StrCat(prefix, "_code")), &code_int)); 1008 error::Code code = static_cast<error::Code>(code_int); 1009 1010 if (code != error::Code::OK) { 1011 string error_message; 1012 TF_RETURN_IF_ERROR(reader->ReadScalar( 1013 full_name(strings::StrCat(prefix, "_msg")), &error_message)); 1014 *status = Status(code, error_message); 1015 } else { 1016 *status = Status::OK(); 1017 } 1018 return Status::OK(); 1019 } 1020 1021 // Mutex & condition variable to guard mutable iterator internals and 1022 // coordinate among worker threads and client thread[s]. 1023 mutex mu_ ACQUIRED_BEFORE(ckpt_mu_); 1024 // The main thread waits on this condition variable if running in sloppy 1025 // mode and no values are available. 1026 condition_variable sloppy_cond_var_; 1027 // Mutex used to wait for a consistent state while checkpointing. 1028 // Only Save and Restore require an exclusive lock on this mutex. In 1029 // other scenarios we just acquire a shared lock so the pipeline's 1030 // performance should not be affected in the absence of checkpointing. 1031 // A thread must not wait on any condition variable while holding 1032 // `ckpt_mu_` in either shared or exclusive modes. 1033 mutex ckpt_mu_; 1034 1035 // The iterator producing elements which are converted to datasets by 1036 // the dataset()->captured_func_ then interleaved together. 1037 // input_impl_ is reset when we have exhausted its input. 1038 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 1039 1040 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_; 1041 1042 // The WorkerState structs the worker threads operate on. 1043 // workers_ elements are in at most one of interleave_ and staging_. 1044 std::vector<WorkerState> workers_ GUARDED_BY(mu_); 1045 1046 // Stores the temporary state of WorkerThreads which is not stored in 1047 // WorkerState. This is used for checkpointing purposes only. 1048 std::vector<WorkerThreadState> worker_thread_states_ GUARDED_BY(ckpt_mu_); 1049 1050 // Indices in `workers_` of iterators to interleave. 1051 std::vector<int64> interleave_indices_ GUARDED_BY(mu_); 1052 // Indices in `workers_` of prefetched iterators. 1053 std::deque<int64> staging_indices_ GUARDED_BY(mu_); 1054 1055 // The index into output_elements_ for next element to produce. 1056 size_t next_index_ GUARDED_BY(mu_) = 0; 1057 // The number of items produced so far within the block 1058 size_t block_count_ GUARDED_BY(mu_) = 0; 1059 // Flag to instruct the worker threads to exit. 1060 bool cancelled_ GUARDED_BY(mu_) = false; 1061 // The worker threads. This must be last to ensure the 1062 // threads have exited before any other members are deallocated. 1063 // TODO(b/65178177): Avoid allocating additional threads. 1064 std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); 1065 }; 1066 1067 const DatasetBase* const input_; 1068 const NameAttrList interleave_func_; 1069 const std::unique_ptr<CapturedFunction> captured_func_; 1070 const int64 cycle_length_; 1071 const int64 block_length_; 1072 const bool sloppy_; 1073 const int64 buffer_output_elements_; 1074 const int64 prefetch_input_elements_; 1075 const DataTypeVector output_types_; 1076 const std::vector<PartialTensorShape> output_shapes_; 1077 }; 1078 1079 DataTypeVector output_types_; 1080 std::vector<PartialTensorShape> output_shapes_; 1081 NameAttrList interleave_func_; 1082 }; 1083 1084 REGISTER_KERNEL_BUILDER( 1085 Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU), 1086 ParallelInterleaveDatasetOp); 1087 1088 } // namespace 1089 } // namespace data 1090 } // namespace tensorflow 1091