Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
     18 
     19 #include "grpc++/impl/codegen/async_stream.h"
     20 #include "grpc++/impl/codegen/async_unary_call.h"
     21 #include "grpc++/impl/codegen/proto_utils.h"
     22 #include "grpc++/impl/codegen/rpc_method.h"
     23 #include "grpc++/impl/codegen/service_type.h"
     24 #include "grpc++/impl/codegen/status.h"
     25 #include "grpc++/impl/codegen/stub_options.h"
     26 #include "grpc++/impl/codegen/sync_stream.h"
     27 #include "grpc++/support/byte_buffer.h"
     28 
     29 #include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
     30 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     31 #include "tensorflow/core/protobuf/worker.pb.h"
     32 
     33 // Contains potentially large GraphDef.
     34 TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RegisterGraphRequest);
     35 // Contains potentially large TensorProto.
     36 TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphRequest);
     37 // Contains potentially large StepStats, TensorProto.
     38 TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphResponse);
     39 
     40 namespace tensorflow {
     41 class GrpcByteSource : public TensorResponse::Source {
     42  public:
     43   explicit GrpcByteSource(grpc_byte_buffer* buffer) : buffer_(buffer) {}
     44   ~GrpcByteSource() override { DeleteStream(); }
     45 
     46   typedef ::grpc::tensorflow_helper::GrpcBufferReader Reader;
     47 
     48   protobuf::io::ZeroCopyInputStream* contents() override {
     49     DeleteStream();
     50     stream_ = new (&space_) Reader(buffer_);
     51     return stream_;
     52   }
     53 
     54  private:
     55   void DeleteStream() {
     56     if (stream_) {
     57       stream_->~Reader();
     58     }
     59   }
     60 
     61   grpc_byte_buffer* buffer_;  // Not owned
     62   Reader* stream_ = nullptr;  // Points into space_ if non-nullptr
     63   char space_[sizeof(Reader)];
     64 };
     65 }  // namespace tensorflow
     66 
     67 namespace grpc {
     68 class CompletionQueue;
     69 class Channel;
     70 class RpcService;
     71 class ServerCompletionQueue;
     72 class ServerContext;
     73 
     74 // Support parsing/unparsing of tensorflow::TensorResponse.
     75 // Wire-format is identical to RecvTensorResponse.
     76 template <>
     77 class SerializationTraits<tensorflow::TensorResponse>
     78     : public UnlimitedSizeProtoSerializationTraits<tensorflow::TensorResponse> {
     79  public:
     80   static Status Serialize(const tensorflow::TensorResponse& msg,
     81                           grpc_byte_buffer** bp, bool* own_buffer) {
     82     LOG(FATAL) << "TODO(sanjay,jeff): Implement";
     83     return Status();
     84   }
     85   static Status Deserialize(grpc_byte_buffer* buffer,
     86                             tensorflow::TensorResponse* msg,
     87                             int max_message_size = INT_MAX) {
     88     if (buffer == nullptr) {
     89       return Status(StatusCode::INTERNAL, "No payload");
     90     }
     91     Status result = g_core_codegen_interface->ok();
     92     if (result.ok()) {
     93       ::tensorflow::GrpcByteSource source(buffer);
     94       auto s = msg->ParseFrom(&source);
     95       if (!s.ok()) {
     96         result = Status(StatusCode::INTERNAL,
     97                         ::tensorflow::strings::StrCat(
     98                             "TensorResponse parse error", s.ToString()));
     99       }
    100     }
    101     g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
    102     return result;
    103   }
    104 };
    105 }  // namespace grpc
    106 
    107 namespace tensorflow {
    108 
    109 // Names of worker methods.
    110 enum class GrpcWorkerMethod {
    111   kGetStatus,
    112   kCreateWorkerSession,
    113   kDeleteWorkerSession,
    114   kRegisterGraph,
    115   kDeregisterGraph,
    116   kRunGraph,
    117   kCleanupGraph,
    118   kCleanupAll,
    119   kRecvTensor,
    120   kLogging,
    121   kTracing,
    122 };
    123 static const int kGrpcNumWorkerMethods =
    124     static_cast<int>(GrpcWorkerMethod::kTracing) + 1;
    125 
    126 const char* GrpcWorkerMethodName(GrpcWorkerMethod id);
    127 
    128 namespace grpc {
    129 
    130 // Implementation of `tensorflow.WorkerService`, based on the
    131 // definition in "//tensorflow/core/protobuf/worker_service.proto",
    132 // and the gRPC generated stub and service classes.
    133 // See the proto file for the definition of methods and messages.
    134 class WorkerService final {
    135  public:
    136   class AsyncService : public ::grpc::Service {
    137    public:
    138     AsyncService();
    139     virtual ~AsyncService();
    140 
    141     // Make RequestAsyncUnary public for grpc_call.h
    142     using ::grpc::Service::RequestAsyncUnary;
    143   };
    144 };
    145 
    146 }  // namespace grpc
    147 
    148 }  // namespace tensorflow
    149 
    150 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
    151