1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/data_flow_ops.cc. 17 18 #include <deque> 19 #include <vector> 20 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/kernels/batch_util.h" 28 #include "tensorflow/core/kernels/queue_op.h" 29 #include "tensorflow/core/kernels/typed_queue.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/random/philox_random.h" 32 #include "tensorflow/core/lib/random/random.h" 33 #include "tensorflow/core/lib/random/random_distributions.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/thread_annotations.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace tensorflow { 41 42 class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > { 43 public: 44 RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed, 45 int64 seed2, const DataTypeVector& component_dtypes, 46 const std::vector<TensorShape>& component_shapes, 47 const string& name); 48 49 Status Initialize() override; // Must be called before any other method. 50 51 // Implementations of QueueInterface methods -------------------------------- 52 void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, 53 DoneCallback callback) override; 54 void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, 55 DoneCallback callback) override; 56 void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; 57 void TryDequeueMany(int num_elements, OpKernelContext* ctx, 58 bool allow_small_batch, 59 CallbackWithTuple callback) override; 60 Status MatchesNodeDef(const NodeDef& node_def) override; 61 62 int32 size() override { 63 mutex_lock lock(mu_); 64 return queues_[0].size(); 65 } 66 67 private: 68 ~RandomShuffleQueue() override {} 69 70 // Helper for dequeuing a single random element from queues_. 71 void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) 72 EXCLUSIVE_LOCKS_REQUIRED(mu_); 73 74 static Status GetElementComponentFromBatch(const Tuple& tuple, int64 index, 75 int component, 76 OpKernelContext* ctx, 77 PersistentTensor* out_tensor); 78 79 const int32 min_after_dequeue_; 80 const int64 original_seed_; 81 const int64 original_seed2_; 82 83 random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); 84 random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_); 85 86 TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue); 87 }; 88 89 RandomShuffleQueue::RandomShuffleQueue( 90 int32 capacity, int32 min_after_dequeue, int64 seed, int64 seed2, 91 const DataTypeVector& component_dtypes, 92 const std::vector<TensorShape>& component_shapes, const string& name) 93 : TypedQueue(capacity, component_dtypes, component_shapes, name), 94 min_after_dequeue_(min_after_dequeue), 95 original_seed_(seed), 96 original_seed2_(seed2), 97 generator_(&parent_generator_) { 98 if (seed == 0 && seed2 == 0) { 99 // If both seeds are unspecified, use completely random seeds. 100 seed = random::New64(); 101 seed2 = random::New64(); 102 } 103 parent_generator_ = random::PhiloxRandom(seed, seed2); 104 } 105 106 Status RandomShuffleQueue::Initialize() { 107 TF_RETURN_IF_ERROR(TypedQueue::Initialize()); 108 109 mutex_lock lock(mu_); 110 for (int i = 0; i < num_components(); ++i) { 111 queues_[i].reserve(min_after_dequeue_); 112 } 113 return Status::OK(); 114 } 115 116 void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) { 117 DCHECK_GT(queues_[0].size(), size_t{0}); 118 int64 index = generator_() % queues_[0].size(); 119 (*tuple).reserve(num_components()); 120 for (int i = 0; i < num_components(); ++i) { 121 (*tuple).push_back(*queues_[i][index].AccessTensor(ctx)); 122 queues_[i][index] = queues_[i].back(); 123 queues_[i].pop_back(); 124 } 125 } 126 127 void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, 128 DoneCallback callback) { 129 CancellationManager* cm = ctx->cancellation_manager(); 130 CancellationToken token = cm->get_cancellation_token(); 131 bool already_cancelled; 132 { 133 mutex_lock l(mu_); 134 already_cancelled = !cm->RegisterCallback( 135 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); 136 if (!already_cancelled) { 137 enqueue_attempts_.emplace_back( 138 1, callback, ctx, cm, token, 139 [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 140 if (closed_) { 141 attempt->context->SetStatus(errors::Cancelled( 142 "RandomShuffleQueue '", name_, "' is closed.")); 143 return kComplete; 144 } 145 if (queues_[0].size() < static_cast<size_t>(capacity_)) { 146 for (int i = 0; i < num_components(); ++i) { 147 queues_[i].push_back(PersistentTensor(tuple[i])); 148 } 149 return kComplete; 150 } else { 151 return kNoProgress; 152 } 153 }); 154 } 155 } 156 if (!already_cancelled) { 157 FlushUnlocked(); 158 } else { 159 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); 160 callback(); 161 } 162 } 163 164 /* static */ 165 Status RandomShuffleQueue::GetElementComponentFromBatch( 166 const Tuple& tuple, int64 index, int component, OpKernelContext* ctx, 167 PersistentTensor* out_tensor) { 168 TensorShape element_shape(tuple[component].shape()); 169 element_shape.RemoveDim(0); 170 Tensor* element_access = nullptr; 171 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 172 tuple[component].dtype(), element_shape, out_tensor, &element_access)); 173 TF_RETURN_IF_ERROR( 174 batch_util::CopySliceToElement(tuple[component], element_access, index)); 175 return Status::OK(); 176 } 177 178 void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple, 179 OpKernelContext* ctx, 180 DoneCallback callback) { 181 const int64 batch_size = tuple[0].dim_size(0); 182 if (batch_size == 0) { 183 callback(); 184 return; 185 } 186 187 CancellationManager* cm = ctx->cancellation_manager(); 188 CancellationToken token = cm->get_cancellation_token(); 189 bool already_cancelled; 190 { 191 mutex_lock l(mu_); 192 already_cancelled = !cm->RegisterCallback( 193 token, [this, cm, token]() { Cancel(kEnqueue, cm, token); }); 194 if (!already_cancelled) { 195 enqueue_attempts_.emplace_back( 196 batch_size, callback, ctx, cm, token, 197 [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 198 if (closed_) { 199 attempt->context->SetStatus(errors::Cancelled( 200 "RandomShuffleQueue '", name_, "' is closed.")); 201 return kComplete; 202 } 203 RunResult result = kNoProgress; 204 while (queues_[0].size() < static_cast<size_t>(capacity_)) { 205 result = kProgress; 206 const int index = 207 tuple[0].dim_size(0) - attempt->elements_requested; 208 for (int i = 0; i < num_components(); ++i) { 209 PersistentTensor element; 210 attempt->context->SetStatus(GetElementComponentFromBatch( 211 tuple, index, i, attempt->context, &element)); 212 if (!attempt->context->status().ok()) return kComplete; 213 queues_[i].push_back(element); 214 } 215 --attempt->elements_requested; 216 if (attempt->elements_requested == 0) { 217 return kComplete; 218 } 219 } 220 return result; 221 }); 222 } 223 } 224 if (!already_cancelled) { 225 FlushUnlocked(); 226 } else { 227 ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled")); 228 callback(); 229 } 230 } 231 232 void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx, 233 CallbackWithTuple callback) { 234 CancellationManager* cm = ctx->cancellation_manager(); 235 CancellationToken token = cm->get_cancellation_token(); 236 bool already_cancelled; 237 { 238 mutex_lock l(mu_); 239 already_cancelled = !cm->RegisterCallback( 240 token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); 241 if (!already_cancelled) { 242 // TODO(josh11b): This makes two copies of callback, avoid this if possible. 243 dequeue_attempts_.emplace_back( 244 1, [callback]() { callback(Tuple()); }, ctx, cm, token, 245 [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 246 int32 queue_size = queues_[0].size(); 247 if (closed_ && queue_size == 0) { 248 attempt->context->SetStatus(errors::OutOfRange( 249 "RandomShuffleQueue '", name_, "' is closed and has ", 250 "insufficient elements (requested ", 1, ", current size ", 251 queue_size, ")")); 252 return kComplete; 253 } 254 if (!closed_) queue_size -= min_after_dequeue_; 255 if (queue_size > 0) { 256 Tuple tuple; 257 DequeueLocked(attempt->context, &tuple); 258 attempt->done_callback = [callback, tuple]() { callback(tuple); }; 259 return kComplete; 260 } else { 261 return kNoProgress; 262 } 263 }); 264 } 265 } 266 if (!already_cancelled) { 267 FlushUnlocked(); 268 } else { 269 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); 270 callback(Tuple()); 271 } 272 } 273 274 void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, 275 bool allow_small_batch, 276 CallbackWithTuple callback) { 277 if (!specified_shapes()) { 278 ctx->SetStatus(errors::InvalidArgument( 279 "RandomShuffleQueue's DequeueMany and DequeueUpTo require the " 280 "components to have specified shapes.")); 281 callback(Tuple()); 282 return; 283 } 284 if (num_elements == 0) { 285 Tuple tuple; 286 tuple.reserve(num_components()); 287 for (int i = 0; i < num_components(); ++i) { 288 // TODO(josh11b,misard): Switch to allocate_output(). Problem is 289 // this breaks the abstraction boundary since we don't *really* 290 // know if and how the Tensors in the tuple we pass to callback 291 // correspond to the outputs of *ctx. For example, the 292 // ReaderRead Op uses TryDequeue() to get a filename out of a 293 // queue that is used internally by the reader and is not 294 // associated with any output of the ReaderRead. 295 // mrry@ adds: 296 // Maybe we need to pass a std::function<Tensor*(...)> (or 297 // better signature) that calls the appropriate allocator 298 // function in addition to ctx? (Or support a shim Allocator 299 // that has an internal OpKernelContext*, and dispatches to the 300 // appropriate method?) 301 // misard@ adds: 302 // I don't see that a std::function would help. The problem is 303 // that at this point (allocation time) the system doesn't know 304 // what is going to happen to the element read out of the 305 // queue. As long as we keep the generality that TensorFlow Ops 306 // do their own dynamic allocation in arbitrary C++ code, we 307 // need to preserve robustness to allocating output Tensors with 308 // the 'wrong' attributes, and fixing up with a copy. The only 309 // improvement I can see here in the future would be to support 310 // an optimized case where the queue 'knows' what attributes to 311 // use, and plumbs them through here. 312 Tensor element; 313 Status s = ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), 314 &element); 315 if (!s.ok()) { 316 ctx->SetStatus(s); 317 callback(Tuple()); 318 return; 319 } 320 tuple.emplace_back(element); 321 } 322 callback(tuple); 323 return; 324 } 325 326 CancellationManager* cm = ctx->cancellation_manager(); 327 CancellationToken token = cm->get_cancellation_token(); 328 bool already_cancelled; 329 { 330 mutex_lock l(mu_); 331 already_cancelled = !cm->RegisterCallback( 332 token, [this, cm, token]() { Cancel(kDequeue, cm, token); }); 333 if (!already_cancelled) { 334 // TODO(josh11b): This makes two copies of callback, avoid this if possible. 335 dequeue_attempts_.emplace_back( 336 num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token, 337 [callback, allow_small_batch, 338 this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 339 int32 queue_size = queues_[0].size(); 340 if (closed_ && queue_size < attempt->elements_requested) { 341 // If we don't have enough for a full dequeue, we have 342 // to reset the attempt tuple. 343 if (!attempt->tuple.empty()) { 344 // Restore already-dequeued elements to the queue. 345 for (int64 i = attempt->tuple[0].dim_size(0) - 346 attempt->elements_requested - 1; 347 i >= 0; --i) { 348 for (int j = 0; j < num_components(); ++j) { 349 PersistentTensor element; 350 Status s = GetElementComponentFromBatch( 351 attempt->tuple, i, j, attempt->context, &element); 352 if (!s.ok()) { 353 attempt->context->SetStatus( 354 errors::DataLoss("Failed to restore element from " 355 "partially-dequeued batch " 356 "to RandomShuffleQueue: ", 357 s.error_message())); 358 } 359 queues_[j].push_back(element); 360 } 361 } 362 } 363 if (allow_small_batch && !queues_[0].empty()) { 364 // Request all remaining elements in the queue. 365 queue_size = queues_[0].size(); 366 attempt->tuple.clear(); 367 attempt->elements_requested = queue_size; 368 } else { 369 if (allow_small_batch) { 370 // There may be some other attempts containing 371 // values. If so, we'll yield and wait for them 372 // to add elements to the queue. 373 if (!enqueue_attempts_.empty()) return kProgress; 374 } 375 if (attempt->context->status().ok()) { 376 attempt->context->SetStatus(errors::OutOfRange( 377 "RandomShuffleQueue '", name_, "' is closed and has ", 378 "insufficient elements (requested ", 379 attempt->elements_requested, ", current size ", 380 queue_size, ")")); 381 } 382 return kComplete; 383 } 384 } 385 386 RunResult result = kNoProgress; 387 if (!closed_) queue_size -= min_after_dequeue_; 388 for (; queue_size > 0; --queue_size) { 389 if (attempt->tuple.empty()) { 390 // Only allocate tuple when we have something to dequeue 391 // so we don't use excessive memory when there are many 392 // blocked dequeue attempts waiting. 393 attempt->tuple.reserve(num_components()); 394 for (int i = 0; i < num_components(); ++i) { 395 const TensorShape shape = 396 ManyOutShape(i, attempt->elements_requested); 397 Tensor element; 398 attempt->context->SetStatus(attempt->context->allocate_temp( 399 component_dtypes_[i], shape, &element)); 400 if (!attempt->context->status().ok()) return kComplete; 401 attempt->tuple.emplace_back(element); 402 } 403 } 404 result = kProgress; 405 Tuple tuple; 406 DequeueLocked(attempt->context, &tuple); 407 const int index = 408 attempt->tuple[0].dim_size(0) - attempt->elements_requested; 409 for (int i = 0; i < num_components(); ++i) { 410 attempt->context->SetStatus(batch_util::CopyElementToSlice( 411 std::move(tuple[i]), &attempt->tuple[i], index)); 412 if (!attempt->context->status().ok()) return kComplete; 413 } 414 tuple.clear(); 415 --attempt->elements_requested; 416 if (attempt->elements_requested == 0) { 417 tuple = attempt->tuple; 418 attempt->done_callback = [callback, tuple]() { 419 callback(tuple); 420 }; 421 return kComplete; 422 } 423 } 424 return result; 425 }); 426 } 427 } 428 if (!already_cancelled) { 429 FlushUnlocked(); 430 } else { 431 ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled")); 432 callback(Tuple()); 433 } 434 } 435 436 Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { 437 if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() && 438 !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) { 439 return errors::InvalidArgument("Expected RandomShuffleQueue, found ", 440 node_def.op()); 441 } 442 TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_)); 443 444 int32 min_after_dequeue = -1; 445 TF_RETURN_IF_ERROR( 446 GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue)); 447 if (min_after_dequeue != min_after_dequeue_) { 448 return errors::InvalidArgument( 449 "Shared queue '", name_, "' has min_after_dequeue ", min_after_dequeue_, 450 " but requested min_after_dequeue was ", min_after_dequeue, "."); 451 } 452 453 int64 seed = -1; 454 int64 seed2 = -1; 455 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed)); 456 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2)); 457 if ((seed != 0 || seed2 != 0) && 458 (seed != original_seed_ || seed2 != original_seed2_)) { 459 return errors::InvalidArgument( 460 "Shared queue '", name_, "' has random seeds (", original_seed_, ", ", 461 original_seed2_, ") but requested seeds are (", seed, ", ", seed2, 462 ")."); 463 } 464 465 TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def)); 466 TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def)); 467 468 return Status::OK(); 469 } 470 471 // Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one 472 // backed by RandomShuffleQueue) that persists across different graph 473 // executions, and sessions. Running this op produces a single-element 474 // tensor of handles to Queues in the corresponding device. 475 class RandomShuffleQueueOp : public TypedQueueOp { 476 public: 477 explicit RandomShuffleQueueOp(OpKernelConstruction* context) 478 : TypedQueueOp(context) { 479 OP_REQUIRES_OK(context, 480 context->GetAttr("min_after_dequeue", &min_after_dequeue_)); 481 OP_REQUIRES(context, min_after_dequeue_ >= 0, 482 errors::InvalidArgument("min_after_dequeue ", 483 min_after_dequeue_, " must be >= 0")); 484 OP_REQUIRES( 485 context, min_after_dequeue_ < capacity_, 486 errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_, 487 " must be < capacity ", capacity_)); 488 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_)); 489 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_)); 490 491 OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); 492 } 493 494 private: 495 Status CreateResource(QueueInterface** ret) override 496 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 497 RandomShuffleQueue* queue = new RandomShuffleQueue( 498 capacity_, min_after_dequeue_, seed_, seed2_, component_types_, 499 component_shapes_, cinfo_.name()); 500 return CreateTypedQueue(queue, ret); 501 } 502 503 int32 min_after_dequeue_; 504 int64 seed_; 505 int64 seed2_; 506 std::vector<TensorShape> component_shapes_; 507 508 TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp); 509 }; 510 511 REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU), 512 RandomShuffleQueueOp); 513 REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU), 514 RandomShuffleQueueOp); 515 516 } // namespace tensorflow 517