Home | History | Annotate | Download | only in tensor_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // A tensor bundle is a set of immutable persistent files storing a set of named
     17 // tensors.  It is designed for checkpointing TensorFlow tensors.
     18 //
     19 // The paths of the managed files share a common prefix; e.g., with the prefix:
     20 //   /fs/model/train/ckpt-step/ckpt
     21 //
     22 // the bundle may contain a metadata file, and sharded data files:
     23 //   /fs/model/train/ckpt-step/
     24 //       ckpt.index
     25 //       ckpt.data-00000-of-00020
     26 //       ckpt.data-00001-of-00020
     27 //       ...
     28 //       ckpt.data-00019-of-00020
     29 //
     30 // The ".index" file is a string-string immutable table
     31 // (tensorflow::table::Table).  Each key is a name of a tensor and its value is
     32 // a serialized BundleEntryProto.  Each BundleEntryProto describes the metadata
     33 // of a tensor: which of the "data" files contains the content of a tensor, the
     34 // offset into that file, checksum, some auxiliary data, etc.
     35 //
     36 // A tensor bundle can be accessed randomly using a BundleReader.  Usage:
     37 //
     38 //   BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
     39 //   reader.Lookup("name", &tensor);
     40 //
     41 // A tensor bundle can be built using BundleWriter.  Each BundleWriter builds a
     42 // single data file bundle.  Multiple bundles can then be merged by
     43 // MergeBundles() without reading and writing large chunk of data: it reads the
     44 // metadata files and outputs a single merged metadata.  Typical usage:
     45 //
     46 //   worker 0:
     47 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
     48 //     writer.Add(...);  // Adds the tensors on this worker.
     49 //     writer.Finish();  // Flushes.
     50 //   worker 1:
     51 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
     52 //     writer.Add(...);
     53 //     writer.Finish();
     54 //   worker 2:
     55 //     MergeBundles(env,
     56 //       {"/fs/model/train/ckpt-step/tmp/worker0-step",
     57 //        "/fs/model/train/ckpt-step/tmp/worker1-step"},
     58 //       "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
     59 //
     60 
     61 #ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
     62 #define TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
     63 
     64 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
     65 
     66 #include <map>
     67 #include <string>
     68 #include <unordered_map>
     69 
     70 #include "tensorflow/core/framework/tensor.h"
     71 #include "tensorflow/core/framework/tensor_shape.h"
     72 #include "tensorflow/core/framework/tensor_slice.h"
     73 #include "tensorflow/core/lib/core/status.h"
     74 #include "tensorflow/core/lib/gtl/array_slice.h"
     75 #include "tensorflow/core/lib/io/inputbuffer.h"
     76 #include "tensorflow/core/lib/io/table.h"
     77 #include "tensorflow/core/platform/env.h"
     78 #include "tensorflow/core/platform/file_system.h"
     79 #include "tensorflow/core/platform/macros.h"
     80 #include "tensorflow/core/platform/types.h"
     81 #include "tensorflow/core/util/tensor_bundle/naming.h"
     82 #include "tensorflow/core/util/tensor_slice_set.h"
     83 
     84 namespace tensorflow {
     85 
     86 class FileOutputBuffer;
     87 
     88 // Versioning of the tensor bundle format.
     89 // Follows the same rules as 3p/tf/core/public/version.h.
     90 //
     91 // History:
     92 // 0. Any tensor bundles produced before this field was added.
     93 // 1. Added this field (2016-09-14).
     94 extern const int kTensorBundleMinProducer;
     95 extern const int kTensorBundleMinConsumer;
     96 extern const int kTensorBundleVersion;
     97 
     98 // The empty string, hence always the first key in the metadata table.  Its
     99 // corresponding value is a BundleHeaderProto.
    100 extern const char* const kHeaderEntryKey;
    101 
    102 // Builds a string-string table of tensor names to BundleEntryProto (metadata).
    103 //
    104 // On construction, attempts to create a directory given by the dirname of
    105 // "prefix", so "status()" must be checked before calling any member functions.
    106 //
    107 // All threads accessing the same BundleWriter must synchronize.
    108 class BundleWriter {
    109  public:
    110   struct Options {
    111     Options() {}
    112     // Alignment, in bytes, for tensor data.
    113     // Must be >= 1. The default size of 1 densely packs tensors.
    114     int data_alignment{1};
    115   };
    116   BundleWriter(Env* env, StringPiece prefix,
    117                const Options& options = Options());
    118 
    119   // Adds the tensor "val" under key "key".
    120   // Across calls "key" must be unique but can be added in any order.
    121   Status Add(StringPiece key, const Tensor& val);
    122 
    123   // Partitioned variables support.
    124   // A slice of a full tensor is stored in two entries in the metadata table:
    125   //
    126   //   full_tensor_key   -> BundleEntryProto, describing all stored slices
    127   //                        of this full tensor.  Does not append to the data
    128   //                        file.
    129   //   encoded slice key -> BundleEntryProto, describing one particular slice.
    130   //                        Appends values of this slice to the data file.
    131   //
    132   // Slices of a full tensor can be added in any order.
    133   //
    134   // If a full tensor has slices placed on N devices and N BundleWriter's are
    135   // concurrently used, the caller must use MergeBundles() to ensure that a
    136   // consistent entry for "full_tensor_key" is produced.
    137   //
    138   // Returns an error if the same slice is added the second time.
    139   Status AddSlice(StringPiece full_tensor_key,
    140                   const TensorShape& full_tensor_shape,
    141                   const TensorSlice& slice_spec, const Tensor& slice_tensor);
    142 
    143   // Finishes the writer and flushes.
    144   Status Finish() TF_MUST_USE_RESULT;
    145 
    146   Status status() const { return status_; }
    147 
    148  private:
    149   Env* const env_;  // Not owned.
    150   const Options options_;
    151   const string prefix_;
    152   const string tmp_metadata_path_;
    153   const string tmp_data_path_;
    154   std::unique_ptr<FileOutputBuffer> out_;
    155   int64 size_;  // Number of bytes written into out_.
    156   std::map<string, BundleEntryProto> entries_;
    157   Status status_;
    158 
    159   TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
    160 };
    161 
    162 // Merges a set of bundles (given their prefixes) into a single bundle with the
    163 // given "merged_prefix".  The merged metadata is guaranteed to be consistent.
    164 //
    165 // If there are N bundles in "prefixes", during the merge the data files will be
    166 // renamed to contain a proper sharded file spec, with num_shards set to the sum
    167 // of num_shards across the N input bundles.
    168 //
    169 // The caller should only rely on the metadata file of the merged bundle to
    170 // query information about a tensor.  In particular, this function does not
    171 // guarantee not to re-order the input data files.
    172 //
    173 // Once merged, makes a best effort to delete the old metadata files.
    174 // Returns OK iff all bundles are successfully merged.
    175 Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
    176                     StringPiece merged_prefix);
    177 
    178 // On construction, silently attempts to read the metadata associated with
    179 // "prefix".  If caller intends to call any function afterwards, "status()"
    180 // must be checked.
    181 // All threads accessing the same BundleReader must synchronize.
    182 class BundleReader {
    183  public:
    184   BundleReader(Env* const env, StringPiece prefix);
    185   ~BundleReader();
    186 
    187   // Is ok() iff the reader construction is successful (completed the read of
    188   // the metadata).
    189   Status status() const { return status_; }
    190 
    191   // Queries whether the bundle contains an entry keyed by "key".  Calls Seek()
    192   // internally, so this call invalidates the reader's current position.
    193   // REQUIRES: status().ok()
    194   bool Contains(StringPiece key);
    195 
    196   // Looks up the dtype and the shape of the tensor keyed by "key".
    197   // REQUIRES: status().ok()
    198   Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
    199                              TensorShape* shape) TF_MUST_USE_RESULT;
    200 
    201   // Looks up the shape of the tensor keyed by "key".
    202   // Clears "shape" if not found.
    203   // REQUIRES: status().ok()
    204   Status LookupTensorShape(StringPiece key,
    205                            TensorShape* shape) TF_MUST_USE_RESULT;
    206 
    207   // Looks up the tensor keyed by "key".  If "key" refers to a partitioned
    208   // tensor, attempts to look up the full contents using all stored slices.
    209   //
    210   // Caller must make sure "val" has the same shape and dtype as the
    211   // corresponding contents, so that its buffer can be filled without needing
    212   // extra allocation.  These can be queried via "LookupDtypeAndShape()".
    213   //
    214   // On error, "val" may contain nonsense data.  Returns a NotFound error if
    215   // tensor keyed by "key" does not exist in this bundle.
    216   //
    217   // Validates the stored crc32c checksum against the restored bytes.
    218   // REQUIRES: status().ok()
    219   Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
    220 
    221   // Looks up the tensor pointed to by the internal iterator.
    222   //
    223   // On error, "val" may contain nonsense data.
    224   //
    225   // Validates the stored crc32c checksum against the restored bytes.
    226   // REQUIRES: status().ok() && Valid()
    227   Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
    228 
    229   // Looks up the slices of the tensor keyed by "key".  On OK, "slices"
    230   // is non-empty if and only if the tensor is a partitioned tensor.
    231   //
    232   // Warning - there is no guaranteed ordering for the returned slices, so
    233   // a slice with a larger start index in some dimension could come before
    234   // another slice with a smaller start index in the same dimension.
    235   // REQUIRES: status().ok()
    236   Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
    237       TF_MUST_USE_RESULT;
    238 
    239   // Looks up a specific slice of a partitioned tensor.
    240   // It is only required that the stored slices cover the requested slice,
    241   // namely "slice_spec" is a subset of the union of the stored slices.
    242   // REQUIRES: status().ok()
    243   Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
    244                      Tensor* val) TF_MUST_USE_RESULT;
    245 
    246   // Seeks to the first position in the bundle whose key is no less than "key".
    247   // REQUIRES: status().ok()
    248   void Seek(StringPiece key) { return iter_->Seek(key); }
    249   // Moves to the next position in the bundle.
    250   // REQUIRES: status().ok()
    251   void Next() const { iter_->Next(); }
    252   // Returns true iff the reader is positioned to a key/val pair.
    253   // REQUIRES: status().ok()
    254   bool Valid() const { return iter_->Valid(); }
    255 
    256   // Returns the key at the current position.
    257   // REQUIRES: status().ok() && Valid()
    258   StringPiece key() const { return iter_->key(); }
    259   // Returns the raw value at the current position.
    260   // REQUIRES: status().ok() && Valid()
    261   StringPiece value() const { return iter_->value(); }
    262 
    263   string DebugString();
    264 
    265  private:
    266   // Seeks for "key" and reads the metadata proto.
    267   // On non-OK return, clears "entry" for the caller.
    268   // REQUIRES: status().ok()
    269   Status GetBundleEntryProto(StringPiece key,
    270                              BundleEntryProto* entry) TF_MUST_USE_RESULT;
    271 
    272   // Reads the tensor value described by the metadata proto "entry".
    273   // Usage for "val" follows the comment of "Lookup()".
    274   Status GetValue(const BundleEntryProto& entry,
    275                   Tensor* val) TF_MUST_USE_RESULT;
    276 
    277   // Reads the slice described by "slice_spec".  The corresponding full tensor
    278   // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
    279   // REQUIRES: full_tensor_entry.slices_size() > 0
    280   Status GetSliceValue(StringPiece full_tensor_key,
    281                        const BundleEntryProto& full_tensor_entry,
    282                        const TensorSlice& slice_spec,
    283                        Tensor* val) TF_MUST_USE_RESULT;
    284 
    285   Env* env_;  // Not owned.
    286   const string prefix_;
    287 
    288   Status status_;
    289   RandomAccessFile* metadata_;  // Owned.
    290   table::Table* table_;
    291   table::Iterator* iter_;
    292   // Owned the InputBuffer objects and their underlying RandomAccessFile's.
    293   std::unordered_map<int32, io::InputBuffer*> data_;
    294 
    295   // Maps each partitioned tensor's key to its stored slices (represented in a
    296   // TensorSliceSet).  Populated on-demand.
    297   std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
    298 
    299   // Expected number of data file shards in the bundle.  Extracted by reading
    300   // the header entry in the metadata table.
    301   int num_shards_;
    302 
    303   friend class TensorBundleAlignmentTest;  // For testing data alignment.
    304 
    305   TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
    306 };
    307 
    308 // A buffering wrapper for a WritableFile.  Useful if the caller wishes to issue
    309 // small writes to a file (e.g. writing out a list of small varints).
    310 // External synchronization must be used in the presence of concurrent callers.
    311 class FileOutputBuffer {
    312  public:
    313   FileOutputBuffer(WritableFile* file, size_t buffer_size)
    314       : file_(file), position_(0), buffer_size_(buffer_size) {
    315     DCHECK_GT(buffer_size, 0);
    316     buffer_.resize(buffer_size);
    317   }
    318   ~FileOutputBuffer();
    319 
    320   // Buffered append.
    321   Status Append(StringPiece data);
    322 
    323   // Returns the running crc32c checksum of all currently appended bytes.
    324   uint32 crc32c() { return crc32c_; }
    325   // Clears the running crc32c checksum.
    326   void clear_crc32c() { crc32c_ = 0; }
    327 
    328   // Appends the buffered data, then closes the underlying file.
    329   Status Close();
    330 
    331  private:
    332   // Appends the buffered data to the underlying file. Does NOT flush the file.
    333   Status FlushBuffer();
    334 
    335   WritableFile* file_;  // Owned.
    336 
    337   // buffer_[0, position_) holds the buffered data not yet appended to the
    338   // underlying file.
    339   size_t position_;
    340   const size_t buffer_size_;
    341   std::vector<char> buffer_;
    342 
    343   // Checksum of all appended bytes since construction or last clear_crc32c().
    344   uint32 crc32c_ = 0;
    345 };
    346 
    347 }  // namespace tensorflow
    348 
    349 #endif  // TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
    350