Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
     17 
     18 #include <utility>
     19 
     20 #include "grpc++/generic/generic_stub.h"
     21 #include "grpc++/grpc++.h"
     22 
     23 #include "tensorflow/core/common_runtime/process_util.h"
     24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
     25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
     26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
     28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
     30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/core/status.h"
     33 #include "tensorflow/core/lib/strings/str_util.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/tracing.h"
     36 #include "tensorflow/core/protobuf/worker.pb.h"
     37 
     38 namespace tensorflow {
     39 
     40 class GrpcRemoteWorker : public WorkerInterface {
     41  public:
     42   explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
     43                             ::grpc::CompletionQueue* completion_queue,
     44                             WorkerCacheLogger* logger)
     45       : channel_(std::move(channel)),
     46         stub_(channel_),
     47         cq_(completion_queue),
     48         getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
     49         createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
     50         deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
     51         registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
     52         deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
     53         rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
     54         cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
     55         cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
     56         recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
     57         logging_(Method(GrpcWorkerMethod::kLogging)),
     58         tracing_(Method(GrpcWorkerMethod::kTracing)),
     59         logger_(logger) {}
     60 
     61   ~GrpcRemoteWorker() override {}
     62 
     63   void GetStatusAsync(const GetStatusRequest* request,
     64                       GetStatusResponse* response,
     65                       StatusCallback done) override {
     66     IssueRequest(request, response, getstatus_, std::move(done));
     67   }
     68 
     69   void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
     70                                 CreateWorkerSessionResponse* response,
     71                                 StatusCallback done) override {
     72     IssueRequest(request, response, createworkersession_, std::move(done));
     73   }
     74 
     75   void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
     76                                 DeleteWorkerSessionResponse* response,
     77                                 StatusCallback done) override {
     78     IssueRequest(request, response, deleteworkersession_, std::move(done));
     79   }
     80 
     81   void RegisterGraphAsync(const RegisterGraphRequest* request,
     82                           RegisterGraphResponse* response,
     83                           StatusCallback done) override {
     84     IssueRequest(request, response, registergraph_, std::move(done));
     85   }
     86 
     87   void DeregisterGraphAsync(const DeregisterGraphRequest* request,
     88                             DeregisterGraphResponse* response,
     89                             StatusCallback done) override {
     90     IssueRequest(request, response, deregistergraph_, std::move(done));
     91   }
     92 
     93   void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
     94                      RunGraphResponse* response, StatusCallback done) override {
     95     IssueRequest(request, response, rungraph_, std::move(done), call_opts);
     96   }
     97   void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
     98                      MutableRunGraphResponseWrapper* response,
     99                      StatusCallback done) override {
    100     IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
    101                  rungraph_, std::move(done), call_opts);
    102   }
    103 
    104   void CleanupGraphAsync(const CleanupGraphRequest* request,
    105                          CleanupGraphResponse* response,
    106                          StatusCallback done) override {
    107     IssueRequest(request, response, cleanupgraph_, std::move(done));
    108   }
    109 
    110   void CleanupAllAsync(const CleanupAllRequest* request,
    111                        CleanupAllResponse* response,
    112                        StatusCallback done) override {
    113     IssueRequest(request, response, cleanupall_, std::move(done));
    114   }
    115 
    116   void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
    117                        TensorResponse* response, StatusCallback done) override {
    118     VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
    119     int64 start_usec = Env::Default()->NowMicros();
    120     // Type-specialized logging for this method.
    121     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
    122     StatusCallback wrapper_done;
    123     const StatusCallback* cb_to_use;
    124     if (!logging_active) {
    125       cb_to_use = &done;  // No additional work to do, so just use done directly
    126     } else {
    127       wrapper_done = [this, request, response, done, start_usec](Status s) {
    128         if (logger_->LoggingActive()) {
    129           int64 end_usec = Env::Default()->NowMicros();
    130           int64 step_id = request->step_id();
    131           int64 bytes = response->tensor().TotalBytes();
    132           int64 send_start_usec = start_usec;
    133           // If a send start time was reported by the other side, use
    134           // that instead.  Maybe we should mark the display if we're using
    135           // our local time instead of the remote start time?
    136           if (response->metadata().send_start_micros()) {
    137             // send_start_micros is the timestamp taken when the
    138             // remote machine began to send the RecvTensor response.
    139             // Due to clock skew between source and dest machines, it
    140             // is possible that send_start_micros can be larger than
    141             // end_usec or less than start_usec.
    142             //
    143             // To respect causality, we enforce the invariants that
    144             // the RecvTensor response can not have been sent before
    145             // the RecvTensor request, and must have been sent before
    146             // it was received.
    147             send_start_usec = std::max(
    148                 start_usec,
    149                 static_cast<int64>(response->metadata().send_start_micros()));
    150             send_start_usec = std::min(send_start_usec, end_usec - 1);
    151           }
    152           const string& key = request->rendezvous_key();
    153           std::vector<string> key_parts = str_util::Split(key, ';');
    154           if (key_parts.size() != 5) {
    155             LOG(WARNING) << "Bad key: " << key;
    156           } else {
    157             logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
    158                                       key_parts[3],  // tensor name
    159                                       key_parts[0],  // src_device
    160                                       key_parts[2],  // dst_device
    161                                       bytes);
    162           }
    163         }
    164         VLOG(2) << "done callback, req: " << request->DebugString()
    165                 << " response " << response->metadata().DebugString();
    166         done(s);
    167       };
    168       cb_to_use = &wrapper_done;
    169     }
    170 
    171     IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
    172   }
    173 
    174   void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
    175                     StatusCallback done) override {
    176     IssueRequest(request, response, logging_, done);
    177   }
    178 
    179   void TracingAsync(const TracingRequest* request, TracingResponse* response,
    180                     StatusCallback done) override {
    181     IssueRequest(request, response, tracing_, done);
    182   }
    183 
    184  private:
    185   // Utility method for issuing a generic asynchronous request. The
    186   // given callback, `done`, will be called when the RPC completes.
    187   void IssueRequest(const protobuf::Message* request,
    188                     protobuf::Message* response, const ::grpc::string& method,
    189                     StatusCallback done, CallOptions* call_opts = nullptr) {
    190     new RPCState<protobuf::Message>(&stub_, cq_, method, *request, response,
    191                                     std::move(done), call_opts);
    192   }
    193   void IssueRequest(const protobuf::Message* request, TensorResponse* response,
    194                     const ::grpc::string& method, StatusCallback done,
    195                     CallOptions* call_opts = nullptr) {
    196     new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
    197                                  std::move(done), call_opts);
    198   }
    199 
    200   // Helper function for initializing the RpcMethod objects below.
    201   const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
    202 
    203   SharedGrpcChannelPtr channel_;
    204   ::grpc::GenericStub stub_;
    205   ::grpc::CompletionQueue* cq_;
    206 
    207   const ::grpc::string getstatus_;
    208   const ::grpc::string createworkersession_;
    209   const ::grpc::string deleteworkersession_;
    210   const ::grpc::string registergraph_;
    211   const ::grpc::string deregistergraph_;
    212   const ::grpc::string rungraph_;
    213   const ::grpc::string cleanupgraph_;
    214   const ::grpc::string cleanupall_;
    215   const ::grpc::string recvtensor_;
    216   const ::grpc::string logging_;
    217   const ::grpc::string tracing_;
    218 
    219   // Support for logging.
    220   WorkerCacheLogger* logger_;
    221 
    222   TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
    223 };
    224 
    225 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
    226                                      ::grpc::CompletionQueue* completion_queue,
    227                                      WorkerCacheLogger* logger) {
    228   return new GrpcRemoteWorker(std::move(channel), completion_queue, logger);
    229 }
    230 
    231 }  // namespace tensorflow
    232