Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
     17 
     18 #include <unordered_map>
     19 
     20 #include "tensorflow/core/common_runtime/session_factory.h"
     21 #include "tensorflow/core/distributed_runtime/call_options.h"
     22 #include "tensorflow/core/distributed_runtime/local_master.h"
     23 #include "tensorflow/core/distributed_runtime/master_interface.h"
     24 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     25 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
     26 #include "tensorflow/core/framework/attr_value.pb.h"
     27 #include "tensorflow/core/framework/node_def.pb.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/protobuf/master.pb.h"
     31 
     32 namespace tensorflow {
     33 
     34 const char* const kSchemePrefix = "grpc://";
     35 const size_t kSchemePrefixLength = strlen(kSchemePrefix);
     36 
     37 GrpcSession::GrpcSession(const SessionOptions& options)
     38     : options_(options), current_graph_version_(-1) {}
     39 
     40 GrpcSession::~GrpcSession() {}
     41 
     42 /* static */
     43 Status GrpcSession::Create(const SessionOptions& options,
     44                            std::unique_ptr<GrpcSession>* out_session) {
     45   std::unique_ptr<GrpcSession> session(new GrpcSession(options));
     46   std::unique_ptr<MasterInterface> master;
     47   // For testing, we enable the client to disable the use of the local
     48   // master registry, so that the RPC stack is exercised.
     49   if (!options.config.rpc_options().use_rpc_for_inprocess_master()) {
     50     master = LocalMaster::Lookup(options.target);
     51   }
     52   if (!master) {
     53     SharedGrpcChannelPtr master_channel;
     54     TF_RETURN_IF_ERROR(NewHostPortGrpcChannel(
     55         options.target.substr(kSchemePrefixLength), &master_channel));
     56     master.reset(NewGrpcMaster(master_channel));
     57   }
     58   session->SetRemoteMaster(std::move(master));
     59   *out_session = std::move(session);
     60   return Status::OK();
     61 }
     62 
     63 namespace {
     64 // Re-encodes constant represented in tensor proto into
     65 // tensor_content, which is slightly better (less copies and lower peak
     66 // memory usage) when used with rpc subsystems.
     67 void ReEncodeConsts(GraphDef* gdef) {
     68   for (NodeDef& ndef : *(gdef->mutable_node())) {
     69     if (ndef.op() == "Const") {
     70       TensorProto* proto = nullptr;
     71       for (auto& attr : *ndef.mutable_attr()) {
     72         if (attr.first == "value") {
     73           proto = attr.second.mutable_tensor();
     74         }
     75       }
     76       if (proto != nullptr && proto->tensor_content().empty() &&
     77           proto->ByteSizeLong() > 64) {
     78         // If the constant is encoded with repeated proto fields and
     79         // it is moderate large, we re-encode it in tensor_content as
     80         // a Cord. This is mildly helpful for reducing the peak memory
     81         // usage on the server side where GraphDef/NodeDef are copied
     82         // quite often.
     83         Tensor parsed(proto->dtype());
     84         if (parsed.FromProto(*proto)) {
     85           parsed.AsProtoTensorContent(proto);
     86         }
     87       }
     88     }
     89   }
     90 }
     91 }  // namespace
     92 
     93 Status GrpcSession::CreateImpl(CallOptions* call_options,
     94                                const GraphDef& graph) {
     95   {
     96     mutex_lock l(mu_);
     97     if (!handle_.empty()) {
     98       return errors::InvalidArgument("A session is alive.");
     99     }
    100   }
    101   CreateSessionRequest req;
    102   *req.mutable_config() = options_.config;
    103   *req.mutable_graph_def() = graph;
    104   req.set_target(options_.target);
    105   ReEncodeConsts(req.mutable_graph_def());
    106   CreateSessionResponse resp;
    107   Status s = master_->CreateSession(call_options, &req, &resp);
    108   if (s.ok()) {
    109     mutex_lock l(mu_);
    110     swap(handle_, *(resp.mutable_session_handle()));
    111     current_graph_version_ = resp.graph_version();
    112   }
    113   return s;
    114 }
    115 
    116 Status GrpcSession::Create(const GraphDef& graph) {
    117   CallOptions call_options;
    118   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
    119   return CreateImpl(&call_options, graph);
    120 }
    121 
    122 Status GrpcSession::Create(const RunOptions& run_options,
    123                            const GraphDef& graph) {
    124   CallOptions call_options;
    125   call_options.SetTimeout(run_options.timeout_in_ms());
    126   return CreateImpl(&call_options, graph);
    127 }
    128 
    129 Status GrpcSession::ExtendImpl(CallOptions* call_options,
    130                                const GraphDef& graph) {
    131   bool handle_is_empty;
    132   {
    133     mutex_lock l(mu_);
    134     handle_is_empty = handle_.empty();
    135   }
    136   if (handle_is_empty) {
    137     // Session was unitialized, so simply initialize the session with 'graph'.
    138     return Create(graph);
    139   }
    140   mutex_lock l(mu_);
    141   ExtendSessionRequest req;
    142   req.set_session_handle(handle_);
    143   *req.mutable_graph_def() = graph;
    144   req.set_current_graph_version(current_graph_version_);
    145   ExtendSessionResponse resp;
    146   Status s = master_->ExtendSession(call_options, &req, &resp);
    147   if (s.ok()) {
    148     current_graph_version_ = resp.new_graph_version();
    149   }
    150   return s;
    151 }
    152 
    153 Status GrpcSession::Extend(const GraphDef& graph) {
    154   CallOptions call_options;
    155   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
    156   return ExtendImpl(&call_options, graph);
    157 }
    158 
    159 Status GrpcSession::Extend(const RunOptions& run_options,
    160                            const GraphDef& graph) {
    161   CallOptions call_options;
    162   call_options.SetTimeout(run_options.timeout_in_ms());
    163   return ExtendImpl(&call_options, graph);
    164 }
    165 
    166 Status GrpcSession::RunHelper(
    167     const RunOptions& run_options,
    168     const std::vector<std::pair<string, Tensor>>& inputs,
    169     const std::vector<string>& output_tensor_names,
    170     const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
    171     RunMetadata* run_metadata, const string& prun_handle) {
    172   // Convert to proto
    173   std::unique_ptr<MutableRunStepRequestWrapper> req(
    174       master_->CreateRunStepRequest());
    175   std::unique_ptr<MutableRunStepResponseWrapper> resp(
    176       master_->CreateRunStepResponse());
    177 
    178   *req->mutable_options() = run_options;
    179 
    180   if (run_options.timeout_in_ms() == 0) {
    181     req->mutable_options()->set_timeout_in_ms(
    182         options_.config.operation_timeout_in_ms());
    183   }
    184 
    185   if (!prun_handle.empty()) {
    186     req->set_partial_run_handle(prun_handle);
    187   }
    188 
    189   for (const auto& it : inputs) {
    190     req->add_feed(it.first, it.second);
    191   }
    192 
    193   // Support long error messages by storing the error code in the response body.
    194   req->set_store_errors_in_response_body(true);
    195 
    196   // Build an index from fetch tensor name to first index in
    197   // output_tensor_names.
    198   std::unordered_map<string, int> output_name_to_offset;
    199   for (int i = 0; i < output_tensor_names.size(); ++i) {
    200     const string& name = output_tensor_names[i];
    201     if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
    202       req->add_fetch(name);
    203     }
    204   }
    205   for (const string& target : target_node_names) {
    206     req->add_target(target);
    207   }
    208 
    209   CallOptions call_options;
    210   call_options.SetTimeout(req->options().timeout_in_ms());
    211   TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
    212 
    213   // Look for an extended error returned in the response body.
    214   if (resp->status_code() != error::Code::OK) {
    215     return Status(resp->status_code(), resp->status_error_message());
    216   }
    217 
    218   if (!output_tensor_names.empty()) {
    219     outputs->resize(output_tensor_names.size());
    220   }
    221 
    222   // Convert response back to Tensors in the correct order.
    223   for (size_t i = 0; i < resp->num_tensors(); ++i) {
    224     auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
    225     if (fetch_it == output_name_to_offset.end()) {
    226       return errors::Internal("Received response for unrequested fetch: ",
    227                               resp->tensor_name(i));
    228     }
    229 
    230     Tensor output;
    231     TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
    232     (*outputs)[fetch_it->second] = output;
    233   }
    234   // In the unlikely event that output_tensor_names contains duplicates, fill in
    235   // the duplicate values.
    236   if (output_name_to_offset.size() != output_tensor_names.size()) {
    237     for (int i = 0; i < output_tensor_names.size(); ++i) {
    238       const string& name = output_tensor_names[i];
    239       int offset = output_name_to_offset[name];
    240       if (offset != i) {
    241         (*outputs)[i] = (*outputs)[offset];
    242       }
    243     }
    244   }
    245 
    246   if (run_metadata) {
    247     run_metadata->Swap(resp->mutable_metadata());
    248   }
    249 
    250   return Status::OK();
    251 }
    252 
    253 Status GrpcSession::Run(const RunOptions& run_options,
    254                         const std::vector<std::pair<string, Tensor>>& inputs,
    255                         const std::vector<string>& output_tensor_names,
    256                         const std::vector<string>& target_node_names,
    257                         std::vector<Tensor>* outputs,
    258                         RunMetadata* run_metadata) {
    259   return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
    260                    outputs, run_metadata, /* prun_handle */ "");
    261 }
    262 
    263 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
    264                         const std::vector<string>& output_tensor_names,
    265                         const std::vector<string>& target_node_names,
    266                         std::vector<Tensor>* outputs) {
    267   RunOptions run_options;
    268   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
    269   return Run(run_options, inputs, output_tensor_names, target_node_names,
    270              outputs, nullptr);
    271 }
    272 
    273 Status GrpcSession::RunProto(CallOptions* call_options,
    274                              MutableRunStepRequestWrapper* req,
    275                              MutableRunStepResponseWrapper* resp) {
    276   {
    277     mutex_lock l(mu_);
    278     if (handle_.empty()) {
    279       return errors::InvalidArgument("A session is not created yet....");
    280     }
    281 
    282     req->set_session_handle(handle_);
    283   }
    284   return master_->RunStep(call_options, req, resp);
    285 }
    286 
    287 Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
    288                               const std::vector<string>& output_names,
    289                               const std::vector<string>& target_nodes,
    290                               string* handle) {
    291   // Convert to proto
    292   PartialRunSetupRequest req;
    293   PartialRunSetupResponse resp;
    294   CallOptions call_options;
    295   {
    296     mutex_lock l(mu_);
    297     if (handle_.empty()) {
    298       return errors::InvalidArgument("A session is not created yet....");
    299     }
    300 
    301     req.set_session_handle(handle_);
    302   }
    303   for (const string& feed : input_names) {
    304     req.add_feed(feed);
    305   }
    306   for (const string& fetch : output_names) {
    307     req.add_fetch(fetch);
    308   }
    309   for (const string& target : target_nodes) {
    310     req.add_target(target);
    311   }
    312   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
    313   TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
    314   *handle = resp.partial_run_handle();
    315   return Status::OK();
    316 }
    317 
    318 Status GrpcSession::PRun(const string& handle,
    319                          const std::vector<std::pair<string, Tensor>>& inputs,
    320                          const std::vector<string>& output_names,
    321                          std::vector<Tensor>* outputs) {
    322   RunOptions run_options;
    323   run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
    324   return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
    325                    /* run_metadata */ nullptr, handle);
    326 }
    327 
    328 Status GrpcSession::Close() {
    329   CloseSessionRequest req;
    330   {
    331     mutex_lock l(mu_);
    332     if (handle_.empty()) {
    333       return Status::OK();
    334     }
    335     req.set_session_handle(handle_);
    336     handle_.clear();
    337   }
    338   CloseSessionResponse resp;
    339   CallOptions call_options;
    340   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
    341   return master_->CloseSession(&call_options, &req, &resp);
    342 }
    343 
    344 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
    345   ListDevicesRequest req;
    346   {
    347     mutex_lock l(mu_);
    348     req.set_session_handle(handle_);
    349   }
    350   if (req.session_handle().empty()) {
    351     LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
    352                     "an empty graph and other defaults because the session has "
    353                     "not yet been created.";
    354     GraphDef graph_def;
    355     TF_RETURN_IF_ERROR(Create(graph_def));
    356     {
    357       mutex_lock l(mu_);
    358       req.set_session_handle(handle_);
    359     }
    360   }
    361   ListDevicesResponse resp;
    362   CallOptions call_options;
    363   call_options.SetTimeout(options_.config.operation_timeout_in_ms());
    364   Status s = master_->ListDevices(&call_options, &req, &resp);
    365   if (!s.ok()) {
    366     LOG(ERROR) << "Could not list devices: " << s;
    367     return s;
    368   }
    369 
    370   response->clear();
    371   response->reserve(resp.local_device_size() + resp.remote_device_size());
    372   for (const auto& device_attr : resp.local_device()) {
    373     response->emplace_back(device_attr);
    374   }
    375   for (const auto& device_attr : resp.remote_device()) {
    376     response->emplace_back(device_attr);
    377   }
    378   return Status::OK();
    379 }
    380 
    381 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
    382   master_ = std::move(master);
    383 }
    384 
    385 // Static method.
    386 Status GrpcSession::Reset(const SessionOptions& options,
    387                           const std::vector<string>& containers) {
    388   SharedGrpcChannelPtr master_channel;
    389   TF_RETURN_IF_ERROR(NewHostPortGrpcChannel(
    390       options.target.substr(kSchemePrefixLength), &master_channel));
    391   auto master = NewGrpcMaster(master_channel);
    392   ResetRequest req;
    393   for (const auto& c : containers) req.add_container(c);
    394   ResetResponse resp;
    395   CallOptions call_options;
    396   call_options.SetTimeout(options.config.operation_timeout_in_ms());
    397   Status ret = master->Reset(&call_options, &req, &resp);
    398   delete master;
    399   return ret;
    400 }
    401 
    402 class GrpcSessionFactory : public SessionFactory {
    403  public:
    404   bool AcceptsOptions(const SessionOptions& options) override {
    405     return StringPiece(options.target).starts_with(kSchemePrefix);
    406   }
    407 
    408   Session* NewSession(const SessionOptions& options) override {
    409     std::unique_ptr<GrpcSession> ret;
    410     Status s = GrpcSession::Create(options, &ret);
    411     if (s.ok()) {
    412       return ret.release();
    413     } else {
    414       LOG(ERROR) << "Error during session construction: " << s.ToString();
    415       return nullptr;
    416     }
    417   }
    418 
    419   // Invokes the session specific static method to reset containers.
    420   Status Reset(const SessionOptions& options,
    421                const std::vector<string>& containers) override {
    422     return GrpcSession::Reset(options, containers);
    423   }
    424 };
    425 
    426 class GrpcSessionRegistrar {
    427  public:
    428   GrpcSessionRegistrar() {
    429     SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
    430   }
    431 };
    432 static GrpcSessionRegistrar registrar;
    433 
    434 }  // namespace tensorflow
    435