Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/session_bundle.h"
     17 
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "google/protobuf/any.pb.h"
     23 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
     24 #include "tensorflow/core/framework/graph.pb.h"
     25 #include "tensorflow/core/framework/graph_def_util.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/tensor_types.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/status.h"
     32 #include "tensorflow/core/lib/io/path.h"
     33 #include "tensorflow/core/lib/monitoring/counter.h"
     34 #include "tensorflow/core/platform/env.h"
     35 #include "tensorflow/core/platform/protobuf_internal.h"
     36 #include "tensorflow/core/platform/types.h"
     37 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     38 #include "tensorflow/core/protobuf/saver.pb.h"
     39 #include "tensorflow/core/public/session_options.h"
     40 #include "tensorflow/core/util/tensor_bundle/naming.h"
     41 
     42 namespace tensorflow {
     43 namespace serving {
     44 namespace {
     45 
     46 auto* load_attempt_count = monitoring::Counter<2>::New(
     47     "/tensorflow/contrib/session_bundle/load_attempt_count",
     48     "The number of times a SessionBundle was requested to be loaded.",
     49     "model_path", "status");
     50 auto* load_latency = monitoring::Counter<1>::New(
     51     "/tensorflow/contrib/session_bundle/load_latency",
     52     "Latency in microseconds for SessionBundles that were successfully loaded.",
     53     "model_path");
     54 constexpr char kLoadAttemptFail[] = "fail";
     55 constexpr char kLoadAttemptSuccess[] = "success";
     56 
     57 // Create a session using the given options and load the graph.
     58 Status CreateSessionFromGraphDef(const SessionOptions& options,
     59                                  const GraphDef& graph,
     60                                  std::unique_ptr<Session>* session) {
     61   session->reset(NewSession(options));
     62   return (*session)->Create(graph);
     63 }
     64 
     65 Status GetMetaGraphDefFromExport(const StringPiece export_dir,
     66                                  MetaGraphDef* meta_graph_def) {
     67   const string meta_graph_def_path =
     68       io::JoinPath(export_dir, kMetaGraphDefFilename);
     69   return ReadBinaryProto(Env::Default(), meta_graph_def_path, meta_graph_def);
     70 }
     71 
     72 // Creates a string tensor.
     73 Tensor CreateStringTensor(const string& value) {
     74   Tensor tensor(DT_STRING, TensorShape({}));
     75   tensor.scalar<string>()() = value;
     76   return tensor;
     77 }
     78 
     79 // Adds Assets related tensors (assets_dir and asset files) to the inputs.
     80 void AddAssetsTensorsToInputs(const StringPiece export_dir,
     81                               const std::vector<AssetFile>& asset_files,
     82                               std::vector<std::pair<string, Tensor>>* inputs) {
     83   if (asset_files.empty()) {
     84     return;
     85   }
     86   for (auto& asset : asset_files) {
     87     Tensor assets_file_tensor = CreateStringTensor(
     88         io::JoinPath(export_dir, kAssetsDirectory, asset.filename()));
     89     inputs->push_back(
     90         {asset.tensor_binding().tensor_name(), assets_file_tensor});
     91   }
     92 }
     93 
     94 // Historically, model exporter(exporter.py) takes only saver with sharded=True,
     95 // and therefore always exports checkpoint in pattern file names.  In practice,
     96 // instead of training from scratch and export directly, we usually want to
     97 // restore from existing checkpoints and then export directly.  To support such
     98 // case, model exporter now supports reusing saver object restored from existing
     99 // checkpoint, that may have sharded=False - it will then export checkpoint file
    100 // in plain file name.  This method is to support models exported by both types
    101 // of saver object.  The change is backward-compatible, therefore no changes are
    102 // needed for existing model exports.
    103 //
    104 // Checkpoint v2 support: Variables exported using tf-exporter in the checkpoint
    105 // v2 format will have export.index and export.data-?????-of-????? files as
    106 // opposed to just an export or export-?????-of-????? file. The V2 save/restore
    107 // code accepts a filename prefix and assumes both prefix.index and
    108 // prefix.data-* are present in the filesystem. So if we see export.index
    109 // present in the export_dir, we know the export is in V2 format and we return
    110 // <export_dir>/export as this prefix.
    111 string GetVariablesFilename(const StringPiece export_dir) {
    112   const char kVariablesFilename[] = "export";
    113   const string kVariablesIndexFilename = MetaFilename("export");  // V2 ckpts
    114   const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?";
    115   if (Env::Default()
    116           ->FileExists(io::JoinPath(export_dir, kVariablesFilename))
    117           .ok() ||
    118       // This works for the case of V2 because the variables filename is taken
    119       // as a prefix in the save/restore abstraction, and the index and actual
    120       // variables are meant to be present as prefix.index and
    121       // prefix.data-?????-of-?????.
    122       Env::Default()
    123           ->FileExists(io::JoinPath(export_dir, kVariablesIndexFilename))
    124           .ok()) {
    125     return io::JoinPath(export_dir, kVariablesFilename);
    126   } else {
    127     return io::JoinPath(export_dir, kVariablesFilenamePattern);
    128   }
    129 }
    130 
    131 Status RunRestoreOp(const RunOptions& run_options, const StringPiece export_dir,
    132                     const std::vector<AssetFile>& asset_files,
    133                     const StringPiece restore_op_name,
    134                     const StringPiece variables_filename_const_op_name,
    135                     Session* session) {
    136   LOG(INFO) << "Running restore op for SessionBundle: " << restore_op_name
    137             << ", " << variables_filename_const_op_name;
    138   Tensor variables_tensor =
    139       CreateStringTensor(GetVariablesFilename(export_dir));
    140   std::vector<std::pair<string, Tensor>> inputs = {
    141       {variables_filename_const_op_name.ToString(), variables_tensor}};
    142   AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
    143   RunMetadata run_metadata;
    144   return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
    145                       nullptr /* outputs */, &run_metadata);
    146 }
    147 
    148 Status RunInitOp(const RunOptions& run_options, const StringPiece export_dir,
    149                  const std::vector<AssetFile>& asset_files,
    150                  const StringPiece init_op_name, Session* session) {
    151   LOG(INFO) << "Running init op for SessionBundle";
    152   std::vector<std::pair<string, Tensor>> inputs;
    153   AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
    154   RunMetadata run_metadata;
    155   return session->Run(run_options, inputs, {}, {init_op_name.ToString()},
    156                       nullptr /* outputs */, &run_metadata);
    157 }
    158 
    159 Status LoadSessionBundleFromPathUsingRunOptionsInternal(
    160     const SessionOptions& options, const RunOptions& run_options,
    161     const StringPiece export_dir, SessionBundle* const bundle) {
    162   LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
    163   LOG(INFO) << "Using RunOptions: " << DebugStringIfAvailable(run_options);
    164   TF_RETURN_IF_ERROR(
    165       GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
    166 
    167   // Deprecated SessionBundle models may fail to load because newly added
    168   // attributes are not added to the Graph in the default Session initialization
    169   // flow. Add an explicit call here when first loading the graph from disk.
    170   TF_RETURN_IF_ERROR(
    171       AddDefaultAttrsToGraphDef(bundle->meta_graph_def.mutable_graph_def(),
    172                                 *OpRegistry::Global(), 0 /* node_offset */));
    173 
    174   const auto& collection_def_map = bundle->meta_graph_def.collection_def();
    175   const auto graph_it = bundle->meta_graph_def.collection_def().find(kGraphKey);
    176   if (graph_it != collection_def_map.end()) {
    177     const CollectionDef& graph_collection_def = graph_it->second;
    178     // Use serving graph_def in MetaGraphDef collection_def.
    179     if (graph_collection_def.any_list().value_size() != 1) {
    180       return errors::FailedPrecondition(
    181           "Expected exactly one serving GraphDef in : ", export_dir);
    182     }
    183     const auto& any = graph_collection_def.any_list().value(0);
    184     GraphDef graph_def;
    185     TF_RETURN_IF_ERROR(ParseAny(any, &graph_def, "tensorflow.GraphDef"));
    186     TF_RETURN_IF_ERROR(
    187         CreateSessionFromGraphDef(options, graph_def, &bundle->session));
    188   } else {
    189     // Fallback to use the graph_def in the MetaGraphDef.
    190     const GraphDef& graph_def = bundle->meta_graph_def.graph_def();
    191     TF_RETURN_IF_ERROR(
    192         CreateSessionFromGraphDef(options, graph_def, &bundle->session));
    193   }
    194 
    195   std::vector<AssetFile> asset_files;
    196   const auto assets_it = collection_def_map.find(kAssetsKey);
    197   if (assets_it != collection_def_map.end()) {
    198     const auto& any_assets = assets_it->second.any_list().value();
    199     for (const auto& any_asset : any_assets) {
    200       AssetFile asset_file;
    201       TF_RETURN_IF_ERROR(
    202           ParseAny(any_asset, &asset_file, "tensorflow.serving.AssetFile"));
    203       asset_files.push_back(asset_file);
    204     }
    205   }
    206 
    207   TF_RETURN_IF_ERROR(
    208       RunRestoreOp(run_options, export_dir, asset_files,
    209                    bundle->meta_graph_def.saver_def().restore_op_name(),
    210                    bundle->meta_graph_def.saver_def().filename_tensor_name(),
    211                    bundle->session.get()));
    212 
    213   const auto init_op_it = collection_def_map.find(kInitOpKey);
    214   if (init_op_it != collection_def_map.end()) {
    215     if (init_op_it->second.node_list().value_size() != 1) {
    216       return errors::FailedPrecondition(strings::StrCat(
    217           "Expected exactly one serving init op in : ", export_dir));
    218     }
    219     TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, asset_files,
    220                                  init_op_it->second.node_list().value(0),
    221                                  bundle->session.get()));
    222   }
    223 
    224   return Status::OK();
    225 }
    226 
    227 }  // namespace
    228 
    229 Status LoadSessionBundleFromPath(const SessionOptions& options,
    230                                  const StringPiece export_dir,
    231                                  SessionBundle* const bundle) {
    232   TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions(
    233       options, RunOptions(), export_dir, bundle));
    234   return Status::OK();
    235 }
    236 
    237 Status LoadSessionBundleFromPathUsingRunOptions(const SessionOptions& options,
    238                                                 const RunOptions& run_options,
    239                                                 const StringPiece export_dir,
    240                                                 SessionBundle* const bundle) {
    241   const uint64 start_microseconds = Env::Default()->NowMicros();
    242   const Status status = LoadSessionBundleFromPathUsingRunOptionsInternal(
    243       options, run_options, export_dir, bundle);
    244 
    245   const uint64 load_latency_microsecs = [&]() -> uint64 {
    246     const uint64 end_microseconds = Env::Default()->NowMicros();
    247     // Avoid clock skew.
    248     if (end_microseconds < start_microseconds) return 0;
    249     return end_microseconds - start_microseconds;
    250   }();
    251   auto log_and_count = [&](const string& status_str) {
    252     LOG(INFO) << "Loading SessionBundle: " << status_str << ". Took "
    253               << load_latency_microsecs << " microseconds.";
    254     load_attempt_count->GetCell(export_dir.ToString(), status_str)
    255         ->IncrementBy(1);
    256   };
    257   if (status.ok()) {
    258     log_and_count(kLoadAttemptSuccess);
    259   } else {
    260     log_and_count(kLoadAttemptFail);
    261   }
    262   load_latency->GetCell(export_dir.ToString())
    263       ->IncrementBy(load_latency_microsecs);
    264   return status;
    265 }
    266 
    267 bool IsPossibleExportDirectory(const StringPiece directory) {
    268   const string meta_graph_def_path =
    269       io::JoinPath(directory, kMetaGraphDefFilename);
    270   return Env::Default()->FileExists(meta_graph_def_path).ok();
    271 }
    272 
    273 }  // namespace serving
    274 }  // namespace tensorflow
    275