Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/tensor_slice_writer.h"
     17 
     18 #include <utility>
     19 
     20 #include "tensorflow/core/framework/versions.pb.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/io/table_builder.h"
     23 #include "tensorflow/core/lib/random/random.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/public/version.h"
     28 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     29 
     30 namespace tensorflow {
     31 
     32 namespace checkpoint {
     33 
     34 namespace {
     35 
     36 class TableBuilder : public TensorSliceWriter::Builder {
     37  public:
     38   TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) {
     39     table::Options option;
     40     option.compression = table::kNoCompression;
     41     builder_.reset(new table::TableBuilder(option, f));
     42   }
     43   void Add(StringPiece key, StringPiece val) override {
     44     builder_->Add(key, val);
     45   }
     46   Status Finish(int64* file_size) override {
     47     *file_size = -1;
     48     Status s = builder_->Finish();
     49     if (s.ok()) {
     50       s = file_->Close();
     51       if (s.ok()) {
     52         *file_size = builder_->FileSize();
     53       }
     54     }
     55     if (!s.ok()) {
     56       s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
     57                            s.ToString());
     58     }
     59     builder_.reset();
     60     file_.reset();
     61     return s;
     62   }
     63 
     64  private:
     65   string name_;
     66   std::unique_ptr<WritableFile> file_;
     67   std::unique_ptr<table::TableBuilder> builder_;
     68 };
     69 }  // anonymous namespace
     70 
     71 Status CreateTableTensorSliceBuilder(const string& name,
     72                                      TensorSliceWriter::Builder** builder) {
     73   *builder = nullptr;
     74   std::unique_ptr<WritableFile> f;
     75   Status s = Env::Default()->NewWritableFile(name, &f);
     76   if (s.ok()) {
     77     *builder = new TableBuilder(name, f.release());
     78     return Status::OK();
     79   } else {
     80     return s;
     81   }
     82 }
     83 
     84 TensorSliceWriter::TensorSliceWriter(const string& filename,
     85                                      CreateBuilderFunction create_builder)
     86     : filename_(filename),
     87       create_builder_(std::move(create_builder)),
     88       tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
     89       slices_(0) {
     90   VersionDef* versions = sts_.mutable_meta()->mutable_versions();
     91   versions->set_producer(TF_CHECKPOINT_VERSION);
     92   versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
     93 }
     94 
     95 Status TensorSliceWriter::Finish() {
     96   Builder* b;
     97   Status s = create_builder_(tmpname_, &b);
     98   if (!s.ok()) {
     99     delete b;
    100     return s;
    101   }
    102   std::unique_ptr<Builder> builder(b);
    103 
    104   // We save the saved tensor slice metadata as the first element.
    105   string meta;
    106   sts_.AppendToString(&meta);
    107   builder->Add(kSavedTensorSlicesKey, meta);
    108 
    109   // Go through all the data and add them
    110   for (const auto& x : data_) {
    111     builder->Add(x.first, x.second);
    112   }
    113 
    114   int64 file_size;
    115   s = builder->Finish(&file_size);
    116   // We need to rename the file to the proper name
    117   if (s.ok()) {
    118     s = Env::Default()->RenameFile(tmpname_, filename_);
    119     if (s.ok()) {
    120       VLOG(1) << "Written " << slices_ << " slices for "
    121               << sts_.meta().tensor_size() << " tensors (" << file_size
    122               << " bytes) to " << filename_;
    123     } else {
    124       LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
    125     }
    126   } else {
    127     Env::Default()->DeleteFile(tmpname_).IgnoreError();
    128   }
    129   return s;
    130 }
    131 
    132 /* static */
    133 size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
    134   switch (dt) {
    135     case DT_FLOAT:
    136       return 4;
    137     case DT_DOUBLE:
    138       return 8;
    139     case DT_INT32:
    140       return 10;
    141     case DT_UINT8:
    142       return 2;
    143     case DT_INT16:
    144       return 10;
    145     case DT_INT8:
    146       return 10;
    147     case DT_COMPLEX64:
    148       return 8;
    149     case DT_INT64:
    150       return 10;
    151     case DT_BOOL:
    152       return 1;
    153     case DT_QINT8:
    154       return 10;
    155     case DT_QUINT8:
    156       return 2;
    157     case DT_QINT32:
    158       return 10;
    159     case DT_QINT16:
    160       return 10;
    161     case DT_QUINT16:
    162       return 3;
    163     case DT_UINT16:
    164       return 3;
    165     case DT_COMPLEX128:
    166       return 16;
    167     case DT_HALF:
    168       return 3;
    169     case DT_INVALID:
    170     case DT_STRING:
    171     case DT_BFLOAT16:
    172     default:
    173       LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
    174   }
    175   return 0;
    176 }
    177 
    178 template <>
    179 Status TensorSliceWriter::SaveData(const string* data, int64 num_elements,
    180                                    SavedSlice* ss) {
    181   size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
    182                       (num_elements * MaxBytesPerElement(DT_INT32));
    183   for (int64 i = 0; i < num_elements; ++i) {
    184     size_bound += data[i].size();
    185   }
    186   if (size_bound > kMaxMessageBytes) {
    187     return errors::InvalidArgument(
    188         "Tensor slice is too large to serialize (conservative estimate: ",
    189         size_bound, " bytes)");
    190   }
    191   Fill(data, num_elements, ss->mutable_data());
    192   DCHECK_GE(ss->ByteSize(), 0);
    193   DCHECK_LE(ss->ByteSize(), size_bound);
    194   return Status::OK();
    195 }
    196 
    197 }  // namespace checkpoint
    198 
    199 }  // namespace tensorflow
    200