1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 17 #include "tensorflow/core/framework/cost_graph.pb.h" 18 #include "tensorflow/core/framework/step_stats.pb.h" 19 #include "tensorflow/core/protobuf/config.pb.h" 20 #include "tensorflow/core/protobuf/named_tensor.pb.h" 21 22 namespace tensorflow { 23 24 namespace { 25 26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, 27 Tensor* out_tensor) { 28 if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { 29 Tensor parsed(tensor_proto.dtype()); 30 if (parsed.FromProto(cpu_allocator(), tensor_proto)) { 31 *out_tensor = parsed; 32 return true; 33 } 34 } 35 return false; 36 } 37 38 } // namespace 39 40 const string& InMemoryRunStepRequest::session_handle() const { 41 return session_handle_; 42 } 43 44 void InMemoryRunStepRequest::set_session_handle(const string& handle) { 45 session_handle_ = handle; 46 } 47 48 const string& InMemoryRunStepRequest::partial_run_handle() const { 49 return partial_run_handle_; 50 } 51 52 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) { 53 partial_run_handle_ = handle; 54 } 55 56 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); } 57 const string& InMemoryRunStepRequest::feed_name(size_t i) const { 58 return feeds_[i].first; 59 } 60 61 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { 62 *out_tensor = feeds_[i].second; 63 return Status::OK(); 64 } 65 66 Status InMemoryRunStepRequest::FeedValue(size_t i, 67 TensorProto* out_tensor) const { 68 feeds_[i].second.AsProtoTensorContent(out_tensor); 69 return Status::OK(); 70 } 71 72 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) { 73 feeds_.emplace_back(name, value); 74 } 75 76 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); } 77 const string& InMemoryRunStepRequest::fetch_name(size_t i) const { 78 return fetches_[i]; 79 } 80 void InMemoryRunStepRequest::add_fetch(const string& name) { 81 fetches_.push_back(name); 82 } 83 84 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); } 85 const string& InMemoryRunStepRequest::target_name(size_t i) const { 86 return targets_[i]; 87 } 88 void InMemoryRunStepRequest::add_target(const string& name) { 89 targets_.push_back(name); 90 } 91 92 const RunOptions& InMemoryRunStepRequest::options() const { return options_; } 93 94 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; } 95 96 bool InMemoryRunStepRequest::store_errors_in_response_body() const { 97 return store_errors_in_response_body_; 98 } 99 100 int64 InMemoryRunStepRequest::request_id() const { 101 return 0; // no need to track request id for local version. 102 } 103 104 void InMemoryRunStepRequest::set_store_errors_in_response_body( 105 bool store_errors) { 106 store_errors_in_response_body_ = store_errors; 107 } 108 109 string InMemoryRunStepRequest::DebugString() const { 110 return ToProto().DebugString(); 111 } 112 113 const RunStepRequest& InMemoryRunStepRequest::ToProto() const { 114 if (!proto_version_) { 115 proto_version_.reset(new RunStepRequest); 116 proto_version_->set_session_handle(session_handle()); 117 proto_version_->set_partial_run_handle(partial_run_handle()); 118 for (size_t i = 0; i < num_feeds(); ++i) { 119 auto feed = proto_version_->add_feed(); 120 feed->set_name(feed_name(i)); 121 feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor()); 122 } 123 for (size_t i = 0; i < num_fetches(); ++i) { 124 proto_version_->add_fetch(fetch_name(i)); 125 } 126 for (size_t i = 0; i < num_targets(); ++i) { 127 proto_version_->add_target(target_name(i)); 128 } 129 *proto_version_->mutable_options() = options(); 130 } 131 return *proto_version_; 132 } 133 134 const string& MutableProtoRunStepRequest::session_handle() const { 135 return request_.session_handle(); 136 } 137 void MutableProtoRunStepRequest::set_session_handle(const string& handle) { 138 request_.set_session_handle(handle); 139 } 140 141 const string& MutableProtoRunStepRequest::partial_run_handle() const { 142 return request_.partial_run_handle(); 143 } 144 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) { 145 request_.set_partial_run_handle(handle); 146 } 147 148 size_t MutableProtoRunStepRequest::num_feeds() const { 149 return request_.feed_size(); 150 } 151 const string& MutableProtoRunStepRequest::feed_name(size_t i) const { 152 return request_.feed(i).name(); 153 } 154 Status MutableProtoRunStepRequest::FeedValue(size_t i, 155 Tensor* out_tensor) const { 156 if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) { 157 return errors::InvalidArgument("Invalid TensorProto for feed value ", i); 158 } else { 159 return Status::OK(); 160 } 161 } 162 163 Status MutableProtoRunStepRequest::FeedValue(size_t i, 164 TensorProto* out_tensor) const { 165 *out_tensor = request_.feed(i).tensor(); 166 return Status::OK(); 167 } 168 169 void MutableProtoRunStepRequest::add_feed(const string& name, 170 const Tensor& value) { 171 NamedTensorProto* feed = request_.add_feed(); 172 feed->set_name(name); 173 TensorProto* value_proto = feed->mutable_tensor(); 174 value.AsProtoTensorContent(value_proto); 175 } 176 177 size_t MutableProtoRunStepRequest::num_fetches() const { 178 return request_.fetch_size(); 179 } 180 181 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const { 182 return request_.fetch(i); 183 } 184 void MutableProtoRunStepRequest::add_fetch(const string& name) { 185 request_.add_fetch(name); 186 } 187 188 size_t MutableProtoRunStepRequest::num_targets() const { 189 return request_.target_size(); 190 } 191 192 const string& MutableProtoRunStepRequest::target_name(size_t i) const { 193 return request_.target(i); 194 } 195 196 void MutableProtoRunStepRequest::add_target(const string& name) { 197 request_.add_target(name); 198 } 199 200 const RunOptions& MutableProtoRunStepRequest::options() const { 201 return request_.options(); 202 } 203 204 RunOptions* MutableProtoRunStepRequest::mutable_options() { 205 return request_.mutable_options(); 206 } 207 208 bool MutableProtoRunStepRequest::store_errors_in_response_body() const { 209 return request_.store_errors_in_response_body(); 210 } 211 212 void MutableProtoRunStepRequest::set_store_errors_in_response_body( 213 bool store_errors) { 214 request_.set_store_errors_in_response_body(store_errors); 215 } 216 217 int64 MutableProtoRunStepRequest::request_id() const { 218 return request_.request_id(); 219 } 220 221 string MutableProtoRunStepRequest::DebugString() const { 222 return request_.DebugString(); 223 } 224 225 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const { 226 return request_; 227 } 228 229 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request) 230 : request_(request) {} 231 232 const string& ProtoRunStepRequest::session_handle() const { 233 return request_->session_handle(); 234 } 235 236 const string& ProtoRunStepRequest::partial_run_handle() const { 237 return request_->partial_run_handle(); 238 } 239 240 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); } 241 242 const string& ProtoRunStepRequest::feed_name(size_t i) const { 243 return request_->feed(i).name(); 244 } 245 246 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { 247 if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) { 248 return errors::InvalidArgument("Invalid TensorProto for feed value ", i); 249 } else { 250 return Status::OK(); 251 } 252 } 253 254 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { 255 *out_tensor = request_->feed(i).tensor(); 256 return Status::OK(); 257 } 258 259 size_t ProtoRunStepRequest::num_fetches() const { 260 return request_->fetch_size(); 261 } 262 263 const string& ProtoRunStepRequest::fetch_name(size_t i) const { 264 return request_->fetch(i); 265 } 266 267 size_t ProtoRunStepRequest::num_targets() const { 268 return request_->target_size(); 269 } 270 271 const string& ProtoRunStepRequest::target_name(size_t i) const { 272 return request_->target(i); 273 } 274 275 const RunOptions& ProtoRunStepRequest::options() const { 276 return request_->options(); 277 } 278 279 bool ProtoRunStepRequest::store_errors_in_response_body() const { 280 return request_->store_errors_in_response_body(); 281 } 282 283 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); } 284 285 string ProtoRunStepRequest::DebugString() const { 286 return request_->DebugString(); 287 } 288 289 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; } 290 291 const string& InMemoryRunGraphRequest::session_handle() const { 292 return session_handle_; 293 } 294 295 bool InMemoryRunGraphRequest::create_worker_session_called() const { 296 return create_worker_session_called_; 297 } 298 299 void InMemoryRunGraphRequest::set_session_handle(const string& handle) { 300 session_handle_ = handle; 301 } 302 303 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) { 304 create_worker_session_called_ = called; 305 } 306 307 const string& InMemoryRunGraphRequest::graph_handle() const { 308 return graph_handle_; 309 } 310 311 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) { 312 graph_handle_ = handle; 313 } 314 315 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; } 316 317 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; } 318 319 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const { 320 return exec_opts_; 321 } 322 323 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() { 324 return &exec_opts_; 325 } 326 327 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); } 328 329 const string& InMemoryRunGraphRequest::send_key(size_t i) const { 330 return sends_[i].first; 331 } 332 333 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { 334 *out_tensor = sends_[i].second; 335 return Status::OK(); 336 } 337 338 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( 339 const RunStepRequestWrapper& run_step_request, size_t i, 340 const string& send_key) { 341 Tensor tensor; 342 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor)); 343 sends_.emplace_back(send_key, std::move(tensor)); 344 return Status::OK(); 345 } 346 347 // TODO(b/74355905): Add a specialized implementation that avoids 348 // copying the tensor when at least two of the {client, master, 349 // worker} are in the same process. 350 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( 351 const RunCallableRequest& run_callable_request, size_t i, 352 const string& send_key) { 353 Tensor tensor; 354 if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) { 355 return errors::InvalidArgument("Invalid TensorProto for feed value ", i); 356 } 357 sends_.emplace_back(send_key, std::move(tensor)); 358 return Status::OK(); 359 } 360 361 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); } 362 363 const string& InMemoryRunGraphRequest::recv_key(size_t i) const { 364 return recvs_[i]; 365 } 366 367 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) { 368 recvs_.push_back(recv_key); 369 } 370 371 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; } 372 373 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) { 374 is_partial_ = is_partial; 375 } 376 377 bool InMemoryRunGraphRequest::is_last_partial_run() const { 378 return is_last_partial_run_; 379 } 380 381 void InMemoryRunGraphRequest::set_is_last_partial_run( 382 bool is_last_partial_run) { 383 is_last_partial_run_ = is_last_partial_run; 384 } 385 386 bool InMemoryRunGraphRequest::store_errors_in_response_body() const { 387 return store_errors_in_response_body_; 388 } 389 390 void InMemoryRunGraphRequest::set_store_errors_in_response_body( 391 bool store_errors) { 392 store_errors_in_response_body_ = store_errors; 393 } 394 395 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const { 396 if (!proto_version_) { 397 proto_version_.reset(new RunGraphRequest); 398 proto_version_->set_session_handle(session_handle()); 399 proto_version_->set_create_worker_session_called( 400 create_worker_session_called()); 401 proto_version_->set_graph_handle(graph_handle()); 402 proto_version_->set_step_id(step_id()); 403 *proto_version_->mutable_exec_opts() = exec_opts(); 404 for (size_t i = 0; i < num_sends(); ++i) { 405 auto send = proto_version_->add_send(); 406 send->set_name(send_key(i)); 407 sends_[i].second.AsProtoTensorContent(send->mutable_tensor()); 408 } 409 for (size_t i = 0; i < num_recvs(); ++i) { 410 proto_version_->add_recv_key(recv_key(i)); 411 } 412 proto_version_->set_is_partial(is_partial()); 413 proto_version_->set_is_last_partial_run(is_last_partial_run()); 414 } 415 return *proto_version_; 416 } 417 418 const string& MutableProtoRunGraphRequest::session_handle() const { 419 return request_.session_handle(); 420 } 421 422 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) { 423 request_.set_session_handle(handle); 424 } 425 426 bool MutableProtoRunGraphRequest::create_worker_session_called() const { 427 return request_.create_worker_session_called(); 428 } 429 430 void MutableProtoRunGraphRequest::set_create_worker_session_called( 431 bool called) { 432 request_.set_create_worker_session_called(called); 433 } 434 435 const string& MutableProtoRunGraphRequest::graph_handle() const { 436 return request_.graph_handle(); 437 } 438 439 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) { 440 request_.set_graph_handle(handle); 441 } 442 443 int64 MutableProtoRunGraphRequest::step_id() const { 444 return request_.step_id(); 445 } 446 447 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) { 448 request_.set_step_id(step_id); 449 } 450 451 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const { 452 return request_.exec_opts(); 453 } 454 455 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() { 456 return request_.mutable_exec_opts(); 457 } 458 459 size_t MutableProtoRunGraphRequest::num_sends() const { 460 return request_.send_size(); 461 } 462 463 const string& MutableProtoRunGraphRequest::send_key(size_t i) const { 464 return request_.send(i).name(); 465 } 466 467 Status MutableProtoRunGraphRequest::SendValue(size_t i, 468 Tensor* out_tensor) const { 469 if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) { 470 return errors::InvalidArgument("Invalid TensorProto for feed value ", i); 471 } else { 472 return Status::OK(); 473 } 474 } 475 476 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( 477 const RunStepRequestWrapper& run_step_request, size_t i, 478 const string& send_key) { 479 NamedTensorProto* send = request_.add_send(); 480 send->set_name(send_key); 481 TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor())); 482 return Status::OK(); 483 } 484 485 // TODO(b/74355905): Add a specialized implementation that avoids 486 // copying the tensor when at least two of the {client, master, 487 // worker} are in the same process. 488 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( 489 const RunCallableRequest& run_callable_request, size_t i, 490 const string& send_key) { 491 NamedTensorProto* send = request_.add_send(); 492 send->set_name(send_key); 493 *send->mutable_tensor() = run_callable_request.feed(i); 494 return Status::OK(); 495 } 496 497 size_t MutableProtoRunGraphRequest::num_recvs() const { 498 return request_.recv_key_size(); 499 } 500 501 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const { 502 return request_.recv_key(i); 503 } 504 505 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) { 506 request_.add_recv_key(recv_key); 507 } 508 509 bool MutableProtoRunGraphRequest::is_partial() const { 510 return request_.is_partial(); 511 } 512 513 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) { 514 request_.set_is_partial(is_partial); 515 } 516 517 bool MutableProtoRunGraphRequest::is_last_partial_run() const { 518 return request_.is_last_partial_run(); 519 } 520 521 void MutableProtoRunGraphRequest::set_is_last_partial_run( 522 bool is_last_partial_run) { 523 request_.set_is_last_partial_run(is_last_partial_run); 524 } 525 526 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const { 527 return request_.store_errors_in_response_body(); 528 } 529 530 void MutableProtoRunGraphRequest::set_store_errors_in_response_body( 531 bool store_errors) { 532 request_.set_store_errors_in_response_body(store_errors); 533 } 534 535 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const { 536 return request_; 537 } 538 539 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request) 540 : request_(request) {} 541 542 const string& ProtoRunGraphRequest::session_handle() const { 543 return request_->session_handle(); 544 } 545 546 bool ProtoRunGraphRequest::create_worker_session_called() const { 547 return request_->create_worker_session_called(); 548 } 549 550 const string& ProtoRunGraphRequest::graph_handle() const { 551 return request_->graph_handle(); 552 } 553 554 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); } 555 556 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const { 557 return request_->exec_opts(); 558 } 559 560 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); } 561 562 const string& ProtoRunGraphRequest::send_key(size_t i) const { 563 return request_->send(i).name(); 564 } 565 566 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { 567 if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) { 568 return errors::InvalidArgument("Invalid TensorProto for feed value ", i); 569 } else { 570 return Status::OK(); 571 } 572 } 573 574 size_t ProtoRunGraphRequest::num_recvs() const { 575 return request_->recv_key_size(); 576 } 577 578 const string& ProtoRunGraphRequest::recv_key(size_t i) const { 579 return request_->recv_key(i); 580 } 581 582 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); } 583 584 bool ProtoRunGraphRequest::is_last_partial_run() const { 585 return request_->is_last_partial_run(); 586 } 587 588 bool ProtoRunGraphRequest::store_errors_in_response_body() const { 589 return request_->store_errors_in_response_body(); 590 } 591 592 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const { 593 return *request_; 594 } 595 596 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); } 597 598 const string& InMemoryRunGraphResponse::recv_key(size_t i) const { 599 return recvs_[i].first; 600 } 601 602 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { 603 recvs_[i].second.AsProtoTensorContent(out_tensor); 604 return Status::OK(); 605 } 606 607 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { 608 *out_tensor = recvs_[i].second; 609 return Status::OK(); 610 } 611 612 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) { 613 recvs_.emplace_back(key, value); 614 } 615 616 StepStats* InMemoryRunGraphResponse::mutable_step_stats() { 617 return &step_stats_; 618 } 619 620 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { 621 return &cost_graph_; 622 } 623 624 errors::Code InMemoryRunGraphResponse::status_code() const { 625 return status_.code(); 626 } 627 628 const string& InMemoryRunGraphResponse::status_error_message() const { 629 return status_.error_message(); 630 } 631 632 void InMemoryRunGraphResponse::set_status(const Status& status) { 633 status_ = status; 634 } 635 636 RunGraphResponse* InMemoryRunGraphResponse::get_proto() { 637 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse"; 638 return nullptr; 639 } 640 641 size_t InMemoryRunGraphResponse::num_partition_graphs() const { 642 return partition_graphs_.size(); 643 } 644 645 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) { 646 return &partition_graphs_[i]; 647 } 648 649 void InMemoryRunGraphResponse::AddPartitionGraph( 650 const GraphDef& partition_graph) { 651 partition_graphs_.push_back(partition_graph); 652 } 653 654 size_t OwnedProtoRunGraphResponse::num_recvs() const { 655 return response_.recv_size(); 656 } 657 658 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const { 659 return response_.recv(i).name(); 660 } 661 662 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, 663 TensorProto* out_tensor) { 664 out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor()); 665 return Status::OK(); 666 } 667 668 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { 669 if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) { 670 return errors::InvalidArgument("Invalid TensorProto for recv value ", i); 671 } else { 672 return Status::OK(); 673 } 674 } 675 676 void OwnedProtoRunGraphResponse::AddRecv(const string& key, 677 const Tensor& value) { 678 NamedTensorProto* recv = response_.add_recv(); 679 recv->set_name(key); 680 TensorProto* value_proto = recv->mutable_tensor(); 681 value.AsProtoTensorContent(value_proto); 682 } 683 684 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() { 685 return response_.mutable_step_stats(); 686 } 687 688 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { 689 return response_.mutable_cost_graph(); 690 } 691 692 errors::Code OwnedProtoRunGraphResponse::status_code() const { 693 return response_.status_code(); 694 } 695 696 const string& OwnedProtoRunGraphResponse::status_error_message() const { 697 return response_.status_error_message(); 698 } 699 700 void OwnedProtoRunGraphResponse::set_status(const Status& status) { 701 response_.set_status_code(status.code()); 702 response_.set_status_error_message(status.error_message()); 703 } 704 705 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; } 706 707 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const { 708 return response_.partition_graph_size(); 709 } 710 711 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { 712 return response_.mutable_partition_graph(i); 713 } 714 715 void OwnedProtoRunGraphResponse::AddPartitionGraph( 716 const GraphDef& partition_graph) { 717 GraphDef* graph_def = response_.mutable_partition_graph()->Add(); 718 *graph_def = partition_graph; 719 } 720 721 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse( 722 RunGraphResponse* response) 723 : response_(response) {} 724 725 size_t NonOwnedProtoRunGraphResponse::num_recvs() const { 726 return response_->recv_size(); 727 } 728 729 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const { 730 return response_->recv(i).name(); 731 } 732 733 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, 734 TensorProto* out_tensor) { 735 out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor()); 736 return Status::OK(); 737 } 738 739 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { 740 if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) { 741 return errors::InvalidArgument("Invalid TensorProto for recv value ", i); 742 } else { 743 return Status::OK(); 744 } 745 } 746 747 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key, 748 const Tensor& value) { 749 NamedTensorProto* recv = response_->add_recv(); 750 recv->set_name(key); 751 TensorProto* value_proto = recv->mutable_tensor(); 752 value.AsProtoTensorContent(value_proto); 753 } 754 755 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() { 756 return response_->mutable_step_stats(); 757 } 758 759 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() { 760 return response_->mutable_cost_graph(); 761 } 762 763 errors::Code NonOwnedProtoRunGraphResponse::status_code() const { 764 return response_->status_code(); 765 } 766 767 const string& NonOwnedProtoRunGraphResponse::status_error_message() const { 768 return response_->status_error_message(); 769 } 770 771 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) { 772 response_->set_status_code(status.code()); 773 response_->set_status_error_message(status.error_message()); 774 } 775 776 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() { 777 return response_; 778 } 779 780 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const { 781 return response_->partition_graph_size(); 782 } 783 784 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { 785 return response_->mutable_partition_graph(i); 786 } 787 788 void NonOwnedProtoRunGraphResponse::AddPartitionGraph( 789 const GraphDef& partition_graph) { 790 GraphDef* graph_def = response_->add_partition_graph(); 791 *graph_def = partition_graph; 792 } 793 794 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {} 795 796 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); } 797 798 const string& InMemoryRunStepResponse::tensor_name(size_t i) const { 799 return tensors_[i].first; 800 } 801 802 Status InMemoryRunStepResponse::TensorValue(size_t i, 803 Tensor* out_tensor) const { 804 *out_tensor = tensors_[i].second; 805 return Status::OK(); 806 } 807 808 const RunMetadata& InMemoryRunStepResponse::metadata() const { 809 return metadata_; 810 } 811 812 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( 813 const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) { 814 Tensor tensor; 815 TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor)); 816 tensors_.emplace_back(name, tensor); 817 return Status::OK(); 818 } 819 820 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } 821 822 errors::Code InMemoryRunStepResponse::status_code() const { 823 return status_.code(); 824 } 825 826 const string& InMemoryRunStepResponse::status_error_message() const { 827 return status_.error_message(); 828 } 829 830 void InMemoryRunStepResponse::set_status(const Status& status) { 831 status_ = status; 832 } 833 834 RunStepResponse* InMemoryRunStepResponse::get_proto() { 835 LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse"; 836 return nullptr; 837 } 838 839 size_t OwnedProtoRunStepResponse::num_tensors() const { 840 return response_.tensor_size(); 841 } 842 843 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const { 844 return response_.tensor(i).name(); 845 } 846 847 Status OwnedProtoRunStepResponse::TensorValue(size_t i, 848 Tensor* out_tensor) const { 849 if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) { 850 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); 851 } else { 852 return Status::OK(); 853 } 854 } 855 856 const RunMetadata& OwnedProtoRunStepResponse::metadata() const { 857 return response_.metadata(); 858 } 859 860 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( 861 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 862 size_t i) { 863 NamedTensorProto* response_tensor = response_.add_tensor(); 864 response_tensor->set_name(name); 865 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); 866 } 867 868 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() { 869 return response_.mutable_metadata(); 870 } 871 872 errors::Code OwnedProtoRunStepResponse::status_code() const { 873 return response_.status_code(); 874 } 875 876 const string& OwnedProtoRunStepResponse::status_error_message() const { 877 return response_.status_error_message(); 878 } 879 880 void OwnedProtoRunStepResponse::set_status(const Status& status) { 881 response_.set_status_code(status.code()); 882 response_.set_status_error_message(status.error_message()); 883 } 884 885 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; } 886 887 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse( 888 RunStepResponse* response) 889 : response_(response) {} 890 891 size_t NonOwnedProtoRunStepResponse::num_tensors() const { 892 return response_->tensor_size(); 893 } 894 895 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const { 896 return response_->tensor(i).name(); 897 } 898 899 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i, 900 Tensor* out_tensor) const { 901 if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) { 902 return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); 903 } else { 904 return Status::OK(); 905 } 906 } 907 908 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const { 909 return response_->metadata(); 910 } 911 912 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( 913 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 914 size_t i) { 915 NamedTensorProto* response_tensor = response_->add_tensor(); 916 response_tensor->set_name(name); 917 return run_graph_response->RecvValue(i, response_tensor->mutable_tensor()); 918 } 919 920 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() { 921 return response_->mutable_metadata(); 922 } 923 924 errors::Code NonOwnedProtoRunStepResponse::status_code() const { 925 return response_->status_code(); 926 } 927 928 const string& NonOwnedProtoRunStepResponse::status_error_message() const { 929 return response_->status_error_message(); 930 } 931 932 void NonOwnedProtoRunStepResponse::set_status(const Status& status) { 933 response_->set_status_code(status.code()); 934 response_->set_status_error_message(status.error_message()); 935 } 936 937 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; } 938 939 } // namespace tensorflow 940