Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/data_flow_ops.cc.
     17 
     18 #include <deque>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_shape.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/batch_util.h"
     27 #include "tensorflow/core/kernels/padding_fifo_queue.h"
     28 #include "tensorflow/core/kernels/queue_base.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace tensorflow {
     35 
     36 PaddingFIFOQueue::PaddingFIFOQueue(
     37     int capacity, const DataTypeVector& component_dtypes,
     38     const std::vector<PartialTensorShape>& partial_shapes, const string& name)
     39     : FIFOQueue(capacity, component_dtypes,
     40                 ConvertShapesPartialDimensionsToZero(partial_shapes), name),
     41       partial_shapes_(partial_shapes) {}
     42 
     43 Status PaddingFIFOQueue::Initialize() {
     44   Status s = FIFOQueue::Initialize();
     45   if (!s.ok()) return s;
     46 
     47   if (component_dtypes_.size() != partial_shapes_.size()) {
     48     return errors::InvalidArgument(
     49         "Shapes must be provided for all components, but received ",
     50         component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
     51         " shapes.");
     52   }
     53 
     54   return Status::OK();
     55 }
     56 
     57 /* static */
     58 Status PaddingFIFOQueue::GetElementComponent(
     59     const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
     60     PersistentTensor* out_tensor) {
     61   TensorShape element_shape(tuple[component].shape());
     62   Tensor* element_access = nullptr;
     63   TF_RETURN_IF_ERROR(ctx->allocate_persistent(
     64       tuple[component].dtype(), element_shape, out_tensor, &element_access));
     65   *element_access = tuple[component];
     66   return Status::OK();
     67 }
     68 
     69 void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
     70                                       bool allow_small_batch,
     71                                       CallbackWithTuple callback) {
     72   if (num_elements == 0) {
     73     Tuple tuple;
     74     tuple.reserve(num_components());
     75     for (int i = 0; i < num_components(); ++i) {
     76       // TODO(josh11b,misard): Switch to allocate_output().
     77       // See similar comment in fifo_queue.cc
     78       Tensor element;
     79       // Here, ManyOutShape returns zeros for undetermined shapes,
     80       // which is exactly what we want to use.
     81       OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i],
     82                                              ManyOutShape(i, 0), &element));
     83       tuple.emplace_back(element);
     84     }
     85     callback(tuple);
     86     return;
     87   }
     88 
     89   CancellationManager* cm = ctx->cancellation_manager();
     90   CancellationToken token = cm->get_cancellation_token();
     91   bool already_cancelled;
     92   {
     93     mutex_lock l(mu_);
     94     already_cancelled = !cm->RegisterCallback(
     95         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
     96     if (!already_cancelled) {
     97       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
     98       dequeue_attempts_.emplace_back(
     99           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
    100           [callback, allow_small_batch,
    101            this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    102             int32 queue_size = queues_[0].size();
    103             if (closed_ && queue_size < attempt->elements_requested) {
    104               // If we don't have enough for a full dequeue, we have
    105               // to reset the attempt tuple.
    106               if (!attempt->tuples.empty()) {
    107                 // Restore already-dequeued elements to the front of the queue.
    108                 for (int64 i = attempt->tuples.size() - 1; i >= 0; --i) {
    109                   for (int j = 0; j < num_components(); ++j) {
    110                     PersistentTensor element;
    111                     Status s = GetElementComponent(attempt->tuples[i], j,
    112                                                    attempt->context, &element);
    113                     if (!s.ok()) {
    114                       attempt->context->SetStatus(
    115                           errors::DataLoss("Failed to restore element from "
    116                                            "partially-dequeued batch "
    117                                            "to PaddingFIFOQueue: ",
    118                                            s.error_message()));
    119                     }
    120                     queues_[j].push_front(element);
    121                   }
    122                 }
    123               }
    124               if (allow_small_batch && !queues_[0].empty()) {
    125                 // Request all remaining elements in the queue.
    126                 queue_size = queues_[0].size();
    127                 attempt->tuples.clear();
    128                 attempt->elements_requested = queue_size;
    129               } else {
    130                 if (allow_small_batch) {
    131                   // There may be some enqueue attempts containing
    132                   // values.  If so, we'll yield and wait for them
    133                   // to add elements to the queue.
    134                   if (!enqueue_attempts_.empty()) return kProgress;
    135                 }
    136                 if (attempt->context->status().ok()) {
    137                   attempt->context->SetStatus(errors::OutOfRange(
    138                       "PaddingFIFOQueue '", name_, "' is closed and has ",
    139                       "insufficient elements (requested ",
    140                       attempt->elements_requested, ", current size ",
    141                       queue_size, ")"));
    142                 }
    143                 return kComplete;
    144               }
    145             }
    146 
    147             RunResult result = kNoProgress;
    148             for (; queue_size > 0; --queue_size) {
    149               result = kProgress;
    150               Tuple tuple;
    151               DequeueLocked(attempt->context, &tuple);
    152               attempt->tuples.push_back(tuple);
    153               tuple.clear();
    154               --attempt->elements_requested;
    155 
    156               if (attempt->elements_requested == 0) {
    157                 // Finished.  Allocate attempt->tuple and
    158                 // copy from attempt->tuples to attempt->tuple.
    159                 attempt->tuple.reserve(num_components());
    160                 std::vector<Tuple>& tuples = attempt->tuples;
    161 
    162                 std::vector<bool> dynamic_shape;
    163                 const int64 batch_size = tuples.size();
    164 
    165                 for (int i = 0; i < num_components(); ++i) {
    166                   const PartialTensorShape partial_shape =
    167                       PartialTensorShape({batch_size})
    168                           .Concatenate(partial_shapes_[i]);
    169                   TensorShape shape({batch_size});
    170 
    171                   for (int j = 0; j < partial_shape.dims() - 1; ++j) {
    172                     if (partial_shape.dim_size(j + 1) > -1) {
    173                       shape.AddDim(partial_shape.dim_size(j + 1));
    174                     } else {
    175                       // Expand sizes to match.
    176                       int64 max_val = 0;
    177                       for (const Tuple& t : tuples) {
    178                         max_val = std::max(max_val, t[i].shape().dim_size(j));
    179                       }
    180                       shape.AddDim(max_val);
    181                     }
    182                   }
    183 
    184                   Tensor element;
    185                   attempt->context->SetStatus(attempt->context->allocate_temp(
    186                       component_dtypes_[i], shape, &element));
    187                   if (!attempt->context->status().ok()) return kComplete;
    188 
    189                   bool has_dynamic_shape = !partial_shape.IsFullyDefined();
    190                   if (has_dynamic_shape) {
    191                     // Set all values to zero because not all values
    192                     // will get written over.
    193                     attempt->context->SetStatus(SetElementZero(&element));
    194                     if (!attempt->context->status().ok()) return kComplete;
    195                   }
    196 
    197                   dynamic_shape.push_back(has_dynamic_shape);
    198 
    199                   // TODO(ebrevdo): should this be a persistent tensor?
    200                   attempt->tuple.emplace_back(element);
    201                 }
    202 
    203                 for (size_t index = 0; index < tuples.size(); ++index) {
    204                   for (int i = 0; i < num_components(); ++i) {
    205                     if (dynamic_shape[i]) {
    206                       // Slightly slower copy operation
    207                       attempt->context->SetStatus(CopyElementToLargerSlice(
    208                           tuples[index][i], &attempt->tuple[i], index));
    209                     } else {
    210                       attempt->context->SetStatus(
    211                           batch_util::CopyElementToSlice(
    212                               std::move(tuples[index][i]), &attempt->tuple[i],
    213                               index));
    214                     }
    215                     if (!attempt->context->status().ok()) return kComplete;
    216                   }
    217                 }
    218                 tuple = attempt->tuple;
    219                 attempt->tuples.clear();
    220                 attempt->done_callback = [callback, tuple]() {
    221                   callback(tuple);
    222                 };
    223                 return kComplete;
    224               }
    225             }
    226             return result;
    227           });
    228     }
    229   }
    230   if (!already_cancelled) {
    231     FlushUnlocked();
    232   } else {
    233     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
    234     callback(Tuple());
    235   }
    236 }
    237 
    238 Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
    239   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
    240   for (size_t i = 0; i < tuple.size(); ++i) {
    241     if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
    242       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
    243                                      ". Expected ",
    244                                      partial_shapes_[i].DebugString(), ", got ",
    245                                      tuple[i].shape().DebugString());
    246     }
    247   }
    248   return Status::OK();
    249 }
    250 
    251 Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
    252   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
    253   const int64 batch_size = tuple[0].dim_size(0);
    254   for (size_t i = 0; i < tuple.size(); ++i) {
    255     // Expected shape is [batch_size] + partial_shapes_[i]
    256     const PartialTensorShape expected_shape =
    257         PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
    258     if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
    259       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
    260                                      ". Expected ",
    261                                      expected_shape.DebugString(), ", got ",
    262                                      tuple[i].shape().DebugString());
    263     }
    264   }
    265   return Status::OK();
    266 }
    267 
    268 Status PaddingFIFOQueue::CompatibleNodeDefShapes(
    269     const NodeDef& node_def) const {
    270   std::vector<PartialTensorShape> requested_shapes;
    271   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
    272   if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
    273                                               partial_shapes_)) {
    274     return errors::InvalidArgument(
    275         "Shared queue '", name_, "' has component shapes ",
    276         PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
    277         " but requested component shapes were ",
    278         PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
    279   } else {
    280     return Status::OK();
    281   }
    282 }
    283 
    284 Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
    285   if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
    286       !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
    287     return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
    288                                    node_def.op());
    289   }
    290   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
    291   TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
    292   TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
    293   return Status::OK();
    294 }
    295 
    296 static Status ValidateElementToLargerSlice(const Tensor& element,
    297                                            Tensor* parent) {
    298   DCHECK_NE(parent->dim_size(0), 0);
    299   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
    300     TensorShape chip_shape = parent->shape();
    301     chip_shape.RemoveDim(0);
    302     return errors::Internal(
    303         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
    304         "element is greater than number of elements in parent slice.  ",
    305         "Shapes are: [element]: ", element.shape().DebugString(),
    306         ", [parent slice]: ", chip_shape.DebugString());
    307   }
    308   return Status::OK();
    309 }
    310 
    311 template <typename T, int NDIMS>
    312 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
    313                                   int index) {
    314   Status s = ValidateElementToLargerSlice(element, parent);
    315   if (!s.ok()) {
    316     return s;
    317   }
    318   if (element.NumElements() == 0) {
    319     return Status::OK();
    320   }
    321   auto element_t = element.tensor<T, NDIMS>();
    322   auto parent_t = parent->tensor<T, NDIMS + 1>();
    323   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
    324   slice_indices[0] = index;
    325   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
    326   slice_size[0] = 1;
    327   for (size_t i = 1; i < slice_size.size(); ++i) {
    328     slice_size[i] = element_t.dimension(i - 1);
    329   }
    330   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
    331   return Status::OK();
    332 }
    333 
    334 namespace {
    335 
    336 template <int NDIMS>
    337 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
    338                                           int index) {
    339 #define HANDLE_TYPE(T)                                                   \
    340   case DataTypeToEnum<T>::value: {                                       \
    341     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
    342   }
    343 
    344   switch (element.dtype()) {
    345     TF_CALL_ALL_TYPES(HANDLE_TYPE);
    346 #undef HANDLE_TYPE
    347     default:
    348       return errors::Unimplemented(
    349           "HandleElementToLargerSliceWithRank Unhandled data type: ",
    350           element.dtype());
    351   }
    352 }
    353 
    354 }  // namespace
    355 
    356 Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
    357                                                   Tensor* parent, int index) {
    358   if (parent->dims() != element.dims() + 1) {
    359     return errors::Internal(
    360         "Mismatched ranks.  Element's rank is: ", element.dims(),
    361         " but element is meant to be a slice in output Tensor having rank: ",
    362         parent->dims(), " (should be: ", element.dims() + 1, ")");
    363   }
    364 
    365 #define HANDLE_DIMS(NDIMS)                                                  \
    366   case NDIMS: {                                                             \
    367     TF_RETURN_IF_ERROR(                                                     \
    368         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
    369     return Status::OK();                                                    \
    370   }
    371 
    372   switch (element.dims()) {
    373     HANDLE_DIMS(0);
    374     HANDLE_DIMS(1);
    375     HANDLE_DIMS(2);
    376     HANDLE_DIMS(3);
    377     HANDLE_DIMS(4);
    378 #undef HANDLE_DIMS
    379     default:
    380       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
    381                                    element.dims());
    382   }
    383 }
    384 
    385 // Static method
    386 Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
    387 #define HANDLE_TYPE(T)                                \
    388   if (element->dtype() == DataTypeToEnum<T>::value) { \
    389     element->flat<T>().setConstant(T());              \
    390     return Status::OK();                              \
    391   }
    392   TF_CALL_ALL_TYPES(HANDLE_TYPE);
    393 #undef HANDLE_TYPE
    394   return errors::Unimplemented("SetElementZero Unhandled data type: ",
    395                                element->dtype());
    396 }
    397 
    398 std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
    399     const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
    400   std::vector<TensorShape> shapes(partial_shapes.size());
    401   for (size_t i = 0; i < shapes.size(); ++i) {
    402     const PartialTensorShape& partial = partial_shapes[i];
    403     TensorShape& shape = shapes[i];
    404     for (int64 s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
    405   }
    406   return shapes;
    407 }
    408 
    409 }  // namespace tensorflow
    410