Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
     17 
     18 #include <cstring>
     19 #include <limits>
     20 #include <memory>
     21 
     22 #include "grpc++/grpc++.h"
     23 #include "grpc++/security/credentials.h"
     24 #include "grpc++/server_builder.h"
     25 #include "grpc/support/alloc.h"
     26 
     27 #include "tensorflow/core/common_runtime/device_factory.h"
     28 #include "tensorflow/core/common_runtime/device_mgr.h"
     29 #include "tensorflow/core/common_runtime/process_util.h"
     30 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
     31 #include "tensorflow/core/distributed_runtime/local_master.h"
     32 #include "tensorflow/core/distributed_runtime/master.h"
     33 #include "tensorflow/core/distributed_runtime/master_env.h"
     34 #include "tensorflow/core/distributed_runtime/master_session.h"
     35 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
     36 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     37 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
     38 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
     39 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
     40 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
     41 #include "tensorflow/core/distributed_runtime/server_lib.h"
     42 #include "tensorflow/core/distributed_runtime/worker_env.h"
     43 #include "tensorflow/core/framework/op.h"
     44 #include "tensorflow/core/lib/strings/strcat.h"
     45 #include "tensorflow/core/platform/env.h"
     46 #include "tensorflow/core/platform/mem.h"
     47 #include "tensorflow/core/public/session_options.h"
     48 
     49 namespace tensorflow {
     50 
     51 namespace {
     52 
     53 // Define an option subclass in order to disable SO_REUSEPORT for the
     54 // server socket.
     55 class NoReusePortOption : public ::grpc::ServerBuilderOption {
     56  public:
     57   void UpdateArguments(::grpc::ChannelArguments* args) override {
     58     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
     59   }
     60 
     61   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
     62                          plugins) override {}
     63 };
     64 
     65 // static utility function
     66 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
     67   return new RpcRendezvousMgr(env);
     68 }
     69 
     70 }  // namespace
     71 
     72 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
     73     : server_def_(server_def), env_(env), state_(NEW) {}
     74 
     75 GrpcServer::~GrpcServer() {
     76   TF_CHECK_OK(Stop());
     77   TF_CHECK_OK(Join());
     78 
     79   delete master_service_;
     80   delete worker_service_;
     81 
     82   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
     83   // to destroy them.
     84 
     85   // Shut down all outstanding rendezvous.
     86   delete worker_env_.rendezvous_mgr;
     87 
     88   // We must delete graph_mgr before device_mgr, due to shared
     89   // ownership of OpKernels in the executors. (The graph_mgr will
     90   // free all stateless OpKernels, and pass over borrowed stateful
     91   // OpKernels, which are also held in their respective devices'
     92   // OpSegments.)
     93   if (worker_env_.session_mgr != nullptr) {
     94     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
     95   } else {
     96     // Note: session_mgr's legacy_session_ deletes device_mgr now.
     97     delete worker_env_.device_mgr;
     98   }
     99 
    100   // Do not delete (as these are not owned by the server):
    101   // - master_env_.env
    102   // - worker_env_.env
    103   // - worker_env_.compute_pool
    104 }
    105 
    106 Status GrpcServer::Init(
    107     ServiceInitFunction service_func,
    108     const RendezvousMgrCreationFunction& rendezvous_mgr_func,
    109     const WorkerCreationFunction& worker_func) {
    110   mutex_lock l(mu_);
    111   CHECK_EQ(state_, NEW);
    112   master_env_.env = env_;
    113   worker_env_.env = env_;
    114 
    115   SessionOptions sess_opts;
    116   ConfigProto config = server_def_.default_session_config();
    117   sess_opts.config = config;
    118 
    119   // Configure shared devices between master and worker.
    120   string name_prefix =
    121       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
    122                       "/task:", server_def_.task_index());
    123   TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
    124                                                &master_env_.local_devices));
    125   worker_env_.local_devices = master_env_.local_devices;
    126   worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
    127   worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
    128                                    ? new RpcRendezvousMgr(&worker_env_)
    129                                    : rendezvous_mgr_func(&worker_env_);
    130   string unused;
    131   string default_worker_name;
    132   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
    133                                         &default_worker_name, &unused)) {
    134     return errors::Internal("Could not parse worker name.");
    135   }
    136 
    137   // Look up the port that has been requested for this task in `server_def_`.
    138   int requested_port = -1;
    139   for (const auto& job : server_def_.cluster().job()) {
    140     if (job.name() == server_def_.job_name()) {
    141       auto iter = job.tasks().find(server_def_.task_index());
    142       if (iter == job.tasks().end()) {
    143         return errors::InvalidArgument("Task ", server_def_.task_index(),
    144                                        " was not defined in job \"",
    145                                        server_def_.job_name(), "\"");
    146       }
    147       const std::vector<string> hostname_port =
    148           str_util::Split(iter->second, ':');
    149       if (hostname_port.size() != 2 ||
    150           !strings::safe_strto32(hostname_port[1], &requested_port)) {
    151         return errors::InvalidArgument(
    152             "Could not parse port for local server from \"", iter->second,
    153             "\"");
    154       } else {
    155         break;
    156       }
    157     }
    158   }
    159   if (requested_port == -1) {
    160     return errors::Internal("Job \"", server_def_.job_name(),
    161                             "\" was not defined in cluster");
    162   }
    163 
    164   // N.B. The order of initialization here is intricate, because we
    165   // wish to allow `requested_port == 0` (for choosing any port,
    166   // mostly for testing). Therefore, the construction of the channel
    167   // and worker caches depends on `bound_port_`, which is not set
    168   // until we call `builder.BuildAndStart()`. We must create the
    169   // service objects before calling `builder.BuildAndStart()`, but
    170   // `master_env_` and `worker_env_` are only partially
    171   // configured. However, this is not dangerous, because we do not
    172   // start serving requests until `this->Start()` is called, which
    173   // happens after this method returns.
    174   //
    175   // TODO(mrry): Provide a general mechanism for dynamically setting
    176   // the identities of tasks in the worker pool after the service is
    177   // running.
    178   ::grpc::ServerBuilder builder;
    179   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
    180                            GetServerCredentials(server_def_), &bound_port_);
    181   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
    182   builder.SetOption(
    183       std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
    184   master_impl_ = CreateMaster(&master_env_);
    185   master_service_ = NewGrpcMasterService(
    186       master_impl_.get(), config.operation_timeout_in_ms(), &builder);
    187   worker_impl_ =
    188       worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
    189   worker_service_ =
    190       NewGrpcWorkerService(worker_impl_.get(), &builder).release();
    191   // extra service:
    192   if (service_func != nullptr) {
    193     service_func(&worker_env_, &builder);
    194   }
    195   server_ = builder.BuildAndStart();
    196 
    197   if (!server_) {
    198     return errors::Unknown("Could not start gRPC server");
    199   }
    200 
    201   WorkerCacheInterface* worker_cache;
    202   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
    203   TF_RETURN_IF_ERROR(
    204       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
    205   CHECK_NE(nullptr, worker_cache);
    206 
    207   // Set up worker environment.
    208   worker_env_.session_mgr = new SessionMgr(
    209       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
    210       std::unique_ptr<WorkerCacheInterface>(worker_cache),
    211       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
    212         WorkerCacheFactoryOptions options(server_def);
    213         return WorkerCacheFactory(options, worker_cache);
    214       });
    215   worker_env_.compute_pool = ComputePool(sess_opts);
    216 
    217   // Finish setting up master environment.
    218   master_env_.ops = OpRegistry::Global();
    219   master_env_.worker_cache = worker_cache;
    220   master_env_.master_session_factory =
    221       [config](
    222           SessionOptions options, const MasterEnv* env,
    223           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
    224           std::unique_ptr<WorkerCacheInterface> worker_cache,
    225           std::unique_ptr<DeviceSet> device_set) {
    226         options.config.MergeFrom(config);
    227         return new MasterSession(options, env, std::move(remote_devs),
    228                                  std::move(worker_cache), std::move(device_set),
    229                                  CreateNoOpStatsPublisher);
    230       };
    231   master_env_.worker_cache_factory =
    232       [this](const WorkerCacheFactoryOptions& options,
    233              WorkerCacheInterface** worker_cache) {
    234         return WorkerCacheFactory(options, worker_cache);
    235       };
    236 
    237   // Provide direct access to the master from in-process clients.
    238   LocalMaster::Register(target(), master_impl_.get(),
    239                         config.operation_timeout_in_ms());
    240 
    241   return Status::OK();
    242 }
    243 
    244 Status GrpcServer::Init(
    245     ServiceInitFunction service_func,
    246     const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
    247   return Init(service_func, rendezvous_mgr_func, nullptr);
    248 }
    249 
    250 Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
    251 
    252 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
    253                                     GrpcChannelSpec* channel_spec) {
    254   for (const auto& job : options.cluster_def->job()) {
    255     std::map<int, string> host_ports;
    256     for (const auto& task : job.tasks()) {
    257       string& host_port = host_ports[task.first];
    258       if (!host_port.empty()) {
    259         return errors::InvalidArgument("JobDef for job \"", job.name(),
    260                                        "\" specified two addresses for task \"",
    261                                        task.first, "\": ", host_port, " and ",
    262                                        task.second);
    263       }
    264       if (job.name() == *options.job_name && task.first == options.task_index) {
    265         host_port = strings::StrCat("localhost:", bound_port_);
    266       } else {
    267         host_port = task.second;
    268       }
    269     }
    270     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
    271   }
    272   return Status::OK();
    273 }
    274 
    275 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
    276                                       WorkerCacheInterface** worker_cache) {
    277   if (options.job_name == nullptr || options.job_name->empty()) {
    278     Status s = errors::InvalidArgument(
    279         "The master (current machine) is not included in the provided "
    280         "cluster_def. ",
    281         options.cluster_def->DebugString());
    282     LOG(WARNING) << s;
    283     return s;
    284   }
    285 
    286   GrpcChannelSpec channel_spec;
    287   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
    288 
    289   std::unique_ptr<GrpcChannelCache> channel_cache(
    290       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
    291 
    292   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
    293                                        "/task:", options.task_index);
    294 
    295   const string host_port = channel_cache->TranslateTask(name_prefix);
    296   int requested_port;
    297 
    298   if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
    299                              &requested_port)) {
    300     return errors::Internal("Could not parse port for local server from \"",
    301                             channel_cache->TranslateTask(name_prefix), "\".");
    302   }
    303   if (requested_port != bound_port_) {
    304     return errors::InvalidArgument("Requested port ", requested_port,
    305                                    " differs from expected port ", bound_port_);
    306   }
    307 
    308   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
    309       channel_cache.release(), worker_impl_.get(), name_prefix);
    310   return Status::OK();
    311 }
    312 
    313 Status GrpcServer::Start() {
    314   mutex_lock l(mu_);
    315   switch (state_) {
    316     case NEW: {
    317       master_thread_.reset(
    318           env_->StartThread(ThreadOptions(), "TF_master_service",
    319                             [this] { master_service_->HandleRPCsLoop(); }));
    320       worker_thread_.reset(
    321           env_->StartThread(ThreadOptions(), "TF_worker_service",
    322                             [this] { worker_service_->HandleRPCsLoop(); }));
    323       state_ = STARTED;
    324       LOG(INFO) << "Started server with target: " << target();
    325       return Status::OK();
    326     }
    327     case STARTED:
    328       LOG(INFO) << "Server already started (target: " << target() << ")";
    329       return Status::OK();
    330     case STOPPED:
    331       return errors::FailedPrecondition("Server has stopped.");
    332     default:
    333       LOG(FATAL);
    334   }
    335 }
    336 
    337 Status GrpcServer::Stop() {
    338   mutex_lock l(mu_);
    339   switch (state_) {
    340     case NEW:
    341       state_ = STOPPED;
    342       return Status::OK();
    343     case STARTED:
    344       return errors::Unimplemented(
    345           "Clean shutdown is not currently implemented");
    346     case STOPPED:
    347       LOG(INFO) << "Server already stopped (target: " << target() << ")";
    348       return Status::OK();
    349     default:
    350       LOG(FATAL);
    351   }
    352 }
    353 
    354 Status GrpcServer::Join() {
    355   mutex_lock l(mu_);
    356   switch (state_) {
    357     case NEW:
    358       // Prevent the server from being started subsequently.
    359       state_ = STOPPED;
    360       return Status::OK();
    361     case STARTED:
    362     case STOPPED:
    363       master_thread_.reset();
    364       worker_thread_.reset();
    365       return Status::OK();
    366     default:
    367       LOG(FATAL);
    368   }
    369 }
    370 
    371 const string GrpcServer::target() const {
    372   return strings::StrCat("grpc://localhost:", bound_port_);
    373 }
    374 
    375 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
    376     const ServerDef& server_def) const {
    377   return ::grpc::InsecureServerCredentials();
    378 }
    379 
    380 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
    381   // We can do this because SparseGrpcChannelCache is robust to nullptr being
    382   // returned by the channel creation function
    383   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
    384 }
    385 
    386 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
    387   return std::unique_ptr<Master>(new Master(master_env, 0.0));
    388 }
    389 
    390 /* static */
    391 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
    392                           std::unique_ptr<ServerInterface>* out_server) {
    393   std::unique_ptr<GrpcServer> ret(
    394       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
    395   ServiceInitFunction service_func = nullptr;
    396   TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
    397   *out_server = std::move(ret);
    398   return Status::OK();
    399 }
    400 
    401 namespace {
    402 
    403 class GrpcServerFactory : public ServerFactory {
    404  public:
    405   bool AcceptsOptions(const ServerDef& server_def) override {
    406     return server_def.protocol() == "grpc";
    407   }
    408 
    409   Status NewServer(const ServerDef& server_def,
    410                    std::unique_ptr<ServerInterface>* out_server) override {
    411     return GrpcServer::Create(server_def, Env::Default(), out_server);
    412   }
    413 };
    414 
    415 // Registers a `ServerFactory` for `GrpcServer` instances.
    416 class GrpcServerRegistrar {
    417  public:
    418   GrpcServerRegistrar() {
    419     gpr_allocation_functions alloc_fns;
    420     memset(&alloc_fns, 0, sizeof(alloc_fns));
    421     alloc_fns.malloc_fn = port::Malloc;
    422     alloc_fns.realloc_fn = port::Realloc;
    423     alloc_fns.free_fn = port::Free;
    424     gpr_set_allocation_functions(alloc_fns);
    425     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
    426   }
    427 };
    428 static GrpcServerRegistrar registrar;
    429 
    430 }  // namespace
    431 }  // namespace tensorflow
    432