1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 17 18 #include <algorithm> 19 #include <cstdlib> 20 #include <cstring> 21 #include <memory> 22 #include <utility> 23 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.pb.h" 26 #include "tensorflow/core/framework/tensor_shape.pb_text.h" 27 #include "tensorflow/core/framework/tensor_shape.pb.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/framework/types.pb_text.h" 30 #include "tensorflow/core/framework/variant.h" 31 #include "tensorflow/core/framework/variant_op_registry.h" 32 #include "tensorflow/core/framework/variant_tensor_data.h" 33 #include "tensorflow/core/framework/versions.h" 34 #include "tensorflow/core/framework/versions.pb.h" 35 #include "tensorflow/core/lib/core/coding.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/gtl/map_util.h" 38 #include "tensorflow/core/lib/gtl/stl_util.h" 39 #include "tensorflow/core/lib/hash/crc32c.h" 40 #include "tensorflow/core/lib/io/path.h" 41 #include "tensorflow/core/lib/io/table_builder.h" 42 #include "tensorflow/core/lib/random/random.h" 43 #include "tensorflow/core/lib/strings/stringprintf.h" 44 #include "tensorflow/core/util/saved_tensor_slice_util.h" 45 #include "tensorflow/core/util/tensor_slice_util.h" 46 47 namespace tensorflow { 48 49 // Versioning of the tensor bundle format. 50 const int kTensorBundleMinProducer = 0; 51 const int kTensorBundleMinConsumer = 0; 52 const int kTensorBundleVersion = 1; 53 54 // Size of our input buffer for streaming reads 55 static const int kBufferSize = 1024 * 1024; 56 57 // Key to the special BundleHeaderProto entry. Do not change this, as clients 58 // can make the assumption that the header is always the first entry in the 59 // bundle. 60 const char* const kHeaderEntryKey = ""; 61 62 namespace { 63 64 // Reads "num_elements" string elements from file[offset, offset+size) into the 65 // length-N "destination". Discards the original content of "destination". 66 // 67 // Checksums the string lengths (as restored uint32, not varint32 bytes) and 68 // string bytes, and stores it into "actual_crc32c". 69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, 70 size_t offset, size_t size, string* destination, 71 uint32* actual_crc32c) { 72 if (size == 0) return Status::OK(); 73 CHECK_GT(size, 0); 74 75 // Reads "num_elements" varint32's from "buffered_file". 76 TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); 77 std::vector<uint32> string_lengths(num_elements); 78 for (size_t i = 0; i < num_elements; ++i) { 79 TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i])); 80 } 81 if (offset + size < buffered_file->Tell()) { 82 return errors::DataLoss("String lengths longer than expected offset ", 83 offset + size); 84 } 85 *actual_crc32c = 86 crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()), 87 sizeof(uint32) * num_elements); 88 89 // Reads the length-checksum. 90 uint32 length_checksum = 0; 91 size_t unused_bytes_read = 0; 92 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes( 93 sizeof(uint32), reinterpret_cast<char*>(&length_checksum), 94 &unused_bytes_read)); 95 if (crc32c::Unmask(length_checksum) != *actual_crc32c) { 96 return errors::DataLoss( 97 "The length checksum does not match: expected ", 98 strings::Printf("%08u", crc32c::Unmask(length_checksum)), 99 " but actual is ", strings::Printf("%08u", *actual_crc32c)); 100 } 101 *actual_crc32c = 102 crc32c::Extend(*actual_crc32c, reinterpret_cast<char*>(&length_checksum), 103 sizeof(uint32)); 104 105 // Reads the actual string bytes. 106 for (size_t i = 0; i < num_elements; ++i) { 107 const uint32 string_length = string_lengths[i]; 108 string* buffer = &destination[i]; 109 110 buffer->resize(string_length); 111 size_t bytes_read = 0; 112 TF_RETURN_IF_ERROR( 113 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read)); 114 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read); 115 } 116 return Status::OK(); 117 } 118 119 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret, 120 size_t offset, size_t size, uint32* actual_crc32c) { 121 // On-disk format: 122 // [varint64 len1][bytes variant1][4 byte checksum] 123 // .. 124 // [varint64 lenN][bytes variantN][4 byte checksum] 125 // Var "crc32c" checksums all the lens, variant bytes, individual variant 126 // checksums (as uint32, not varint32 bytes). 127 if (size == 0) return Status::OK(); 128 size_t num_elements = ret->NumElements(); 129 130 // Reads the actual string bytes. 131 TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); 132 for (size_t i = 0; i < num_elements; ++i) { 133 // Read the serialized variant length. 134 uint64 string_length = 0; 135 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length)); 136 *actual_crc32c = crc32c::Extend( 137 *actual_crc32c, reinterpret_cast<const char*>(&string_length), 138 sizeof(uint64)); 139 // Read the actual serialized variant. 140 string buffer; 141 buffer.resize(string_length); 142 size_t bytes_read = 0; 143 TF_RETURN_IF_ERROR( 144 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read)); 145 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read); 146 VariantTensorDataProto proto; 147 if (!proto.ParseFromString(buffer)) { 148 return errors::DataLoss("Unable to parse VariantTensorDataProto from ", 149 "buffer of size ", string_length, ". ", 150 "Bundle entry offset: ", offset, " size: ", size); 151 } 152 Variant v = proto; 153 if (!DecodeUnaryVariant(&v)) { 154 return errors::Internal("Could not decode variant with type_name: \"", 155 v.TypeName(), "\". Perhaps you forgot to ", 156 "register a decoder via ", 157 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?"); 158 } 159 160 // Read the checksum. 161 uint32 checksum = 0; 162 size_t unused_bytes_read = 0; 163 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes( 164 sizeof(uint32), reinterpret_cast<char*>(&checksum), 165 &unused_bytes_read)); 166 if (crc32c::Unmask(checksum) != *actual_crc32c) { 167 return errors::DataLoss( 168 "The checksum after Variant ", i, " does not match.", 169 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)), 170 " Actual: ", strings::Printf("%08u", *actual_crc32c)); 171 } 172 *actual_crc32c = crc32c::Extend( 173 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32)); 174 175 ret->flat<Variant>()(i) = std::move(v); 176 } 177 178 return Status::OK(); 179 } 180 181 char* GetBackingBuffer(const Tensor& val) { 182 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype(); 183 return const_cast<char*>(val.tensor_data().data()); 184 } 185 186 string* GetStringBackingBuffer(const Tensor& val) { 187 CHECK_EQ(DT_STRING, val.dtype()); 188 return const_cast<string*>(val.flat<string>().data()); 189 } 190 191 Status ParseEntryProto(StringPiece key, StringPiece value, 192 protobuf::MessageLite* out) { 193 if (!out->ParseFromArray(value.data(), value.size())) { 194 return errors::DataLoss("Entry for key ", key, " not parseable."); 195 } 196 return Status::OK(); 197 } 198 199 // Serializes the data bytes of the non-string tensor "val". Discards the 200 // original content of "bytes_written", and on OK updates it with number of 201 // bytes written. 202 // REQUIRES: val.dtype() != DT_STRING 203 Status WriteTensor(const Tensor& val, FileOutputBuffer* out, 204 size_t* bytes_written) { 205 DCHECK_NE(val.dtype(), DT_STRING); 206 DCHECK_NE(val.dtype(), DT_VARIANT); 207 *bytes_written = val.TotalBytes(); 208 char* buf = GetBackingBuffer(val); 209 VLOG(1) << "Appending " << *bytes_written << " bytes to file"; 210 return out->Append(StringPiece(buf, *bytes_written)); 211 } 212 213 // Serializes string tensor "val". "bytes_written" is treated in the same 214 // fashion as WriteTensor(). 215 // 216 // Checksums all bytes written and stores it into "crc32c". 217 // REQUIRES: val.dtype() == DT_STRING 218 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, 219 size_t* bytes_written, uint32* crc32c) { 220 // On-disk format: 221 // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes] 222 // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes), 223 // the length-checksum, and all the string bytes. 224 DCHECK_EQ(val.dtype(), DT_STRING); 225 const string* strings = GetStringBackingBuffer(val); 226 227 // Writes the varint lengths. 228 string lengths; 229 lengths.reserve(val.NumElements()); // At least 1 byte per element. 230 *crc32c = 0; 231 for (int64 i = 0; i < val.NumElements(); ++i) { 232 const string* elem = &strings[i]; 233 DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size())); 234 const uint32 elem_size = static_cast<uint32>(elem->size()); 235 236 core::PutVarint32(&lengths, elem_size); 237 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size), 238 sizeof(uint32)); 239 } 240 TF_RETURN_IF_ERROR(out->Append(lengths)); 241 *bytes_written = lengths.size(); 242 243 // Writes the length checksum. 244 const uint32 length_checksum = crc32c::Mask(*crc32c); 245 TF_RETURN_IF_ERROR(out->Append(StringPiece( 246 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32)))); 247 *crc32c = crc32c::Extend( 248 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32)); 249 *bytes_written += sizeof(uint32); 250 251 // Writes all the string bytes out. 252 for (int64 i = 0; i < val.NumElements(); ++i) { 253 const string* string = &strings[i]; 254 TF_RETURN_IF_ERROR(out->Append(*string)); 255 *bytes_written += string->size(); 256 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size()); 257 } 258 return Status::OK(); 259 } 260 261 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out, 262 size_t* bytes_written, uint32* crc32c) { 263 // On-disk format: 264 // [varint64 len1][bytes variant1][4 byte checksum] 265 // .. 266 // [varint64 lenN][bytes variantN][4 byte checksum] 267 // Var "crc32c" checksums all the lens, variant bytes, individual variant 268 // checksums (as uint32, not varint32 bytes). 269 DCHECK_EQ(val.dtype(), DT_VARIANT); 270 271 *crc32c = 0; 272 *bytes_written = 0; 273 for (int64 i = 0; i < val.NumElements(); ++i) { 274 VariantTensorData data; 275 val.flat<Variant>()(i).Encode(&data); 276 VariantTensorDataProto proto; 277 data.ToProto(&proto); 278 string elem; 279 proto.SerializeToString(&elem); 280 281 // Write the length of the serialized variant. 282 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size())); 283 const auto elem_size = static_cast<uint64>(elem.size()); 284 string len; 285 core::PutVarint64(&len, elem_size); 286 TF_RETURN_IF_ERROR(out->Append(len)); 287 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size), 288 sizeof(uint64)); 289 *bytes_written += len.size(); 290 291 // Write the serialized variant. 292 TF_RETURN_IF_ERROR(out->Append(elem)); 293 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size()); 294 *bytes_written += elem.size(); 295 296 // Write the checksum. 297 const uint32 length_checksum = crc32c::Mask(*crc32c); 298 TF_RETURN_IF_ERROR(out->Append(StringPiece( 299 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32)))); 300 *crc32c = 301 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum), 302 sizeof(uint32)); 303 *bytes_written += sizeof(uint32); 304 } 305 306 return Status::OK(); 307 } 308 309 // Returns whether "slice_spec" is a full slice, with respect to the full shape. 310 // 311 // This can happen say, when "slice_spec" is 312 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0, 313 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against. 314 bool IsFullSlice(const TensorSlice& slice_spec, 315 const TensorShape& full_tensor_shape) { 316 if (slice_spec.IsFull()) { 317 return true; 318 } else { 319 TensorShape sliced_shape; 320 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError(); 321 return sliced_shape == full_tensor_shape; 322 } 323 } 324 325 Status CorruptFileError(const Status& in_status, const string& filename, 326 const string& detail) { 327 if (in_status.ok()) { 328 return errors::Internal("Unable to read file (", filename, 329 "). Perhaps the file is corrupt or was produced by " 330 "a newer version of TensorFlow with format changes " 331 "(", 332 detail, ")"); 333 } 334 return Status( 335 in_status.code(), 336 strings::StrCat("Unable to read file (", filename, 337 "). Perhaps the file is corrupt or was produced by a " 338 "newer version of TensorFlow with format changes (", 339 detail, "): ", in_status.error_message())); 340 } 341 342 table::Options TableBuilderOptions() { 343 table::Options o; 344 // Compressed tables cannot be read by TensorFlow releases prior to 1.1. 345 // To smoothen the transition, compressed writes are disabled for now 346 // (version 1.2) with the intention that they will be enabled again at 347 // some point (perhaps the 1.3 release?). 348 o.compression = table::kNoCompression; 349 return o; 350 } 351 352 // Writes zeros to output buffer to align the next write to the requested 353 // alignment. "size" is the current size of the buffer and is updated to the 354 // new size. 355 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) { 356 int bytes_over = *size % alignment; 357 if (bytes_over == 0) { 358 return Status::OK(); 359 } 360 int bytes_to_write = alignment - bytes_over; 361 Status status = out->Append(string(bytes_to_write, '\0')); 362 if (status.ok()) { 363 *size += bytes_to_write; 364 } 365 return status; 366 } 367 368 } // namespace 369 370 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options) 371 : env_(env), 372 options_(options), 373 prefix_(prefix.ToString()), 374 tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate", 375 random::New64())), 376 tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate", 377 random::New64())), 378 out_(nullptr), 379 size_(0) { 380 status_ = env_->CreateDir(io::Dirname(prefix_).ToString()); 381 if (!status_.ok() && !errors::IsAlreadyExists(status_)) { 382 return; 383 } 384 const string filename = DataFilename(prefix_, 0, 1); 385 std::unique_ptr<WritableFile> wrapper; 386 status_ = env_->NewWritableFile(tmp_data_path_, &wrapper); 387 if (!status_.ok()) return; 388 out_ = std::unique_ptr<FileOutputBuffer>( 389 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */)); 390 391 VLOG(1) << "Writing to file " << tmp_data_path_; 392 } 393 394 Status BundleWriter::Add(StringPiece key, const Tensor& val) { 395 if (!status_.ok()) return status_; 396 CHECK_NE(key, kHeaderEntryKey); 397 const string key_string = key.ToString(); 398 if (entries_.find(key_string) != entries_.end()) { 399 status_ = errors::InvalidArgument("Adding duplicate key: ", key); 400 return status_; 401 } 402 403 BundleEntryProto* entry = &entries_[key_string]; 404 entry->set_dtype(val.dtype()); 405 val.shape().AsProto(entry->mutable_shape()); 406 entry->set_shard_id(0); 407 entry->set_offset(size_); 408 409 // Updates the data file. 410 size_t data_bytes_written = 0; 411 uint32 crc32c = 0; 412 out_->clear_crc32c(); 413 if (val.dtype() == DT_STRING) { 414 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c); 415 } else if (val.dtype() == DT_VARIANT) { 416 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c); 417 } else { 418 status_ = WriteTensor(val, out_.get(), &data_bytes_written); 419 crc32c = out_->crc32c(); 420 } 421 422 if (status_.ok()) { 423 entry->set_size(data_bytes_written); 424 entry->set_crc32c(crc32c::Mask(crc32c)); 425 size_ += data_bytes_written; 426 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_); 427 } 428 return status_; 429 } 430 431 Status BundleWriter::AddSlice(StringPiece full_tensor_key, 432 const TensorShape& full_tensor_shape, 433 const TensorSlice& slice_spec, 434 const Tensor& slice_tensor) { 435 if (!status_.ok()) return status_; 436 CHECK_NE(full_tensor_key, kHeaderEntryKey); 437 438 // If just a singleton full slice, use the regular Add() to be more efficient. 439 if (IsFullSlice(slice_spec, full_tensor_shape)) { 440 return Add(full_tensor_key, slice_tensor); 441 } 442 443 // Inserts/updates the full tensor's metadata entry. 444 // 445 // In the case of a sharded save, MergeBundles() is responsible for merging 446 // the "slices" field of multiple metadata entries corresponding to the same 447 // full tensor. 448 const string full_tensor_key_string = full_tensor_key.ToString(); 449 BundleEntryProto* full_entry = &entries_[full_tensor_key_string]; 450 if (full_entry->dtype() != DT_INVALID) { 451 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype()); 452 } 453 if (full_entry->has_shape()) { 454 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape); 455 } 456 457 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and 458 // offset, which do not make sense for this full tensor entry. 459 full_entry->set_dtype(slice_tensor.dtype()); 460 full_tensor_shape.AsProto(full_entry->mutable_shape()); 461 TensorSliceProto* slice_proto = full_entry->add_slices(); 462 slice_spec.AsProto(slice_proto); 463 464 // The slice itself is handled by a regular Add(), which includes adding its 465 // own metadata entry, and writing out the slice's values. 466 const string slice_name = 467 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec); 468 status_ = Add(slice_name, slice_tensor); 469 return status_; 470 } 471 472 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing 473 // the orphaned data file. 474 Status BundleWriter::Finish() { 475 if (out_) { 476 status_.Update(out_->Close()); 477 out_ = nullptr; 478 if (status_.ok()) { 479 status_ = Env::Default()->RenameFile(tmp_data_path_, 480 DataFilename(prefix_, 0, 1)); 481 } else { 482 Env::Default()->DeleteFile(tmp_data_path_).IgnoreError(); 483 } 484 } 485 if (!status_.ok()) return status_; 486 // Build key -> BundleEntryProto table. 487 std::unique_ptr<WritableFile> file; 488 status_ = env_->NewWritableFile(tmp_metadata_path_, &file); 489 if (!status_.ok()) return status_; 490 { 491 // N.B.: the default use of Snappy compression may not be supported on all 492 // platforms (e.g. Android). The metadata file is small, so this is fine. 493 table::Options options; 494 options.compression = table::kNoCompression; 495 table::TableBuilder builder(options, file.get()); 496 // Header entry. 497 BundleHeaderProto header; 498 header.set_num_shards(1); 499 header.set_endianness(BundleHeaderProto::LITTLE); 500 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG); 501 VersionDef* version = header.mutable_version(); 502 version->set_producer(kTensorBundleVersion); 503 version->set_min_consumer(kTensorBundleMinConsumer); 504 505 builder.Add(kHeaderEntryKey, header.SerializeAsString()); 506 507 // All others. 508 for (const auto& p : entries_) { 509 builder.Add(p.first, p.second.SerializeAsString()); 510 } 511 status_ = builder.Finish(); 512 } 513 status_.Update(file->Close()); 514 if (!status_.ok()) { 515 Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError(); 516 return status_; 517 } else { 518 status_ = 519 Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_)); 520 if (!status_.ok()) return status_; 521 } 522 status_ = errors::Internal("BundleWriter is closed"); 523 return Status::OK(); 524 } 525 526 // Merging tensor bundles. 527 528 // Accumulator of metadata states during a merge. 529 struct MergeState { 530 // Accumulated from the header entries. 531 int num_shards = 0; 532 533 // Derives "endianness" and "version" from the first bundle merged (hence the 534 // "seen_first_bundle" guard). The two fields must be the same for all 535 // bundles in a merge. 536 bool seen_first_bundle = false; 537 BundleHeaderProto_Endianness endianness; 538 VersionDef version; 539 540 // Tensor key -> BundleEntryProto. 541 std::map<string, BundleEntryProto> entries; 542 // Data file path -> new shard id in the final merged bundle. 543 std::unordered_map<string, int32> shard_ids; 544 }; 545 546 // Merges entries of "prefix" into the accumulator state "merge". 547 // Returns OK iff the merge succeeds. 548 static Status MergeOneBundle(Env* env, StringPiece prefix, 549 MergeState* merge_state) { 550 VLOG(1) << "Merging bundle:" << prefix; 551 const string filename = MetaFilename(prefix); 552 uint64 file_size; 553 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); 554 std::unique_ptr<RandomAccessFile> file; 555 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); 556 557 table::Table* table = nullptr; 558 TF_RETURN_IF_ERROR( 559 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table)); 560 std::unique_ptr<table::Table> table_deleter(table); 561 std::unique_ptr<table::Iterator> iter(table->NewIterator()); 562 563 int num_shards; 564 // Process header. 565 { 566 iter->Seek(kHeaderEntryKey); 567 if (!iter->Valid()) { 568 return CorruptFileError(iter->status(), filename, 569 "failed to seek to header entry"); 570 } 571 BundleHeaderProto header; 572 Status s = ParseEntryProto(iter->key(), iter->value(), &header); 573 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header"); 574 575 merge_state->num_shards += header.num_shards(); 576 if (!merge_state->seen_first_bundle) { 577 merge_state->seen_first_bundle = true; 578 merge_state->endianness = header.endianness(); 579 merge_state->version = header.version(); 580 } else { 581 // Validates "endianness". 582 if (merge_state->endianness != header.endianness()) { 583 return errors::InvalidArgument( 584 "Merging bundles with conflicting endianness; inputs corrupted?"); 585 } 586 // Validates "version". 587 string curr_version, merge_version; 588 header.version().SerializeToString(&curr_version); 589 merge_state->version.SerializeToString(&merge_version); 590 if (curr_version != merge_version) { 591 return errors::InvalidArgument( 592 "Merging bundles with different format versions: merged ", 593 merge_version, " vs. curr ", curr_version); 594 } 595 } 596 num_shards = header.num_shards(); 597 iter->Next(); 598 } 599 600 // Loops through the non-header to-merge entries. 601 BundleEntryProto to_merge_entry; 602 for (; iter->Valid(); iter->Next()) { 603 const string key = iter->key().ToString(); 604 const auto entry_iter = merge_state->entries.find(key); 605 606 // Illegal: the duplicated entry is a non-slice tensor. 607 if (entry_iter != merge_state->entries.end() && 608 entry_iter->second.slices().empty()) { 609 return errors::InvalidArgument( 610 "Duplicate tensor keyed by ", key, 611 " encountered, when merging prefix: ", prefix); 612 } 613 614 TF_RETURN_IF_ERROR( 615 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry)); 616 617 // The duplicated entry holds metadata for a sliced full tensor. 618 // Allows the duplication and merges "slices". 619 if (entry_iter != merge_state->entries.end()) { 620 BundleEntryProto& existing_entry = entry_iter->second; 621 if (to_merge_entry.slices().empty()) { 622 return errors::Internal( 623 "Duplicate tensor keyed by ", key, 624 "; attempting to merge in a non-slice bundle entry"); 625 } 626 // Only needs merge the "slices" field (and validate dtype/shape). 627 for (int i = 0; i < to_merge_entry.slices_size(); ++i) { 628 TensorSliceProto* slot = existing_entry.add_slices(); 629 *slot = to_merge_entry.slices(i); 630 } 631 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype()); 632 CHECK(TensorShape(existing_entry.shape()) == 633 TensorShape(to_merge_entry.shape())); 634 continue; 635 } 636 637 // Key doesn't duplicate: a fresh tensor/slice entry. 638 auto result = merge_state->shard_ids.insert( 639 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards), 640 merge_state->shard_ids.size()}); 641 to_merge_entry.set_shard_id(result.first->second); 642 merge_state->entries[key] = to_merge_entry; 643 } 644 return Status::OK(); 645 } 646 647 Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes, 648 StringPiece merged_prefix) { 649 // Merges all metadata tables. 650 // TODO(zhifengc): KeyValue sorter if it becomes too big. 651 MergeState merge; 652 Status status = env->CreateDir(io::Dirname(merged_prefix).ToString()); 653 if (!status.ok() && !errors::IsAlreadyExists(status)) return status; 654 for (int i = 0; i < prefixes.size(); ++i) { 655 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge)); 656 } 657 658 // Renames data files to contain the merged bundle prefix. 659 for (const auto& p : merge.shard_ids) { 660 VLOG(1) << "Renaming " << p.first << " to " 661 << DataFilename(merged_prefix, p.second, merge.shard_ids.size()); 662 TF_RETURN_IF_ERROR(env->RenameFile( 663 p.first, 664 DataFilename(merged_prefix, p.second, merge.shard_ids.size()))); 665 } 666 667 // Writes the final metadata table under the merged prefix. 668 std::unique_ptr<WritableFile> merged_metadata; 669 TF_RETURN_IF_ERROR( 670 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata)); 671 { 672 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get()); 673 // Header entry. 674 BundleHeaderProto header; 675 header.set_num_shards(merge.num_shards); 676 header.set_endianness(merge.endianness); 677 *header.mutable_version() = merge.version; 678 builder.Add(kHeaderEntryKey, header.SerializeAsString()); 679 // All others. 680 for (const auto& p : merge.entries) { 681 builder.Add(p.first, p.second.SerializeAsString()); 682 } 683 status = builder.Finish(); 684 } 685 status.Update(merged_metadata->Close()); 686 if (!status.ok()) return status; 687 VLOG(1) << "Merged bundles to:" << merged_prefix; 688 689 // Cleanup: best effort based and ignores errors. 690 for (const string& prefix : prefixes) { 691 env->DeleteFile(MetaFilename(prefix)).IgnoreError(); 692 } 693 return status; 694 } 695 696 // Interface for reading a tensor bundle. 697 698 BundleReader::BundleReader(Env* env, StringPiece prefix) 699 : env_(env), 700 prefix_(prefix.ToString()), 701 metadata_(nullptr), 702 table_(nullptr), 703 iter_(nullptr) { 704 const string filename = MetaFilename(prefix_); 705 uint64 file_size; 706 status_ = env_->GetFileSize(filename, &file_size); 707 if (!status_.ok()) return; 708 709 // Opens the metadata table. 710 std::unique_ptr<RandomAccessFile> wrapper; 711 status_ = env_->NewRandomAccessFile(filename, &wrapper); 712 if (!status_.ok()) return; 713 metadata_ = wrapper.release(); 714 status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_); 715 if (!status_.ok()) return; 716 iter_ = table_->NewIterator(); 717 718 // Reads "num_shards_" from the first entry. 719 iter_->Seek(kHeaderEntryKey); 720 if (!iter_->Valid()) { 721 status_ = CorruptFileError(iter_->status(), filename, 722 "failed to seek to header entry"); 723 return; 724 } 725 BundleHeaderProto header; 726 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header); 727 if (!status_.ok()) { 728 status_ = CorruptFileError(status_, filename, "unable to parse header"); 729 return; 730 } 731 num_shards_ = header.num_shards(); 732 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) || 733 (header.endianness() == BundleHeaderProto::LITTLE && 734 !port::kLittleEndian)) { 735 status_ = errors::Unimplemented( 736 "Reading a bundle with different endianness from the reader"); 737 return; 738 } 739 status_ = CheckVersions(header.version(), kTensorBundleVersion, 740 kTensorBundleMinProducer, "Checkpoint", "checkpoint"); 741 } 742 743 BundleReader::~BundleReader() { 744 delete metadata_; 745 delete iter_; 746 delete table_; 747 // InputBuffer does not own the underlying RandomAccessFile. 748 for (auto pair : data_) { 749 if (pair.second != nullptr && pair.second->file() != nullptr) { 750 delete pair.second->file(); 751 } 752 } 753 gtl::STLDeleteValues(&data_); 754 gtl::STLDeleteValues(&tensor_slices_); 755 } 756 757 Status BundleReader::GetBundleEntryProto(StringPiece key, 758 BundleEntryProto* entry) { 759 entry->Clear(); 760 TF_CHECK_OK(status_); 761 Seek(key); 762 if (!iter_->Valid() || iter_->key() != key) { 763 return errors::NotFound("Key ", key, " not found in checkpoint"); 764 } 765 766 BundleEntryProto entry_copy; 767 TF_RETURN_IF_ERROR( 768 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy)); 769 if (!TensorShape::IsValid(entry_copy.shape())) { 770 return errors::DataLoss("Invaid tensor shape: ", key, " ", 771 ProtoShortDebugString(entry_copy.shape())); 772 } 773 774 *entry = entry_copy; 775 return Status::OK(); 776 } 777 778 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) { 779 Tensor* ret = val; 780 const TensorShape stored_shape(TensorShape(entry.shape())); 781 if (val->NumElements() == 0) { 782 ret = new Tensor(entry.dtype(), stored_shape); 783 } 784 785 // Validates the "size" field. 786 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) { 787 if (entry.size() != ret->TotalBytes()) { 788 return errors::DataLoss("Invalid size in bundle entry: key ", key(), 789 "; stored size ", entry.size(), 790 "; expected size ", ret->TotalBytes()); 791 } 792 } else if (entry.dtype() == DT_STRING) { 793 // Relaxes the check for string tensors as follows: 794 // entry.size() == bytes(varint lengths) + bytes(data) 795 // >= NumElems + bytes(data), since size bytes(varint) >= 1. 796 // TotalBytes() == sizeof(string) * NumElems + bytes(data) 797 // Since we don't know bytes(varint lengths), we just check an inequality. 798 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() - 799 sizeof(string) * ret->NumElements(); 800 if (entry.size() < lower_bound) { 801 return errors::DataLoss("Invalid size in bundle entry: key ", key(), 802 "; stored size ", entry.size(), 803 "; expected size is at least ", lower_bound); 804 } 805 } 806 807 // Open the data file if it has not been opened. 808 io::InputBuffer* buffered_file = data_[entry.shard_id()]; 809 if (buffered_file == nullptr) { 810 std::unique_ptr<RandomAccessFile> file = nullptr; 811 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile( 812 DataFilename(prefix_, entry.shard_id(), num_shards_), &file)); 813 buffered_file = new io::InputBuffer(file.release(), kBufferSize); 814 // The InputBuffer and RandomAccessFile objects are both released in dtor. 815 data_[entry.shard_id()] = buffered_file; 816 } 817 CHECK(buffered_file != nullptr); 818 819 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset())); 820 uint32 actual_crc32c = 0; 821 822 if (DataTypeCanUseMemcpy(entry.dtype())) { 823 char* backing_buffer = const_cast<char*>((ret->tensor_data().data())); 824 size_t unused_bytes_read; 825 if (entry.size() > kBufferSize) { 826 StringPiece sp; 827 TF_RETURN_IF_ERROR(buffered_file->file()->Read( 828 entry.offset(), entry.size(), &sp, backing_buffer)); 829 if (sp.data() != backing_buffer) { 830 memmove(backing_buffer, sp.data(), entry.size()); 831 } 832 } else { 833 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer, 834 &unused_bytes_read)); 835 } 836 actual_crc32c = crc32c::Value(backing_buffer, entry.size()); 837 } else if (entry.dtype() == DT_VARIANT) { 838 // Relies on io::InputBuffer's buffering, because we issue many neighboring 839 // reads for a single string tensor. 840 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(), 841 entry.size(), &actual_crc32c)); 842 } else { 843 // Relies on io::InputBuffer's buffering, because we issue many neighboring 844 // reads for a single string tensor. 845 TF_RETURN_IF_ERROR(ReadStringTensor( 846 buffered_file, ret->NumElements(), entry.offset(), entry.size(), 847 GetStringBackingBuffer(*ret), &actual_crc32c)); 848 } 849 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) { 850 return errors::DataLoss( 851 "Checksum does not match: stored ", 852 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())), 853 " vs. calculated on the restored bytes ", actual_crc32c); 854 } 855 856 *val = *ret; 857 if (ret != val) delete ret; 858 return Status::OK(); 859 } 860 861 Status BundleReader::Lookup(StringPiece key, Tensor* val) { 862 CHECK(val != nullptr); 863 BundleEntryProto entry; 864 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); 865 866 if (entry.slices().empty()) { 867 return GetValue(entry, val); 868 } else { 869 return GetSliceValue( 870 key, entry, 871 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val); 872 } 873 } 874 875 Status BundleReader::ReadCurrent(Tensor* val) { 876 CHECK(val != nullptr); 877 BundleEntryProto entry; 878 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry)); 879 if (!TensorShape::IsValid(entry.shape())) { 880 return errors::DataLoss("Invaid tensor shape: ", iter_->key(), " ", 881 ProtoShortDebugString(entry.shape())); 882 } 883 884 if (entry.slices().empty()) { 885 return GetValue(entry, val); 886 } else { 887 return GetSliceValue( 888 iter_->key(), entry, 889 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val); 890 } 891 } 892 893 Status BundleReader::LookupTensorSlices(StringPiece key, 894 std::vector<TensorSlice>* slices) { 895 slices->clear(); 896 BundleEntryProto entry; 897 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); 898 slices->reserve(entry.slices_size()); 899 for (const auto& slice : entry.slices()) { 900 slices->emplace_back(slice); 901 } 902 return Status::OK(); 903 } 904 905 Status BundleReader::LookupSlice(StringPiece full_tensor_key, 906 const TensorSlice& slice_spec, Tensor* val) { 907 CHECK(val != nullptr); 908 BundleEntryProto entry; 909 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry)); 910 return GetSliceValue(full_tensor_key, entry, slice_spec, val); 911 } 912 913 Status BundleReader::GetSliceValue(StringPiece full_tensor_key, 914 const BundleEntryProto& full_tensor_entry, 915 const TensorSlice& slice_spec, Tensor* val) { 916 using checkpoint::RegisterTensorSlice; 917 using checkpoint::TensorSliceSet; 918 DCHECK_GE(full_tensor_entry.slices_size(), 0); 919 920 const TensorShape full_shape(TensorShape(full_tensor_entry.shape())); 921 std::vector<std::pair<TensorSlice, string>> details; 922 const string full_tensor_key_string = full_tensor_key.ToString(); 923 const TensorSliceSet* tss = 924 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string); 925 926 // Populates the "full tensor key -> TensorSliceSet" cache. 927 if (tss == nullptr) { 928 if (full_tensor_entry.slices().empty()) { 929 // Special case: a writer has saved a tensor fully, but the reader wants 930 // to read in slices. We therefore register the full slice on-demand here 931 // without further complicating the on-disk bundle format. 932 TF_RETURN_IF_ERROR(RegisterTensorSlice( 933 full_tensor_key_string, full_shape, full_tensor_entry.dtype(), 934 /* tag */ "", 935 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_)); 936 } 937 for (const TensorSliceProto& slice : full_tensor_entry.slices()) { 938 TF_RETURN_IF_ERROR(RegisterTensorSlice( 939 full_tensor_key_string, full_shape, full_tensor_entry.dtype(), 940 /* tag */ "", TensorSlice(slice), &tensor_slices_)); 941 } 942 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string); 943 CHECK_NE(tss, nullptr); 944 } 945 if (!tss->QueryMeta(slice_spec, &details)) { 946 return errors::InvalidArgument( 947 "Does not have sufficient slices for partitioned tensor ", 948 full_tensor_key, 949 " to restore in slice_spec: ", slice_spec.DebugString()); 950 } 951 952 // The union of the slices in "details" covers "slice_spec". Performs the 953 // copies from each. 954 BundleEntryProto stored_slice_entry = full_tensor_entry; 955 for (const auto& slice_tag_pair : details) { 956 // Seeks for the stored slice. 957 const TensorSlice& stored_slice = slice_tag_pair.first; 958 959 // We already have the entry for the full tensor, so don't query again if 960 // the slice is full. 961 if (!stored_slice.IsFull()) { 962 const string encoded_stored_slice_name = 963 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, 964 stored_slice); 965 status_ = 966 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry); 967 if (!status_.ok()) return status_; 968 } 969 970 // TODO(zongheng): should we take an OpKernelContext, so that we can call 971 // allocate_temp()? Note that without major refactorings to Saver, it's 972 // hard for the caller of the tensor bundle module to allocate these 973 // precisely-shaped scratch storage. 974 975 // Optimization for the common case: the stored slice can be directly 976 // copied to the destination without additional slicing. This is true when 977 // either the slices are equal or when they are both full slices having the 978 // same shape. 979 TensorShape stored_slice_shape(stored_slice_entry.shape()); 980 if (stored_slice == slice_spec || 981 (stored_slice_shape == val->shape() && 982 IsFullSlice(stored_slice, stored_slice_shape) && 983 IsFullSlice(slice_spec, stored_slice_shape))) { 984 VLOG(1) << "Optimized for common case: directly copying into " 985 "pre-allocated buffer; spec: " 986 << slice_spec.DebugString(); 987 status_ = GetValue(stored_slice_entry, val); 988 return status_; 989 } 990 991 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape); 992 status_ = GetValue(stored_slice_entry, &stored_slice_tensor); 993 if (!status_.ok()) return status_; 994 995 // Copies the intersection over. 996 const DataType common_dtype = full_tensor_entry.dtype(); 997 switch (common_dtype) { 998 #define HANDLE_COPY(T) \ 999 case DataTypeToEnum<T>::value: \ 1000 CHECK(CopyDataFromTensorSliceToTensorSlice( \ 1001 full_shape, stored_slice, slice_spec, \ 1002 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \ 1003 break; 1004 1005 HANDLE_COPY(float) 1006 HANDLE_COPY(double) 1007 HANDLE_COPY(int32) 1008 HANDLE_COPY(uint8) 1009 HANDLE_COPY(int16) 1010 HANDLE_COPY(int8) 1011 HANDLE_COPY(complex64) 1012 HANDLE_COPY(complex128) 1013 HANDLE_COPY(int64) 1014 HANDLE_COPY(bool) 1015 HANDLE_COPY(qint32) 1016 HANDLE_COPY(quint8) 1017 HANDLE_COPY(qint8) 1018 default: 1019 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype), 1020 " not supported."); 1021 } 1022 #undef HANDLE_COPY 1023 } 1024 return Status::OK(); 1025 } 1026 1027 bool BundleReader::Contains(StringPiece key) { 1028 Seek(key); 1029 return Valid() && (this->key() == key); 1030 } 1031 1032 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype, 1033 TensorShape* shape) { 1034 BundleEntryProto entry; 1035 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry)); 1036 *dtype = entry.dtype(); 1037 *shape = TensorShape(entry.shape()); 1038 return Status::OK(); 1039 } 1040 1041 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) { 1042 DataType ignored; 1043 return LookupDtypeAndShape(key, &ignored, shape); 1044 } 1045 1046 string BundleReader::DebugString() { 1047 // Format used below emulates that of TensorSliceReader::DebugString(). 1048 string shape_str; 1049 BundleEntryProto entry; 1050 Seek(kHeaderEntryKey); 1051 for (Next(); Valid(); Next()) { 1052 CHECK(entry.ParseFromArray(value().data(), value().size())); 1053 if (entry.slices_size() > 0) continue; // Slice of some partitioned var. 1054 1055 strings::StrAppend(&shape_str, key(), " (", 1056 EnumName_DataType(entry.dtype()), ") ", 1057 TensorShape(entry.shape()).DebugString()); 1058 strings::StrAppend(&shape_str, "\n"); 1059 } 1060 return shape_str; 1061 } 1062 1063 FileOutputBuffer::~FileOutputBuffer() { delete file_; } 1064 1065 Status FileOutputBuffer::Append(StringPiece data) { 1066 // In the below, it is critical to calculate the checksum on the actually 1067 // copied bytes, not the source bytes. This is because "data" typically 1068 // points to tensor buffers, which may be concurrently written. 1069 if (data.size() + position_ <= buffer_size_) { 1070 // Can fit into the current buffer. 1071 memcpy(&buffer_[position_], data.data(), data.size()); 1072 crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size()); 1073 } else if (data.size() <= buffer_size_) { 1074 // Cannot fit, but can fit after flushing. 1075 TF_RETURN_IF_ERROR(FlushBuffer()); 1076 memcpy(&buffer_[0], data.data(), data.size()); 1077 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size()); 1078 } else { 1079 // Cannot fit even after flushing. So we break down "data" by chunk, and 1080 // flush/checksum each chunk. 1081 TF_RETURN_IF_ERROR(FlushBuffer()); 1082 for (size_t i = 0; i < data.size(); i += buffer_size_) { 1083 const size_t nbytes = std::min(data.size() - i, buffer_size_); 1084 memcpy(&buffer_[0], data.data() + i, nbytes); 1085 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes); 1086 position_ = nbytes; 1087 TF_RETURN_IF_ERROR(FlushBuffer()); 1088 } 1089 return Status::OK(); 1090 } 1091 position_ += data.size(); 1092 return Status::OK(); 1093 } 1094 1095 Status FileOutputBuffer::Close() { 1096 TF_RETURN_IF_ERROR(FlushBuffer()); 1097 return file_->Close(); 1098 } 1099 1100 Status FileOutputBuffer::FlushBuffer() { 1101 if (position_ > 0) { 1102 TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_))); 1103 position_ = 0; 1104 } 1105 return Status::OK(); 1106 } 1107 1108 } // namespace tensorflow 1109