1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_session.h" 17 18 #include <unordered_map> 19 20 #include "tensorflow/core/common_runtime/session_factory.h" 21 #include "tensorflow/core/distributed_runtime/call_options.h" 22 #include "tensorflow/core/distributed_runtime/local_master.h" 23 #include "tensorflow/core/distributed_runtime/master_interface.h" 24 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" 26 #include "tensorflow/core/framework/attr_value.pb.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/protobuf/master.pb.h" 31 32 namespace tensorflow { 33 34 const char* const kSchemePrefix = "grpc://"; 35 const size_t kSchemePrefixLength = strlen(kSchemePrefix); 36 37 GrpcSession::GrpcSession(const SessionOptions& options) 38 : options_(options), current_graph_version_(-1) {} 39 40 GrpcSession::~GrpcSession() {} 41 42 /* static */ 43 Status GrpcSession::Create(const SessionOptions& options, 44 std::unique_ptr<GrpcSession>* out_session) { 45 std::unique_ptr<GrpcSession> session(new GrpcSession(options)); 46 std::unique_ptr<MasterInterface> master; 47 // For testing, we enable the client to disable the use of the local 48 // master registry, so that the RPC stack is exercised. 49 if (!options.config.rpc_options().use_rpc_for_inprocess_master()) { 50 master = LocalMaster::Lookup(options.target); 51 } 52 if (!master) { 53 SharedGrpcChannelPtr master_channel; 54 TF_RETURN_IF_ERROR(NewHostPortGrpcChannel( 55 options.target.substr(kSchemePrefixLength), &master_channel)); 56 master.reset(NewGrpcMaster(master_channel)); 57 } 58 session->SetRemoteMaster(std::move(master)); 59 *out_session = std::move(session); 60 return Status::OK(); 61 } 62 63 namespace { 64 // Re-encodes constant represented in tensor proto into 65 // tensor_content, which is slightly better (less copies and lower peak 66 // memory usage) when used with rpc subsystems. 67 void ReEncodeConsts(GraphDef* gdef) { 68 for (NodeDef& ndef : *(gdef->mutable_node())) { 69 if (ndef.op() == "Const") { 70 TensorProto* proto = nullptr; 71 for (auto& attr : *ndef.mutable_attr()) { 72 if (attr.first == "value") { 73 proto = attr.second.mutable_tensor(); 74 } 75 } 76 if (proto != nullptr && proto->tensor_content().empty() && 77 proto->ByteSizeLong() > 64) { 78 // If the constant is encoded with repeated proto fields and 79 // it is moderate large, we re-encode it in tensor_content as 80 // a Cord. This is mildly helpful for reducing the peak memory 81 // usage on the server side where GraphDef/NodeDef are copied 82 // quite often. 83 Tensor parsed(proto->dtype()); 84 if (parsed.FromProto(*proto)) { 85 parsed.AsProtoTensorContent(proto); 86 } 87 } 88 } 89 } 90 } 91 } // namespace 92 93 Status GrpcSession::CreateImpl(CallOptions* call_options, 94 const GraphDef& graph) { 95 { 96 mutex_lock l(mu_); 97 if (!handle_.empty()) { 98 return errors::InvalidArgument("A session is alive."); 99 } 100 } 101 CreateSessionRequest req; 102 *req.mutable_config() = options_.config; 103 *req.mutable_graph_def() = graph; 104 req.set_target(options_.target); 105 ReEncodeConsts(req.mutable_graph_def()); 106 CreateSessionResponse resp; 107 Status s = master_->CreateSession(call_options, &req, &resp); 108 if (s.ok()) { 109 mutex_lock l(mu_); 110 swap(handle_, *(resp.mutable_session_handle())); 111 current_graph_version_ = resp.graph_version(); 112 } 113 return s; 114 } 115 116 Status GrpcSession::Create(const GraphDef& graph) { 117 CallOptions call_options; 118 call_options.SetTimeout(options_.config.operation_timeout_in_ms()); 119 return CreateImpl(&call_options, graph); 120 } 121 122 Status GrpcSession::Create(const RunOptions& run_options, 123 const GraphDef& graph) { 124 CallOptions call_options; 125 call_options.SetTimeout(run_options.timeout_in_ms()); 126 return CreateImpl(&call_options, graph); 127 } 128 129 Status GrpcSession::ExtendImpl(CallOptions* call_options, 130 const GraphDef& graph) { 131 bool handle_is_empty; 132 { 133 mutex_lock l(mu_); 134 handle_is_empty = handle_.empty(); 135 } 136 if (handle_is_empty) { 137 // Session was unitialized, so simply initialize the session with 'graph'. 138 return Create(graph); 139 } 140 mutex_lock l(mu_); 141 ExtendSessionRequest req; 142 req.set_session_handle(handle_); 143 *req.mutable_graph_def() = graph; 144 req.set_current_graph_version(current_graph_version_); 145 ExtendSessionResponse resp; 146 Status s = master_->ExtendSession(call_options, &req, &resp); 147 if (s.ok()) { 148 current_graph_version_ = resp.new_graph_version(); 149 } 150 return s; 151 } 152 153 Status GrpcSession::Extend(const GraphDef& graph) { 154 CallOptions call_options; 155 call_options.SetTimeout(options_.config.operation_timeout_in_ms()); 156 return ExtendImpl(&call_options, graph); 157 } 158 159 Status GrpcSession::Extend(const RunOptions& run_options, 160 const GraphDef& graph) { 161 CallOptions call_options; 162 call_options.SetTimeout(run_options.timeout_in_ms()); 163 return ExtendImpl(&call_options, graph); 164 } 165 166 Status GrpcSession::RunHelper( 167 const RunOptions& run_options, 168 const std::vector<std::pair<string, Tensor>>& inputs, 169 const std::vector<string>& output_tensor_names, 170 const std::vector<string>& target_node_names, std::vector<Tensor>* outputs, 171 RunMetadata* run_metadata, const string& prun_handle) { 172 // Convert to proto 173 std::unique_ptr<MutableRunStepRequestWrapper> req( 174 master_->CreateRunStepRequest()); 175 std::unique_ptr<MutableRunStepResponseWrapper> resp( 176 master_->CreateRunStepResponse()); 177 178 *req->mutable_options() = run_options; 179 180 if (run_options.timeout_in_ms() == 0) { 181 req->mutable_options()->set_timeout_in_ms( 182 options_.config.operation_timeout_in_ms()); 183 } 184 185 if (!prun_handle.empty()) { 186 req->set_partial_run_handle(prun_handle); 187 } 188 189 for (const auto& it : inputs) { 190 req->add_feed(it.first, it.second); 191 } 192 193 // Support long error messages by storing the error code in the response body. 194 req->set_store_errors_in_response_body(true); 195 196 // Build an index from fetch tensor name to first index in 197 // output_tensor_names. 198 std::unordered_map<string, int> output_name_to_offset; 199 for (int i = 0; i < output_tensor_names.size(); ++i) { 200 const string& name = output_tensor_names[i]; 201 if (output_name_to_offset.insert(std::make_pair(name, i)).second) { 202 req->add_fetch(name); 203 } 204 } 205 for (const string& target : target_node_names) { 206 req->add_target(target); 207 } 208 209 CallOptions call_options; 210 call_options.SetTimeout(req->options().timeout_in_ms()); 211 TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get())); 212 213 // Look for an extended error returned in the response body. 214 if (resp->status_code() != error::Code::OK) { 215 return Status(resp->status_code(), resp->status_error_message()); 216 } 217 218 if (!output_tensor_names.empty()) { 219 outputs->resize(output_tensor_names.size()); 220 } 221 222 // Convert response back to Tensors in the correct order. 223 for (size_t i = 0; i < resp->num_tensors(); ++i) { 224 auto fetch_it = output_name_to_offset.find(resp->tensor_name(i)); 225 if (fetch_it == output_name_to_offset.end()) { 226 return errors::Internal("Received response for unrequested fetch: ", 227 resp->tensor_name(i)); 228 } 229 230 Tensor output; 231 TF_RETURN_IF_ERROR(resp->TensorValue(i, &output)); 232 (*outputs)[fetch_it->second] = output; 233 } 234 // In the unlikely event that output_tensor_names contains duplicates, fill in 235 // the duplicate values. 236 if (output_name_to_offset.size() != output_tensor_names.size()) { 237 for (int i = 0; i < output_tensor_names.size(); ++i) { 238 const string& name = output_tensor_names[i]; 239 int offset = output_name_to_offset[name]; 240 if (offset != i) { 241 (*outputs)[i] = (*outputs)[offset]; 242 } 243 } 244 } 245 246 if (run_metadata) { 247 run_metadata->Swap(resp->mutable_metadata()); 248 } 249 250 return Status::OK(); 251 } 252 253 Status GrpcSession::Run(const RunOptions& run_options, 254 const std::vector<std::pair<string, Tensor>>& inputs, 255 const std::vector<string>& output_tensor_names, 256 const std::vector<string>& target_node_names, 257 std::vector<Tensor>* outputs, 258 RunMetadata* run_metadata) { 259 return RunHelper(run_options, inputs, output_tensor_names, target_node_names, 260 outputs, run_metadata, /* prun_handle */ ""); 261 } 262 263 Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, 264 const std::vector<string>& output_tensor_names, 265 const std::vector<string>& target_node_names, 266 std::vector<Tensor>* outputs) { 267 RunOptions run_options; 268 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); 269 return Run(run_options, inputs, output_tensor_names, target_node_names, 270 outputs, nullptr); 271 } 272 273 Status GrpcSession::RunProto(CallOptions* call_options, 274 MutableRunStepRequestWrapper* req, 275 MutableRunStepResponseWrapper* resp) { 276 { 277 mutex_lock l(mu_); 278 if (handle_.empty()) { 279 return errors::InvalidArgument("A session is not created yet...."); 280 } 281 282 req->set_session_handle(handle_); 283 } 284 return master_->RunStep(call_options, req, resp); 285 } 286 287 Status GrpcSession::PRunSetup(const std::vector<string>& input_names, 288 const std::vector<string>& output_names, 289 const std::vector<string>& target_nodes, 290 string* handle) { 291 // Convert to proto 292 PartialRunSetupRequest req; 293 PartialRunSetupResponse resp; 294 CallOptions call_options; 295 { 296 mutex_lock l(mu_); 297 if (handle_.empty()) { 298 return errors::InvalidArgument("A session is not created yet...."); 299 } 300 301 req.set_session_handle(handle_); 302 } 303 for (const string& feed : input_names) { 304 req.add_feed(feed); 305 } 306 for (const string& fetch : output_names) { 307 req.add_fetch(fetch); 308 } 309 for (const string& target : target_nodes) { 310 req.add_target(target); 311 } 312 call_options.SetTimeout(options_.config.operation_timeout_in_ms()); 313 TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp)); 314 *handle = resp.partial_run_handle(); 315 return Status::OK(); 316 } 317 318 Status GrpcSession::PRun(const string& handle, 319 const std::vector<std::pair<string, Tensor>>& inputs, 320 const std::vector<string>& output_names, 321 std::vector<Tensor>* outputs) { 322 RunOptions run_options; 323 run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); 324 return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs, 325 /* run_metadata */ nullptr, handle); 326 } 327 328 Status GrpcSession::Close() { 329 CloseSessionRequest req; 330 { 331 mutex_lock l(mu_); 332 if (handle_.empty()) { 333 return Status::OK(); 334 } 335 req.set_session_handle(handle_); 336 handle_.clear(); 337 } 338 CloseSessionResponse resp; 339 CallOptions call_options; 340 call_options.SetTimeout(options_.config.operation_timeout_in_ms()); 341 return master_->CloseSession(&call_options, &req, &resp); 342 } 343 344 Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) { 345 ListDevicesRequest req; 346 { 347 mutex_lock l(mu_); 348 req.set_session_handle(handle_); 349 } 350 if (req.session_handle().empty()) { 351 LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with " 352 "an empty graph and other defaults because the session has " 353 "not yet been created."; 354 GraphDef graph_def; 355 TF_RETURN_IF_ERROR(Create(graph_def)); 356 { 357 mutex_lock l(mu_); 358 req.set_session_handle(handle_); 359 } 360 } 361 ListDevicesResponse resp; 362 CallOptions call_options; 363 call_options.SetTimeout(options_.config.operation_timeout_in_ms()); 364 Status s = master_->ListDevices(&call_options, &req, &resp); 365 if (!s.ok()) { 366 LOG(ERROR) << "Could not list devices: " << s; 367 return s; 368 } 369 370 response->clear(); 371 response->reserve(resp.local_device_size() + resp.remote_device_size()); 372 for (const auto& device_attr : resp.local_device()) { 373 response->emplace_back(device_attr); 374 } 375 for (const auto& device_attr : resp.remote_device()) { 376 response->emplace_back(device_attr); 377 } 378 return Status::OK(); 379 } 380 381 void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) { 382 master_ = std::move(master); 383 } 384 385 // Static method. 386 Status GrpcSession::Reset(const SessionOptions& options, 387 const std::vector<string>& containers) { 388 SharedGrpcChannelPtr master_channel; 389 TF_RETURN_IF_ERROR(NewHostPortGrpcChannel( 390 options.target.substr(kSchemePrefixLength), &master_channel)); 391 auto master = NewGrpcMaster(master_channel); 392 ResetRequest req; 393 for (const auto& c : containers) req.add_container(c); 394 ResetResponse resp; 395 CallOptions call_options; 396 call_options.SetTimeout(options.config.operation_timeout_in_ms()); 397 Status ret = master->Reset(&call_options, &req, &resp); 398 delete master; 399 return ret; 400 } 401 402 class GrpcSessionFactory : public SessionFactory { 403 public: 404 bool AcceptsOptions(const SessionOptions& options) override { 405 return StringPiece(options.target).starts_with(kSchemePrefix); 406 } 407 408 Session* NewSession(const SessionOptions& options) override { 409 std::unique_ptr<GrpcSession> ret; 410 Status s = GrpcSession::Create(options, &ret); 411 if (s.ok()) { 412 return ret.release(); 413 } else { 414 LOG(ERROR) << "Error during session construction: " << s.ToString(); 415 return nullptr; 416 } 417 } 418 419 // Invokes the session specific static method to reset containers. 420 Status Reset(const SessionOptions& options, 421 const std::vector<string>& containers) override { 422 return GrpcSession::Reset(options, containers); 423 } 424 }; 425 426 class GrpcSessionRegistrar { 427 public: 428 GrpcSessionRegistrar() { 429 SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory()); 430 } 431 }; 432 static GrpcSessionRegistrar registrar; 433 434 } // namespace tensorflow 435