1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/training/queue_runner.h" 17 #include "tensorflow/core/kernels/ops_util.h" 18 #include "tensorflow/core/platform/env.h" 19 20 namespace tensorflow { 21 22 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, 23 std::unique_ptr<QueueRunner>* result) { 24 result->reset(new QueueRunner()); 25 return (*result)->Init(queue_runner_def); 26 } 27 28 Status QueueRunner::New(const QueueRunnerDef& queue_runner_def, 29 Coordinator* coord, 30 std::unique_ptr<QueueRunner>* result) { 31 result->reset(new QueueRunner()); 32 (*result)->coord_ = coord; 33 return (*result)->Init(queue_runner_def); 34 } 35 36 void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) { 37 mutex_lock l(cb_mu_); 38 callbacks_.push_back(cb); 39 } 40 41 void QueueRunner::ClearErrorCallbacks() { 42 mutex_lock l(cb_mu_); 43 callbacks_.clear(); 44 } 45 46 Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { 47 queue_name_ = queue_runner_def.queue_name(); 48 enqueue_op_names_.clear(); 49 enqueue_op_names_.insert(enqueue_op_names_.end(), 50 queue_runner_def.enqueue_op_name().begin(), 51 queue_runner_def.enqueue_op_name().end()); 52 size_t op_names_size = enqueue_op_names_.size(); 53 if (op_names_size > kint32max) { 54 return Status(error::INVALID_ARGUMENT, 55 "Enqueue ops to run cannot exceed kint32max"); 56 } 57 runs_ = static_cast<int>(op_names_size); 58 if (runs_ == 0) { 59 return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run."); 60 } 61 close_op_name_ = queue_runner_def.close_op_name(); 62 cancel_op_name_ = queue_runner_def.cancel_op_name(); 63 if (queue_runner_def.queue_closed_exception_types_size() == 0) { 64 queue_closed_exception_types_.insert(error::OUT_OF_RANGE); 65 } else { 66 for (const auto& code : queue_runner_def.queue_closed_exception_types()) { 67 queue_closed_exception_types_.insert(static_cast<int>(code)); 68 } 69 } 70 71 int nthreads = runs_; 72 if (coord_) { 73 // One more thread to call Stop() 74 nthreads++; 75 } 76 thread_pool_.reset(new thread::ThreadPool( 77 Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads)); 78 79 return Status::OK(); 80 } 81 82 QueueRunner::~QueueRunner() { 83 // Cannot run Stop() here because the session might already be closed or 84 // destroyed. 85 Join().IgnoreError(); 86 } 87 88 Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } 89 90 Status QueueRunner::StartAndCollectCostGraph(Session* sess, 91 const RunOptions& run_options) { 92 SetRunArgumentsAndCostGraph(run_options); 93 return Start(sess, 0); 94 } 95 96 Status QueueRunner::Start(Session* sess, int wait_for) { 97 counter_.reset(new BlockingCounter(runs_)); 98 for (const string& enqueue_op : enqueue_op_names_) { 99 thread_pool_->Schedule( 100 std::bind(&QueueRunner::Run, this, sess, enqueue_op)); 101 } 102 if (coord_) { 103 thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess)); 104 } 105 // Wait for up to 'wait_for' milliseconds. 106 if (wait_for > 0) { 107 if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) { 108 return Status(error::DEADLINE_EXCEEDED, 109 "Queues not fed before the timeout"); 110 } 111 // Check the status of the queue runner as well as the result of the enqueue 112 // operations. 113 mutex_lock l(mu_); 114 if (!enqueue_status_.ok()) { 115 return enqueue_status_; 116 } else { 117 return status_; 118 } 119 } 120 return Status::OK(); 121 } 122 123 Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, 124 const RunOptions& run_options) { 125 SetRunArgumentsAndCostGraph(run_options); 126 return Start(session, wait_for_ms); 127 } 128 129 void QueueRunner::Stop(Session* sess) { 130 if (coord_ != nullptr) { 131 coord_->WaitForStop(); 132 } 133 if (!cancel_op_name_.empty()) { 134 UpdateStatus(RealRun(sess, cancel_op_name_, false)); 135 } 136 stopped_ = true; 137 } 138 139 Status QueueRunner::Join() { 140 thread_pool_.reset(); 141 mutex_lock l(mu_); 142 return status_; 143 } 144 145 void QueueRunner::UpdateStatus(const Status& status) { 146 { 147 mutex_lock l(mu_); 148 if (!status_.ok() || status.ok() || IsQueueClosed(status)) { 149 return; 150 } 151 status_ = status; 152 } 153 if (coord_) { 154 coord_->ReportStatus(status); 155 } 156 mutex_lock l(cb_mu_); 157 for (auto& cb : callbacks_) { 158 cb(status); 159 } 160 } 161 162 void QueueRunner::Run(Session* sess, const string& enqueue_op) { 163 bool first_iteration = true; 164 Status status; 165 while (status.ok()) { 166 if (coord_ && coord_->ShouldStop()) { 167 break; 168 } 169 status = RealRun(sess, enqueue_op, true); 170 if (first_iteration) { 171 if (!status.ok()) { 172 mutex_lock l(mu_); 173 enqueue_status_ = status; 174 } 175 counter_->DecrementCount(); 176 first_iteration = false; 177 } 178 } 179 bool last_run = false; 180 { 181 mutex_lock l(mu_); 182 runs_--; 183 last_run = (runs_ == 0); 184 } 185 186 // Close the queue unless the coordinator is shutting down since the cancel op 187 // will be run anway in this case. 188 if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { 189 if (last_run && !close_op_name_.empty()) { 190 UpdateStatus(RealRun(sess, close_op_name_, false)); 191 } 192 } else if (!status.ok()) { 193 LOG(ERROR) << "Queue runner thread got a failure status: " 194 << status.ToString(); 195 UpdateStatus(status); 196 if (coord_) { 197 coord_->RequestStop().IgnoreError(); 198 } 199 } 200 } 201 202 Status QueueRunner::GetStatus() { 203 mutex_lock l(mu_); 204 return status_; 205 } 206 207 Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { 208 if (!cg_mu_) { 209 return Status(error::FAILED_PRECONDITION, 210 "This QueueRunner doesn't collect a cost graph."); 211 } 212 mutex_lock l(*cg_mu_); 213 cost_graph->MergeFrom(*cost_graph_); 214 return Status::OK(); 215 } 216 217 void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) { 218 cg_mu_.reset(new mutex()); 219 { 220 mutex_lock l(*cg_mu_); 221 cost_graph_.reset(new CostGraphDef()); 222 } 223 run_options_ = run_options; 224 } 225 226 Status QueueRunner::RealRun(Session* sess, const string& op, 227 bool update_costs) { 228 Status s; 229 if (update_costs && cg_mu_) { 230 RunMetadata metadata; 231 s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); 232 mutex_lock l(*cg_mu_); 233 cost_graph_->Swap(metadata.mutable_cost_graph()); 234 } else { 235 s = sess->Run({}, {}, {op}, nullptr); 236 } 237 return s; 238 } 239 240 } // namespace tensorflow 241