1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/worker.h" 17 18 #include "tensorflow/core/common_runtime/device_mgr.h" 19 #include "tensorflow/core/common_runtime/process_util.h" 20 #include "tensorflow/core/common_runtime/step_stats_collector.h" 21 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 22 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 23 #include "tensorflow/core/distributed_runtime/worker_session.h" 24 #include "tensorflow/core/platform/tracing.h" 25 26 namespace tensorflow { 27 28 Worker::Worker(WorkerEnv* env) 29 : env_(env), cancellation_manager_(new CancellationManager) {} 30 31 void Worker::GetStatusAsync(const GetStatusRequest* request, 32 GetStatusResponse* response, StatusCallback done) { 33 DeviceMgr* dm = env_->device_mgr; 34 std::vector<DeviceAttributes> devices; 35 dm->ListDeviceAttributes(&devices); 36 response->mutable_device_attributes()->Reserve(devices.size()); 37 for (auto& d : devices) { 38 response->add_device_attributes()->Swap(&d); 39 } 40 done(Status::OK()); 41 } 42 43 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 44 CreateWorkerSessionResponse* response, 45 StatusCallback done) { 46 Status s = env_->session_mgr->CreateSession(request->session_handle(), 47 request->server_def(), 48 request->isolate_session_state()); 49 done(s); 50 } 51 52 void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, 53 DeleteWorkerSessionResponse* response, 54 StatusCallback done) { 55 Status s = env_->session_mgr->DeleteSession(request->session_handle()); 56 done(s); 57 } 58 59 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, 60 RegisterGraphResponse* response, 61 StatusCallback done) { 62 auto session = 63 env_->session_mgr->WorkerSessionForSession(request->session_handle()); 64 Status s = session->graph_mgr->Register( 65 request->session_handle(), request->graph_def(), request->graph_options(), 66 request->debug_options(), session->cluster_flr.get(), 67 response->mutable_graph_handle()); 68 done(s); 69 } 70 71 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, 72 DeregisterGraphResponse* response, 73 StatusCallback done) { 74 auto session = 75 env_->session_mgr->WorkerSessionForSession(request->session_handle()); 76 Status s = session->graph_mgr->Deregister(request->graph_handle()); 77 78 done(s); 79 } 80 81 void Worker::AbortStep(int64 step_id) { 82 Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id); 83 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { 84 // Delay a bit before aborting the step. This way, the root 85 // cause may return first back to the client instead of this 86 // cancellation generated abort error. 87 rendez->StartAbort(errors::Aborted("Step ", step_id)); 88 rendez->Unref(); 89 }); 90 } 91 92 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, 93 GraphMgr::NamedTensors* in, 94 GraphMgr::NamedTensors* out) { 95 static Tensor empty_tensor(DT_FLOAT); 96 if (req->num_sends() > 0) { 97 Tensor val; 98 for (size_t i = 0; i < req->num_sends(); ++i) { 99 TF_RETURN_IF_ERROR(req->SendValue(i, &val)); 100 in->insert({req->send_key(i), val}); 101 } 102 } 103 for (size_t i = 0; i < req->num_recvs(); ++i) { 104 out->insert({req->recv_key(i), empty_tensor}); 105 } 106 return Status::OK(); 107 } 108 109 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, 110 MutableRunGraphResponseWrapper* response, 111 StatusCallback done) { 112 if (request->store_errors_in_response_body()) { 113 done = [response, done](const Status& status) { 114 response->set_status(status); 115 done(Status::OK()); 116 }; 117 } 118 if (request->is_partial()) { 119 DoPartialRunGraph(opts, request, response, std::move(done)); 120 } else { 121 DoRunGraph(opts, request, response, std::move(done)); 122 } 123 } 124 125 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() { 126 return new InMemoryRunGraphRequest; 127 } 128 129 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() { 130 return new InMemoryRunGraphResponse; 131 } 132 133 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, 134 MutableRunGraphResponseWrapper* response, 135 StatusCallback done) { 136 const int64 step_id = request->step_id(); 137 TRACEPRINTF("RunGraph: %lld", step_id); 138 auto session = 139 env_->session_mgr->WorkerSessionForSession(request->session_handle()); 140 GraphMgr::NamedTensors in; 141 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; 142 Status s = PrepareRunGraph(request, &in, out); 143 if (!s.ok()) { 144 delete out; 145 done(s); 146 return; 147 } 148 StepStatsCollector* collector = nullptr; 149 if (request->exec_opts().report_tensor_allocations_upon_oom() || 150 request->exec_opts().record_timeline() || 151 request->exec_opts().record_costs()) { 152 collector = new StepStatsCollector(response->mutable_step_stats()); 153 // TODO(mrry,pbar): GPU tracing for distributed steps. 154 } 155 CancellationManager* cm = new CancellationManager; 156 opts->SetCancelCallback([this, cm, step_id]() { 157 cm->StartCancel(); 158 AbortStep(step_id); 159 }); 160 CancellationToken token; 161 { 162 mutex_lock l(mu_); 163 token = cancellation_manager_->get_cancellation_token(); 164 bool already_cancelled = !cancellation_manager_->RegisterCallback( 165 token, [cm]() { cm->StartCancel(); }); 166 if (already_cancelled) { 167 opts->ClearCancelCallback(); 168 delete cm; 169 delete collector; 170 delete out; 171 done(errors::Aborted("Call was aborted")); 172 return; 173 } 174 } 175 session->graph_mgr->ExecuteAsync( 176 request->graph_handle(), step_id, session.get(), request->exec_opts(), 177 collector, response, cm, in, 178 [this, step_id, response, session, cm, out, token, collector, opts, 179 done](Status s) { 180 if (s.ok()) { 181 s = session->graph_mgr->RecvOutputs(step_id, out); 182 } 183 opts->ClearCancelCallback(); 184 { 185 mutex_lock l(mu_); 186 cancellation_manager_->DeregisterCallback(token); 187 } 188 delete cm; 189 190 if (s.ok()) { 191 for (const auto& p : *out) { 192 const string& key = p.first; 193 const Tensor& val = p.second; 194 response->AddRecv(key, val); 195 } 196 } 197 if (collector) collector->Finalize(); 198 delete collector; 199 delete out; 200 done(s); 201 }); 202 } 203 204 // TODO(suharshs): Add stats collection support to partial run. 205 void Worker::DoPartialRunGraph(CallOptions* opts, 206 RunGraphRequestWrapper* request, 207 MutableRunGraphResponseWrapper* response, 208 StatusCallback done) { 209 const int64 step_id = request->step_id(); 210 const string& graph_handle = request->graph_handle(); 211 TRACEPRINTF("PartialRunGraph: %lld", step_id); 212 auto session = 213 env_->session_mgr->WorkerSessionForSession(request->session_handle()); 214 215 GraphMgr::NamedTensors in; 216 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; 217 Status s = PrepareRunGraph(request, &in, out); 218 auto finish = [this, done, out, opts](const Status& s) { 219 opts->ClearCancelCallback(); 220 delete out; 221 done(s); 222 }; 223 if (!s.ok()) { 224 finish(s); 225 return; 226 } 227 228 CancellationManager* cm = nullptr; 229 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm); 230 231 // Before we start doing anything, we set the RPC cancellation. 232 opts->SetCancelCallback([this, cm, step_id]() { 233 cm->StartCancel(); 234 AbortStep(step_id); 235 }); 236 237 // If this is a new partial run request, the request will need to start the 238 // executors. 239 if (is_new_partial_run) { 240 CancellationToken token; 241 { 242 mutex_lock l(mu_); 243 token = cancellation_manager_->get_cancellation_token(); 244 cancellation_manager_->RegisterCallback(token, 245 [cm]() { cm->StartCancel(); }); 246 } 247 session->graph_mgr->ExecuteAsync( 248 graph_handle, step_id, session.get(), request->exec_opts(), 249 nullptr /* collector */, nullptr /* response */, cm, in, 250 [this, token, step_id, session, cm](Status s) { 251 { 252 mutex_lock l(mu_); 253 cancellation_manager_->DeregisterCallback(token); 254 } 255 partial_run_mgr_.ExecutorDone(step_id, s); 256 }); 257 } else { 258 // Send the partial run's new inputs. 259 s = session->graph_mgr->SendInputs(step_id, in); 260 if (!s.ok()) { 261 finish(s); 262 return; 263 } 264 } 265 266 session->graph_mgr->RecvOutputsAsync( 267 step_id, out, [this, out, request, response, step_id, finish](Status s) { 268 if (s.ok()) { 269 // Construct and return the resp. 270 for (const auto& p : *out) { 271 const string& key = p.first; 272 const Tensor& val = p.second; 273 response->AddRecv(key, val); 274 } 275 } 276 if (request->is_last_partial_run()) { 277 partial_run_mgr_.PartialRunDone(step_id, finish, s); 278 } else { 279 finish(s); 280 } 281 }); 282 } 283 284 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, 285 CleanupGraphResponse* response, 286 StatusCallback done) { 287 const int64 step_id = request->step_id(); 288 env_->rendezvous_mgr->Cleanup(step_id); 289 done(Status::OK()); 290 } 291 292 void Worker::CleanupAllAsync(const CleanupAllRequest* request, 293 CleanupAllResponse* response, 294 StatusCallback done) { 295 std::vector<string> containers; 296 for (const auto& c : request->container()) containers.push_back(c); 297 env_->device_mgr->ClearContainers(containers); 298 done(Status::OK()); 299 } 300 301 void Worker::LoggingAsync(const LoggingRequest* request, 302 LoggingResponse* response, StatusCallback done) { 303 done(errors::Unimplemented("Logging")); 304 } 305 306 void Worker::TracingAsync(const TracingRequest* request, 307 TracingResponse* response, StatusCallback done) { 308 done(errors::Unimplemented("Tracing")); 309 } 310 311 // Helper for RecvTensor. Validates "key" and returns the source 312 // device in "*src_dev". 313 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, 314 Device** src_dev) { 315 // Figures out which device the tensor is hosted on. 316 string local_name = DeviceNameUtils::LocalName(parsed.src_device); 317 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); 318 319 // Does the device have the right incarnation number we expect? 320 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { 321 return errors::Aborted( 322 "RecvTensor expects a different device incarnation: ", 323 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), 324 ". Your worker job was probably restarted. Check your " 325 "worker job for the reason why it was restarted."); 326 } 327 328 return Status::OK(); 329 } 330 331 void Worker::RecvTensorAsync(CallOptions* opts, 332 const RecvTensorRequest* request, 333 TensorResponse* response, StatusCallback done) { 334 // The base Worker class does not implement RecvTensorAsync, because 335 // it is not currently used for worker-to-worker communication. Use a 336 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`) 337 // instead. 338 done(errors::Unimplemented("Worker::RecvTensorAsync()")); 339 } 340 341 } // namespace tensorflow 342