Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/signature.h"
     17 
     18 #include <memory>
     19 
     20 #include "google/protobuf/any.pb.h"
     21 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
     22 #include "tensorflow/core/framework/graph.pb.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/framework/tensor_testutil.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/lib/core/status_test_util.h"
     28 #include "tensorflow/core/lib/core/stringpiece.h"
     29 #include "tensorflow/core/platform/test.h"
     30 #include "tensorflow/core/public/session.h"
     31 
     32 namespace tensorflow {
     33 namespace serving {
     34 namespace {
     35 
     36 static bool HasSubstr(const string& base, const string& substr) {
     37   bool ok = StringPiece(base).contains(substr);
     38   EXPECT_TRUE(ok) << base << ", expected substring " << substr;
     39   return ok;
     40 }
     41 
     42 TEST(GetClassificationSignature, Basic) {
     43   tensorflow::MetaGraphDef meta_graph_def;
     44   Signatures signatures;
     45   ClassificationSignature* input_signature =
     46       signatures.mutable_default_signature()
     47           ->mutable_classification_signature();
     48   input_signature->mutable_input()->set_tensor_name("flow");
     49   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
     50       .mutable_any_list()
     51       ->add_value()
     52       ->PackFrom(signatures);
     53 
     54   ClassificationSignature signature;
     55   const Status status = GetClassificationSignature(meta_graph_def, &signature);
     56   TF_ASSERT_OK(status);
     57   EXPECT_EQ(signature.input().tensor_name(), "flow");
     58 }
     59 
     60 TEST(GetClassificationSignature, MissingSignature) {
     61   tensorflow::MetaGraphDef meta_graph_def;
     62   Signatures signatures;
     63   signatures.mutable_default_signature();
     64   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
     65       .mutable_any_list()
     66       ->add_value()
     67       ->PackFrom(signatures);
     68 
     69   ClassificationSignature signature;
     70   const Status status = GetClassificationSignature(meta_graph_def, &signature);
     71   ASSERT_FALSE(status.ok());
     72   EXPECT_TRUE(StringPiece(status.error_message())
     73                   .contains("Expected a classification signature"))
     74       << status.error_message();
     75 }
     76 
     77 TEST(GetClassificationSignature, WrongSignatureType) {
     78   tensorflow::MetaGraphDef meta_graph_def;
     79   Signatures signatures;
     80   signatures.mutable_default_signature()->mutable_regression_signature();
     81   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
     82       .mutable_any_list()
     83       ->add_value()
     84       ->PackFrom(signatures);
     85 
     86   ClassificationSignature signature;
     87   const Status status = GetClassificationSignature(meta_graph_def, &signature);
     88   ASSERT_FALSE(status.ok());
     89   EXPECT_TRUE(StringPiece(status.error_message())
     90                   .contains("Expected a classification signature"))
     91       << status.error_message();
     92 }
     93 
     94 TEST(GetNamedClassificationSignature, Basic) {
     95   tensorflow::MetaGraphDef meta_graph_def;
     96   Signatures signatures;
     97   ClassificationSignature* input_signature =
     98       (*signatures.mutable_named_signatures())["foo"]
     99           .mutable_classification_signature();
    100   input_signature->mutable_input()->set_tensor_name("flow");
    101   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    102       .mutable_any_list()
    103       ->add_value()
    104       ->PackFrom(signatures);
    105 
    106   ClassificationSignature signature;
    107   const Status status =
    108       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
    109   TF_ASSERT_OK(status);
    110   EXPECT_EQ(signature.input().tensor_name(), "flow");
    111 }
    112 
    113 TEST(GetNamedClassificationSignature, MissingSignature) {
    114   tensorflow::MetaGraphDef meta_graph_def;
    115   Signatures signatures;
    116   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    117       .mutable_any_list()
    118       ->add_value()
    119       ->PackFrom(signatures);
    120 
    121   ClassificationSignature signature;
    122   const Status status =
    123       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
    124   ASSERT_FALSE(status.ok());
    125   EXPECT_TRUE(StringPiece(status.error_message())
    126                   .contains("Missing signature named \"foo\""))
    127       << status.error_message();
    128 }
    129 
    130 TEST(GetNamedClassificationSignature, WrongSignatureType) {
    131   tensorflow::MetaGraphDef meta_graph_def;
    132   Signatures signatures;
    133   (*signatures.mutable_named_signatures())["foo"]
    134       .mutable_regression_signature();
    135   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    136       .mutable_any_list()
    137       ->add_value()
    138       ->PackFrom(signatures);
    139 
    140   ClassificationSignature signature;
    141   const Status status =
    142       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
    143   ASSERT_FALSE(status.ok());
    144   EXPECT_TRUE(
    145       StringPiece(status.error_message())
    146           .contains("Expected a classification signature for name \"foo\""))
    147       << status.error_message();
    148 }
    149 
    150 TEST(GetRegressionSignature, Basic) {
    151   tensorflow::MetaGraphDef meta_graph_def;
    152   Signatures signatures;
    153   RegressionSignature* input_signature =
    154       signatures.mutable_default_signature()->mutable_regression_signature();
    155   input_signature->mutable_input()->set_tensor_name("flow");
    156   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    157       .mutable_any_list()
    158       ->add_value()
    159       ->PackFrom(signatures);
    160 
    161   RegressionSignature signature;
    162   const Status status = GetRegressionSignature(meta_graph_def, &signature);
    163   TF_ASSERT_OK(status);
    164   EXPECT_EQ(signature.input().tensor_name(), "flow");
    165 }
    166 
    167 TEST(GetRegressionSignature, MissingSignature) {
    168   tensorflow::MetaGraphDef meta_graph_def;
    169   Signatures signatures;
    170   signatures.mutable_default_signature();
    171   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    172       .mutable_any_list()
    173       ->add_value()
    174       ->PackFrom(signatures);
    175 
    176   RegressionSignature signature;
    177   const Status status = GetRegressionSignature(meta_graph_def, &signature);
    178   ASSERT_FALSE(status.ok());
    179   EXPECT_TRUE(StringPiece(status.error_message())
    180                   .contains("Expected a regression signature"))
    181       << status.error_message();
    182 }
    183 
    184 TEST(GetRegressionSignature, WrongSignatureType) {
    185   tensorflow::MetaGraphDef meta_graph_def;
    186   Signatures signatures;
    187   signatures.mutable_default_signature()->mutable_classification_signature();
    188   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    189       .mutable_any_list()
    190       ->add_value()
    191       ->PackFrom(signatures);
    192 
    193   RegressionSignature signature;
    194   const Status status = GetRegressionSignature(meta_graph_def, &signature);
    195   ASSERT_FALSE(status.ok());
    196   EXPECT_TRUE(StringPiece(status.error_message())
    197                   .contains("Expected a regression signature"))
    198       << status.error_message();
    199 }
    200 
    201 TEST(GetNamedSignature, Basic) {
    202   tensorflow::MetaGraphDef meta_graph_def;
    203   Signatures signatures;
    204   ClassificationSignature* input_signature =
    205       (*signatures.mutable_named_signatures())["foo"]
    206           .mutable_classification_signature();
    207   input_signature->mutable_input()->set_tensor_name("flow");
    208   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    209       .mutable_any_list()
    210       ->add_value()
    211       ->PackFrom(signatures);
    212 
    213   Signature signature;
    214   const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
    215   TF_ASSERT_OK(status);
    216   EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow");
    217 }
    218 
    219 TEST(GetNamedSignature, MissingSignature) {
    220   tensorflow::MetaGraphDef meta_graph_def;
    221   Signatures signatures;
    222   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    223       .mutable_any_list()
    224       ->add_value()
    225       ->PackFrom(signatures);
    226 
    227   Signature signature;
    228   const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
    229   ASSERT_FALSE(status.ok());
    230   EXPECT_TRUE(StringPiece(status.error_message())
    231                   .contains("Missing signature named \"foo\""))
    232       << status.error_message();
    233 }
    234 
    235 // MockSession used to test input and output interactions with a
    236 // tensorflow::Session.
    237 struct MockSession : public tensorflow::Session {
    238   ~MockSession() override = default;
    239 
    240   Status Create(const GraphDef& graph) override {
    241     return errors::Unimplemented("Not implemented for mock.");
    242   }
    243 
    244   Status Extend(const GraphDef& graph) override {
    245     return errors::Unimplemented("Not implemented for mock.");
    246   }
    247 
    248   // Sets the input and output arguments.
    249   Status Run(const std::vector<std::pair<string, Tensor>>& inputs_arg,
    250              const std::vector<string>& output_tensor_names_arg,
    251              const std::vector<string>& target_node_names_arg,
    252              std::vector<Tensor>* outputs_arg) override {
    253     inputs = inputs_arg;
    254     output_tensor_names = output_tensor_names_arg;
    255     target_node_names = target_node_names_arg;
    256     *outputs_arg = outputs;
    257     return status;
    258   }
    259 
    260   Status Close() override {
    261     return errors::Unimplemented("Not implemented for mock.");
    262   }
    263 
    264   Status ListDevices(std::vector<DeviceAttributes>* response) override {
    265     return errors::Unimplemented("Not implemented for mock.");
    266   }
    267 
    268   // Arguments stored on a Run call.
    269   std::vector<std::pair<string, Tensor>> inputs;
    270   std::vector<string> output_tensor_names;
    271   std::vector<string> target_node_names;
    272 
    273   // Output argument set by Run; should be set before calling.
    274   std::vector<Tensor> outputs;
    275 
    276   // Return value for Run; should be set before calling.
    277   Status status;
    278 };
    279 
    280 constexpr char kInputName[] = "in:0";
    281 constexpr char kClassesName[] = "classes:0";
    282 constexpr char kScoresName[] = "scores:0";
    283 
    284 class RunClassificationTest : public ::testing::Test {
    285  public:
    286   void SetUp() override {
    287     signature_.mutable_input()->set_tensor_name(kInputName);
    288     signature_.mutable_classes()->set_tensor_name(kClassesName);
    289     signature_.mutable_scores()->set_tensor_name(kScoresName);
    290   }
    291 
    292  protected:
    293   ClassificationSignature signature_;
    294   Tensor input_tensor_;
    295   Tensor classes_tensor_;
    296   Tensor scores_tensor_;
    297   MockSession session_;
    298 };
    299 
    300 TEST_F(RunClassificationTest, Basic) {
    301   input_tensor_ = test::AsTensor<int>({99});
    302   session_.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({2})};
    303   const Status status = RunClassification(signature_, input_tensor_, &session_,
    304                                           &classes_tensor_, &scores_tensor_);
    305 
    306   // Validate outputs.
    307   TF_ASSERT_OK(status);
    308   test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
    309   test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
    310 
    311   // Validate inputs.
    312   ASSERT_EQ(1, session_.inputs.size());
    313   EXPECT_EQ(kInputName, session_.inputs[0].first);
    314   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
    315                                session_.inputs[0].second);
    316 
    317   ASSERT_EQ(2, session_.output_tensor_names.size());
    318   EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
    319   EXPECT_EQ(kScoresName, session_.output_tensor_names[1]);
    320 }
    321 
    322 TEST_F(RunClassificationTest, ClassesOnly) {
    323   input_tensor_ = test::AsTensor<int>({99});
    324   session_.outputs = {test::AsTensor<int>({3})};
    325   const Status status = RunClassification(signature_, input_tensor_, &session_,
    326                                           &classes_tensor_, nullptr);
    327 
    328   // Validate outputs.
    329   TF_ASSERT_OK(status);
    330   test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
    331 
    332   // Validate inputs.
    333   ASSERT_EQ(1, session_.inputs.size());
    334   EXPECT_EQ(kInputName, session_.inputs[0].first);
    335   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
    336                                session_.inputs[0].second);
    337 
    338   ASSERT_EQ(1, session_.output_tensor_names.size());
    339   EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
    340 }
    341 
    342 TEST_F(RunClassificationTest, ScoresOnly) {
    343   input_tensor_ = test::AsTensor<int>({99});
    344   session_.outputs = {test::AsTensor<int>({2})};
    345   const Status status = RunClassification(signature_, input_tensor_, &session_,
    346                                           nullptr, &scores_tensor_);
    347 
    348   // Validate outputs.
    349   TF_ASSERT_OK(status);
    350   test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
    351 
    352   // Validate inputs.
    353   ASSERT_EQ(1, session_.inputs.size());
    354   EXPECT_EQ(kInputName, session_.inputs[0].first);
    355   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
    356                                session_.inputs[0].second);
    357 
    358   ASSERT_EQ(1, session_.output_tensor_names.size());
    359   EXPECT_EQ(kScoresName, session_.output_tensor_names[0]);
    360 }
    361 
    362 TEST(RunClassification, RunNotOk) {
    363   ClassificationSignature signature;
    364   signature.mutable_input()->set_tensor_name("in:0");
    365   signature.mutable_classes()->set_tensor_name("classes:0");
    366   Tensor input_tensor = test::AsTensor<int>({99});
    367   MockSession session;
    368   session.status = errors::DataLoss("Data is gone");
    369   Tensor classes_tensor;
    370   const Status status = RunClassification(signature, input_tensor, &session,
    371                                           &classes_tensor, nullptr);
    372   ASSERT_FALSE(status.ok());
    373   EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
    374       << status.error_message();
    375 }
    376 
    377 TEST(RunClassification, TooManyOutputs) {
    378   ClassificationSignature signature;
    379   signature.mutable_input()->set_tensor_name("in:0");
    380   signature.mutable_classes()->set_tensor_name("classes:0");
    381   Tensor input_tensor = test::AsTensor<int>({99});
    382   MockSession session;
    383   session.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({4})};
    384 
    385   Tensor classes_tensor;
    386   const Status status = RunClassification(signature, input_tensor, &session,
    387                                           &classes_tensor, nullptr);
    388   ASSERT_FALSE(status.ok());
    389   EXPECT_TRUE(StringPiece(status.error_message()).contains("Expected 1 output"))
    390       << status.error_message();
    391 }
    392 
    393 TEST(RunClassification, WrongBatchOutputs) {
    394   ClassificationSignature signature;
    395   signature.mutable_input()->set_tensor_name("in:0");
    396   signature.mutable_classes()->set_tensor_name("classes:0");
    397   Tensor input_tensor = test::AsTensor<int>({99, 100});
    398   MockSession session;
    399   session.outputs = {test::AsTensor<int>({3})};
    400 
    401   Tensor classes_tensor;
    402   const Status status = RunClassification(signature, input_tensor, &session,
    403                                           &classes_tensor, nullptr);
    404   ASSERT_FALSE(status.ok());
    405   EXPECT_TRUE(StringPiece(status.error_message())
    406                   .contains("Input batch size did not match output batch size"))
    407       << status.error_message();
    408 }
    409 
    410 constexpr char kRegressionsName[] = "regressions:0";
    411 
    412 class RunRegressionTest : public ::testing::Test {
    413  public:
    414   void SetUp() override {
    415     signature_.mutable_input()->set_tensor_name(kInputName);
    416     signature_.mutable_output()->set_tensor_name(kRegressionsName);
    417   }
    418 
    419  protected:
    420   RegressionSignature signature_;
    421   Tensor input_tensor_;
    422   Tensor output_tensor_;
    423   MockSession session_;
    424 };
    425 
    426 TEST_F(RunRegressionTest, Basic) {
    427   input_tensor_ = test::AsTensor<int>({99, 100});
    428   session_.outputs = {test::AsTensor<float>({1, 2})};
    429   const Status status =
    430       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
    431 
    432   // Validate outputs.
    433   TF_ASSERT_OK(status);
    434   test::ExpectTensorEqual<float>(test::AsTensor<float>({1, 2}), output_tensor_);
    435 
    436   // Validate inputs.
    437   ASSERT_EQ(1, session_.inputs.size());
    438   EXPECT_EQ(kInputName, session_.inputs[0].first);
    439   test::ExpectTensorEqual<int>(test::AsTensor<int>({99, 100}),
    440                                session_.inputs[0].second);
    441 
    442   ASSERT_EQ(1, session_.output_tensor_names.size());
    443   EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]);
    444 }
    445 
    446 TEST_F(RunRegressionTest, RunNotOk) {
    447   input_tensor_ = test::AsTensor<int>({99});
    448   session_.status = errors::DataLoss("Data is gone");
    449   const Status status =
    450       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
    451   ASSERT_FALSE(status.ok());
    452   EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
    453       << status.error_message();
    454 }
    455 
    456 TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
    457   input_tensor_ = test::AsTensor<int>({99, 100});
    458   session_.outputs = {test::AsTensor<float>({3})};
    459 
    460   const Status status =
    461       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
    462   ASSERT_FALSE(status.ok());
    463   EXPECT_TRUE(StringPiece(status.error_message())
    464                   .contains("Input batch size did not match output batch size"))
    465       << status.error_message();
    466 }
    467 
    468 TEST(SetAndGetSignatures, RoundTrip) {
    469   tensorflow::MetaGraphDef meta_graph_def;
    470   Signatures signatures;
    471   signatures.mutable_default_signature()
    472       ->mutable_classification_signature()
    473       ->mutable_input()
    474       ->set_tensor_name("in:0");
    475   TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def));
    476   Signatures read_signatures;
    477   TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures));
    478 
    479   EXPECT_EQ("in:0", read_signatures.default_signature()
    480                         .classification_signature()
    481                         .input()
    482                         .tensor_name());
    483 }
    484 
    485 TEST(GetSignatures, MissingSignature) {
    486   tensorflow::MetaGraphDef meta_graph_def;
    487   Signatures read_signatures;
    488   const auto status = GetSignatures(meta_graph_def, &read_signatures);
    489   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
    490   EXPECT_TRUE(
    491       StringPiece(status.error_message()).contains("Expected exactly one"))
    492       << status.error_message();
    493 }
    494 
    495 TEST(GetSignatures, WrongProtoInAny) {
    496   tensorflow::MetaGraphDef meta_graph_def;
    497   auto& collection_def = *(meta_graph_def.mutable_collection_def());
    498   auto* any =
    499       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
    500   // Put an unexpected type into the Signatures Any.
    501   any->PackFrom(TensorBinding());
    502   Signatures read_signatures;
    503   const auto status = GetSignatures(meta_graph_def, &read_signatures);
    504   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
    505   EXPECT_TRUE(StringPiece(status.error_message())
    506                   .contains("Expected Any type_url for: "
    507                             "tensorflow.serving.Signatures"))
    508       << status.error_message();
    509 }
    510 
    511 TEST(GetSignatures, JunkInAny) {
    512   tensorflow::MetaGraphDef meta_graph_def;
    513   auto& collection_def = *(meta_graph_def.mutable_collection_def());
    514   auto* any =
    515       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
    516   // Create a valid Any then corrupt it.
    517   any->PackFrom(Signatures());
    518   any->set_value("junk junk");
    519   Signatures read_signatures;
    520   const auto status = GetSignatures(meta_graph_def, &read_signatures);
    521   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
    522   EXPECT_TRUE(StringPiece(status.error_message()).contains("Failed to unpack"))
    523       << status.error_message();
    524 }
    525 
    526 TEST(GetSignatures, DefaultAndNamedTogetherOK) {
    527   tensorflow::MetaGraphDef meta_graph_def;
    528   auto& collection_def = *(meta_graph_def.mutable_collection_def());
    529   auto* any =
    530       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
    531   Signatures signatures;
    532   signatures.mutable_default_signature()
    533       ->mutable_classification_signature()
    534       ->mutable_input()
    535       ->set_tensor_name("in:0");
    536   ClassificationSignature* input_signature =
    537       (*signatures.mutable_named_signatures())["foo"]
    538           .mutable_classification_signature();
    539   input_signature->mutable_input()->set_tensor_name("flow");
    540 
    541   any->PackFrom(signatures);
    542   Signatures read_signatures;
    543   const auto status = GetSignatures(meta_graph_def, &read_signatures);
    544 
    545   EXPECT_TRUE(status.ok());
    546 }
    547 
    548 // Check that we only have one 'Signatures' entry in the collection_def map.
    549 // Note that each such object can have multiple named_signatures inside of it.
    550 TEST(GetSignatures, MultipleSignaturesNotOK) {
    551   tensorflow::MetaGraphDef meta_graph_def;
    552   auto& collection_def = *(meta_graph_def.mutable_collection_def());
    553   auto* any =
    554       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
    555   Signatures signatures;
    556   signatures.mutable_default_signature()
    557       ->mutable_classification_signature()
    558       ->mutable_input()
    559       ->set_tensor_name("in:0");
    560   any->PackFrom(signatures);
    561 
    562   // Add another signatures object.
    563   any =
    564       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
    565   any->PackFrom(signatures);
    566   Signatures read_signatures;
    567   const auto status = GetSignatures(meta_graph_def, &read_signatures);
    568   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
    569   EXPECT_TRUE(
    570       StringPiece(status.error_message()).contains("Expected exactly one"))
    571       << status.error_message();
    572 }
    573 
    574 // GenericSignature test fixture that contains a signature initialized with two
    575 // bound Tensors.
    576 class GenericSignatureTest : public ::testing::Test {
    577  protected:
    578   GenericSignatureTest() {
    579     TensorBinding binding;
    580     binding.set_tensor_name("graph_A");
    581     signature_.mutable_map()->insert({"logical_A", binding});
    582 
    583     binding.set_tensor_name("graph_B");
    584     signature_.mutable_map()->insert({"logical_B", binding});
    585   }
    586 
    587   // GenericSignature that contains two bound Tensors.
    588   GenericSignature signature_;
    589 };
    590 
    591 // GenericSignature tests.
    592 
    593 TEST_F(GenericSignatureTest, GetGenericSignatureBasic) {
    594   Signature expected_signature;
    595   expected_signature.mutable_generic_signature()->MergeFrom(signature_);
    596 
    597   tensorflow::MetaGraphDef meta_graph_def;
    598   Signatures signatures;
    599   signatures.mutable_named_signatures()->insert(
    600       {"generic_bindings", expected_signature});
    601   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    602       .mutable_any_list()
    603       ->add_value()
    604       ->PackFrom(signatures);
    605 
    606   GenericSignature actual_signature;
    607   TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def,
    608                                    &actual_signature));
    609   ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name());
    610   ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name());
    611 }
    612 
    613 TEST(GetGenericSignature, MissingSignature) {
    614   tensorflow::MetaGraphDef meta_graph_def;
    615   Signatures signatures;
    616   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    617       .mutable_any_list()
    618       ->add_value()
    619       ->PackFrom(signatures);
    620 
    621   GenericSignature signature;
    622   const Status status =
    623       GetGenericSignature("generic_bindings", meta_graph_def, &signature);
    624   ASSERT_FALSE(status.ok());
    625   EXPECT_TRUE(HasSubstr(status.error_message(),
    626                         "Missing generic signature named \"generic_bindings\""))
    627       << status.error_message();
    628 }
    629 
    630 TEST(GetGenericSignature, WrongSignatureType) {
    631   tensorflow::MetaGraphDef meta_graph_def;
    632   Signatures signatures;
    633   (*signatures.mutable_named_signatures())["generic_bindings"]
    634       .mutable_regression_signature();
    635   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
    636       .mutable_any_list()
    637       ->add_value()
    638       ->PackFrom(signatures);
    639 
    640   GenericSignature signature;
    641   const Status status =
    642       GetGenericSignature("generic_bindings", meta_graph_def, &signature);
    643   ASSERT_FALSE(status.ok());
    644   EXPECT_TRUE(StringPiece(status.error_message())
    645                   .contains("Expected a generic signature:"))
    646       << status.error_message();
    647 }
    648 
    649 // BindGeneric Tests.
    650 
    651 TEST_F(GenericSignatureTest, BindGenericInputsBasic) {
    652   const std::vector<std::pair<string, Tensor>> inputs = {
    653       {"logical_A", test::AsTensor<float>({-1.0})},
    654       {"logical_B", test::AsTensor<float>({-2.0})}};
    655 
    656   std::vector<std::pair<string, Tensor>> bound_inputs;
    657   TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs));
    658 
    659   EXPECT_EQ("graph_A", bound_inputs[0].first);
    660   EXPECT_EQ("graph_B", bound_inputs[1].first);
    661   test::ExpectTensorEqual<float>(test::AsTensor<float>({-1.0}),
    662                                  bound_inputs[0].second);
    663   test::ExpectTensorEqual<float>(test::AsTensor<float>({-2.0}),
    664                                  bound_inputs[1].second);
    665 }
    666 
    667 TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) {
    668   const std::vector<std::pair<string, Tensor>> inputs = {
    669       {"logical_A", test::AsTensor<float>({-42.0})},
    670       {"logical_MISSING", test::AsTensor<float>({-43.0})}};
    671 
    672   std::vector<std::pair<string, Tensor>> bound_inputs;
    673   const Status status = BindGenericInputs(signature_, inputs, &bound_inputs);
    674   ASSERT_FALSE(status.ok());
    675 }
    676 
    677 TEST_F(GenericSignatureTest, BindGenericNamesBasic) {
    678   const std::vector<string> input_names = {"logical_B", "logical_A"};
    679   std::vector<string> bound_names;
    680   TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names));
    681 
    682   EXPECT_EQ("graph_B", bound_names[0]);
    683   EXPECT_EQ("graph_A", bound_names[1]);
    684 }
    685 
    686 TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) {
    687   const std::vector<string> input_names = {"logical_B", "logical_MISSING"};
    688   std::vector<string> bound_names;
    689   const Status status = BindGenericNames(signature_, input_names, &bound_names);
    690   ASSERT_FALSE(status.ok());
    691 }
    692 
    693 }  // namespace
    694 }  // namespace serving
    695 }  // namespace tensorflow
    696