Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     17 #include "tensorflow/core/framework/cost_graph.pb.h"
     18 #include "tensorflow/core/framework/step_stats.pb.h"
     19 #include "tensorflow/core/protobuf/config.pb.h"
     20 #include "tensorflow/core/protobuf/named_tensor.pb.h"
     21 
     22 namespace tensorflow {
     23 
     24 namespace {
     25 
     26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
     27                               Tensor* out_tensor) {
     28   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
     29     Tensor parsed(tensor_proto.dtype());
     30     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
     31       *out_tensor = parsed;
     32       return true;
     33     }
     34   }
     35   return false;
     36 }
     37 
     38 }  // namespace
     39 
     40 const string& InMemoryRunStepRequest::session_handle() const {
     41   return session_handle_;
     42 }
     43 
     44 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
     45   session_handle_ = handle;
     46 }
     47 
     48 const string& InMemoryRunStepRequest::partial_run_handle() const {
     49   return partial_run_handle_;
     50 }
     51 
     52 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
     53   partial_run_handle_ = handle;
     54 }
     55 
     56 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
     57 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
     58   return feeds_[i].first;
     59 }
     60 
     61 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
     62   *out_tensor = feeds_[i].second;
     63   return Status::OK();
     64 }
     65 
     66 Status InMemoryRunStepRequest::FeedValue(size_t i,
     67                                          TensorProto* out_tensor) const {
     68   feeds_[i].second.AsProtoTensorContent(out_tensor);
     69   return Status::OK();
     70 }
     71 
     72 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
     73   feeds_.emplace_back(name, value);
     74 }
     75 
     76 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
     77 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
     78   return fetches_[i];
     79 }
     80 void InMemoryRunStepRequest::add_fetch(const string& name) {
     81   fetches_.push_back(name);
     82 }
     83 
     84 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
     85 const string& InMemoryRunStepRequest::target_name(size_t i) const {
     86   return targets_[i];
     87 }
     88 void InMemoryRunStepRequest::add_target(const string& name) {
     89   targets_.push_back(name);
     90 }
     91 
     92 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
     93 
     94 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
     95 
     96 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
     97   return store_errors_in_response_body_;
     98 }
     99 
    100 int64 InMemoryRunStepRequest::request_id() const {
    101   return 0;  // no need to track request id for local version.
    102 }
    103 
    104 void InMemoryRunStepRequest::set_store_errors_in_response_body(
    105     bool store_errors) {
    106   store_errors_in_response_body_ = store_errors;
    107 }
    108 
    109 string InMemoryRunStepRequest::DebugString() const {
    110   return ToProto().DebugString();
    111 }
    112 
    113 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
    114   if (!proto_version_) {
    115     proto_version_.reset(new RunStepRequest);
    116     proto_version_->set_session_handle(session_handle());
    117     proto_version_->set_partial_run_handle(partial_run_handle());
    118     for (size_t i = 0; i < num_feeds(); ++i) {
    119       auto feed = proto_version_->add_feed();
    120       feed->set_name(feed_name(i));
    121       feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
    122     }
    123     for (size_t i = 0; i < num_fetches(); ++i) {
    124       proto_version_->add_fetch(fetch_name(i));
    125     }
    126     for (size_t i = 0; i < num_targets(); ++i) {
    127       proto_version_->add_target(target_name(i));
    128     }
    129     *proto_version_->mutable_options() = options();
    130   }
    131   return *proto_version_;
    132 }
    133 
    134 const string& MutableProtoRunStepRequest::session_handle() const {
    135   return request_.session_handle();
    136 }
    137 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
    138   request_.set_session_handle(handle);
    139 }
    140 
    141 const string& MutableProtoRunStepRequest::partial_run_handle() const {
    142   return request_.partial_run_handle();
    143 }
    144 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
    145   request_.set_partial_run_handle(handle);
    146 }
    147 
    148 size_t MutableProtoRunStepRequest::num_feeds() const {
    149   return request_.feed_size();
    150 }
    151 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
    152   return request_.feed(i).name();
    153 }
    154 Status MutableProtoRunStepRequest::FeedValue(size_t i,
    155                                              Tensor* out_tensor) const {
    156   if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
    157     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
    158   } else {
    159     return Status::OK();
    160   }
    161 }
    162 
    163 Status MutableProtoRunStepRequest::FeedValue(size_t i,
    164                                              TensorProto* out_tensor) const {
    165   *out_tensor = request_.feed(i).tensor();
    166   return Status::OK();
    167 }
    168 
    169 void MutableProtoRunStepRequest::add_feed(const string& name,
    170                                           const Tensor& value) {
    171   NamedTensorProto* feed = request_.add_feed();
    172   feed->set_name(name);
    173   TensorProto* value_proto = feed->mutable_tensor();
    174   value.AsProtoTensorContent(value_proto);
    175 }
    176 
    177 size_t MutableProtoRunStepRequest::num_fetches() const {
    178   return request_.fetch_size();
    179 }
    180 
    181 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
    182   return request_.fetch(i);
    183 }
    184 void MutableProtoRunStepRequest::add_fetch(const string& name) {
    185   request_.add_fetch(name);
    186 }
    187 
    188 size_t MutableProtoRunStepRequest::num_targets() const {
    189   return request_.target_size();
    190 }
    191 
    192 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
    193   return request_.target(i);
    194 }
    195 
    196 void MutableProtoRunStepRequest::add_target(const string& name) {
    197   request_.add_target(name);
    198 }
    199 
    200 const RunOptions& MutableProtoRunStepRequest::options() const {
    201   return request_.options();
    202 }
    203 
    204 RunOptions* MutableProtoRunStepRequest::mutable_options() {
    205   return request_.mutable_options();
    206 }
    207 
    208 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
    209   return request_.store_errors_in_response_body();
    210 }
    211 
    212 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
    213     bool store_errors) {
    214   request_.set_store_errors_in_response_body(store_errors);
    215 }
    216 
    217 int64 MutableProtoRunStepRequest::request_id() const {
    218   return request_.request_id();
    219 }
    220 
    221 string MutableProtoRunStepRequest::DebugString() const {
    222   return request_.DebugString();
    223 }
    224 
    225 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
    226   return request_;
    227 }
    228 
    229 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
    230     : request_(request) {}
    231 
    232 const string& ProtoRunStepRequest::session_handle() const {
    233   return request_->session_handle();
    234 }
    235 
    236 const string& ProtoRunStepRequest::partial_run_handle() const {
    237   return request_->partial_run_handle();
    238 }
    239 
    240 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
    241 
    242 const string& ProtoRunStepRequest::feed_name(size_t i) const {
    243   return request_->feed(i).name();
    244 }
    245 
    246 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
    247   if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
    248     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
    249   } else {
    250     return Status::OK();
    251   }
    252 }
    253 
    254 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
    255   *out_tensor = request_->feed(i).tensor();
    256   return Status::OK();
    257 }
    258 
    259 size_t ProtoRunStepRequest::num_fetches() const {
    260   return request_->fetch_size();
    261 }
    262 
    263 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
    264   return request_->fetch(i);
    265 }
    266 
    267 size_t ProtoRunStepRequest::num_targets() const {
    268   return request_->target_size();
    269 }
    270 
    271 const string& ProtoRunStepRequest::target_name(size_t i) const {
    272   return request_->target(i);
    273 }
    274 
    275 const RunOptions& ProtoRunStepRequest::options() const {
    276   return request_->options();
    277 }
    278 
    279 bool ProtoRunStepRequest::store_errors_in_response_body() const {
    280   return request_->store_errors_in_response_body();
    281 }
    282 
    283 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
    284 
    285 string ProtoRunStepRequest::DebugString() const {
    286   return request_->DebugString();
    287 }
    288 
    289 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
    290 
    291 const string& InMemoryRunGraphRequest::session_handle() const {
    292   return session_handle_;
    293 }
    294 
    295 bool InMemoryRunGraphRequest::create_worker_session_called() const {
    296   return create_worker_session_called_;
    297 }
    298 
    299 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
    300   session_handle_ = handle;
    301 }
    302 
    303 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
    304   create_worker_session_called_ = called;
    305 }
    306 
    307 const string& InMemoryRunGraphRequest::graph_handle() const {
    308   return graph_handle_;
    309 }
    310 
    311 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
    312   graph_handle_ = handle;
    313 }
    314 
    315 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
    316 
    317 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; }
    318 
    319 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
    320   return exec_opts_;
    321 }
    322 
    323 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
    324   return &exec_opts_;
    325 }
    326 
    327 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
    328 
    329 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
    330   return sends_[i].first;
    331 }
    332 
    333 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
    334   *out_tensor = sends_[i].second;
    335   return Status::OK();
    336 }
    337 
    338 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
    339     const RunStepRequestWrapper& run_step_request, size_t i,
    340     const string& send_key) {
    341   Tensor tensor;
    342   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
    343   sends_.emplace_back(send_key, std::move(tensor));
    344   return Status::OK();
    345 }
    346 
    347 // TODO(b/74355905): Add a specialized implementation that avoids
    348 // copying the tensor when at least two of the {client, master,
    349 // worker} are in the same process.
    350 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
    351     const RunCallableRequest& run_callable_request, size_t i,
    352     const string& send_key) {
    353   Tensor tensor;
    354   if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
    355     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
    356   }
    357   sends_.emplace_back(send_key, std::move(tensor));
    358   return Status::OK();
    359 }
    360 
    361 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
    362 
    363 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
    364   return recvs_[i];
    365 }
    366 
    367 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
    368   recvs_.push_back(recv_key);
    369 }
    370 
    371 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
    372 
    373 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
    374   is_partial_ = is_partial;
    375 }
    376 
    377 bool InMemoryRunGraphRequest::is_last_partial_run() const {
    378   return is_last_partial_run_;
    379 }
    380 
    381 void InMemoryRunGraphRequest::set_is_last_partial_run(
    382     bool is_last_partial_run) {
    383   is_last_partial_run_ = is_last_partial_run;
    384 }
    385 
    386 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
    387   return store_errors_in_response_body_;
    388 }
    389 
    390 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
    391     bool store_errors) {
    392   store_errors_in_response_body_ = store_errors;
    393 }
    394 
    395 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
    396   if (!proto_version_) {
    397     proto_version_.reset(new RunGraphRequest);
    398     proto_version_->set_session_handle(session_handle());
    399     proto_version_->set_create_worker_session_called(
    400         create_worker_session_called());
    401     proto_version_->set_graph_handle(graph_handle());
    402     proto_version_->set_step_id(step_id());
    403     *proto_version_->mutable_exec_opts() = exec_opts();
    404     for (size_t i = 0; i < num_sends(); ++i) {
    405       auto send = proto_version_->add_send();
    406       send->set_name(send_key(i));
    407       sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
    408     }
    409     for (size_t i = 0; i < num_recvs(); ++i) {
    410       proto_version_->add_recv_key(recv_key(i));
    411     }
    412     proto_version_->set_is_partial(is_partial());
    413     proto_version_->set_is_last_partial_run(is_last_partial_run());
    414   }
    415   return *proto_version_;
    416 }
    417 
    418 const string& MutableProtoRunGraphRequest::session_handle() const {
    419   return request_.session_handle();
    420 }
    421 
    422 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
    423   request_.set_session_handle(handle);
    424 }
    425 
    426 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
    427   return request_.create_worker_session_called();
    428 }
    429 
    430 void MutableProtoRunGraphRequest::set_create_worker_session_called(
    431     bool called) {
    432   request_.set_create_worker_session_called(called);
    433 }
    434 
    435 const string& MutableProtoRunGraphRequest::graph_handle() const {
    436   return request_.graph_handle();
    437 }
    438 
    439 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
    440   request_.set_graph_handle(handle);
    441 }
    442 
    443 int64 MutableProtoRunGraphRequest::step_id() const {
    444   return request_.step_id();
    445 }
    446 
    447 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) {
    448   request_.set_step_id(step_id);
    449 }
    450 
    451 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
    452   return request_.exec_opts();
    453 }
    454 
    455 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
    456   return request_.mutable_exec_opts();
    457 }
    458 
    459 size_t MutableProtoRunGraphRequest::num_sends() const {
    460   return request_.send_size();
    461 }
    462 
    463 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
    464   return request_.send(i).name();
    465 }
    466 
    467 Status MutableProtoRunGraphRequest::SendValue(size_t i,
    468                                               Tensor* out_tensor) const {
    469   if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
    470     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
    471   } else {
    472     return Status::OK();
    473   }
    474 }
    475 
    476 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
    477     const RunStepRequestWrapper& run_step_request, size_t i,
    478     const string& send_key) {
    479   NamedTensorProto* send = request_.add_send();
    480   send->set_name(send_key);
    481   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
    482   return Status::OK();
    483 }
    484 
    485 // TODO(b/74355905): Add a specialized implementation that avoids
    486 // copying the tensor when at least two of the {client, master,
    487 // worker} are in the same process.
    488 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
    489     const RunCallableRequest& run_callable_request, size_t i,
    490     const string& send_key) {
    491   NamedTensorProto* send = request_.add_send();
    492   send->set_name(send_key);
    493   *send->mutable_tensor() = run_callable_request.feed(i);
    494   return Status::OK();
    495 }
    496 
    497 size_t MutableProtoRunGraphRequest::num_recvs() const {
    498   return request_.recv_key_size();
    499 }
    500 
    501 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
    502   return request_.recv_key(i);
    503 }
    504 
    505 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
    506   request_.add_recv_key(recv_key);
    507 }
    508 
    509 bool MutableProtoRunGraphRequest::is_partial() const {
    510   return request_.is_partial();
    511 }
    512 
    513 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
    514   request_.set_is_partial(is_partial);
    515 }
    516 
    517 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
    518   return request_.is_last_partial_run();
    519 }
    520 
    521 void MutableProtoRunGraphRequest::set_is_last_partial_run(
    522     bool is_last_partial_run) {
    523   request_.set_is_last_partial_run(is_last_partial_run);
    524 }
    525 
    526 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
    527   return request_.store_errors_in_response_body();
    528 }
    529 
    530 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
    531     bool store_errors) {
    532   request_.set_store_errors_in_response_body(store_errors);
    533 }
    534 
    535 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
    536   return request_;
    537 }
    538 
    539 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
    540     : request_(request) {}
    541 
    542 const string& ProtoRunGraphRequest::session_handle() const {
    543   return request_->session_handle();
    544 }
    545 
    546 bool ProtoRunGraphRequest::create_worker_session_called() const {
    547   return request_->create_worker_session_called();
    548 }
    549 
    550 const string& ProtoRunGraphRequest::graph_handle() const {
    551   return request_->graph_handle();
    552 }
    553 
    554 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
    555 
    556 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
    557   return request_->exec_opts();
    558 }
    559 
    560 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
    561 
    562 const string& ProtoRunGraphRequest::send_key(size_t i) const {
    563   return request_->send(i).name();
    564 }
    565 
    566 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
    567   if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
    568     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
    569   } else {
    570     return Status::OK();
    571   }
    572 }
    573 
    574 size_t ProtoRunGraphRequest::num_recvs() const {
    575   return request_->recv_key_size();
    576 }
    577 
    578 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
    579   return request_->recv_key(i);
    580 }
    581 
    582 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
    583 
    584 bool ProtoRunGraphRequest::is_last_partial_run() const {
    585   return request_->is_last_partial_run();
    586 }
    587 
    588 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
    589   return request_->store_errors_in_response_body();
    590 }
    591 
    592 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
    593   return *request_;
    594 }
    595 
    596 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
    597 
    598 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
    599   return recvs_[i].first;
    600 }
    601 
    602 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
    603   recvs_[i].second.AsProtoTensorContent(out_tensor);
    604   return Status::OK();
    605 }
    606 
    607 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
    608   *out_tensor = recvs_[i].second;
    609   return Status::OK();
    610 }
    611 
    612 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
    613   recvs_.emplace_back(key, value);
    614 }
    615 
    616 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
    617   return &step_stats_;
    618 }
    619 
    620 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
    621   return &cost_graph_;
    622 }
    623 
    624 errors::Code InMemoryRunGraphResponse::status_code() const {
    625   return status_.code();
    626 }
    627 
    628 const string& InMemoryRunGraphResponse::status_error_message() const {
    629   return status_.error_message();
    630 }
    631 
    632 void InMemoryRunGraphResponse::set_status(const Status& status) {
    633   status_ = status;
    634 }
    635 
    636 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
    637   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
    638   return nullptr;
    639 }
    640 
    641 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
    642   return partition_graphs_.size();
    643 }
    644 
    645 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
    646   return &partition_graphs_[i];
    647 }
    648 
    649 void InMemoryRunGraphResponse::AddPartitionGraph(
    650     const GraphDef& partition_graph) {
    651   partition_graphs_.push_back(partition_graph);
    652 }
    653 
    654 size_t OwnedProtoRunGraphResponse::num_recvs() const {
    655   return response_.recv_size();
    656 }
    657 
    658 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
    659   return response_.recv(i).name();
    660 }
    661 
    662 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
    663                                              TensorProto* out_tensor) {
    664   out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
    665   return Status::OK();
    666 }
    667 
    668 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
    669   if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
    670     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
    671   } else {
    672     return Status::OK();
    673   }
    674 }
    675 
    676 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
    677                                          const Tensor& value) {
    678   NamedTensorProto* recv = response_.add_recv();
    679   recv->set_name(key);
    680   TensorProto* value_proto = recv->mutable_tensor();
    681   value.AsProtoTensorContent(value_proto);
    682 }
    683 
    684 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
    685   return response_.mutable_step_stats();
    686 }
    687 
    688 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
    689   return response_.mutable_cost_graph();
    690 }
    691 
    692 errors::Code OwnedProtoRunGraphResponse::status_code() const {
    693   return response_.status_code();
    694 }
    695 
    696 const string& OwnedProtoRunGraphResponse::status_error_message() const {
    697   return response_.status_error_message();
    698 }
    699 
    700 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
    701   response_.set_status_code(status.code());
    702   response_.set_status_error_message(status.error_message());
    703 }
    704 
    705 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
    706 
    707 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
    708   return response_.partition_graph_size();
    709 }
    710 
    711 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
    712   return response_.mutable_partition_graph(i);
    713 }
    714 
    715 void OwnedProtoRunGraphResponse::AddPartitionGraph(
    716     const GraphDef& partition_graph) {
    717   GraphDef* graph_def = response_.mutable_partition_graph()->Add();
    718   *graph_def = partition_graph;
    719 }
    720 
    721 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
    722     RunGraphResponse* response)
    723     : response_(response) {}
    724 
    725 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
    726   return response_->recv_size();
    727 }
    728 
    729 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
    730   return response_->recv(i).name();
    731 }
    732 
    733 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
    734                                                 TensorProto* out_tensor) {
    735   out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
    736   return Status::OK();
    737 }
    738 
    739 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
    740   if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
    741     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
    742   } else {
    743     return Status::OK();
    744   }
    745 }
    746 
    747 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
    748                                             const Tensor& value) {
    749   NamedTensorProto* recv = response_->add_recv();
    750   recv->set_name(key);
    751   TensorProto* value_proto = recv->mutable_tensor();
    752   value.AsProtoTensorContent(value_proto);
    753 }
    754 
    755 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
    756   return response_->mutable_step_stats();
    757 }
    758 
    759 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
    760   return response_->mutable_cost_graph();
    761 }
    762 
    763 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
    764   return response_->status_code();
    765 }
    766 
    767 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
    768   return response_->status_error_message();
    769 }
    770 
    771 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
    772   response_->set_status_code(status.code());
    773   response_->set_status_error_message(status.error_message());
    774 }
    775 
    776 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
    777   return response_;
    778 }
    779 
    780 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
    781   return response_->partition_graph_size();
    782 }
    783 
    784 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
    785   return response_->mutable_partition_graph(i);
    786 }
    787 
    788 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
    789     const GraphDef& partition_graph) {
    790   GraphDef* graph_def = response_->add_partition_graph();
    791   *graph_def = partition_graph;
    792 }
    793 
    794 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
    795 
    796 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
    797 
    798 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
    799   return tensors_[i].first;
    800 }
    801 
    802 Status InMemoryRunStepResponse::TensorValue(size_t i,
    803                                             Tensor* out_tensor) const {
    804   *out_tensor = tensors_[i].second;
    805   return Status::OK();
    806 }
    807 
    808 const RunMetadata& InMemoryRunStepResponse::metadata() const {
    809   return metadata_;
    810 }
    811 
    812 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
    813     const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
    814   Tensor tensor;
    815   TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
    816   tensors_.emplace_back(name, tensor);
    817   return Status::OK();
    818 }
    819 
    820 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
    821 
    822 errors::Code InMemoryRunStepResponse::status_code() const {
    823   return status_.code();
    824 }
    825 
    826 const string& InMemoryRunStepResponse::status_error_message() const {
    827   return status_.error_message();
    828 }
    829 
    830 void InMemoryRunStepResponse::set_status(const Status& status) {
    831   status_ = status;
    832 }
    833 
    834 RunStepResponse* InMemoryRunStepResponse::get_proto() {
    835   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
    836   return nullptr;
    837 }
    838 
    839 size_t OwnedProtoRunStepResponse::num_tensors() const {
    840   return response_.tensor_size();
    841 }
    842 
    843 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
    844   return response_.tensor(i).name();
    845 }
    846 
    847 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
    848                                               Tensor* out_tensor) const {
    849   if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
    850     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
    851   } else {
    852     return Status::OK();
    853   }
    854 }
    855 
    856 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
    857   return response_.metadata();
    858 }
    859 
    860 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
    861     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    862     size_t i) {
    863   NamedTensorProto* response_tensor = response_.add_tensor();
    864   response_tensor->set_name(name);
    865   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
    866 }
    867 
    868 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
    869   return response_.mutable_metadata();
    870 }
    871 
    872 errors::Code OwnedProtoRunStepResponse::status_code() const {
    873   return response_.status_code();
    874 }
    875 
    876 const string& OwnedProtoRunStepResponse::status_error_message() const {
    877   return response_.status_error_message();
    878 }
    879 
    880 void OwnedProtoRunStepResponse::set_status(const Status& status) {
    881   response_.set_status_code(status.code());
    882   response_.set_status_error_message(status.error_message());
    883 }
    884 
    885 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
    886 
    887 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
    888     RunStepResponse* response)
    889     : response_(response) {}
    890 
    891 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
    892   return response_->tensor_size();
    893 }
    894 
    895 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
    896   return response_->tensor(i).name();
    897 }
    898 
    899 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
    900                                                  Tensor* out_tensor) const {
    901   if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
    902     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
    903   } else {
    904     return Status::OK();
    905   }
    906 }
    907 
    908 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
    909   return response_->metadata();
    910 }
    911 
    912 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
    913     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    914     size_t i) {
    915   NamedTensorProto* response_tensor = response_->add_tensor();
    916   response_tensor->set_name(name);
    917   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
    918 }
    919 
    920 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
    921   return response_->mutable_metadata();
    922 }
    923 
    924 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
    925   return response_->status_code();
    926 }
    927 
    928 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
    929   return response_->status_error_message();
    930 }
    931 
    932 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
    933   response_->set_status_code(status.code());
    934   response_->set_status_error_message(status.error_message());
    935 }
    936 
    937 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
    938 
    939 }  // namespace tensorflow
    940