1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/cc/ops/standard_ops.h" 16 #include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" 17 #include "tensorflow/core/framework/tensor_testutil.h" 18 #include "tensorflow/core/graph/graph_def_builder.h" 19 #include "tensorflow/core/lib/core/status_test_util.h" 20 #include "tensorflow/core/lib/io/path.h" 21 #include "tensorflow/core/platform/test.h" 22 #include "tensorflow/core/platform/test_benchmark.h" 23 #include "tensorflow/core/public/session.h" 24 #include "tensorflow/core/util/memmapped_file_system.h" 25 26 namespace tensorflow { 27 namespace { 28 29 bool GraphHasImmutableConstNodes(const GraphDef& graph_def) { 30 for (const auto& node : graph_def.node()) { 31 if (node.op() == "ImmutableConst") { 32 return true; 33 } 34 } 35 return false; 36 } 37 38 TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { 39 const string dir = testing::TmpDir(); 40 const string filename_pb = io::JoinPath(dir, "graphdef.pb"); 41 42 // Create a simple graph and write it to filename_pb. 43 constexpr int kTensorWidth = 4000; 44 constexpr int kTensorHeight = 100; 45 const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); 46 const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth}); 47 48 Tensor test_tensor1(DT_FLOAT, kTestTensorShape); 49 test::FillFn<float>(&test_tensor1, [](int) -> float { return 2.0; }); 50 51 Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT); 52 test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; }); 53 54 auto root = Scope::NewRootScope().ExitOnError(); 55 Output m = ops::MatMul(root, test_tensor1, test_tensor2); 56 const string result_name = m.node()->name(); 57 58 GraphDef graph_def; 59 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 60 string graph_def_serialized; 61 graph_def.SerializeToString(&graph_def_serialized); 62 TF_ASSERT_OK( 63 WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized)); 64 65 const string filename_mmap = io::JoinPath(dir, "graphdef.mmap"); 66 TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000)); 67 68 // Create and initialize MemmappedEnv from the converted file. 69 MemmappedEnv memmapped_env(Env::Default()); 70 TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap)); 71 72 // Load the graph and run calculations. 73 SessionOptions session_options; 74 session_options.env = &memmapped_env; 75 std::unique_ptr<Session> session(NewSession(session_options)); 76 ASSERT_TRUE(session != nullptr) << "Failed to create session"; 77 GraphDef loaded_graph_def; 78 TF_ASSERT_OK(ReadBinaryProto( 79 &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, 80 &loaded_graph_def)); 81 ASSERT_TRUE(GraphHasImmutableConstNodes(loaded_graph_def)); 82 83 TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph"; 84 std::vector<Tensor> outputs; 85 TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs)); 86 ASSERT_EQ(outputs.size(), 1); 87 EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f * kTensorHeight); 88 EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f * kTensorHeight); 89 EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight); 90 } 91 92 TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) { 93 // Create a graph with strings. 94 const string dir = testing::TmpDir(); 95 const string filename_pb = io::JoinPath(dir, "string_graphdef.pb"); 96 97 constexpr int kTensorWidth = 4000; 98 constexpr int kTensorHeight = 100; 99 const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); 100 Tensor test_tensor1(DT_STRING, kTestTensorShape); 101 test::FillFn<string>(&test_tensor1, [](int) -> string { return "ABC"; }); 102 103 Tensor test_tensor2(DT_STRING, kTestTensorShape); 104 test::FillFn<string>(&test_tensor2, [](int) -> string { return "XYZ"; }); 105 auto root = Scope::NewRootScope().ExitOnError(); 106 Output m = ops::Add(root, test_tensor1, test_tensor2); 107 const string result_name = m.node()->name(); 108 109 GraphDef graph_def; 110 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 111 string graph_def_serialized; 112 graph_def.SerializeToString(&graph_def_serialized); 113 TF_ASSERT_OK( 114 WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized)); 115 116 const string filename_mmap = io::JoinPath(dir, "string_graphdef.mmap"); 117 TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 1000)); 118 119 // Create and initialize MemmappedEnv from the converted file. 120 MemmappedEnv memmapped_env(Env::Default()); 121 TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap)); 122 123 // Load the graph and run calculations. 124 SessionOptions session_options; 125 session_options.env = &memmapped_env; 126 std::unique_ptr<Session> session(NewSession(session_options)); 127 ASSERT_TRUE(session != nullptr) << "Failed to create session"; 128 GraphDef loaded_graph_def; 129 TF_ASSERT_OK(ReadBinaryProto( 130 &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, 131 &loaded_graph_def)); 132 ASSERT_FALSE(GraphHasImmutableConstNodes(loaded_graph_def)); 133 } 134 135 } // namespace 136 } // namespace tensorflow 137