Home | History | Annotate | Download | only in training
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
     17 #define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <unordered_set>
     22 #include <vector>
     23 
     24 #include "tensorflow/cc/training/coordinator.h"
     25 #include "tensorflow/core/lib/core/blocking_counter.h"
     26 #include "tensorflow/core/lib/core/error_codes.pb.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/core/threadpool.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/protobuf/config.pb.h"
     31 #include "tensorflow/core/protobuf/queue_runner.pb.h"
     32 #include "tensorflow/core/public/session.h"
     33 
     34 namespace tensorflow {
     35 
     36 /// QueueRunner class imitates the behavior of the python version of QueueRunner
     37 /// which creates a thread for each enqueue op, runs close op on completion.
     38 class QueueRunner : public RunnerInterface {
     39  public:
     40   /// Creates a new QueueRunner from proto.
     41   // TODO(yuefengz): we may want to initialize from queues and ops in the
     42   // future.
     43   static Status New(const QueueRunnerDef& queue_runner_def,
     44                     std::unique_ptr<QueueRunner>* result);
     45 
     46   /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage.
     47   static Status New(const QueueRunnerDef& queue_runner_def, Coordinator* coord,
     48                     std::unique_ptr<QueueRunner>* result);
     49 
     50   /// Adds a callback that the queue runner will call when it detects an error.
     51   void AddErrorCallback(const std::function<void(Status)>& cb);
     52 
     53   /// Delete the previously registered callbacks.
     54   void ClearErrorCallbacks();
     55 
     56   /// The destructor would join all the threads.
     57   ~QueueRunner();
     58 
     59   /// Starts the queue runner with the given session.
     60   Status Start(Session* sess);
     61 
     62   /// Starts the queue runner with the given session and sets the run arguments
     63   /// for sess->Run. It also collects and stores the cost model.
     64   Status StartAndCollectCostGraph(Session* sess,
     65                                   const RunOptions& run_options = RunOptions());
     66 
     67   /// Starts the queue runner with the given session, and wait for up to the
     68   /// specified time (in milliseconds) for the queues to start to fill up.
     69   Status Start(Session* sess, int wait_for_ms);
     70   Status StartAndCollectCostGraph(Session* session, int wait_for_ms,
     71                                   const RunOptions& run_options = RunOptions());
     72 
     73   /// Requests to stop and runs the cancel op. It would be called in a separate
     74   /// thread when coordinator is set. If there is no coordinator it should be
     75   /// called before calling Join.
     76   void Stop(Session* sess);
     77 
     78   /// Joins all the threads. Returns okay if all threads run successfully;
     79   /// otherwise returns the first captured failure status.
     80   Status Join() final;
     81 
     82   /// Returns the latest status.
     83   Status GetStatus();
     84 
     85   // Returns the stored cost model.
     86   Status ExportCostGraph(CostGraphDef* cost_graph) const override;
     87 
     88  private:
     89   QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {}
     90 
     91   // Initializes the instance with the QueueRunnerDef proto.
     92   Status Init(const QueueRunnerDef& queue_runner_def);
     93 
     94   // The Run function for each thread.
     95   void Run(Session* sess, const string& enqueue_op);
     96 
     97   // Updates the internal status; it only keeps OK or the first unexpected error
     98   // status.
     99   void UpdateStatus(const Status& status);
    100 
    101   bool IsQueueClosed(Status status) const {
    102     return queue_closed_exception_types_.count(
    103                static_cast<int>(status.code())) > 0;
    104   }
    105 
    106   bool IsRunning() const override { return !stopped_; }
    107 
    108   void SetRunArgumentsAndCostGraph(const RunOptions& run_options);
    109 
    110   Status RealRun(Session* sess, const string& op, bool update_costs);
    111 
    112   string queue_name_;
    113   std::vector<string> enqueue_op_names_;
    114   string close_op_name_;
    115   string cancel_op_name_;
    116   // code::Code casted to int to avoid a hash function.
    117   std::unordered_set<int> queue_closed_exception_types_;
    118 
    119   std::unique_ptr<thread::ThreadPool> thread_pool_;
    120   mutex mu_;
    121   int runs_ = 0;
    122   Status status_ GUARDED_BY(mu_);
    123   Status enqueue_status_ GUARDED_BY(mu_);
    124   std::unique_ptr<BlockingCounter> counter_;
    125 
    126   Coordinator* coord_;
    127 
    128   std::atomic<bool> stopped_;
    129 
    130   mutex cb_mu_;
    131   std::vector<std::function<void(Status)>> callbacks_;
    132 
    133   mutable std::unique_ptr<mutex> cg_mu_;
    134   std::unique_ptr<CostGraphDef> cost_graph_ GUARDED_BY(cg_mu_);
    135   RunOptions run_options_;
    136 };
    137 
    138 }  // namespace tensorflow
    139 
    140 #endif  // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_
    141