Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // GrpcMasterService implements the RPC service MasterSerivce.
     17 //
     18 // A GrpcMasterService maintains the state of live graph computation
     19 // sessions, each session orchestrates both local and remote devices
     20 // to carry out the graph computation.
     21 //
     22 // A GrpcMasterService knows ahead of time local devices available as
     23 // client devices.
     24 //
     25 // A GrpcMasterService discovers remote devices in the background and
     26 // keeps track of statistics of those remote devices.
     27 //
     28 // Each session analyzes the graph, places nodes across available
     29 // devices, and ultimately drives the graph computation by initiating
     30 // RunGraph on workers.
     31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
     32 
     33 #include "grpc++/alarm.h"
     34 #include "grpc++/server_builder.h"
     35 
     36 #include "tensorflow/core/distributed_runtime/master.h"
     37 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
     38 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
     39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
     40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     41 #include "tensorflow/core/platform/logging.h"
     42 #include "tensorflow/core/platform/macros.h"
     43 #include "tensorflow/core/platform/tracing.h"
     44 #include "tensorflow/core/protobuf/master.pb.h"
     45 
     46 namespace tensorflow {
     47 
     48 class GrpcMasterService : public AsyncServiceInterface {
     49  public:
     50   GrpcMasterService(Master* master, int64 default_timeout_in_ms,
     51                     ::grpc::ServerBuilder* builder)
     52       : master_impl_(master),
     53         default_timeout_in_ms_(default_timeout_in_ms),
     54         is_shutdown_(false) {
     55     builder->RegisterService(&master_service_);
     56     cq_ = builder->AddCompletionQueue();
     57   }
     58 
     59   ~GrpcMasterService() override { delete shutdown_alarm_; }
     60 
     61   void Shutdown() override {
     62     bool did_shutdown = false;
     63     {
     64       mutex_lock l(mu_);
     65       if (!is_shutdown_) {
     66         LOG(INFO) << "Shutting down GrpcMasterService.";
     67         is_shutdown_ = true;
     68         did_shutdown = true;
     69       }
     70     }
     71     if (did_shutdown) {
     72       // NOTE(mrry): This enqueues a special event (with a null tag)
     73       // that causes the completion queue to be shut down on the
     74       // polling thread.
     75       shutdown_alarm_ =
     76           new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
     77     }
     78   }
     79 
     80 // This macro creates a new request for the given RPC method name
     81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
     82 // `this->cq_`.
     83 //
     84 // This macro is invoked one or more times for each RPC method to
     85 // ensure that there are sufficient completion queue entries to
     86 // handle incoming requests without blocking.
     87 //
     88 // The implementation of the request handler for each RPC method
     89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
     90 // to keep accepting new requests.
     91 #define ENQUEUE_REQUEST(method, supports_cancel)                              \
     92   do {                                                                        \
     93     mutex_lock l(mu_);                                                        \
     94     if (!is_shutdown_) {                                                      \
     95       Call<GrpcMasterService, grpc::MasterService::AsyncService,              \
     96            method##Request, method##Response>::                               \
     97           EnqueueRequest(&master_service_, cq_.get(),                         \
     98                          &grpc::MasterService::AsyncService::Request##method, \
     99                          &GrpcMasterService::method##Handler,                 \
    100                          (supports_cancel));                                  \
    101     }                                                                         \
    102   } while (0)
    103 
    104   void HandleRPCsLoop() override {
    105     ENQUEUE_REQUEST(CreateSession, true);
    106     ENQUEUE_REQUEST(ExtendSession, false);
    107     for (int i = 0; i < 100; ++i) {
    108       ENQUEUE_REQUEST(PartialRunSetup, false);
    109       ENQUEUE_REQUEST(RunStep, true);
    110     }
    111     ENQUEUE_REQUEST(CloseSession, false);
    112     ENQUEUE_REQUEST(ListDevices, false);
    113     ENQUEUE_REQUEST(Reset, false);
    114 
    115     void* tag;
    116     bool ok;
    117     while (cq_->Next(&tag, &ok)) {
    118       UntypedCall<GrpcMasterService>::Tag* callback_tag =
    119           static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
    120       if (callback_tag) {
    121         callback_tag->OnCompleted(this, ok);
    122       } else {
    123         // NOTE(mrry): A null `callback_tag` indicates that this is
    124         // the shutdown alarm.
    125         cq_->Shutdown();
    126       }
    127     }
    128   }
    129 
    130  private:
    131   Master* master_impl_ = nullptr;  // Not owned.
    132   const int64 default_timeout_in_ms_;
    133   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
    134   grpc::MasterService::AsyncService master_service_;
    135 
    136   mutex mu_;
    137   bool is_shutdown_ GUARDED_BY(mu_);
    138   ::grpc::Alarm* shutdown_alarm_ = nullptr;
    139 
    140   template <class RequestMessage, class ResponseMessage>
    141   using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
    142                           RequestMessage, ResponseMessage>;
    143 
    144   // RPC handler for creating a session.
    145   void CreateSessionHandler(
    146       MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
    147     master_impl_->CreateSession(&call->request, &call->response,
    148                                 [call](const Status& status) {
    149                                   call->SendResponse(ToGrpcStatus(status));
    150                                 });
    151     ENQUEUE_REQUEST(CreateSession, true);
    152   }
    153 
    154   // RPC handler for extending a session.
    155   void ExtendSessionHandler(
    156       MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
    157     master_impl_->ExtendSession(&call->request, &call->response,
    158                                 [call](const Status& status) {
    159                                   call->SendResponse(ToGrpcStatus(status));
    160                                 });
    161     ENQUEUE_REQUEST(ExtendSession, false);
    162   }
    163 
    164   // RPC handler for setting up a partial run call.
    165   void PartialRunSetupHandler(
    166       MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
    167     master_impl_->PartialRunSetup(&call->request, &call->response,
    168                                   [call](const Status& status) {
    169                                     call->SendResponse(ToGrpcStatus(status));
    170                                   });
    171     ENQUEUE_REQUEST(PartialRunSetup, false);
    172   }
    173 
    174   // RPC handler for running one step in a session.
    175   void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
    176     auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
    177     CallOptions* call_opts = new CallOptions;
    178     if (call->request.options().timeout_in_ms() > 0) {
    179       call_opts->SetTimeout(call->request.options().timeout_in_ms());
    180     } else {
    181       call_opts->SetTimeout(default_timeout_in_ms_);
    182     }
    183     RunStepRequestWrapper* wrapped_request =
    184         new ProtoRunStepRequest(&call->request);
    185     MutableRunStepResponseWrapper* wrapped_response =
    186         new NonOwnedProtoRunStepResponse(&call->response);
    187     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
    188     master_impl_->RunStep(
    189         call_opts, wrapped_request, wrapped_response,
    190         [call, call_opts, wrapped_request, wrapped_response,
    191          trace](const Status& status) {
    192           call->ClearCancelCallback();
    193           delete call_opts;
    194           delete wrapped_request;
    195           delete trace;
    196           if (call->request.store_errors_in_response_body() && !status.ok()) {
    197             call->response.set_status_code(status.code());
    198             call->response.set_status_error_message(status.error_message());
    199             call->SendResponse(ToGrpcStatus(Status::OK()));
    200           } else {
    201             call->SendResponse(ToGrpcStatus(status));
    202           }
    203         });
    204     ENQUEUE_REQUEST(RunStep, true);
    205   }
    206 
    207   // RPC handler for deleting a session.
    208   void CloseSessionHandler(
    209       MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
    210     master_impl_->CloseSession(&call->request, &call->response,
    211                                [call](const Status& status) {
    212                                  call->SendResponse(ToGrpcStatus(status));
    213                                });
    214     ENQUEUE_REQUEST(CloseSession, false);
    215   }
    216 
    217   // RPC handler for listing devices.
    218   void ListDevicesHandler(
    219       MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
    220     master_impl_->ListDevices(&call->request, &call->response,
    221                               [call](const Status& status) {
    222                                 call->SendResponse(ToGrpcStatus(status));
    223                               });
    224     ENQUEUE_REQUEST(ListDevices, false);
    225   }
    226 
    227   // RPC handler for resetting all sessions.
    228   void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
    229     master_impl_->Reset(&call->request, &call->response,
    230                         [call](const Status& status) {
    231                           call->SendResponse(ToGrpcStatus(status));
    232                         });
    233     ENQUEUE_REQUEST(Reset, false);
    234   }
    235 #undef ENQUEUE_REQUEST
    236 
    237   // Start tracing, including the ID attached to the RPC.
    238   port::Tracing::TraceMe* TraceRpc(
    239       StringPiece name,
    240       const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
    241     StringPiece id;
    242     auto it = metadata.find(GrpcIdKey());
    243     if (it != metadata.end()) {
    244       id = StringPiece(it->second.data(), it->second.size());
    245     }
    246     return new port::Tracing::TraceMe(name, id);
    247   }
    248 
    249   TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
    250 };
    251 
    252 AsyncServiceInterface* NewGrpcMasterService(Master* master,
    253                                             int64 default_timeout_in_ms,
    254                                             ::grpc::ServerBuilder* builder) {
    255   return new GrpcMasterService(master, default_timeout_in_ms, builder);
    256 }
    257 
    258 }  // end namespace tensorflow
    259