Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
     18 
     19 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
     20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor_types.h"
     23 #include "tensorflow/core/util/rpc/call_container.h"
     24 #include "tensorflow/core/util/rpc/rpc_factory.h"
     25 
     26 namespace tensorflow {
     27 
     28 // Forward declaration of GrpcCall.
     29 namespace internal {
     30 class GrpcCall;
     31 }  // namespace internal
     32 
     33 class GrpcRPCFactory : public RPCFactory {
     34  public:
     35   explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
     36                           int64 timeout_in_ms);
     37 
     38   // Explicit destructor to control destruction order.
     39   ~GrpcRPCFactory() override;
     40 
     41   void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
     42             const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
     43             Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
     44             AsyncOpKernel::DoneCallback done) override;
     45 
     46  protected:
     47   typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
     48   virtual ChannelPtr CreateChannelForAddress(const string& address);
     49 
     50  private:
     51   // Creates a call and registers it with given `container`. The `index` is used
     52   // to index into the tensor arguments.
     53   void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
     54                   CallContainer<internal::GrpcCall>* container,
     55                   Tensor* response_t, Tensor* status_code_t,
     56                   Tensor* status_message_t);
     57 
     58   // Asynchronously invokes the given `call`. The call completion is handled
     59   // by the call container the call was previously registered with.
     60   void StartCall(const Tensor& address_t, const Tensor& method_t,
     61                  internal::GrpcCall* call);
     62 
     63   ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
     64 
     65   bool fail_fast_;
     66   int64 timeout_in_ms_;
     67   ::grpc::CompletionQueue completion_queue_;
     68   Thread* polling_thread_;  // Owned.
     69 
     70   mutex mu_;
     71   typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
     72   std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_);
     73 };
     74 
     75 }  // namespace tensorflow
     76 
     77 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
     78