Home | History | Annotate | Download | only in session_bundle
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/session_bundle/session_bundle.h"
     17 
     18 #include <string>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "google/protobuf/any.pb.h"
     23 #include "tensorflow/contrib/session_bundle/signature.h"
     24 #include "tensorflow/contrib/session_bundle/test_util.h"
     25 #include "tensorflow/core/example/example.pb.h"
     26 #include "tensorflow/core/example/feature.pb.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_testutil.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 #include "tensorflow/core/lib/io/path.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/platform/types.h"
     35 #include "tensorflow/core/public/session.h"
     36 #include "tensorflow/core/public/session_options.h"
     37 
     38 namespace tensorflow {
     39 namespace serving {
     40 namespace {
     41 
     42 // Constants for the export path and file-names.
     43 const char kExportPath[] = "session_bundle/testdata/half_plus_two/00000123";
     44 const char kExportCheckpointV2Path[] =
     45     "session_bundle/testdata/half_plus_two_ckpt_v2/00000123";
     46 const char kMetaGraphDefFilename[] = "export.meta";
     47 const char kVariablesFilename[] = "export-00000-of-00001";
     48 
     49 // Function used to rewrite a MetaGraphDef.
     50 using MetaGraphDefTwiddler = std::function<void(MetaGraphDef*)>;
     51 
     52 // Copy the base half_plus_two to `export_path`.
     53 // Outputs the files using the passed names (typically the constants above).
     54 // The Twiddler can be used to update the MetaGraphDef before output.
     55 Status CopyExport(const string& export_path, const string& variables_filename,
     56                   const string& meta_graph_def_filename,
     57                   const MetaGraphDefTwiddler& twiddler) {
     58   TF_RETURN_IF_ERROR(Env::Default()->CreateDir(export_path));
     59   const string orig_path = test_util::TestSrcDirPath(kExportPath);
     60   {
     61     const string source = io::JoinPath(orig_path, kVariablesFilename);
     62     const string sink = io::JoinPath(export_path, variables_filename);
     63 
     64     string data;
     65     TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), source, &data));
     66     TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), sink, data));
     67   }
     68   {
     69     const string source = io::JoinPath(orig_path, kMetaGraphDefFilename);
     70     const string sink = io::JoinPath(export_path, meta_graph_def_filename);
     71 
     72     MetaGraphDef graph_def;
     73     TF_RETURN_IF_ERROR(ReadBinaryProto(Env::Default(), source, &graph_def));
     74     twiddler(&graph_def);
     75     TF_RETURN_IF_ERROR(
     76         WriteStringToFile(Env::Default(), sink, graph_def.SerializeAsString()));
     77   }
     78   return Status::OK();
     79 }
     80 
     81 string MakeSerializedExample(float x) {
     82   tensorflow::Example example;
     83   auto* feature_map = example.mutable_features()->mutable_feature();
     84   (*feature_map)["x"].mutable_float_list()->add_value(x);
     85   return example.SerializeAsString();
     86 }
     87 
     88 void CheckRegressionSignature(const Signatures& signatures,
     89                               const SessionBundle& bundle) {
     90   // Recover the Tensor names of our inputs and outputs.
     91   ASSERT_TRUE(signatures.default_signature().has_regression_signature());
     92   const RegressionSignature regression_signature =
     93       signatures.default_signature().regression_signature();
     94 
     95   const string input_name = regression_signature.input().tensor_name();
     96   const string output_name = regression_signature.output().tensor_name();
     97 
     98   // Validate the half plus two behavior.
     99   std::vector<string> serialized_examples;
    100   for (float x : {0, 1, 2, 3}) {
    101     serialized_examples.push_back(MakeSerializedExample(x));
    102   }
    103   Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4}));
    104   std::vector<Tensor> outputs;
    105   TF_ASSERT_OK(
    106       bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
    107   ASSERT_EQ(outputs.size(), 1);
    108   test::ExpectTensorEqual<float>(
    109       outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
    110 }
    111 
    112 void CheckNamedSignatures(const Signatures& signatures,
    113                           const SessionBundle& bundle) {
    114   // Recover the Tensor names of our inputs and outputs.
    115   const string input_name = signatures.named_signatures()
    116                                 .at("inputs")
    117                                 .generic_signature()
    118                                 .map()
    119                                 .at("x")
    120                                 .tensor_name();
    121   const string output_name = signatures.named_signatures()
    122                                  .at("outputs")
    123                                  .generic_signature()
    124                                  .map()
    125                                  .at("y")
    126                                  .tensor_name();
    127 
    128   // Validate the half plus two behavior.
    129   Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
    130   std::vector<Tensor> outputs;
    131   TF_ASSERT_OK(
    132       bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
    133   ASSERT_EQ(outputs.size(), 1);
    134   test::ExpectTensorEqual<float>(
    135       outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
    136 }
    137 
    138 void CheckSessionBundle(const string& export_path,
    139                         const SessionBundle& bundle) {
    140   const string asset_path = io::JoinPath(export_path, kAssetsDirectory);
    141   // Validate the assets behavior.
    142   std::vector<Tensor> path_outputs;
    143   TF_ASSERT_OK(bundle.session->Run({}, {"filename1:0", "filename2:0"}, {},
    144                                    &path_outputs));
    145   ASSERT_EQ(2, path_outputs.size());
    146   // Validate the two asset file tensors are set by the init_op and include the
    147   // base_path and asset directory.
    148   test::ExpectTensorEqual<string>(
    149       test::AsTensor<string>({io::JoinPath(asset_path, "hello1.txt")},
    150                              TensorShape({})),
    151       path_outputs[0]);
    152   test::ExpectTensorEqual<string>(
    153       test::AsTensor<string>({io::JoinPath(asset_path, "hello2.txt")},
    154                              TensorShape({})),
    155       path_outputs[1]);
    156 
    157   Signatures signatures;
    158   TF_ASSERT_OK(GetSignatures(bundle.meta_graph_def, &signatures));
    159   CheckRegressionSignature(signatures, bundle);
    160   CheckNamedSignatures(signatures, bundle);
    161 }
    162 
    163 void BasicTest(const string& export_path) {
    164   SessionOptions options;
    165   SessionBundle bundle;
    166   TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle));
    167   CheckSessionBundle(export_path, bundle);
    168 }
    169 
    170 // Test for resource leaks when loading and unloading large numbers of
    171 // SessionBundles. Concurrent with adding this test, we had a leak where the
    172 // TensorFlow Session was not being closed, which leaked memory.
    173 // TODO(b/31711147): Increase the SessionBundle ResourceLeakTest iterations and
    174 // move outside of the test suite; decrease test size back to small at the same
    175 // time.
    176 TEST(LoadSessionBundleFromPath, ResourceLeakTest) {
    177   const string export_path = test_util::TestSrcDirPath(kExportPath);
    178   for (int i = 0; i < 100; i++) {
    179     BasicTest(export_path);
    180   }
    181 }
    182 
    183 TEST(LoadSessionBundleFromPath, BasicTensorFlowContrib) {
    184   const string export_path = test_util::TestSrcDirPath(kExportPath);
    185   BasicTest(export_path);
    186 }
    187 
    188 TEST(LoadSessionBundleFromPath, BasicTestRunOptions) {
    189   const string export_path = test_util::TestSrcDirPath(kExportPath);
    190 
    191   // Use default session-options.
    192   SessionOptions session_options;
    193 
    194   // Setup run-options with full-traces.
    195   RunOptions run_options;
    196   run_options.set_trace_level(RunOptions::FULL_TRACE);
    197 
    198   SessionBundle bundle;
    199   TF_ASSERT_OK(LoadSessionBundleFromPathUsingRunOptions(
    200       session_options, run_options, export_path, &bundle));
    201   CheckSessionBundle(export_path, bundle);
    202 }
    203 
    204 TEST(LoadSessionBundleFromPath, BasicTestRunOptionsThreadPool) {
    205   const string export_path = test_util::TestSrcDirPath(kExportPath);
    206   const int32 threadpool_index = 1;
    207 
    208   // Setup session-options with separate thread-pools.
    209   SessionOptions session_options;
    210   session_options.config.add_session_inter_op_thread_pool();
    211   session_options.config.add_session_inter_op_thread_pool()->set_num_threads(2);
    212 
    213   // Setup run-options with the threadpool index to use.
    214   RunOptions run_options;
    215   run_options.set_inter_op_thread_pool(threadpool_index);
    216 
    217   SessionBundle bundle;
    218   TF_ASSERT_OK(LoadSessionBundleFromPathUsingRunOptions(
    219       session_options, run_options, export_path, &bundle));
    220   CheckSessionBundle(export_path, bundle);
    221 }
    222 
    223 TEST(LoadSessionBundleFromPath, BasicTestRunOptionsThreadPoolInvalid) {
    224   const string export_path = test_util::TestSrcDirPath(kExportPath);
    225   const int32 invalid_threadpool_index = 2;
    226 
    227   // Setup session-options with separate thread-pools.
    228   SessionOptions session_options;
    229   session_options.config.add_session_inter_op_thread_pool();
    230   session_options.config.add_session_inter_op_thread_pool()->set_num_threads(2);
    231 
    232   // Setup run-options with an invalid threadpool index.
    233   RunOptions run_options;
    234   run_options.set_inter_op_thread_pool(invalid_threadpool_index);
    235 
    236   SessionBundle bundle;
    237   Status status = LoadSessionBundleFromPathUsingRunOptions(
    238       session_options, run_options, export_path, &bundle);
    239 
    240   // Expect failed session run calls with invalid run-options.
    241   EXPECT_FALSE(status.ok());
    242   EXPECT_TRUE(StringPiece(status.error_message())
    243                   .contains("Invalid inter_op_thread_pool: 2"))
    244       << status.error_message();
    245 }
    246 
    247 TEST(LoadSessionBundleFromPath, BadExportPath) {
    248   const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot");
    249   SessionOptions options;
    250   options.target = "local";
    251   SessionBundle bundle;
    252   const auto status = LoadSessionBundleFromPath(options, export_path, &bundle);
    253   ASSERT_FALSE(status.ok());
    254   const string msg = status.ToString();
    255   EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg;
    256 }
    257 
    258 TEST(CheckpointV2Test, LoadSessionBundleFromPath) {
    259   const string export_path = test_util::TestSrcDirPath(kExportCheckpointV2Path);
    260   BasicTest(export_path);
    261 }
    262 
    263 TEST(CheckpointV2Test, IsPossibleExportDirectory) {
    264   const string export_path = test_util::TestSrcDirPath(kExportCheckpointV2Path);
    265   EXPECT_TRUE(IsPossibleExportDirectory(export_path));
    266 }
    267 
    268 class SessionBundleTest : public ::testing::Test {
    269  protected:
    270   // Copy the half_plus_two graph and apply the twiddler to rewrite the
    271   // MetaGraphDef.
    272   // Returns the path of the export.
    273   // ** Should only be called once per test **
    274   string SetupExport(const MetaGraphDefTwiddler& twiddler) {
    275     return SetupExport(twiddler, kVariablesFilename, kMetaGraphDefFilename);
    276   }
    277   // SetupExport that allows for the variables and meta_graph_def filenames
    278   // to be overridden.
    279   string SetupExport(const MetaGraphDefTwiddler& twiddler,
    280                      const string& variables_filename,
    281                      const string& meta_graph_def_filename) {
    282     // Construct a unique path name based on the test name.
    283     const ::testing::TestInfo* const test_info =
    284         ::testing::UnitTest::GetInstance()->current_test_info();
    285     const string export_path = io::JoinPath(
    286         testing::TmpDir(),
    287         strings::StrCat(test_info->test_case_name(), test_info->name()));
    288     TF_CHECK_OK(CopyExport(export_path, variables_filename,
    289                            meta_graph_def_filename, twiddler));
    290     return export_path;
    291   }
    292 
    293   SessionOptions options_;
    294   SessionBundle bundle_;
    295   Status status_;
    296 };
    297 
    298 TEST_F(SessionBundleTest, Basic) {
    299   const string export_path = SetupExport([](MetaGraphDef*) {});
    300   BasicTest(export_path);
    301 }
    302 
    303 TEST_F(SessionBundleTest, UnshardedVariableFile) {
    304   // Test that we can properly read the variables when exported
    305   // without sharding.
    306   const string export_path =
    307       SetupExport([](MetaGraphDef*) {}, "export", kMetaGraphDefFilename);
    308   BasicTest(export_path);
    309 }
    310 
    311 TEST_F(SessionBundleTest, ServingGraphEmpty) {
    312   const string path = SetupExport([](MetaGraphDef* def) {
    313     (*def->mutable_collection_def())[kGraphKey].clear_any_list();
    314   });
    315   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    316   EXPECT_FALSE(status_.ok());
    317   EXPECT_TRUE(StringPiece(status_.error_message())
    318                   .contains("Expected exactly one serving GraphDef"))
    319       << status_.error_message();
    320 }
    321 
    322 TEST_F(SessionBundleTest, ServingGraphAnyIncorrectType) {
    323   const string path = SetupExport([](MetaGraphDef* def) {
    324     // Pack an unexpected type in the GraphDef Any.
    325     (*def->mutable_collection_def())[kGraphKey].clear_any_list();
    326     auto* any = (*def->mutable_collection_def())[kGraphKey]
    327                     .mutable_any_list()
    328                     ->add_value();
    329     any->PackFrom(AssetFile());
    330   });
    331   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    332   EXPECT_FALSE(status_.ok());
    333   EXPECT_TRUE(StringPiece(status_.error_message())
    334                   .contains("Expected Any type_url for: tensorflow.GraphDef"))
    335       << status_.error_message();
    336 }
    337 
    338 TEST_F(SessionBundleTest, ServingGraphAnyValueCorrupted) {
    339   const string path = SetupExport([](MetaGraphDef* def) {
    340     // Pack an unexpected type in the GraphDef Any.
    341     (*def->mutable_collection_def())[kGraphKey].clear_any_list();
    342     auto* any = (*def->mutable_collection_def())[kGraphKey]
    343                     .mutable_any_list()
    344                     ->add_value();
    345     any->PackFrom(GraphDef());
    346     any->set_value("junk junk");
    347   });
    348   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    349   EXPECT_FALSE(status_.ok());
    350   EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
    351       << status_.error_message();
    352 }
    353 
    354 TEST_F(SessionBundleTest, AssetFileAnyIncorrectType) {
    355   const string path = SetupExport([](MetaGraphDef* def) {
    356     // Pack an unexpected type in the AssetFile Any.
    357     (*def->mutable_collection_def())[kAssetsKey].clear_any_list();
    358     auto* any = (*def->mutable_collection_def())[kAssetsKey]
    359                     .mutable_any_list()
    360                     ->add_value();
    361     any->PackFrom(GraphDef());
    362   });
    363   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    364   EXPECT_FALSE(status_.ok());
    365   EXPECT_TRUE(
    366       StringPiece(status_.error_message())
    367           .contains("Expected Any type_url for: tensorflow.serving.AssetFile"))
    368       << status_.error_message();
    369 }
    370 
    371 TEST_F(SessionBundleTest, AssetFileAnyValueCorrupted) {
    372   const string path = SetupExport([](MetaGraphDef* def) {
    373     // Pack an unexpected type in the AssetFile Any.
    374     (*def->mutable_collection_def())[kAssetsKey].clear_any_list();
    375     auto* any = (*def->mutable_collection_def())[kAssetsKey]
    376                     .mutable_any_list()
    377                     ->add_value();
    378     any->PackFrom(AssetFile());
    379     any->set_value("junk junk");
    380   });
    381   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    382   EXPECT_FALSE(status_.ok());
    383   EXPECT_TRUE(StringPiece(status_.error_message()).contains("Failed to unpack"))
    384       << status_.error_message();
    385 }
    386 
    387 TEST_F(SessionBundleTest, InitOpTooManyValues) {
    388   const string path = SetupExport([](MetaGraphDef* def) {
    389     // Pack multiple init ops in to the collection.
    390     (*def->mutable_collection_def())[kInitOpKey].clear_node_list();
    391     auto* node_list =
    392         (*def->mutable_collection_def())[kInitOpKey].mutable_node_list();
    393     node_list->add_value("foo");
    394     node_list->add_value("bar");
    395   });
    396   status_ = LoadSessionBundleFromPath(options_, path, &bundle_);
    397   EXPECT_FALSE(status_.ok());
    398   EXPECT_TRUE(StringPiece(status_.error_message())
    399                   .contains("Expected exactly one serving init op"))
    400       << status_.error_message();
    401 }
    402 
    403 TEST_F(SessionBundleTest, PossibleExportDirectory) {
    404   const string export_path = SetupExport([](MetaGraphDef*) {});
    405   EXPECT_TRUE(IsPossibleExportDirectory(export_path));
    406 
    407   EXPECT_FALSE(
    408       IsPossibleExportDirectory(io::JoinPath(export_path, kAssetsDirectory)));
    409 }
    410 
    411 }  // namespace
    412 }  // namespace serving
    413 }  // namespace tensorflow
    414