Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/bundle_shim.h"
     17 
     18 #include "tensorflow/cc/saved_model/loader.h"
     19 #include "tensorflow/cc/saved_model/signature_constants.h"
     20 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
     21 #include "tensorflow/contrib/session_bundle/session_bundle.h"
     22 #include "tensorflow/contrib/session_bundle/signature.h"
     23 #include "tensorflow/core/graph/graph_constructor.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/core/stringpiece.h"
     27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     28 #include "tensorflow/core/public/session.h"
     29 #include "tensorflow/core/public/session_options.h"
     30 
     31 namespace tensorflow {
     32 namespace serving {
     33 namespace {
     34 ///////////////////////////////////////////////////////////////////////////////
     35 // Helper functions to check Signature type.
     36 
     37 bool IsClassificationSignature(const Signature& signature) {
     38   return signature.type_case() == Signature::kClassificationSignature;
     39 }
     40 
     41 bool IsRegressionSignature(const Signature& signature) {
     42   return signature.type_case() == Signature::kRegressionSignature;
     43 }
     44 
     45 ///////////////////////////////////////////////////////////////////////////////
     46 // Helper functions to build `Classification`, `Regression` and `Predict`
     47 // SignatureDefs.
     48 
     49 SignatureDef BuildRegressionSignatureDef(
     50     const RegressionSignature& regression_signature,
     51     const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
     52   SignatureDef signature_def;
     53   signature_def.set_method_name(kRegressMethodName);
     54   internal::AddInputToSignatureDef(regression_signature.input().tensor_name(),
     55                                    tensor_name_to_dtype, kRegressInputs,
     56                                    &signature_def);
     57   internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(),
     58                                     tensor_name_to_dtype, kRegressOutputs,
     59                                     &signature_def);
     60   return signature_def;
     61 }
     62 
     63 SignatureDef BuildClassificationSignatureDef(
     64     const ClassificationSignature& classification_signature,
     65     const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
     66   SignatureDef signature_def;
     67   signature_def.set_method_name(kClassifyMethodName);
     68   internal::AddInputToSignatureDef(
     69       classification_signature.input().tensor_name(), tensor_name_to_dtype,
     70       kClassifyInputs, &signature_def);
     71   internal::AddOutputToSignatureDef(
     72       classification_signature.classes().tensor_name(), tensor_name_to_dtype,
     73       kClassifyOutputClasses, &signature_def);
     74   internal::AddOutputToSignatureDef(
     75       classification_signature.scores().tensor_name(), tensor_name_to_dtype,
     76       kClassifyOutputScores, &signature_def);
     77   return signature_def;
     78 }
     79 
     80 Status MaybeBuildPredictSignatureDef(
     81     const std::unordered_map<string, DataType>& tensor_name_to_dtype,
     82     MetaGraphDef* meta_graph_def) {
     83   Signature input_signature, output_signature;
     84   // Ensure that named signatures corresponding to `inputs` and `outputs` keys
     85   // exist.
     86   if (!GetNamedSignature(kPredictInputs, *meta_graph_def, &input_signature)
     87            .ok() ||
     88       !GetNamedSignature(kPredictOutputs, *meta_graph_def, &output_signature)
     89            .ok()) {
     90     return Status(error::Code::INVALID_ARGUMENT,
     91                   "Named signatures can only be up-converted if entries "
     92                   "corresponding to both `inputs` and `outputs` exist.");
     93   }
     94   // Ensure the `inputs` and `outputs` named signatures are generic signatures.
     95   if (input_signature.type_case() != Signature::TypeCase::kGenericSignature ||
     96       output_signature.type_case() != Signature::TypeCase::kGenericSignature) {
     97     return Status(error::Code::INVALID_ARGUMENT,
     98                   "Named signatures corresponding to `inputs` and `outputs` "
     99                   "can only be up-converted if they are GenericSignatures.");
    100   }
    101   SignatureDef signature_def;
    102   signature_def.set_method_name(kPredictMethodName);
    103   // Add map entries from the `inputs` generic signature to the input map in the
    104   // signature def.
    105   for (const auto& map_entry : input_signature.generic_signature().map()) {
    106     internal::AddInputToSignatureDef(map_entry.second.tensor_name(),
    107                                      tensor_name_to_dtype, map_entry.first,
    108                                      &signature_def);
    109   }
    110   // Add map entries from the `outputs` generic signature to the output map in
    111   // the signature def.
    112   for (const auto& map_entry : output_signature.generic_signature().map()) {
    113     internal::AddOutputToSignatureDef(map_entry.second.tensor_name(),
    114                                       tensor_name_to_dtype, map_entry.first,
    115                                       &signature_def);
    116   }
    117   // Add the constructed signature def to the signature def map of the meta
    118   // graph def. Use the default key if it isn't already in use.
    119   const bool already_has_default_signature =
    120       meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
    121       meta_graph_def->signature_def().end();
    122   const string signature_def_key =
    123       already_has_default_signature
    124           ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")
    125           : kDefaultServingSignatureDefKey;
    126   (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def;
    127   return Status::OK();
    128 }
    129 
    130 Status LoadSavedModelFromLegacySessionBundlePath(
    131     const SessionOptions& session_options, const RunOptions& run_options,
    132     const StringPiece session_bundle_export_dir,
    133     SavedModelBundle* saved_model_bundle) {
    134   if (session_bundle_export_dir.empty()) {
    135     return Status(error::Code::NOT_FOUND, "Export directory path is empty.");
    136   }
    137   if (!IsPossibleExportDirectory(session_bundle_export_dir)) {
    138     return Status(
    139         error::Code::NOT_FOUND,
    140         "Export directory does not contain a valid SessionBundle export.");
    141   }
    142 
    143   // Build the session-bundle.
    144   SessionBundle session_bundle;
    145   TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions(
    146       session_options, run_options, session_bundle_export_dir,
    147       &session_bundle));
    148 
    149   // Convert the session-bundle to a saved-model-bundle.
    150   return internal::ConvertSessionBundleToSavedModelBundle(session_bundle,
    151                                                           saved_model_bundle);
    152 }
    153 
    154 ///////////////////////////////////////////////////////////////////////////////
    155 // Helper functions to convert `Default` and `Named` signatures to
    156 // SignatureDefs.
    157 
    158 // Up-conversion of default signatures is supported for classification and
    159 // regression.
    160 Status ConvertDefaultSignatureToSignatureDef(
    161     const Signatures& signatures,
    162     const std::unordered_map<string, DataType>& tensor_name_to_dtype,
    163     MetaGraphDef* meta_graph_def) {
    164   if (!signatures.has_default_signature()) {
    165     return Status::OK();
    166   }
    167   const bool already_has_default_signature =
    168       meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
    169       meta_graph_def->signature_def().end();
    170   if (already_has_default_signature) {
    171     return Status(error::Code::ALREADY_EXISTS,
    172                   strings::StrCat(
    173                       "Default signature cannot be up-converted since ",
    174                       kDefaultServingSignatureDefKey, " key already exists."));
    175   }
    176   const Signature& signature = signatures.default_signature();
    177   if (IsRegressionSignature(signature)) {
    178     (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
    179         BuildRegressionSignatureDef(signature.regression_signature(),
    180                                     tensor_name_to_dtype);
    181   } else if (IsClassificationSignature(signature)) {
    182     (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
    183         BuildClassificationSignatureDef(signature.classification_signature(),
    184                                         tensor_name_to_dtype);
    185   } else {
    186     LOG(WARNING) << "Default signature up-conversion to SignatureDef is only "
    187                     "supported for `Classification` and `Regression`. Could "
    188                     "not up-convert signature: "
    189                  << signature.DebugString()
    190                  << ". (If using SessionRun with the SessionBundle export "
    191                     "format please ignore this warning.)";
    192   }
    193   return Status::OK();
    194 }
    195 
    196 Status ConvertNamedSignaturesToSignatureDef(
    197     const Signatures& signatures,
    198     const std::unordered_map<string, DataType>& tensor_name_to_dtype,
    199     MetaGraphDef* meta_graph_def) {
    200   if (signatures.named_signatures().empty()) {
    201     return Status::OK();
    202   }
    203   // Check for a Predict signature for up-conversion.
    204   Status predict_signature_def_status =
    205       MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def);
    206   for (const auto& it_named_signature : signatures.named_signatures()) {
    207     const string key = it_named_signature.first;
    208     // If a Predict SignatureDef was successfully constructed, skip the entries
    209     // corresponding to `inputs` and `outputs`.
    210     if (predict_signature_def_status.ok()) {
    211       if (key == kPredictInputs || key == kPredictOutputs) {
    212         continue;
    213       }
    214     }
    215     const Signature signature = it_named_signature.second;
    216     if (IsRegressionSignature(signature)) {
    217       (*meta_graph_def->mutable_signature_def())[key] =
    218           BuildRegressionSignatureDef(signature.regression_signature(),
    219                                       tensor_name_to_dtype);
    220     } else if (IsClassificationSignature(signature)) {
    221       (*meta_graph_def->mutable_signature_def())[key] =
    222           BuildClassificationSignatureDef(signature.classification_signature(),
    223                                           tensor_name_to_dtype);
    224     } else {
    225       LOG(WARNING)
    226           << "Named signature up-conversion to SignatureDef is only supported "
    227              "for `Classification`, `Regression` or if two `GenericSignatures` "
    228              "signatures  called `inputs` and `outputs` exist, corresponding "
    229              "to the `Prediction` API. Could not up-convert signature: "
    230           << signature.DebugString();
    231     }
    232   }
    233   return Status::OK();
    234 }
    235 
    236 }  // namespace
    237 
    238 namespace internal {
    239 ///////////////////////////////////////////////////////////////////////////////
    240 // Helper functions to populate SignatureDef fields.
    241 
    242 // Adds an entry to the `inputs` map of the supplied SignatureDef.
    243 void AddInputToSignatureDef(
    244     const string& tensor_name,
    245     const std::unordered_map<string, DataType>& tensor_name_to_dtype,
    246     const string& input_key, SignatureDef* signature_def) {
    247   if (tensor_name.empty()) {
    248     LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
    249                     "SignatureDef inputs.";
    250     return;
    251   }
    252   // Extract the tensor-name in case the supplied string is a tensor-reference.
    253   // Example: Extract "x" from "x:0".
    254   std::size_t pos = tensor_name.find(":");
    255   const string key =
    256       (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
    257   const auto it_tensor_info = tensor_name_to_dtype.find(key);
    258   TensorInfo tensor_info;
    259   tensor_info.set_name(tensor_name);
    260   if (it_tensor_info != tensor_name_to_dtype.end()) {
    261     tensor_info.set_dtype(it_tensor_info->second);
    262   } else {
    263     LOG(WARNING)
    264         << "No dtype found for tensor with name: " << tensor_name << ". "
    265         << "Building TensorInfo with only name for SignatureDef inputs. "
    266         << "Downstream functionality including validation may be "
    267         << "impacted.";
    268   }
    269   (*signature_def->mutable_inputs())[input_key] = tensor_info;
    270 }
    271 
    272 // Adds an entry to the `outputs` map of the supplied SignatureDef.
    273 void AddOutputToSignatureDef(
    274     const string& tensor_name,
    275     const std::unordered_map<string, DataType>& tensor_name_to_dtype,
    276     const string& output_key, SignatureDef* signature_def) {
    277   if (tensor_name.empty()) {
    278     LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
    279                     "SignatureDef outputs.";
    280     return;
    281   }
    282   // Extract the tensor-name in case the supplied string is a tensor-reference.
    283   // Example: Extract "x" from "x:0".
    284   std::size_t pos = tensor_name.find(":");
    285   const string key =
    286       (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
    287   const auto it_tensor_info = tensor_name_to_dtype.find(key);
    288   TensorInfo tensor_info;
    289   tensor_info.set_name(tensor_name);
    290   if (it_tensor_info != tensor_name_to_dtype.end()) {
    291     tensor_info.set_dtype(it_tensor_info->second);
    292   } else {
    293     LOG(WARNING)
    294         << "No dtype found for tensor with name: " << tensor_name << ". "
    295         << "Building TensorInfo with only name for SignatureDef outputs."
    296         << " Downstream functionality including validation may be "
    297         << "impacted.";
    298   }
    299   (*signature_def->mutable_outputs())[output_key] = tensor_info;
    300 }
    301 
    302 // Builds a map from tensor name to the corresponding datatype, by parsing the
    303 // MetaGraphDef.
    304 Status BuildTensorNameToDtypeMap(
    305     const MetaGraphDef& meta_graph_def,
    306     std::unordered_map<string, DataType>* tensor_name_to_dtype) {
    307   GraphConstructorOptions opts;
    308   Graph graph(OpRegistry::Global());
    309   TF_RETURN_IF_ERROR(
    310       ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph));
    311   for (Node* node : graph.nodes()) {
    312     for (auto dt : node->output_types()) {
    313       tensor_name_to_dtype->insert(std::make_pair(node->name(), dt));
    314     }
    315   }
    316   return Status::OK();
    317 }
    318 
    319 // Converts SessionBundle signatures to SavedModel signature-defs.
    320 Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) {
    321   Signatures signatures;
    322   GetSignatures(*meta_graph_def, &signatures).IgnoreError();
    323 
    324   // Build a map of tensor-names to the corresponding tensor-info with `name`
    325   // and `dtype` fields.
    326   std::unordered_map<string, DataType> tensor_name_to_dtype;
    327   TF_RETURN_IF_ERROR(
    328       BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype));
    329 
    330   TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef(
    331       signatures, tensor_name_to_dtype, meta_graph_def));
    332   TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef(
    333       signatures, tensor_name_to_dtype, meta_graph_def));
    334   return Status::OK();
    335 }
    336 
    337 // Converts a SessionBundle to a SavedModelBundle.
    338 Status ConvertSessionBundleToSavedModelBundle(
    339     SessionBundle& session_bundle, SavedModelBundle* saved_model_bundle) {
    340   // Transfer ownership of the session from old to new.
    341   saved_model_bundle->session = std::move(session_bundle.session);
    342 
    343   // Copy the meta graph def from the SessionBundle to the SavedModelBundle.
    344   saved_model_bundle->meta_graph_def = session_bundle.meta_graph_def;
    345 
    346   // Convert signatures from session-bundle to signature-defs in
    347   // saved-model-bundle.
    348   return internal::ConvertSignaturesToSignatureDefs(
    349       &saved_model_bundle->meta_graph_def);
    350 }
    351 
    352 }  // namespace internal
    353 
    354 Status LoadSessionBundleOrSavedModelBundle(
    355     const SessionOptions& session_options, const RunOptions& run_options,
    356     const string& export_dir,
    357     const std::unordered_set<string>& saved_model_tags,
    358     SavedModelBundle* saved_model_bundle) {
    359   if (MaybeSavedModelDirectory(export_dir)) {
    360     LOG(INFO)
    361         << "Attempting to load native SavedModelBundle in bundle-shim from: "
    362         << export_dir;
    363     return LoadSavedModel(session_options, run_options, export_dir,
    364                           saved_model_tags, saved_model_bundle);
    365   } else if (IsPossibleExportDirectory(export_dir)) {
    366     LOG(ERROR) << "Found possible SessionBundle in export directory. "
    367                   "SessionBundle is deprecated. Use SavedModel instead.";
    368     LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
    369                  "in bundle-shim from: "
    370               << export_dir;
    371     return LoadSavedModelFromLegacySessionBundlePath(
    372         session_options, run_options, export_dir, saved_model_bundle);
    373   }
    374   return Status(
    375       error::Code::NOT_FOUND,
    376       strings::StrCat(
    377           "Specified file path does not appear to contain a:\n"
    378           "- Session bundle (should have a file called `export.meta`)\n"
    379           "- or, SavedModel bundle (should have a file called "
    380           "`saved_model.pb`)\n"
    381           "Specified file path: ",
    382           export_dir));
    383 }
    384 
    385 }  // namespace serving
    386 }  // namespace tensorflow
    387