Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
     17 #include "grpc++/support/byte_buffer.h"
     18 #include "grpc++/support/slice.h"
     19 #include "tensorflow/core/common_runtime/dma_helper.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor.pb.h"
     22 #include "tensorflow/core/framework/tensor_reference.h"
     23 #include "tensorflow/core/framework/tensor_shape.pb.h"
     24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     25 #include "tensorflow/core/lib/io/proto_encode_helper.h"
     26 #include "tensorflow/core/platform/env.h"
     27 #include "tensorflow/core/protobuf/worker.pb.h"
     28 
     29 namespace tensorflow {
     30 namespace grpc {
     31 
     32 void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
     33                                           ::grpc::ByteBuffer* result) {
     34   ::grpc::Slice slice(proto.ByteSizeLong());
     35   proto.SerializeWithCachedSizesToArray(
     36       const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
     37   ::grpc::ByteBuffer tmp(&slice, 1);
     38   result->Swap(&tmp);
     39 }
     40 
     41 // We generate a RecvTensorResponse protocol buffer encoding into "*result",
     42 // but where possible, we share the underlying Tensor buffer for "val", to
     43 // avoid an extra copy.
     44 //
     45 // We hand-encode the protocol buffer data in the following order, as follows:
     46 //
     47 // Let R be a RecvTensorResponse object we want to encode, logically
     48 // constructed by filling in data from "is_dead" and "val" and filling
     49 // in a few other fields as well.
     50 //
     51 // (Letters here are used in the code to refer back to which part of the
     52 //  encoding the code is generating).
     53 //
     54 // A:   <protocol buffer encoding of fields except R.tensor()>
     55 // B1:  <tag encoding for RecvTensorResponse::tensor>
     56 // B2:  <varint32 length of R.tensor() sub message>
     57 // C:   <protocol buffer encoding of R.tensor() except for
     58 //          R.tensor().tensor_content()>
     59 // D1:  <tag encoding for TensorProto::tensor_content>
     60 // D2:  <varint32 length of R.tensor().tensor_content() data>
     61 // E:   <actual data for val's representation>
     62 //
     63 // If the tensor data is up to "kLargeTensorBytes", then A
     64 // through E will all be encoded into "*result" in a single grpc::Slice.
     65 //
     66 // If the tensor data is larger than "kLargeTensorBytes", then A through
     67 // D2 will be encoded in one grpc::Slice, and E will be encoded in a second
     68 // grpc::Slice that points to the backing store for the tensor data, to avoid
     69 // copying the tensor data (and the grpc::Slice setup will be arrange so as
     70 // to dereference the underlying tensor data buffer when it is no longer
     71 // needed in the "*result" ByteBuffer).
     72 static int VarLengthEncodingSize(uint32 tag, size_t bytes) {
     73   return core::VarintLength(tag << 3) + core::VarintLength(bytes) + bytes;
     74 }
     75 
     76 // Returns an upper bound in bytes of the protocol buffer encoding of
     77 // the "skeleton" of "val" (all the data needed for dtype and the shape,
     78 // but not the actual contents of "val").
     79 static int SkeletonEncodingSizeUpperBound(const Tensor& val) {
     80   static const int kVarintMax64 = 10;  // Max length of varint64 encoding
     81   const int ndims = val.shape().dims();
     82   return (2 * kVarintMax64) +           // dtype
     83          (ndims * (4 * kVarintMax64));  // Shape: 4 varints per dim
     84 }
     85 
     86 // Encode the skeleton for "val" (the encoded TensorProto contents
     87 // (dtype and shape, but not the actual data) into "*e".  The backing
     88 // store for "*e" must be of appropriate size to hold this encoding.
     89 static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
     90   // Encode val.dtype()
     91   e->WriteUint64(TensorProto::kDtypeFieldNumber, val.dtype());
     92 
     93   // Compute length of val.shape() proto encoding
     94   const int ndims = val.shape().dims();
     95   int tensor_shape_bytes = 0;
     96   for (int d = 0; d < ndims; d++) {
     97     int64 dim_size = val.shape().dim_size(d);
     98     tensor_shape_bytes +=
     99         2 +  // TensorShapeProto dim tag + varintlength of submessage
    100         1 +  // TensorShapeProto_Dim::kSizeFieldNumber
    101         core::VarintLength(dim_size);
    102   }
    103 
    104   if (tensor_shape_bytes > 0) {
    105     e->WriteVarlengthBeginning(TensorProto::kTensorShapeFieldNumber,
    106                                tensor_shape_bytes);
    107     // Encode val.shape()
    108     for (int d = 0; d < ndims; d++) {
    109       int64 dim_size = val.shape().dim_size(d);
    110       int64 dim_varlen = 1 +  // TensorShapeProto_Dim::kSizeFieldNumber
    111                          core::VarintLength(dim_size);
    112       e->WriteVarlengthBeginning(TensorShapeProto::kDimFieldNumber, dim_varlen);
    113       e->WriteUint64(TensorShapeProto_Dim::kSizeFieldNumber, dim_size);
    114     }
    115   }
    116 
    117 #ifndef NDEBUG
    118   {
    119     // Debug-mode only check to make sure the encoding above is
    120     // identical to the auto-generated protocol buffer encoding.
    121     TensorProto skeleton;
    122     skeleton.set_dtype(val.dtype());
    123     val.shape().AsProto(skeleton.mutable_tensor_shape());
    124     string tensor_except_contents;  // tensor() field except contents
    125     skeleton.AppendToString(&tensor_except_contents);
    126     TensorProto skeleton2;
    127     skeleton2.ParseFromString(string(e->data(), e->size()));
    128     string out;
    129     skeleton.AppendToString(&out);
    130     DCHECK_EQ(tensor_except_contents, out) << skeleton.DebugString() << " vs\n"
    131                                            << skeleton2.DebugString();
    132   }
    133 #endif
    134 }
    135 
    136 void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
    137                               ::grpc::ByteBuffer* result) {
    138   const int kLargeTensorBytes = 1024;
    139   RecvTensorResponse response;
    140   if (is_dead) {
    141     response.set_is_dead(is_dead);
    142   }
    143   response.set_send_start_micros(Env::Default()->NowMicros());
    144   if (!DataTypeCanUseMemcpy(val.dtype())) {
    145     // Straightforward but slow path for complicated kinds of tensor data
    146     // TODO(jeff,sanjay): If this becomes an issue, we could
    147     // go directly from val -> ByteBuffer, with some effort.
    148     val.AsProtoTensorContent(response.mutable_tensor());
    149 
    150     // Encode full protocol buffer to a ByteBuffer
    151     EncodeRecvTensorResponseToByteBuffer(response, result);
    152   } else {
    153     // skeleton is the encoded TensorProto contents (dtype and shape), but
    154     // not the actual data
    155     gtl::InlinedVector<char, 128> skeleton(SkeletonEncodingSizeUpperBound(val));
    156     io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size());
    157     EncodeSkeleton(val, &e_skeleton);
    158 
    159     StringPiece tdata = val.tensor_data();
    160     uint32 overall_tensor_proto_bytesize =
    161         (e_skeleton.size() +
    162          VarLengthEncodingSize(TensorProto::kTensorContentFieldNumber,
    163                                tdata.size()));
    164     string header;  // All of RecvTensorResponse except the tensor() field
    165     response.AppendToString(&header);
    166 
    167     size_t expected_size =
    168         (header.size() +
    169          VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
    170                                overall_tensor_proto_bytesize));
    171     // If "tensor_data_is_large == false", we copy the tensor data to the
    172     // end of the buffer we are preparing that holds the rest of the
    173     // RecvTensorResponse protocol buffer.
    174     //
    175     // If "tensor_data_is_large == true", we arrange to share the backing
    176     // store of the data by creating a slice that also points to the
    177     // backing store, with appropriate reference counts to keep the
    178     // backing store alive as needed.
    179     bool tensor_data_is_large = (tdata.size() > kLargeTensorBytes);
    180     size_t encoder_size = expected_size - tdata.size();
    181 
    182     // Encode all but the actual "tdata", but including the tag and
    183     // varlength header for the "tdata"
    184     gtl::InlinedVector<char, 1024> space(encoder_size);
    185     io::ProtoEncodeHelper e(space.data(), space.size());
    186     // (A)
    187     e.WriteRawBytes(header);
    188 
    189     // (B1) & (B2)
    190     e.WriteVarlengthBeginning(RecvTensorResponse::kTensorFieldNumber,
    191                               overall_tensor_proto_bytesize);
    192     // (C)
    193     e.WriteRawBytes(StringPiece(e_skeleton.data(), e_skeleton.size()));
    194     // (D1) & (D2)
    195     e.WriteVarlengthBeginning(TensorProto::kTensorContentFieldNumber,
    196                               tdata.size());
    197 
    198     // All but the tensor backing store are serialized now
    199 
    200     // Now allocate memory and put into the ByteBuffer
    201     ::grpc::Slice slices[2];
    202     int num_slices = 0;
    203     {
    204       size_t slice_len = e.size() + (tensor_data_is_large ? 0 : tdata.size());
    205       slices[0] = ::grpc::Slice(slice_len);
    206       memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
    207       if (!tensor_data_is_large) {
    208         // (E)
    209         memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
    210                tdata.size());
    211       }
    212       num_slices += 1;
    213     }
    214 
    215     if (tensor_data_is_large) {
    216       // (E) Encode tensor data, but by sharing backing store
    217       const TensorBuffer* buf = DMAHelper::buffer(&val);
    218       buf->Ref();
    219       slices[1] = ::grpc::Slice(
    220           const_cast<void*>(static_cast<const void*>(tdata.data())),
    221           tdata.size(),
    222           [](void* backing) { static_cast<TensorBuffer*>(backing)->Unref(); },
    223           const_cast<TensorBuffer*>(buf));
    224       num_slices += 1;
    225     }
    226     size_t total_bytes = 0;
    227     for (int i = 0; i < num_slices; i++) {
    228       total_bytes += slices[i].size();
    229     }
    230     CHECK_EQ(total_bytes, expected_size);
    231 
    232     ::grpc::ByteBuffer tmp(&slices[0], num_slices);
    233     result->Swap(&tmp);
    234   }
    235 }
    236 
    237 }  // namespace grpc
    238 }  // namespace tensorflow
    239