Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include <deque>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/resource_mgr.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/kernels/batch_util.h"
     28 #include "tensorflow/core/kernels/queue_op.h"
     29 #include "tensorflow/core/kernels/typed_queue.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/random/philox_random.h"
     32 #include "tensorflow/core/lib/random/random.h"
     33 #include "tensorflow/core/lib/random/random_distributions.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 #include "tensorflow/core/platform/mutex.h"
     37 #include "tensorflow/core/platform/thread_annotations.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace tensorflow {
     41 
     42 class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
     43  public:
     44   RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
     45                      int64 seed2, const DataTypeVector& component_dtypes,
     46                      const std::vector<TensorShape>& component_shapes,
     47                      const string& name);
     48 
     49   Status Initialize() override;  // Must be called before any other method.
     50 
     51   // Implementations of QueueInterface methods --------------------------------
     52   void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
     53                   DoneCallback callback) override;
     54   void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
     55                       DoneCallback callback) override;
     56   void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
     57   void TryDequeueMany(int num_elements, OpKernelContext* ctx,
     58                       bool allow_small_batch,
     59                       CallbackWithTuple callback) override;
     60   Status MatchesNodeDef(const NodeDef& node_def) override;
     61 
     62   int32 size() override {
     63     mutex_lock lock(mu_);
     64     return queues_[0].size();
     65   }
     66 
     67  private:
     68   ~RandomShuffleQueue() override {}
     69 
     70   // Helper for dequeuing a single random element from queues_.
     71   void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
     72       EXCLUSIVE_LOCKS_REQUIRED(mu_);
     73 
     74   static Status GetElementComponentFromBatch(const Tuple& tuple, int64 index,
     75                                              int component,
     76                                              OpKernelContext* ctx,
     77                                              PersistentTensor* out_tensor);
     78 
     79   const int32 min_after_dequeue_;
     80   const int64 original_seed_;
     81   const int64 original_seed2_;
     82 
     83   random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
     84   random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
     85 
     86   TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
     87 };
     88 
     89 RandomShuffleQueue::RandomShuffleQueue(
     90     int32 capacity, int32 min_after_dequeue, int64 seed, int64 seed2,
     91     const DataTypeVector& component_dtypes,
     92     const std::vector<TensorShape>& component_shapes, const string& name)
     93     : TypedQueue(capacity, component_dtypes, component_shapes, name),
     94       min_after_dequeue_(min_after_dequeue),
     95       original_seed_(seed),
     96       original_seed2_(seed2),
     97       generator_(&parent_generator_) {
     98   if (seed == 0 && seed2 == 0) {
     99     // If both seeds are unspecified, use completely random seeds.
    100     seed = random::New64();
    101     seed2 = random::New64();
    102   }
    103   parent_generator_ = random::PhiloxRandom(seed, seed2);
    104 }
    105 
    106 Status RandomShuffleQueue::Initialize() {
    107   TF_RETURN_IF_ERROR(TypedQueue::Initialize());
    108 
    109   mutex_lock lock(mu_);
    110   for (int i = 0; i < num_components(); ++i) {
    111     queues_[i].reserve(min_after_dequeue_);
    112   }
    113   return Status::OK();
    114 }
    115 
    116 void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
    117   DCHECK_GT(queues_[0].size(), size_t{0});
    118   int64 index = generator_() % queues_[0].size();
    119   (*tuple).reserve(num_components());
    120   for (int i = 0; i < num_components(); ++i) {
    121     (*tuple).push_back(*queues_[i][index].AccessTensor(ctx));
    122     queues_[i][index] = queues_[i].back();
    123     queues_[i].pop_back();
    124   }
    125 }
    126 
    127 void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
    128                                     DoneCallback callback) {
    129   CancellationManager* cm = ctx->cancellation_manager();
    130   CancellationToken token = cm->get_cancellation_token();
    131   bool already_cancelled;
    132   {
    133     mutex_lock l(mu_);
    134     already_cancelled = !cm->RegisterCallback(
    135         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
    136     if (!already_cancelled) {
    137       enqueue_attempts_.emplace_back(
    138           1, callback, ctx, cm, token,
    139           [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    140             if (closed_) {
    141               attempt->context->SetStatus(errors::Cancelled(
    142                   "RandomShuffleQueue '", name_, "' is closed."));
    143               return kComplete;
    144             }
    145             if (queues_[0].size() < static_cast<size_t>(capacity_)) {
    146               for (int i = 0; i < num_components(); ++i) {
    147                 queues_[i].push_back(PersistentTensor(tuple[i]));
    148               }
    149               return kComplete;
    150             } else {
    151               return kNoProgress;
    152             }
    153           });
    154     }
    155   }
    156   if (!already_cancelled) {
    157     FlushUnlocked();
    158   } else {
    159     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
    160     callback();
    161   }
    162 }
    163 
    164 /* static */
    165 Status RandomShuffleQueue::GetElementComponentFromBatch(
    166     const Tuple& tuple, int64 index, int component, OpKernelContext* ctx,
    167     PersistentTensor* out_tensor) {
    168   TensorShape element_shape(tuple[component].shape());
    169   element_shape.RemoveDim(0);
    170   Tensor* element_access = nullptr;
    171   TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    172       tuple[component].dtype(), element_shape, out_tensor, &element_access));
    173   TF_RETURN_IF_ERROR(
    174       batch_util::CopySliceToElement(tuple[component], element_access, index));
    175   return Status::OK();
    176 }
    177 
    178 void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple,
    179                                         OpKernelContext* ctx,
    180                                         DoneCallback callback) {
    181   const int64 batch_size = tuple[0].dim_size(0);
    182   if (batch_size == 0) {
    183     callback();
    184     return;
    185   }
    186 
    187   CancellationManager* cm = ctx->cancellation_manager();
    188   CancellationToken token = cm->get_cancellation_token();
    189   bool already_cancelled;
    190   {
    191     mutex_lock l(mu_);
    192     already_cancelled = !cm->RegisterCallback(
    193         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
    194     if (!already_cancelled) {
    195       enqueue_attempts_.emplace_back(
    196           batch_size, callback, ctx, cm, token,
    197           [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    198             if (closed_) {
    199               attempt->context->SetStatus(errors::Cancelled(
    200                   "RandomShuffleQueue '", name_, "' is closed."));
    201               return kComplete;
    202             }
    203             RunResult result = kNoProgress;
    204             while (queues_[0].size() < static_cast<size_t>(capacity_)) {
    205               result = kProgress;
    206               const int index =
    207                   tuple[0].dim_size(0) - attempt->elements_requested;
    208               for (int i = 0; i < num_components(); ++i) {
    209                 PersistentTensor element;
    210                 attempt->context->SetStatus(GetElementComponentFromBatch(
    211                     tuple, index, i, attempt->context, &element));
    212                 if (!attempt->context->status().ok()) return kComplete;
    213                 queues_[i].push_back(element);
    214               }
    215               --attempt->elements_requested;
    216               if (attempt->elements_requested == 0) {
    217                 return kComplete;
    218               }
    219             }
    220             return result;
    221           });
    222     }
    223   }
    224   if (!already_cancelled) {
    225     FlushUnlocked();
    226   } else {
    227     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
    228     callback();
    229   }
    230 }
    231 
    232 void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
    233                                     CallbackWithTuple callback) {
    234   CancellationManager* cm = ctx->cancellation_manager();
    235   CancellationToken token = cm->get_cancellation_token();
    236   bool already_cancelled;
    237   {
    238     mutex_lock l(mu_);
    239     already_cancelled = !cm->RegisterCallback(
    240         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
    241     if (!already_cancelled) {
    242       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
    243       dequeue_attempts_.emplace_back(
    244           1, [callback]() { callback(Tuple()); }, ctx, cm, token,
    245           [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    246             int32 queue_size = queues_[0].size();
    247             if (closed_ && queue_size == 0) {
    248               attempt->context->SetStatus(errors::OutOfRange(
    249                   "RandomShuffleQueue '", name_, "' is closed and has ",
    250                   "insufficient elements (requested ", 1, ", current size ",
    251                   queue_size, ")"));
    252               return kComplete;
    253             }
    254             if (!closed_) queue_size -= min_after_dequeue_;
    255             if (queue_size > 0) {
    256               Tuple tuple;
    257               DequeueLocked(attempt->context, &tuple);
    258               attempt->done_callback = [callback, tuple]() { callback(tuple); };
    259               return kComplete;
    260             } else {
    261               return kNoProgress;
    262             }
    263           });
    264     }
    265   }
    266   if (!already_cancelled) {
    267     FlushUnlocked();
    268   } else {
    269     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
    270     callback(Tuple());
    271   }
    272 }
    273 
    274 void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
    275                                         bool allow_small_batch,
    276                                         CallbackWithTuple callback) {
    277   if (!specified_shapes()) {
    278     ctx->SetStatus(errors::InvalidArgument(
    279         "RandomShuffleQueue's DequeueMany and DequeueUpTo require the "
    280         "components to have specified shapes."));
    281     callback(Tuple());
    282     return;
    283   }
    284   if (num_elements == 0) {
    285     Tuple tuple;
    286     tuple.reserve(num_components());
    287     for (int i = 0; i < num_components(); ++i) {
    288       // TODO(josh11b,misard): Switch to allocate_output().  Problem is
    289       // this breaks the abstraction boundary since we don't *really*
    290       // know if and how the Tensors in the tuple we pass to callback
    291       // correspond to the outputs of *ctx.  For example, the
    292       // ReaderRead Op uses TryDequeue() to get a filename out of a
    293       // queue that is used internally by the reader and is not
    294       // associated with any output of the ReaderRead.
    295       // mrry@ adds:
    296       // Maybe we need to pass a std::function<Tensor*(...)> (or
    297       // better signature) that calls the appropriate allocator
    298       // function in addition to ctx?  (Or support a shim Allocator
    299       // that has an internal OpKernelContext*, and dispatches to the
    300       // appropriate method?)
    301       // misard@ adds:
    302       // I don't see that a std::function would help. The problem is
    303       // that at this point (allocation time) the system doesn't know
    304       // what is going to happen to the element read out of the
    305       // queue. As long as we keep the generality that TensorFlow Ops
    306       // do their own dynamic allocation in arbitrary C++ code, we
    307       // need to preserve robustness to allocating output Tensors with
    308       // the 'wrong' attributes, and fixing up with a copy. The only
    309       // improvement I can see here in the future would be to support
    310       // an optimized case where the queue 'knows' what attributes to
    311       // use, and plumbs them through here.
    312       Tensor element;
    313       Status s = ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0),
    314                                     &element);
    315       if (!s.ok()) {
    316         ctx->SetStatus(s);
    317         callback(Tuple());
    318         return;
    319       }
    320       tuple.emplace_back(element);
    321     }
    322     callback(tuple);
    323     return;
    324   }
    325 
    326   CancellationManager* cm = ctx->cancellation_manager();
    327   CancellationToken token = cm->get_cancellation_token();
    328   bool already_cancelled;
    329   {
    330     mutex_lock l(mu_);
    331     already_cancelled = !cm->RegisterCallback(
    332         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
    333     if (!already_cancelled) {
    334       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
    335       dequeue_attempts_.emplace_back(
    336           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
    337           [callback, allow_small_batch,
    338            this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    339             int32 queue_size = queues_[0].size();
    340             if (closed_ && queue_size < attempt->elements_requested) {
    341               // If we don't have enough for a full dequeue, we have
    342               // to reset the attempt tuple.
    343               if (!attempt->tuple.empty()) {
    344                 // Restore already-dequeued elements to the queue.
    345                 for (int64 i = attempt->tuple[0].dim_size(0) -
    346                                attempt->elements_requested - 1;
    347                      i >= 0; --i) {
    348                   for (int j = 0; j < num_components(); ++j) {
    349                     PersistentTensor element;
    350                     Status s = GetElementComponentFromBatch(
    351                         attempt->tuple, i, j, attempt->context, &element);
    352                     if (!s.ok()) {
    353                       attempt->context->SetStatus(
    354                           errors::DataLoss("Failed to restore element from "
    355                                            "partially-dequeued batch "
    356                                            "to RandomShuffleQueue: ",
    357                                            s.error_message()));
    358                     }
    359                     queues_[j].push_back(element);
    360                   }
    361                 }
    362               }
    363               if (allow_small_batch && !queues_[0].empty()) {
    364                 // Request all remaining elements in the queue.
    365                 queue_size = queues_[0].size();
    366                 attempt->tuple.clear();
    367                 attempt->elements_requested = queue_size;
    368               } else {
    369                 if (allow_small_batch) {
    370                   // There may be some other attempts containing
    371                   // values.  If so, we'll yield and wait for them
    372                   // to add elements to the queue.
    373                   if (!enqueue_attempts_.empty()) return kProgress;
    374                 }
    375                 if (attempt->context->status().ok()) {
    376                   attempt->context->SetStatus(errors::OutOfRange(
    377                       "RandomShuffleQueue '", name_, "' is closed and has ",
    378                       "insufficient elements (requested ",
    379                       attempt->elements_requested, ", current size ",
    380                       queue_size, ")"));
    381                 }
    382                 return kComplete;
    383               }
    384             }
    385 
    386             RunResult result = kNoProgress;
    387             if (!closed_) queue_size -= min_after_dequeue_;
    388             for (; queue_size > 0; --queue_size) {
    389               if (attempt->tuple.empty()) {
    390                 // Only allocate tuple when we have something to dequeue
    391                 // so we don't use excessive memory when there are many
    392                 // blocked dequeue attempts waiting.
    393                 attempt->tuple.reserve(num_components());
    394                 for (int i = 0; i < num_components(); ++i) {
    395                   const TensorShape shape =
    396                       ManyOutShape(i, attempt->elements_requested);
    397                   Tensor element;
    398                   attempt->context->SetStatus(attempt->context->allocate_temp(
    399                       component_dtypes_[i], shape, &element));
    400                   if (!attempt->context->status().ok()) return kComplete;
    401                   attempt->tuple.emplace_back(element);
    402                 }
    403               }
    404               result = kProgress;
    405               Tuple tuple;
    406               DequeueLocked(attempt->context, &tuple);
    407               const int index =
    408                   attempt->tuple[0].dim_size(0) - attempt->elements_requested;
    409               for (int i = 0; i < num_components(); ++i) {
    410                 attempt->context->SetStatus(batch_util::CopyElementToSlice(
    411                     std::move(tuple[i]), &attempt->tuple[i], index));
    412                 if (!attempt->context->status().ok()) return kComplete;
    413               }
    414               tuple.clear();
    415               --attempt->elements_requested;
    416               if (attempt->elements_requested == 0) {
    417                 tuple = attempt->tuple;
    418                 attempt->done_callback = [callback, tuple]() {
    419                   callback(tuple);
    420                 };
    421                 return kComplete;
    422               }
    423             }
    424             return result;
    425           });
    426     }
    427   }
    428   if (!already_cancelled) {
    429     FlushUnlocked();
    430   } else {
    431     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
    432     callback(Tuple());
    433   }
    434 }
    435 
    436 Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
    437   if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() &&
    438       !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) {
    439     return errors::InvalidArgument("Expected RandomShuffleQueue, found ",
    440                                    node_def.op());
    441   }
    442   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
    443 
    444   int32 min_after_dequeue = -1;
    445   TF_RETURN_IF_ERROR(
    446       GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue));
    447   if (min_after_dequeue != min_after_dequeue_) {
    448     return errors::InvalidArgument(
    449         "Shared queue '", name_, "' has min_after_dequeue ", min_after_dequeue_,
    450         " but requested min_after_dequeue was ", min_after_dequeue, ".");
    451   }
    452 
    453   int64 seed = -1;
    454   int64 seed2 = -1;
    455   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed));
    456   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2));
    457   if ((seed != 0 || seed2 != 0) &&
    458       (seed != original_seed_ || seed2 != original_seed2_)) {
    459     return errors::InvalidArgument(
    460         "Shared queue '", name_, "' has random seeds (", original_seed_, ", ",
    461         original_seed2_, ") but requested seeds are (", seed, ", ", seed2,
    462         ").");
    463   }
    464 
    465   TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
    466   TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
    467 
    468   return Status::OK();
    469 }
    470 
    471 // Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
    472 // backed by RandomShuffleQueue) that persists across different graph
    473 // executions, and sessions. Running this op produces a single-element
    474 // tensor of handles to Queues in the corresponding device.
    475 class RandomShuffleQueueOp : public TypedQueueOp {
    476  public:
    477   explicit RandomShuffleQueueOp(OpKernelConstruction* context)
    478       : TypedQueueOp(context) {
    479     OP_REQUIRES_OK(context,
    480                    context->GetAttr("min_after_dequeue", &min_after_dequeue_));
    481     OP_REQUIRES(context, min_after_dequeue_ >= 0,
    482                 errors::InvalidArgument("min_after_dequeue ",
    483                                         min_after_dequeue_, " must be >= 0"));
    484     OP_REQUIRES(
    485         context, min_after_dequeue_ < capacity_,
    486         errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_,
    487                                 " must be < capacity ", capacity_));
    488     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
    489     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
    490 
    491     OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
    492   }
    493 
    494  private:
    495   Status CreateResource(QueueInterface** ret) override
    496       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    497     RandomShuffleQueue* queue = new RandomShuffleQueue(
    498         capacity_, min_after_dequeue_, seed_, seed2_, component_types_,
    499         component_shapes_, cinfo_.name());
    500     return CreateTypedQueue(queue, ret);
    501   }
    502 
    503   int32 min_after_dequeue_;
    504   int64 seed_;
    505   int64 seed2_;
    506   std::vector<TensorShape> component_shapes_;
    507 
    508   TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
    509 };
    510 
    511 REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU),
    512                         RandomShuffleQueueOp);
    513 REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU),
    514                         RandomShuffleQueueOp);
    515 
    516 }  // namespace tensorflow
    517