Home | History | Annotate | Download | only in tensor_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     17 
     18 #include <algorithm>
     19 #include <cstdlib>
     20 #include <cstring>
     21 #include <memory>
     22 #include <utility>
     23 
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.pb.h"
     26 #include "tensorflow/core/framework/tensor_shape.pb_text.h"
     27 #include "tensorflow/core/framework/tensor_shape.pb.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/framework/types.pb_text.h"
     30 #include "tensorflow/core/framework/variant.h"
     31 #include "tensorflow/core/framework/variant_op_registry.h"
     32 #include "tensorflow/core/framework/variant_tensor_data.h"
     33 #include "tensorflow/core/framework/versions.h"
     34 #include "tensorflow/core/framework/versions.pb.h"
     35 #include "tensorflow/core/lib/core/coding.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/gtl/map_util.h"
     38 #include "tensorflow/core/lib/gtl/stl_util.h"
     39 #include "tensorflow/core/lib/hash/crc32c.h"
     40 #include "tensorflow/core/lib/io/path.h"
     41 #include "tensorflow/core/lib/io/table_builder.h"
     42 #include "tensorflow/core/lib/random/random.h"
     43 #include "tensorflow/core/lib/strings/stringprintf.h"
     44 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     45 #include "tensorflow/core/util/tensor_slice_util.h"
     46 
     47 namespace tensorflow {
     48 
     49 // Versioning of the tensor bundle format.
     50 const int kTensorBundleMinProducer = 0;
     51 const int kTensorBundleMinConsumer = 0;
     52 const int kTensorBundleVersion = 1;
     53 
     54 // Size of our input buffer for streaming reads
     55 static const int kBufferSize = 1024 * 1024;
     56 
     57 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
     58 // can make the assumption that the header is always the first entry in the
     59 // bundle.
     60 const char* const kHeaderEntryKey = "";
     61 
     62 namespace {
     63 
     64 // Reads "num_elements" string elements from file[offset, offset+size) into the
     65 // length-N "destination".  Discards the original content of "destination".
     66 //
     67 // Checksums the string lengths (as restored uint32, not varint32 bytes) and
     68 // string bytes, and stores it into "actual_crc32c".
     69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
     70                         size_t offset, size_t size, string* destination,
     71                         uint32* actual_crc32c) {
     72   if (size == 0) return Status::OK();
     73   CHECK_GT(size, 0);
     74 
     75   // Reads "num_elements" varint32's from "buffered_file".
     76   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
     77   std::vector<uint32> string_lengths(num_elements);
     78   for (size_t i = 0; i < num_elements; ++i) {
     79     TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
     80   }
     81   if (offset + size < buffered_file->Tell()) {
     82     return errors::DataLoss("String lengths longer than expected offset ",
     83                             offset + size);
     84   }
     85   *actual_crc32c =
     86       crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
     87                     sizeof(uint32) * num_elements);
     88 
     89   // Reads the length-checksum.
     90   uint32 length_checksum = 0;
     91   size_t unused_bytes_read = 0;
     92   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
     93       sizeof(uint32), reinterpret_cast<char*>(&length_checksum),
     94       &unused_bytes_read));
     95   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
     96     return errors::DataLoss(
     97         "The length checksum does not match: expected ",
     98         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
     99         " but actual is ", strings::Printf("%08u", *actual_crc32c));
    100   }
    101   *actual_crc32c =
    102       crc32c::Extend(*actual_crc32c, reinterpret_cast<char*>(&length_checksum),
    103                      sizeof(uint32));
    104 
    105   // Reads the actual string bytes.
    106   for (size_t i = 0; i < num_elements; ++i) {
    107     const uint32 string_length = string_lengths[i];
    108     string* buffer = &destination[i];
    109 
    110     buffer->resize(string_length);
    111     size_t bytes_read = 0;
    112     TF_RETURN_IF_ERROR(
    113         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
    114     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
    115   }
    116   return Status::OK();
    117 }
    118 
    119 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
    120                          size_t offset, size_t size, uint32* actual_crc32c) {
    121   // On-disk format:
    122   //   [varint64 len1][bytes variant1][4 byte checksum]
    123   //   ..
    124   //   [varint64 lenN][bytes variantN][4 byte checksum]
    125   // Var "crc32c" checksums all the lens, variant bytes, individual variant
    126   // checksums (as uint32, not varint32 bytes).
    127   if (size == 0) return Status::OK();
    128   size_t num_elements = ret->NumElements();
    129 
    130   // Reads the actual string bytes.
    131   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
    132   for (size_t i = 0; i < num_elements; ++i) {
    133     // Read the serialized variant length.
    134     uint64 string_length = 0;
    135     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
    136     *actual_crc32c = crc32c::Extend(
    137         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
    138         sizeof(uint64));
    139     // Read the actual serialized variant.
    140     string buffer;
    141     buffer.resize(string_length);
    142     size_t bytes_read = 0;
    143     TF_RETURN_IF_ERROR(
    144         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
    145     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
    146     VariantTensorDataProto proto;
    147     if (!proto.ParseFromString(buffer)) {
    148       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
    149                               "buffer of size ", string_length, ". ",
    150                               "Bundle entry offset: ", offset, " size: ", size);
    151     }
    152     Variant v = proto;
    153     if (!DecodeUnaryVariant(&v)) {
    154       return errors::Internal("Could not decode variant with type_name: \"",
    155                               v.TypeName(), "\".  Perhaps you forgot to ",
    156                               "register a decoder via ",
    157                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
    158     }
    159 
    160     // Read the checksum.
    161     uint32 checksum = 0;
    162     size_t unused_bytes_read = 0;
    163     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
    164         sizeof(uint32), reinterpret_cast<char*>(&checksum),
    165         &unused_bytes_read));
    166     if (crc32c::Unmask(checksum) != *actual_crc32c) {
    167       return errors::DataLoss(
    168           "The checksum after Variant ", i, " does not match.",
    169           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
    170           " Actual: ", strings::Printf("%08u", *actual_crc32c));
    171     }
    172     *actual_crc32c = crc32c::Extend(
    173         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
    174 
    175     ret->flat<Variant>()(i) = std::move(v);
    176   }
    177 
    178   return Status::OK();
    179 }
    180 
    181 char* GetBackingBuffer(const Tensor& val) {
    182   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
    183   return const_cast<char*>(val.tensor_data().data());
    184 }
    185 
    186 string* GetStringBackingBuffer(const Tensor& val) {
    187   CHECK_EQ(DT_STRING, val.dtype());
    188   return const_cast<string*>(val.flat<string>().data());
    189 }
    190 
    191 Status ParseEntryProto(StringPiece key, StringPiece value,
    192                        protobuf::MessageLite* out) {
    193   if (!out->ParseFromArray(value.data(), value.size())) {
    194     return errors::DataLoss("Entry for key ", key, " not parseable.");
    195   }
    196   return Status::OK();
    197 }
    198 
    199 // Serializes the data bytes of the non-string tensor "val".  Discards the
    200 // original content of "bytes_written", and on OK updates it with number of
    201 // bytes written.
    202 // REQUIRES: val.dtype() != DT_STRING
    203 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
    204                    size_t* bytes_written) {
    205   DCHECK_NE(val.dtype(), DT_STRING);
    206   DCHECK_NE(val.dtype(), DT_VARIANT);
    207   *bytes_written = val.TotalBytes();
    208   char* buf = GetBackingBuffer(val);
    209   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
    210   return out->Append(StringPiece(buf, *bytes_written));
    211 }
    212 
    213 // Serializes string tensor "val".  "bytes_written" is treated in the same
    214 // fashion as WriteTensor().
    215 //
    216 // Checksums all bytes written and stores it into "crc32c".
    217 // REQUIRES: val.dtype() == DT_STRING
    218 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
    219                          size_t* bytes_written, uint32* crc32c) {
    220   // On-disk format:
    221   //   [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
    222   // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
    223   // the length-checksum, and all the string bytes.
    224   DCHECK_EQ(val.dtype(), DT_STRING);
    225   const string* strings = GetStringBackingBuffer(val);
    226 
    227   // Writes the varint lengths.
    228   string lengths;
    229   lengths.reserve(val.NumElements());  // At least 1 byte per element.
    230   *crc32c = 0;
    231   for (int64 i = 0; i < val.NumElements(); ++i) {
    232     const string* elem = &strings[i];
    233     DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
    234     const uint32 elem_size = static_cast<uint32>(elem->size());
    235 
    236     core::PutVarint32(&lengths, elem_size);
    237     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
    238                              sizeof(uint32));
    239   }
    240   TF_RETURN_IF_ERROR(out->Append(lengths));
    241   *bytes_written = lengths.size();
    242 
    243   // Writes the length checksum.
    244   const uint32 length_checksum = crc32c::Mask(*crc32c);
    245   TF_RETURN_IF_ERROR(out->Append(StringPiece(
    246       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
    247   *crc32c = crc32c::Extend(
    248       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
    249   *bytes_written += sizeof(uint32);
    250 
    251   // Writes all the string bytes out.
    252   for (int64 i = 0; i < val.NumElements(); ++i) {
    253     const string* string = &strings[i];
    254     TF_RETURN_IF_ERROR(out->Append(*string));
    255     *bytes_written += string->size();
    256     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
    257   }
    258   return Status::OK();
    259 }
    260 
    261 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
    262                           size_t* bytes_written, uint32* crc32c) {
    263   // On-disk format:
    264   //   [varint64 len1][bytes variant1][4 byte checksum]
    265   //   ..
    266   //   [varint64 lenN][bytes variantN][4 byte checksum]
    267   // Var "crc32c" checksums all the lens, variant bytes, individual variant
    268   // checksums (as uint32, not varint32 bytes).
    269   DCHECK_EQ(val.dtype(), DT_VARIANT);
    270 
    271   *crc32c = 0;
    272   *bytes_written = 0;
    273   for (int64 i = 0; i < val.NumElements(); ++i) {
    274     VariantTensorData data;
    275     val.flat<Variant>()(i).Encode(&data);
    276     VariantTensorDataProto proto;
    277     data.ToProto(&proto);
    278     string elem;
    279     proto.SerializeToString(&elem);
    280 
    281     // Write the length of the serialized variant.
    282     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
    283     const auto elem_size = static_cast<uint64>(elem.size());
    284     string len;
    285     core::PutVarint64(&len, elem_size);
    286     TF_RETURN_IF_ERROR(out->Append(len));
    287     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
    288                              sizeof(uint64));
    289     *bytes_written += len.size();
    290 
    291     // Write the serialized variant.
    292     TF_RETURN_IF_ERROR(out->Append(elem));
    293     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
    294     *bytes_written += elem.size();
    295 
    296     // Write the checksum.
    297     const uint32 length_checksum = crc32c::Mask(*crc32c);
    298     TF_RETURN_IF_ERROR(out->Append(StringPiece(
    299         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
    300     *crc32c =
    301         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
    302                        sizeof(uint32));
    303     *bytes_written += sizeof(uint32);
    304   }
    305 
    306   return Status::OK();
    307 }
    308 
    309 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
    310 //
    311 // This can happen say, when "slice_spec" is
    312 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
    313 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
    314 bool IsFullSlice(const TensorSlice& slice_spec,
    315                  const TensorShape& full_tensor_shape) {
    316   if (slice_spec.IsFull()) {
    317     return true;
    318   } else {
    319     TensorShape sliced_shape;
    320     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
    321     return sliced_shape == full_tensor_shape;
    322   }
    323 }
    324 
    325 Status CorruptFileError(const Status& in_status, const string& filename,
    326                         const string& detail) {
    327   if (in_status.ok()) {
    328     return errors::Internal("Unable to read file (", filename,
    329                             "). Perhaps the file is corrupt or was produced by "
    330                             "a newer version of TensorFlow with format changes "
    331                             "(",
    332                             detail, ")");
    333   }
    334   return Status(
    335       in_status.code(),
    336       strings::StrCat("Unable to read file (", filename,
    337                       "). Perhaps the file is corrupt or was produced by a "
    338                       "newer version of TensorFlow with format changes (",
    339                       detail, "): ", in_status.error_message()));
    340 }
    341 
    342 table::Options TableBuilderOptions() {
    343   table::Options o;
    344   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
    345   // To smoothen the transition, compressed writes are disabled for now
    346   // (version 1.2) with the intention that they will be enabled again at
    347   // some point (perhaps the 1.3 release?).
    348   o.compression = table::kNoCompression;
    349   return o;
    350 }
    351 
    352 // Writes zeros to output buffer to align the next write to the requested
    353 // alignment. "size" is the current size of the buffer and is updated to the
    354 // new size.
    355 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
    356   int bytes_over = *size % alignment;
    357   if (bytes_over == 0) {
    358     return Status::OK();
    359   }
    360   int bytes_to_write = alignment - bytes_over;
    361   Status status = out->Append(string(bytes_to_write, '\0'));
    362   if (status.ok()) {
    363     *size += bytes_to_write;
    364   }
    365   return status;
    366 }
    367 
    368 }  // namespace
    369 
    370 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
    371     : env_(env),
    372       options_(options),
    373       prefix_(prefix.ToString()),
    374       tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
    375                                          random::New64())),
    376       tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
    377                                      random::New64())),
    378       out_(nullptr),
    379       size_(0) {
    380   status_ = env_->CreateDir(io::Dirname(prefix_).ToString());
    381   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
    382     return;
    383   }
    384   const string filename = DataFilename(prefix_, 0, 1);
    385   std::unique_ptr<WritableFile> wrapper;
    386   status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
    387   if (!status_.ok()) return;
    388   out_ = std::unique_ptr<FileOutputBuffer>(
    389       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
    390 
    391   VLOG(1) << "Writing to file " << tmp_data_path_;
    392 }
    393 
    394 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
    395   if (!status_.ok()) return status_;
    396   CHECK_NE(key, kHeaderEntryKey);
    397   const string key_string = key.ToString();
    398   if (entries_.find(key_string) != entries_.end()) {
    399     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
    400     return status_;
    401   }
    402 
    403   BundleEntryProto* entry = &entries_[key_string];
    404   entry->set_dtype(val.dtype());
    405   val.shape().AsProto(entry->mutable_shape());
    406   entry->set_shard_id(0);
    407   entry->set_offset(size_);
    408 
    409   // Updates the data file.
    410   size_t data_bytes_written = 0;
    411   uint32 crc32c = 0;
    412   out_->clear_crc32c();
    413   if (val.dtype() == DT_STRING) {
    414     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
    415   } else if (val.dtype() == DT_VARIANT) {
    416     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
    417   } else {
    418     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
    419     crc32c = out_->crc32c();
    420   }
    421 
    422   if (status_.ok()) {
    423     entry->set_size(data_bytes_written);
    424     entry->set_crc32c(crc32c::Mask(crc32c));
    425     size_ += data_bytes_written;
    426     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
    427   }
    428   return status_;
    429 }
    430 
    431 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
    432                               const TensorShape& full_tensor_shape,
    433                               const TensorSlice& slice_spec,
    434                               const Tensor& slice_tensor) {
    435   if (!status_.ok()) return status_;
    436   CHECK_NE(full_tensor_key, kHeaderEntryKey);
    437 
    438   // If just a singleton full slice, use the regular Add() to be more efficient.
    439   if (IsFullSlice(slice_spec, full_tensor_shape)) {
    440     return Add(full_tensor_key, slice_tensor);
    441   }
    442 
    443   // Inserts/updates the full tensor's metadata entry.
    444   //
    445   // In the case of a sharded save, MergeBundles() is responsible for merging
    446   // the "slices" field of multiple metadata entries corresponding to the same
    447   // full tensor.
    448   const string full_tensor_key_string = full_tensor_key.ToString();
    449   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
    450   if (full_entry->dtype() != DT_INVALID) {
    451     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
    452   }
    453   if (full_entry->has_shape()) {
    454     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
    455   }
    456 
    457   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
    458   // offset, which do not make sense for this full tensor entry.
    459   full_entry->set_dtype(slice_tensor.dtype());
    460   full_tensor_shape.AsProto(full_entry->mutable_shape());
    461   TensorSliceProto* slice_proto = full_entry->add_slices();
    462   slice_spec.AsProto(slice_proto);
    463 
    464   // The slice itself is handled by a regular Add(), which includes adding its
    465   // own metadata entry, and writing out the slice's values.
    466   const string slice_name =
    467       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
    468   status_ = Add(slice_name, slice_tensor);
    469   return status_;
    470 }
    471 
    472 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
    473 // the orphaned data file.
    474 Status BundleWriter::Finish() {
    475   if (out_) {
    476     status_.Update(out_->Close());
    477     out_ = nullptr;
    478     if (status_.ok()) {
    479       status_ = Env::Default()->RenameFile(tmp_data_path_,
    480                                            DataFilename(prefix_, 0, 1));
    481     } else {
    482       Env::Default()->DeleteFile(tmp_data_path_).IgnoreError();
    483     }
    484   }
    485   if (!status_.ok()) return status_;
    486   // Build key -> BundleEntryProto table.
    487   std::unique_ptr<WritableFile> file;
    488   status_ = env_->NewWritableFile(tmp_metadata_path_, &file);
    489   if (!status_.ok()) return status_;
    490   {
    491     // N.B.: the default use of Snappy compression may not be supported on all
    492     // platforms (e.g. Android).  The metadata file is small, so this is fine.
    493     table::Options options;
    494     options.compression = table::kNoCompression;
    495     table::TableBuilder builder(options, file.get());
    496     // Header entry.
    497     BundleHeaderProto header;
    498     header.set_num_shards(1);
    499     header.set_endianness(BundleHeaderProto::LITTLE);
    500     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
    501     VersionDef* version = header.mutable_version();
    502     version->set_producer(kTensorBundleVersion);
    503     version->set_min_consumer(kTensorBundleMinConsumer);
    504 
    505     builder.Add(kHeaderEntryKey, header.SerializeAsString());
    506 
    507     // All others.
    508     for (const auto& p : entries_) {
    509       builder.Add(p.first, p.second.SerializeAsString());
    510     }
    511     status_ = builder.Finish();
    512   }
    513   status_.Update(file->Close());
    514   if (!status_.ok()) {
    515     Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError();
    516     return status_;
    517   } else {
    518     status_ =
    519         Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_));
    520     if (!status_.ok()) return status_;
    521   }
    522   status_ = errors::Internal("BundleWriter is closed");
    523   return Status::OK();
    524 }
    525 
    526 // Merging tensor bundles.
    527 
    528 // Accumulator of metadata states during a merge.
    529 struct MergeState {
    530   // Accumulated from the header entries.
    531   int num_shards = 0;
    532 
    533   // Derives "endianness" and "version" from the first bundle merged (hence the
    534   // "seen_first_bundle" guard).  The two fields must be the same for all
    535   // bundles in a merge.
    536   bool seen_first_bundle = false;
    537   BundleHeaderProto_Endianness endianness;
    538   VersionDef version;
    539 
    540   // Tensor key -> BundleEntryProto.
    541   std::map<string, BundleEntryProto> entries;
    542   // Data file path -> new shard id in the final merged bundle.
    543   std::unordered_map<string, int32> shard_ids;
    544 };
    545 
    546 // Merges entries of "prefix" into the accumulator state "merge".
    547 // Returns OK iff the merge succeeds.
    548 static Status MergeOneBundle(Env* env, StringPiece prefix,
    549                              MergeState* merge_state) {
    550   VLOG(1) << "Merging bundle:" << prefix;
    551   const string filename = MetaFilename(prefix);
    552   uint64 file_size;
    553   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
    554   std::unique_ptr<RandomAccessFile> file;
    555   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    556 
    557   table::Table* table = nullptr;
    558   TF_RETURN_IF_ERROR(
    559       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
    560   std::unique_ptr<table::Table> table_deleter(table);
    561   std::unique_ptr<table::Iterator> iter(table->NewIterator());
    562 
    563   int num_shards;
    564   // Process header.
    565   {
    566     iter->Seek(kHeaderEntryKey);
    567     if (!iter->Valid()) {
    568       return CorruptFileError(iter->status(), filename,
    569                               "failed to seek to header entry");
    570     }
    571     BundleHeaderProto header;
    572     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
    573     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
    574 
    575     merge_state->num_shards += header.num_shards();
    576     if (!merge_state->seen_first_bundle) {
    577       merge_state->seen_first_bundle = true;
    578       merge_state->endianness = header.endianness();
    579       merge_state->version = header.version();
    580     } else {
    581       // Validates "endianness".
    582       if (merge_state->endianness != header.endianness()) {
    583         return errors::InvalidArgument(
    584             "Merging bundles with conflicting endianness; inputs corrupted?");
    585       }
    586       // Validates "version".
    587       string curr_version, merge_version;
    588       header.version().SerializeToString(&curr_version);
    589       merge_state->version.SerializeToString(&merge_version);
    590       if (curr_version != merge_version) {
    591         return errors::InvalidArgument(
    592             "Merging bundles with different format versions: merged ",
    593             merge_version, " vs. curr ", curr_version);
    594       }
    595     }
    596     num_shards = header.num_shards();
    597     iter->Next();
    598   }
    599 
    600   // Loops through the non-header to-merge entries.
    601   BundleEntryProto to_merge_entry;
    602   for (; iter->Valid(); iter->Next()) {
    603     const string key = iter->key().ToString();
    604     const auto entry_iter = merge_state->entries.find(key);
    605 
    606     // Illegal: the duplicated entry is a non-slice tensor.
    607     if (entry_iter != merge_state->entries.end() &&
    608         entry_iter->second.slices().empty()) {
    609       return errors::InvalidArgument(
    610           "Duplicate tensor keyed by ", key,
    611           " encountered, when merging prefix: ", prefix);
    612     }
    613 
    614     TF_RETURN_IF_ERROR(
    615         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
    616 
    617     // The duplicated entry holds metadata for a sliced full tensor.
    618     // Allows the duplication and merges "slices".
    619     if (entry_iter != merge_state->entries.end()) {
    620       BundleEntryProto& existing_entry = entry_iter->second;
    621       if (to_merge_entry.slices().empty()) {
    622         return errors::Internal(
    623             "Duplicate tensor keyed by ", key,
    624             "; attempting to merge in a non-slice bundle entry");
    625       }
    626       // Only needs merge the "slices" field (and validate dtype/shape).
    627       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
    628         TensorSliceProto* slot = existing_entry.add_slices();
    629         *slot = to_merge_entry.slices(i);
    630       }
    631       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
    632       CHECK(TensorShape(existing_entry.shape()) ==
    633             TensorShape(to_merge_entry.shape()));
    634       continue;
    635     }
    636 
    637     // Key doesn't duplicate: a fresh tensor/slice entry.
    638     auto result = merge_state->shard_ids.insert(
    639         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
    640          merge_state->shard_ids.size()});
    641     to_merge_entry.set_shard_id(result.first->second);
    642     merge_state->entries[key] = to_merge_entry;
    643   }
    644   return Status::OK();
    645 }
    646 
    647 Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
    648                     StringPiece merged_prefix) {
    649   // Merges all metadata tables.
    650   // TODO(zhifengc): KeyValue sorter if it becomes too big.
    651   MergeState merge;
    652   Status status = env->CreateDir(io::Dirname(merged_prefix).ToString());
    653   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
    654   for (int i = 0; i < prefixes.size(); ++i) {
    655     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
    656   }
    657 
    658   // Renames data files to contain the merged bundle prefix.
    659   for (const auto& p : merge.shard_ids) {
    660     VLOG(1) << "Renaming " << p.first << " to "
    661             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
    662     TF_RETURN_IF_ERROR(env->RenameFile(
    663         p.first,
    664         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
    665   }
    666 
    667   // Writes the final metadata table under the merged prefix.
    668   std::unique_ptr<WritableFile> merged_metadata;
    669   TF_RETURN_IF_ERROR(
    670       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
    671   {
    672     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
    673     // Header entry.
    674     BundleHeaderProto header;
    675     header.set_num_shards(merge.num_shards);
    676     header.set_endianness(merge.endianness);
    677     *header.mutable_version() = merge.version;
    678     builder.Add(kHeaderEntryKey, header.SerializeAsString());
    679     // All others.
    680     for (const auto& p : merge.entries) {
    681       builder.Add(p.first, p.second.SerializeAsString());
    682     }
    683     status = builder.Finish();
    684   }
    685   status.Update(merged_metadata->Close());
    686   if (!status.ok()) return status;
    687   VLOG(1) << "Merged bundles to:" << merged_prefix;
    688 
    689   // Cleanup: best effort based and ignores errors.
    690   for (const string& prefix : prefixes) {
    691     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
    692   }
    693   return status;
    694 }
    695 
    696 // Interface for reading a tensor bundle.
    697 
    698 BundleReader::BundleReader(Env* env, StringPiece prefix)
    699     : env_(env),
    700       prefix_(prefix.ToString()),
    701       metadata_(nullptr),
    702       table_(nullptr),
    703       iter_(nullptr) {
    704   const string filename = MetaFilename(prefix_);
    705   uint64 file_size;
    706   status_ = env_->GetFileSize(filename, &file_size);
    707   if (!status_.ok()) return;
    708 
    709   // Opens the metadata table.
    710   std::unique_ptr<RandomAccessFile> wrapper;
    711   status_ = env_->NewRandomAccessFile(filename, &wrapper);
    712   if (!status_.ok()) return;
    713   metadata_ = wrapper.release();
    714   status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
    715   if (!status_.ok()) return;
    716   iter_ = table_->NewIterator();
    717 
    718   // Reads "num_shards_" from the first entry.
    719   iter_->Seek(kHeaderEntryKey);
    720   if (!iter_->Valid()) {
    721     status_ = CorruptFileError(iter_->status(), filename,
    722                                "failed to seek to header entry");
    723     return;
    724   }
    725   BundleHeaderProto header;
    726   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
    727   if (!status_.ok()) {
    728     status_ = CorruptFileError(status_, filename, "unable to parse header");
    729     return;
    730   }
    731   num_shards_ = header.num_shards();
    732   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
    733       (header.endianness() == BundleHeaderProto::LITTLE &&
    734        !port::kLittleEndian)) {
    735     status_ = errors::Unimplemented(
    736         "Reading a bundle with different endianness from the reader");
    737     return;
    738   }
    739   status_ = CheckVersions(header.version(), kTensorBundleVersion,
    740                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
    741 }
    742 
    743 BundleReader::~BundleReader() {
    744   delete metadata_;
    745   delete iter_;
    746   delete table_;
    747   // InputBuffer does not own the underlying RandomAccessFile.
    748   for (auto pair : data_) {
    749     if (pair.second != nullptr && pair.second->file() != nullptr) {
    750       delete pair.second->file();
    751     }
    752   }
    753   gtl::STLDeleteValues(&data_);
    754   gtl::STLDeleteValues(&tensor_slices_);
    755 }
    756 
    757 Status BundleReader::GetBundleEntryProto(StringPiece key,
    758                                          BundleEntryProto* entry) {
    759   entry->Clear();
    760   TF_CHECK_OK(status_);
    761   Seek(key);
    762   if (!iter_->Valid() || iter_->key() != key) {
    763     return errors::NotFound("Key ", key, " not found in checkpoint");
    764   }
    765 
    766   BundleEntryProto entry_copy;
    767   TF_RETURN_IF_ERROR(
    768       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
    769   if (!TensorShape::IsValid(entry_copy.shape())) {
    770     return errors::DataLoss("Invaid tensor shape: ", key, " ",
    771                             ProtoShortDebugString(entry_copy.shape()));
    772   }
    773 
    774   *entry = entry_copy;
    775   return Status::OK();
    776 }
    777 
    778 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
    779   Tensor* ret = val;
    780   const TensorShape stored_shape(TensorShape(entry.shape()));
    781   if (val->NumElements() == 0) {
    782     ret = new Tensor(entry.dtype(), stored_shape);
    783   }
    784 
    785   // Validates the "size" field.
    786   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
    787     if (entry.size() != ret->TotalBytes()) {
    788       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
    789                               "; stored size ", entry.size(),
    790                               "; expected size ", ret->TotalBytes());
    791     }
    792   } else if (entry.dtype() == DT_STRING) {
    793     // Relaxes the check for string tensors as follows:
    794     //   entry.size() == bytes(varint lengths) + bytes(data)
    795     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
    796     //   TotalBytes() == sizeof(string) * NumElems + bytes(data)
    797     // Since we don't know bytes(varint lengths), we just check an inequality.
    798     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
    799                                sizeof(string) * ret->NumElements();
    800     if (entry.size() < lower_bound) {
    801       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
    802                               "; stored size ", entry.size(),
    803                               "; expected size is at least ", lower_bound);
    804     }
    805   }
    806 
    807   // Open the data file if it has not been opened.
    808   io::InputBuffer* buffered_file = data_[entry.shard_id()];
    809   if (buffered_file == nullptr) {
    810     std::unique_ptr<RandomAccessFile> file = nullptr;
    811     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
    812         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
    813     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
    814     // The InputBuffer and RandomAccessFile objects are both released in dtor.
    815     data_[entry.shard_id()] = buffered_file;
    816   }
    817   CHECK(buffered_file != nullptr);
    818 
    819   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
    820   uint32 actual_crc32c = 0;
    821 
    822   if (DataTypeCanUseMemcpy(entry.dtype())) {
    823     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
    824     size_t unused_bytes_read;
    825     if (entry.size() > kBufferSize) {
    826       StringPiece sp;
    827       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
    828           entry.offset(), entry.size(), &sp, backing_buffer));
    829       if (sp.data() != backing_buffer) {
    830         memmove(backing_buffer, sp.data(), entry.size());
    831       }
    832     } else {
    833       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
    834                                                    &unused_bytes_read));
    835     }
    836     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
    837   } else if (entry.dtype() == DT_VARIANT) {
    838     // Relies on io::InputBuffer's buffering, because we issue many neighboring
    839     // reads for a single string tensor.
    840     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
    841                                          entry.size(), &actual_crc32c));
    842   } else {
    843     // Relies on io::InputBuffer's buffering, because we issue many neighboring
    844     // reads for a single string tensor.
    845     TF_RETURN_IF_ERROR(ReadStringTensor(
    846         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
    847         GetStringBackingBuffer(*ret), &actual_crc32c));
    848   }
    849   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
    850     return errors::DataLoss(
    851         "Checksum does not match: stored ",
    852         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
    853         " vs. calculated on the restored bytes ", actual_crc32c);
    854   }
    855 
    856   *val = *ret;
    857   if (ret != val) delete ret;
    858   return Status::OK();
    859 }
    860 
    861 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
    862   CHECK(val != nullptr);
    863   BundleEntryProto entry;
    864   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
    865 
    866   if (entry.slices().empty()) {
    867     return GetValue(entry, val);
    868   } else {
    869     return GetSliceValue(
    870         key, entry,
    871         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
    872   }
    873 }
    874 
    875 Status BundleReader::ReadCurrent(Tensor* val) {
    876   CHECK(val != nullptr);
    877   BundleEntryProto entry;
    878   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
    879   if (!TensorShape::IsValid(entry.shape())) {
    880     return errors::DataLoss("Invaid tensor shape: ", iter_->key(), " ",
    881                             ProtoShortDebugString(entry.shape()));
    882   }
    883 
    884   if (entry.slices().empty()) {
    885     return GetValue(entry, val);
    886   } else {
    887     return GetSliceValue(
    888         iter_->key(), entry,
    889         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
    890   }
    891 }
    892 
    893 Status BundleReader::LookupTensorSlices(StringPiece key,
    894                                         std::vector<TensorSlice>* slices) {
    895   slices->clear();
    896   BundleEntryProto entry;
    897   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
    898   slices->reserve(entry.slices_size());
    899   for (const auto& slice : entry.slices()) {
    900     slices->emplace_back(slice);
    901   }
    902   return Status::OK();
    903 }
    904 
    905 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
    906                                  const TensorSlice& slice_spec, Tensor* val) {
    907   CHECK(val != nullptr);
    908   BundleEntryProto entry;
    909   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
    910   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
    911 }
    912 
    913 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
    914                                    const BundleEntryProto& full_tensor_entry,
    915                                    const TensorSlice& slice_spec, Tensor* val) {
    916   using checkpoint::RegisterTensorSlice;
    917   using checkpoint::TensorSliceSet;
    918   DCHECK_GE(full_tensor_entry.slices_size(), 0);
    919 
    920   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
    921   std::vector<std::pair<TensorSlice, string>> details;
    922   const string full_tensor_key_string = full_tensor_key.ToString();
    923   const TensorSliceSet* tss =
    924       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
    925 
    926   // Populates the "full tensor key -> TensorSliceSet" cache.
    927   if (tss == nullptr) {
    928     if (full_tensor_entry.slices().empty()) {
    929       // Special case: a writer has saved a tensor fully, but the reader wants
    930       // to read in slices.  We therefore register the full slice on-demand here
    931       // without further complicating the on-disk bundle format.
    932       TF_RETURN_IF_ERROR(RegisterTensorSlice(
    933           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
    934           /* tag */ "",
    935           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
    936     }
    937     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
    938       TF_RETURN_IF_ERROR(RegisterTensorSlice(
    939           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
    940           /* tag */ "", TensorSlice(slice), &tensor_slices_));
    941     }
    942     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
    943     CHECK_NE(tss, nullptr);
    944   }
    945   if (!tss->QueryMeta(slice_spec, &details)) {
    946     return errors::InvalidArgument(
    947         "Does not have sufficient slices for partitioned tensor ",
    948         full_tensor_key,
    949         " to restore in slice_spec: ", slice_spec.DebugString());
    950   }
    951 
    952   // The union of the slices in "details" covers "slice_spec".  Performs the
    953   // copies from each.
    954   BundleEntryProto stored_slice_entry = full_tensor_entry;
    955   for (const auto& slice_tag_pair : details) {
    956     // Seeks for the stored slice.
    957     const TensorSlice& stored_slice = slice_tag_pair.first;
    958 
    959     // We already have the entry for the full tensor, so don't query again if
    960     // the slice is full.
    961     if (!stored_slice.IsFull()) {
    962       const string encoded_stored_slice_name =
    963           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
    964                                             stored_slice);
    965       status_ =
    966           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
    967       if (!status_.ok()) return status_;
    968     }
    969 
    970     // TODO(zongheng): should we take an OpKernelContext, so that we can call
    971     // allocate_temp()?  Note that without major refactorings to Saver, it's
    972     // hard for the caller of the tensor bundle module to allocate these
    973     // precisely-shaped scratch storage.
    974 
    975     // Optimization for the common case: the stored slice can be directly
    976     // copied to the destination without additional slicing. This is true when
    977     // either the slices are equal or when they are both full slices having the
    978     // same shape.
    979     TensorShape stored_slice_shape(stored_slice_entry.shape());
    980     if (stored_slice == slice_spec ||
    981         (stored_slice_shape == val->shape() &&
    982          IsFullSlice(stored_slice, stored_slice_shape) &&
    983          IsFullSlice(slice_spec, stored_slice_shape))) {
    984       VLOG(1) << "Optimized for common case: directly copying into "
    985                  "pre-allocated buffer; spec: "
    986               << slice_spec.DebugString();
    987       status_ = GetValue(stored_slice_entry, val);
    988       return status_;
    989     }
    990 
    991     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
    992     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
    993     if (!status_.ok()) return status_;
    994 
    995     // Copies the intersection over.
    996     const DataType common_dtype = full_tensor_entry.dtype();
    997     switch (common_dtype) {
    998 #define HANDLE_COPY(T)                                                 \
    999   case DataTypeToEnum<T>::value:                                       \
   1000     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
   1001         full_shape, stored_slice, slice_spec,                          \
   1002         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
   1003     break;
   1004 
   1005       HANDLE_COPY(float)
   1006       HANDLE_COPY(double)
   1007       HANDLE_COPY(int32)
   1008       HANDLE_COPY(uint8)
   1009       HANDLE_COPY(int16)
   1010       HANDLE_COPY(int8)
   1011       HANDLE_COPY(complex64)
   1012       HANDLE_COPY(complex128)
   1013       HANDLE_COPY(int64)
   1014       HANDLE_COPY(bool)
   1015       HANDLE_COPY(qint32)
   1016       HANDLE_COPY(quint8)
   1017       HANDLE_COPY(qint8)
   1018       default:
   1019         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
   1020                                        " not supported.");
   1021     }
   1022 #undef HANDLE_COPY
   1023   }
   1024   return Status::OK();
   1025 }
   1026 
   1027 bool BundleReader::Contains(StringPiece key) {
   1028   Seek(key);
   1029   return Valid() && (this->key() == key);
   1030 }
   1031 
   1032 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
   1033                                          TensorShape* shape) {
   1034   BundleEntryProto entry;
   1035   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
   1036   *dtype = entry.dtype();
   1037   *shape = TensorShape(entry.shape());
   1038   return Status::OK();
   1039 }
   1040 
   1041 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
   1042   DataType ignored;
   1043   return LookupDtypeAndShape(key, &ignored, shape);
   1044 }
   1045 
   1046 string BundleReader::DebugString() {
   1047   // Format used below emulates that of TensorSliceReader::DebugString().
   1048   string shape_str;
   1049   BundleEntryProto entry;
   1050   Seek(kHeaderEntryKey);
   1051   for (Next(); Valid(); Next()) {
   1052     CHECK(entry.ParseFromArray(value().data(), value().size()));
   1053     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
   1054 
   1055     strings::StrAppend(&shape_str, key(), " (",
   1056                        EnumName_DataType(entry.dtype()), ") ",
   1057                        TensorShape(entry.shape()).DebugString());
   1058     strings::StrAppend(&shape_str, "\n");
   1059   }
   1060   return shape_str;
   1061 }
   1062 
   1063 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
   1064 
   1065 Status FileOutputBuffer::Append(StringPiece data) {
   1066   // In the below, it is critical to calculate the checksum on the actually
   1067   // copied bytes, not the source bytes.  This is because "data" typically
   1068   // points to tensor buffers, which may be concurrently written.
   1069   if (data.size() + position_ <= buffer_size_) {
   1070     // Can fit into the current buffer.
   1071     memcpy(&buffer_[position_], data.data(), data.size());
   1072     crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
   1073   } else if (data.size() <= buffer_size_) {
   1074     // Cannot fit, but can fit after flushing.
   1075     TF_RETURN_IF_ERROR(FlushBuffer());
   1076     memcpy(&buffer_[0], data.data(), data.size());
   1077     crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
   1078   } else {
   1079     // Cannot fit even after flushing.  So we break down "data" by chunk, and
   1080     // flush/checksum each chunk.
   1081     TF_RETURN_IF_ERROR(FlushBuffer());
   1082     for (size_t i = 0; i < data.size(); i += buffer_size_) {
   1083       const size_t nbytes = std::min(data.size() - i, buffer_size_);
   1084       memcpy(&buffer_[0], data.data() + i, nbytes);
   1085       crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
   1086       position_ = nbytes;
   1087       TF_RETURN_IF_ERROR(FlushBuffer());
   1088     }
   1089     return Status::OK();
   1090   }
   1091   position_ += data.size();
   1092   return Status::OK();
   1093 }
   1094 
   1095 Status FileOutputBuffer::Close() {
   1096   TF_RETURN_IF_ERROR(FlushBuffer());
   1097   return file_->Close();
   1098 }
   1099 
   1100 Status FileOutputBuffer::FlushBuffer() {
   1101   if (position_ > 0) {
   1102     TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
   1103     position_ = 0;
   1104   }
   1105   return Status::OK();
   1106 }
   1107 
   1108 }  // namespace tensorflow
   1109