1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 17 18 #include <cstring> 19 #include <limits> 20 #include <memory> 21 22 #include "grpc++/grpc++.h" 23 #include "grpc++/security/credentials.h" 24 #include "grpc++/server_builder.h" 25 #include "grpc/support/alloc.h" 26 27 #include "tensorflow/core/common_runtime/device_factory.h" 28 #include "tensorflow/core/common_runtime/device_mgr.h" 29 #include "tensorflow/core/common_runtime/process_util.h" 30 #include "tensorflow/core/distributed_runtime/graph_mgr.h" 31 #include "tensorflow/core/distributed_runtime/local_master.h" 32 #include "tensorflow/core/distributed_runtime/master.h" 33 #include "tensorflow/core/distributed_runtime/master_env.h" 34 #include "tensorflow/core/distributed_runtime/master_session.h" 35 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 36 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 37 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" 38 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 39 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 40 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" 41 #include "tensorflow/core/distributed_runtime/server_lib.h" 42 #include "tensorflow/core/distributed_runtime/worker_env.h" 43 #include "tensorflow/core/framework/op.h" 44 #include "tensorflow/core/lib/strings/strcat.h" 45 #include "tensorflow/core/platform/env.h" 46 #include "tensorflow/core/platform/mem.h" 47 #include "tensorflow/core/public/session_options.h" 48 49 namespace tensorflow { 50 51 namespace { 52 53 // Define an option subclass in order to disable SO_REUSEPORT for the 54 // server socket. 55 class NoReusePortOption : public ::grpc::ServerBuilderOption { 56 public: 57 void UpdateArguments(::grpc::ChannelArguments* args) override { 58 args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); 59 } 60 61 void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>* 62 plugins) override {} 63 }; 64 65 // static utility function 66 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { 67 return new RpcRendezvousMgr(env); 68 } 69 70 } // namespace 71 72 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env) 73 : server_def_(server_def), env_(env), state_(NEW) {} 74 75 GrpcServer::~GrpcServer() { 76 TF_CHECK_OK(Stop()); 77 TF_CHECK_OK(Join()); 78 79 delete master_service_; 80 delete worker_service_; 81 82 // TODO(mrry): Refactor the *Env classes so that it is less fiddly 83 // to destroy them. 84 85 // Shut down all outstanding rendezvous. 86 delete worker_env_.rendezvous_mgr; 87 88 // We must delete graph_mgr before device_mgr, due to shared 89 // ownership of OpKernels in the executors. (The graph_mgr will 90 // free all stateless OpKernels, and pass over borrowed stateful 91 // OpKernels, which are also held in their respective devices' 92 // OpSegments.) 93 if (worker_env_.session_mgr != nullptr) { 94 delete worker_env_.session_mgr; // Deletes graph_mgr's. 95 } else { 96 // Note: session_mgr's legacy_session_ deletes device_mgr now. 97 delete worker_env_.device_mgr; 98 } 99 100 // Do not delete (as these are not owned by the server): 101 // - master_env_.env 102 // - worker_env_.env 103 // - worker_env_.compute_pool 104 } 105 106 Status GrpcServer::Init( 107 ServiceInitFunction service_func, 108 const RendezvousMgrCreationFunction& rendezvous_mgr_func, 109 const WorkerCreationFunction& worker_func) { 110 mutex_lock l(mu_); 111 CHECK_EQ(state_, NEW); 112 master_env_.env = env_; 113 worker_env_.env = env_; 114 115 SessionOptions sess_opts; 116 ConfigProto config = server_def_.default_session_config(); 117 sess_opts.config = config; 118 119 // Configure shared devices between master and worker. 120 string name_prefix = 121 strings::StrCat("/job:", server_def_.job_name(), "/replica:0", 122 "/task:", server_def_.task_index()); 123 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix, 124 &master_env_.local_devices)); 125 worker_env_.local_devices = master_env_.local_devices; 126 worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices); 127 worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr 128 ? new RpcRendezvousMgr(&worker_env_) 129 : rendezvous_mgr_func(&worker_env_); 130 string unused; 131 string default_worker_name; 132 if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(), 133 &default_worker_name, &unused)) { 134 return errors::Internal("Could not parse worker name."); 135 } 136 137 // Look up the port that has been requested for this task in `server_def_`. 138 int requested_port = -1; 139 for (const auto& job : server_def_.cluster().job()) { 140 if (job.name() == server_def_.job_name()) { 141 auto iter = job.tasks().find(server_def_.task_index()); 142 if (iter == job.tasks().end()) { 143 return errors::InvalidArgument("Task ", server_def_.task_index(), 144 " was not defined in job \"", 145 server_def_.job_name(), "\""); 146 } 147 const std::vector<string> hostname_port = 148 str_util::Split(iter->second, ':'); 149 if (hostname_port.size() != 2 || 150 !strings::safe_strto32(hostname_port[1], &requested_port)) { 151 return errors::InvalidArgument( 152 "Could not parse port for local server from \"", iter->second, 153 "\""); 154 } else { 155 break; 156 } 157 } 158 } 159 if (requested_port == -1) { 160 return errors::Internal("Job \"", server_def_.job_name(), 161 "\" was not defined in cluster"); 162 } 163 164 // N.B. The order of initialization here is intricate, because we 165 // wish to allow `requested_port == 0` (for choosing any port, 166 // mostly for testing). Therefore, the construction of the channel 167 // and worker caches depends on `bound_port_`, which is not set 168 // until we call `builder.BuildAndStart()`. We must create the 169 // service objects before calling `builder.BuildAndStart()`, but 170 // `master_env_` and `worker_env_` are only partially 171 // configured. However, this is not dangerous, because we do not 172 // start serving requests until `this->Start()` is called, which 173 // happens after this method returns. 174 // 175 // TODO(mrry): Provide a general mechanism for dynamically setting 176 // the identities of tasks in the worker pool after the service is 177 // running. 178 ::grpc::ServerBuilder builder; 179 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port), 180 GetServerCredentials(server_def_), &bound_port_); 181 builder.SetMaxMessageSize(std::numeric_limits<int32>::max()); 182 builder.SetOption( 183 std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption)); 184 master_impl_ = CreateMaster(&master_env_); 185 master_service_ = NewGrpcMasterService( 186 master_impl_.get(), config.operation_timeout_in_ms(), &builder); 187 worker_impl_ = 188 worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_); 189 worker_service_ = 190 NewGrpcWorkerService(worker_impl_.get(), &builder).release(); 191 // extra service: 192 if (service_func != nullptr) { 193 service_func(&worker_env_, &builder); 194 } 195 server_ = builder.BuildAndStart(); 196 197 if (!server_) { 198 return errors::Unknown("Could not start gRPC server"); 199 } 200 201 WorkerCacheInterface* worker_cache; 202 WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); 203 TF_RETURN_IF_ERROR( 204 WorkerCacheFactory(worker_cache_factory_options, &worker_cache)); 205 CHECK_NE(nullptr, worker_cache); 206 207 // Set up worker environment. 208 worker_env_.session_mgr = new SessionMgr( 209 &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_), 210 std::unique_ptr<WorkerCacheInterface>(worker_cache), 211 [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { 212 WorkerCacheFactoryOptions options(server_def); 213 return WorkerCacheFactory(options, worker_cache); 214 }); 215 worker_env_.compute_pool = ComputePool(sess_opts); 216 217 // Finish setting up master environment. 218 master_env_.ops = OpRegistry::Global(); 219 master_env_.worker_cache = worker_cache; 220 master_env_.master_session_factory = 221 [config]( 222 SessionOptions options, const MasterEnv* env, 223 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, 224 std::unique_ptr<WorkerCacheInterface> worker_cache, 225 std::unique_ptr<DeviceSet> device_set) { 226 options.config.MergeFrom(config); 227 return new MasterSession(options, env, std::move(remote_devs), 228 std::move(worker_cache), std::move(device_set), 229 CreateNoOpStatsPublisher); 230 }; 231 master_env_.worker_cache_factory = 232 [this](const WorkerCacheFactoryOptions& options, 233 WorkerCacheInterface** worker_cache) { 234 return WorkerCacheFactory(options, worker_cache); 235 }; 236 237 // Provide direct access to the master from in-process clients. 238 LocalMaster::Register(target(), master_impl_.get(), 239 config.operation_timeout_in_ms()); 240 241 return Status::OK(); 242 } 243 244 Status GrpcServer::Init( 245 ServiceInitFunction service_func, 246 const RendezvousMgrCreationFunction& rendezvous_mgr_func) { 247 return Init(service_func, rendezvous_mgr_func, nullptr); 248 } 249 250 Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); } 251 252 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, 253 GrpcChannelSpec* channel_spec) { 254 for (const auto& job : options.cluster_def->job()) { 255 std::map<int, string> host_ports; 256 for (const auto& task : job.tasks()) { 257 string& host_port = host_ports[task.first]; 258 if (!host_port.empty()) { 259 return errors::InvalidArgument("JobDef for job \"", job.name(), 260 "\" specified two addresses for task \"", 261 task.first, "\": ", host_port, " and ", 262 task.second); 263 } 264 if (job.name() == *options.job_name && task.first == options.task_index) { 265 host_port = strings::StrCat("localhost:", bound_port_); 266 } else { 267 host_port = task.second; 268 } 269 } 270 TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports)); 271 } 272 return Status::OK(); 273 } 274 275 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 276 WorkerCacheInterface** worker_cache) { 277 if (options.job_name == nullptr || options.job_name->empty()) { 278 Status s = errors::InvalidArgument( 279 "The master (current machine) is not included in the provided " 280 "cluster_def. ", 281 options.cluster_def->DebugString()); 282 LOG(WARNING) << s; 283 return s; 284 } 285 286 GrpcChannelSpec channel_spec; 287 TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); 288 289 std::unique_ptr<GrpcChannelCache> channel_cache( 290 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); 291 292 string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", 293 "/task:", options.task_index); 294 295 const string host_port = channel_cache->TranslateTask(name_prefix); 296 int requested_port; 297 298 if (!strings::safe_strto32(str_util::Split(host_port, ':')[1], 299 &requested_port)) { 300 return errors::Internal("Could not parse port for local server from \"", 301 channel_cache->TranslateTask(name_prefix), "\"."); 302 } 303 if (requested_port != bound_port_) { 304 return errors::InvalidArgument("Requested port ", requested_port, 305 " differs from expected port ", bound_port_); 306 } 307 308 *worker_cache = NewGrpcWorkerCacheWithLocalWorker( 309 channel_cache.release(), worker_impl_.get(), name_prefix); 310 return Status::OK(); 311 } 312 313 Status GrpcServer::Start() { 314 mutex_lock l(mu_); 315 switch (state_) { 316 case NEW: { 317 master_thread_.reset( 318 env_->StartThread(ThreadOptions(), "TF_master_service", 319 [this] { master_service_->HandleRPCsLoop(); })); 320 worker_thread_.reset( 321 env_->StartThread(ThreadOptions(), "TF_worker_service", 322 [this] { worker_service_->HandleRPCsLoop(); })); 323 state_ = STARTED; 324 LOG(INFO) << "Started server with target: " << target(); 325 return Status::OK(); 326 } 327 case STARTED: 328 LOG(INFO) << "Server already started (target: " << target() << ")"; 329 return Status::OK(); 330 case STOPPED: 331 return errors::FailedPrecondition("Server has stopped."); 332 default: 333 LOG(FATAL); 334 } 335 } 336 337 Status GrpcServer::Stop() { 338 mutex_lock l(mu_); 339 switch (state_) { 340 case NEW: 341 state_ = STOPPED; 342 return Status::OK(); 343 case STARTED: 344 return errors::Unimplemented( 345 "Clean shutdown is not currently implemented"); 346 case STOPPED: 347 LOG(INFO) << "Server already stopped (target: " << target() << ")"; 348 return Status::OK(); 349 default: 350 LOG(FATAL); 351 } 352 } 353 354 Status GrpcServer::Join() { 355 mutex_lock l(mu_); 356 switch (state_) { 357 case NEW: 358 // Prevent the server from being started subsequently. 359 state_ = STOPPED; 360 return Status::OK(); 361 case STARTED: 362 case STOPPED: 363 master_thread_.reset(); 364 worker_thread_.reset(); 365 return Status::OK(); 366 default: 367 LOG(FATAL); 368 } 369 } 370 371 const string GrpcServer::target() const { 372 return strings::StrCat("grpc://localhost:", bound_port_); 373 } 374 375 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials( 376 const ServerDef& server_def) const { 377 return ::grpc::InsecureServerCredentials(); 378 } 379 380 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const { 381 // We can do this because SparseGrpcChannelCache is robust to nullptr being 382 // returned by the channel creation function 383 return ConvertToChannelCreationFunction(NewHostPortGrpcChannel); 384 } 385 386 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) { 387 return std::unique_ptr<Master>(new Master(master_env, 0.0)); 388 } 389 390 /* static */ 391 Status GrpcServer::Create(const ServerDef& server_def, Env* env, 392 std::unique_ptr<ServerInterface>* out_server) { 393 std::unique_ptr<GrpcServer> ret( 394 new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); 395 ServiceInitFunction service_func = nullptr; 396 TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr)); 397 *out_server = std::move(ret); 398 return Status::OK(); 399 } 400 401 namespace { 402 403 class GrpcServerFactory : public ServerFactory { 404 public: 405 bool AcceptsOptions(const ServerDef& server_def) override { 406 return server_def.protocol() == "grpc"; 407 } 408 409 Status NewServer(const ServerDef& server_def, 410 std::unique_ptr<ServerInterface>* out_server) override { 411 return GrpcServer::Create(server_def, Env::Default(), out_server); 412 } 413 }; 414 415 // Registers a `ServerFactory` for `GrpcServer` instances. 416 class GrpcServerRegistrar { 417 public: 418 GrpcServerRegistrar() { 419 gpr_allocation_functions alloc_fns; 420 memset(&alloc_fns, 0, sizeof(alloc_fns)); 421 alloc_fns.malloc_fn = port::Malloc; 422 alloc_fns.realloc_fn = port::Realloc; 423 alloc_fns.free_fn = port::Free; 424 gpr_set_allocation_functions(alloc_fns); 425 ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory()); 426 } 427 }; 428 static GrpcServerRegistrar registrar; 429 430 } // namespace 431 } // namespace tensorflow 432