Home | History | Annotate | Download | only in training
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/training/queue_runner.h"
     17 #include "tensorflow/core/kernels/ops_util.h"
     18 #include "tensorflow/core/platform/env.h"
     19 
     20 namespace tensorflow {
     21 
     22 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
     23                         std::unique_ptr<QueueRunner>* result) {
     24   result->reset(new QueueRunner());
     25   return (*result)->Init(queue_runner_def);
     26 }
     27 
     28 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
     29                         Coordinator* coord,
     30                         std::unique_ptr<QueueRunner>* result) {
     31   result->reset(new QueueRunner());
     32   (*result)->coord_ = coord;
     33   return (*result)->Init(queue_runner_def);
     34 }
     35 
     36 void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
     37   mutex_lock l(cb_mu_);
     38   callbacks_.push_back(cb);
     39 }
     40 
     41 void QueueRunner::ClearErrorCallbacks() {
     42   mutex_lock l(cb_mu_);
     43   callbacks_.clear();
     44 }
     45 
     46 Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
     47   queue_name_ = queue_runner_def.queue_name();
     48   enqueue_op_names_.clear();
     49   enqueue_op_names_.insert(enqueue_op_names_.end(),
     50                            queue_runner_def.enqueue_op_name().begin(),
     51                            queue_runner_def.enqueue_op_name().end());
     52   size_t op_names_size = enqueue_op_names_.size();
     53   if (op_names_size > kint32max) {
     54     return Status(error::INVALID_ARGUMENT,
     55                   "Enqueue ops to run cannot exceed kint32max");
     56   }
     57   runs_ = static_cast<int>(op_names_size);
     58   if (runs_ == 0) {
     59     return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
     60   }
     61   close_op_name_ = queue_runner_def.close_op_name();
     62   cancel_op_name_ = queue_runner_def.cancel_op_name();
     63   if (queue_runner_def.queue_closed_exception_types_size() == 0) {
     64     queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
     65   } else {
     66     for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
     67       queue_closed_exception_types_.insert(static_cast<int>(code));
     68     }
     69   }
     70 
     71   int nthreads = runs_;
     72   if (coord_) {
     73     // One more thread to call Stop()
     74     nthreads++;
     75   }
     76   thread_pool_.reset(new thread::ThreadPool(
     77       Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
     78 
     79   return Status::OK();
     80 }
     81 
     82 QueueRunner::~QueueRunner() {
     83   // Cannot run Stop() here because the session might already be closed or
     84   // destroyed.
     85   Join().IgnoreError();
     86 }
     87 
     88 Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
     89 
     90 Status QueueRunner::StartAndCollectCostGraph(Session* sess,
     91                                              const RunOptions& run_options) {
     92   SetRunArgumentsAndCostGraph(run_options);
     93   return Start(sess, 0);
     94 }
     95 
     96 Status QueueRunner::Start(Session* sess, int wait_for) {
     97   counter_.reset(new BlockingCounter(runs_));
     98   for (const string& enqueue_op : enqueue_op_names_) {
     99     thread_pool_->Schedule(
    100         std::bind(&QueueRunner::Run, this, sess, enqueue_op));
    101   }
    102   if (coord_) {
    103     thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
    104   }
    105   // Wait for up to 'wait_for' milliseconds.
    106   if (wait_for > 0) {
    107     if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
    108       return Status(error::DEADLINE_EXCEEDED,
    109                     "Queues not fed before the timeout");
    110     }
    111     // Check the status of the queue runner as well as the result of the enqueue
    112     // operations.
    113     mutex_lock l(mu_);
    114     if (!enqueue_status_.ok()) {
    115       return enqueue_status_;
    116     } else {
    117       return status_;
    118     }
    119   }
    120   return Status::OK();
    121 }
    122 
    123 Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
    124                                              const RunOptions& run_options) {
    125   SetRunArgumentsAndCostGraph(run_options);
    126   return Start(session, wait_for_ms);
    127 }
    128 
    129 void QueueRunner::Stop(Session* sess) {
    130   if (coord_ != nullptr) {
    131     coord_->WaitForStop();
    132   }
    133   if (!cancel_op_name_.empty()) {
    134     UpdateStatus(RealRun(sess, cancel_op_name_, false));
    135   }
    136   stopped_ = true;
    137 }
    138 
    139 Status QueueRunner::Join() {
    140   thread_pool_.reset();
    141   mutex_lock l(mu_);
    142   return status_;
    143 }
    144 
    145 void QueueRunner::UpdateStatus(const Status& status) {
    146   {
    147     mutex_lock l(mu_);
    148     if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
    149       return;
    150     }
    151     status_ = status;
    152   }
    153   if (coord_) {
    154     coord_->ReportStatus(status);
    155   }
    156   mutex_lock l(cb_mu_);
    157   for (auto& cb : callbacks_) {
    158     cb(status);
    159   }
    160 }
    161 
    162 void QueueRunner::Run(Session* sess, const string& enqueue_op) {
    163   bool first_iteration = true;
    164   Status status;
    165   while (status.ok()) {
    166     if (coord_ && coord_->ShouldStop()) {
    167       break;
    168     }
    169     status = RealRun(sess, enqueue_op, true);
    170     if (first_iteration) {
    171       if (!status.ok()) {
    172         mutex_lock l(mu_);
    173         enqueue_status_ = status;
    174       }
    175       counter_->DecrementCount();
    176       first_iteration = false;
    177     }
    178   }
    179   bool last_run = false;
    180   {
    181     mutex_lock l(mu_);
    182     runs_--;
    183     last_run = (runs_ == 0);
    184   }
    185 
    186   // Close the queue unless the coordinator is shutting down since the cancel op
    187   // will be run anway in this case.
    188   if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
    189     if (last_run && !close_op_name_.empty()) {
    190       UpdateStatus(RealRun(sess, close_op_name_, false));
    191     }
    192   } else if (!status.ok()) {
    193     LOG(ERROR) << "Queue runner thread got a failure status: "
    194                << status.ToString();
    195     UpdateStatus(status);
    196     if (coord_) {
    197       coord_->RequestStop().IgnoreError();
    198     }
    199   }
    200 }
    201 
    202 Status QueueRunner::GetStatus() {
    203   mutex_lock l(mu_);
    204   return status_;
    205 }
    206 
    207 Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
    208   if (!cg_mu_) {
    209     return Status(error::FAILED_PRECONDITION,
    210                   "This QueueRunner doesn't collect a cost graph.");
    211   }
    212   mutex_lock l(*cg_mu_);
    213   cost_graph->MergeFrom(*cost_graph_);
    214   return Status::OK();
    215 }
    216 
    217 void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) {
    218   cg_mu_.reset(new mutex());
    219   {
    220     mutex_lock l(*cg_mu_);
    221     cost_graph_.reset(new CostGraphDef());
    222   }
    223   run_options_ = run_options;
    224 }
    225 
    226 Status QueueRunner::RealRun(Session* sess, const string& op,
    227                             bool update_costs) {
    228   Status s;
    229   if (update_costs && cg_mu_) {
    230     RunMetadata metadata;
    231     s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
    232     mutex_lock l(*cg_mu_);
    233     cost_graph_->Swap(metadata.mutable_cost_graph());
    234   } else {
    235     s = sess->Run({}, {}, {op}, nullptr);
    236   }
    237   return s;
    238 }
    239 
    240 }  // namespace tensorflow
    241