1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // GrpcMasterService implements the RPC service MasterSerivce. 17 // 18 // A GrpcMasterService maintains the state of live graph computation 19 // sessions, each session orchestrates both local and remote devices 20 // to carry out the graph computation. 21 // 22 // A GrpcMasterService knows ahead of time local devices available as 23 // client devices. 24 // 25 // A GrpcMasterService discovers remote devices in the background and 26 // keeps track of statistics of those remote devices. 27 // 28 // Each session analyzes the graph, places nodes across available 29 // devices, and ultimately drives the graph computation by initiating 30 // RunGraph on workers. 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h" 32 33 #include "grpc++/alarm.h" 34 #include "grpc++/server_builder.h" 35 36 #include "tensorflow/core/distributed_runtime/master.h" 37 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 38 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" 39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" 40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 41 #include "tensorflow/core/platform/logging.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/tracing.h" 44 #include "tensorflow/core/protobuf/master.pb.h" 45 46 namespace tensorflow { 47 48 class GrpcMasterService : public AsyncServiceInterface { 49 public: 50 GrpcMasterService(Master* master, int64 default_timeout_in_ms, 51 ::grpc::ServerBuilder* builder) 52 : master_impl_(master), 53 default_timeout_in_ms_(default_timeout_in_ms), 54 is_shutdown_(false) { 55 builder->RegisterService(&master_service_); 56 cq_ = builder->AddCompletionQueue(); 57 } 58 59 ~GrpcMasterService() override { delete shutdown_alarm_; } 60 61 void Shutdown() override { 62 bool did_shutdown = false; 63 { 64 mutex_lock l(mu_); 65 if (!is_shutdown_) { 66 LOG(INFO) << "Shutting down GrpcMasterService."; 67 is_shutdown_ = true; 68 did_shutdown = true; 69 } 70 } 71 if (did_shutdown) { 72 // NOTE(mrry): This enqueues a special event (with a null tag) 73 // that causes the completion queue to be shut down on the 74 // polling thread. 75 shutdown_alarm_ = 76 new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr); 77 } 78 } 79 80 // This macro creates a new request for the given RPC method name 81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on 82 // `this->cq_`. 83 // 84 // This macro is invoked one or more times for each RPC method to 85 // ensure that there are sufficient completion queue entries to 86 // handle incoming requests without blocking. 87 // 88 // The implementation of the request handler for each RPC method 89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method, 90 // to keep accepting new requests. 91 #define ENQUEUE_REQUEST(method, supports_cancel) \ 92 do { \ 93 mutex_lock l(mu_); \ 94 if (!is_shutdown_) { \ 95 Call<GrpcMasterService, grpc::MasterService::AsyncService, \ 96 method##Request, method##Response>:: \ 97 EnqueueRequest(&master_service_, cq_.get(), \ 98 &grpc::MasterService::AsyncService::Request##method, \ 99 &GrpcMasterService::method##Handler, \ 100 (supports_cancel)); \ 101 } \ 102 } while (0) 103 104 void HandleRPCsLoop() override { 105 ENQUEUE_REQUEST(CreateSession, true); 106 ENQUEUE_REQUEST(ExtendSession, false); 107 for (int i = 0; i < 100; ++i) { 108 ENQUEUE_REQUEST(PartialRunSetup, false); 109 ENQUEUE_REQUEST(RunStep, true); 110 } 111 ENQUEUE_REQUEST(CloseSession, false); 112 ENQUEUE_REQUEST(ListDevices, false); 113 ENQUEUE_REQUEST(Reset, false); 114 115 void* tag; 116 bool ok; 117 while (cq_->Next(&tag, &ok)) { 118 UntypedCall<GrpcMasterService>::Tag* callback_tag = 119 static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag); 120 if (callback_tag) { 121 callback_tag->OnCompleted(this, ok); 122 } else { 123 // NOTE(mrry): A null `callback_tag` indicates that this is 124 // the shutdown alarm. 125 cq_->Shutdown(); 126 } 127 } 128 } 129 130 private: 131 Master* master_impl_ = nullptr; // Not owned. 132 const int64 default_timeout_in_ms_; 133 std::unique_ptr<::grpc::ServerCompletionQueue> cq_; 134 grpc::MasterService::AsyncService master_service_; 135 136 mutex mu_; 137 bool is_shutdown_ GUARDED_BY(mu_); 138 ::grpc::Alarm* shutdown_alarm_ = nullptr; 139 140 template <class RequestMessage, class ResponseMessage> 141 using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService, 142 RequestMessage, ResponseMessage>; 143 144 // RPC handler for creating a session. 145 void CreateSessionHandler( 146 MasterCall<CreateSessionRequest, CreateSessionResponse>* call) { 147 master_impl_->CreateSession(&call->request, &call->response, 148 [call](const Status& status) { 149 call->SendResponse(ToGrpcStatus(status)); 150 }); 151 ENQUEUE_REQUEST(CreateSession, true); 152 } 153 154 // RPC handler for extending a session. 155 void ExtendSessionHandler( 156 MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) { 157 master_impl_->ExtendSession(&call->request, &call->response, 158 [call](const Status& status) { 159 call->SendResponse(ToGrpcStatus(status)); 160 }); 161 ENQUEUE_REQUEST(ExtendSession, false); 162 } 163 164 // RPC handler for setting up a partial run call. 165 void PartialRunSetupHandler( 166 MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) { 167 master_impl_->PartialRunSetup(&call->request, &call->response, 168 [call](const Status& status) { 169 call->SendResponse(ToGrpcStatus(status)); 170 }); 171 ENQUEUE_REQUEST(PartialRunSetup, false); 172 } 173 174 // RPC handler for running one step in a session. 175 void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) { 176 auto* trace = TraceRpc("RunStep/Server", call->client_metadata()); 177 CallOptions* call_opts = new CallOptions; 178 if (call->request.options().timeout_in_ms() > 0) { 179 call_opts->SetTimeout(call->request.options().timeout_in_ms()); 180 } else { 181 call_opts->SetTimeout(default_timeout_in_ms_); 182 } 183 RunStepRequestWrapper* wrapped_request = 184 new ProtoRunStepRequest(&call->request); 185 MutableRunStepResponseWrapper* wrapped_response = 186 new NonOwnedProtoRunStepResponse(&call->response); 187 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); 188 master_impl_->RunStep( 189 call_opts, wrapped_request, wrapped_response, 190 [call, call_opts, wrapped_request, wrapped_response, 191 trace](const Status& status) { 192 call->ClearCancelCallback(); 193 delete call_opts; 194 delete wrapped_request; 195 delete trace; 196 if (call->request.store_errors_in_response_body() && !status.ok()) { 197 call->response.set_status_code(status.code()); 198 call->response.set_status_error_message(status.error_message()); 199 call->SendResponse(ToGrpcStatus(Status::OK())); 200 } else { 201 call->SendResponse(ToGrpcStatus(status)); 202 } 203 }); 204 ENQUEUE_REQUEST(RunStep, true); 205 } 206 207 // RPC handler for deleting a session. 208 void CloseSessionHandler( 209 MasterCall<CloseSessionRequest, CloseSessionResponse>* call) { 210 master_impl_->CloseSession(&call->request, &call->response, 211 [call](const Status& status) { 212 call->SendResponse(ToGrpcStatus(status)); 213 }); 214 ENQUEUE_REQUEST(CloseSession, false); 215 } 216 217 // RPC handler for listing devices. 218 void ListDevicesHandler( 219 MasterCall<ListDevicesRequest, ListDevicesResponse>* call) { 220 master_impl_->ListDevices(&call->request, &call->response, 221 [call](const Status& status) { 222 call->SendResponse(ToGrpcStatus(status)); 223 }); 224 ENQUEUE_REQUEST(ListDevices, false); 225 } 226 227 // RPC handler for resetting all sessions. 228 void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) { 229 master_impl_->Reset(&call->request, &call->response, 230 [call](const Status& status) { 231 call->SendResponse(ToGrpcStatus(status)); 232 }); 233 ENQUEUE_REQUEST(Reset, false); 234 } 235 #undef ENQUEUE_REQUEST 236 237 // Start tracing, including the ID attached to the RPC. 238 port::Tracing::TraceMe* TraceRpc( 239 StringPiece name, 240 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) { 241 StringPiece id; 242 auto it = metadata.find(GrpcIdKey()); 243 if (it != metadata.end()) { 244 id = StringPiece(it->second.data(), it->second.size()); 245 } 246 return new port::Tracing::TraceMe(name, id); 247 } 248 249 TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService); 250 }; 251 252 AsyncServiceInterface* NewGrpcMasterService(Master* master, 253 int64 default_timeout_in_ms, 254 ::grpc::ServerBuilder* builder) { 255 return new GrpcMasterService(master, default_timeout_in_ms, builder); 256 } 257 258 } // end namespace tensorflow 259