1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" 17 18 #include <utility> 19 20 #include "grpc++/generic/generic_stub.h" 21 #include "grpc++/grpc++.h" 22 23 #include "tensorflow/core/common_runtime/process_util.h" 24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h" 28 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h" 30 #include "tensorflow/core/distributed_runtime/worker_interface.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/core/status.h" 33 #include "tensorflow/core/lib/strings/str_util.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/tracing.h" 36 #include "tensorflow/core/protobuf/worker.pb.h" 37 38 namespace tensorflow { 39 40 class GrpcRemoteWorker : public WorkerInterface { 41 public: 42 explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, 43 ::grpc::CompletionQueue* completion_queue, 44 WorkerCacheLogger* logger) 45 : channel_(std::move(channel)), 46 stub_(channel_), 47 cq_(completion_queue), 48 getstatus_(Method(GrpcWorkerMethod::kGetStatus)), 49 createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), 50 deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)), 51 registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)), 52 deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)), 53 rungraph_(Method(GrpcWorkerMethod::kRunGraph)), 54 cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)), 55 cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)), 56 recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)), 57 logging_(Method(GrpcWorkerMethod::kLogging)), 58 tracing_(Method(GrpcWorkerMethod::kTracing)), 59 logger_(logger) {} 60 61 ~GrpcRemoteWorker() override {} 62 63 void GetStatusAsync(const GetStatusRequest* request, 64 GetStatusResponse* response, 65 StatusCallback done) override { 66 IssueRequest(request, response, getstatus_, std::move(done)); 67 } 68 69 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, 70 CreateWorkerSessionResponse* response, 71 StatusCallback done) override { 72 IssueRequest(request, response, createworkersession_, std::move(done)); 73 } 74 75 void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, 76 DeleteWorkerSessionResponse* response, 77 StatusCallback done) override { 78 IssueRequest(request, response, deleteworkersession_, std::move(done)); 79 } 80 81 void RegisterGraphAsync(const RegisterGraphRequest* request, 82 RegisterGraphResponse* response, 83 StatusCallback done) override { 84 IssueRequest(request, response, registergraph_, std::move(done)); 85 } 86 87 void DeregisterGraphAsync(const DeregisterGraphRequest* request, 88 DeregisterGraphResponse* response, 89 StatusCallback done) override { 90 IssueRequest(request, response, deregistergraph_, std::move(done)); 91 } 92 93 void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request, 94 RunGraphResponse* response, StatusCallback done) override { 95 IssueRequest(request, response, rungraph_, std::move(done), call_opts); 96 } 97 void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request, 98 MutableRunGraphResponseWrapper* response, 99 StatusCallback done) override { 100 IssueRequest(&request->ToProto(), get_proto_from_wrapper(response), 101 rungraph_, std::move(done), call_opts); 102 } 103 104 void CleanupGraphAsync(const CleanupGraphRequest* request, 105 CleanupGraphResponse* response, 106 StatusCallback done) override { 107 IssueRequest(request, response, cleanupgraph_, std::move(done)); 108 } 109 110 void CleanupAllAsync(const CleanupAllRequest* request, 111 CleanupAllResponse* response, 112 StatusCallback done) override { 113 IssueRequest(request, response, cleanupall_, std::move(done)); 114 } 115 116 void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, 117 TensorResponse* response, StatusCallback done) override { 118 VLOG(1) << "RecvTensorAsync req: " << request->DebugString(); 119 int64 start_usec = Env::Default()->NowMicros(); 120 // Type-specialized logging for this method. 121 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); 122 StatusCallback wrapper_done; 123 const StatusCallback* cb_to_use; 124 if (!logging_active) { 125 cb_to_use = &done; // No additional work to do, so just use done directly 126 } else { 127 wrapper_done = [this, request, response, done, start_usec](Status s) { 128 if (logger_->LoggingActive()) { 129 int64 end_usec = Env::Default()->NowMicros(); 130 int64 step_id = request->step_id(); 131 int64 bytes = response->tensor().TotalBytes(); 132 int64 send_start_usec = start_usec; 133 // If a send start time was reported by the other side, use 134 // that instead. Maybe we should mark the display if we're using 135 // our local time instead of the remote start time? 136 if (response->metadata().send_start_micros()) { 137 // send_start_micros is the timestamp taken when the 138 // remote machine began to send the RecvTensor response. 139 // Due to clock skew between source and dest machines, it 140 // is possible that send_start_micros can be larger than 141 // end_usec or less than start_usec. 142 // 143 // To respect causality, we enforce the invariants that 144 // the RecvTensor response can not have been sent before 145 // the RecvTensor request, and must have been sent before 146 // it was received. 147 send_start_usec = std::max( 148 start_usec, 149 static_cast<int64>(response->metadata().send_start_micros())); 150 send_start_usec = std::min(send_start_usec, end_usec - 1); 151 } 152 const string& key = request->rendezvous_key(); 153 std::vector<string> key_parts = str_util::Split(key, ';'); 154 if (key_parts.size() != 5) { 155 LOG(WARNING) << "Bad key: " << key; 156 } else { 157 logger_->RecordRecvTensor(step_id, send_start_usec, end_usec, 158 key_parts[3], // tensor name 159 key_parts[0], // src_device 160 key_parts[2], // dst_device 161 bytes); 162 } 163 } 164 VLOG(2) << "done callback, req: " << request->DebugString() 165 << " response " << response->metadata().DebugString(); 166 done(s); 167 }; 168 cb_to_use = &wrapper_done; 169 } 170 171 IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts); 172 } 173 174 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, 175 StatusCallback done) override { 176 IssueRequest(request, response, logging_, done); 177 } 178 179 void TracingAsync(const TracingRequest* request, TracingResponse* response, 180 StatusCallback done) override { 181 IssueRequest(request, response, tracing_, done); 182 } 183 184 private: 185 // Utility method for issuing a generic asynchronous request. The 186 // given callback, `done`, will be called when the RPC completes. 187 void IssueRequest(const protobuf::Message* request, 188 protobuf::Message* response, const ::grpc::string& method, 189 StatusCallback done, CallOptions* call_opts = nullptr) { 190 new RPCState<protobuf::Message>(&stub_, cq_, method, *request, response, 191 std::move(done), call_opts); 192 } 193 void IssueRequest(const protobuf::Message* request, TensorResponse* response, 194 const ::grpc::string& method, StatusCallback done, 195 CallOptions* call_opts = nullptr) { 196 new RPCState<TensorResponse>(&stub_, cq_, method, *request, response, 197 std::move(done), call_opts); 198 } 199 200 // Helper function for initializing the RpcMethod objects below. 201 const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); } 202 203 SharedGrpcChannelPtr channel_; 204 ::grpc::GenericStub stub_; 205 ::grpc::CompletionQueue* cq_; 206 207 const ::grpc::string getstatus_; 208 const ::grpc::string createworkersession_; 209 const ::grpc::string deleteworkersession_; 210 const ::grpc::string registergraph_; 211 const ::grpc::string deregistergraph_; 212 const ::grpc::string rungraph_; 213 const ::grpc::string cleanupgraph_; 214 const ::grpc::string cleanupall_; 215 const ::grpc::string recvtensor_; 216 const ::grpc::string logging_; 217 const ::grpc::string tracing_; 218 219 // Support for logging. 220 WorkerCacheLogger* logger_; 221 222 TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); 223 }; 224 225 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, 226 ::grpc::CompletionQueue* completion_queue, 227 WorkerCacheLogger* logger) { 228 return new GrpcRemoteWorker(std::move(channel), completion_queue, logger); 229 } 230 231 } // namespace tensorflow 232