1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 17 18 #include "google/protobuf/any.pb.h" 19 20 #include "tensorflow/core/common_runtime/device.h" 21 #include "tensorflow/core/framework/tensor.pb.h" 22 #include "tensorflow/core/framework/tensor_shape.pb.h" 23 24 namespace tensorflow { 25 26 TensorResponse::Source::~Source() {} 27 28 void TensorResponse::Clear() { 29 on_host_ = false; 30 device_ = nullptr; 31 alloc_attrs_ = AllocatorAttributes(); 32 allocator_ = nullptr; 33 already_used_ = false; 34 ClearTensor(); 35 } 36 37 void TensorResponse::ClearTensor() { 38 meta_.Clear(); 39 tensor_ = Tensor(); 40 } 41 42 void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) { 43 Clear(); 44 device_ = d; 45 alloc_attrs_ = aa; 46 const DeviceAttributes& da = d->attributes(); 47 if (alloc_attrs_.on_host() || da.device_type() == "CPU") { 48 on_host_ = true; 49 } 50 allocator_ = device_->GetAllocator(alloc_attrs_); 51 } 52 53 Status TensorResponse::InitFrom(RecvTensorResponse* response) { 54 Status s; 55 meta_.Swap(response); 56 if (on_host_) { 57 if (!tensor_.FromProto(allocator_, meta_.tensor())) { 58 s = errors::InvalidArgument("Cannot parse tensor from response"); 59 } 60 } else { 61 s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); 62 } 63 { 64 TensorProto empty; 65 meta_.mutable_tensor()->Swap(&empty); 66 } 67 meta_.clear_tensor(); 68 return s; 69 } 70 71 void TensorResponse::InitPartial(const RecvTensorResponse& response) { 72 // Everything except content is present in *response. Content will 73 // arrive later; allocate a Tensor with appropriate storage for that 74 // content. 75 meta_ = response; 76 TensorShape shape(meta_.tensor().tensor_shape()); 77 Tensor t(allocator_, meta_.tensor().dtype(), shape); 78 tensor_ = std::move(t); 79 } 80 81 Status TensorResponse::ParseFrom(Source* source) { 82 if (!on_host_) { 83 protobuf::io::CodedInputStream input(source->contents()); 84 input.SetTotalBytesLimit(INT_MAX); // Unlimited 85 86 // Pre-parse into local storage, then delegate to device. 87 if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) { 88 return errors::InvalidArgument("Cannot parse tensor from response"); 89 } 90 Status s = 91 device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); 92 // Reduce memory usage for big tensors. 93 { 94 TensorProto empty; 95 meta_.mutable_tensor()->Swap(&empty); 96 } 97 meta_.clear_tensor(); 98 return s; 99 } 100 if (already_used_) { 101 ClearTensor(); 102 } 103 already_used_ = true; 104 if (ParseFast(source)) return Status::OK(); 105 meta_.Clear(); 106 if (ParseSlow(source)) return Status::OK(); 107 return errors::InvalidArgument("Cannot parse tensor from response"); 108 } 109 110 // Define some helper routines for decoding protocol buffer wire format data 111 namespace { 112 // We only need some of the wiretype values for this code 113 enum WireType { 114 WIRETYPE_VARINT = 0, 115 WIRETYPE_LENGTH_DELIMITED = 2, 116 }; 117 inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; } 118 inline WireType GetTagWireType(uint32 tag) { 119 return static_cast<WireType>(tag & 0x7); 120 } 121 122 bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) { 123 protobuf_uint64 v; 124 if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) { 125 *result = static_cast<int>(v); 126 return true; 127 } else { 128 return false; 129 } 130 } 131 132 bool ReadNestedMessage(protobuf::io::CodedInputStream* input, 133 protobuf::Message* value) { 134 int length; 135 if (!ReadVarintSizeAsInt(input, &length)) return false; 136 std::pair<protobuf::io::CodedInputStream::Limit, int> p = 137 input->IncrementRecursionDepthAndPushLimit(length); 138 if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false; 139 // Make sure that parsing stopped when the limit was hit, not at an endgroup 140 // tag. 141 return input->DecrementRecursionDepthAndPopLimit(p.first); 142 } 143 144 } // namespace 145 146 bool TensorResponse::ParseTensorSubmessage( 147 protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) { 148 bool seen_tensor_content = false; 149 while (true) { 150 auto p = input->ReadTagWithCutoff(127); 151 int tag = GetTagFieldNumber(p.first); 152 WireType wt = GetTagWireType(p.first); 153 if (!p.second) { 154 bool ok = (tag == 0); 155 if (ok && !seen_tensor_content) { 156 // No tensor content: could be because it's a zero-length tensor 157 TensorShape shape(tensor_meta->tensor_shape()); 158 Tensor t(allocator_, tensor_meta->dtype(), shape); 159 tensor_ = std::move(t); 160 } 161 return ok; 162 } 163 switch (tag) { 164 case TensorProto::kDtypeFieldNumber: { 165 uint32 v; 166 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false; 167 if (seen_tensor_content) return false; 168 tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v))); 169 if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false; 170 break; 171 } 172 case TensorProto::kTensorShapeFieldNumber: { 173 if ((wt != WIRETYPE_LENGTH_DELIMITED) || 174 !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape())) 175 return false; 176 if (seen_tensor_content) return false; 177 break; 178 } 179 case TensorProto::kVersionNumberFieldNumber: { 180 uint32 v; 181 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false; 182 if (seen_tensor_content) return false; 183 tensor_meta->set_version_number(static_cast<int32>(v)); 184 break; 185 } 186 case TensorProto::kTensorContentFieldNumber: { 187 // If we haven't seen the dtype and tensor_shape data first, we can't 188 // deal with this in the fast path. 189 if (seen_tensor_content) return false; 190 if (wt != WIRETYPE_LENGTH_DELIMITED || 191 !tensor_meta->has_tensor_shape()) { 192 return false; 193 } 194 int num_bytes; 195 if (!ReadVarintSizeAsInt(input, &num_bytes)) return false; 196 seen_tensor_content = true; 197 TensorShape shape(tensor_meta->tensor_shape()); 198 Tensor t(allocator_, tensor_meta->dtype(), shape); 199 StringPiece buf = t.tensor_data(); 200 if (static_cast<size_t>(num_bytes) != buf.size()) return false; 201 // TODO(jeff,sanjay): Figure out a way to avoid this copy if 202 // the underlying ZeroCopyInputStream data is properly aligned 203 // and compatible with what allocator_ wants. 204 if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes)) 205 return false; 206 tensor_ = std::move(t); 207 break; 208 } 209 default: { 210 // Some other tag our fast path code is not prepared to handle. 211 // return false. 212 return false; 213 } 214 } 215 } 216 } 217 218 bool TensorResponse::ParseFast(Source* source) { 219 protobuf::io::CodedInputStream input(source->contents()); 220 input.SetTotalBytesLimit(INT_MAX); // Unlimited 221 while (true) { 222 auto p = input.ReadTagWithCutoff(127); 223 int tag = GetTagFieldNumber(p.first); 224 WireType wt = GetTagWireType(p.first); 225 if (!p.second) { 226 return (tag == 0); 227 } 228 switch (tag) { 229 case RecvTensorResponse::kTensorFieldNumber: { 230 if (wt != WIRETYPE_LENGTH_DELIMITED) return false; 231 232 int length; 233 if (!ReadVarintSizeAsInt(&input, &length)) return false; 234 std::pair<protobuf::io::CodedInputStream::Limit, int> p = 235 input.IncrementRecursionDepthAndPushLimit(length); 236 if (p.second < 0 || 237 !ParseTensorSubmessage(&input, meta_.mutable_tensor())) { 238 return false; 239 } 240 if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { 241 return false; 242 } 243 break; 244 } 245 case RecvTensorResponse::kIsDeadFieldNumber: { 246 uint32 v; 247 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false; 248 meta_.set_is_dead((v != 0) ? true : false); 249 break; 250 } 251 case RecvTensorResponse::kSendStartMicrosFieldNumber: { 252 protobuf_uint64 v; 253 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false; 254 meta_.set_send_start_micros(static_cast<int64>(v)); 255 break; 256 } 257 case RecvTensorResponse::kTransportOptionsFieldNumber: { 258 if ((wt != WIRETYPE_LENGTH_DELIMITED) || 259 !ReadNestedMessage(&input, meta_.mutable_transport_options())) 260 return false; 261 break; 262 } 263 default: { 264 // Unknown tag, so don't handle we can't handle on the fast path 265 return false; 266 } 267 } 268 } 269 270 return false; 271 } 272 273 bool TensorResponse::ParseSlow(Source* source) { 274 if (!meta_.ParseFromZeroCopyStream(source->contents())) { 275 return false; 276 } 277 278 Tensor parsed(meta_.tensor().dtype()); 279 if (!parsed.FromProto(allocator_, meta_.tensor())) { 280 return false; 281 } 282 tensor_ = std::move(parsed); 283 284 // Reduce memory usage for big tensors. 285 { 286 TensorProto empty; 287 meta_.mutable_tensor()->Swap(&empty); 288 } 289 meta_.clear_tensor(); 290 291 return true; 292 } 293 294 } // namespace tensorflow 295