1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ 18 19 #include "tensorflow/core/framework/allocator.h" 20 #include "tensorflow/core/framework/cost_graph.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/step_stats.pb.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor.pb_text.h" 25 #include "tensorflow/core/framework/versions.pb.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 #include "tensorflow/core/protobuf/master.pb.h" 28 #include "tensorflow/core/protobuf/worker.pb.h" 29 30 namespace tensorflow { 31 32 //////////////////////////////////////////////////////////////////////////////// 33 // 34 // Wrapper classes for the `MasterService.RunStep` request message. 35 // 36 // The `RunStepRequest` message can contain potentially large tensor 37 // data as part of its `feed` submessages. Here we provide specialized 38 // wrappers that avoid copying the tensor data wherever possible. 39 // 40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the 41 // protocol buffer definition. 42 // 43 //////////////////////////////////////////////////////////////////////////////// 44 45 // Abstract interface for an immutable RunStepRequest message. 46 // 47 // This interface is typically used by server-side components in the 48 // TensorFlow master. 49 class RunStepRequestWrapper { 50 public: 51 virtual ~RunStepRequestWrapper() {} 52 53 // REQUIRED: session_handle must be returned by a CreateSession call 54 // to the same master service. 55 virtual const string& session_handle() const = 0; 56 57 // Partial run handle (optional). If specified, this will be a partial run 58 // execution, run up to the specified fetches. 59 virtual const string& partial_run_handle() const = 0; 60 61 // Tensors to be fed in the step. Each feed is a named tensor. 62 virtual size_t num_feeds() const = 0; 63 virtual const string& feed_name(size_t i) const = 0; 64 65 // Stores the content of the feed value at index `i` in `tensor`. 66 virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; 67 virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; 68 69 // Fetches. A list of tensor names. The caller expects a tensor to 70 // be returned for each fetch[i] (see RunStepResponse.tensor). The 71 // order of specified fetches does not change the execution order. 72 virtual size_t num_fetches() const = 0; 73 virtual const string& fetch_name(size_t i) const = 0; 74 75 // Target Nodes. A list of node names. The named nodes will be run 76 // to but their outputs will not be fetched. 77 virtual size_t num_targets() const = 0; 78 virtual const string& target_name(size_t i) const = 0; 79 80 // Options for the run call. 81 virtual const RunOptions& options() const = 0; 82 83 // If true then some errors, e.g., execution errors that have long 84 // error messages, may return an OK RunStepResponse with the actual 85 // error saved in the status_code/status_error_message fields of the 86 // response body. This is a workaround since the RPC subsystem may 87 // truncate long metadata messages. 88 virtual bool store_errors_in_response_body() const = 0; 89 90 // Returns a human-readable representation of this message for debugging. 91 virtual string DebugString() const = 0; 92 93 // Returns the wrapped data as a protocol buffer message. 94 virtual const RunStepRequest& ToProto() const = 0; 95 }; 96 97 // Abstract interface for a mutable RunStepRequest message. 98 // 99 // See `RunStepRequestWrapper` above for a description of the fields. 100 class MutableRunStepRequestWrapper : public RunStepRequestWrapper { 101 public: 102 virtual void set_session_handle(const string& handle) = 0; 103 virtual void set_partial_run_handle(const string& handle) = 0; 104 virtual void add_feed(const string& name, const Tensor& value) = 0; 105 virtual void add_fetch(const string& name) = 0; 106 virtual void add_target(const string& name) = 0; 107 virtual RunOptions* mutable_options() = 0; 108 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 109 }; 110 111 // Specialized (and mutable) wrapper for RunStep requests between a client and 112 // master in the same address space. 113 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { 114 public: 115 // RunStepRequestWrapper methods. 116 const string& session_handle() const override; 117 const string& partial_run_handle() const override; 118 size_t num_feeds() const override; 119 const string& feed_name(size_t i) const override; 120 Status FeedValue(size_t i, Tensor* out_tensor) const override; 121 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 122 size_t num_fetches() const override; 123 const string& fetch_name(size_t i) const override; 124 size_t num_targets() const override; 125 const string& target_name(size_t i) const override; 126 const RunOptions& options() const override; 127 string DebugString() const override; 128 const RunStepRequest& ToProto() const override; 129 bool store_errors_in_response_body() const override; 130 131 // MutableRunStepRequestWrapper methods. 132 void set_session_handle(const string& handle) override; 133 void set_partial_run_handle(const string& handle) override; 134 void add_feed(const string& name, const Tensor& value) override; 135 void add_fetch(const string& name) override; 136 void add_target(const string& name) override; 137 RunOptions* mutable_options() override; 138 void set_store_errors_in_response_body(bool store_errors) override; 139 140 private: 141 string session_handle_; 142 string partial_run_handle_; 143 gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_; 144 gtl::InlinedVector<string, 4> fetches_; 145 gtl::InlinedVector<string, 4> targets_; 146 RunOptions options_; 147 bool store_errors_in_response_body_ = false; 148 149 // Holds a cached and owned representation of the proto 150 // representation of this request, if needed, so that `ToProto()` 151 // can return a const RunStepRequest&. 152 // NOTE(mrry): Although calls to `ToProto()` on this class are 153 // expected to be rare, retaining ownership of the returned message 154 // makes it easier to return a reference from the proto-backed 155 // representations. 156 mutable std::unique_ptr<RunStepRequest> proto_version_; 157 }; 158 159 // Wrapper for mutable RunStep requests that uses a protobuf message. 160 // 161 // This wrapper class should be used for RunStep requests between a 162 // client and master in different address spaces. 163 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { 164 public: 165 // RunStepRequestWrapper methods. 166 const string& session_handle() const override; 167 const string& partial_run_handle() const override; 168 size_t num_feeds() const override; 169 const string& feed_name(size_t i) const override; 170 Status FeedValue(size_t i, Tensor* out_tensor) const override; 171 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 172 size_t num_fetches() const override; 173 const string& fetch_name(size_t i) const override; 174 size_t num_targets() const override; 175 const string& target_name(size_t i) const override; 176 const RunOptions& options() const override; 177 string DebugString() const override; 178 const RunStepRequest& ToProto() const override; 179 bool store_errors_in_response_body() const override; 180 181 // MutableRunStepRequestWrapper methods. 182 void set_session_handle(const string& handle) override; 183 void set_partial_run_handle(const string& handle) override; 184 void add_feed(const string& name, const Tensor& value) override; 185 void add_fetch(const string& name) override; 186 void add_target(const string& name) override; 187 RunOptions* mutable_options() override; 188 void set_store_errors_in_response_body(bool store_errors) override; 189 190 private: 191 RunStepRequest request_; 192 }; 193 194 // Wrapper for immutable RunStep requests that use a non-owned 195 // protobuf message. 196 // 197 // This interface is typically used by server-side components in the 198 // TensorFlow master, where the incoming message is a (possibly const) 199 // `RunStepRequest*`. 200 class ProtoRunStepRequest : public RunStepRequestWrapper { 201 public: 202 ProtoRunStepRequest(const RunStepRequest* request); 203 204 // RunStepRequestWrapper methods. 205 const string& session_handle() const override; 206 const string& partial_run_handle() const override; 207 size_t num_feeds() const override; 208 const string& feed_name(size_t i) const override; 209 Status FeedValue(size_t i, Tensor* out_tensor) const override; 210 Status FeedValue(size_t i, TensorProto* out_tensor) const override; 211 size_t num_fetches() const override; 212 const string& fetch_name(size_t i) const override; 213 size_t num_targets() const override; 214 const string& target_name(size_t i) const override; 215 const RunOptions& options() const override; 216 string DebugString() const override; 217 const RunStepRequest& ToProto() const override; 218 bool store_errors_in_response_body() const override; 219 220 private: 221 const RunStepRequest* const request_; // Not owned. 222 }; 223 224 //////////////////////////////////////////////////////////////////////////////// 225 // 226 // Wrapper classes for the `WorkerService.RunGraph` request message. 227 // 228 // The `RunGraphRequest` message can contain potentially large tensor 229 // data as part of its `send` submessages. Here we provide specialized 230 // wrappers that avoid copying the tensor data wherever possible. 231 // 232 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the 233 // protocol buffer definition. 234 // 235 //////////////////////////////////////////////////////////////////////////////// 236 237 // Abstract interface for an immutable RunStepRequest message. 238 // 239 // This interface is typically used by server-side components in the 240 // TensorFlow worker. 241 class RunGraphRequestWrapper { 242 public: 243 virtual ~RunGraphRequestWrapper() {} 244 245 // The session handle used to register the graph. If empty, a single global 246 // namespace is used. 247 virtual const string& session_handle() const = 0; 248 249 // REQUIRED: graph_handle must be returned by a RegisterGraph call 250 // to the same WorkerService. 251 virtual const string& graph_handle() const = 0; 252 253 // A unique ID to distinguish different runs of the same graph. 254 // 255 // The master generates a global unique `step_id` to distinguish 256 // different runs of the graph computation. Subgraphs communicate 257 // (e.g., send/recv ops) with each other using `step_id` to 258 // distinguish tensors generated by different runs. 259 virtual int64 step_id() const = 0; 260 261 // Options for this step. 262 virtual const ExecutorOpts& exec_opts() const = 0; 263 264 // Sends the tensors in "send" into the graph before the run. 265 virtual size_t num_sends() const = 0; 266 virtual const string& send_key(size_t i) const = 0; 267 virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; 268 269 // Fetches the keys into `RunGraphResponse.recv` after the run. 270 virtual size_t num_recvs() const = 0; 271 virtual const string& recv_key(size_t i) const = 0; 272 273 // True if the RunGraphRequest is a partial run request. 274 virtual bool is_partial() const = 0; 275 276 // True if this is the last partial run request in a sequence of requests. 277 virtual bool is_last_partial_run() const = 0; 278 279 // If true then some errors, e.g., execution errors that have long 280 // error messages, may return an OK RunStepResponse with the actual 281 // error saved in the status_code/status_error_message fields of the 282 // response body. This is a workaround since the RPC subsystem may 283 // truncate long metadata messages. 284 virtual bool store_errors_in_response_body() const = 0; 285 286 // Returns the wrapped data as a protocol buffer message. 287 virtual const RunGraphRequest& ToProto() const = 0; 288 }; 289 290 // Abstract interface for a mutable RunGraphRequest message. 291 // 292 // See `RunGraphRequestWrapper` above for a description of the fields. 293 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { 294 public: 295 virtual void set_session_handle(const string& handle) = 0; 296 virtual void set_graph_handle(const string& handle) = 0; 297 virtual void set_step_id(int64 step_id) = 0; 298 virtual ExecutorOpts* mutable_exec_opts() = 0; 299 300 // Stores the i^{th} feed value in `run_step_request` in this 301 // request with the given `send_key`. 302 virtual Status AddSendFromRunStepRequest( 303 const RunStepRequestWrapper& run_step_request, size_t i, 304 const string& send_key) = 0; 305 306 virtual void add_recv_key(const string& recv_key) = 0; 307 virtual void set_is_partial(bool is_partial) = 0; 308 virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; 309 virtual void set_store_errors_in_response_body(bool store_errors) = 0; 310 }; 311 312 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { 313 public: 314 // RunGraphRequestWrapper methods. 315 const string& session_handle() const override; 316 const string& graph_handle() const override; 317 int64 step_id() const override; 318 const ExecutorOpts& exec_opts() const override; 319 size_t num_sends() const override; 320 const string& send_key(size_t i) const override; 321 Status SendValue(size_t i, Tensor* out_tensor) const override; 322 size_t num_recvs() const override; 323 const string& recv_key(size_t i) const override; 324 bool is_partial() const override; 325 bool is_last_partial_run() const override; 326 const RunGraphRequest& ToProto() const override; 327 bool store_errors_in_response_body() const override; 328 329 // MutableRunGraphRequestWrapper methods. 330 void set_session_handle(const string& handle) override; 331 void set_graph_handle(const string& handle) override; 332 void set_step_id(int64 step_id) override; 333 ExecutorOpts* mutable_exec_opts() override; 334 Status AddSendFromRunStepRequest( 335 const RunStepRequestWrapper& run_step_request, size_t i, 336 const string& send_key) override; 337 void add_recv_key(const string& recv_key) override; 338 void set_is_partial(bool is_partial) override; 339 void set_is_last_partial_run(bool is_last_partial_run) override; 340 void set_store_errors_in_response_body(bool store_errors) override; 341 342 private: 343 string session_handle_; 344 string graph_handle_; 345 int64 step_id_; 346 ExecutorOpts exec_opts_; 347 gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_; 348 gtl::InlinedVector<string, 4> recvs_; 349 bool is_partial_ = false; 350 bool is_last_partial_run_ = false; 351 bool store_errors_in_response_body_ = false; 352 353 // Holds a cached and owned representation of the proto 354 // representation of this request, if needed, so that `ToProto()` 355 // can return a const RunGraphRequest&. 356 // NOTE(mrry): Although calls to `ToProto()` on this class are 357 // expected to be rare, retaining ownership of the returned message 358 // makes it easier to return a reference from the proto-backed 359 // representations. 360 mutable std::unique_ptr<RunGraphRequest> proto_version_; 361 }; 362 363 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { 364 public: 365 // RunGraphRequestWrapper methods. 366 const string& session_handle() const override; 367 const string& graph_handle() const override; 368 int64 step_id() const override; 369 const ExecutorOpts& exec_opts() const override; 370 size_t num_sends() const override; 371 const string& send_key(size_t i) const override; 372 Status SendValue(size_t i, Tensor* out_tensor) const override; 373 size_t num_recvs() const override; 374 const string& recv_key(size_t i) const override; 375 bool is_partial() const override; 376 bool is_last_partial_run() const override; 377 bool store_errors_in_response_body() const override; 378 const RunGraphRequest& ToProto() const override; 379 380 // MutableRunGraphRequestWrapper methods. 381 void set_session_handle(const string& handle) override; 382 void set_graph_handle(const string& handle) override; 383 void set_step_id(int64 step_id) override; 384 ExecutorOpts* mutable_exec_opts() override; 385 Status AddSendFromRunStepRequest( 386 const RunStepRequestWrapper& run_step_request, size_t i, 387 const string& send_key) override; 388 void add_recv_key(const string& recv_key) override; 389 void set_is_partial(bool is_partial) override; 390 void set_is_last_partial_run(bool is_last_partial_run) override; 391 void set_store_errors_in_response_body(bool store_errors) override; 392 393 private: 394 RunGraphRequest request_; 395 }; 396 397 class ProtoRunGraphRequest : public RunGraphRequestWrapper { 398 public: 399 ProtoRunGraphRequest(const RunGraphRequest* request); 400 401 // RunGraphRequestWrapper methods. 402 const string& session_handle() const override; 403 const string& graph_handle() const override; 404 int64 step_id() const override; 405 const ExecutorOpts& exec_opts() const override; 406 size_t num_sends() const override; 407 const string& send_key(size_t i) const override; 408 Status SendValue(size_t i, Tensor* out_tensor) const override; 409 size_t num_recvs() const override; 410 const string& recv_key(size_t i) const override; 411 bool is_partial() const override; 412 bool is_last_partial_run() const override; 413 bool store_errors_in_response_body() const override; 414 const RunGraphRequest& ToProto() const override; 415 416 private: 417 const RunGraphRequest* const request_; // Not owned. 418 }; 419 420 //////////////////////////////////////////////////////////////////////////////// 421 // 422 // Wrapper classes for the `WorkerService.RunGraph` response message. 423 // 424 // The `RunGraphResponse` message can contain potentially large tensor 425 // data as part of its `recv` submessages. Here we provide specialized 426 // wrappers that avoid copying the tensor data wherever possible. 427 // 428 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the 429 // protocol buffer definition. 430 // 431 //////////////////////////////////////////////////////////////////////////////// 432 433 // Abstract interface for a mutable RunGraphResponse message. 434 // 435 // Note that there is no corresponding (immutable) 436 // RunGraphResponseWrapper class, because the RunGraphResponse object 437 // is always used as a mutable pointer. 438 class MutableRunGraphResponseWrapper { 439 public: 440 virtual ~MutableRunGraphResponseWrapper() {} 441 442 // A list of tensors corresponding to those requested by 443 // `RunGraphRequest.recv_key`. 444 virtual size_t num_recvs() const = 0; 445 virtual const string& recv_key(size_t i) const = 0; 446 // NOTE: The following methods may perform a destructive read, for 447 // efficiency. 448 virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; 449 virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; 450 virtual void AddRecv(const string& key, const Tensor& value) = 0; 451 452 // Submessages that store performance statistics about the subgraph 453 // execution, if necessary. 454 virtual StepStats* mutable_step_stats() = 0; 455 virtual CostGraphDef* mutable_cost_graph() = 0; 456 virtual size_t num_partition_graphs() const = 0; 457 virtual GraphDef* mutable_partition_graph(size_t i) = 0; 458 virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; 459 460 // Returned status if requested. 461 virtual errors::Code status_code() const = 0; 462 virtual const string& status_error_message() const = 0; 463 virtual void set_status(const Status& status) = 0; 464 465 protected: 466 // Returns a mutable protobuf message that represents the contents of 467 // this wrapper, for passing to an RPC subsystem that will populate 468 // the message. 469 // 470 // NOTE: Only `WorkerInterface` subclasses may call this method. The 471 // `InMemoryRunGraphResponse` subclass does not implement this 472 // method, and attempts to call it will fail with a fatal 473 // error. However, as long as callers always call 474 // `WorkerInterface::RunGraphAsync()` with a wrapper object returned 475 // from `WorkerInterface::CreateRunGraphResponse()` called on the 476 // *same* WorkerInterface object, this error will never trigger. 477 virtual RunGraphResponse* get_proto() = 0; 478 friend class WorkerInterface; 479 }; 480 481 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { 482 public: 483 // MutableRunGraphResponseWrapper methods. 484 size_t num_recvs() const override; 485 const string& recv_key(size_t i) const override; 486 Status RecvValue(size_t i, TensorProto* out_tensor) override; 487 Status RecvValue(size_t i, Tensor* out_tensor) override; 488 void AddRecv(const string& key, const Tensor& value) override; 489 StepStats* mutable_step_stats() override; 490 CostGraphDef* mutable_cost_graph() override; 491 size_t num_partition_graphs() const override; 492 GraphDef* mutable_partition_graph(size_t i) override; 493 void AddPartitionGraph(const GraphDef& partition_graph) override; 494 errors::Code status_code() const override; 495 const string& status_error_message() const override; 496 void set_status(const Status& status) override; 497 498 protected: 499 // NOTE: This method is not implemented. See 500 // MutableRunGraphResponseWrapper for an explanation. 501 RunGraphResponse* get_proto() override; 502 503 private: 504 gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; 505 StepStats step_stats_; 506 CostGraphDef cost_graph_; 507 std::vector<GraphDef> partition_graphs_; 508 // Store the code and message separately so that they can be updated 509 // independently by setters. 510 Status status_; 511 }; 512 513 // Proto-based message wrapper for use on the client side of the RunGraph RPC. 514 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 515 public: 516 // MutableRunGraphResponseWrapper methods. 517 size_t num_recvs() const override; 518 const string& recv_key(size_t i) const override; 519 Status RecvValue(size_t i, TensorProto* out_tensor) override; 520 Status RecvValue(size_t i, Tensor* out_tensor) override; 521 void AddRecv(const string& key, const Tensor& value) override; 522 StepStats* mutable_step_stats() override; 523 CostGraphDef* mutable_cost_graph() override; 524 size_t num_partition_graphs() const override; 525 GraphDef* mutable_partition_graph(size_t i) override; 526 void AddPartitionGraph(const GraphDef& partition_graph) override; 527 errors::Code status_code() const override; 528 const string& status_error_message() const override; 529 void set_status(const Status& status) override; 530 531 protected: 532 RunGraphResponse* get_proto() override; 533 534 private: 535 RunGraphResponse response_; 536 }; 537 538 // Proto-based message wrapper for use on the server side of the RunGraph RPC. 539 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { 540 public: 541 NonOwnedProtoRunGraphResponse(RunGraphResponse* response); 542 543 // MutableRunGraphResponseWrapper methods. 544 size_t num_recvs() const override; 545 const string& recv_key(size_t i) const override; 546 Status RecvValue(size_t i, TensorProto* out_tensor) override; 547 Status RecvValue(size_t i, Tensor* out_tensor) override; 548 void AddRecv(const string& key, const Tensor& value) override; 549 StepStats* mutable_step_stats() override; 550 CostGraphDef* mutable_cost_graph() override; 551 size_t num_partition_graphs() const override; 552 GraphDef* mutable_partition_graph(size_t i) override; 553 void AddPartitionGraph(const GraphDef& partition_graph) override; 554 errors::Code status_code() const override; 555 const string& status_error_message() const override; 556 void set_status(const Status& status) override; 557 558 protected: 559 RunGraphResponse* get_proto() override; 560 561 private: 562 RunGraphResponse* const response_; 563 }; 564 565 //////////////////////////////////////////////////////////////////////////////// 566 // 567 // Wrapper classes for the `MasterService.RunStep` response message. 568 // 569 // The `RunStepResponse` message can contain potentially large tensor 570 // data as part of its `tensor` submessages. Here we provide specialized 571 // wrappers that avoid copying the tensor data wherever possible. 572 // 573 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the 574 // protocol buffer definition. 575 // 576 //////////////////////////////////////////////////////////////////////////////// 577 578 // Abstract interface for a mutable RunStepResponse message. 579 // 580 // Note that there is no corresponding (immutable) 581 // RunStepResponseWrapper class, because the RunStepResponse object is 582 // always used as a mutable pointer. 583 class MutableRunStepResponseWrapper { 584 public: 585 virtual ~MutableRunStepResponseWrapper(); 586 587 // The values of the tensors whose fetching was requested in the 588 // RunStep call. 589 // 590 // NOTE: The order of the returned tensors may or may not match 591 // the fetch order specified in RunStepRequest. 592 virtual size_t num_tensors() const = 0; 593 virtual const string& tensor_name(size_t i) const = 0; 594 virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; 595 596 // Stores the i^{th} recv value in `run_graph_response` in this 597 // response with the given `name`. 598 virtual Status AddTensorFromRunGraphResponse( 599 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 600 size_t i) = 0; 601 602 // Returned metadata if requested in the options. 603 virtual const RunMetadata& metadata() const = 0; 604 virtual RunMetadata* mutable_metadata() = 0; 605 606 // Returned status if requested. 607 virtual errors::Code status_code() const = 0; 608 virtual const string& status_error_message() const = 0; 609 virtual void set_status(const Status& status) = 0; 610 611 protected: 612 // Returns a mutable protobuf message that represents the contents of 613 // this wrapper, for passing to an RPC subsystem that will populate 614 // the message. 615 // 616 // NOTE: Only `MasterInterface` subclasses may call this method. The 617 // `InMemoryRunStepResponse` subclass does not implement this 618 // method, and attempts to call it will fail with a fatal 619 // error. However, as long as callers always call 620 // `MasterInterface::RunStep()` with a wrapper object returned 621 // from `MasterInterface::CreateRunStepResponse()` called on the 622 // *same* MasterInterface object, this error will never trigger. 623 virtual RunStepResponse* get_proto() = 0; 624 friend class MasterInterface; 625 }; 626 627 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { 628 public: 629 // MutableRunStepResponseWrapper methods. 630 size_t num_tensors() const override; 631 const string& tensor_name(size_t i) const override; 632 Status TensorValue(size_t i, Tensor* out_tensor) const override; 633 Status AddTensorFromRunGraphResponse( 634 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 635 size_t i) override; 636 const RunMetadata& metadata() const override; 637 RunMetadata* mutable_metadata() override; 638 errors::Code status_code() const override; 639 const string& status_error_message() const override; 640 void set_status(const Status& status) override; 641 642 protected: 643 // NOTE: This method is not implemented. See 644 // MutableRunGraphResponseWrapper for an explanation. 645 RunStepResponse* get_proto() override; 646 647 private: 648 gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_; 649 RunMetadata metadata_; 650 // Store the code and message separately so that they can be updated 651 // independently by setters. 652 Status status_; 653 }; 654 655 // Proto-based message wrapper for use on the client side of the RunStep RPC. 656 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 657 public: 658 // MutableRunStepResponseWrapper methods. 659 size_t num_tensors() const override; 660 const string& tensor_name(size_t i) const override; 661 Status TensorValue(size_t i, Tensor* out_tensor) const override; 662 Status AddTensorFromRunGraphResponse( 663 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 664 size_t i) override; 665 const RunMetadata& metadata() const override; 666 RunMetadata* mutable_metadata() override; 667 errors::Code status_code() const override; 668 const string& status_error_message() const override; 669 void set_status(const Status& status) override; 670 671 protected: 672 RunStepResponse* get_proto() override; 673 674 private: 675 RunStepResponse response_; 676 }; 677 678 // Proto-based message wrapper for use on the server side of the RunStep RPC. 679 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { 680 public: 681 NonOwnedProtoRunStepResponse(RunStepResponse* response); 682 683 // MutableRunStepResponseWrapper methods. 684 size_t num_tensors() const override; 685 const string& tensor_name(size_t i) const override; 686 Status TensorValue(size_t i, Tensor* out_tensor) const override; 687 Status AddTensorFromRunGraphResponse( 688 const string& name, MutableRunGraphResponseWrapper* run_graph_response, 689 size_t i) override; 690 const RunMetadata& metadata() const override; 691 RunMetadata* mutable_metadata() override; 692 errors::Code status_code() const override; 693 const string& status_error_message() const override; 694 void set_status(const Status& status) override; 695 696 protected: 697 RunStepResponse* get_proto() override; 698 699 private: 700 RunStepResponse* response_; // Not owned. 701 }; 702 703 } // namespace tensorflow 704 705 #endif // TENSORFLOW 706