Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
     18 
     19 #include "grpc++/impl/codegen/proto_utils.h"
     20 #include "grpc++/support/slice.h"
     21 
     22 namespace grpc {
     23 
     24 namespace tensorflow_helper {
     25 
     26 const int kGrpcBufferWriterMaxBufferLength = 8192;
     27 
     28 class GrpcBufferWriter final
     29     : public ::grpc::protobuf::io::ZeroCopyOutputStream {
     30  public:
     31   explicit GrpcBufferWriter(grpc_byte_buffer** bp, int block_size)
     32       : block_size_(block_size), byte_count_(0), have_backup_(false) {
     33     *bp = g_core_codegen_interface->grpc_raw_byte_buffer_create(NULL, 0);
     34     slice_buffer_ = &(*bp)->data.raw.slice_buffer;
     35   }
     36 
     37   ~GrpcBufferWriter() override {
     38     if (have_backup_) {
     39       g_core_codegen_interface->grpc_slice_unref(backup_slice_);
     40     }
     41   }
     42 
     43   bool Next(void** data, int* size) override {
     44     if (have_backup_) {
     45       slice_ = backup_slice_;
     46       have_backup_ = false;
     47     } else {
     48       slice_ = g_core_codegen_interface->grpc_slice_malloc(block_size_);
     49     }
     50     *data = GRPC_SLICE_START_PTR(slice_);
     51     // On win x64, int is only 32bit
     52     GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
     53     byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_);
     54     g_core_codegen_interface->grpc_slice_buffer_add(slice_buffer_, slice_);
     55     return true;
     56   }
     57 
     58   void BackUp(int count) override {
     59     g_core_codegen_interface->grpc_slice_buffer_pop(slice_buffer_);
     60     if (count == block_size_) {
     61       backup_slice_ = slice_;
     62     } else {
     63       backup_slice_ = g_core_codegen_interface->grpc_slice_split_tail(
     64           &slice_, GRPC_SLICE_LENGTH(slice_) - count);
     65       g_core_codegen_interface->grpc_slice_buffer_add(slice_buffer_, slice_);
     66     }
     67     // It's dangerous to keep an inlined grpc_slice as the backup slice, since
     68     // on a following Next() call, a reference will be returned to this slice
     69     // via GRPC_SLICE_START_PTR, which will not be an address held by
     70     // slice_buffer_.
     71     have_backup_ = backup_slice_.refcount != NULL;
     72     byte_count_ -= count;
     73   }
     74 
     75   grpc::protobuf::int64 ByteCount() const override { return byte_count_; }
     76 
     77  private:
     78   const int block_size_;
     79   int64_t byte_count_;
     80   grpc_slice_buffer* slice_buffer_;
     81   bool have_backup_;
     82   grpc_slice backup_slice_;
     83   grpc_slice slice_;
     84 };
     85 
     86 class GrpcBufferReader final
     87     : public ::grpc::protobuf::io::ZeroCopyInputStream {
     88   typedef void (CoreCodegenInterface::*OldReaderInitAPI)(
     89       grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
     90   typedef int (CoreCodegenInterface::*NewReaderInitAPI)(
     91       grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer);
     92   void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
     93                   grpc_byte_buffer* buffer) {
     94     (g_core_codegen_interface->*ptr)(reader, buffer);
     95   }
     96   void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader,
     97                   grpc_byte_buffer* buffer) {
     98     int result = (g_core_codegen_interface->*ptr)(reader, buffer);
     99     (void)result;
    100   }
    101 
    102  public:
    103   explicit GrpcBufferReader(grpc_byte_buffer* buffer)
    104       : byte_count_(0), backup_count_(0) {
    105     ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_,
    106                buffer);
    107   }
    108   ~GrpcBufferReader() override {
    109     g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_);
    110   }
    111 
    112   bool Next(const void** data, int* size) override {
    113     if (backup_count_ > 0) {
    114       *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) -
    115               backup_count_;
    116       GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX);
    117       *size = (int)backup_count_;
    118       backup_count_ = 0;
    119       return true;
    120     }
    121     if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_,
    122                                                                 &slice_)) {
    123       return false;
    124     }
    125     g_core_codegen_interface->grpc_slice_unref(slice_);
    126     *data = GRPC_SLICE_START_PTR(slice_);
    127     // On win x64, int is only 32bit
    128     GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX);
    129     byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_);
    130     return true;
    131   }
    132 
    133   void BackUp(int count) override { backup_count_ = count; }
    134 
    135   bool Skip(int count) override {
    136     const void* data;
    137     int size;
    138     while (Next(&data, &size)) {
    139       if (size >= count) {
    140         BackUp(size - count);
    141         return true;
    142       }
    143       // size < count;
    144       count -= size;
    145     }
    146     // error or we have too large count;
    147     return false;
    148   }
    149 
    150   grpc::protobuf::int64 ByteCount() const override {
    151     return byte_count_ - backup_count_;
    152   }
    153 
    154  private:
    155   int64_t byte_count_;
    156   int64_t backup_count_;
    157   grpc_byte_buffer_reader reader_;
    158   grpc_slice slice_;
    159 };
    160 
    161 }  // namespace tensorflow_helper
    162 
    163 // Defines specialized serialization/deserialization routines that
    164 // default to allowing a 2GB max message size.
    165 //
    166 // To instantiate this template for a particular type `T`, use
    167 // `TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(T)`, as defined below.
    168 template <typename T>
    169 class UnlimitedSizeProtoSerializationTraits {
    170  public:
    171   static Status Serialize(const T& msg, grpc_byte_buffer** bp,
    172                           bool* own_buffer) {
    173     *own_buffer = true;
    174     int byte_size = msg.ByteSize();
    175     if (byte_size < 0) {
    176       return Status(StatusCode::INTERNAL, "Message length was negative");
    177     } else if (byte_size <=
    178                tensorflow_helper::kGrpcBufferWriterMaxBufferLength) {
    179       grpc_slice slice = g_core_codegen_interface->grpc_slice_malloc(byte_size);
    180       GPR_CODEGEN_ASSERT(
    181           GRPC_SLICE_END_PTR(slice) ==
    182           msg.SerializeWithCachedSizesToArray(GRPC_SLICE_START_PTR(slice)));
    183       *bp = g_core_codegen_interface->grpc_raw_byte_buffer_create(&slice, 1);
    184       g_core_codegen_interface->grpc_slice_unref(slice);
    185       return g_core_codegen_interface->ok();
    186     } else {
    187       tensorflow_helper::GrpcBufferWriter writer(
    188           bp, tensorflow_helper::kGrpcBufferWriterMaxBufferLength);
    189       return msg.SerializeToZeroCopyStream(&writer)
    190                  ? g_core_codegen_interface->ok()
    191                  : Status(StatusCode::INTERNAL, "Failed to serialize message");
    192     }
    193   }
    194 
    195   static Status Deserialize(grpc_byte_buffer* buffer, T* msg,
    196                             int max_message_size = INT_MAX) {
    197     if (buffer == nullptr) {
    198       return Status(StatusCode::INTERNAL, "No payload");
    199     }
    200     Status result = g_core_codegen_interface->ok();
    201     {
    202       tensorflow_helper::GrpcBufferReader reader(buffer);
    203       ::grpc::protobuf::io::CodedInputStream decoder(&reader);
    204       if (max_message_size == 0) {
    205         // NOTE(mrry): Override maximum message size to 2GB.
    206         decoder.SetTotalBytesLimit(INT_MAX, INT_MAX);
    207       } else {
    208         decoder.SetTotalBytesLimit(max_message_size, max_message_size);
    209       }
    210       if (!msg->ParseFromCodedStream(&decoder)) {
    211         result = Status(StatusCode::INTERNAL, msg->InitializationErrorString());
    212       }
    213       if (!decoder.ConsumedEntireMessage()) {
    214         result = Status(StatusCode::INTERNAL, "Did not read entire message");
    215       }
    216     }
    217     g_core_codegen_interface->grpc_byte_buffer_destroy(buffer);
    218     return result;
    219   }
    220 };
    221 
    222 }  // namespace grpc
    223 
    224 // For the given protobuf message type `MessageType`, specializes the
    225 // gRPC serialization and deserialization such that the default
    226 // maximum message size is 2GB.
    227 #define TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(MessageType)             \
    228   namespace grpc {                                                    \
    229   template <>                                                         \
    230   class SerializationTraits<MessageType>                              \
    231       : public UnlimitedSizeProtoSerializationTraits<MessageType> {}; \
    232   }  // namespace grpc
    233 
    234 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_
    235