Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // See docs in ../ops/data_flow_ops.cc.
     16 
     17 #include <deque>
     18 #include <queue>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/kernels/batch_util.h"
     26 #include "tensorflow/core/kernels/priority_queue.h"
     27 #include "tensorflow/core/kernels/queue_base.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/gtl/priority_queue_util.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace tensorflow {
     35 
     36 PriorityQueue::PriorityQueue(int32 capacity,
     37                              const DataTypeVector& component_dtypes,
     38                              const std::vector<TensorShape>& component_shapes,
     39                              const string& name)
     40     : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
     41 
     42 Status PriorityQueue::Initialize() {
     43   Status s = TypedQueue::Initialize();
     44   if (!s.ok()) return s;
     45 
     46   mutex_lock lock(mu_);
     47   if (component_dtypes_[0] != DT_INT64) {
     48     return errors::InvalidArgument(
     49         "PriorityQueue priority index component must be type int64, but "
     50         "dtype is: ",
     51         DataTypeString(component_dtypes_[0]));
     52   }
     53   if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
     54     return errors::InvalidArgument(
     55         "PriorityQueue priority index component must be a scalar, but shape "
     56         "is: ",
     57         component_shapes_[0].DebugString());
     58   }
     59   return Status::OK();
     60 }
     61 
     62 void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
     63   DCHECK_GT(queues_[0].size(), 0);
     64   (*tuple).reserve(num_components());
     65   for (int i = 0; i < num_components(); ++i) {
     66     PersistentTensor persistent_tensor = gtl::ConsumeTop(&queues_[i]).second;
     67     (*tuple).push_back(*persistent_tensor.AccessTensor(ctx));
     68   }
     69 }
     70 
     71 void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
     72                                DoneCallback callback) {
     73   CancellationManager* cm = ctx->cancellation_manager();
     74   CancellationToken token = cm->get_cancellation_token();
     75   bool already_cancelled;
     76   {
     77     mutex_lock l(mu_);
     78     already_cancelled = !cm->RegisterCallback(
     79         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
     80     if (!already_cancelled) {
     81       enqueue_attempts_.emplace_back(
     82           1, callback, ctx, cm, token,
     83           [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
     84             if (closed_) {
     85               attempt->context->SetStatus(
     86                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
     87               return kComplete;
     88             }
     89             if (queues_[0].size() < static_cast<size_t>(capacity_)) {
     90               if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
     91                 attempt->context->SetStatus(errors::InvalidArgument(
     92                     "Expected the priority element to be a scalar, but "
     93                     "received shape: ",
     94                     tuple[0].shape().DebugString()));
     95                 return kComplete;
     96               }
     97               const int64 priority = tuple[0].scalar<int64>()();
     98               for (int i = 0; i < num_components(); ++i) {
     99                 queues_[i].emplace(priority, PersistentTensor(tuple[i]));
    100               }
    101               return kComplete;
    102             } else {
    103               return kNoProgress;
    104             }
    105           });
    106     }
    107   }
    108   if (!already_cancelled) {
    109     FlushUnlocked();
    110   } else {
    111     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
    112     callback();
    113   }
    114 }
    115 
    116 /* static */
    117 Status PriorityQueue::GetElementComponentFromBatch(
    118     const PriorityQueue::Tuple& tuple, int index, int component,
    119     OpKernelContext* ctx, PersistentTensor* out_tensor) {
    120   TensorShape element_shape(tuple[component].shape());
    121   element_shape.RemoveDim(0);
    122   Tensor* element_access = nullptr;
    123   TF_RETURN_IF_ERROR(ctx->allocate_persistent(
    124       tuple[component].dtype(), element_shape, out_tensor, &element_access));
    125   TF_RETURN_IF_ERROR(
    126       batch_util::CopySliceToElement(tuple[component], element_access, index));
    127   return Status::OK();
    128 }
    129 
    130 void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
    131                                    DoneCallback callback) {
    132   const int64 batch_size = tuple[0].dim_size(0);
    133   if (batch_size == 0) {
    134     callback();
    135     return;
    136   }
    137 
    138   CancellationManager* cm = ctx->cancellation_manager();
    139   CancellationToken token = cm->get_cancellation_token();
    140   bool already_cancelled;
    141   {
    142     mutex_lock l(mu_);
    143     already_cancelled = !cm->RegisterCallback(
    144         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
    145     if (!already_cancelled) {
    146       enqueue_attempts_.emplace_back(
    147           batch_size, callback, ctx, cm, token,
    148           [tuple, this, ctx](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    149             if (closed_) {
    150               attempt->context->SetStatus(
    151                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
    152               return kComplete;
    153             }
    154             RunResult result = kNoProgress;
    155             while (queues_[0].size() < static_cast<size_t>(capacity_)) {
    156               result = kProgress;
    157               const int index =
    158                   tuple[0].dim_size(0) - attempt->elements_requested;
    159 
    160               PersistentTensor priority_element;
    161               attempt->context->SetStatus(GetElementComponentFromBatch(
    162                   tuple, index, 0, attempt->context, &priority_element));
    163               if (!attempt->context->status().ok()) return kComplete;
    164               Tensor* priority_tensor = priority_element.AccessTensor(ctx);
    165               if (!TensorShapeUtils::IsScalar(priority_tensor->shape())) {
    166                 attempt->context->SetStatus(errors::InvalidArgument(
    167                     "Expected the priority element to be a scalar, but "
    168                     "received shape: ",
    169                     priority_tensor->shape().DebugString()));
    170                 return kComplete;
    171               }
    172               const int64 priority = priority_tensor->scalar<int64>()();
    173               for (int i = 0; i < num_components(); ++i) {
    174                 PersistentTensor element;
    175                 attempt->context->SetStatus(GetElementComponentFromBatch(
    176                     tuple, index, i, attempt->context, &element));
    177                 if (!attempt->context->status().ok()) return kComplete;
    178                 queues_[i].emplace(priority, element);
    179               }
    180               --attempt->elements_requested;
    181               if (attempt->elements_requested == 0) {
    182                 return kComplete;
    183               }
    184             }
    185             return result;
    186           });
    187     }
    188   }
    189   if (!already_cancelled) {
    190     FlushUnlocked();
    191   } else {
    192     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
    193     callback();
    194   }
    195 }
    196 
    197 void PriorityQueue::TryDequeue(OpKernelContext* ctx,
    198                                CallbackWithTuple callback) {
    199   CancellationManager* cm = ctx->cancellation_manager();
    200   CancellationToken token = cm->get_cancellation_token();
    201   bool already_cancelled;
    202   {
    203     mutex_lock l(mu_);
    204     already_cancelled = !cm->RegisterCallback(
    205         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
    206     if (!already_cancelled) {
    207       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
    208       dequeue_attempts_.emplace_back(
    209           1, [callback]() { callback(Tuple()); }, ctx, cm, token,
    210           [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    211             const int32 s = queues_[0].size();
    212             if (closed_ && s == 0) {
    213               attempt->context->SetStatus(errors::OutOfRange(
    214                   "PriorityQueue '", name_, "' is closed and has ",
    215                   "insufficient elements (requested ", 1, ", current size ", s,
    216                   ")"));
    217               return kComplete;
    218             }
    219             if (s > 0) {
    220               Tuple tuple;
    221               DequeueLocked(attempt->context, &tuple);
    222               attempt->done_callback = [callback, tuple]() { callback(tuple); };
    223               return kComplete;
    224             } else {
    225               return kNoProgress;
    226             }
    227           });
    228     }
    229   }
    230   if (!already_cancelled) {
    231     FlushUnlocked();
    232   } else {
    233     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
    234     callback(Tuple());
    235   }
    236 }
    237 
    238 void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
    239                                    bool allow_small_batch,
    240                                    CallbackWithTuple callback) {
    241   if (!specified_shapes()) {
    242     ctx->SetStatus(
    243         errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
    244                                 "components to have specified shapes."));
    245     callback(Tuple());
    246     return;
    247   }
    248   if (num_elements == 0) {
    249     Tuple tuple;
    250     tuple.reserve(num_components());
    251     for (int i = 0; i < num_components(); ++i) {
    252       // TODO(josh11b,misard): Switch to allocate_output().  Problem is
    253       // this breaks the abstraction boundary since we don't *really*
    254       // know if and how the Tensors in the tuple we pass to callback
    255       // correspond to the outputs of *ctx.  For example, the
    256       // ReaderRead Op uses TryDequeue() to get a filename out of a
    257       // queue that is used internally by the reader and is not
    258       // associated with any output of the ReaderRead.
    259       // mrry@ adds:
    260       // Maybe we need to pass a std::function<Tensor*(...)> (or
    261       // better signature) that calls the appropriate allocator
    262       // function in addition to ctx?  (Or support a shim Allocator
    263       // that has an internal OpKernelContext*, and dispatches to the
    264       // appropriate method?)
    265       // misard@ adds:
    266       // I don't see that a std::function would help. The problem is
    267       // that at this point (allocation time) the system doesn't know
    268       // what is going to happen to the element read out of the
    269       // queue. As long as we keep the generality that TensorFlow Ops
    270       // do their own dynamic allocation in arbitrary C++ code, we
    271       // need to preserve robustness to allocating output Tensors with
    272       // the 'wrong' attributes, and fixing up with a copy. The only
    273       // improvement I can see here in the future would be to support
    274       // an optimized case where the queue 'knows' what attributes to
    275       // use, and plumbs them through here.
    276       Tensor element;
    277       Status status = ctx->allocate_temp(component_dtypes_[i],
    278                                          ManyOutShape(i, 0), &element);
    279       if (!status.ok()) {
    280         ctx->SetStatus(status);
    281         callback(Tuple());
    282         return;
    283       }
    284       tuple.emplace_back(element);
    285     }
    286     callback(tuple);
    287     return;
    288   }
    289 
    290   CancellationManager* cm = ctx->cancellation_manager();
    291   CancellationToken token = cm->get_cancellation_token();
    292   bool already_cancelled;
    293   {
    294     mutex_lock l(mu_);
    295     already_cancelled = !cm->RegisterCallback(
    296         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
    297     if (!already_cancelled) {
    298       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
    299       dequeue_attempts_.emplace_back(
    300           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
    301           [callback, this,
    302            allow_small_batch](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    303             int32 s = queues_[0].size();
    304             // Return OutOfRange if closed and there are fewer elements
    305             // available than requested.  *Unless* allow_small_batch
    306             // is true, in which case we return as many elements as
    307             // possible.
    308             if (closed_) {
    309               if (s == 0 ||
    310                   (!allow_small_batch && s < attempt->elements_requested)) {
    311                 attempt->context->SetStatus(errors::OutOfRange(
    312                     "PriorityQueue '", name_, "' is closed and has ",
    313                     "insufficient elements (requested ",
    314                     attempt->elements_requested, ", current size ", s, ")"));
    315                 return kComplete;
    316               }
    317             }
    318 
    319             // The PriorityQueue is expected to always return a
    320             // sorted set of entries.  In order to do this, the underlying
    321             // queue must have at least this many entries already.
    322             // Doing the dynamic thing and pulling out a portion at a
    323             // time leads to unordered output in calls to DequeueMany.
    324             //
    325             // An alternative solution is to store the attempt tuple
    326             // entries in an identical priority_queue and push onto
    327             // this queue dynamically, then when it is full, do all
    328             // the Tensor concatenation at the very end.
    329             // TODO(ebrevdo): Change approach if this leads to locking issues.
    330             if (s < attempt->elements_requested) {
    331               // If we have no elements at all, then wait.
    332               // Otherwise proceed if closed and allow small batch is true.
    333               // Otherwise wait until we have more enqueued elements.
    334               if (s == 0 || !(closed_ && allow_small_batch)) {
    335                 return kNoProgress;
    336               }
    337             }
    338 
    339             RunResult result = kNoProgress;
    340             for (; s > 0; --s) {
    341               if (attempt->tuple.empty()) {
    342                 // Only allocate tuple when we have something to dequeue
    343                 // so we don't use excessive memory when there are many
    344                 // blocked dequeue attempts waiting.
    345                 attempt->tuple.reserve(num_components());
    346                 for (int i = 0; i < num_components(); ++i) {
    347                   const TensorShape shape =
    348                       ManyOutShape(i, attempt->elements_requested);
    349                   Tensor element;
    350                   attempt->context->SetStatus(attempt->context->allocate_temp(
    351                       component_dtypes_[i], shape, &element));
    352                   if (!attempt->context->status().ok()) return kComplete;
    353                   attempt->tuple.emplace_back(element);
    354                 }
    355               }
    356               result = kProgress;
    357               Tuple tuple;
    358               DequeueLocked(attempt->context, &tuple);
    359               const int index =
    360                   attempt->tuple[0].dim_size(0) - attempt->elements_requested;
    361               for (int i = 0; i < num_components(); ++i) {
    362                 attempt->context->SetStatus(batch_util::CopyElementToSlice(
    363                     std::move(tuple[i]), &attempt->tuple[i], index));
    364                 if (!attempt->context->status().ok()) return kComplete;
    365               }
    366               tuple.clear();
    367               --attempt->elements_requested;
    368               if (attempt->elements_requested == 0) {
    369                 tuple = attempt->tuple;
    370                 attempt->done_callback = [callback, tuple]() {
    371                   callback(tuple);
    372                 };
    373                 return kComplete;
    374               }
    375             }
    376             return result;
    377           });
    378     }
    379   }
    380   if (!already_cancelled) {
    381     FlushUnlocked();
    382   } else {
    383     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
    384     callback(Tuple());
    385   }
    386 }
    387 
    388 Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
    389   if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
    390       !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
    391     return errors::InvalidArgument("Expected PriorityQueue, found ",
    392                                    node_def.op());
    393   }
    394   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
    395   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
    396   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
    397   return Status::OK();
    398 }
    399 
    400 Status PriorityQueue::MatchesPriorityNodeDefTypes(
    401     const NodeDef& node_def) const {
    402   DataTypeVector requested_dtypes;
    403   TF_RETURN_IF_ERROR(
    404       GetNodeAttr(node_def, "component_types", &requested_dtypes));
    405   requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
    406   if (requested_dtypes != component_dtypes_) {
    407     return errors::InvalidArgument("Shared queue '", name_,
    408                                    "' has component types ",
    409                                    DataTypeSliceString(component_dtypes_),
    410                                    " but requested component types were ",
    411                                    DataTypeSliceString(requested_dtypes));
    412   }
    413   return Status::OK();
    414 }
    415 
    416 Status PriorityQueue::MatchesPriorityNodeDefShapes(
    417     const NodeDef& node_def) const {
    418   std::vector<TensorShape> requested_shapes;
    419   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
    420   requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
    421   if (requested_shapes != component_shapes_) {
    422     return errors::InvalidArgument("Shared queue '", name_,
    423                                    "' has component shapes ",
    424                                    ShapeListString(component_shapes_),
    425                                    " but requested component shapes were ",
    426                                    ShapeListString(requested_shapes));
    427   }
    428   return Status::OK();
    429 }
    430 
    431 }  // namespace tensorflow
    432