1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h" 17 #include "grpc++/support/byte_buffer.h" 18 #include "grpc++/support/slice.h" 19 #include "tensorflow/core/common_runtime/dma_helper.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor.pb.h" 22 #include "tensorflow/core/framework/tensor_reference.h" 23 #include "tensorflow/core/framework/tensor_shape.pb.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/lib/io/proto_encode_helper.h" 26 #include "tensorflow/core/platform/env.h" 27 #include "tensorflow/core/protobuf/worker.pb.h" 28 29 namespace tensorflow { 30 namespace grpc { 31 32 void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto, 33 ::grpc::ByteBuffer* result) { 34 ::grpc::Slice slice(proto.ByteSizeLong()); 35 proto.SerializeWithCachedSizesToArray( 36 const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin()))); 37 ::grpc::ByteBuffer tmp(&slice, 1); 38 result->Swap(&tmp); 39 } 40 41 // We generate a RecvTensorResponse protocol buffer encoding into "*result", 42 // but where possible, we share the underlying Tensor buffer for "val", to 43 // avoid an extra copy. 44 // 45 // We hand-encode the protocol buffer data in the following order, as follows: 46 // 47 // Let R be a RecvTensorResponse object we want to encode, logically 48 // constructed by filling in data from "is_dead" and "val" and filling 49 // in a few other fields as well. 50 // 51 // (Letters here are used in the code to refer back to which part of the 52 // encoding the code is generating). 53 // 54 // A: <protocol buffer encoding of fields except R.tensor()> 55 // B1: <tag encoding for RecvTensorResponse::tensor> 56 // B2: <varint32 length of R.tensor() sub message> 57 // C: <protocol buffer encoding of R.tensor() except for 58 // R.tensor().tensor_content()> 59 // D1: <tag encoding for TensorProto::tensor_content> 60 // D2: <varint32 length of R.tensor().tensor_content() data> 61 // E: <actual data for val's representation> 62 // 63 // If the tensor data is up to "kLargeTensorBytes", then A 64 // through E will all be encoded into "*result" in a single grpc::Slice. 65 // 66 // If the tensor data is larger than "kLargeTensorBytes", then A through 67 // D2 will be encoded in one grpc::Slice, and E will be encoded in a second 68 // grpc::Slice that points to the backing store for the tensor data, to avoid 69 // copying the tensor data (and the grpc::Slice setup will be arrange so as 70 // to dereference the underlying tensor data buffer when it is no longer 71 // needed in the "*result" ByteBuffer). 72 static int VarLengthEncodingSize(uint32 tag, size_t bytes) { 73 return core::VarintLength(tag << 3) + core::VarintLength(bytes) + bytes; 74 } 75 76 // Returns an upper bound in bytes of the protocol buffer encoding of 77 // the "skeleton" of "val" (all the data needed for dtype and the shape, 78 // but not the actual contents of "val"). 79 static int SkeletonEncodingSizeUpperBound(const Tensor& val) { 80 static const int kVarintMax64 = 10; // Max length of varint64 encoding 81 const int ndims = val.shape().dims(); 82 return (2 * kVarintMax64) + // dtype 83 (ndims * (4 * kVarintMax64)); // Shape: 4 varints per dim 84 } 85 86 // Encode the skeleton for "val" (the encoded TensorProto contents 87 // (dtype and shape, but not the actual data) into "*e". The backing 88 // store for "*e" must be of appropriate size to hold this encoding. 89 static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) { 90 // Encode val.dtype() 91 e->WriteUint64(TensorProto::kDtypeFieldNumber, val.dtype()); 92 93 // Compute length of val.shape() proto encoding 94 const int ndims = val.shape().dims(); 95 int tensor_shape_bytes = 0; 96 for (int d = 0; d < ndims; d++) { 97 int64 dim_size = val.shape().dim_size(d); 98 tensor_shape_bytes += 99 2 + // TensorShapeProto dim tag + varintlength of submessage 100 1 + // TensorShapeProto_Dim::kSizeFieldNumber 101 core::VarintLength(dim_size); 102 } 103 104 if (tensor_shape_bytes > 0) { 105 e->WriteVarlengthBeginning(TensorProto::kTensorShapeFieldNumber, 106 tensor_shape_bytes); 107 // Encode val.shape() 108 for (int d = 0; d < ndims; d++) { 109 int64 dim_size = val.shape().dim_size(d); 110 int64 dim_varlen = 1 + // TensorShapeProto_Dim::kSizeFieldNumber 111 core::VarintLength(dim_size); 112 e->WriteVarlengthBeginning(TensorShapeProto::kDimFieldNumber, dim_varlen); 113 e->WriteUint64(TensorShapeProto_Dim::kSizeFieldNumber, dim_size); 114 } 115 } 116 117 #ifndef NDEBUG 118 { 119 // Debug-mode only check to make sure the encoding above is 120 // identical to the auto-generated protocol buffer encoding. 121 TensorProto skeleton; 122 skeleton.set_dtype(val.dtype()); 123 val.shape().AsProto(skeleton.mutable_tensor_shape()); 124 string tensor_except_contents; // tensor() field except contents 125 skeleton.AppendToString(&tensor_except_contents); 126 TensorProto skeleton2; 127 skeleton2.ParseFromString(string(e->data(), e->size())); 128 string out; 129 skeleton.AppendToString(&out); 130 DCHECK_EQ(tensor_except_contents, out) << skeleton.DebugString() << " vs\n" 131 << skeleton2.DebugString(); 132 } 133 #endif 134 } 135 136 void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, 137 ::grpc::ByteBuffer* result) { 138 const int kLargeTensorBytes = 1024; 139 RecvTensorResponse response; 140 if (is_dead) { 141 response.set_is_dead(is_dead); 142 } 143 response.set_send_start_micros(Env::Default()->NowMicros()); 144 if (!DataTypeCanUseMemcpy(val.dtype())) { 145 // Straightforward but slow path for complicated kinds of tensor data 146 // TODO(jeff,sanjay): If this becomes an issue, we could 147 // go directly from val -> ByteBuffer, with some effort. 148 val.AsProtoTensorContent(response.mutable_tensor()); 149 150 // Encode full protocol buffer to a ByteBuffer 151 EncodeRecvTensorResponseToByteBuffer(response, result); 152 } else { 153 // skeleton is the encoded TensorProto contents (dtype and shape), but 154 // not the actual data 155 gtl::InlinedVector<char, 128> skeleton(SkeletonEncodingSizeUpperBound(val)); 156 io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size()); 157 EncodeSkeleton(val, &e_skeleton); 158 159 StringPiece tdata = val.tensor_data(); 160 uint32 overall_tensor_proto_bytesize = 161 (e_skeleton.size() + 162 VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber, 163 tdata.size())); 164 string header; // All of RecvTensorResponse except the tensor() field 165 response.AppendToString(&header); 166 167 size_t expected_size = 168 (header.size() + 169 VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber, 170 overall_tensor_proto_bytesize)); 171 // If "tensor_data_is_large == false", we copy the tensor data to the 172 // end of the buffer we are preparing that holds the rest of the 173 // RecvTensorResponse protocol buffer. 174 // 175 // If "tensor_data_is_large == true", we arrange to share the backing 176 // store of the data by creating a slice that also points to the 177 // backing store, with appropriate reference counts to keep the 178 // backing store alive as needed. 179 bool tensor_data_is_large = (tdata.size() > kLargeTensorBytes); 180 size_t encoder_size = expected_size - tdata.size(); 181 182 // Encode all but the actual "tdata", but including the tag and 183 // varlength header for the "tdata" 184 gtl::InlinedVector<char, 1024> space(encoder_size); 185 io::ProtoEncodeHelper e(space.data(), space.size()); 186 // (A) 187 e.WriteRawBytes(header); 188 189 // (B1) & (B2) 190 e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber, 191 overall_tensor_proto_bytesize); 192 // (C) 193 e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size())); 194 // (D1) & (D2) 195 e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber, 196 tdata.size()); 197 198 // All but the tensor backing store are serialized now 199 200 // Now allocate memory and put into the ByteBuffer 201 ::grpc::Slice slices[2]; 202 int num_slices = 0; 203 { 204 size_t slice_len = e.size() + (tensor_data_is_large ? 0 : tdata.size()); 205 slices[0] = ::grpc::Slice(slice_len); 206 memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size()); 207 if (!tensor_data_is_large) { 208 // (E) 209 memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(), 210 tdata.size()); 211 } 212 num_slices += 1; 213 } 214 215 if (tensor_data_is_large) { 216 // (E) Encode tensor data, but by sharing backing store 217 const TensorBuffer* buf = DMAHelper::buffer(&val); 218 buf->Ref(); 219 slices[1] = ::grpc::Slice( 220 const_cast<void*>(static_cast<const void*>(tdata.data())), 221 tdata.size(), 222 [](void* backing) { static_cast<TensorBuffer*>(backing)->Unref(); }, 223 const_cast<TensorBuffer*>(buf)); 224 num_slices += 1; 225 } 226 size_t total_bytes = 0; 227 for (int i = 0; i < num_slices; i++) { 228 total_bytes += slices[i].size(); 229 } 230 CHECK_EQ(total_bytes, expected_size); 231 232 ::grpc::ByteBuffer tmp(&slices[0], num_slices); 233 result->Swap(&tmp); 234 } 235 } 236 237 } // namespace grpc 238 } // namespace tensorflow 239