1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/saved_model/loader.h" 17 18 #include <unordered_set> 19 20 #include "tensorflow/cc/saved_model/constants.h" 21 #include "tensorflow/core/lib/io/path.h" 22 #include "tensorflow/core/lib/monitoring/counter.h" 23 #include "tensorflow/core/lib/strings/strcat.h" 24 #include "tensorflow/core/platform/env.h" 25 #include "tensorflow/core/platform/protobuf_internal.h" 26 #include "tensorflow/core/protobuf/saved_model.pb.h" 27 #include "tensorflow/core/protobuf/saver.pb.h" 28 #include "tensorflow/core/public/session.h" 29 #include "tensorflow/core/public/session_options.h" 30 #include "tensorflow/core/util/tensor_bundle/naming.h" 31 32 namespace tensorflow { 33 namespace { 34 35 auto* load_attempt_count = monitoring::Counter<2>::New( 36 "/tensorflow/cc/saved_model/load_attempt_count", 37 "The number of times a SavedModel was successfully loaded.", "model_path", 38 "status"); 39 auto* load_latency = monitoring::Counter<1>::New( 40 "/tensorflow/cc/saved_model/load_latency", 41 "Latency in microseconds for SavedModels that were successfully loaded.", 42 "model_path"); 43 constexpr char kLoadAttemptFail[] = "fail"; 44 constexpr char kLoadAttemptSuccess[] = "success"; 45 46 Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { 47 const string saved_model_pb_path = 48 io::JoinPath(export_dir, kSavedModelFilenamePb); 49 if (Env::Default()->FileExists(saved_model_pb_path).ok()) { 50 return ReadBinaryProto(Env::Default(), saved_model_pb_path, 51 saved_model_proto); 52 } 53 const string saved_model_pbtxt_path = 54 io::JoinPath(export_dir, kSavedModelFilenamePbTxt); 55 if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { 56 return ReadTextProto(Env::Default(), saved_model_pbtxt_path, 57 saved_model_proto); 58 } 59 return Status(error::Code::NOT_FOUND, 60 "Could not find SavedModel .pb or .pbtxt at supplied export " 61 "directory path: " + 62 export_dir); 63 } 64 65 string GetTagsAsString(const std::unordered_set<string>& tags) { 66 string tags_as_string = "{ "; 67 for (const string& tag : tags) { 68 tags_as_string = strings::StrCat(tags_as_string, tag, " "); 69 } 70 tags_as_string = strings::StrCat(tags_as_string, "}"); 71 return tags_as_string; 72 } 73 74 Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, 75 const std::unordered_set<string>& tags, 76 MetaGraphDef* meta_graph_def_to_load) { 77 for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) { 78 // Get tags from the meta_graph_def. 79 std::unordered_set<string> graph_tags; 80 for (const string& tag : meta_graph_def.meta_info_def().tags()) { 81 graph_tags.insert(tag); 82 } 83 // Match with the set of tags provided. 84 if (graph_tags == tags) { 85 *meta_graph_def_to_load = meta_graph_def; 86 return Status::OK(); 87 } 88 } 89 return Status(error::Code::NOT_FOUND, 90 "Could not find meta graph def matching supplied tags: " + 91 GetTagsAsString(tags) + 92 ". To inspect available tag-sets in the SavedModel, please " 93 "use the SavedModel CLI: `saved_model_cli`"); 94 } 95 96 Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, 97 const SessionOptions& session_options, 98 std::unique_ptr<Session>* session) { 99 Session* session_p = nullptr; 100 TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); 101 session->reset(session_p); 102 return (*session)->Create(meta_graph_def.graph_def()); 103 } 104 105 Tensor CreateStringTensor(const string& value) { 106 Tensor tensor(DT_STRING, TensorShape({})); 107 tensor.scalar<string>()() = value; 108 return tensor; 109 } 110 111 void AddAssetsTensorsToInputs(const StringPiece export_dir, 112 const std::vector<AssetFileDef>& asset_file_defs, 113 std::vector<std::pair<string, Tensor>>* inputs) { 114 if (asset_file_defs.empty()) { 115 return; 116 } 117 for (auto& asset_file_def : asset_file_defs) { 118 Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath( 119 export_dir, kSavedModelAssetsDirectory, asset_file_def.filename())); 120 inputs->push_back( 121 {asset_file_def.tensor_info().name(), assets_file_path_tensor}); 122 } 123 } 124 125 bool HasMainOp(const MetaGraphDef& meta_graph_def) { 126 const auto& collection_def_map = meta_graph_def.collection_def(); 127 if (collection_def_map.find(kSavedModelMainOpKey) != 128 collection_def_map.end()) { 129 return true; 130 } 131 return false; 132 } 133 134 Status RunMainOp(const RunOptions& run_options, const string& export_dir, 135 const MetaGraphDef& meta_graph_def, 136 const std::vector<AssetFileDef>& asset_file_defs, 137 Session* session) { 138 LOG(INFO) << "Running MainOp on SavedModel bundle."; 139 const auto& collection_def_map = meta_graph_def.collection_def(); 140 const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); 141 if (main_op_it != collection_def_map.end()) { 142 if (main_op_it->second.node_list().value_size() != 1) { 143 return errors::FailedPrecondition( 144 strings::StrCat("Expected exactly one main op in : ", export_dir)); 145 } 146 std::vector<std::pair<string, Tensor>> inputs; 147 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); 148 RunMetadata run_metadata; 149 const StringPiece main_op_name = main_op_it->second.node_list().value(0); 150 return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, 151 nullptr /* outputs */, &run_metadata); 152 } 153 return Status::OK(); 154 } 155 156 Status RunRestore(const RunOptions& run_options, const string& export_dir, 157 const StringPiece restore_op_name, 158 const StringPiece variable_filename_const_op_name, 159 const std::vector<AssetFileDef>& asset_file_defs, 160 Session* session) { 161 LOG(INFO) << "Restoring SavedModel bundle."; 162 // Find path to variables to be restored in export directory. 163 const string variables_directory = 164 io::JoinPath(export_dir, kSavedModelVariablesDirectory); 165 // Check for saver checkpoints in v2 format. Models exported in the checkpoint 166 // v2 format will have a variables.index file. The corresponding 167 // variables are stored in the variables.data-?????-of-????? files. 168 const string variables_index_path = io::JoinPath( 169 variables_directory, MetaFilename(kSavedModelVariablesFilename)); 170 if (!Env::Default()->FileExists(variables_index_path).ok()) { 171 LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " 172 "were restored."; 173 return Status::OK(); 174 } 175 const string variables_path = 176 io::JoinPath(variables_directory, kSavedModelVariablesFilename); 177 178 // Add variables to the graph. 179 Tensor variables_path_tensor(DT_STRING, TensorShape({})); 180 variables_path_tensor.scalar<string>()() = variables_path; 181 182 std::vector<std::pair<string, Tensor>> inputs = { 183 {variable_filename_const_op_name.ToString(), variables_path_tensor}}; 184 185 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); 186 187 RunMetadata run_metadata; 188 return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, 189 nullptr /* outputs */, &run_metadata); 190 } 191 192 Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir, 193 const MetaGraphDef& meta_graph_def, 194 const std::vector<AssetFileDef>& asset_file_defs, 195 Session* session) { 196 LOG(INFO) << "Running LegacyInitOp on SavedModel bundle."; 197 const auto& collection_def_map = meta_graph_def.collection_def(); 198 const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey); 199 if (init_op_it != collection_def_map.end()) { 200 if (init_op_it->second.node_list().value_size() != 1) { 201 return errors::FailedPrecondition(strings::StrCat( 202 "Expected exactly one serving init op in : ", export_dir)); 203 } 204 std::vector<std::pair<string, Tensor>> inputs; 205 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); 206 RunMetadata run_metadata; 207 const StringPiece legacy_init_op_name = 208 init_op_it->second.node_list().value(0); 209 return session->Run(run_options, inputs, {}, 210 {legacy_init_op_name.ToString()}, nullptr /* outputs */, 211 &run_metadata); 212 } 213 return Status::OK(); 214 } 215 216 Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, 217 std::vector<AssetFileDef>* asset_file_defs) { 218 const auto& collection_def_map = meta_graph_def.collection_def(); 219 const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); 220 if (assets_it == collection_def_map.end()) { 221 return Status::OK(); 222 } 223 const auto& any_assets = assets_it->second.any_list().value(); 224 for (const auto& any_asset : any_assets) { 225 AssetFileDef asset_file_def; 226 TF_RETURN_IF_ERROR( 227 ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")); 228 asset_file_defs->push_back(asset_file_def); 229 } 230 return Status::OK(); 231 } 232 233 Status LoadSavedModelInternal(const SessionOptions& session_options, 234 const RunOptions& run_options, 235 const string& export_dir, 236 const std::unordered_set<string>& tags, 237 SavedModelBundle* const bundle) { 238 if (!MaybeSavedModelDirectory(export_dir)) { 239 return Status(error::Code::NOT_FOUND, 240 "SavedModel not found in export directory: " + export_dir); 241 } 242 LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) 243 << "; from: " << export_dir; 244 245 SavedModel saved_model_proto; 246 TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); 247 248 TF_RETURN_IF_ERROR( 249 FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); 250 251 TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( 252 bundle->meta_graph_def, session_options, &bundle->session)); 253 254 std::vector<AssetFileDef> asset_file_defs; 255 TF_RETURN_IF_ERROR( 256 GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs)); 257 TF_RETURN_IF_ERROR( 258 RunRestore(run_options, export_dir, 259 bundle->meta_graph_def.saver_def().restore_op_name(), 260 bundle->meta_graph_def.saver_def().filename_tensor_name(), 261 asset_file_defs, bundle->session.get())); 262 if (HasMainOp(bundle->meta_graph_def)) { 263 TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, 264 bundle->meta_graph_def, asset_file_defs, 265 bundle->session.get())); 266 } else { 267 TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, 268 bundle->meta_graph_def, asset_file_defs, 269 bundle->session.get())); 270 } 271 return Status::OK(); 272 } 273 274 } // namespace 275 276 Status LoadSavedModel(const SessionOptions& session_options, 277 const RunOptions& run_options, const string& export_dir, 278 const std::unordered_set<string>& tags, 279 SavedModelBundle* const bundle) { 280 // TODO(robson): Add tests for the counters. 281 const uint64 start_microseconds = Env::Default()->NowMicros(); 282 const Status status = LoadSavedModelInternal(session_options, run_options, 283 export_dir, tags, bundle); 284 const uint64 load_latency_microsecs = [&]() -> uint64 { 285 const uint64 end_microseconds = Env::Default()->NowMicros(); 286 // Avoid clock skew. 287 if (end_microseconds < start_microseconds) return 0; 288 return end_microseconds - start_microseconds; 289 }(); 290 auto log_and_count = [&](const string& status_str) { 291 LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) 292 << "; Status: " << status_str << ". Took " 293 << load_latency_microsecs << " microseconds."; 294 load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); 295 }; 296 if (status.ok()) { 297 log_and_count(kLoadAttemptSuccess); 298 } else { 299 log_and_count(kLoadAttemptFail); 300 } 301 load_latency->GetCell(export_dir)->IncrementBy(load_latency_microsecs); 302 return status; 303 } 304 305 bool MaybeSavedModelDirectory(const string& export_dir) { 306 const string saved_model_pb_path = 307 io::JoinPath(export_dir, kSavedModelFilenamePb); 308 const string saved_model_pbtxt_path = 309 io::JoinPath(export_dir, kSavedModelFilenamePbTxt); 310 return Env::Default()->FileExists(saved_model_pb_path).ok() || 311 Env::Default()->FileExists(saved_model_pbtxt_path).ok(); 312 } 313 314 } // namespace tensorflow 315