Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/bundle_shim.h"
     17 
     18 #include "google/protobuf/any.pb.h"
     19 #include "tensorflow/cc/saved_model/signature_constants.h"
     20 #include "tensorflow/cc/saved_model/tag_constants.h"
     21 #include "tensorflow/contrib/session_bundle/test_util.h"
     22 #include "tensorflow/core/example/example.pb.h"
     23 #include "tensorflow/core/example/feature.pb.h"
     24 #include "tensorflow/core/framework/tensor_testutil.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/lib/io/path.h"
     27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     28 
     29 namespace tensorflow {
     30 namespace serving {
     31 namespace internal {
     32 namespace {
     33 
     34 constexpr char kSessionBundlePath[] =
     35     "session_bundle/testdata/half_plus_two/00000123";
     36 constexpr char kSavedModelBundlePath[] =
     37     "cc/saved_model/testdata/half_plus_two/00000123";
     38 
     39 string MakeSerializedExample(float x) {
     40   tensorflow::Example example;
     41   auto* feature_map = example.mutable_features()->mutable_feature();
     42   (*feature_map)["x"].mutable_float_list()->add_value(x);
     43   return example.SerializeAsString();
     44 }
     45 
     46 void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
     47                          const string& input_tensor_name,
     48                          const string& output_tensor_name) {
     49   // Validate the half plus two behavior.
     50   std::vector<string> serialized_examples;
     51   for (float x : {0, 1, 2, 3}) {
     52     serialized_examples.push_back(MakeSerializedExample(x));
     53   }
     54   Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4}));
     55 
     56   std::vector<Tensor> outputs;
     57   TF_ASSERT_OK(saved_model_bundle.session->Run(
     58       {{input_tensor_name, input}}, {output_tensor_name}, {}, &outputs));
     59   ASSERT_EQ(outputs.size(), 1);
     60   test::ExpectTensorEqual<float>(
     61       outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
     62 }
     63 
     64 void LoadAndValidateSavedModelBundle(const string& export_dir,
     65                                      const std::unordered_set<string>& tags,
     66                                      const string& signature_def_key,
     67                                      bool expect_session_bundle) {
     68   SessionOptions session_options;
     69   RunOptions run_options;
     70   SavedModelBundle saved_model_bundle;
     71   bool is_session_bundle = false;
     72   TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
     73       session_options, run_options, export_dir, tags, &saved_model_bundle,
     74       &is_session_bundle));
     75   EXPECT_EQ(expect_session_bundle, is_session_bundle);
     76   const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
     77   const auto& signature_def_map = meta_graph_def.signature_def();
     78 
     79   const auto& regression_entry = signature_def_map.find(signature_def_key);
     80   ASSERT_FALSE(regression_entry == signature_def_map.end());
     81   SignatureDef regression_signature_def = regression_entry->second;
     82 
     83   EXPECT_EQ(1, regression_signature_def.inputs_size());
     84   ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
     85                regression_signature_def.inputs().end());
     86   TensorInfo input_tensor_info =
     87       regression_signature_def.inputs().find(kRegressInputs)->second;
     88   EXPECT_EQ(1, regression_signature_def.outputs_size());
     89   // Ensure the TensorInfo has dtype populated.
     90   EXPECT_EQ(DT_STRING, input_tensor_info.dtype());
     91 
     92   ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
     93                regression_signature_def.outputs().end());
     94   TensorInfo output_tensor_info =
     95       regression_signature_def.outputs().find(kRegressOutputs)->second;
     96   // Ensure the TensorInfo has dtype populated.
     97   EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype());
     98   ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
     99                       output_tensor_info.name());
    100 }
    101 
    102 // Helper function to validate that the SignatureDef found in the MetaGraphDef
    103 // with the provided key has the expected string representation.
    104 void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
    105                           const string& expected_string_signature_def) {
    106   tensorflow::SignatureDef expected_signature;
    107   CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
    108                                               &expected_signature));
    109   auto iter = meta_graph_def.signature_def().find(key);
    110   ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
    111   EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
    112 }
    113 
    114 // Checks that the input map in a signature def is populated correctly.
    115 TEST(BundleShimTest, AddInputToSignatureDef) {
    116   SignatureDef signature_def;
    117   const string tensor_name = "foo_tensor";
    118   const string map_key = "foo_key";
    119 
    120   // Build a map of tensor-name to dtype, for the unit-test.
    121   std::unordered_map<string, DataType> tensor_name_to_dtype;
    122   tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
    123 
    124   AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
    125                          &signature_def);
    126   EXPECT_EQ(1, signature_def.inputs_size());
    127   EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name());
    128 }
    129 
    130 // Checks that the output map in a signature def is populated correctly.
    131 TEST(BundleShimTest, AddOutputToSignatureDef) {
    132   SignatureDef signature_def;
    133   const string tensor_name = "foo_tensor";
    134   const string map_key = "foo_key";
    135 
    136   // Build a map of tensor-name to dtype, for the unit-test.
    137   std::unordered_map<string, DataType> tensor_name_to_dtype;
    138   tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
    139 
    140   AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
    141                           &signature_def);
    142   EXPECT_EQ(1, signature_def.outputs_size());
    143   EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name());
    144 }
    145 
    146 // Checks that no signature defs are added if the default signature is missing.
    147 TEST(BundleShimTest, DefaultSignatureMissing) {
    148   MetaGraphDef meta_graph_def;
    149   // Signatures signatures;
    150   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    151   EXPECT_EQ(0, meta_graph_def.signature_def_size());
    152 }
    153 
    154 // Checks that no signature defs are added if the default signature is empty.
    155 TEST(BundleShimTest, DefaultSignatureEmpty) {
    156   Signatures signatures;
    157   signatures.mutable_default_signature();
    158 
    159   MetaGraphDef meta_graph_def;
    160   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    161       .mutable_any_list()
    162       ->add_value()
    163       ->PackFrom(signatures);
    164   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    165   EXPECT_EQ(0, meta_graph_def.signature_def_size());
    166 }
    167 
    168 // Checks the conversion to signature def for a regression default signature.
    169 TEST(BundleShimTest, DefaultSignatureRegression) {
    170   Signatures signatures;
    171   RegressionSignature* regression_signature =
    172       signatures.mutable_default_signature()->mutable_regression_signature();
    173   regression_signature->mutable_input()->set_tensor_name("foo-input");
    174   regression_signature->mutable_output()->set_tensor_name("foo-output");
    175   MetaGraphDef meta_graph_def;
    176   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    177       .mutable_any_list()
    178       ->add_value()
    179       ->PackFrom(signatures);
    180   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    181   EXPECT_EQ(1, meta_graph_def.signature_def_size());
    182   const auto actual_signature_def =
    183       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
    184   EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
    185                              .find(kRegressInputs)
    186                              ->second.name());
    187   EXPECT_EQ("foo-output", actual_signature_def->second.outputs()
    188                               .find(kRegressOutputs)
    189                               ->second.name());
    190   EXPECT_EQ(kRegressMethodName, actual_signature_def->second.method_name());
    191 }
    192 
    193 // Checks the conversion to signature def for a classification default
    194 // signature.
    195 TEST(BundleShimTest, DefaultSignatureClassification) {
    196   Signatures signatures;
    197   ClassificationSignature* classification_signature =
    198       signatures.mutable_default_signature()
    199           ->mutable_classification_signature();
    200   classification_signature->mutable_input()->set_tensor_name("foo-input");
    201   classification_signature->mutable_classes()->set_tensor_name("foo-classes");
    202   classification_signature->mutable_scores()->set_tensor_name("foo-scores");
    203   MetaGraphDef meta_graph_def;
    204   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    205       .mutable_any_list()
    206       ->add_value()
    207       ->PackFrom(signatures);
    208   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    209   EXPECT_EQ(1, meta_graph_def.signature_def_size());
    210   const auto actual_signature_def =
    211       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
    212   EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
    213                              .find(kClassifyInputs)
    214                              ->second.name());
    215   EXPECT_EQ("foo-classes", actual_signature_def->second.outputs()
    216                                .find(kClassifyOutputClasses)
    217                                ->second.name());
    218   EXPECT_EQ("foo-scores", actual_signature_def->second.outputs()
    219                               .find(kClassifyOutputScores)
    220                               ->second.name());
    221   EXPECT_EQ(kClassifyMethodName, actual_signature_def->second.method_name());
    222 }
    223 
    224 // Checks that generic default signatures are not up converted.
    225 TEST(BundleShimTest, DefaultSignatureGeneric) {
    226   TensorBinding input_binding;
    227   input_binding.set_tensor_name("foo-input");
    228 
    229   TensorBinding output_binding;
    230   output_binding.set_tensor_name("foo-output");
    231 
    232   Signatures signatures;
    233   GenericSignature* generic_signature =
    234       signatures.mutable_default_signature()->mutable_generic_signature();
    235   generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
    236   generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
    237 
    238   MetaGraphDef meta_graph_def;
    239   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    240       .mutable_any_list()
    241       ->add_value()
    242       ->PackFrom(signatures);
    243   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    244   EXPECT_EQ(0, meta_graph_def.signature_def_size());
    245 }
    246 
    247 TEST(BundleShimTest, NamedRegressionSignatures) {
    248   Signatures signatures;
    249 
    250   RegressionSignature* foo_regression_signature =
    251       (*signatures.mutable_named_signatures())["foo"]
    252           .mutable_regression_signature();
    253   foo_regression_signature->mutable_input()->set_tensor_name("foo-input");
    254   foo_regression_signature->mutable_output()->set_tensor_name("foo-output");
    255 
    256   RegressionSignature* bar_regression_signature =
    257       (*signatures.mutable_named_signatures())["bar"]
    258           .mutable_regression_signature();
    259   bar_regression_signature->mutable_input()->set_tensor_name("bar-input");
    260   bar_regression_signature->mutable_output()->set_tensor_name("bar-output");
    261 
    262   MetaGraphDef meta_graph_def;
    263   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    264       .mutable_any_list()
    265       ->add_value()
    266       ->PackFrom(signatures);
    267   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    268   ASSERT_EQ(2, meta_graph_def.signature_def_size());
    269 
    270   ValidateSignatureDef(meta_graph_def, "foo",
    271                        "inputs { "
    272                        "  key: \"inputs\" "
    273                        "  value { "
    274                        "name: \"foo-input\" "
    275                        "  } "
    276                        "} "
    277                        "outputs { "
    278                        "  key: \"outputs\" "
    279                        "  value { "
    280                        "    name: \"foo-output\" "
    281                        "  } "
    282                        "} "
    283                        "method_name: \"tensorflow/serving/regress\" ");
    284   ValidateSignatureDef(meta_graph_def, "bar",
    285                        "inputs { "
    286                        "  key: \"inputs\" "
    287                        "  value { "
    288                        "name: \"bar-input\" "
    289                        "  } "
    290                        "} "
    291                        "outputs { "
    292                        "  key: \"outputs\" "
    293                        "  value { "
    294                        "    name: \"bar-output\" "
    295                        "  } "
    296                        "} "
    297                        "method_name: \"tensorflow/serving/regress\" ");
    298 }
    299 
    300 TEST(BundleShimTest, NamedClassificationSignatures) {
    301   Signatures signatures;
    302 
    303   ClassificationSignature* foo_classification_signature =
    304       (*signatures.mutable_named_signatures())["foo"]
    305           .mutable_classification_signature();
    306   foo_classification_signature->mutable_input()->set_tensor_name("foo-input");
    307   foo_classification_signature->mutable_classes()->set_tensor_name(
    308       "foo-classes");
    309 
    310   ClassificationSignature* bar_classification_signature =
    311       (*signatures.mutable_named_signatures())["bar"]
    312           .mutable_classification_signature();
    313   bar_classification_signature->mutable_input()->set_tensor_name("bar-input");
    314   bar_classification_signature->mutable_scores()->set_tensor_name("bar-scores");
    315 
    316   MetaGraphDef meta_graph_def;
    317   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    318       .mutable_any_list()
    319       ->add_value()
    320       ->PackFrom(signatures);
    321   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    322   ASSERT_EQ(2, meta_graph_def.signature_def_size());
    323 
    324   ValidateSignatureDef(meta_graph_def, "foo",
    325                        "inputs { "
    326                        "  key: \"inputs\" "
    327                        "  value { "
    328                        "name: \"foo-input\" "
    329                        "  } "
    330                        "} "
    331                        "outputs { "
    332                        "  key: \"classes\" "
    333                        "  value { "
    334                        "    name: \"foo-classes\" "
    335                        "  } "
    336                        "} "
    337                        "method_name: \"tensorflow/serving/classify\" ");
    338   ValidateSignatureDef(meta_graph_def, "bar",
    339                        "inputs { "
    340                        "  key: \"inputs\" "
    341                        "  value { "
    342                        "name: \"bar-input\" "
    343                        "  } "
    344                        "} "
    345                        "outputs { "
    346                        "  key: \"scores\" "
    347                        "  value { "
    348                        "    name: \"bar-scores\" "
    349                        "  } "
    350                        "} "
    351                        "method_name: \"tensorflow/serving/classify\" ");
    352 }
    353 
    354 // Checks the Predict SignatureDef created when the named signatures have
    355 // `inputs` and `outputs`.
    356 TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
    357   TensorBinding input_binding;
    358   input_binding.set_tensor_name("foo-input");
    359 
    360   TensorBinding output_binding;
    361   output_binding.set_tensor_name("foo-output");
    362 
    363   Signatures signatures;
    364   GenericSignature* input_generic_signature =
    365       (*signatures.mutable_named_signatures())[kPredictInputs]
    366           .mutable_generic_signature();
    367   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
    368 
    369   GenericSignature* output_generic_signature =
    370       (*signatures.mutable_named_signatures())[kPredictOutputs]
    371           .mutable_generic_signature();
    372   output_generic_signature->mutable_map()->insert(
    373       {"foo-output", output_binding});
    374 
    375   MetaGraphDef meta_graph_def;
    376   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    377       .mutable_any_list()
    378       ->add_value()
    379       ->PackFrom(signatures);
    380   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    381   EXPECT_EQ(1, meta_graph_def.signature_def_size());
    382   const auto actual_signature_def =
    383       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
    384   ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
    385   ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
    386                actual_signature_def->second.inputs().end());
    387   EXPECT_EQ(
    388       "foo-input",
    389       actual_signature_def->second.inputs().find("foo-input")->second.name());
    390   ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
    391                actual_signature_def->second.outputs().end());
    392   EXPECT_EQ(
    393       "foo-output",
    394       actual_signature_def->second.outputs().find("foo-output")->second.name());
    395   EXPECT_EQ(kPredictMethodName, actual_signature_def->second.method_name());
    396 }
    397 
    398 // Checks that a signature def is not added if the named signatures is generic
    399 // but does not have `inputs` and `outputs`.
    400 TEST(BundleShimTest, NamedSignatureGenericNoInputsOrOutputs) {
    401   TensorBinding input_binding;
    402   input_binding.set_tensor_name("foo-input");
    403 
    404   TensorBinding output_binding;
    405   output_binding.set_tensor_name("foo-output");
    406 
    407   Signatures signatures;
    408   GenericSignature* generic_signature =
    409       (*signatures.mutable_named_signatures())["unknown"]
    410           .mutable_generic_signature();
    411   generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
    412   generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
    413 
    414   MetaGraphDef meta_graph_def;
    415   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    416       .mutable_any_list()
    417       ->add_value()
    418       ->PackFrom(signatures);
    419   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    420   EXPECT_EQ(0, meta_graph_def.signature_def_size());
    421 }
    422 
    423 // Checks that a signature def is not added when the named signatures have only
    424 // one of `inputs` and `outputs`.
    425 TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
    426   TensorBinding input_binding;
    427   input_binding.set_tensor_name("foo-input");
    428 
    429   Signatures signatures;
    430   GenericSignature* input_generic_signature =
    431       (*signatures.mutable_named_signatures())[kPredictInputs]
    432           .mutable_generic_signature();
    433   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
    434 
    435   MetaGraphDef meta_graph_def;
    436   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    437       .mutable_any_list()
    438       ->add_value()
    439       ->PackFrom(signatures);
    440   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    441   EXPECT_EQ(0, meta_graph_def.signature_def_size());
    442 }
    443 
    444 // Tests up-conversion of Signatures to SignatureDefs when both `default` and
    445 // `named` signatures are present.
    446 TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) {
    447   Signatures signatures;
    448 
    449   // Build a generic signature corresponding to `inputs` and add it to the
    450   // Signatures to up-convert.
    451   TensorBinding input_binding;
    452   input_binding.set_tensor_name("foo-input");
    453   GenericSignature* input_generic_signature =
    454       (*signatures.mutable_named_signatures())[kPredictInputs]
    455           .mutable_generic_signature();
    456   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
    457 
    458   // Build a generic signature corresponding to `outputs` and add it to the
    459   // Signatures to up-convert.
    460   TensorBinding output_binding;
    461   output_binding.set_tensor_name("foo-output");
    462   GenericSignature* output_generic_signature =
    463       (*signatures.mutable_named_signatures())[kPredictOutputs]
    464           .mutable_generic_signature();
    465   output_generic_signature->mutable_map()->insert(
    466       {"foo-output", output_binding});
    467 
    468   // Build a regression signature and set it as the default signature.
    469   RegressionSignature* inputs_regression_signature =
    470       (*signatures.mutable_default_signature()).mutable_regression_signature();
    471   inputs_regression_signature->mutable_input()->set_tensor_name("bar-input");
    472 
    473   // Up-convert the available signatures to SignatureDefs.
    474   MetaGraphDef meta_graph_def;
    475   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    476       .mutable_any_list()
    477       ->add_value()
    478       ->PackFrom(signatures);
    479   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
    480   EXPECT_EQ(2, meta_graph_def.signature_def_size());
    481 
    482   // Verify that the default regression signature is converted to a
    483   // SignatureDef that corresponds to the kDefaultServingSignatureDefKey.
    484   const auto actual_signature_def_regress =
    485       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
    486   ASSERT_FALSE(actual_signature_def_regress ==
    487                meta_graph_def.signature_def().end());
    488   ASSERT_FALSE(
    489       actual_signature_def_regress->second.inputs().find(kRegressInputs) ==
    490       actual_signature_def_regress->second.inputs().end());
    491 
    492   // Verify that the `Predict` SignatureDef is created under a different key.
    493   const auto actual_signature_def_predict = meta_graph_def.signature_def().find(
    494       strings::StrCat(kDefaultServingSignatureDefKey, "_from_named"));
    495   ASSERT_FALSE(actual_signature_def_predict ==
    496                meta_graph_def.signature_def().end());
    497   ASSERT_FALSE(
    498       actual_signature_def_predict->second.inputs().find("foo-input") ==
    499       actual_signature_def_predict->second.inputs().end());
    500   EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs()
    501                              .find("foo-input")
    502                              ->second.name());
    503   ASSERT_FALSE(
    504       actual_signature_def_predict->second.outputs().find("foo-output") ==
    505       actual_signature_def_predict->second.outputs().end());
    506   EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs()
    507                               .find("foo-output")
    508                               ->second.name());
    509   EXPECT_EQ(kPredictMethodName,
    510             actual_signature_def_predict->second.method_name());
    511 }
    512 
    513 // Checks a basic up conversion for half plus two for SessionBundle.
    514 TEST(BundleShimTest, BasicExportSessionBundle) {
    515   const std::unordered_set<string> tags = {"tag"};
    516   const string session_bundle_export_dir =
    517       test_util::TestSrcDirPath(kSessionBundlePath);
    518   LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
    519                                   kDefaultServingSignatureDefKey,
    520                                   /*expect_session_bundle=*/true);
    521 
    522   // Verify that the named signature is also present.
    523   SessionOptions session_options;
    524   RunOptions run_options;
    525   SavedModelBundle saved_model_bundle;
    526   TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
    527                                                    session_bundle_export_dir,
    528                                                    tags, &saved_model_bundle));
    529   const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
    530   const auto& signature_def_map = meta_graph_def.signature_def();
    531   bool found_named_signature = false;
    532   for (const auto& entry : signature_def_map) {
    533     const string& key = entry.first;
    534     const SignatureDef& signature_def = entry.second;
    535 
    536     // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
    537     if (key == kDefaultServingSignatureDefKey) {
    538       continue;
    539     }
    540     found_named_signature = true;
    541 
    542     EXPECT_EQ(1, signature_def.inputs_size());
    543     const auto it_inputs_x = signature_def.inputs().find("x");
    544     EXPECT_FALSE(it_inputs_x == signature_def.inputs().end());
    545     // Ensure the TensorInfo has name and dtype populated.
    546     const TensorInfo& tensor_info_x = it_inputs_x->second;
    547     EXPECT_EQ("x:0", tensor_info_x.name());
    548     EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype());
    549 
    550     EXPECT_EQ(1, signature_def.outputs_size());
    551     const auto it_outputs_y = signature_def.outputs().find("y");
    552     EXPECT_FALSE(it_outputs_y == signature_def.outputs().end());
    553     // Ensure the TensorInfo has name and dtype populated.
    554     const TensorInfo& tensor_info_y = it_outputs_y->second;
    555     EXPECT_EQ("y:0", tensor_info_y.name());
    556     EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype());
    557   }
    558   EXPECT_TRUE(found_named_signature);
    559 }
    560 
    561 // Checks a basic load for half plus two for SavedModelBundle.
    562 TEST(BundleShimTest, BasicExportSavedModel) {
    563   const string saved_model_bundle_export_dir =
    564       io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
    565   LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
    566                                   {kSavedModelTagServe}, "regress_x_to_y",
    567                                   /*expect_session_bundle=*/false);
    568 }
    569 
    570 // Checks a basic load fails with an invalid export path.
    571 TEST(BundleShimTest, InvalidPath) {
    572   const string invalid_export_dir = testing::TensorFlowSrcRoot();
    573   SessionOptions session_options;
    574   RunOptions run_options;
    575   SavedModelBundle saved_model_bundle;
    576   Status status = LoadSessionBundleOrSavedModelBundle(
    577       session_options, run_options, invalid_export_dir, {kSavedModelTagServe},
    578       &saved_model_bundle);
    579   EXPECT_EQ(error::Code::NOT_FOUND, status.code());
    580 }
    581 
    582 // Checks that if loading a session bundle fails, the error is propagated to
    583 // LoadSessionBundleOrSavedModelBundle().
    584 TEST(BundleShimTest, LoadSessionBundleError) {
    585   const string session_bundle_export_dir =
    586       test_util::TestSrcDirPath(kSessionBundlePath);
    587   SessionOptions session_options;
    588   RunOptions run_options;
    589   // Invalid threadpool index to use for session-run calls.
    590   run_options.set_inter_op_thread_pool(100);
    591   SavedModelBundle saved_model_bundle;
    592   EXPECT_FALSE(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
    593                                                    session_bundle_export_dir,
    594                                                    {"tag"}, &saved_model_bundle)
    595                    .ok());
    596 }
    597 
    598 }  // namespace
    599 }  // namespace internal
    600 }  // namespace serving
    601 }  // namespace tensorflow
    602