Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
     18 
     19 #include <memory>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/distributed_runtime/call_options.h"
     24 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     26 #include "tensorflow/core/framework/graph.pb.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/protobuf/config.pb.h"
     35 #include "tensorflow/core/protobuf/master.pb.h"
     36 #include "tensorflow/core/public/session.h"
     37 #include "tensorflow/core/public/session_options.h"
     38 
     39 namespace tensorflow {
     40 
     41 class MasterInterface;
     42 
     43 // A Session instance lets the caller drive a TensorFlow graph
     44 // computation on potentially remote sets of devices. This is a thin
     45 // wrapper around tensorflow::grpc::MasterService.
     46 //
     47 // Multiple threads must synchronize their accesses to a single
     48 // session.
     49 class GrpcSession : public Session {
     50  protected:
     51   explicit GrpcSession(const SessionOptions& options);
     52 
     53  public:
     54   static Status Create(const SessionOptions& options,
     55                        std::unique_ptr<GrpcSession>* out_session);
     56   // Resets the resource containers.
     57   static Status Reset(const SessionOptions& options,
     58                       const std::vector<string>& containers);
     59 
     60   ~GrpcSession() override;
     61 
     62   // Creates a session with the "target". The session carries out
     63   // the graph computation defined by "graph", and will have version
     64   // number "initial_version".
     65   Status Create(const GraphDef& graph) override;
     66   Status Create(const RunOptions& run_options, const GraphDef& graph) override;
     67 
     68   // Runs with and without RunOptions.
     69   Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
     70              const std::vector<string>& output_tensor_names,
     71              const std::vector<string>& target_node_names,
     72              std::vector<Tensor>* outputs) override;
     73   Status Run(const RunOptions& run_options,
     74              const std::vector<std::pair<string, Tensor> >& inputs,
     75              const std::vector<string>& output_tensor_names,
     76              const std::vector<string>& target_node_names,
     77              std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
     78 
     79   Status Extend(const GraphDef& graph) override;
     80   Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
     81 
     82   Status Close() override;
     83 
     84   // NOTE: This API is still experimental and may change.
     85   ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
     86                                  const std::vector<string>& output_names,
     87                                  const std::vector<string>& target_nodes,
     88                                  string* handle) override;
     89 
     90   // NOTE: This API is still experimental and may change.
     91   ::tensorflow::Status PRun(
     92       const string& handle,
     93       const std::vector<std::pair<string, Tensor> >& inputs,
     94       const std::vector<string>& output_names,
     95       std::vector<Tensor>* outputs) override;
     96 
     97   Status ListDevices(std::vector<DeviceAttributes>* response) override;
     98 
     99  protected:
    100   // Takes ownership of `*master`.
    101   void SetRemoteMaster(std::unique_ptr<MasterInterface> master);
    102 
    103  private:
    104   SessionOptions options_;
    105   std::unique_ptr<MasterInterface> master_;
    106   mutex mu_;
    107 
    108   // handle_ returned by the master to identify this session.
    109   string handle_ GUARDED_BY(mu_);
    110 
    111   // The current version of the graph.
    112   int64 current_graph_version_ GUARDED_BY(mu_);
    113 
    114   Status RunHelper(const RunOptions& run_options,
    115                    const std::vector<std::pair<string, Tensor> >& inputs,
    116                    const std::vector<string>& output_tensor_names,
    117                    const std::vector<string>& target_node_names,
    118                    std::vector<Tensor>* outputs, RunMetadata* run_metadata,
    119                    const string& prun_handle);
    120 
    121   Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req,
    122                   MutableRunStepResponseWrapper* resp);
    123 
    124   // Implementations for all the public interfaces.
    125   Status CreateImpl(CallOptions* call_options, const GraphDef& graph);
    126   Status ExtendImpl(CallOptions* call_options, const GraphDef& graph);
    127 
    128   TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
    129 };
    130 
    131 }  // namespace tensorflow
    132 
    133 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
    134