1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/session_bundle/signature.h" 17 18 #include <memory> 19 20 #include "google/protobuf/any.pb.h" 21 #include "tensorflow/contrib/session_bundle/manifest.pb.h" 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_testutil.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/core/status_test_util.h" 28 #include "tensorflow/core/lib/core/stringpiece.h" 29 #include "tensorflow/core/platform/test.h" 30 #include "tensorflow/core/public/session.h" 31 32 namespace tensorflow { 33 namespace serving { 34 namespace { 35 36 static bool HasSubstr(const string& base, const string& substr) { 37 bool ok = StringPiece(base).contains(substr); 38 EXPECT_TRUE(ok) << base << ", expected substring " << substr; 39 return ok; 40 } 41 42 TEST(GetClassificationSignature, Basic) { 43 tensorflow::MetaGraphDef meta_graph_def; 44 Signatures signatures; 45 ClassificationSignature* input_signature = 46 signatures.mutable_default_signature() 47 ->mutable_classification_signature(); 48 input_signature->mutable_input()->set_tensor_name("flow"); 49 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 50 .mutable_any_list() 51 ->add_value() 52 ->PackFrom(signatures); 53 54 ClassificationSignature signature; 55 const Status status = GetClassificationSignature(meta_graph_def, &signature); 56 TF_ASSERT_OK(status); 57 EXPECT_EQ(signature.input().tensor_name(), "flow"); 58 } 59 60 TEST(GetClassificationSignature, MissingSignature) { 61 tensorflow::MetaGraphDef meta_graph_def; 62 Signatures signatures; 63 signatures.mutable_default_signature(); 64 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 65 .mutable_any_list() 66 ->add_value() 67 ->PackFrom(signatures); 68 69 ClassificationSignature signature; 70 const Status status = GetClassificationSignature(meta_graph_def, &signature); 71 ASSERT_FALSE(status.ok()); 72 EXPECT_TRUE(StringPiece(status.error_message()) 73 .contains("Expected a classification signature")) 74 << status.error_message(); 75 } 76 77 TEST(GetClassificationSignature, WrongSignatureType) { 78 tensorflow::MetaGraphDef meta_graph_def; 79 Signatures signatures; 80 signatures.mutable_default_signature()->mutable_regression_signature(); 81 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 82 .mutable_any_list() 83 ->add_value() 84 ->PackFrom(signatures); 85 86 ClassificationSignature signature; 87 const Status status = GetClassificationSignature(meta_graph_def, &signature); 88 ASSERT_FALSE(status.ok()); 89 EXPECT_TRUE(StringPiece(status.error_message()) 90 .contains("Expected a classification signature")) 91 << status.error_message(); 92 } 93 94 TEST(GetNamedClassificationSignature, Basic) { 95 tensorflow::MetaGraphDef meta_graph_def; 96 Signatures signatures; 97 ClassificationSignature* input_signature = 98 (*signatures.mutable_named_signatures())["foo"] 99 .mutable_classification_signature(); 100 input_signature->mutable_input()->set_tensor_name("flow"); 101 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 102 .mutable_any_list() 103 ->add_value() 104 ->PackFrom(signatures); 105 106 ClassificationSignature signature; 107 const Status status = 108 GetNamedClassificationSignature("foo", meta_graph_def, &signature); 109 TF_ASSERT_OK(status); 110 EXPECT_EQ(signature.input().tensor_name(), "flow"); 111 } 112 113 TEST(GetNamedClassificationSignature, MissingSignature) { 114 tensorflow::MetaGraphDef meta_graph_def; 115 Signatures signatures; 116 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 117 .mutable_any_list() 118 ->add_value() 119 ->PackFrom(signatures); 120 121 ClassificationSignature signature; 122 const Status status = 123 GetNamedClassificationSignature("foo", meta_graph_def, &signature); 124 ASSERT_FALSE(status.ok()); 125 EXPECT_TRUE(StringPiece(status.error_message()) 126 .contains("Missing signature named \"foo\"")) 127 << status.error_message(); 128 } 129 130 TEST(GetNamedClassificationSignature, WrongSignatureType) { 131 tensorflow::MetaGraphDef meta_graph_def; 132 Signatures signatures; 133 (*signatures.mutable_named_signatures())["foo"] 134 .mutable_regression_signature(); 135 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 136 .mutable_any_list() 137 ->add_value() 138 ->PackFrom(signatures); 139 140 ClassificationSignature signature; 141 const Status status = 142 GetNamedClassificationSignature("foo", meta_graph_def, &signature); 143 ASSERT_FALSE(status.ok()); 144 EXPECT_TRUE( 145 StringPiece(status.error_message()) 146 .contains("Expected a classification signature for name \"foo\"")) 147 << status.error_message(); 148 } 149 150 TEST(GetRegressionSignature, Basic) { 151 tensorflow::MetaGraphDef meta_graph_def; 152 Signatures signatures; 153 RegressionSignature* input_signature = 154 signatures.mutable_default_signature()->mutable_regression_signature(); 155 input_signature->mutable_input()->set_tensor_name("flow"); 156 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 157 .mutable_any_list() 158 ->add_value() 159 ->PackFrom(signatures); 160 161 RegressionSignature signature; 162 const Status status = GetRegressionSignature(meta_graph_def, &signature); 163 TF_ASSERT_OK(status); 164 EXPECT_EQ(signature.input().tensor_name(), "flow"); 165 } 166 167 TEST(GetRegressionSignature, MissingSignature) { 168 tensorflow::MetaGraphDef meta_graph_def; 169 Signatures signatures; 170 signatures.mutable_default_signature(); 171 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 172 .mutable_any_list() 173 ->add_value() 174 ->PackFrom(signatures); 175 176 RegressionSignature signature; 177 const Status status = GetRegressionSignature(meta_graph_def, &signature); 178 ASSERT_FALSE(status.ok()); 179 EXPECT_TRUE(StringPiece(status.error_message()) 180 .contains("Expected a regression signature")) 181 << status.error_message(); 182 } 183 184 TEST(GetRegressionSignature, WrongSignatureType) { 185 tensorflow::MetaGraphDef meta_graph_def; 186 Signatures signatures; 187 signatures.mutable_default_signature()->mutable_classification_signature(); 188 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 189 .mutable_any_list() 190 ->add_value() 191 ->PackFrom(signatures); 192 193 RegressionSignature signature; 194 const Status status = GetRegressionSignature(meta_graph_def, &signature); 195 ASSERT_FALSE(status.ok()); 196 EXPECT_TRUE(StringPiece(status.error_message()) 197 .contains("Expected a regression signature")) 198 << status.error_message(); 199 } 200 201 TEST(GetNamedSignature, Basic) { 202 tensorflow::MetaGraphDef meta_graph_def; 203 Signatures signatures; 204 ClassificationSignature* input_signature = 205 (*signatures.mutable_named_signatures())["foo"] 206 .mutable_classification_signature(); 207 input_signature->mutable_input()->set_tensor_name("flow"); 208 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 209 .mutable_any_list() 210 ->add_value() 211 ->PackFrom(signatures); 212 213 Signature signature; 214 const Status status = GetNamedSignature("foo", meta_graph_def, &signature); 215 TF_ASSERT_OK(status); 216 EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow"); 217 } 218 219 TEST(GetNamedSignature, MissingSignature) { 220 tensorflow::MetaGraphDef meta_graph_def; 221 Signatures signatures; 222 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 223 .mutable_any_list() 224 ->add_value() 225 ->PackFrom(signatures); 226 227 Signature signature; 228 const Status status = GetNamedSignature("foo", meta_graph_def, &signature); 229 ASSERT_FALSE(status.ok()); 230 EXPECT_TRUE(StringPiece(status.error_message()) 231 .contains("Missing signature named \"foo\"")) 232 << status.error_message(); 233 } 234 235 // MockSession used to test input and output interactions with a 236 // tensorflow::Session. 237 struct MockSession : public tensorflow::Session { 238 ~MockSession() override = default; 239 240 Status Create(const GraphDef& graph) override { 241 return errors::Unimplemented("Not implemented for mock."); 242 } 243 244 Status Extend(const GraphDef& graph) override { 245 return errors::Unimplemented("Not implemented for mock."); 246 } 247 248 // Sets the input and output arguments. 249 Status Run(const std::vector<std::pair<string, Tensor>>& inputs_arg, 250 const std::vector<string>& output_tensor_names_arg, 251 const std::vector<string>& target_node_names_arg, 252 std::vector<Tensor>* outputs_arg) override { 253 inputs = inputs_arg; 254 output_tensor_names = output_tensor_names_arg; 255 target_node_names = target_node_names_arg; 256 *outputs_arg = outputs; 257 return status; 258 } 259 260 Status Close() override { 261 return errors::Unimplemented("Not implemented for mock."); 262 } 263 264 Status ListDevices(std::vector<DeviceAttributes>* response) override { 265 return errors::Unimplemented("Not implemented for mock."); 266 } 267 268 // Arguments stored on a Run call. 269 std::vector<std::pair<string, Tensor>> inputs; 270 std::vector<string> output_tensor_names; 271 std::vector<string> target_node_names; 272 273 // Output argument set by Run; should be set before calling. 274 std::vector<Tensor> outputs; 275 276 // Return value for Run; should be set before calling. 277 Status status; 278 }; 279 280 constexpr char kInputName[] = "in:0"; 281 constexpr char kClassesName[] = "classes:0"; 282 constexpr char kScoresName[] = "scores:0"; 283 284 class RunClassificationTest : public ::testing::Test { 285 public: 286 void SetUp() override { 287 signature_.mutable_input()->set_tensor_name(kInputName); 288 signature_.mutable_classes()->set_tensor_name(kClassesName); 289 signature_.mutable_scores()->set_tensor_name(kScoresName); 290 } 291 292 protected: 293 ClassificationSignature signature_; 294 Tensor input_tensor_; 295 Tensor classes_tensor_; 296 Tensor scores_tensor_; 297 MockSession session_; 298 }; 299 300 TEST_F(RunClassificationTest, Basic) { 301 input_tensor_ = test::AsTensor<int>({99}); 302 session_.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({2})}; 303 const Status status = RunClassification(signature_, input_tensor_, &session_, 304 &classes_tensor_, &scores_tensor_); 305 306 // Validate outputs. 307 TF_ASSERT_OK(status); 308 test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_); 309 test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_); 310 311 // Validate inputs. 312 ASSERT_EQ(1, session_.inputs.size()); 313 EXPECT_EQ(kInputName, session_.inputs[0].first); 314 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}), 315 session_.inputs[0].second); 316 317 ASSERT_EQ(2, session_.output_tensor_names.size()); 318 EXPECT_EQ(kClassesName, session_.output_tensor_names[0]); 319 EXPECT_EQ(kScoresName, session_.output_tensor_names[1]); 320 } 321 322 TEST_F(RunClassificationTest, ClassesOnly) { 323 input_tensor_ = test::AsTensor<int>({99}); 324 session_.outputs = {test::AsTensor<int>({3})}; 325 const Status status = RunClassification(signature_, input_tensor_, &session_, 326 &classes_tensor_, nullptr); 327 328 // Validate outputs. 329 TF_ASSERT_OK(status); 330 test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_); 331 332 // Validate inputs. 333 ASSERT_EQ(1, session_.inputs.size()); 334 EXPECT_EQ(kInputName, session_.inputs[0].first); 335 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}), 336 session_.inputs[0].second); 337 338 ASSERT_EQ(1, session_.output_tensor_names.size()); 339 EXPECT_EQ(kClassesName, session_.output_tensor_names[0]); 340 } 341 342 TEST_F(RunClassificationTest, ScoresOnly) { 343 input_tensor_ = test::AsTensor<int>({99}); 344 session_.outputs = {test::AsTensor<int>({2})}; 345 const Status status = RunClassification(signature_, input_tensor_, &session_, 346 nullptr, &scores_tensor_); 347 348 // Validate outputs. 349 TF_ASSERT_OK(status); 350 test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_); 351 352 // Validate inputs. 353 ASSERT_EQ(1, session_.inputs.size()); 354 EXPECT_EQ(kInputName, session_.inputs[0].first); 355 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}), 356 session_.inputs[0].second); 357 358 ASSERT_EQ(1, session_.output_tensor_names.size()); 359 EXPECT_EQ(kScoresName, session_.output_tensor_names[0]); 360 } 361 362 TEST(RunClassification, RunNotOk) { 363 ClassificationSignature signature; 364 signature.mutable_input()->set_tensor_name("in:0"); 365 signature.mutable_classes()->set_tensor_name("classes:0"); 366 Tensor input_tensor = test::AsTensor<int>({99}); 367 MockSession session; 368 session.status = errors::DataLoss("Data is gone"); 369 Tensor classes_tensor; 370 const Status status = RunClassification(signature, input_tensor, &session, 371 &classes_tensor, nullptr); 372 ASSERT_FALSE(status.ok()); 373 EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone")) 374 << status.error_message(); 375 } 376 377 TEST(RunClassification, TooManyOutputs) { 378 ClassificationSignature signature; 379 signature.mutable_input()->set_tensor_name("in:0"); 380 signature.mutable_classes()->set_tensor_name("classes:0"); 381 Tensor input_tensor = test::AsTensor<int>({99}); 382 MockSession session; 383 session.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({4})}; 384 385 Tensor classes_tensor; 386 const Status status = RunClassification(signature, input_tensor, &session, 387 &classes_tensor, nullptr); 388 ASSERT_FALSE(status.ok()); 389 EXPECT_TRUE(StringPiece(status.error_message()).contains("Expected 1 output")) 390 << status.error_message(); 391 } 392 393 TEST(RunClassification, WrongBatchOutputs) { 394 ClassificationSignature signature; 395 signature.mutable_input()->set_tensor_name("in:0"); 396 signature.mutable_classes()->set_tensor_name("classes:0"); 397 Tensor input_tensor = test::AsTensor<int>({99, 100}); 398 MockSession session; 399 session.outputs = {test::AsTensor<int>({3})}; 400 401 Tensor classes_tensor; 402 const Status status = RunClassification(signature, input_tensor, &session, 403 &classes_tensor, nullptr); 404 ASSERT_FALSE(status.ok()); 405 EXPECT_TRUE(StringPiece(status.error_message()) 406 .contains("Input batch size did not match output batch size")) 407 << status.error_message(); 408 } 409 410 constexpr char kRegressionsName[] = "regressions:0"; 411 412 class RunRegressionTest : public ::testing::Test { 413 public: 414 void SetUp() override { 415 signature_.mutable_input()->set_tensor_name(kInputName); 416 signature_.mutable_output()->set_tensor_name(kRegressionsName); 417 } 418 419 protected: 420 RegressionSignature signature_; 421 Tensor input_tensor_; 422 Tensor output_tensor_; 423 MockSession session_; 424 }; 425 426 TEST_F(RunRegressionTest, Basic) { 427 input_tensor_ = test::AsTensor<int>({99, 100}); 428 session_.outputs = {test::AsTensor<float>({1, 2})}; 429 const Status status = 430 RunRegression(signature_, input_tensor_, &session_, &output_tensor_); 431 432 // Validate outputs. 433 TF_ASSERT_OK(status); 434 test::ExpectTensorEqual<float>(test::AsTensor<float>({1, 2}), output_tensor_); 435 436 // Validate inputs. 437 ASSERT_EQ(1, session_.inputs.size()); 438 EXPECT_EQ(kInputName, session_.inputs[0].first); 439 test::ExpectTensorEqual<int>(test::AsTensor<int>({99, 100}), 440 session_.inputs[0].second); 441 442 ASSERT_EQ(1, session_.output_tensor_names.size()); 443 EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]); 444 } 445 446 TEST_F(RunRegressionTest, RunNotOk) { 447 input_tensor_ = test::AsTensor<int>({99}); 448 session_.status = errors::DataLoss("Data is gone"); 449 const Status status = 450 RunRegression(signature_, input_tensor_, &session_, &output_tensor_); 451 ASSERT_FALSE(status.ok()); 452 EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone")) 453 << status.error_message(); 454 } 455 456 TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) { 457 input_tensor_ = test::AsTensor<int>({99, 100}); 458 session_.outputs = {test::AsTensor<float>({3})}; 459 460 const Status status = 461 RunRegression(signature_, input_tensor_, &session_, &output_tensor_); 462 ASSERT_FALSE(status.ok()); 463 EXPECT_TRUE(StringPiece(status.error_message()) 464 .contains("Input batch size did not match output batch size")) 465 << status.error_message(); 466 } 467 468 TEST(SetAndGetSignatures, RoundTrip) { 469 tensorflow::MetaGraphDef meta_graph_def; 470 Signatures signatures; 471 signatures.mutable_default_signature() 472 ->mutable_classification_signature() 473 ->mutable_input() 474 ->set_tensor_name("in:0"); 475 TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def)); 476 Signatures read_signatures; 477 TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures)); 478 479 EXPECT_EQ("in:0", read_signatures.default_signature() 480 .classification_signature() 481 .input() 482 .tensor_name()); 483 } 484 485 TEST(GetSignatures, MissingSignature) { 486 tensorflow::MetaGraphDef meta_graph_def; 487 Signatures read_signatures; 488 const auto status = GetSignatures(meta_graph_def, &read_signatures); 489 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); 490 EXPECT_TRUE( 491 StringPiece(status.error_message()).contains("Expected exactly one")) 492 << status.error_message(); 493 } 494 495 TEST(GetSignatures, WrongProtoInAny) { 496 tensorflow::MetaGraphDef meta_graph_def; 497 auto& collection_def = *(meta_graph_def.mutable_collection_def()); 498 auto* any = 499 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); 500 // Put an unexpected type into the Signatures Any. 501 any->PackFrom(TensorBinding()); 502 Signatures read_signatures; 503 const auto status = GetSignatures(meta_graph_def, &read_signatures); 504 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); 505 EXPECT_TRUE(StringPiece(status.error_message()) 506 .contains("Expected Any type_url for: " 507 "tensorflow.serving.Signatures")) 508 << status.error_message(); 509 } 510 511 TEST(GetSignatures, JunkInAny) { 512 tensorflow::MetaGraphDef meta_graph_def; 513 auto& collection_def = *(meta_graph_def.mutable_collection_def()); 514 auto* any = 515 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); 516 // Create a valid Any then corrupt it. 517 any->PackFrom(Signatures()); 518 any->set_value("junk junk"); 519 Signatures read_signatures; 520 const auto status = GetSignatures(meta_graph_def, &read_signatures); 521 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); 522 EXPECT_TRUE(StringPiece(status.error_message()).contains("Failed to unpack")) 523 << status.error_message(); 524 } 525 526 TEST(GetSignatures, DefaultAndNamedTogetherOK) { 527 tensorflow::MetaGraphDef meta_graph_def; 528 auto& collection_def = *(meta_graph_def.mutable_collection_def()); 529 auto* any = 530 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); 531 Signatures signatures; 532 signatures.mutable_default_signature() 533 ->mutable_classification_signature() 534 ->mutable_input() 535 ->set_tensor_name("in:0"); 536 ClassificationSignature* input_signature = 537 (*signatures.mutable_named_signatures())["foo"] 538 .mutable_classification_signature(); 539 input_signature->mutable_input()->set_tensor_name("flow"); 540 541 any->PackFrom(signatures); 542 Signatures read_signatures; 543 const auto status = GetSignatures(meta_graph_def, &read_signatures); 544 545 EXPECT_TRUE(status.ok()); 546 } 547 548 // Check that we only have one 'Signatures' entry in the collection_def map. 549 // Note that each such object can have multiple named_signatures inside of it. 550 TEST(GetSignatures, MultipleSignaturesNotOK) { 551 tensorflow::MetaGraphDef meta_graph_def; 552 auto& collection_def = *(meta_graph_def.mutable_collection_def()); 553 auto* any = 554 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); 555 Signatures signatures; 556 signatures.mutable_default_signature() 557 ->mutable_classification_signature() 558 ->mutable_input() 559 ->set_tensor_name("in:0"); 560 any->PackFrom(signatures); 561 562 // Add another signatures object. 563 any = 564 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add(); 565 any->PackFrom(signatures); 566 Signatures read_signatures; 567 const auto status = GetSignatures(meta_graph_def, &read_signatures); 568 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); 569 EXPECT_TRUE( 570 StringPiece(status.error_message()).contains("Expected exactly one")) 571 << status.error_message(); 572 } 573 574 // GenericSignature test fixture that contains a signature initialized with two 575 // bound Tensors. 576 class GenericSignatureTest : public ::testing::Test { 577 protected: 578 GenericSignatureTest() { 579 TensorBinding binding; 580 binding.set_tensor_name("graph_A"); 581 signature_.mutable_map()->insert({"logical_A", binding}); 582 583 binding.set_tensor_name("graph_B"); 584 signature_.mutable_map()->insert({"logical_B", binding}); 585 } 586 587 // GenericSignature that contains two bound Tensors. 588 GenericSignature signature_; 589 }; 590 591 // GenericSignature tests. 592 593 TEST_F(GenericSignatureTest, GetGenericSignatureBasic) { 594 Signature expected_signature; 595 expected_signature.mutable_generic_signature()->MergeFrom(signature_); 596 597 tensorflow::MetaGraphDef meta_graph_def; 598 Signatures signatures; 599 signatures.mutable_named_signatures()->insert( 600 {"generic_bindings", expected_signature}); 601 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 602 .mutable_any_list() 603 ->add_value() 604 ->PackFrom(signatures); 605 606 GenericSignature actual_signature; 607 TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def, 608 &actual_signature)); 609 ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name()); 610 ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name()); 611 } 612 613 TEST(GetGenericSignature, MissingSignature) { 614 tensorflow::MetaGraphDef meta_graph_def; 615 Signatures signatures; 616 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 617 .mutable_any_list() 618 ->add_value() 619 ->PackFrom(signatures); 620 621 GenericSignature signature; 622 const Status status = 623 GetGenericSignature("generic_bindings", meta_graph_def, &signature); 624 ASSERT_FALSE(status.ok()); 625 EXPECT_TRUE(HasSubstr(status.error_message(), 626 "Missing generic signature named \"generic_bindings\"")) 627 << status.error_message(); 628 } 629 630 TEST(GetGenericSignature, WrongSignatureType) { 631 tensorflow::MetaGraphDef meta_graph_def; 632 Signatures signatures; 633 (*signatures.mutable_named_signatures())["generic_bindings"] 634 .mutable_regression_signature(); 635 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 636 .mutable_any_list() 637 ->add_value() 638 ->PackFrom(signatures); 639 640 GenericSignature signature; 641 const Status status = 642 GetGenericSignature("generic_bindings", meta_graph_def, &signature); 643 ASSERT_FALSE(status.ok()); 644 EXPECT_TRUE(StringPiece(status.error_message()) 645 .contains("Expected a generic signature:")) 646 << status.error_message(); 647 } 648 649 // BindGeneric Tests. 650 651 TEST_F(GenericSignatureTest, BindGenericInputsBasic) { 652 const std::vector<std::pair<string, Tensor>> inputs = { 653 {"logical_A", test::AsTensor<float>({-1.0})}, 654 {"logical_B", test::AsTensor<float>({-2.0})}}; 655 656 std::vector<std::pair<string, Tensor>> bound_inputs; 657 TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs)); 658 659 EXPECT_EQ("graph_A", bound_inputs[0].first); 660 EXPECT_EQ("graph_B", bound_inputs[1].first); 661 test::ExpectTensorEqual<float>(test::AsTensor<float>({-1.0}), 662 bound_inputs[0].second); 663 test::ExpectTensorEqual<float>(test::AsTensor<float>({-2.0}), 664 bound_inputs[1].second); 665 } 666 667 TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) { 668 const std::vector<std::pair<string, Tensor>> inputs = { 669 {"logical_A", test::AsTensor<float>({-42.0})}, 670 {"logical_MISSING", test::AsTensor<float>({-43.0})}}; 671 672 std::vector<std::pair<string, Tensor>> bound_inputs; 673 const Status status = BindGenericInputs(signature_, inputs, &bound_inputs); 674 ASSERT_FALSE(status.ok()); 675 } 676 677 TEST_F(GenericSignatureTest, BindGenericNamesBasic) { 678 const std::vector<string> input_names = {"logical_B", "logical_A"}; 679 std::vector<string> bound_names; 680 TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names)); 681 682 EXPECT_EQ("graph_B", bound_names[0]); 683 EXPECT_EQ("graph_A", bound_names[1]); 684 } 685 686 TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) { 687 const std::vector<string> input_names = {"logical_B", "logical_MISSING"}; 688 std::vector<string> bound_names; 689 const Status status = BindGenericNames(signature_, input_names, &bound_names); 690 ASSERT_FALSE(status.ok()); 691 } 692 693 } // namespace 694 } // namespace serving 695 } // namespace tensorflow 696