1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 18 19 #include <memory> 20 21 #include "grpc++/grpc++.h" 22 #include "grpc++/security/credentials.h" 23 24 #include "tensorflow/core/common_runtime/process_util.h" 25 #include "tensorflow/core/distributed_runtime/master_env.h" 26 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 27 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 28 #include "tensorflow/core/distributed_runtime/server_lib.h" 29 #include "tensorflow/core/distributed_runtime/session_mgr.h" 30 #include "tensorflow/core/distributed_runtime/worker_env.h" 31 #include "tensorflow/core/framework/op.h" 32 #include "tensorflow/core/platform/env.h" 33 34 namespace tensorflow { 35 36 class GrpcWorker; 37 class Master; 38 39 // function that creates a RendezvousMgr. 40 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)> 41 RendezvousMgrCreationFunction; 42 43 // function that registers a service to the server. The service needs to 44 // be registered before builder.BuildAndStart(). 45 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)> 46 ServiceInitFunction; 47 48 // function that creates a grpc based worker implementation. 49 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*)> 50 WorkerCreationFunction; 51 52 class GrpcServer : public ServerInterface { 53 protected: 54 GrpcServer(const ServerDef& server_def, Env* env); 55 56 public: 57 static Status Create(const ServerDef& server_def, Env* env, 58 std::unique_ptr<ServerInterface>* out_server); 59 60 // Destruction is only supported in the factory method. Clean 61 // shutdown is not currently implemented for this server type. 62 virtual ~GrpcServer(); 63 64 // Implementations of ServerInterface methods. 65 Status Start() override; 66 Status Stop() override; 67 Status Join() override; 68 const string target() const override; 69 70 protected: 71 Status Init(ServiceInitFunction service_func, 72 const RendezvousMgrCreationFunction& rendezvous_mgr_func, 73 const WorkerCreationFunction& worker_func); 74 75 Status Init(ServiceInitFunction service_func, 76 const RendezvousMgrCreationFunction& rendezvous_mgr_func); 77 78 Status Init(); 79 80 // A subclass can override this method to support secure credentials. 81 virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( 82 const ServerDef& server_def) const; 83 84 virtual ChannelCreationFunction GetChannelCreationFunction() const; 85 86 virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env); 87 88 // Creates a WorkerCacheInterface for a session. 89 Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 90 WorkerCacheInterface** worker_cache); 91 92 // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. 93 Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 94 GrpcChannelSpec* channel_spec); 95 96 // Returns the port to which this server is bound. 97 // This method may only be called after `this->Init()` returns successfully. 98 int bound_port() const { return bound_port_; } 99 100 WorkerEnv* worker_env() { return &worker_env_; } 101 102 const ServerDef& server_def() const { return server_def_; } 103 104 private: 105 // The overall server configuration. 106 const ServerDef server_def_; 107 Env* env_; 108 109 // The port to which this server is bound. 110 int bound_port_ = 0; 111 112 // Guards state transitions. 113 mutex mu_; 114 115 // Represents the current state of the server, which changes as follows: 116 // 117 // Join() Join() 118 // ___ ___ 119 // Start() \ / Stop() \ / 120 // NEW ---------> STARTED --------> STOPPED 121 // \ / 122 // \________________________/ 123 // Stop(), Join() 124 enum State { NEW, STARTED, STOPPED }; 125 State state_ GUARDED_BY(mu_); 126 127 // Implementation of a TensorFlow master, and RPC polling thread. 128 MasterEnv master_env_; 129 std::unique_ptr<Master> master_impl_; 130 AsyncServiceInterface* master_service_ = nullptr; 131 std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_); 132 133 // Implementation of a TensorFlow worker, and RPC polling thread. 134 WorkerEnv worker_env_; 135 std::unique_ptr<GrpcWorker> worker_impl_; 136 AsyncServiceInterface* worker_service_ = nullptr; 137 std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_); 138 139 std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_); 140 }; 141 142 } // namespace tensorflow 143 144 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ 145