Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h"
     16 
     17 #include <map>
     18 
     19 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/node_def_builder.h"
     22 #include "tensorflow/core/lib/random/random.h"
     23 #include "tensorflow/core/protobuf/named_tensor.pb.h"
     24 
     25 namespace tensorflow {
     26 
     27 /* static */
     28 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
     29     const OpDef& sig, AttrSlice attrs,
     30     const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
     31     std::vector<string>* send_keys, std::vector<string>* recv_keys) {
     32   const string& target = options.target;
     33   // Construct recv nodes for each input argument.
     34   int i = 0;
     35   for (const auto& in : sig.input_arg()) {
     36     // Resolve the input type.
     37     bool is_type_list;
     38     DataTypeVector dtypes;
     39     TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes));
     40     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
     41     if (is_type_list || dtypes.size() > 1) {
     42       return errors::Unimplemented("Input arg: ", in.name(),
     43                                    " has a list type or variadic number of "
     44                                    "attrs. Currently unsupported.");
     45     }
     46 
     47     NodeDef* input_node = g->add_node();
     48     TF_RETURN_IF_ERROR(
     49         NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv")
     50             .Attr("tensor_type", dtypes[0])
     51             .Attr("tensor_name", in.name())
     52             .Attr("send_device", target)
     53             .Attr("recv_device", target)
     54             .Attr("send_device_incarnation", 1)
     55             .Attr("client_terminated", true)
     56             .Device(target)
     57             .Finalize(input_node));
     58     // src_incarnation = 1 works because the transfer is across the same device.
     59     // TODO(rohanj): Find the src_incarnation for the remote device and set it.
     60     const string& key = Rendezvous::CreateKey(
     61         target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0));
     62     send_keys->push_back(key);
     63     ++i;
     64   }
     65 
     66   NodeDef* function_node = g->add_node();
     67   function_node->set_name(sig.name());
     68   function_node->set_op(sig.name());
     69   i = 0;
     70   for (const auto& in : sig.input_arg()) {
     71     function_node->add_input(strings::StrCat("_recv_", in.name(), "_", i));
     72     ++i;
     73   }
     74   function_node->set_device(target);
     75   for (const auto& p : attrs) {
     76     (*function_node->mutable_attr())[p.first] = p.second;
     77   }
     78 
     79   // Construct output nodes for each output.
     80   i = 0;
     81   for (const auto& out : sig.output_arg()) {
     82     // Resolve the output type.
     83     bool is_type_list;
     84     DataTypeVector dtypes;
     85     TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes));
     86     // TODO(rohanj): Handle list and variadic number of attrs. Here and below.
     87     if (is_type_list || dtypes.size() > 1) {
     88       return errors::Unimplemented("Output arg: ", out.name(),
     89                                    " has a list type or variadic number of "
     90                                    "attrs. Currently unsupported.");
     91     }
     92 
     93     NodeDef* output_node = g->add_node();
     94     TF_RETURN_IF_ERROR(
     95         NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send")
     96             .Input(sig.name(), i, dtypes[0])
     97             .Attr("tensor_name", out.name())
     98             .Attr("send_device", target)
     99             .Attr("recv_device", target)
    100             .Attr("send_device_incarnation", 1)
    101             .Attr("client_terminated", true)
    102             .Device(target)
    103             .Finalize(output_node));
    104     const string& key =
    105         Rendezvous::CreateKey(target, 1 /* src_incarnation */, target,
    106                               out.name(), FrameAndIter(0, 0));
    107     recv_keys->push_back(key);
    108     ++i;
    109   }
    110   return Status::OK();
    111 }
    112 
    113 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() {
    114   for (auto& function_data : function_data_) {
    115     worker_session_->worker_cache->ReleaseWorker(function_data.target,
    116                                                  function_data.wi);
    117   }
    118 }
    119 
    120 Status ClusterFunctionLibraryRuntime::Instantiate(
    121     const string& function_name, const FunctionLibraryDefinition& lib_def,
    122     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
    123     FunctionLibraryRuntime::LocalHandle* handle) {
    124   WorkerInterface* wi =
    125       worker_session_->worker_cache->CreateWorker(options.target);
    126 
    127   if (wi == nullptr) {
    128     std::vector<string> workers;
    129     worker_session_->worker_cache->ListWorkers(&workers);
    130     return errors::InvalidArgument(
    131         "Could not find worker with target: ", options.target,
    132         " Available workers: ", str_util::Join(workers, ", "));
    133   }
    134 
    135   // Make RPC and obtain a graph handle.
    136   const FunctionDef* fdef = lib_def.Find(function_name);
    137   const OpDef& sig = fdef->signature();
    138   GraphDef gdef;
    139   std::vector<string> send_keys, recv_keys;
    140   TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef,
    141                                             &send_keys, &recv_keys));
    142   *gdef.mutable_library() = lib_def.ToProto();
    143 
    144   RegisterGraphRequest req;
    145   req.set_session_handle(worker_session_->session_name);
    146   *req.mutable_graph_def() = gdef;
    147   req.mutable_graph_options()
    148       ->mutable_optimizer_options()
    149       ->set_do_function_inlining(true);
    150   RegisterGraphResponse resp;
    151   TF_RETURN_IF_ERROR(wi->RegisterGraph(&req, &resp));
    152 
    153   mutex_lock l(mu_);
    154   *handle = function_data_.size();
    155   function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi,
    156                                         send_keys, recv_keys));
    157   return Status::OK();
    158 }
    159 
    160 void ClusterFunctionLibraryRuntime::Run(
    161     const FunctionLibraryRuntime::Options& opts,
    162     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
    163     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
    164   FunctionData* function_data = nullptr;
    165   {
    166     mutex_lock l(mu_);
    167     CHECK_LE(handle, function_data_.size());
    168     function_data = &function_data_[handle];
    169   }
    170 
    171   WorkerInterface* wi = function_data->wi;
    172 
    173   if (wi == nullptr) {
    174     done(errors::Internal("Could not find worker"));
    175     return;
    176   }
    177 
    178   RunGraphRequest req;
    179   req.set_session_handle(worker_session_->session_name);
    180   req.set_graph_handle(function_data->graph_handle);
    181   // Borrowed from master_session.cc
    182   const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
    183   req.set_step_id(step_id);
    184   int i = 0;
    185   for (const auto& send_key : function_data->send_keys) {
    186     NamedTensorProto* send = req.add_send();
    187     send->set_name(send_key);
    188     args[i].AsProtoTensorContent(send->mutable_tensor());
    189     i++;
    190   }
    191   const std::vector<string>& recv_keys = function_data->recv_keys;
    192   for (const auto& recv_key : recv_keys) {
    193     req.add_recv_key(recv_key);
    194   }
    195 
    196   RunGraphResponse* resp = new RunGraphResponse();
    197   CallOptions* call_options = new CallOptions();
    198   wi->RunGraphAsync(
    199       call_options, &req, resp,
    200       [call_options, resp, rets, recv_keys, done](const Status& status) {
    201         if (!status.ok()) {
    202           done(status);
    203           delete call_options;
    204           delete resp;
    205           return;
    206         }
    207         std::map<string, TensorProto*> mapped_recvs;
    208         for (auto& recv : *resp->mutable_recv()) {
    209           mapped_recvs[recv.name()] = recv.mutable_tensor();
    210         }
    211 
    212         for (const auto& recv_key : recv_keys) {
    213           TensorProto* tp = mapped_recvs[recv_key];
    214           if (tp == nullptr) {
    215             delete call_options;
    216             delete resp;
    217             done(errors::Internal("Could not find key: ", recv_key));
    218             return;
    219           }
    220           Tensor t;
    221           if (t.FromProto(*tp)) {
    222             rets->push_back(t);
    223           } else {
    224             delete call_options;
    225             delete resp;
    226             done(errors::Internal("Could not convert tensor proto: ",
    227                                   tp->DebugString()));
    228             return;
    229           }
    230         }
    231         delete call_options;
    232         delete resp;
    233         done(status);
    234       });
    235 }
    236 
    237 }  // namespace tensorflow
    238