Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
     18 
     19 #include "tensorflow/core/framework/allocator.h"
     20 #include "tensorflow/core/framework/cost_graph.pb.h"
     21 #include "tensorflow/core/framework/graph.pb.h"
     22 #include "tensorflow/core/framework/step_stats.pb.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor.pb_text.h"
     25 #include "tensorflow/core/framework/versions.pb.h"
     26 #include "tensorflow/core/protobuf/config.pb.h"
     27 #include "tensorflow/core/protobuf/master.pb.h"
     28 #include "tensorflow/core/protobuf/worker.pb.h"
     29 
     30 namespace tensorflow {
     31 
     32 ////////////////////////////////////////////////////////////////////////////////
     33 //
     34 // Wrapper classes for the `MasterService.RunStep` request message.
     35 //
     36 // The `RunStepRequest` message can contain potentially large tensor
     37 // data as part of its `feed` submessages. Here we provide specialized
     38 // wrappers that avoid copying the tensor data wherever possible.
     39 //
     40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
     41 // protocol buffer definition.
     42 //
     43 ////////////////////////////////////////////////////////////////////////////////
     44 
     45 // Abstract interface for an immutable RunStepRequest message.
     46 //
     47 // This interface is typically used by server-side components in the
     48 // TensorFlow master.
     49 class RunStepRequestWrapper {
     50  public:
     51   virtual ~RunStepRequestWrapper() {}
     52 
     53   // REQUIRED: session_handle must be returned by a CreateSession call
     54   // to the same master service.
     55   virtual const string& session_handle() const = 0;
     56 
     57   // Partial run handle (optional). If specified, this will be a partial run
     58   // execution, run up to the specified fetches.
     59   virtual const string& partial_run_handle() const = 0;
     60 
     61   // Tensors to be fed in the step. Each feed is a named tensor.
     62   virtual size_t num_feeds() const = 0;
     63   virtual const string& feed_name(size_t i) const = 0;
     64 
     65   // Stores the content of the feed value at index `i` in `tensor`.
     66   virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
     67   virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
     68 
     69   // Fetches. A list of tensor names. The caller expects a tensor to
     70   // be returned for each fetch[i] (see RunStepResponse.tensor). The
     71   // order of specified fetches does not change the execution order.
     72   virtual size_t num_fetches() const = 0;
     73   virtual const string& fetch_name(size_t i) const = 0;
     74 
     75   // Target Nodes. A list of node names. The named nodes will be run
     76   // to but their outputs will not be fetched.
     77   virtual size_t num_targets() const = 0;
     78   virtual const string& target_name(size_t i) const = 0;
     79 
     80   // Options for the run call.
     81   virtual const RunOptions& options() const = 0;
     82 
     83   // If true then some errors, e.g., execution errors that have long
     84   // error messages, may return an OK RunStepResponse with the actual
     85   // error saved in the status_code/status_error_message fields of the
     86   // response body. This is a workaround since the RPC subsystem may
     87   // truncate long metadata messages.
     88   virtual bool store_errors_in_response_body() const = 0;
     89 
     90   // Returns a human-readable representation of this message for debugging.
     91   virtual string DebugString() const = 0;
     92 
     93   // Returns the wrapped data as a protocol buffer message.
     94   virtual const RunStepRequest& ToProto() const = 0;
     95 };
     96 
     97 // Abstract interface for a mutable RunStepRequest message.
     98 //
     99 // See `RunStepRequestWrapper` above for a description of the fields.
    100 class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
    101  public:
    102   virtual void set_session_handle(const string& handle) = 0;
    103   virtual void set_partial_run_handle(const string& handle) = 0;
    104   virtual void add_feed(const string& name, const Tensor& value) = 0;
    105   virtual void add_fetch(const string& name) = 0;
    106   virtual void add_target(const string& name) = 0;
    107   virtual RunOptions* mutable_options() = 0;
    108   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
    109 };
    110 
    111 // Specialized (and mutable) wrapper for RunStep requests between a client and
    112 // master in the same address space.
    113 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
    114  public:
    115   // RunStepRequestWrapper methods.
    116   const string& session_handle() const override;
    117   const string& partial_run_handle() const override;
    118   size_t num_feeds() const override;
    119   const string& feed_name(size_t i) const override;
    120   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    121   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    122   size_t num_fetches() const override;
    123   const string& fetch_name(size_t i) const override;
    124   size_t num_targets() const override;
    125   const string& target_name(size_t i) const override;
    126   const RunOptions& options() const override;
    127   string DebugString() const override;
    128   const RunStepRequest& ToProto() const override;
    129   bool store_errors_in_response_body() const override;
    130 
    131   // MutableRunStepRequestWrapper methods.
    132   void set_session_handle(const string& handle) override;
    133   void set_partial_run_handle(const string& handle) override;
    134   void add_feed(const string& name, const Tensor& value) override;
    135   void add_fetch(const string& name) override;
    136   void add_target(const string& name) override;
    137   RunOptions* mutable_options() override;
    138   void set_store_errors_in_response_body(bool store_errors) override;
    139 
    140  private:
    141   string session_handle_;
    142   string partial_run_handle_;
    143   gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
    144   gtl::InlinedVector<string, 4> fetches_;
    145   gtl::InlinedVector<string, 4> targets_;
    146   RunOptions options_;
    147   bool store_errors_in_response_body_ = false;
    148 
    149   // Holds a cached and owned representation of the proto
    150   // representation of this request, if needed, so that `ToProto()`
    151   // can return a const RunStepRequest&.
    152   // NOTE(mrry): Although calls to `ToProto()` on this class are
    153   // expected to be rare, retaining ownership of the returned message
    154   // makes it easier to return a reference from the proto-backed
    155   // representations.
    156   mutable std::unique_ptr<RunStepRequest> proto_version_;
    157 };
    158 
    159 // Wrapper for mutable RunStep requests that uses a protobuf message.
    160 //
    161 // This wrapper class should be used for RunStep requests between a
    162 // client and master in different address spaces.
    163 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
    164  public:
    165   // RunStepRequestWrapper methods.
    166   const string& session_handle() const override;
    167   const string& partial_run_handle() const override;
    168   size_t num_feeds() const override;
    169   const string& feed_name(size_t i) const override;
    170   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    171   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    172   size_t num_fetches() const override;
    173   const string& fetch_name(size_t i) const override;
    174   size_t num_targets() const override;
    175   const string& target_name(size_t i) const override;
    176   const RunOptions& options() const override;
    177   string DebugString() const override;
    178   const RunStepRequest& ToProto() const override;
    179   bool store_errors_in_response_body() const override;
    180 
    181   // MutableRunStepRequestWrapper methods.
    182   void set_session_handle(const string& handle) override;
    183   void set_partial_run_handle(const string& handle) override;
    184   void add_feed(const string& name, const Tensor& value) override;
    185   void add_fetch(const string& name) override;
    186   void add_target(const string& name) override;
    187   RunOptions* mutable_options() override;
    188   void set_store_errors_in_response_body(bool store_errors) override;
    189 
    190  private:
    191   RunStepRequest request_;
    192 };
    193 
    194 // Wrapper for immutable RunStep requests that use a non-owned
    195 // protobuf message.
    196 //
    197 // This interface is typically used by server-side components in the
    198 // TensorFlow master, where the incoming message is a (possibly const)
    199 // `RunStepRequest*`.
    200 class ProtoRunStepRequest : public RunStepRequestWrapper {
    201  public:
    202   ProtoRunStepRequest(const RunStepRequest* request);
    203 
    204   // RunStepRequestWrapper methods.
    205   const string& session_handle() const override;
    206   const string& partial_run_handle() const override;
    207   size_t num_feeds() const override;
    208   const string& feed_name(size_t i) const override;
    209   Status FeedValue(size_t i, Tensor* out_tensor) const override;
    210   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
    211   size_t num_fetches() const override;
    212   const string& fetch_name(size_t i) const override;
    213   size_t num_targets() const override;
    214   const string& target_name(size_t i) const override;
    215   const RunOptions& options() const override;
    216   string DebugString() const override;
    217   const RunStepRequest& ToProto() const override;
    218   bool store_errors_in_response_body() const override;
    219 
    220  private:
    221   const RunStepRequest* const request_;  // Not owned.
    222 };
    223 
    224 ////////////////////////////////////////////////////////////////////////////////
    225 //
    226 // Wrapper classes for the `WorkerService.RunGraph` request message.
    227 //
    228 // The `RunGraphRequest` message can contain potentially large tensor
    229 // data as part of its `send` submessages. Here we provide specialized
    230 // wrappers that avoid copying the tensor data wherever possible.
    231 //
    232 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
    233 // protocol buffer definition.
    234 //
    235 ////////////////////////////////////////////////////////////////////////////////
    236 
    237 // Abstract interface for an immutable RunStepRequest message.
    238 //
    239 // This interface is typically used by server-side components in the
    240 // TensorFlow worker.
    241 class RunGraphRequestWrapper {
    242  public:
    243   virtual ~RunGraphRequestWrapper() {}
    244 
    245   // The session handle used to register the graph. If empty, a single global
    246   // namespace is used.
    247   virtual const string& session_handle() const = 0;
    248 
    249   // REQUIRED: graph_handle must be returned by a RegisterGraph call
    250   // to the same WorkerService.
    251   virtual const string& graph_handle() const = 0;
    252 
    253   // A unique ID to distinguish different runs of the same graph.
    254   //
    255   // The master generates a global unique `step_id` to distinguish
    256   // different runs of the graph computation. Subgraphs communicate
    257   // (e.g., send/recv ops) with each other using `step_id` to
    258   // distinguish tensors generated by different runs.
    259   virtual int64 step_id() const = 0;
    260 
    261   // Options for this step.
    262   virtual const ExecutorOpts& exec_opts() const = 0;
    263 
    264   // Sends the tensors in "send" into the graph before the run.
    265   virtual size_t num_sends() const = 0;
    266   virtual const string& send_key(size_t i) const = 0;
    267   virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
    268 
    269   // Fetches the keys into `RunGraphResponse.recv` after the run.
    270   virtual size_t num_recvs() const = 0;
    271   virtual const string& recv_key(size_t i) const = 0;
    272 
    273   // True if the RunGraphRequest is a partial run request.
    274   virtual bool is_partial() const = 0;
    275 
    276   // True if this is the last partial run request in a sequence of requests.
    277   virtual bool is_last_partial_run() const = 0;
    278 
    279   // If true then some errors, e.g., execution errors that have long
    280   // error messages, may return an OK RunStepResponse with the actual
    281   // error saved in the status_code/status_error_message fields of the
    282   // response body. This is a workaround since the RPC subsystem may
    283   // truncate long metadata messages.
    284   virtual bool store_errors_in_response_body() const = 0;
    285 
    286   // Returns the wrapped data as a protocol buffer message.
    287   virtual const RunGraphRequest& ToProto() const = 0;
    288 };
    289 
    290 // Abstract interface for a mutable RunGraphRequest message.
    291 //
    292 // See `RunGraphRequestWrapper` above for a description of the fields.
    293 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
    294  public:
    295   virtual void set_session_handle(const string& handle) = 0;
    296   virtual void set_graph_handle(const string& handle) = 0;
    297   virtual void set_step_id(int64 step_id) = 0;
    298   virtual ExecutorOpts* mutable_exec_opts() = 0;
    299 
    300   // Stores the i^{th} feed value in `run_step_request` in this
    301   // request with the given `send_key`.
    302   virtual Status AddSendFromRunStepRequest(
    303       const RunStepRequestWrapper& run_step_request, size_t i,
    304       const string& send_key) = 0;
    305 
    306   virtual void add_recv_key(const string& recv_key) = 0;
    307   virtual void set_is_partial(bool is_partial) = 0;
    308   virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
    309   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
    310 };
    311 
    312 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
    313  public:
    314   // RunGraphRequestWrapper methods.
    315   const string& session_handle() const override;
    316   const string& graph_handle() const override;
    317   int64 step_id() const override;
    318   const ExecutorOpts& exec_opts() const override;
    319   size_t num_sends() const override;
    320   const string& send_key(size_t i) const override;
    321   Status SendValue(size_t i, Tensor* out_tensor) const override;
    322   size_t num_recvs() const override;
    323   const string& recv_key(size_t i) const override;
    324   bool is_partial() const override;
    325   bool is_last_partial_run() const override;
    326   const RunGraphRequest& ToProto() const override;
    327   bool store_errors_in_response_body() const override;
    328 
    329   // MutableRunGraphRequestWrapper methods.
    330   void set_session_handle(const string& handle) override;
    331   void set_graph_handle(const string& handle) override;
    332   void set_step_id(int64 step_id) override;
    333   ExecutorOpts* mutable_exec_opts() override;
    334   Status AddSendFromRunStepRequest(
    335       const RunStepRequestWrapper& run_step_request, size_t i,
    336       const string& send_key) override;
    337   void add_recv_key(const string& recv_key) override;
    338   void set_is_partial(bool is_partial) override;
    339   void set_is_last_partial_run(bool is_last_partial_run) override;
    340   void set_store_errors_in_response_body(bool store_errors) override;
    341 
    342  private:
    343   string session_handle_;
    344   string graph_handle_;
    345   int64 step_id_;
    346   ExecutorOpts exec_opts_;
    347   gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
    348   gtl::InlinedVector<string, 4> recvs_;
    349   bool is_partial_ = false;
    350   bool is_last_partial_run_ = false;
    351   bool store_errors_in_response_body_ = false;
    352 
    353   // Holds a cached and owned representation of the proto
    354   // representation of this request, if needed, so that `ToProto()`
    355   // can return a const RunGraphRequest&.
    356   // NOTE(mrry): Although calls to `ToProto()` on this class are
    357   // expected to be rare, retaining ownership of the returned message
    358   // makes it easier to return a reference from the proto-backed
    359   // representations.
    360   mutable std::unique_ptr<RunGraphRequest> proto_version_;
    361 };
    362 
    363 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
    364  public:
    365   // RunGraphRequestWrapper methods.
    366   const string& session_handle() const override;
    367   const string& graph_handle() const override;
    368   int64 step_id() const override;
    369   const ExecutorOpts& exec_opts() const override;
    370   size_t num_sends() const override;
    371   const string& send_key(size_t i) const override;
    372   Status SendValue(size_t i, Tensor* out_tensor) const override;
    373   size_t num_recvs() const override;
    374   const string& recv_key(size_t i) const override;
    375   bool is_partial() const override;
    376   bool is_last_partial_run() const override;
    377   bool store_errors_in_response_body() const override;
    378   const RunGraphRequest& ToProto() const override;
    379 
    380   // MutableRunGraphRequestWrapper methods.
    381   void set_session_handle(const string& handle) override;
    382   void set_graph_handle(const string& handle) override;
    383   void set_step_id(int64 step_id) override;
    384   ExecutorOpts* mutable_exec_opts() override;
    385   Status AddSendFromRunStepRequest(
    386       const RunStepRequestWrapper& run_step_request, size_t i,
    387       const string& send_key) override;
    388   void add_recv_key(const string& recv_key) override;
    389   void set_is_partial(bool is_partial) override;
    390   void set_is_last_partial_run(bool is_last_partial_run) override;
    391   void set_store_errors_in_response_body(bool store_errors) override;
    392 
    393  private:
    394   RunGraphRequest request_;
    395 };
    396 
    397 class ProtoRunGraphRequest : public RunGraphRequestWrapper {
    398  public:
    399   ProtoRunGraphRequest(const RunGraphRequest* request);
    400 
    401   // RunGraphRequestWrapper methods.
    402   const string& session_handle() const override;
    403   const string& graph_handle() const override;
    404   int64 step_id() const override;
    405   const ExecutorOpts& exec_opts() const override;
    406   size_t num_sends() const override;
    407   const string& send_key(size_t i) const override;
    408   Status SendValue(size_t i, Tensor* out_tensor) const override;
    409   size_t num_recvs() const override;
    410   const string& recv_key(size_t i) const override;
    411   bool is_partial() const override;
    412   bool is_last_partial_run() const override;
    413   bool store_errors_in_response_body() const override;
    414   const RunGraphRequest& ToProto() const override;
    415 
    416  private:
    417   const RunGraphRequest* const request_;  // Not owned.
    418 };
    419 
    420 ////////////////////////////////////////////////////////////////////////////////
    421 //
    422 // Wrapper classes for the `WorkerService.RunGraph` response message.
    423 //
    424 // The `RunGraphResponse` message can contain potentially large tensor
    425 // data as part of its `recv` submessages. Here we provide specialized
    426 // wrappers that avoid copying the tensor data wherever possible.
    427 //
    428 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
    429 // protocol buffer definition.
    430 //
    431 ////////////////////////////////////////////////////////////////////////////////
    432 
    433 // Abstract interface for a mutable RunGraphResponse message.
    434 //
    435 // Note that there is no corresponding (immutable)
    436 // RunGraphResponseWrapper class, because the RunGraphResponse object
    437 // is always used as a mutable pointer.
    438 class MutableRunGraphResponseWrapper {
    439  public:
    440   virtual ~MutableRunGraphResponseWrapper() {}
    441 
    442   // A list of tensors corresponding to those requested by
    443   // `RunGraphRequest.recv_key`.
    444   virtual size_t num_recvs() const = 0;
    445   virtual const string& recv_key(size_t i) const = 0;
    446   // NOTE: The following methods may perform a destructive read, for
    447   // efficiency.
    448   virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
    449   virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
    450   virtual void AddRecv(const string& key, const Tensor& value) = 0;
    451 
    452   // Submessages that store performance statistics about the subgraph
    453   // execution, if necessary.
    454   virtual StepStats* mutable_step_stats() = 0;
    455   virtual CostGraphDef* mutable_cost_graph() = 0;
    456   virtual size_t num_partition_graphs() const = 0;
    457   virtual GraphDef* mutable_partition_graph(size_t i) = 0;
    458   virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
    459 
    460   // Returned status if requested.
    461   virtual errors::Code status_code() const = 0;
    462   virtual const string& status_error_message() const = 0;
    463   virtual void set_status(const Status& status) = 0;
    464 
    465  protected:
    466   // Returns a mutable protobuf message that represents the contents of
    467   // this wrapper, for passing to an RPC subsystem that will populate
    468   // the message.
    469   //
    470   // NOTE: Only `WorkerInterface` subclasses may call this method. The
    471   // `InMemoryRunGraphResponse` subclass does not implement this
    472   // method, and attempts to call it will fail with a fatal
    473   // error. However, as long as callers always call
    474   // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
    475   // from `WorkerInterface::CreateRunGraphResponse()` called on the
    476   // *same* WorkerInterface object, this error will never trigger.
    477   virtual RunGraphResponse* get_proto() = 0;
    478   friend class WorkerInterface;
    479 };
    480 
    481 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
    482  public:
    483   // MutableRunGraphResponseWrapper methods.
    484   size_t num_recvs() const override;
    485   const string& recv_key(size_t i) const override;
    486   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    487   Status RecvValue(size_t i, Tensor* out_tensor) override;
    488   void AddRecv(const string& key, const Tensor& value) override;
    489   StepStats* mutable_step_stats() override;
    490   CostGraphDef* mutable_cost_graph() override;
    491   size_t num_partition_graphs() const override;
    492   GraphDef* mutable_partition_graph(size_t i) override;
    493   void AddPartitionGraph(const GraphDef& partition_graph) override;
    494   errors::Code status_code() const override;
    495   const string& status_error_message() const override;
    496   void set_status(const Status& status) override;
    497 
    498  protected:
    499   // NOTE: This method is not implemented. See
    500   // MutableRunGraphResponseWrapper for an explanation.
    501   RunGraphResponse* get_proto() override;
    502 
    503  private:
    504   gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
    505   StepStats step_stats_;
    506   CostGraphDef cost_graph_;
    507   std::vector<GraphDef> partition_graphs_;
    508   // Store the code and message separately so that they can be updated
    509   // independently by setters.
    510   Status status_;
    511 };
    512 
    513 // Proto-based message wrapper for use on the client side of the RunGraph RPC.
    514 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
    515  public:
    516   // MutableRunGraphResponseWrapper methods.
    517   size_t num_recvs() const override;
    518   const string& recv_key(size_t i) const override;
    519   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    520   Status RecvValue(size_t i, Tensor* out_tensor) override;
    521   void AddRecv(const string& key, const Tensor& value) override;
    522   StepStats* mutable_step_stats() override;
    523   CostGraphDef* mutable_cost_graph() override;
    524   size_t num_partition_graphs() const override;
    525   GraphDef* mutable_partition_graph(size_t i) override;
    526   void AddPartitionGraph(const GraphDef& partition_graph) override;
    527   errors::Code status_code() const override;
    528   const string& status_error_message() const override;
    529   void set_status(const Status& status) override;
    530 
    531  protected:
    532   RunGraphResponse* get_proto() override;
    533 
    534  private:
    535   RunGraphResponse response_;
    536 };
    537 
    538 // Proto-based message wrapper for use on the server side of the RunGraph RPC.
    539 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
    540  public:
    541   NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
    542 
    543   // MutableRunGraphResponseWrapper methods.
    544   size_t num_recvs() const override;
    545   const string& recv_key(size_t i) const override;
    546   Status RecvValue(size_t i, TensorProto* out_tensor) override;
    547   Status RecvValue(size_t i, Tensor* out_tensor) override;
    548   void AddRecv(const string& key, const Tensor& value) override;
    549   StepStats* mutable_step_stats() override;
    550   CostGraphDef* mutable_cost_graph() override;
    551   size_t num_partition_graphs() const override;
    552   GraphDef* mutable_partition_graph(size_t i) override;
    553   void AddPartitionGraph(const GraphDef& partition_graph) override;
    554   errors::Code status_code() const override;
    555   const string& status_error_message() const override;
    556   void set_status(const Status& status) override;
    557 
    558  protected:
    559   RunGraphResponse* get_proto() override;
    560 
    561  private:
    562   RunGraphResponse* const response_;
    563 };
    564 
    565 ////////////////////////////////////////////////////////////////////////////////
    566 //
    567 // Wrapper classes for the `MasterService.RunStep` response message.
    568 //
    569 // The `RunStepResponse` message can contain potentially large tensor
    570 // data as part of its `tensor` submessages. Here we provide specialized
    571 // wrappers that avoid copying the tensor data wherever possible.
    572 //
    573 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
    574 // protocol buffer definition.
    575 //
    576 ////////////////////////////////////////////////////////////////////////////////
    577 
    578 // Abstract interface for a mutable RunStepResponse message.
    579 //
    580 // Note that there is no corresponding (immutable)
    581 // RunStepResponseWrapper class, because the RunStepResponse object is
    582 // always used as a mutable pointer.
    583 class MutableRunStepResponseWrapper {
    584  public:
    585   virtual ~MutableRunStepResponseWrapper();
    586 
    587   // The values of the tensors whose fetching was requested in the
    588   // RunStep call.
    589   //
    590   // NOTE: The order of the returned tensors may or may not match
    591   // the fetch order specified in RunStepRequest.
    592   virtual size_t num_tensors() const = 0;
    593   virtual const string& tensor_name(size_t i) const = 0;
    594   virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
    595 
    596   // Stores the i^{th} recv value in `run_graph_response` in this
    597   // response with the given `name`.
    598   virtual Status AddTensorFromRunGraphResponse(
    599       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    600       size_t i) = 0;
    601 
    602   // Returned metadata if requested in the options.
    603   virtual const RunMetadata& metadata() const = 0;
    604   virtual RunMetadata* mutable_metadata() = 0;
    605 
    606   // Returned status if requested.
    607   virtual errors::Code status_code() const = 0;
    608   virtual const string& status_error_message() const = 0;
    609   virtual void set_status(const Status& status) = 0;
    610 
    611  protected:
    612   // Returns a mutable protobuf message that represents the contents of
    613   // this wrapper, for passing to an RPC subsystem that will populate
    614   // the message.
    615   //
    616   // NOTE: Only `MasterInterface` subclasses may call this method. The
    617   // `InMemoryRunStepResponse` subclass does not implement this
    618   // method, and attempts to call it will fail with a fatal
    619   // error. However, as long as callers always call
    620   // `MasterInterface::RunStep()` with a wrapper object returned
    621   // from `MasterInterface::CreateRunStepResponse()` called on the
    622   // *same* MasterInterface object, this error will never trigger.
    623   virtual RunStepResponse* get_proto() = 0;
    624   friend class MasterInterface;
    625 };
    626 
    627 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
    628  public:
    629   // MutableRunStepResponseWrapper methods.
    630   size_t num_tensors() const override;
    631   const string& tensor_name(size_t i) const override;
    632   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    633   Status AddTensorFromRunGraphResponse(
    634       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    635       size_t i) override;
    636   const RunMetadata& metadata() const override;
    637   RunMetadata* mutable_metadata() override;
    638   errors::Code status_code() const override;
    639   const string& status_error_message() const override;
    640   void set_status(const Status& status) override;
    641 
    642  protected:
    643   // NOTE: This method is not implemented. See
    644   // MutableRunGraphResponseWrapper for an explanation.
    645   RunStepResponse* get_proto() override;
    646 
    647  private:
    648   gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
    649   RunMetadata metadata_;
    650   // Store the code and message separately so that they can be updated
    651   // independently by setters.
    652   Status status_;
    653 };
    654 
    655 // Proto-based message wrapper for use on the client side of the RunStep RPC.
    656 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
    657  public:
    658   // MutableRunStepResponseWrapper methods.
    659   size_t num_tensors() const override;
    660   const string& tensor_name(size_t i) const override;
    661   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    662   Status AddTensorFromRunGraphResponse(
    663       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    664       size_t i) override;
    665   const RunMetadata& metadata() const override;
    666   RunMetadata* mutable_metadata() override;
    667   errors::Code status_code() const override;
    668   const string& status_error_message() const override;
    669   void set_status(const Status& status) override;
    670 
    671  protected:
    672   RunStepResponse* get_proto() override;
    673 
    674  private:
    675   RunStepResponse response_;
    676 };
    677 
    678 // Proto-based message wrapper for use on the server side of the RunStep RPC.
    679 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
    680  public:
    681   NonOwnedProtoRunStepResponse(RunStepResponse* response);
    682 
    683   // MutableRunStepResponseWrapper methods.
    684   size_t num_tensors() const override;
    685   const string& tensor_name(size_t i) const override;
    686   Status TensorValue(size_t i, Tensor* out_tensor) const override;
    687   Status AddTensorFromRunGraphResponse(
    688       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
    689       size_t i) override;
    690   const RunMetadata& metadata() const override;
    691   RunMetadata* mutable_metadata() override;
    692   errors::Code status_code() const override;
    693   const string& status_error_message() const override;
    694   void set_status(const Status& status) override;
    695 
    696  protected:
    697   RunStepResponse* get_proto() override;
    698 
    699  private:
    700   RunStepResponse* response_;  // Not owned.
    701 };
    702 
    703 }  // namespace tensorflow
    704 
    705 #endif  // TENSORFLOW
    706