1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef TENSORFLOW_USE_VERBS 17 18 #include "tensorflow/contrib/verbs/verbs_server_lib.h" 19 20 #include "grpc/support/alloc.h" 21 22 #include "tensorflow/contrib/verbs/rdma_mgr.h" 23 #include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h" 24 #include "tensorflow/core/distributed_runtime/server_lib.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/env.h" 27 28 namespace tensorflow { 29 30 namespace { 31 // static utility function 32 RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env) { 33 return new RdmaRendezvousMgr(env); 34 } 35 36 } // namespace 37 38 VerbsServer::VerbsServer(const ServerDef& server_def, Env* env) 39 : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {} 40 41 VerbsServer::~VerbsServer() { 42 TF_CHECK_OK(Stop()); 43 TF_CHECK_OK(Join()); 44 delete rdma_mgr_; 45 delete verbs_service_; 46 delete channel_cache_; 47 } 48 49 Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def, 50 GrpcChannelCache** channel_cache) { 51 string name_prefix = 52 strings::StrCat("/job:", server_def.job_name(), "/replica:0", 53 "/task:", server_def.task_index()); 54 55 GrpcChannelSpec channel_spec; 56 TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec)); 57 58 *channel_cache = 59 NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()); 60 61 const string host_port = (*channel_cache)->TranslateTask(name_prefix); 62 int requested_port; 63 64 if (!strings::safe_strto32(str_util::Split(host_port, ':')[1], 65 &requested_port)) { 66 return errors::Internal("Could not parse port for local server from \"", 67 (*channel_cache)->TranslateTask(name_prefix), 68 "\"."); 69 } 70 if (requested_port != bound_port()) { 71 return errors::InvalidArgument("Requested port ", requested_port, 72 " differs from expected port ", 73 bound_port()); 74 } 75 76 return Status::OK(); 77 } 78 79 Status VerbsServer::Init(ServiceInitFunction service_func, 80 RendezvousMgrCreationFunction rendezvous_mgr_func) { 81 Status s = GrpcServer::Init(service_func, rendezvous_mgr_func); 82 { 83 mutex_lock l(mu_); 84 CHECK_EQ(verbs_state_, DISCONNECTED); 85 CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok()); 86 rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_); 87 // set rdma_mgr for verbs_service and rdma_rendezvous_mgr 88 verbs_service_->SetRdmaMgr(rdma_mgr_); 89 dynamic_cast<RdmaRendezvousMgr*>(worker_env()->rendezvous_mgr) 90 ->SetRdmaMgr(rdma_mgr_); 91 } 92 return s; 93 } 94 95 Status VerbsServer::Start() { 96 Status s = GrpcServer::Start(); 97 { 98 mutex_lock l(mu_); 99 if (verbs_state_ == DISCONNECTED) { 100 // verbs_thread needs to be initiated 101 // before rdma_mgr sets up the rdma channels. 102 verbs_thread_.reset(worker_env()->env->StartThread( 103 ThreadOptions(), "TF_verbs_service", 104 [this] { verbs_service_->HandleRPCsLoop(); })); 105 rdma_mgr_->SetupChannels(); 106 CHECK(rdma_mgr_->ConnectivityCheck()) << "Connectivity check failed!"; 107 rdma_mgr_->InitAllocators(); 108 verbs_state_ = CONNECTED; 109 } 110 } 111 return s; 112 } 113 114 Status VerbsServer::Join() { 115 Status s = GrpcServer::Join(); 116 { 117 mutex_lock l(mu_); 118 if (verbs_state_ == CONNECTED) { 119 verbs_state_ = DISCONNECTED; 120 verbs_thread_.reset(); 121 } 122 } 123 return s; 124 } 125 126 /* static */ 127 Status VerbsServer::Create(const ServerDef& server_def, Env* env, 128 std::unique_ptr<ServerInterface>* out_server) { 129 std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default())); 130 ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env, 131 ::grpc::ServerBuilder* builder) { 132 return SetNewVerbsService(&ret->verbs_service_, worker_env, builder); 133 }; 134 TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr)); 135 *out_server = std::move(ret); 136 return Status::OK(); 137 } 138 139 namespace { 140 141 class VerbsServerFactory : public ServerFactory { 142 public: 143 bool AcceptsOptions(const ServerDef& server_def) override { 144 return server_def.protocol() == "grpc+verbs"; 145 } 146 147 Status NewServer(const ServerDef& server_def, 148 std::unique_ptr<ServerInterface>* out_server) override { 149 return VerbsServer::Create(server_def, Env::Default(), out_server); 150 } 151 }; 152 153 // Registers a `ServerFactory` for `VerbsServer` instances. 154 class VerbsServerRegistrar { 155 public: 156 VerbsServerRegistrar() { 157 gpr_allocation_functions alloc_fns; 158 alloc_fns.malloc_fn = port::Malloc; 159 alloc_fns.realloc_fn = port::Realloc; 160 alloc_fns.free_fn = port::Free; 161 gpr_set_allocation_functions(alloc_fns); 162 ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory()); 163 } 164 }; 165 static VerbsServerRegistrar registrar; 166 167 } // namespace 168 } // namespace tensorflow 169 170 #endif 171