Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/queue_base.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/framework/tensor_shape.h"
     21 #include "tensorflow/core/kernels/batch_util.h"
     22 #include "tensorflow/core/lib/core/errors.h"
     23 #include "tensorflow/core/platform/mutex.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 namespace {
     29 
     30 template <DataType DT>
     31 Status HandleSliceToElement(const Tensor& parent, Tensor* element,
     32                             int64 index) {
     33   typedef typename EnumToDataType<DT>::Type T;
     34   DCHECK_NE(parent.dim_size(0), 0);
     35   DCHECK_GE(index, 0);
     36   if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) {
     37     TensorShape chip_shape = parent.shape();
     38     chip_shape.RemoveDim(0);
     39     return errors::Internal(
     40         "HandleSliceToElement Cannot copy slice: number of elements does not "
     41         "match.  Shapes are: [element]: ",
     42         element->shape().DebugString(),
     43         ", [parent slice]: ", chip_shape.DebugString());
     44   }
     45   auto parent_as_matrix = parent.flat_outer_dims<T>();
     46   element->flat<T>() = parent_as_matrix.chip(index, 0);
     47   return Status::OK();
     48 }
     49 
     50 }  // namespace
     51 
     52 QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
     53                      const std::vector<TensorShape>& component_shapes,
     54                      const string& name)
     55     : capacity_(capacity),
     56       component_dtypes_(component_dtypes),
     57       component_shapes_(component_shapes),
     58       name_(name),
     59       closed_(false) {}
     60 
     61 QueueBase::~QueueBase() {}
     62 
     63 Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
     64   if (tuple.size() != static_cast<size_t>(num_components())) {
     65     return errors::InvalidArgument(
     66         "Wrong number of components in tuple. Expected ", num_components(),
     67         ", got ", tuple.size());
     68   }
     69   for (size_t i = 0; i < tuple.size(); ++i) {
     70     if (tuple[i].dtype() != component_dtypes_[i]) {
     71       return errors::InvalidArgument(
     72           "Type mismatch in tuple component ", i, ". Expected ",
     73           DataTypeString(component_dtypes_[i]), ", got ",
     74           DataTypeString(tuple[i].dtype()));
     75     }
     76   }
     77   return Status::OK();
     78 }
     79 
     80 // static
     81 string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
     82   string result = "[";
     83   bool first = true;
     84   for (const TensorShape& shape : shapes) {
     85     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
     86     first = false;
     87   }
     88   strings::StrAppend(&result, "]");
     89   return result;
     90 }
     91 
     92 Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def,
     93                                    const string& op) const {
     94   if (node_def.op() != op) {
     95     return errors::InvalidArgument("Shared queue '", name_, "' has type '", op,
     96                                    "' that does not match type of Node '",
     97                                    node_def.name(), "': ", node_def.op());
     98   }
     99   return Status::OK();
    100 }
    101 
    102 Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def,
    103                                          int32 capacity) const {
    104   int32 requested_capacity = -1;
    105   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity));
    106   if (requested_capacity < 0) requested_capacity = kUnbounded;
    107   if (requested_capacity != capacity) {
    108     return errors::InvalidArgument("Shared queue '", name_, "' has capacity ",
    109                                    capacity, " but requested capacity was ",
    110                                    requested_capacity);
    111   }
    112   return Status::OK();
    113 }
    114 
    115 Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const {
    116   DataTypeVector requested_dtypes;
    117   TF_RETURN_IF_ERROR(
    118       GetNodeAttr(node_def, "component_types", &requested_dtypes));
    119   if (requested_dtypes != component_dtypes_) {
    120     return errors::InvalidArgument("Shared queue '", name_,
    121                                    "' has component types ",
    122                                    DataTypeSliceString(component_dtypes_),
    123                                    " but requested component types were ",
    124                                    DataTypeSliceString(requested_dtypes));
    125   }
    126   return Status::OK();
    127 }
    128 
    129 Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
    130   std::vector<TensorShape> requested_shapes;
    131   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
    132   if (requested_shapes != component_shapes_) {
    133     return errors::InvalidArgument("Shared queue '", name_,
    134                                    "' has component shapes ",
    135                                    ShapeListString(component_shapes_),
    136                                    " but requested component shapes were ",
    137                                    ShapeListString(requested_shapes));
    138   }
    139   return Status::OK();
    140 }
    141 
    142 // TODO(mrry): If these checks become a bottleneck, find a way to
    143 //   reduce the number of times that they are called.
    144 Status QueueBase::ValidateTuple(const Tuple& tuple) {
    145   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
    146   if (specified_shapes()) {
    147     for (size_t i = 0; i < tuple.size(); ++i) {
    148       if (!component_shapes_[i].IsSameSize(tuple[i].shape())) {
    149         return errors::InvalidArgument(
    150             "Shape mismatch in tuple component ", i, ". Expected ",
    151             component_shapes_[i].DebugString(), ", got ",
    152             tuple[i].shape().DebugString());
    153       }
    154     }
    155   }
    156   return Status::OK();
    157 }
    158 
    159 // TODO(mrry): If these checks become a bottleneck, find a way to
    160 //   reduce the number of times that they are called.
    161 Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
    162   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
    163   const int64 batch_size = tuple[0].dim_size(0);
    164   if (specified_shapes()) {
    165     for (size_t i = 0; i < tuple.size(); ++i) {
    166       // Expected shape is [batch_size] + component_shapes_[i]
    167       const TensorShape expected_shape = ManyOutShape(i, batch_size);
    168       if (!expected_shape.IsSameSize(tuple[i].shape())) {
    169         return errors::InvalidArgument("Shape mismatch in tuple component ", i,
    170                                        ". Expected ",
    171                                        expected_shape.DebugString(), ", got ",
    172                                        tuple[i].shape().DebugString());
    173       }
    174     }
    175   } else {
    176     for (size_t i = 1; i < tuple.size(); ++i) {
    177       if (tuple[i].dim_size(0) != batch_size) {
    178         return errors::InvalidArgument(
    179             "All input tensors must have the same size in the 0th ",
    180             "dimension. Component ", i, " has ", tuple[i].dim_size(0),
    181             ", and should have ", batch_size);
    182       }
    183     }
    184   }
    185   return Status::OK();
    186 }
    187 
    188 void QueueBase::Cancel(Action action, CancellationManager* cancellation_manager,
    189                        CancellationToken token) {
    190   DoneCallback callback = nullptr;
    191   {
    192     mutex_lock lock(mu_);
    193     std::deque<Attempt>* attempts =
    194         action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
    195 
    196     for (Attempt& attempt : *attempts) {
    197       if (attempt.cancellation_manager == cancellation_manager &&
    198           attempt.cancellation_token == token) {
    199         if (!attempt.is_cancelled) {
    200           attempt.is_cancelled = true;
    201           if (action == kEnqueue) {
    202             attempt.context->SetStatus(
    203                 errors::Cancelled("Enqueue operation was cancelled"));
    204           } else {
    205             attempt.context->SetStatus(
    206                 errors::Cancelled("Dequeue operation was cancelled"));
    207           }
    208           std::swap(callback, attempt.done_callback);
    209         }
    210         break;
    211       }
    212     }
    213   }
    214   if (callback) {
    215     callback();
    216     FlushUnlocked();
    217   }
    218 }
    219 
    220 void QueueBase::CloseAndCancel() {
    221   std::vector<DoneCallback> callbacks;
    222   {
    223     mutex_lock lock(mu_);
    224     closed_ = true;
    225     for (Attempt& attempt : enqueue_attempts_) {
    226       if (!attempt.is_cancelled) {
    227         attempt.is_cancelled = true;
    228         attempt.context->SetStatus(
    229             errors::Cancelled("Enqueue operation was cancelled"));
    230         callbacks.emplace_back(std::move(attempt.done_callback));
    231       }
    232     }
    233   }
    234   for (const DoneCallback& callback : callbacks) {
    235     callback();
    236   }
    237   FlushUnlocked();
    238 }
    239 
    240 void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
    241                       DoneCallback callback) {
    242   if (cancel_pending_enqueues) {
    243     CloseAndCancel();
    244     callback();
    245   } else {
    246     {
    247       mutex_lock lock(mu_);
    248       enqueue_attempts_.emplace_back(
    249           0, callback, ctx, nullptr, CancellationManager::kInvalidToken,
    250           [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    251             if (closed_) {
    252               attempt->context->SetStatus(
    253                   errors::Cancelled("Queue '", name_, "' is already closed."));
    254             } else {
    255               closed_ = true;
    256             }
    257             return kComplete;
    258           });
    259     }
    260     FlushUnlocked();
    261   }
    262 }
    263 
    264 bool QueueBase::TryAttemptLocked(Action action,
    265                                  std::vector<CleanUp>* clean_up) {
    266   std::deque<Attempt>* attempts =
    267       action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
    268 
    269   bool progress = false;
    270   bool done = false;
    271   while (!done && !attempts->empty()) {
    272     if (attempts->front().is_cancelled) {
    273       if (action == kEnqueue) {
    274         if (closed_) {
    275           VLOG(1) << "Skipping cancelled enqueue attempt";
    276         } else {
    277           LOG(WARNING)
    278               << name_
    279               << ": Skipping cancelled enqueue attempt with queue not closed";
    280         }
    281       } else {
    282         if (closed_) {
    283           VLOG(1) << "Skipping cancelled dequeue attempt";
    284         } else {
    285           LOG(WARNING)
    286               << name_
    287               << ": Skipping cancelled dequeue attempt with queue not closed";
    288         }
    289       }
    290       attempts->pop_front();
    291     } else {
    292       Attempt* cur_attempt = &attempts->front();
    293       switch (cur_attempt->run_callback(cur_attempt)) {
    294         case kNoProgress:
    295           done = true;
    296           break;
    297         case kProgress:
    298           done = true;
    299           progress = true;
    300           break;
    301         case kComplete:
    302           progress = true;
    303           clean_up->emplace_back(std::move(cur_attempt->done_callback),
    304                                  cur_attempt->cancellation_token,
    305                                  cur_attempt->context->cancellation_manager());
    306           attempts->pop_front();
    307           break;
    308       }
    309     }
    310   }
    311   return progress;
    312 }
    313 
    314 void QueueBase::FlushUnlocked() {
    315   std::vector<CleanUp> clean_up;
    316   Ref();
    317   {
    318     mutex_lock lock(mu_);
    319     bool changed;
    320     do {
    321       changed = TryAttemptLocked(kEnqueue, &clean_up);
    322       changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
    323     } while (changed);
    324   }
    325   Unref();
    326   for (const auto& to_clean : clean_up) {
    327     if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
    328       // NOTE(mrry): We can safely ignore the return value of
    329       // DeregisterCallback because the mutex mu_ ensures that the
    330       // cleanup action only executes once.
    331       to_clean.cm->DeregisterCallback(to_clean.to_deregister);
    332     }
    333     to_clean.finished();
    334   }
    335 }
    336 
    337 Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
    338                                      int64 index) {
    339   return batch_util::CopySliceToElement(parent, element, index);
    340 }
    341 
    342 /* static */
    343 Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
    344                                      int64 index) {
    345   return batch_util::CopyElementToSlice(element, parent, index);
    346 }
    347 
    348 }  // namespace tensorflow
    349