Home | History | Annotate | Download | only in saved_model
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/saved_model/loader.h"
     17 
     18 #include <unordered_set>
     19 
     20 #include "tensorflow/cc/saved_model/constants.h"
     21 #include "tensorflow/core/lib/io/path.h"
     22 #include "tensorflow/core/lib/monitoring/counter.h"
     23 #include "tensorflow/core/lib/strings/strcat.h"
     24 #include "tensorflow/core/platform/env.h"
     25 #include "tensorflow/core/platform/protobuf_internal.h"
     26 #include "tensorflow/core/protobuf/saved_model.pb.h"
     27 #include "tensorflow/core/protobuf/saver.pb.h"
     28 #include "tensorflow/core/public/session.h"
     29 #include "tensorflow/core/public/session_options.h"
     30 #include "tensorflow/core/util/tensor_bundle/naming.h"
     31 
     32 namespace tensorflow {
     33 namespace {
     34 
     35 auto* load_attempt_count = monitoring::Counter<2>::New(
     36     "/tensorflow/cc/saved_model/load_attempt_count",
     37     "The number of times a SavedModel was successfully loaded.", "model_path",
     38     "status");
     39 auto* load_latency = monitoring::Counter<1>::New(
     40     "/tensorflow/cc/saved_model/load_latency",
     41     "Latency in microseconds for SavedModels that were successfully loaded.",
     42     "model_path");
     43 constexpr char kLoadAttemptFail[] = "fail";
     44 constexpr char kLoadAttemptSuccess[] = "success";
     45 
     46 Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
     47   const string saved_model_pb_path =
     48       io::JoinPath(export_dir, kSavedModelFilenamePb);
     49   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
     50     return ReadBinaryProto(Env::Default(), saved_model_pb_path,
     51                            saved_model_proto);
     52   }
     53   const string saved_model_pbtxt_path =
     54       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
     55   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
     56     return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
     57                          saved_model_proto);
     58   }
     59   return Status(error::Code::NOT_FOUND,
     60                 "Could not find SavedModel .pb or .pbtxt at supplied export "
     61                 "directory path: " +
     62                     export_dir);
     63 }
     64 
     65 string GetTagsAsString(const std::unordered_set<string>& tags) {
     66   string tags_as_string = "{ ";
     67   for (const string& tag : tags) {
     68     tags_as_string = strings::StrCat(tags_as_string, tag, " ");
     69   }
     70   tags_as_string = strings::StrCat(tags_as_string, "}");
     71   return tags_as_string;
     72 }
     73 
     74 Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
     75                               const std::unordered_set<string>& tags,
     76                               MetaGraphDef* meta_graph_def_to_load) {
     77   for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
     78     // Get tags from the meta_graph_def.
     79     std::unordered_set<string> graph_tags;
     80     for (const string& tag : meta_graph_def.meta_info_def().tags()) {
     81       graph_tags.insert(tag);
     82     }
     83     // Match with the set of tags provided.
     84     if (graph_tags == tags) {
     85       *meta_graph_def_to_load = meta_graph_def;
     86       return Status::OK();
     87     }
     88   }
     89   return Status(error::Code::NOT_FOUND,
     90                 "Could not find meta graph def matching supplied tags: " +
     91                     GetTagsAsString(tags) +
     92                     ". To inspect available tag-sets in the SavedModel, please "
     93                     "use the SavedModel CLI: `saved_model_cli`");
     94 }
     95 
     96 Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
     97                                 const SessionOptions& session_options,
     98                                 std::unique_ptr<Session>* session) {
     99   Session* session_p = nullptr;
    100   TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
    101   session->reset(session_p);
    102   return (*session)->Create(meta_graph_def.graph_def());
    103 }
    104 
    105 Tensor CreateStringTensor(const string& value) {
    106   Tensor tensor(DT_STRING, TensorShape({}));
    107   tensor.scalar<string>()() = value;
    108   return tensor;
    109 }
    110 
    111 void AddAssetsTensorsToInputs(const StringPiece export_dir,
    112                               const std::vector<AssetFileDef>& asset_file_defs,
    113                               std::vector<std::pair<string, Tensor>>* inputs) {
    114   if (asset_file_defs.empty()) {
    115     return;
    116   }
    117   for (auto& asset_file_def : asset_file_defs) {
    118     Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
    119         export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
    120     inputs->push_back(
    121         {asset_file_def.tensor_info().name(), assets_file_path_tensor});
    122   }
    123 }
    124 
    125 bool HasMainOp(const MetaGraphDef& meta_graph_def) {
    126   const auto& collection_def_map = meta_graph_def.collection_def();
    127   if (collection_def_map.find(kSavedModelMainOpKey) !=
    128       collection_def_map.end()) {
    129     return true;
    130   }
    131   return false;
    132 }
    133 
    134 Status RunMainOp(const RunOptions& run_options, const string& export_dir,
    135                  const MetaGraphDef& meta_graph_def,
    136                  const std::vector<AssetFileDef>& asset_file_defs,
    137                  Session* session) {
    138   LOG(INFO) << "Running MainOp on SavedModel bundle.";
    139   const auto& collection_def_map = meta_graph_def.collection_def();
    140   const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey);
    141   if (main_op_it != collection_def_map.end()) {
    142     if (main_op_it->second.node_list().value_size() != 1) {
    143       return errors::FailedPrecondition(
    144           strings::StrCat("Expected exactly one main op in : ", export_dir));
    145     }
    146     std::vector<std::pair<string, Tensor>> inputs;
    147     AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
    148     RunMetadata run_metadata;
    149     const StringPiece main_op_name = main_op_it->second.node_list().value(0);
    150     return session->Run(run_options, inputs, {}, {main_op_name.ToString()},
    151                         nullptr /* outputs */, &run_metadata);
    152   }
    153   return Status::OK();
    154 }
    155 
    156 Status RunRestore(const RunOptions& run_options, const string& export_dir,
    157                   const StringPiece restore_op_name,
    158                   const StringPiece variable_filename_const_op_name,
    159                   const std::vector<AssetFileDef>& asset_file_defs,
    160                   Session* session) {
    161   LOG(INFO) << "Restoring SavedModel bundle.";
    162   // Find path to variables to be restored in export directory.
    163   const string variables_directory =
    164       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
    165   // Check for saver checkpoints in v2 format. Models exported in the checkpoint
    166   // v2 format will have a variables.index file. The corresponding
    167   // variables are stored in the variables.data-?????-of-????? files.
    168   const string variables_index_path = io::JoinPath(
    169       variables_directory, MetaFilename(kSavedModelVariablesFilename));
    170   if (!Env::Default()->FileExists(variables_index_path).ok()) {
    171     LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
    172                  "were restored.";
    173     return Status::OK();
    174   }
    175   const string variables_path =
    176       io::JoinPath(variables_directory, kSavedModelVariablesFilename);
    177 
    178   // Add variables to the graph.
    179   Tensor variables_path_tensor(DT_STRING, TensorShape({}));
    180   variables_path_tensor.scalar<string>()() = variables_path;
    181 
    182   std::vector<std::pair<string, Tensor>> inputs = {
    183       {variable_filename_const_op_name.ToString(), variables_path_tensor}};
    184 
    185   AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
    186 
    187   RunMetadata run_metadata;
    188   return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
    189                       nullptr /* outputs */, &run_metadata);
    190 }
    191 
    192 Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir,
    193                        const MetaGraphDef& meta_graph_def,
    194                        const std::vector<AssetFileDef>& asset_file_defs,
    195                        Session* session) {
    196   LOG(INFO) << "Running LegacyInitOp on SavedModel bundle.";
    197   const auto& collection_def_map = meta_graph_def.collection_def();
    198   const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey);
    199   if (init_op_it != collection_def_map.end()) {
    200     if (init_op_it->second.node_list().value_size() != 1) {
    201       return errors::FailedPrecondition(strings::StrCat(
    202           "Expected exactly one serving init op in : ", export_dir));
    203     }
    204     std::vector<std::pair<string, Tensor>> inputs;
    205     AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
    206     RunMetadata run_metadata;
    207     const StringPiece legacy_init_op_name =
    208         init_op_it->second.node_list().value(0);
    209     return session->Run(run_options, inputs, {},
    210                         {legacy_init_op_name.ToString()}, nullptr /* outputs */,
    211                         &run_metadata);
    212   }
    213   return Status::OK();
    214 }
    215 
    216 Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
    217                         std::vector<AssetFileDef>* asset_file_defs) {
    218   const auto& collection_def_map = meta_graph_def.collection_def();
    219   const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
    220   if (assets_it == collection_def_map.end()) {
    221     return Status::OK();
    222   }
    223   const auto& any_assets = assets_it->second.any_list().value();
    224   for (const auto& any_asset : any_assets) {
    225     AssetFileDef asset_file_def;
    226     TF_RETURN_IF_ERROR(
    227         ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
    228     asset_file_defs->push_back(asset_file_def);
    229   }
    230   return Status::OK();
    231 }
    232 
    233 Status LoadSavedModelInternal(const SessionOptions& session_options,
    234                               const RunOptions& run_options,
    235                               const string& export_dir,
    236                               const std::unordered_set<string>& tags,
    237                               SavedModelBundle* const bundle) {
    238   if (!MaybeSavedModelDirectory(export_dir)) {
    239     return Status(error::Code::NOT_FOUND,
    240                   "SavedModel not found in export directory: " + export_dir);
    241   }
    242   LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
    243             << "; from: " << export_dir;
    244 
    245   SavedModel saved_model_proto;
    246   TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
    247 
    248   TF_RETURN_IF_ERROR(
    249       FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
    250 
    251   TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
    252       bundle->meta_graph_def, session_options, &bundle->session));
    253 
    254   std::vector<AssetFileDef> asset_file_defs;
    255   TF_RETURN_IF_ERROR(
    256       GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
    257   TF_RETURN_IF_ERROR(
    258       RunRestore(run_options, export_dir,
    259                  bundle->meta_graph_def.saver_def().restore_op_name(),
    260                  bundle->meta_graph_def.saver_def().filename_tensor_name(),
    261                  asset_file_defs, bundle->session.get()));
    262   if (HasMainOp(bundle->meta_graph_def)) {
    263     TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
    264                                  bundle->meta_graph_def, asset_file_defs,
    265                                  bundle->session.get()));
    266   } else {
    267     TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
    268                                        bundle->meta_graph_def, asset_file_defs,
    269                                        bundle->session.get()));
    270   }
    271   return Status::OK();
    272 }
    273 
    274 }  // namespace
    275 
    276 Status LoadSavedModel(const SessionOptions& session_options,
    277                       const RunOptions& run_options, const string& export_dir,
    278                       const std::unordered_set<string>& tags,
    279                       SavedModelBundle* const bundle) {
    280   // TODO(robson): Add tests for the counters.
    281   const uint64 start_microseconds = Env::Default()->NowMicros();
    282   const Status status = LoadSavedModelInternal(session_options, run_options,
    283                                                export_dir, tags, bundle);
    284   const uint64 load_latency_microsecs = [&]() -> uint64 {
    285     const uint64 end_microseconds = Env::Default()->NowMicros();
    286     // Avoid clock skew.
    287     if (end_microseconds < start_microseconds) return 0;
    288     return end_microseconds - start_microseconds;
    289   }();
    290   auto log_and_count = [&](const string& status_str) {
    291     LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
    292               << "; Status: " << status_str << ". Took "
    293               << load_latency_microsecs << " microseconds.";
    294     load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
    295   };
    296   if (status.ok()) {
    297     log_and_count(kLoadAttemptSuccess);
    298   } else {
    299     log_and_count(kLoadAttemptFail);
    300   }
    301   load_latency->GetCell(export_dir)->IncrementBy(load_latency_microsecs);
    302   return status;
    303 }
    304 
    305 bool MaybeSavedModelDirectory(const string& export_dir) {
    306   const string saved_model_pb_path =
    307       io::JoinPath(export_dir, kSavedModelFilenamePb);
    308   const string saved_model_pbtxt_path =
    309       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
    310   return Env::Default()->FileExists(saved_model_pb_path).ok() ||
    311          Env::Default()->FileExists(saved_model_pbtxt_path).ok();
    312 }
    313 
    314 }  // namespace tensorflow
    315