Home | History | Annotate | Download | only in c
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/c/checkpoint_reader.h"
     17 
     18 #include <unordered_set>
     19 #include <utility>
     20 
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/lib/core/stringpiece.h"
     23 #include "tensorflow/core/platform/env.h"
     24 #include "tensorflow/core/platform/types.h"
     25 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     26 
     27 namespace tensorflow {
     28 namespace checkpoint {
     29 
     30 class TensorSliceReader;
     31 
     32 CheckpointReader::CheckpointReader(const string& filename,
     33                                    TF_Status* out_status)
     34     : reader_(nullptr),
     35       v2_reader_(nullptr),
     36       var_to_shape_map_(nullptr),
     37       var_to_data_type_map_(nullptr) {
     38   // Depending on whether this is a V2 ckpt, initializes "reader_" or
     39   // "v2_reader_".
     40   std::vector<string> v2_path;
     41   if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
     42       !v2_path.empty()) {
     43     v2_reader_.reset(
     44         new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
     45     if (!v2_reader_->status().ok()) {
     46       Set_TF_Status_from_Status(out_status, v2_reader_->status());
     47       return;
     48     }
     49     auto result = BuildV2VarMaps();
     50     var_to_shape_map_.swap(result.first);
     51     var_to_data_type_map_.swap(result.second);
     52   } else {
     53     reader_.reset(new TensorSliceReader(filename));
     54     if (!reader_->status().ok()) {
     55       Set_TF_Status_from_Status(out_status, reader_->status());
     56       return;
     57     }
     58     var_to_shape_map_.reset(
     59         new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
     60     var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
     61         reader_->GetVariableToDataTypeMap()));
     62   }
     63 }
     64 
     65 bool CheckpointReader::HasTensor(const string& name) const {
     66   if (reader_ != nullptr) {
     67     return reader_->HasTensor(name, nullptr, nullptr);
     68   }
     69   return v2_reader_->Contains(name);
     70 }
     71 
     72 const TensorSliceReader::VarToShapeMap&
     73 CheckpointReader::GetVariableToShapeMap() const {
     74   CHECK(var_to_shape_map_);
     75   return *var_to_shape_map_;
     76 }
     77 
     78 const TensorSliceReader::VarToDataTypeMap&
     79 CheckpointReader::GetVariableToDataTypeMap() const {
     80   CHECK(var_to_data_type_map_);
     81   return *var_to_data_type_map_;
     82 }
     83 
     84 const string CheckpointReader::DebugString() const {
     85   if (reader_ != nullptr) return reader_->DebugString();
     86   return v2_reader_->DebugString();
     87 }
     88 
     89 void CheckpointReader::GetTensor(
     90     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor,
     91     TF_Status* out_status) const {
     92   Status status;
     93   if (reader_ != nullptr) {
     94     status = reader_->GetTensor(name, out_tensor);
     95   } else {
     96     tensorflow::DataType dtype;
     97     tensorflow::TensorShape shape;
     98     status = v2_reader_->LookupDtypeAndShape(name, &dtype, &shape);
     99     if (status.ok()) {
    100       out_tensor->reset(new Tensor(dtype, shape));
    101       status = v2_reader_->Lookup(name, out_tensor->get());
    102       if (!status.ok()) out_tensor->reset();
    103     }
    104   }
    105   if (!status.ok()) {
    106     Set_TF_Status_from_Status(out_status, status);
    107   }
    108 }
    109 
    110 std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
    111           std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
    112 CheckpointReader::BuildV2VarMaps() {
    113   CHECK(v2_reader_ != nullptr);
    114   CHECK(v2_reader_->status().ok());
    115 
    116   // First pass: filters out the entries of the slices.
    117   std::unordered_set<string> filtered_keys;
    118   BundleEntryProto entry;
    119   v2_reader_->Seek(kHeaderEntryKey);
    120   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
    121     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
    122                                v2_reader_->value().size()))
    123         << entry.InitializationErrorString();
    124     for (int i = 0; i < entry.slices_size(); ++i) {
    125       const auto& slice_proto = entry.slices(i);
    126       CHECK(filtered_keys
    127                 .insert(EncodeTensorNameSlice(
    128                     string(v2_reader_->key()) /* full var's name */,
    129                     TensorSlice(slice_proto)))
    130                 .second);
    131     }
    132   }
    133 
    134   // Second pass: adds the entries, ignoring the filtered keys.
    135   std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
    136       new TensorSliceReader::VarToShapeMap);
    137   std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
    138       new TensorSliceReader::VarToDataTypeMap);
    139   v2_reader_->Seek(kHeaderEntryKey);
    140   for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
    141     if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
    142     CHECK(entry.ParseFromArray(v2_reader_->value().data(),
    143                                v2_reader_->value().size()))
    144         << entry.InitializationErrorString();
    145     string key(v2_reader_->key());
    146     (*var_to_shape_map)[key] = TensorShape(entry.shape());
    147     (*var_to_data_type_map)[key] = DataType(entry.dtype());
    148   }
    149   // The returned pointers are owned by the caller.
    150   return std::make_pair(std::move(var_to_shape_map),
    151                         std::move(var_to_data_type_map));
    152 }
    153 
    154 }  // namespace checkpoint
    155 }  // namespace tensorflow
    156