1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ 18 19 #include "grpc++/impl/codegen/proto_utils.h" 20 #include "grpc++/support/slice.h" 21 22 namespace grpc { 23 24 namespace tensorflow_helper { 25 26 const int kGrpcBufferWriterMaxBufferLength = 8192; 27 28 class GrpcBufferWriter final 29 : public ::grpc::protobuf::io::ZeroCopyOutputStream { 30 public: 31 explicit GrpcBufferWriter(grpc_byte_buffer** bp, int block_size) 32 : block_size_(block_size), byte_count_(0), have_backup_(false) { 33 *bp = g_core_codegen_interface->grpc_raw_byte_buffer_create(NULL, 0); 34 slice_buffer_ = &(*bp)->data.raw.slice_buffer; 35 } 36 37 ~GrpcBufferWriter() override { 38 if (have_backup_) { 39 g_core_codegen_interface->grpc_slice_unref(backup_slice_); 40 } 41 } 42 43 bool Next(void** data, int* size) override { 44 if (have_backup_) { 45 slice_ = backup_slice_; 46 have_backup_ = false; 47 } else { 48 slice_ = g_core_codegen_interface->grpc_slice_malloc(block_size_); 49 } 50 *data = GRPC_SLICE_START_PTR(slice_); 51 // On win x64, int is only 32bit 52 GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); 53 byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); 54 g_core_codegen_interface->grpc_slice_buffer_add(slice_buffer_, slice_); 55 return true; 56 } 57 58 void BackUp(int count) override { 59 g_core_codegen_interface->grpc_slice_buffer_pop(slice_buffer_); 60 if (count == block_size_) { 61 backup_slice_ = slice_; 62 } else { 63 backup_slice_ = g_core_codegen_interface->grpc_slice_split_tail( 64 &slice_, GRPC_SLICE_LENGTH(slice_) - count); 65 g_core_codegen_interface->grpc_slice_buffer_add(slice_buffer_, slice_); 66 } 67 // It's dangerous to keep an inlined grpc_slice as the backup slice, since 68 // on a following Next() call, a reference will be returned to this slice 69 // via GRPC_SLICE_START_PTR, which will not be an address held by 70 // slice_buffer_. 71 have_backup_ = backup_slice_.refcount != NULL; 72 byte_count_ -= count; 73 } 74 75 grpc::protobuf::int64 ByteCount() const override { return byte_count_; } 76 77 private: 78 const int block_size_; 79 int64_t byte_count_; 80 grpc_slice_buffer* slice_buffer_; 81 bool have_backup_; 82 grpc_slice backup_slice_; 83 grpc_slice slice_; 84 }; 85 86 class GrpcBufferReader final 87 : public ::grpc::protobuf::io::ZeroCopyInputStream { 88 typedef void (CoreCodegenInterface::*OldReaderInitAPI)( 89 grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); 90 typedef int (CoreCodegenInterface::*NewReaderInitAPI)( 91 grpc_byte_buffer_reader* reader, grpc_byte_buffer* buffer); 92 void ReaderInit(OldReaderInitAPI ptr, grpc_byte_buffer_reader* reader, 93 grpc_byte_buffer* buffer) { 94 (g_core_codegen_interface->*ptr)(reader, buffer); 95 } 96 void ReaderInit(NewReaderInitAPI ptr, grpc_byte_buffer_reader* reader, 97 grpc_byte_buffer* buffer) { 98 int result = (g_core_codegen_interface->*ptr)(reader, buffer); 99 (void)result; 100 } 101 102 public: 103 explicit GrpcBufferReader(grpc_byte_buffer* buffer) 104 : byte_count_(0), backup_count_(0) { 105 ReaderInit(&CoreCodegenInterface::grpc_byte_buffer_reader_init, &reader_, 106 buffer); 107 } 108 ~GrpcBufferReader() override { 109 g_core_codegen_interface->grpc_byte_buffer_reader_destroy(&reader_); 110 } 111 112 bool Next(const void** data, int* size) override { 113 if (backup_count_ > 0) { 114 *data = GRPC_SLICE_START_PTR(slice_) + GRPC_SLICE_LENGTH(slice_) - 115 backup_count_; 116 GPR_CODEGEN_ASSERT(backup_count_ <= INT_MAX); 117 *size = (int)backup_count_; 118 backup_count_ = 0; 119 return true; 120 } 121 if (!g_core_codegen_interface->grpc_byte_buffer_reader_next(&reader_, 122 &slice_)) { 123 return false; 124 } 125 g_core_codegen_interface->grpc_slice_unref(slice_); 126 *data = GRPC_SLICE_START_PTR(slice_); 127 // On win x64, int is only 32bit 128 GPR_CODEGEN_ASSERT(GRPC_SLICE_LENGTH(slice_) <= INT_MAX); 129 byte_count_ += * size = (int)GRPC_SLICE_LENGTH(slice_); 130 return true; 131 } 132 133 void BackUp(int count) override { backup_count_ = count; } 134 135 bool Skip(int count) override { 136 const void* data; 137 int size; 138 while (Next(&data, &size)) { 139 if (size >= count) { 140 BackUp(size - count); 141 return true; 142 } 143 // size < count; 144 count -= size; 145 } 146 // error or we have too large count; 147 return false; 148 } 149 150 grpc::protobuf::int64 ByteCount() const override { 151 return byte_count_ - backup_count_; 152 } 153 154 private: 155 int64_t byte_count_; 156 int64_t backup_count_; 157 grpc_byte_buffer_reader reader_; 158 grpc_slice slice_; 159 }; 160 161 } // namespace tensorflow_helper 162 163 // Defines specialized serialization/deserialization routines that 164 // default to allowing a 2GB max message size. 165 // 166 // To instantiate this template for a particular type `T`, use 167 // `TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(T)`, as defined below. 168 template <typename T> 169 class UnlimitedSizeProtoSerializationTraits { 170 public: 171 static Status Serialize(const T& msg, grpc_byte_buffer** bp, 172 bool* own_buffer) { 173 *own_buffer = true; 174 int byte_size = msg.ByteSize(); 175 if (byte_size < 0) { 176 return Status(StatusCode::INTERNAL, "Message length was negative"); 177 } else if (byte_size <= 178 tensorflow_helper::kGrpcBufferWriterMaxBufferLength) { 179 grpc_slice slice = g_core_codegen_interface->grpc_slice_malloc(byte_size); 180 GPR_CODEGEN_ASSERT( 181 GRPC_SLICE_END_PTR(slice) == 182 msg.SerializeWithCachedSizesToArray(GRPC_SLICE_START_PTR(slice))); 183 *bp = g_core_codegen_interface->grpc_raw_byte_buffer_create(&slice, 1); 184 g_core_codegen_interface->grpc_slice_unref(slice); 185 return g_core_codegen_interface->ok(); 186 } else { 187 tensorflow_helper::GrpcBufferWriter writer( 188 bp, tensorflow_helper::kGrpcBufferWriterMaxBufferLength); 189 return msg.SerializeToZeroCopyStream(&writer) 190 ? g_core_codegen_interface->ok() 191 : Status(StatusCode::INTERNAL, "Failed to serialize message"); 192 } 193 } 194 195 static Status Deserialize(grpc_byte_buffer* buffer, T* msg, 196 int max_message_size = INT_MAX) { 197 if (buffer == nullptr) { 198 return Status(StatusCode::INTERNAL, "No payload"); 199 } 200 Status result = g_core_codegen_interface->ok(); 201 { 202 tensorflow_helper::GrpcBufferReader reader(buffer); 203 ::grpc::protobuf::io::CodedInputStream decoder(&reader); 204 if (max_message_size == 0) { 205 // NOTE(mrry): Override maximum message size to 2GB. 206 decoder.SetTotalBytesLimit(INT_MAX, INT_MAX); 207 } else { 208 decoder.SetTotalBytesLimit(max_message_size, max_message_size); 209 } 210 if (!msg->ParseFromCodedStream(&decoder)) { 211 result = Status(StatusCode::INTERNAL, msg->InitializationErrorString()); 212 } 213 if (!decoder.ConsumedEntireMessage()) { 214 result = Status(StatusCode::INTERNAL, "Did not read entire message"); 215 } 216 } 217 g_core_codegen_interface->grpc_byte_buffer_destroy(buffer); 218 return result; 219 } 220 }; 221 222 } // namespace grpc 223 224 // For the given protobuf message type `MessageType`, specializes the 225 // gRPC serialization and deserialization such that the default 226 // maximum message size is 2GB. 227 #define TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(MessageType) \ 228 namespace grpc { \ 229 template <> \ 230 class SerializationTraits<MessageType> \ 231 : public UnlimitedSizeProtoSerializationTraits<MessageType> {}; \ 232 } // namespace grpc 233 234 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERIALIZATION_TRAITS_H_ 235