1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/session_bundle/bundle_shim.h" 17 18 #include "tensorflow/cc/saved_model/loader.h" 19 #include "tensorflow/cc/saved_model/signature_constants.h" 20 #include "tensorflow/contrib/session_bundle/manifest.pb.h" 21 #include "tensorflow/contrib/session_bundle/session_bundle.h" 22 #include "tensorflow/contrib/session_bundle/signature.h" 23 #include "tensorflow/core/graph/graph_constructor.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/core/stringpiece.h" 27 #include "tensorflow/core/protobuf/meta_graph.pb.h" 28 #include "tensorflow/core/public/session.h" 29 #include "tensorflow/core/public/session_options.h" 30 31 namespace tensorflow { 32 namespace serving { 33 namespace { 34 /////////////////////////////////////////////////////////////////////////////// 35 // Helper functions to check Signature type. 36 37 bool IsClassificationSignature(const Signature& signature) { 38 return signature.type_case() == Signature::kClassificationSignature; 39 } 40 41 bool IsRegressionSignature(const Signature& signature) { 42 return signature.type_case() == Signature::kRegressionSignature; 43 } 44 45 /////////////////////////////////////////////////////////////////////////////// 46 // Helper functions to build `Classification`, `Regression` and `Predict` 47 // SignatureDefs. 48 49 SignatureDef BuildRegressionSignatureDef( 50 const RegressionSignature& regression_signature, 51 const std::unordered_map<string, DataType>& tensor_name_to_dtype) { 52 SignatureDef signature_def; 53 signature_def.set_method_name(kRegressMethodName); 54 internal::AddInputToSignatureDef(regression_signature.input().tensor_name(), 55 tensor_name_to_dtype, kRegressInputs, 56 &signature_def); 57 internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(), 58 tensor_name_to_dtype, kRegressOutputs, 59 &signature_def); 60 return signature_def; 61 } 62 63 SignatureDef BuildClassificationSignatureDef( 64 const ClassificationSignature& classification_signature, 65 const std::unordered_map<string, DataType>& tensor_name_to_dtype) { 66 SignatureDef signature_def; 67 signature_def.set_method_name(kClassifyMethodName); 68 internal::AddInputToSignatureDef( 69 classification_signature.input().tensor_name(), tensor_name_to_dtype, 70 kClassifyInputs, &signature_def); 71 internal::AddOutputToSignatureDef( 72 classification_signature.classes().tensor_name(), tensor_name_to_dtype, 73 kClassifyOutputClasses, &signature_def); 74 internal::AddOutputToSignatureDef( 75 classification_signature.scores().tensor_name(), tensor_name_to_dtype, 76 kClassifyOutputScores, &signature_def); 77 return signature_def; 78 } 79 80 Status MaybeBuildPredictSignatureDef( 81 const std::unordered_map<string, DataType>& tensor_name_to_dtype, 82 MetaGraphDef* meta_graph_def) { 83 Signature input_signature, output_signature; 84 // Ensure that named signatures corresponding to `inputs` and `outputs` keys 85 // exist. 86 if (!GetNamedSignature(kPredictInputs, *meta_graph_def, &input_signature) 87 .ok() || 88 !GetNamedSignature(kPredictOutputs, *meta_graph_def, &output_signature) 89 .ok()) { 90 return Status(error::Code::INVALID_ARGUMENT, 91 "Named signatures can only be up-converted if entries " 92 "corresponding to both `inputs` and `outputs` exist."); 93 } 94 // Ensure the `inputs` and `outputs` named signatures are generic signatures. 95 if (input_signature.type_case() != Signature::TypeCase::kGenericSignature || 96 output_signature.type_case() != Signature::TypeCase::kGenericSignature) { 97 return Status(error::Code::INVALID_ARGUMENT, 98 "Named signatures corresponding to `inputs` and `outputs` " 99 "can only be up-converted if they are GenericSignatures."); 100 } 101 SignatureDef signature_def; 102 signature_def.set_method_name(kPredictMethodName); 103 // Add map entries from the `inputs` generic signature to the input map in the 104 // signature def. 105 for (const auto& map_entry : input_signature.generic_signature().map()) { 106 internal::AddInputToSignatureDef(map_entry.second.tensor_name(), 107 tensor_name_to_dtype, map_entry.first, 108 &signature_def); 109 } 110 // Add map entries from the `outputs` generic signature to the output map in 111 // the signature def. 112 for (const auto& map_entry : output_signature.generic_signature().map()) { 113 internal::AddOutputToSignatureDef(map_entry.second.tensor_name(), 114 tensor_name_to_dtype, map_entry.first, 115 &signature_def); 116 } 117 // Add the constructed signature def to the signature def map of the meta 118 // graph def. Use the default key if it isn't already in use. 119 const bool already_has_default_signature = 120 meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) != 121 meta_graph_def->signature_def().end(); 122 const string signature_def_key = 123 already_has_default_signature 124 ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named") 125 : kDefaultServingSignatureDefKey; 126 (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def; 127 return Status::OK(); 128 } 129 130 Status LoadSavedModelFromLegacySessionBundlePath( 131 const SessionOptions& session_options, const RunOptions& run_options, 132 const StringPiece session_bundle_export_dir, 133 SavedModelBundle* saved_model_bundle) { 134 if (session_bundle_export_dir.empty()) { 135 return Status(error::Code::NOT_FOUND, "Export directory path is empty."); 136 } 137 if (!IsPossibleExportDirectory(session_bundle_export_dir)) { 138 return Status( 139 error::Code::NOT_FOUND, 140 "Export directory does not contain a valid SessionBundle export."); 141 } 142 143 // Build the session-bundle. 144 SessionBundle session_bundle; 145 TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions( 146 session_options, run_options, session_bundle_export_dir, 147 &session_bundle)); 148 149 // Convert the session-bundle to a saved-model-bundle. 150 return internal::ConvertSessionBundleToSavedModelBundle(session_bundle, 151 saved_model_bundle); 152 } 153 154 /////////////////////////////////////////////////////////////////////////////// 155 // Helper functions to convert `Default` and `Named` signatures to 156 // SignatureDefs. 157 158 // Up-conversion of default signatures is supported for classification and 159 // regression. 160 Status ConvertDefaultSignatureToSignatureDef( 161 const Signatures& signatures, 162 const std::unordered_map<string, DataType>& tensor_name_to_dtype, 163 MetaGraphDef* meta_graph_def) { 164 if (!signatures.has_default_signature()) { 165 return Status::OK(); 166 } 167 const bool already_has_default_signature = 168 meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) != 169 meta_graph_def->signature_def().end(); 170 if (already_has_default_signature) { 171 return Status(error::Code::ALREADY_EXISTS, 172 strings::StrCat( 173 "Default signature cannot be up-converted since ", 174 kDefaultServingSignatureDefKey, " key already exists.")); 175 } 176 const Signature& signature = signatures.default_signature(); 177 if (IsRegressionSignature(signature)) { 178 (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = 179 BuildRegressionSignatureDef(signature.regression_signature(), 180 tensor_name_to_dtype); 181 } else if (IsClassificationSignature(signature)) { 182 (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] = 183 BuildClassificationSignatureDef(signature.classification_signature(), 184 tensor_name_to_dtype); 185 } else { 186 LOG(WARNING) << "Default signature up-conversion to SignatureDef is only " 187 "supported for `Classification` and `Regression`. Could " 188 "not up-convert signature: " 189 << signature.DebugString() 190 << ". (If using SessionRun with the SessionBundle export " 191 "format please ignore this warning.)"; 192 } 193 return Status::OK(); 194 } 195 196 Status ConvertNamedSignaturesToSignatureDef( 197 const Signatures& signatures, 198 const std::unordered_map<string, DataType>& tensor_name_to_dtype, 199 MetaGraphDef* meta_graph_def) { 200 if (signatures.named_signatures().empty()) { 201 return Status::OK(); 202 } 203 // Check for a Predict signature for up-conversion. 204 Status predict_signature_def_status = 205 MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def); 206 for (const auto& it_named_signature : signatures.named_signatures()) { 207 const string key = it_named_signature.first; 208 // If a Predict SignatureDef was successfully constructed, skip the entries 209 // corresponding to `inputs` and `outputs`. 210 if (predict_signature_def_status.ok()) { 211 if (key == kPredictInputs || key == kPredictOutputs) { 212 continue; 213 } 214 } 215 const Signature signature = it_named_signature.second; 216 if (IsRegressionSignature(signature)) { 217 (*meta_graph_def->mutable_signature_def())[key] = 218 BuildRegressionSignatureDef(signature.regression_signature(), 219 tensor_name_to_dtype); 220 } else if (IsClassificationSignature(signature)) { 221 (*meta_graph_def->mutable_signature_def())[key] = 222 BuildClassificationSignatureDef(signature.classification_signature(), 223 tensor_name_to_dtype); 224 } else { 225 LOG(WARNING) 226 << "Named signature up-conversion to SignatureDef is only supported " 227 "for `Classification`, `Regression` or if two `GenericSignatures` " 228 "signatures called `inputs` and `outputs` exist, corresponding " 229 "to the `Prediction` API. Could not up-convert signature: " 230 << signature.DebugString(); 231 } 232 } 233 return Status::OK(); 234 } 235 236 } // namespace 237 238 namespace internal { 239 /////////////////////////////////////////////////////////////////////////////// 240 // Helper functions to populate SignatureDef fields. 241 242 // Adds an entry to the `inputs` map of the supplied SignatureDef. 243 void AddInputToSignatureDef( 244 const string& tensor_name, 245 const std::unordered_map<string, DataType>& tensor_name_to_dtype, 246 const string& input_key, SignatureDef* signature_def) { 247 if (tensor_name.empty()) { 248 LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to " 249 "SignatureDef inputs."; 250 return; 251 } 252 // Extract the tensor-name in case the supplied string is a tensor-reference. 253 // Example: Extract "x" from "x:0". 254 std::size_t pos = tensor_name.find(":"); 255 const string key = 256 (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name; 257 const auto it_tensor_info = tensor_name_to_dtype.find(key); 258 TensorInfo tensor_info; 259 tensor_info.set_name(tensor_name); 260 if (it_tensor_info != tensor_name_to_dtype.end()) { 261 tensor_info.set_dtype(it_tensor_info->second); 262 } else { 263 LOG(WARNING) 264 << "No dtype found for tensor with name: " << tensor_name << ". " 265 << "Building TensorInfo with only name for SignatureDef inputs. " 266 << "Downstream functionality including validation may be " 267 << "impacted."; 268 } 269 (*signature_def->mutable_inputs())[input_key] = tensor_info; 270 } 271 272 // Adds an entry to the `outputs` map of the supplied SignatureDef. 273 void AddOutputToSignatureDef( 274 const string& tensor_name, 275 const std::unordered_map<string, DataType>& tensor_name_to_dtype, 276 const string& output_key, SignatureDef* signature_def) { 277 if (tensor_name.empty()) { 278 LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to " 279 "SignatureDef outputs."; 280 return; 281 } 282 // Extract the tensor-name in case the supplied string is a tensor-reference. 283 // Example: Extract "x" from "x:0". 284 std::size_t pos = tensor_name.find(":"); 285 const string key = 286 (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name; 287 const auto it_tensor_info = tensor_name_to_dtype.find(key); 288 TensorInfo tensor_info; 289 tensor_info.set_name(tensor_name); 290 if (it_tensor_info != tensor_name_to_dtype.end()) { 291 tensor_info.set_dtype(it_tensor_info->second); 292 } else { 293 LOG(WARNING) 294 << "No dtype found for tensor with name: " << tensor_name << ". " 295 << "Building TensorInfo with only name for SignatureDef outputs." 296 << " Downstream functionality including validation may be " 297 << "impacted."; 298 } 299 (*signature_def->mutable_outputs())[output_key] = tensor_info; 300 } 301 302 // Builds a map from tensor name to the corresponding datatype, by parsing the 303 // MetaGraphDef. 304 Status BuildTensorNameToDtypeMap( 305 const MetaGraphDef& meta_graph_def, 306 std::unordered_map<string, DataType>* tensor_name_to_dtype) { 307 GraphConstructorOptions opts; 308 Graph graph(OpRegistry::Global()); 309 TF_RETURN_IF_ERROR( 310 ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph)); 311 for (Node* node : graph.nodes()) { 312 for (auto dt : node->output_types()) { 313 tensor_name_to_dtype->insert(std::make_pair(node->name(), dt)); 314 } 315 } 316 return Status::OK(); 317 } 318 319 // Converts SessionBundle signatures to SavedModel signature-defs. 320 Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) { 321 Signatures signatures; 322 GetSignatures(*meta_graph_def, &signatures).IgnoreError(); 323 324 // Build a map of tensor-names to the corresponding tensor-info with `name` 325 // and `dtype` fields. 326 std::unordered_map<string, DataType> tensor_name_to_dtype; 327 TF_RETURN_IF_ERROR( 328 BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype)); 329 330 TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef( 331 signatures, tensor_name_to_dtype, meta_graph_def)); 332 TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef( 333 signatures, tensor_name_to_dtype, meta_graph_def)); 334 return Status::OK(); 335 } 336 337 // Converts a SessionBundle to a SavedModelBundle. 338 Status ConvertSessionBundleToSavedModelBundle( 339 SessionBundle& session_bundle, SavedModelBundle* saved_model_bundle) { 340 // Transfer ownership of the session from old to new. 341 saved_model_bundle->session = std::move(session_bundle.session); 342 343 // Copy the meta graph def from the SessionBundle to the SavedModelBundle. 344 saved_model_bundle->meta_graph_def = session_bundle.meta_graph_def; 345 346 // Convert signatures from session-bundle to signature-defs in 347 // saved-model-bundle. 348 return internal::ConvertSignaturesToSignatureDefs( 349 &saved_model_bundle->meta_graph_def); 350 } 351 352 } // namespace internal 353 354 Status LoadSessionBundleOrSavedModelBundle( 355 const SessionOptions& session_options, const RunOptions& run_options, 356 const string& export_dir, 357 const std::unordered_set<string>& saved_model_tags, 358 SavedModelBundle* saved_model_bundle, bool* is_session_bundle) { 359 if (is_session_bundle != nullptr) { 360 *is_session_bundle = false; 361 } 362 if (MaybeSavedModelDirectory(export_dir)) { 363 LOG(INFO) 364 << "Attempting to load native SavedModelBundle in bundle-shim from: " 365 << export_dir; 366 367 return LoadSavedModel(session_options, run_options, export_dir, 368 saved_model_tags, saved_model_bundle); 369 } else if (IsPossibleExportDirectory(export_dir)) { 370 LOG(ERROR) << "Found possible SessionBundle in export directory. " 371 "SessionBundle is deprecated. Use SavedModel instead."; 372 LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle " 373 "in bundle-shim from: " 374 << export_dir; 375 if (is_session_bundle != nullptr) { 376 *is_session_bundle = true; 377 } 378 return LoadSavedModelFromLegacySessionBundlePath( 379 session_options, run_options, export_dir, saved_model_bundle); 380 } 381 return Status( 382 error::Code::NOT_FOUND, 383 strings::StrCat( 384 "Specified file path does not appear to contain a:\n" 385 "- Session bundle (should have a file called `export.meta`)\n" 386 "- or, SavedModel bundle (should have a file called " 387 "`saved_model.pb`)\n" 388 "Specified file path: ", 389 export_dir)); 390 } 391 392 } // namespace serving 393 } // namespace tensorflow 394