1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/distributed_runtime/call_options.h" 24 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 26 #include "tensorflow/core/framework/graph.pb.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/protobuf/config.pb.h" 35 #include "tensorflow/core/protobuf/master.pb.h" 36 #include "tensorflow/core/public/session.h" 37 #include "tensorflow/core/public/session_options.h" 38 39 namespace tensorflow { 40 41 class MasterInterface; 42 43 // A Session instance lets the caller drive a TensorFlow graph 44 // computation on potentially remote sets of devices. This is a thin 45 // wrapper around tensorflow::grpc::MasterService. 46 // 47 // Multiple threads must synchronize their accesses to a single 48 // session. 49 class GrpcSession : public Session { 50 protected: 51 explicit GrpcSession(const SessionOptions& options); 52 53 public: 54 static Status Create(const SessionOptions& options, 55 std::unique_ptr<GrpcSession>* out_session); 56 // Resets the resource containers. 57 static Status Reset(const SessionOptions& options, 58 const std::vector<string>& containers); 59 60 ~GrpcSession() override; 61 62 // Creates a session with the "target". The session carries out 63 // the graph computation defined by "graph", and will have version 64 // number "initial_version". 65 Status Create(const GraphDef& graph) override; 66 Status Create(const RunOptions& run_options, const GraphDef& graph) override; 67 68 // Runs with and without RunOptions. 69 Status Run(const std::vector<std::pair<string, Tensor> >& inputs, 70 const std::vector<string>& output_tensor_names, 71 const std::vector<string>& target_node_names, 72 std::vector<Tensor>* outputs) override; 73 Status Run(const RunOptions& run_options, 74 const std::vector<std::pair<string, Tensor> >& inputs, 75 const std::vector<string>& output_tensor_names, 76 const std::vector<string>& target_node_names, 77 std::vector<Tensor>* outputs, RunMetadata* run_metadata) override; 78 79 Status Extend(const GraphDef& graph) override; 80 Status Extend(const RunOptions& run_options, const GraphDef& graph) override; 81 82 Status Close() override; 83 84 // NOTE: This API is still experimental and may change. 85 ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, 86 const std::vector<string>& output_names, 87 const std::vector<string>& target_nodes, 88 string* handle) override; 89 90 // NOTE: This API is still experimental and may change. 91 ::tensorflow::Status PRun( 92 const string& handle, 93 const std::vector<std::pair<string, Tensor> >& inputs, 94 const std::vector<string>& output_names, 95 std::vector<Tensor>* outputs) override; 96 97 Status ListDevices(std::vector<DeviceAttributes>* response) override; 98 99 protected: 100 // Takes ownership of `*master`. 101 void SetRemoteMaster(std::unique_ptr<MasterInterface> master); 102 103 private: 104 SessionOptions options_; 105 std::unique_ptr<MasterInterface> master_; 106 mutex mu_; 107 108 // handle_ returned by the master to identify this session. 109 string handle_ GUARDED_BY(mu_); 110 111 // The current version of the graph. 112 int64 current_graph_version_ GUARDED_BY(mu_); 113 114 Status RunHelper(const RunOptions& run_options, 115 const std::vector<std::pair<string, Tensor> >& inputs, 116 const std::vector<string>& output_tensor_names, 117 const std::vector<string>& target_node_names, 118 std::vector<Tensor>* outputs, RunMetadata* run_metadata, 119 const string& prun_handle); 120 121 Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, 122 MutableRunStepResponseWrapper* resp); 123 124 // Implementations for all the public interfaces. 125 Status CreateImpl(CallOptions* call_options, const GraphDef& graph); 126 Status ExtendImpl(CallOptions* call_options, const GraphDef& graph); 127 128 TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession); 129 }; 130 131 } // namespace tensorflow 132 133 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ 134