Home | History | Annotate | Download | only in verbs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_VERBS
     17 
     18 #include "grpc++/alarm.h"
     19 #include "grpc++/grpc++.h"
     20 #include "grpc++/server_builder.h"
     21 
     22 #include "tensorflow/contrib/verbs/grpc_verbs_service.h"
     23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     24 #include "tensorflow/core/distributed_runtime/session_mgr.h"
     25 
     26 namespace tensorflow {
     27 
     28 GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
     29                                    ::grpc::ServerBuilder* builder)
     30     : is_shutdown_(false), worker_env_(worker_env) {
     31   builder->RegisterService(&verbs_service_);
     32   cq_ = builder->AddCompletionQueue().release();
     33 }
     34 
     35 GrpcVerbsService::~GrpcVerbsService() {
     36   delete shutdown_alarm_;
     37   delete cq_;
     38 }
     39 
     40 void GrpcVerbsService::Shutdown() {
     41   bool did_shutdown = false;
     42   {
     43     mutex_lock l(shutdown_mu_);
     44     if (!is_shutdown_) {
     45       LOG(INFO) << "Shutting down GrpcWorkerService.";
     46       is_shutdown_ = true;
     47       did_shutdown = true;
     48     }
     49   }
     50   if (did_shutdown) {
     51     shutdown_alarm_ =
     52         new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
     53   }
     54 }
     55 
     56 // This macro creates a new request for the given RPC method name
     57 // (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
     58 // `this->cq_`.
     59 //
     60 // This macro is invoked one or more times for each RPC method to
     61 // ensure that there are sufficient completion queue entries to
     62 // handle incoming requests without blocking.
     63 //
     64 // The implementation of the request handler for each RPC method
     65 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
     66 // to keep accepting new requests.
     67 #define ENQUEUE_REQUEST(method, supports_cancel)                             \
     68   do {                                                                       \
     69     mutex_lock l(shutdown_mu_);                                              \
     70     if (!is_shutdown_) {                                                     \
     71       Call<GrpcVerbsService, grpc::VerbsService::AsyncService,               \
     72            method##Request, method##Response>::                              \
     73           EnqueueRequest(&verbs_service_, cq_,                               \
     74                          &grpc::VerbsService::AsyncService::Request##method, \
     75                          &GrpcVerbsService::method##Handler,                 \
     76                          (supports_cancel));                                 \
     77     }                                                                        \
     78   } while (0)
     79 
     80 // This method blocks forever handling requests from the completion queue.
     81 void GrpcVerbsService::HandleRPCsLoop() {
     82   for (int i = 0; i < 10; ++i) {
     83     ENQUEUE_REQUEST(GetRemoteAddress, false);
     84   }
     85 
     86   void* tag;
     87   bool ok;
     88 
     89   while (cq_->Next(&tag, &ok)) {
     90     UntypedCall<GrpcVerbsService>::Tag* callback_tag =
     91         static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
     92     if (callback_tag) {
     93       callback_tag->OnCompleted(this, ok);
     94     } else {
     95       cq_->Shutdown();
     96     }
     97   }
     98 }
     99 
    100 void GrpcVerbsService::GetRemoteAddressHandler(
    101     WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
    102   Status s = GetRemoteAddressSync(&call->request, &call->response);
    103   call->SendResponse(ToGrpcStatus(s));
    104   ENQUEUE_REQUEST(GetRemoteAddress, false);
    105 }
    106 
    107 // synchronous method
    108 Status GrpcVerbsService::GetRemoteAddressSync(
    109     const GetRemoteAddressRequest* request,
    110     GetRemoteAddressResponse* response) {
    111   // analyzing request
    112   // the channel setting part is redundant.
    113   const string remote_host_name = request->host_name();
    114   RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
    115   CHECK(rc);
    116   RdmaAddress ra;
    117   ra.lid = request->channel().lid();
    118   ra.qpn = request->channel().qpn();
    119   ra.psn = request->channel().psn();
    120   ra.snp = request->channel().snp();
    121   ra.iid = request->channel().iid();
    122   rc->SetRemoteAddress(ra, false);
    123   rc->Connect();
    124   int i = 0;
    125   int idx[] = {1, 0};
    126   std::vector<RdmaMessageBuffer*> mb(rc->message_buffers());
    127   CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers);
    128   for (const auto& mr : request->mr()) {
    129     // the connections are crossed, i.e.
    130     // local tx_message_buffer <---> remote rx_message_buffer_
    131     // local rx_message_buffer <---> remote tx_message_buffer_
    132     // hence idx[] = {1, 0}.
    133     RdmaMessageBuffer* rb = mb[idx[i]];
    134     RemoteMR rmr;
    135     rmr.remote_addr = mr.remote_addr();
    136     rmr.rkey = mr.rkey();
    137     rb->SetRemoteMR(rmr, false);
    138     i++;
    139   }
    140   CHECK(i == RdmaChannel::kNumMessageBuffers);
    141 
    142   // setting up response
    143   response->set_host_name(
    144       worker_env_->session_mgr->LegacySession()->worker_name);
    145   Channel* channel_info = response->mutable_channel();
    146   channel_info->set_lid(rc->self().lid);
    147   channel_info->set_qpn(rc->self().qpn);
    148   channel_info->set_psn(rc->self().psn);
    149   channel_info->set_snp(rc->self().snp);
    150   channel_info->set_iid(rc->self().iid);
    151   for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
    152     MemoryRegion* mr = response->add_mr();
    153     mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
    154     mr->set_rkey(mb[i]->self()->rkey);
    155   }
    156   return Status::OK();
    157 }
    158 
    159 // Create a GrpcVerbsService, then assign it to a given handle.
    160 void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
    161                         ::grpc::ServerBuilder* builder) {
    162   *handle = new GrpcVerbsService(worker_env, builder);
    163 }
    164 
    165 }  // namespace tensorflow
    166 
    167 #endif  // TENSORFLOW_USE_VERBS
    168