1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ 18 19 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" 20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/util/rpc/call_container.h" 24 #include "tensorflow/core/util/rpc/rpc_factory.h" 25 26 namespace tensorflow { 27 28 // Forward declaration of GrpcCall. 29 namespace internal { 30 class GrpcCall; 31 } // namespace internal 32 33 class GrpcRPCFactory : public RPCFactory { 34 public: 35 explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, 36 int64 timeout_in_ms); 37 38 // Explicit destructor to control destruction order. 39 ~GrpcRPCFactory() override; 40 41 void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t, 42 const Tensor& method_t, const Tensor& request_t, const bool try_rpc, 43 Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t, 44 AsyncOpKernel::DoneCallback done) override; 45 46 protected: 47 typedef std::shared_ptr<::grpc::Channel> ChannelPtr; 48 virtual ChannelPtr CreateChannelForAddress(const string& address); 49 50 private: 51 // Creates a call and registers it with given `container`. The `index` is used 52 // to index into the tensor arguments. 53 void CreateCall(const Tensor& request_t, const bool try_rpc, int index, 54 CallContainer<internal::GrpcCall>* container, 55 Tensor* response_t, Tensor* status_code_t, 56 Tensor* status_message_t); 57 58 // Asynchronously invokes the given `call`. The call completion is handled 59 // by the call container the call was previously registered with. 60 void StartCall(const Tensor& address_t, const Tensor& method_t, 61 internal::GrpcCall* call); 62 63 ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address); 64 65 bool fail_fast_; 66 int64 timeout_in_ms_; 67 ::grpc::CompletionQueue completion_queue_; 68 Thread* polling_thread_; // Owned. 69 70 mutex mu_; 71 typedef std::unique_ptr<::grpc::GenericStub> StubPtr; 72 std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_); 73 }; 74 75 } // namespace tensorflow 76 77 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ 78