Home | History | Annotate | Download | only in verbs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_VERBS
     17 
     18 #include "tensorflow/contrib/verbs/verbs_server_lib.h"
     19 
     20 #include "grpc/support/alloc.h"
     21 
     22 #include "tensorflow/contrib/verbs/rdma_mgr.h"
     23 #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
     24 #include "tensorflow/core/distributed_runtime/server_lib.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/platform/env.h"
     27 
     28 namespace tensorflow {
     29 
     30 namespace {
     31 // static utility function
     32 RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) {
     33   return new RdmaRendezvousMgr(env);
     34 }
     35 
     36 }  // namespace
     37 
     38 VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
     39     : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
     40 
     41 VerbsServer::~VerbsServer() {
     42   TF_CHECK_OK(Stop());
     43   TF_CHECK_OK(Join());
     44   delete rdma_mgr_;
     45   delete verbs_service_;
     46   delete channel_cache_;
     47 }
     48 
     49 Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
     50                                         GrpcChannelCache** channel_cache) {
     51   string name_prefix =
     52       strings::StrCat("/job:", server_def.job_name(), "/replica:0",
     53                       "/task:", server_def.task_index());
     54 
     55   GrpcChannelSpec channel_spec;
     56   TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
     57 
     58   *channel_cache =
     59       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction());
     60 
     61   const string host_port = (*channel_cache)->TranslateTask(name_prefix);
     62   int requested_port;
     63 
     64   if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
     65                              &requested_port)) {
     66     return errors::Internal("Could not parse port for local server from \"",
     67                             (*channel_cache)->TranslateTask(name_prefix),
     68                             "\".");
     69   }
     70   if (requested_port != bound_port()) {
     71     return errors::InvalidArgument("Requested port ", requested_port,
     72                                    " differs from expected port ",
     73                                    bound_port());
     74   }
     75 
     76   return Status::OK();
     77 }
     78 
     79 Status VerbsServer::Init(ServiceInitFunction service_func,
     80                          RendezvousMgrCreationFunction rendezvous_mgr_func) {
     81   Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
     82   {
     83     mutex_lock l(mu_);
     84     CHECK_EQ(verbs_state_, DISCONNECTED);
     85     CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
     86     rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
     87     // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
     88     verbs_service_->SetRdmaMgr(rdma_mgr_);
     89     dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr)
     90         ->SetRdmaMgr(rdma_mgr_);
     91   }
     92   return s;
     93 }
     94 
     95 Status VerbsServer::Start() {
     96   Status s = GrpcServer::Start();
     97   {
     98     mutex_lock l(mu_);
     99     if (verbs_state_ == DISCONNECTED) {
    100       // verbs_thread needs to be initiated
    101       // before rdma_mgr sets up the rdma channels.
    102       verbs_thread_.reset(worker_env()->env->StartThread(
    103           ThreadOptions(), "TF_verbs_service",
    104           [this] { verbs_service_->HandleRPCsLoop(); }));
    105       rdma_mgr_->SetupChannels();
    106       CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!";
    107       rdma_mgr_->InitAllocators();
    108       verbs_state_ = CONNECTED;
    109     }
    110   }
    111   return s;
    112 }
    113 
    114 Status VerbsServer::Join() {
    115   Status s = GrpcServer::Join();
    116   {
    117     mutex_lock l(mu_);
    118     if (verbs_state_ == CONNECTED) {
    119       verbs_state_ = DISCONNECTED;
    120       verbs_thread_.reset();
    121     }
    122   }
    123   return s;
    124 }
    125 
    126 /* static */
    127 Status VerbsServer::Create(const ServerDef& server_def, Env* env,
    128                            std::unique_ptr<ServerInterface>* out_server) {
    129   std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
    130   ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
    131                                             ::grpc::ServerBuilder* builder) {
    132     return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
    133   };
    134   TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
    135   *out_server = std::move(ret);
    136   return Status::OK();
    137 }
    138 
    139 namespace {
    140 
    141 class VerbsServerFactory : public ServerFactory {
    142  public:
    143   bool AcceptsOptions(const ServerDef& server_def) override {
    144     return server_def.protocol() == "grpc+verbs";
    145   }
    146 
    147   Status NewServer(const ServerDef& server_def,
    148                    std::unique_ptr<ServerInterface>* out_server) override {
    149     return VerbsServer::Create(server_def, Env::Default(), out_server);
    150   }
    151 };
    152 
    153 // Registers a `ServerFactory` for `VerbsServer` instances.
    154 class VerbsServerRegistrar {
    155  public:
    156   VerbsServerRegistrar() {
    157     gpr_allocation_functions alloc_fns;
    158     alloc_fns.malloc_fn = port::Malloc;
    159     alloc_fns.realloc_fn = port::Realloc;
    160     alloc_fns.free_fn = port::Free;
    161     gpr_set_allocation_functions(alloc_fns);
    162     ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
    163   }
    164 };
    165 static VerbsServerRegistrar registrar;
    166 
    167 }  // namespace
    168 }  // namespace tensorflow
    169 
    170 #endif
    171