1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <deque> 16 17 #include "tensorflow/core/framework/partial_tensor_shape.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/kernels/data/dataset.h" 20 #include "tensorflow/core/lib/core/error_codes.pb.h" 21 22 namespace tensorflow { 23 24 namespace { 25 26 // See documentation in ../ops/dataset_ops.cc for a high-level 27 // description of the following op. 28 29 class PrefetchDatasetOp : public UnaryDatasetOpKernel { 30 public: 31 explicit PrefetchDatasetOp(OpKernelConstruction* ctx) 32 : UnaryDatasetOpKernel(ctx) {} 33 34 protected: 35 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 36 DatasetBase** output) override { 37 int64 buffer_size; 38 OP_REQUIRES_OK( 39 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 40 OP_REQUIRES(ctx, buffer_size > 0, 41 errors::InvalidArgument("buffer_size must be > 0")); 42 43 *output = new Dataset(ctx, input, buffer_size); 44 } 45 46 private: 47 class Dataset : public GraphDatasetBase { 48 public: 49 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size) 50 : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) { 51 input_->Ref(); 52 } 53 54 ~Dataset() override { input_->Unref(); } 55 56 std::unique_ptr<IteratorBase> MakeIterator( 57 const string& prefix) const override { 58 return std::unique_ptr<IteratorBase>( 59 new Iterator({this, strings::StrCat(prefix, "::Prefetch")})); 60 } 61 62 const DataTypeVector& output_dtypes() const override { 63 return input_->output_dtypes(); 64 } 65 const std::vector<PartialTensorShape>& output_shapes() const override { 66 return input_->output_shapes(); 67 } 68 69 string DebugString() override { return "PrefetchDatasetOp::Dataset"; } 70 71 protected: 72 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 73 Node** output) const override { 74 Node* input_graph_node = nullptr; 75 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); 76 Node* buffer_size = nullptr; 77 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 78 TF_RETURN_IF_ERROR( 79 b->AddDataset(this, {input_graph_node, buffer_size}, output)); 80 return Status::OK(); 81 } 82 83 private: 84 class Iterator : public DatasetIterator<Dataset> { 85 public: 86 explicit Iterator(const Params& params) 87 : DatasetIterator<Dataset>(params), 88 input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} 89 90 ~Iterator() override { 91 // Signal the prefetch thread to terminate it. We will then 92 // join that thread when we delete `this->prefetch_thread_`. 93 // 94 // TODO(mrry): Replace this cancellation logic with a 95 // CancellationManager. The syntax would be more heavyweight, 96 // but it would be possible to thread a cancellation manager 97 // through the IteratorContext to upstream, 98 // potentially-blocking iterators, when we add these. 99 { 100 mutex_lock l(mu_); 101 cancelled_ = true; 102 cond_var_.notify_all(); 103 } 104 } 105 106 Status GetNextInternal(IteratorContext* ctx, 107 std::vector<Tensor>* out_tensors, 108 bool* end_of_sequence) override { 109 mutex_lock l(mu_); 110 TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); 111 112 while (true) { 113 // Wait until the next element in the buffer has been 114 // produced, or we are shutting down. 115 while (!cancelled_ && !prefetch_thread_finished_ && buffer_.empty()) { 116 cond_var_.wait(l); 117 } 118 119 if (cancelled_) { 120 return errors::Cancelled( 121 "PrefetchDatasetOp::Dataset::Iterator::GetNext"); 122 } 123 124 if (!buffer_.empty()) { 125 // A new element is available. Forward the status from 126 // computing it, and (if we successfully got an element) 127 // the output values. 128 Status s = buffer_.front().status; 129 if (s.ok()) { 130 *out_tensors = std::move(buffer_.front().value); 131 } 132 buffer_.pop_front(); 133 *end_of_sequence = false; 134 135 // Wake the prefetch thread, in case it has been waiting 136 // for space in the buffer. 137 // Also wake up threads from other calls to GetNext. 138 // TODO(mrry): Consider using different condition variables 139 // for GetNext and Prefetch. 140 cond_var_.notify_all(); 141 return s; 142 } else if (prefetch_thread_finished_) { 143 *end_of_sequence = true; 144 return Status::OK(); 145 } 146 } 147 } 148 149 protected: 150 Status SaveInternal(IteratorStateWriter* writer) override { 151 // Acquire both locks to ensure that the prefetch thread and 152 // all GetNext threads are blocked. 153 mutex_lock parent_l(parent_mu_); 154 mutex_lock l(mu_); 155 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 156 TF_RETURN_IF_ERROR( 157 writer->WriteScalar(full_name("buffer_size"), buffer_.size())); 158 for (size_t i = 0; i < buffer_.size(); i++) { 159 auto& buffer_element = buffer_[i]; 160 TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status)); 161 if (buffer_element.status.ok()) { 162 TF_RETURN_IF_ERROR(writer->WriteScalar( 163 full_name(strings::StrCat("buffer[", i, "].size")), 164 buffer_element.value.size())); 165 for (size_t j = 0; j < buffer_element.value.size(); j++) { 166 TF_RETURN_IF_ERROR(writer->WriteTensor( 167 full_name(strings::StrCat("buffer[", i, "][", j, "]")), 168 buffer_element.value[j])); 169 } 170 } 171 } 172 return Status::OK(); 173 } 174 175 Status RestoreInternal(IteratorContext* ctx, 176 IteratorStateReader* reader) override { 177 mutex_lock parent_l(parent_mu_); 178 mutex_lock l(mu_); 179 buffer_.clear(); 180 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 181 size_t buffer_size; 182 { 183 int64 temp; 184 TF_RETURN_IF_ERROR( 185 reader->ReadScalar(full_name("buffer_size"), &temp)); 186 buffer_size = static_cast<size_t>(temp); 187 } 188 for (size_t i = 0; i < buffer_size; i++) { 189 buffer_.emplace_back(); 190 auto& buffer_element = buffer_.back(); 191 TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status)); 192 if (buffer_element.status.ok()) { 193 size_t value_size; 194 { 195 int64 temp; 196 TF_RETURN_IF_ERROR(reader->ReadScalar( 197 full_name(strings::StrCat("buffer[", i, "].size")), &temp)); 198 value_size = static_cast<size_t>(temp); 199 } 200 buffer_element.value.reserve(value_size); 201 for (size_t j = 0; j < value_size; j++) { 202 buffer_element.value.emplace_back(); 203 TF_RETURN_IF_ERROR(reader->ReadTensor( 204 full_name(strings::StrCat("buffer[", i, "][", j, "]")), 205 &buffer_element.value.back())); 206 } 207 } 208 } 209 return Status::OK(); 210 } 211 212 private: 213 // A buffer element comprises a status and (if that status is 214 // OK) a vector of tensors, representing an element of the input dataset. 215 struct BufferElement { 216 // The producer sets `status` if getting the input element fails. 217 Status status; 218 // The buffered data element. 219 std::vector<Tensor> value; 220 }; 221 222 Status EnsurePrefetchThreadStarted(IteratorContext* ctx) 223 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 224 if (!prefetch_thread_) { 225 prefetch_thread_.reset( 226 ctx->env()->StartThread({}, "prefetch_thread", 227 std::bind(&Iterator::PrefetchThread, this, 228 new IteratorContext(*ctx)))); 229 } 230 return Status::OK(); 231 } 232 233 // Prefetches elements of the input, storing results in an internal 234 // buffer. 235 // 236 // It owns the iterator context passed to it. 237 void PrefetchThread(IteratorContext* ctx) { 238 std::unique_ptr<IteratorContext> cleanup(ctx); 239 while (true) { 240 std::vector<Tensor> value; 241 242 // 1. Wait for a slot in the buffer. 243 { 244 mutex_lock l(mu_); 245 while (!cancelled_ && buffer_.size() == dataset()->buffer_size_) { 246 cond_var_.wait(l); 247 } 248 249 if (cancelled_) { 250 return; 251 } 252 } 253 254 // 2. Read the next element. 255 // Acquire the parent lock since we will be reading an element 256 // from the input iterator. Note that we do not wish to release 257 // this lock till we have added the fetched element to the 258 // `buffer_` else there will be local state that may be missed 259 // by SaveInternal. 260 mutex_lock parent_l(parent_mu_); 261 bool end_of_sequence; 262 BufferElement buffer_element; 263 buffer_element.status = input_impl_->GetNext( 264 ctx, &buffer_element.value, &end_of_sequence); 265 if (buffer_element.status.ok() && end_of_sequence) { 266 mutex_lock l(mu_); 267 prefetch_thread_finished_ = true; 268 cond_var_.notify_all(); 269 return; 270 } 271 272 // 3. Signal that the element has been produced. 273 { 274 mutex_lock l(mu_); 275 buffer_.push_back(std::move(buffer_element)); 276 cond_var_.notify_all(); 277 } 278 } 279 } 280 281 Status WriteStatus(IteratorStateWriter* writer, size_t index, 282 const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 283 TF_RETURN_IF_ERROR(writer->WriteScalar( 284 CodeKey(index), static_cast<int64>(status.code()))); 285 if (!status.ok()) { 286 TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), 287 status.error_message())); 288 } 289 return Status::OK(); 290 } 291 292 Status ReadStatus(IteratorStateReader* reader, size_t index, 293 Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 294 int64 code_int; 295 TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); 296 error::Code code = static_cast<error::Code>(code_int); 297 298 if (code != error::Code::OK) { 299 string error_message; 300 TF_RETURN_IF_ERROR( 301 reader->ReadScalar(ErrorMessageKey(index), &error_message)); 302 *status = Status(code, error_message); 303 } else { 304 *status = Status::OK(); 305 } 306 return Status::OK(); 307 } 308 309 string CodeKey(size_t index) { 310 return full_name(strings::StrCat("status[", index, "].code")); 311 } 312 313 string ErrorMessageKey(size_t index) { 314 return full_name(strings::StrCat("status[", index, "].error_message")); 315 } 316 317 // This mutex is used to ensure exclusivity between multiple threads 318 // reading/writing this iterator's local state. 319 mutex mu_; 320 // This mutex is used to ensure exclusivity between multiple threads 321 // accessing the parent iterator. We keep this separate from `mu_` to 322 // allow prefetching to run in parallel with GetNext calls. 323 mutex parent_mu_ ACQUIRED_BEFORE(mu_); 324 const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); 325 condition_variable cond_var_; 326 std::deque<BufferElement> buffer_ GUARDED_BY(mu_); 327 std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); 328 bool cancelled_ GUARDED_BY(mu_) = false; 329 bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; 330 }; 331 332 const DatasetBase* const input_; 333 const int64 buffer_size_; 334 }; 335 }; 336 337 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU), 338 PrefetchDatasetOp); 339 340 } // namespace 341 342 } // namespace tensorflow 343