Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     17 
     18 #include "google/protobuf/any.pb.h"
     19 
     20 #include "tensorflow/core/common_runtime/device.h"
     21 #include "tensorflow/core/framework/tensor.pb.h"
     22 #include "tensorflow/core/framework/tensor_shape.pb.h"
     23 
     24 namespace tensorflow {
     25 
     26 TensorResponse::Source::~Source() {}
     27 
     28 void TensorResponse::Clear() {
     29   on_host_ = false;
     30   device_ = nullptr;
     31   alloc_attrs_ = AllocatorAttributes();
     32   allocator_ = nullptr;
     33   already_used_ = false;
     34   ClearTensor();
     35 }
     36 
     37 void TensorResponse::ClearTensor() {
     38   meta_.Clear();
     39   tensor_ = Tensor();
     40 }
     41 
     42 void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
     43   Clear();
     44   device_ = d;
     45   alloc_attrs_ = aa;
     46   const DeviceAttributes& da = d->attributes();
     47   if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
     48     on_host_ = true;
     49   }
     50   allocator_ = device_->GetAllocator(alloc_attrs_);
     51 }
     52 
     53 Status TensorResponse::InitFrom(RecvTensorResponse* response) {
     54   Status s;
     55   meta_.Swap(response);
     56   if (on_host_) {
     57     if (!tensor_.FromProto(allocator_, meta_.tensor())) {
     58       s = errors::InvalidArgument("Cannot parse tensor from response");
     59     }
     60   } else {
     61     s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
     62   }
     63   {
     64     TensorProto empty;
     65     meta_.mutable_tensor()->Swap(&empty);
     66   }
     67   meta_.clear_tensor();
     68   return s;
     69 }
     70 
     71 void TensorResponse::InitPartial(const RecvTensorResponse& response) {
     72   // Everything except content is present in *response.  Content will
     73   // arrive later; allocate a Tensor with appropriate storage for that
     74   // content.
     75   meta_ = response;
     76   TensorShape shape(meta_.tensor().tensor_shape());
     77   Tensor t(allocator_, meta_.tensor().dtype(), shape);
     78   tensor_ = std::move(t);
     79 }
     80 
     81 Status TensorResponse::ParseFrom(Source* source) {
     82   if (!on_host_) {
     83     protobuf::io::CodedInputStream input(source->contents());
     84     input.SetTotalBytesLimit(INT_MAX);  // Unlimited
     85 
     86     // Pre-parse into local storage, then delegate to device.
     87     if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
     88       return errors::InvalidArgument("Cannot parse tensor from response");
     89     }
     90     Status s =
     91         device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
     92     // Reduce memory usage for big tensors.
     93     {
     94       TensorProto empty;
     95       meta_.mutable_tensor()->Swap(&empty);
     96     }
     97     meta_.clear_tensor();
     98     return s;
     99   }
    100   if (already_used_) {
    101     ClearTensor();
    102   }
    103   already_used_ = true;
    104   if (ParseFast(source)) return Status::OK();
    105   meta_.Clear();
    106   if (ParseSlow(source)) return Status::OK();
    107   return errors::InvalidArgument("Cannot parse tensor from response");
    108 }
    109 
    110 // Define some helper routines for decoding protocol buffer wire format data
    111 namespace {
    112 // We only need some of the wiretype values for this code
    113 enum WireType {
    114   WIRETYPE_VARINT = 0,
    115   WIRETYPE_LENGTH_DELIMITED = 2,
    116 };
    117 inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
    118 inline WireType GetTagWireType(uint32 tag) {
    119   return static_cast<WireType>(tag & 0x7);
    120 }
    121 
    122 bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
    123   protobuf_uint64 v;
    124   if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
    125     *result = static_cast<int>(v);
    126     return true;
    127   } else {
    128     return false;
    129   }
    130 }
    131 
    132 bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
    133                        protobuf::Message* value) {
    134   int length;
    135   if (!ReadVarintSizeAsInt(input, &length)) return false;
    136   std::pair<protobuf::io::CodedInputStream::Limit, int> p =
    137       input->IncrementRecursionDepthAndPushLimit(length);
    138   if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
    139   // Make sure that parsing stopped when the limit was hit, not at an endgroup
    140   // tag.
    141   return input->DecrementRecursionDepthAndPopLimit(p.first);
    142 }
    143 
    144 }  // namespace
    145 
    146 bool TensorResponse::ParseTensorSubmessage(
    147     protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
    148   bool seen_tensor_content = false;
    149   while (true) {
    150     auto p = input->ReadTagWithCutoff(127);
    151     int tag = GetTagFieldNumber(p.first);
    152     WireType wt = GetTagWireType(p.first);
    153     if (!p.second) {
    154       bool ok = (tag == 0);
    155       if (ok && !seen_tensor_content) {
    156         // No tensor content: could be because it's a zero-length tensor
    157         TensorShape shape(tensor_meta->tensor_shape());
    158         Tensor t(allocator_, tensor_meta->dtype(), shape);
    159         tensor_ = std::move(t);
    160       }
    161       return ok;
    162     }
    163     switch (tag) {
    164       case TensorProto::kDtypeFieldNumber: {
    165         uint32 v;
    166         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
    167         if (seen_tensor_content) return false;
    168         tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
    169         if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
    170         break;
    171       }
    172       case TensorProto::kTensorShapeFieldNumber: {
    173         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
    174             !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
    175           return false;
    176         if (seen_tensor_content) return false;
    177         break;
    178       }
    179       case TensorProto::kVersionNumberFieldNumber: {
    180         uint32 v;
    181         if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
    182         if (seen_tensor_content) return false;
    183         tensor_meta->set_version_number(static_cast<int32>(v));
    184         break;
    185       }
    186       case TensorProto::kTensorContentFieldNumber: {
    187         // If we haven't seen the dtype and tensor_shape data first, we can't
    188         // deal with this in the fast path.
    189         if (seen_tensor_content) return false;
    190         if (wt != WIRETYPE_LENGTH_DELIMITED ||
    191             !tensor_meta->has_tensor_shape()) {
    192           return false;
    193         }
    194         int num_bytes;
    195         if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
    196         seen_tensor_content = true;
    197         TensorShape shape(tensor_meta->tensor_shape());
    198         Tensor t(allocator_, tensor_meta->dtype(), shape);
    199         StringPiece buf = t.tensor_data();
    200         if (static_cast<size_t>(num_bytes) != buf.size()) return false;
    201         // TODO(jeff,sanjay): Figure out a way to avoid this copy if
    202         // the underlying ZeroCopyInputStream data is properly aligned
    203         // and compatible with what allocator_ wants.
    204         if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
    205           return false;
    206         tensor_ = std::move(t);
    207         break;
    208       }
    209       default: {
    210         // Some other tag our fast path code is not prepared to handle.
    211         // return false.
    212         return false;
    213       }
    214     }
    215   }
    216 }
    217 
    218 bool TensorResponse::ParseFast(Source* source) {
    219   protobuf::io::CodedInputStream input(source->contents());
    220   input.SetTotalBytesLimit(INT_MAX);  // Unlimited
    221   while (true) {
    222     auto p = input.ReadTagWithCutoff(127);
    223     int tag = GetTagFieldNumber(p.first);
    224     WireType wt = GetTagWireType(p.first);
    225     if (!p.second) {
    226       return (tag == 0);
    227     }
    228     switch (tag) {
    229       case RecvTensorResponse::kTensorFieldNumber: {
    230         if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
    231 
    232         int length;
    233         if (!ReadVarintSizeAsInt(&input, &length)) return false;
    234         std::pair<protobuf::io::CodedInputStream::Limit, int> p =
    235             input.IncrementRecursionDepthAndPushLimit(length);
    236         if (p.second < 0 ||
    237             !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
    238           return false;
    239         }
    240         if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
    241           return false;
    242         }
    243         break;
    244       }
    245       case RecvTensorResponse::kIsDeadFieldNumber: {
    246         uint32 v;
    247         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
    248         meta_.set_is_dead((v != 0) ? true : false);
    249         break;
    250       }
    251       case RecvTensorResponse::kSendStartMicrosFieldNumber: {
    252         protobuf_uint64 v;
    253         if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
    254         meta_.set_send_start_micros(static_cast<int64>(v));
    255         break;
    256       }
    257       case RecvTensorResponse::kTransportOptionsFieldNumber: {
    258         if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
    259             !ReadNestedMessage(&input, meta_.mutable_transport_options()))
    260           return false;
    261         break;
    262       }
    263       default: {
    264         // Unknown tag, so don't handle we can't handle on the fast path
    265         return false;
    266       }
    267     }
    268   }
    269 
    270   return false;
    271 }
    272 
    273 bool TensorResponse::ParseSlow(Source* source) {
    274   if (!meta_.ParseFromZeroCopyStream(source->contents())) {
    275     return false;
    276   }
    277 
    278   Tensor parsed(meta_.tensor().dtype());
    279   if (!parsed.FromProto(allocator_, meta_.tensor())) {
    280     return false;
    281   }
    282   tensor_ = std::move(parsed);
    283 
    284   // Reduce memory usage for big tensors.
    285   {
    286     TensorProto empty;
    287     meta_.mutable_tensor()->Swap(&empty);
    288   }
    289   meta_.clear_tensor();
    290 
    291   return true;
    292 }
    293 
    294 }  // namespace tensorflow
    295