1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/distributed_runtime/cluster_function_library_runtime.h" 16 17 #include <map> 18 19 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/node_def_builder.h" 22 #include "tensorflow/core/lib/random/random.h" 23 #include "tensorflow/core/protobuf/named_tensor.pb.h" 24 25 namespace tensorflow { 26 27 /* static */ 28 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( 29 const OpDef& sig, AttrSlice attrs, 30 const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g, 31 std::vector<string>* send_keys, std::vector<string>* recv_keys) { 32 const string& target = options.target; 33 // Construct recv nodes for each input argument. 34 int i = 0; 35 for (const auto& in : sig.input_arg()) { 36 // Resolve the input type. 37 bool is_type_list; 38 DataTypeVector dtypes; 39 TF_RETURN_IF_ERROR(ArgNumType(attrs, in, &is_type_list, &dtypes)); 40 // TODO(rohanj): Handle list and variadic number of attrs. Here and below. 41 if (is_type_list || dtypes.size() > 1) { 42 return errors::Unimplemented("Input arg: ", in.name(), 43 " has a list type or variadic number of " 44 "attrs. Currently unsupported."); 45 } 46 47 NodeDef* input_node = g->add_node(); 48 TF_RETURN_IF_ERROR( 49 NodeDefBuilder(strings::StrCat("_recv_", in.name(), "_", i), "_Recv") 50 .Attr("tensor_type", dtypes[0]) 51 .Attr("tensor_name", in.name()) 52 .Attr("send_device", target) 53 .Attr("recv_device", target) 54 .Attr("send_device_incarnation", 1) 55 .Attr("client_terminated", true) 56 .Device(target) 57 .Finalize(input_node)); 58 // src_incarnation = 1 works because the transfer is across the same device. 59 // TODO(rohanj): Find the src_incarnation for the remote device and set it. 60 const string& key = Rendezvous::CreateKey( 61 target, 1 /* src_incarnation */, target, in.name(), FrameAndIter(0, 0)); 62 send_keys->push_back(key); 63 ++i; 64 } 65 66 NodeDef* function_node = g->add_node(); 67 function_node->set_name(sig.name()); 68 function_node->set_op(sig.name()); 69 i = 0; 70 for (const auto& in : sig.input_arg()) { 71 function_node->add_input(strings::StrCat("_recv_", in.name(), "_", i)); 72 ++i; 73 } 74 function_node->set_device(target); 75 for (const auto& p : attrs) { 76 (*function_node->mutable_attr())[p.first] = p.second; 77 } 78 79 // Construct output nodes for each output. 80 i = 0; 81 for (const auto& out : sig.output_arg()) { 82 // Resolve the output type. 83 bool is_type_list; 84 DataTypeVector dtypes; 85 TF_RETURN_IF_ERROR(ArgNumType(attrs, out, &is_type_list, &dtypes)); 86 // TODO(rohanj): Handle list and variadic number of attrs. Here and below. 87 if (is_type_list || dtypes.size() > 1) { 88 return errors::Unimplemented("Output arg: ", out.name(), 89 " has a list type or variadic number of " 90 "attrs. Currently unsupported."); 91 } 92 93 NodeDef* output_node = g->add_node(); 94 TF_RETURN_IF_ERROR( 95 NodeDefBuilder(strings::StrCat("_send_", out.name(), "_", i), "_Send") 96 .Input(sig.name(), i, dtypes[0]) 97 .Attr("tensor_name", out.name()) 98 .Attr("send_device", target) 99 .Attr("recv_device", target) 100 .Attr("send_device_incarnation", 1) 101 .Attr("client_terminated", true) 102 .Device(target) 103 .Finalize(output_node)); 104 const string& key = 105 Rendezvous::CreateKey(target, 1 /* src_incarnation */, target, 106 out.name(), FrameAndIter(0, 0)); 107 recv_keys->push_back(key); 108 ++i; 109 } 110 return Status::OK(); 111 } 112 113 ClusterFunctionLibraryRuntime::~ClusterFunctionLibraryRuntime() { 114 for (auto& function_data : function_data_) { 115 worker_session_->worker_cache->ReleaseWorker(function_data.target, 116 function_data.wi); 117 } 118 } 119 120 Status ClusterFunctionLibraryRuntime::Instantiate( 121 const string& function_name, const FunctionLibraryDefinition& lib_def, 122 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, 123 FunctionLibraryRuntime::LocalHandle* handle) { 124 WorkerInterface* wi = 125 worker_session_->worker_cache->CreateWorker(options.target); 126 127 if (wi == nullptr) { 128 std::vector<string> workers; 129 worker_session_->worker_cache->ListWorkers(&workers); 130 return errors::InvalidArgument( 131 "Could not find worker with target: ", options.target, 132 " Available workers: ", str_util::Join(workers, ", ")); 133 } 134 135 // Make RPC and obtain a graph handle. 136 const FunctionDef* fdef = lib_def.Find(function_name); 137 const OpDef& sig = fdef->signature(); 138 GraphDef gdef; 139 std::vector<string> send_keys, recv_keys; 140 TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef, 141 &send_keys, &recv_keys)); 142 *gdef.mutable_library() = lib_def.ToProto(); 143 144 RegisterGraphRequest req; 145 req.set_session_handle(worker_session_->session_name); 146 *req.mutable_graph_def() = gdef; 147 req.mutable_graph_options() 148 ->mutable_optimizer_options() 149 ->set_do_function_inlining(true); 150 RegisterGraphResponse resp; 151 TF_RETURN_IF_ERROR(wi->RegisterGraph(&req, &resp)); 152 153 mutex_lock l(mu_); 154 *handle = function_data_.size(); 155 function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi, 156 send_keys, recv_keys)); 157 return Status::OK(); 158 } 159 160 void ClusterFunctionLibraryRuntime::Run( 161 const FunctionLibraryRuntime::Options& opts, 162 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args, 163 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) { 164 FunctionData* function_data = nullptr; 165 { 166 mutex_lock l(mu_); 167 CHECK_LE(handle, function_data_.size()); 168 function_data = &function_data_[handle]; 169 } 170 171 WorkerInterface* wi = function_data->wi; 172 173 if (wi == nullptr) { 174 done(errors::Internal("Could not find worker")); 175 return; 176 } 177 178 RunGraphRequest req; 179 req.set_session_handle(worker_session_->session_name); 180 req.set_graph_handle(function_data->graph_handle); 181 // Borrowed from master_session.cc 182 const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); 183 req.set_step_id(step_id); 184 int i = 0; 185 for (const auto& send_key : function_data->send_keys) { 186 NamedTensorProto* send = req.add_send(); 187 send->set_name(send_key); 188 args[i].AsProtoTensorContent(send->mutable_tensor()); 189 i++; 190 } 191 const std::vector<string>& recv_keys = function_data->recv_keys; 192 for (const auto& recv_key : recv_keys) { 193 req.add_recv_key(recv_key); 194 } 195 196 RunGraphResponse* resp = new RunGraphResponse(); 197 CallOptions* call_options = new CallOptions(); 198 wi->RunGraphAsync( 199 call_options, &req, resp, 200 [call_options, resp, rets, recv_keys, done](const Status& status) { 201 if (!status.ok()) { 202 done(status); 203 delete call_options; 204 delete resp; 205 return; 206 } 207 std::map<string, TensorProto*> mapped_recvs; 208 for (auto& recv : *resp->mutable_recv()) { 209 mapped_recvs[recv.name()] = recv.mutable_tensor(); 210 } 211 212 for (const auto& recv_key : recv_keys) { 213 TensorProto* tp = mapped_recvs[recv_key]; 214 if (tp == nullptr) { 215 delete call_options; 216 delete resp; 217 done(errors::Internal("Could not find key: ", recv_key)); 218 return; 219 } 220 Tensor t; 221 if (t.FromProto(*tp)) { 222 rets->push_back(t); 223 } else { 224 delete call_options; 225 delete resp; 226 done(errors::Internal("Could not convert tensor proto: ", 227 tp->DebugString())); 228 return; 229 } 230 } 231 delete call_options; 232 delete resp; 233 done(status); 234 }); 235 } 236 237 } // namespace tensorflow 238