Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
     18 
     19 #include "tensorflow/core/framework/allocator.h"
     20 #include "tensorflow/core/framework/cost_graph.pb.h"
     21 #include "tensorflow/core/framework/graph.pb.h"
     22 #include "tensorflow/core/framework/step_stats.pb.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor.pb_text.h"
     25 #include "tensorflow/core/framework/versions.pb.h"
     26 #include "tensorflow/core/protobuf/config.pb.h"
     27 #include "tensorflow/core/protobuf/master.pb.h"
     28 #include "tensorflow/core/protobuf/worker.pb.h"
     29 
     30 namespace tensorflow {
     31 
     32 ////////////////////////////////////////////////////////////////////////////////
     33 //
     34 // Wrapper classes for the `MasterService.RunStep` request message.
     35 //
     36 // The `RunStepRequest` message can contain potentially large tensor
     37 // data as part of its `feed` submessages. Here we provide specialized
     38 // wrappers that avoid copying the tensor data wherever possible.
     39 //
     40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
     41 // protocol buffer definition.
     42 //
     43 ////////////////////////////////////////////////////////////////////////////////
     44 
     45 // Abstract interface for an immutable RunStepRequest message.
     46 //
     47 // This interface is typically used by server-side components in the
     48 // TensorFlow master.
     49 class RunStepRequestWrapper {
     50  public:
     51   virtual ~RunStepRequestWrapper() {}
     52 
     53   // REQUIRED: session_handle must be returned by a CreateSession call
     54   // to the same master service.
     55   virtual const string& session_handle() const = 0;
     56 
     57   // Partial run handle (optional). If specified, this will be a partial run
     58   // execution, run up to the specified fetches.
     59   virtual const string& partial_run_handle() const = 0;
     60 
     61   // Tensors to be fed in the step. Each feed is a named tensor.
     62   virtual size_t num_feeds() const = 0;
     63   virtual const string& feed_name(size_t i) const = 0;
     64 
     65   // Stores the content of the feed value at index `i` in `tensor`.
     66   virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
     67   virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
     68 
     69   // Fetches. A list of tensor names. The caller expects a tensor to
     70   // be returned for each fetch[i] (see RunStepResponse.tensor). The
     71   // order of specified fetches does not change the execution order.
     72   virtual size_t num_fetches() const = 0;
     73   virtual const string& fetch_name(size_t i) const = 0;
     74 
     75   // Target Nodes. A list of node names. The named nodes will be run
     76   // to but their outputs will not be fetched.
     77   virtual size_t num_targets() const = 0;
     78   virtual const string& target_name(size_t i) const = 0;
     79 
     80   // Options for the run call.
     81   virtual const RunOptions& options() const = 0;
     82 
     83   // If true then some errors, e.g., execution errors that have long
     84   // error messages, may return an OK RunStepResponse with the actual
     85   // error saved in the status_code/status_error_message fields of the
     86   // response body. This is a workaround since the RPC subsystem may
     87   // truncate long metadata messages.
     88   virtual bool store_errors_in_response_body() const = 0;
     89 
     90   virtual int64 request_id() const = 0;
     91 
     92   // Returns a human-readable representation of this message for debugging.
     93   virtual string DebugString() const = 0;
     94 
     95   // Returns the wrapped data as a protocol buffer message.
     96   virtual const RunStepRequest& ToProto() const = 0;
     97 };
     98 
     99 // Abstract interface for a mutable RunStepRequest message.
    100 //
    101 // See `RunStepRequestWrapper` above for a description of the fields.
    102 class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
    103  public:
    104   virtual void set_session_handle(const string& handle) = 0;
    105   virtual void set_partial_run_handle(const string& handle) = 0;
    106   virtual void add_feed(const string& name, const Tensor& value) = 0;
    107   virtual void add_fetch(const string& name) = 0;
    108   virtual void add_target(const string& name) = 0;
    109   virtual RunOptions* mutable_options() = 0;
    110   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
    111 };
    112 
    113 // Specialized (and mutable) wrapper for RunStep requests between a client and
    114 // master in the same address space.
    115 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
    116  public:
    117   // RunStepRequestWrapper methods.
    118   const string& session_handle() const override;
    119   const string& partial_run_handle() const override;
    120   size_t num_feeds() const override;
    121   const string& feed_name(size_t i) const override;
    122   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    123   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    124   size_t num_fetches() const override;
    125   const string& fetch_name(size_t i) const override;
    126   size_t num_targets() const override;
    127   const string& target_name(size_t i) const override;
    128   const RunOptions& options() const override;
    129   string DebugString() const override;
    130   const RunStepRequest& ToProto() const override;
    131   bool store_errors_in_response_body() const override;
    132   int64 request_id() const override;
    133 
    134   // MutableRunStepRequestWrapper methods.
    135   void set_session_handle(const string& handle) override;
    136   void set_partial_run_handle(const string& handle) override;
    137   void add_feed(const string& name, const Tensor& value) override;
    138   void add_fetch(const string& name) override;
    139   void add_target(const string& name) override;
    140   RunOptions* mutable_options() override;
    141   void set_store_errors_in_response_body(bool store_errors) override;
    142 
    143  private:
    144   string session_handle_;
    145   string partial_run_handle_;
    146   gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
    147   gtl::InlinedVector<string, 4> fetches_;
    148   gtl::InlinedVector<string, 4> targets_;
    149   RunOptions options_;
    150   bool store_errors_in_response_body_ = false;
    151 
    152   // Holds a cached and owned representation of the proto
    153   // representation of this request, if needed, so that `ToProto()`
    154   // can return a const RunStepRequest&.
    155   // NOTE(mrry): Although calls to `ToProto()` on this class are
    156   // expected to be rare, retaining ownership of the returned message
    157   // makes it easier to return a reference from the proto-backed
    158   // representations.
    159   mutable std::unique_ptr<RunStepRequest> proto_version_;
    160 };
    161 
    162 // Wrapper for mutable RunStep requests that uses a protobuf message.
    163 //
    164 // This wrapper class should be used for RunStep requests between a
    165 // client and master in different address spaces.
    166 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
    167  public:
    168   // RunStepRequestWrapper methods.
    169   const string& session_handle() const override;
    170   const string& partial_run_handle() const override;
    171   size_t num_feeds() const override;
    172   const string& feed_name(size_t i) const override;
    173   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    174   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    175   size_t num_fetches() const override;
    176   const string& fetch_name(size_t i) const override;
    177   size_t num_targets() const override;
    178   const string& target_name(size_t i) const override;
    179   const RunOptions& options() const override;
    180   string DebugString() const override;
    181   const RunStepRequest& ToProto() const override;
    182   bool store_errors_in_response_body() const override;
    183   int64 request_id() const override;
    184 
    185   // MutableRunStepRequestWrapper methods.
    186   void set_session_handle(const string& handle) override;
    187   void set_partial_run_handle(const string& handle) override;
    188   void add_feed(const string& name, const Tensor& value) override;
    189   void add_fetch(const string& name) override;
    190   void add_target(const string& name) override;
    191   RunOptions* mutable_options() override;
    192   void set_store_errors_in_response_body(bool store_errors) override;
    193 
    194  private:
    195   RunStepRequest request_;
    196   friend class MasterInterface;
    197 };
    198 
    199 // Wrapper for immutable RunStep requests that use a non-owned
    200 // protobuf message.
    201 //
    202 // This interface is typically used by server-side components in the
    203 // TensorFlow master, where the incoming message is a (possibly const)
    204 // `RunStepRequest*`.
    205 class ProtoRunStepRequest : public RunStepRequestWrapper {
    206  public:
    207   ProtoRunStepRequest(const RunStepRequest* request);
    208 
    209   // RunStepRequestWrapper methods.
    210   const string& session_handle() const override;
    211   const string& partial_run_handle() const override;
    212   size_t num_feeds() const override;
    213   const string& feed_name(size_t i) const override;
    214   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    215   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    216   size_t num_fetches() const override;
    217   const string& fetch_name(size_t i) const override;
    218   size_t num_targets() const override;
    219   const string& target_name(size_t i) const override;
    220   const RunOptions& options() const override;
    221   string DebugString() const override;
    222   const RunStepRequest& ToProto() const override;
    223   bool store_errors_in_response_body() const override;
    224   int64 request_id() const override;
    225 
    226  private:
    227   const RunStepRequest* const request_;  // Not owned.
    228 };
    229 
    230 ////////////////////////////////////////////////////////////////////////////////
    231 //
    232 // Wrapper classes for the `WorkerService.RunGraph` request message.
    233 //
    234 // The `RunGraphRequest` message can contain potentially large tensor
    235 // data as part of its `send` submessages. Here we provide specialized
    236 // wrappers that avoid copying the tensor data wherever possible.
    237 //
    238 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
    239 // protocol buffer definition.
    240 //
    241 ////////////////////////////////////////////////////////////////////////////////
    242 
    243 // Abstract interface for an immutable RunGraphRequest message.
    244 //
    245 // This interface is typically used by server-side components in the
    246 // TensorFlow worker.
    247 class RunGraphRequestWrapper {
    248  public:
    249   virtual ~RunGraphRequestWrapper() {}
    250 
    251   // The session handle used to register the graph. If empty, a single global
    252   // namespace is used.
    253   virtual const string& session_handle() const = 0;
    254 
    255   // Set to true if `CreateWorkerSession` was called for `session_handle`.
    256   virtual bool create_worker_session_called() const = 0;
    257 
    258   // REQUIRED: graph_handle must be returned by a RegisterGraph call
    259   // to the same WorkerService.
    260   virtual const string& graph_handle() const = 0;
    261 
    262   // A unique ID to distinguish different runs of the same graph.
    263   //
    264   // The master generates a global unique `step_id` to distinguish
    265   // different runs of the graph computation. Subgraphs communicate
    266   // (e.g., send/recv ops) with each other using `step_id` to
    267   // distinguish tensors generated by different runs.
    268   virtual int64 step_id() const = 0;
    269 
    270   // Options for this step.
    271   virtual const ExecutorOpts& exec_opts() const = 0;
    272 
    273   // Sends the tensors in "send" into the graph before the run.
    274   virtual size_t num_sends() const = 0;
    275   virtual const string& send_key(size_t i) const = 0;
    276   virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
    277 
    278   // Fetches the keys into `RunGraphResponse.recv` after the run.
    279   virtual size_t num_recvs() const = 0;
    280   virtual const string& recv_key(size_t i) const = 0;
    281 
    282   // True if the RunGraphRequest is a partial run request.
    283   virtual bool is_partial() const = 0;
    284 
    285   // True if this is the last partial run request in a sequence of requests.
    286   virtual bool is_last_partial_run() const = 0;
    287 
    288   // If true then some errors, e.g., execution errors that have long
    289   // error messages, may return an OK RunStepResponse with the actual
    290   // error saved in the status_code/status_error_message fields of the
    291   // response body. This is a workaround since the RPC subsystem may
    292   // truncate long metadata messages.
    293   virtual bool store_errors_in_response_body() const = 0;
    294 
    295   // Returns the wrapped data as a protocol buffer message.
    296   virtual const RunGraphRequest& ToProto() const = 0;
    297 };
    298 
    299 // Abstract interface for a mutable RunGraphRequest message.
    300 //
    301 // See `RunGraphRequestWrapper` above for a description of the fields.
    302 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
    303  public:
    304   virtual void set_session_handle(const string& handle) = 0;
    305   virtual void set_create_worker_session_called(bool called) = 0;
    306   virtual void set_graph_handle(const string& handle) = 0;
    307   virtual void set_step_id(int64 step_id) = 0;
    308   virtual ExecutorOpts* mutable_exec_opts() = 0;
    309 
    310   // Stores the i^{th} feed value in `run_step_request` in this
    311   // request with the given `send_key`.
    312   virtual Status AddSendFromRunStepRequest(
    313       const RunStepRequestWrapper& run_step_request, size_t i,
    314       const string& send_key) = 0;
    315   virtual Status AddSendFromRunCallableRequest(
    316       const RunCallableRequest& run_callable_request, size_t i,
    317       const string& send_key) = 0;
    318 
    319   virtual void add_recv_key(const string& recv_key) = 0;
    320   virtual void set_is_partial(bool is_partial) = 0;
    321   virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
    322   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
    323 };
    324 
    325 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
    326  public:
    327   // RunGraphRequestWrapper methods.
    328   const string& session_handle() const override;
    329   const string& graph_handle() const override;
    330   bool create_worker_session_called() const override;
    331   int64 step_id() const override;
    332   const ExecutorOpts& exec_opts() const override;
    333   size_t num_sends() const override;
    334   const string& send_key(size_t i) const override;
    335   Status SendValue(size_t i, Tensor* out_tensor) const override;
    336   size_t num_recvs() const override;
    337   const string& recv_key(size_t i) const override;
    338   bool is_partial() const override;
    339   bool is_last_partial_run() const override;
    340   const RunGraphRequest& ToProto() const override;
    341   bool store_errors_in_response_body() const override;
    342 
    343   // MutableRunGraphRequestWrapper methods.
    344   void set_session_handle(const string& handle) override;
    345   void set_create_worker_session_called(bool called) override;
    346   void set_graph_handle(const string& handle) override;
    347   void set_step_id(int64 step_id) override;
    348   ExecutorOpts* mutable_exec_opts() override;
    349   Status AddSendFromRunStepRequest(
    350       const RunStepRequestWrapper& run_step_request, size_t i,
    351       const string& send_key) override;
    352   Status AddSendFromRunCallableRequest(
    353       const RunCallableRequest& run_callable_request, size_t i,
    354       const string& send_key) override;
    355   void add_recv_key(const string& recv_key) override;
    356   void set_is_partial(bool is_partial) override;
    357   void set_is_last_partial_run(bool is_last_partial_run) override;
    358   void set_store_errors_in_response_body(bool store_errors) override;
    359 
    360  private:
    361   string session_handle_;
    362   bool create_worker_session_called_ = false;
    363   string graph_handle_;
    364   int64 step_id_;
    365   ExecutorOpts exec_opts_;
    366   gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
    367   gtl::InlinedVector<string, 4> recvs_;
    368   bool is_partial_ = false;
    369   bool is_last_partial_run_ = false;
    370   bool store_errors_in_response_body_ = false;
    371 
    372   // Holds a cached and owned representation of the proto
    373   // representation of this request, if needed, so that `ToProto()`
    374   // can return a const RunGraphRequest&.
    375   // NOTE(mrry): Although calls to `ToProto()` on this class are
    376   // expected to be rare, retaining ownership of the returned message
    377   // makes it easier to return a reference from the proto-backed
    378   // representations.
    379   mutable std::unique_ptr<RunGraphRequest> proto_version_;
    380 };
    381 
    382 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
    383  public:
    384   // RunGraphRequestWrapper methods.
    385   const string& session_handle() const override;
    386   bool create_worker_session_called() const override;
    387   const string& graph_handle() const override;
    388   int64 step_id() const override;
    389   const ExecutorOpts& exec_opts() const override;
    390   size_t num_sends() const override;
    391   const string& send_key(size_t i) const override;
    392   Status SendValue(size_t i, Tensor* out_tensor) const override;
    393   size_t num_recvs() const override;
    394   const string& recv_key(size_t i) const override;
    395   bool is_partial() const override;
    396   bool is_last_partial_run() const override;
    397   bool store_errors_in_response_body() const override;
    398   const RunGraphRequest& ToProto() const override;
    399 
    400   // MutableRunGraphRequestWrapper methods.
    401   void set_session_handle(const string& handle) override;
    402   void set_create_worker_session_called(bool called) override;
    403   void set_graph_handle(const string& handle) override;
    404   void set_step_id(int64 step_id) override;
    405   ExecutorOpts* mutable_exec_opts() override;
    406   Status AddSendFromRunStepRequest(
    407       const RunStepRequestWrapper& run_step_request, size_t i,
    408       const string& send_key) override;
    409   Status AddSendFromRunCallableRequest(
    410       const RunCallableRequest& run_callable_request, size_t i,
    411       const string& send_key) override;
    412   void add_recv_key(const string& recv_key) override;
    413   void set_is_partial(bool is_partial) override;
    414   void set_is_last_partial_run(bool is_last_partial_run) override;
    415   void set_store_errors_in_response_body(bool store_errors) override;
    416 
    417  private:
    418   RunGraphRequest request_;
    419 };
    420 
    421 class ProtoRunGraphRequest : public RunGraphRequestWrapper {
    422  public:
    423   ProtoRunGraphRequest(const RunGraphRequest* request);
    424 
    425   // RunGraphRequestWrapper methods.
    426   const string& session_handle() const override;
    427   bool create_worker_session_called() const override;
    428   const string& graph_handle() const override;
    429   int64 step_id() const override;
    430   const ExecutorOpts& exec_opts() const override;
    431   size_t num_sends() const override;
    432   const string& send_key(size_t i) const override;
    433   Status SendValue(size_t i, Tensor* out_tensor) const override;
    434   size_t num_recvs() const override;
    435   const string& recv_key(size_t i) const override;
    436   bool is_partial() const override;
    437   bool is_last_partial_run() const override;
    438   bool store_errors_in_response_body() const override;
    439   const RunGraphRequest& ToProto() const override;
    440 
    441  private:
    442   const RunGraphRequest* const request_;  // Not owned.
    443 };
    444 
    445 ////////////////////////////////////////////////////////////////////////////////
    446 //
    447 // Wrapper classes for the `WorkerService.RunGraph` response message.
    448 //
    449 // The `RunGraphResponse` message can contain potentially large tensor
    450 // data as part of its `recv` submessages. Here we provide specialized
    451 // wrappers that avoid copying the tensor data wherever possible.
    452 //
    453 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
    454 // protocol buffer definition.
    455 //
    456 ////////////////////////////////////////////////////////////////////////////////
    457 
    458 // Abstract interface for a mutable RunGraphResponse message.
    459 //
    460 // Note that there is no corresponding (immutable)
    461 // RunGraphResponseWrapper class, because the RunGraphResponse object
    462 // is always used as a mutable pointer.
    463 class MutableRunGraphResponseWrapper {
    464  public:
    465   virtual ~MutableRunGraphResponseWrapper() {}
    466 
    467   // A list of tensors corresponding to those requested by
    468   // `RunGraphRequest.recv_key`.
    469   virtual size_t num_recvs() const = 0;
    470   virtual const string& recv_key(size_t i) const = 0;
    471   // NOTE: The following methods may perform a destructive read, for
    472   // efficiency.
    473   virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
    474   virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
    475   virtual void AddRecv(const string& key, const Tensor& value) = 0;
    476 
    477   // Submessages that store performance statistics about the subgraph
    478   // execution, if necessary.
    479   virtual StepStats* mutable_step_stats() = 0;
    480   virtual CostGraphDef* mutable_cost_graph() = 0;
    481   virtual size_t num_partition_graphs() const = 0;
    482   virtual GraphDef* mutable_partition_graph(size_t i) = 0;
    483   virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
    484 
    485   // Returned status if requested.
    486   virtual errors::Code status_code() const = 0;
    487   virtual const string& status_error_message() const = 0;
    488   virtual void set_status(const Status& status) = 0;
    489 
    490  protected:
    491   // Returns a mutable protobuf message that represents the contents of
    492   // this wrapper, for passing to an RPC subsystem that will populate
    493   // the message.
    494   //
    495   // NOTE: Only `WorkerInterface` subclasses may call this method. The
    496   // `InMemoryRunGraphResponse` subclass does not implement this
    497   // method, and attempts to call it will fail with a fatal
    498   // error. However, as long as callers always call
    499   // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
    500   // from `WorkerInterface::CreateRunGraphResponse()` called on the
    501   // *same* WorkerInterface object, this error will never trigger.
    502   virtual RunGraphResponse* get_proto() = 0;
    503   friend class WorkerInterface;
    504 };
    505 
    506 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
    507  public:
    508   // MutableRunGraphResponseWrapper methods.
    509   size_t num_recvs() const override;
    510   const string& recv_key(size_t i) const override;
    511   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    512   Status RecvValue(size_t i, Tensor* out_tensor) override;
    513   void AddRecv(const string& key, const Tensor& value) override;
    514   StepStats* mutable_step_stats() override;
    515   CostGraphDef* mutable_cost_graph() override;
    516   size_t num_partition_graphs() const override;
    517   GraphDef* mutable_partition_graph(size_t i) override;
    518   void AddPartitionGraph(const GraphDef& partition_graph) override;
    519   errors::Code status_code() const override;
    520   const string& status_error_message() const override;
    521   void set_status(const Status& status) override;
    522 
    523  protected:
    524   // NOTE: This method is not implemented. See
    525   // MutableRunGraphResponseWrapper for an explanation.
    526   RunGraphResponse* get_proto() override;
    527 
    528  private:
    529   gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
    530   StepStats step_stats_;
    531   CostGraphDef cost_graph_;
    532   std::vector<GraphDef> partition_graphs_;
    533   // Store the code and message separately so that they can be updated
    534   // independently by setters.
    535   Status status_;
    536 };
    537 
    538 // Proto-based message wrapper for use on the client side of the RunGraph RPC.
    539 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
    540  public:
    541   // MutableRunGraphResponseWrapper methods.
    542   size_t num_recvs() const override;
    543   const string& recv_key(size_t i) const override;
    544   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    545   Status RecvValue(size_t i, Tensor* out_tensor) override;
    546   void AddRecv(const string& key, const Tensor& value) override;
    547   StepStats* mutable_step_stats() override;
    548   CostGraphDef* mutable_cost_graph() override;
    549   size_t num_partition_graphs() const override;
    550   GraphDef* mutable_partition_graph(size_t i) override;
    551   void AddPartitionGraph(const GraphDef& partition_graph) override;
    552   errors::Code status_code() const override;
    553   const string& status_error_message() const override;
    554   void set_status(const Status& status) override;
    555 
    556  protected:
    557   RunGraphResponse* get_proto() override;
    558 
    559  private:
    560   RunGraphResponse response_;
    561 };
    562 
    563 // Proto-based message wrapper for use on the server side of the RunGraph RPC.
    564 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
    565  public:
    566   NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
    567 
    568   // MutableRunGraphResponseWrapper methods.
    569   size_t num_recvs() const override;
    570   const string& recv_key(size_t i) const override;
    571   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    572   Status RecvValue(size_t i, Tensor* out_tensor) override;
    573   void AddRecv(const string& key, const Tensor& value) override;
    574   StepStats* mutable_step_stats() override;
    575   CostGraphDef* mutable_cost_graph() override;
    576   size_t num_partition_graphs() const override;
    577   GraphDef* mutable_partition_graph(size_t i) override;
    578   void AddPartitionGraph(const GraphDef& partition_graph) override;
    579   errors::Code status_code() const override;
    580   const string& status_error_message() const override;
    581   void set_status(const Status& status) override;
    582 
    583  protected:
    584   RunGraphResponse* get_proto() override;
    585 
    586  private:
    587   RunGraphResponse* const response_;
    588 };
    589 
    590 ////////////////////////////////////////////////////////////////////////////////
    591 //
    592 // Wrapper classes for the `MasterService.RunStep` response message.
    593 //
    594 // The `RunStepResponse` message can contain potentially large tensor
    595 // data as part of its `tensor` submessages. Here we provide specialized
    596 // wrappers that avoid copying the tensor data wherever possible.
    597 //
    598 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
    599 // protocol buffer definition.
    600 //
    601 ////////////////////////////////////////////////////////////////////////////////
    602 
    603 // Abstract interface for a mutable RunStepResponse message.
    604 //
    605 // Note that there is no corresponding (immutable)
    606 // RunStepResponseWrapper class, because the RunStepResponse object is
    607 // always used as a mutable pointer.
    608 class MutableRunStepResponseWrapper {
    609  public:
    610   virtual ~MutableRunStepResponseWrapper();
    611 
    612   // The values of the tensors whose fetching was requested in the
    613   // RunStep call.
    614   //
    615   // NOTE: The order of the returned tensors may or may not match
    616   // the fetch order specified in RunStepRequest.
    617   virtual size_t num_tensors() const = 0;
    618   virtual const string& tensor_name(size_t i) const = 0;
    619   virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
    620 
    621   // Stores the i^{th} recv value in `run_graph_response` in this
    622   // response with the given `name`.
    623   virtual Status AddTensorFromRunGraphResponse(
    624       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    625       size_t i) = 0;
    626 
    627   // Returned metadata if requested in the options.
    628   virtual const RunMetadata& metadata() const = 0;
    629   virtual RunMetadata* mutable_metadata() = 0;
    630 
    631   // Returned status if requested.
    632   virtual errors::Code status_code() const = 0;
    633   virtual const string& status_error_message() const = 0;
    634   virtual void set_status(const Status& status) = 0;
    635 
    636  protected:
    637   // Returns a mutable protobuf message that represents the contents of
    638   // this wrapper, for passing to an RPC subsystem that will populate
    639   // the message.
    640   //
    641   // NOTE: Only `MasterInterface` subclasses may call this method. The
    642   // `InMemoryRunStepResponse` subclass does not implement this
    643   // method, and attempts to call it will fail with a fatal
    644   // error. However, as long as callers always call
    645   // `MasterInterface::RunStep()` with a wrapper object returned
    646   // from `MasterInterface::CreateRunStepResponse()` called on the
    647   // *same* MasterInterface object, this error will never trigger.
    648   virtual RunStepResponse* get_proto() = 0;
    649   friend class MasterInterface;
    650 };
    651 
    652 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
    653  public:
    654   // MutableRunStepResponseWrapper methods.
    655   size_t num_tensors() const override;
    656   const string& tensor_name(size_t i) const override;
    657   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    658   Status AddTensorFromRunGraphResponse(
    659       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    660       size_t i) override;
    661   const RunMetadata& metadata() const override;
    662   RunMetadata* mutable_metadata() override;
    663   errors::Code status_code() const override;
    664   const string& status_error_message() const override;
    665   void set_status(const Status& status) override;
    666 
    667  protected:
    668   // NOTE: This method is not implemented. See
    669   // MutableRunGraphResponseWrapper for an explanation.
    670   RunStepResponse* get_proto() override;
    671 
    672  private:
    673   gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
    674   RunMetadata metadata_;
    675   // Store the code and message separately so that they can be updated
    676   // independently by setters.
    677   Status status_;
    678 };
    679 
    680 // Proto-based message wrapper for use on the client side of the RunStep RPC.
    681 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
    682  public:
    683   // MutableRunStepResponseWrapper methods.
    684   size_t num_tensors() const override;
    685   const string& tensor_name(size_t i) const override;
    686   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    687   Status AddTensorFromRunGraphResponse(
    688       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    689       size_t i) override;
    690   const RunMetadata& metadata() const override;
    691   RunMetadata* mutable_metadata() override;
    692   errors::Code status_code() const override;
    693   const string& status_error_message() const override;
    694   void set_status(const Status& status) override;
    695 
    696  protected:
    697   RunStepResponse* get_proto() override;
    698 
    699  private:
    700   RunStepResponse response_;
    701 };
    702 
    703 // Proto-based message wrapper for use on the server side of the RunStep RPC.
    704 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
    705  public:
    706   NonOwnedProtoRunStepResponse(RunStepResponse* response);
    707 
    708   // MutableRunStepResponseWrapper methods.
    709   size_t num_tensors() const override;
    710   const string& tensor_name(size_t i) const override;
    711   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    712   Status AddTensorFromRunGraphResponse(
    713       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    714       size_t i) override;
    715   const RunMetadata& metadata() const override;
    716   RunMetadata* mutable_metadata() override;
    717   errors::Code status_code() const override;
    718   const string& status_error_message() const override;
    719   void set_status(const Status& status) override;
    720 
    721  protected:
    722   RunStepResponse* get_proto() override;
    723 
    724  private:
    725   RunStepResponse* response_;  // Not owned.
    726 };
    727 
    728 }  // namespace tensorflow
    729 
    730 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
    731