Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
     17 #define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
     18 
     19 #include <deque>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/queue_interface.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/lib/gtl/array_slice.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/types.h"
     31 
     32 namespace tensorflow {
     33 
     34 namespace barrier {
     35 class Barrier;
     36 }  // namespace barrier
     37 
     38 // Functionality common to asynchronous QueueInterface implementations.
     39 class QueueBase : public QueueInterface {
     40  public:
     41   // As a possible value of 'capacity'.
     42   static const int32 kUnbounded = INT_MAX;
     43 
     44   // Args:
     45   //   component_dtypes: The types of each component in a queue-element tuple.
     46   //   component_shapes: The shapes of each component in a queue-element tuple,
     47   //     which must either be empty (if the shapes are not specified) or
     48   //     or have the same size as component_dtypes.
     49   //   name: A name to use for the queue.
     50   QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
     51             const std::vector<TensorShape>& component_shapes,
     52             const string& name);
     53 
     54   // Implementations of QueueInterface methods --------------------------------
     55   const DataTypeVector& component_dtypes() const override {
     56     return component_dtypes_;
     57   }
     58 
     59   Status ValidateTuple(const Tuple& tuple) override;
     60   Status ValidateManyTuple(const Tuple& tuple) override;
     61 
     62   void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
     63              DoneCallback callback) override;
     64 
     65   // Other public methods -----------------------------------------------------
     66   const std::vector<TensorShape>& component_shapes() const {
     67     return component_shapes_;
     68   }
     69 
     70   int32 capacity() const { return capacity_; }
     71 
     72   bool is_closed() const override {
     73     mutex_lock lock(mu_);
     74     return closed_;
     75   }
     76 
     77   // Copies the index^th slice (in the first dimension) of parent into element.
     78   static Status CopySliceToElement(const Tensor& parent, Tensor* element,
     79                                    int64 index);
     80 
     81   // Copies element into the index^th slice (in the first dimension) of parent.
     82   // NOTE(mrry): This method is deprecated. Use
     83   // `tensorflow::batch_util::CopySliceToElement()` defined in
     84   // "./batch_util.h" instead.
     85   static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
     86                                    int64 index);
     87 
     88  protected:
     89   enum Action { kEnqueue, kDequeue };
     90   enum RunResult { kNoProgress, kProgress, kComplete };
     91 
     92   // Tries to enqueue/dequeue (or close) based on whatever is at the
     93   // front of enqueue_attempts_/dequeue_attempts_.  Appends to
     94   // *finished the callback for any finished attempt (so it may be
     95   // called once mu_ is released).  Returns true if any progress was
     96   // made.
     97   struct CleanUp {
     98     CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
     99         : finished(f), to_deregister(ct), cm(cm) {}
    100     DoneCallback finished;
    101     CancellationToken to_deregister;
    102     CancellationManager* cm;
    103   };
    104 
    105   // Returns the number of components in a queue-element tuple.
    106   int32 num_components() const { return component_dtypes_.size(); }
    107 
    108   // True if shapes were specified.  If so, inputs will be validated
    109   // against them, etc.
    110   bool specified_shapes() const { return component_shapes_.size() > 0; }
    111 
    112   // Code common to Validate*Tuple().
    113   Status ValidateTupleCommon(const Tuple& tuple) const;
    114 
    115   TensorShape ManyOutShape(int i, int64 batch_size) {
    116     TensorShape shape({batch_size});
    117     shape.AppendShape(component_shapes_[i]);
    118     return shape;
    119   }
    120 
    121   void Cancel(Action action, CancellationManager* cancellation_manager,
    122               CancellationToken token);
    123 
    124   // Helper for cancelling all pending Enqueue(Many) operations when
    125   // Close is called with cancel_pending_enqueues.
    126   void CloseAndCancel();
    127 
    128   bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
    129       EXCLUSIVE_LOCKS_REQUIRED(mu_);
    130 
    131   // Tries to make progress on the enqueues or dequeues at the front
    132   // of the *_attempts_ queues.
    133   void FlushUnlocked();
    134 
    135   ~QueueBase() override;
    136 
    137   // Helpers for implementing MatchesNodeDef().
    138   static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
    139   Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const;
    140   Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const;
    141   Status MatchesNodeDefTypes(const NodeDef& node_def) const;
    142   Status MatchesNodeDefShapes(const NodeDef& node_def) const;
    143 
    144  protected:
    145   const int32 capacity_;
    146   const DataTypeVector component_dtypes_;
    147   const std::vector<TensorShape> component_shapes_;
    148   const string name_;
    149   mutable mutex mu_;
    150   bool closed_ GUARDED_BY(mu_);
    151 
    152   struct Attempt;
    153   typedef std::function<RunResult(Attempt*)> RunCallback;
    154   struct Attempt {
    155     int32 elements_requested;
    156     DoneCallback done_callback;  // must be run outside mu_
    157     OpKernelContext* context;
    158     CancellationManager* cancellation_manager;  // not owned
    159     CancellationToken cancellation_token;
    160     RunCallback run_callback;  // must be run while holding mu_
    161     bool is_cancelled;
    162     Tuple tuple;
    163     // tuples is used by some implementations allowing dynamic shapes.
    164     std::vector<Tuple> tuples;
    165 
    166     Attempt(int32 elements_requested, DoneCallback done_callback,
    167             OpKernelContext* context, CancellationManager* cancellation_manager,
    168             CancellationToken cancellation_token, RunCallback run_callback)
    169         : elements_requested(elements_requested),
    170           done_callback(done_callback),
    171           context(context),
    172           cancellation_manager(cancellation_manager),
    173           cancellation_token(cancellation_token),
    174           run_callback(run_callback),
    175           is_cancelled(false) {}
    176   };
    177   std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
    178   std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
    179 
    180   TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
    181 };
    182 
    183 }  // namespace tensorflow
    184 
    185 #endif  // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
    186