Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The utility to read checkpoints for google brain tensor ops and v3
     17 // checkpoints for dist_belief.
     18 
     19 #ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
     20 #define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
     21 
     22 #include <unordered_map>
     23 
     24 #include <vector>
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/tensor_slice.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/lib/core/stringpiece.h"
     31 #include "tensorflow/core/lib/gtl/map_util.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/macros.h"
     34 #include "tensorflow/core/platform/mutex.h"
     35 #include "tensorflow/core/platform/protobuf.h"
     36 #include "tensorflow/core/platform/types.h"
     37 #include "tensorflow/core/util/saved_tensor_slice.pb.h"
     38 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     39 #include "tensorflow/core/util/tensor_slice_set.h"
     40 #include "tensorflow/core/util/tensor_slice_util.h"
     41 
     42 namespace tensorflow {
     43 
     44 namespace checkpoint {
     45 
     46 // The reader reads in all the meta data about all the tensor slices. Then it
     47 // will try to read the relevant data on-demand to produce the data for the
     48 // slices needed.
     49 // NOTE(yangke): another way to do this is to first load a list of the tensor
     50 // slices needed and then just selectively read some of the meta data. That
     51 // might optimize the loading but makes the logic a bit more complicated. We
     52 // might want to revisit that.
     53 // TODO(yangke): consider moving to TensorProto.
     54 class TensorSliceReader {
     55  public:
     56   // Abstract interface for reading data out of a tensor slice checkpoint file
     57   class Table {
     58    public:
     59     virtual ~Table();
     60     virtual bool Get(const string& key, string* value) = 0;
     61   };
     62   typedef std::function<Status(const string&, Table**)> OpenTableFunction;
     63 
     64   static const int kLoadAllShards = -1;
     65   TensorSliceReader(const string& filepattern);
     66   TensorSliceReader(const string& filepattern, OpenTableFunction open_function);
     67   TensorSliceReader(const string& filepattern, OpenTableFunction open_function,
     68                     int preferred_shard);
     69   virtual ~TensorSliceReader();
     70 
     71   // Get the filename this reader is attached to.
     72   const string& filepattern() const { return filepattern_; }
     73 
     74   // Get the number of files matched.
     75   int num_files() const { return sss_.size(); }
     76 
     77   // Get the status of the reader.
     78   const Status status() const { return status_; }
     79 
     80   // Checks if the reader contains any slice of a tensor. In case the reader
     81   // does contain the tensor, if "shape" is not nullptr, fill "shape" with the
     82   // shape of the tensor; if "type" is not nullptr, fill "type" with the type
     83   // of the tensor.
     84   bool HasTensor(const string& name, TensorShape* shape, DataType* type) const;
     85 
     86   // Checks if the reader contains all the data about a tensor slice, and if
     87   // yes, copies the data of the slice to "data". The caller needs to make sure
     88   // that "data" points to a buffer that holds enough data.
     89   // This is a slow function since it needs to read sstables.
     90   template <typename T>
     91   bool CopySliceData(const string& name, const TensorSlice& slice,
     92                      T* data) const;
     93 
     94   // Get the tensors.
     95   const std::unordered_map<string, TensorSliceSet*>& Tensors() const {
     96     return tensors_;
     97   }
     98 
     99   // Returns value for one tensor. Only single slice checkpoints are supported
    100   // at the moment.
    101   Status GetTensor(const string& name,
    102                    std::unique_ptr<tensorflow::Tensor>* out_tensor) const;
    103 
    104   typedef std::unordered_map<string, TensorShape> VarToShapeMap;
    105   typedef std::unordered_map<string, DataType> VarToDataTypeMap;
    106 
    107   // Returns a map from tensor name to shape.
    108   VarToShapeMap GetVariableToShapeMap() const;
    109 
    110   // Returns a map from tensor name to data type.
    111   VarToDataTypeMap GetVariableToDataTypeMap() const;
    112 
    113   // Returns a string containing names and shapes of all the tensors.
    114   const string DebugString() const;
    115 
    116  private:
    117   friend class TensorSliceWriteTestHelper;
    118 
    119   void LoadShard(int shard) const;
    120   void LoadAllShards() const;
    121 
    122   const TensorSliceSet* FindTensorSlice(
    123       const string& name, const TensorSlice& slice,
    124       std::vector<std::pair<TensorSlice, string>>* details) const;
    125 
    126   const string filepattern_;
    127   const OpenTableFunction open_function_;
    128   std::vector<string> fnames_;
    129   std::unordered_map<string, int> fname_to_index_;
    130 
    131   // Guards the attributes below.
    132   mutable mutex mu_;
    133   mutable bool all_shards_loaded_ = false;
    134   mutable std::vector<std::unique_ptr<Table>> sss_;
    135   mutable std::unordered_map<string, TensorSliceSet*> tensors_;
    136   mutable Status status_;
    137 
    138   TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader);
    139 };
    140 
    141 Status OpenTableTensorSliceReader(const string& fname,
    142                                   TensorSliceReader::Table** table);
    143 
    144 template <typename T>
    145 bool TensorSliceReader::CopySliceData(const string& name,
    146                                       const TensorSlice& slice, T* data) const {
    147   std::vector<std::pair<TensorSlice, string>> details;
    148   const TensorSliceSet* tss;
    149   {
    150     mutex_lock l(mu_);
    151     tss = FindTensorSlice(name, slice, &details);
    152     if (!tss && !all_shards_loaded_) {
    153       VLOG(1) << "Did not find slice in preferred shard, loading all shards."
    154               << name << ": " << slice.DebugString();
    155       LoadAllShards();
    156       tss = FindTensorSlice(name, slice, &details);
    157     }
    158     if (!tss) {
    159       // No such tensor
    160       return false;
    161     }
    162   }
    163   // We have the data -- copy it over.
    164   string value;
    165   for (const auto& x : details) {
    166     const TensorSlice& slice_s = x.first;
    167     const string& fname = x.second;
    168     int idx = gtl::FindWithDefault(fname_to_index_, fname, -1);
    169     CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
    170     // We read a record in the corresponding sstable
    171     const string key = EncodeTensorNameSlice(name, slice_s);
    172     if (!sss_[idx]->Get(key, &value)) {
    173       VLOG(1) << "Failed to seek to the record for tensor " << name
    174               << ", slice " << slice_s.DebugString()
    175               << ": computed key = " << key;
    176       return false;
    177     }
    178     SavedTensorSlices sts;
    179     if (!ParseProtoUnlimited(&sts, value)) {
    180       VLOG(1) << "Failed to parse the record for tensor " << name << ", slice "
    181               << slice_s.DebugString() << ": computed key = " << key;
    182       return false;
    183     }
    184     CopyDataFromTensorSliceToTensorSlice(
    185         tss->shape(), slice_s, slice,
    186         checkpoint::TensorProtoData<T>(sts.data().data()), data);
    187   }
    188   return true;
    189 }
    190 
    191 }  // namespace checkpoint
    192 
    193 }  // namespace tensorflow
    194 
    195 #endif  // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
    196