Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
     17 #define TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
     18 
     19 #include <deque>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/partial_tensor_shape.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/kernels/fifo_queue.h"
     28 #include "tensorflow/core/kernels/typed_queue.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/mutex.h"
     31 #include "tensorflow/core/platform/types.h"
     32 
     33 namespace tensorflow {
     34 
     35 class PaddingFIFOQueue : public FIFOQueue {
     36  public:
     37   PaddingFIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
     38                    const std::vector<PartialTensorShape>& component_shapes,
     39                    const string& name);
     40 
     41   Status Initialize() override;
     42 
     43   // Implementations of QueueInterface methods --------------------------------
     44 
     45   void TryDequeueMany(int num_elements, OpKernelContext* ctx,
     46                       bool allow_small_batch,
     47                       CallbackWithTuple callback) override;
     48   Status MatchesNodeDef(const NodeDef& node_def) override;
     49 
     50  protected:
     51   Status ValidateManyTuple(const Tuple& tuple) override;
     52   Status ValidateTuple(const Tuple& tuple) override;
     53   Status CompatibleNodeDefShapes(const NodeDef& node_def) const;
     54 
     55   // Convert a list of PartialTensorShape to a list of
     56   // TensorShape.
     57   // Any unknown dimension sizes are converted to 0.
     58   // REQUIRED: All the input shapes have well defined rank.
     59   static std::vector<TensorShape> ConvertShapesPartialDimensionsToZero(
     60       const gtl::ArraySlice<PartialTensorShape>& partial_shapes);
     61 
     62   // Sets the values in the given element to zero.
     63   static Status SetElementZero(Tensor* element);
     64 
     65   // Copies element into the index^th slice (in the first dimension)
     66   // of parent.  Allows for the parent's slice to have a larger size
     67   // than the element, and copies the element into the upper left hand
     68   // corner of the slice.
     69   static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
     70                                          int index);
     71 
     72   std::vector<PartialTensorShape> partial_shapes_;
     73 
     74  private:
     75   ~PaddingFIFOQueue() override {}
     76 
     77   static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple,
     78                                     int component, OpKernelContext* ctx,
     79                                     PersistentTensor* out_tensor);
     80 
     81   static Status IsSameSizeExceptZerosInFirst(const TensorShape& first,
     82                                              const TensorShape& second);
     83 
     84   TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueue);
     85 };
     86 
     87 }  // namespace tensorflow
     88 
     89 #endif  // TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
     90