1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/graph_memory.h" 17 #include "tensorflow/cc/ops/standard_ops.h" 18 #include "tensorflow/core/grappler/grappler_item.h" 19 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" 20 #include "tensorflow/core/platform/test.h" 21 22 namespace tensorflow { 23 namespace grappler { 24 namespace { 25 26 class GraphMemoryTest : public ::testing::Test { 27 protected: 28 std::unordered_map<string, DeviceProperties> devices_; 29 30 public: 31 GraphMemoryTest() { 32 devices_["/CPU:0"].set_type("CPU"); 33 devices_["/CPU:0"].set_num_cores(1); 34 devices_["/CPU:0"].set_frequency(1); 35 devices_["/CPU:0"].set_bandwidth(1); 36 37 devices_["/GPU:0"].set_type("GPU"); 38 devices_["/GPU:0"].set_num_cores(1); 39 devices_["/GPU:0"].set_frequency(1); 40 devices_["/CPU:0"].set_bandwidth(1); 41 (*devices_["/GPU:0"].mutable_environment())["architecture"] = "3"; 42 } 43 }; 44 45 TEST_F(GraphMemoryTest, Basic) { 46 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"/CPU:0"}); 47 GrapplerItem item; 48 CHECK(fake_input.NextItem(&item)); 49 item.feed.clear(); 50 51 GraphMemory memory(item); 52 Status s = memory.InferStatically(devices_); 53 TF_CHECK_OK(s); 54 const GraphMemory::MemoryUsage& mem_usage = 55 memory.GetPeakMemoryUsage("/CPU:0"); 56 EXPECT_EQ(120, mem_usage.used_memory); 57 58 std::set<string> tensors; 59 for (const auto& t : mem_usage.live_tensors) { 60 tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 61 } 62 // When the execution of the 'Square' node completes, TF can start executing 63 // 'Square_1' and release the memory used by 'x'. Since we can't be sure of 64 // the order in which this takes place, in the worst case the 3 tensors are in 65 // memory. 66 std::set<string> expected; 67 expected.insert("Square:0"); 68 expected.insert("Square_1:0"); 69 expected.insert("x:0"); 70 EXPECT_EQ(expected, tensors); 71 } 72 73 TEST_F(GraphMemoryTest, UnknownBatchSize) { 74 TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {"/CPU:0"}); 75 GrapplerItem item; 76 CHECK(fake_input.NextItem(&item)); 77 item.feed.clear(); 78 79 GraphMemory memory(item); 80 Status s = memory.InferStatically(devices_); 81 TF_CHECK_OK(s); 82 // Same maths as before, except that batch size is unknown and therefore 83 // assumed to be one. 84 const GraphMemory::MemoryUsage& mem_usage = 85 memory.GetPeakMemoryUsage("/CPU:0"); 86 EXPECT_EQ(16, mem_usage.used_memory); 87 88 std::set<string> tensors; 89 for (const auto& t : mem_usage.live_tensors) { 90 tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 91 } 92 std::set<string> expected; 93 expected.insert("Const/Const:0"); 94 expected.insert("Square:0"); 95 expected.insert("x:0"); 96 EXPECT_EQ(expected, tensors); 97 } 98 99 TEST_F(GraphMemoryTest, MultiDevice) { 100 TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, 101 {"/CPU:0", "/GPU:0"}); 102 GrapplerItem item; 103 CHECK(fake_input.NextItem(&item)); 104 item.feed.clear(); 105 106 GraphMemory memory(item); 107 Status s = memory.InferStatically(devices_); 108 TF_CHECK_OK(s); 109 110 const GraphMemory::MemoryUsage& cpu_mem = memory.GetPeakMemoryUsage("/CPU:0"); 111 EXPECT_EQ(16777216, cpu_mem.used_memory); 112 std::set<string> cpu_tensors; 113 for (const auto& t : cpu_mem.live_tensors) { 114 cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 115 } 116 std::set<string> cpu_expected; 117 cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0"); 118 cpu_expected.insert("Square:0"); 119 cpu_expected.insert("x:0"); 120 cpu_expected.insert("AddN:0"); 121 EXPECT_EQ(cpu_expected, cpu_tensors); 122 123 const GraphMemory::MemoryUsage& gpu_mem = memory.GetPeakMemoryUsage("/GPU:0"); 124 EXPECT_EQ(16777216, gpu_mem.used_memory); 125 std::set<string> gpu_tensors; 126 for (const auto& t : gpu_mem.live_tensors) { 127 gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 128 } 129 std::set<string> gpu_expected; 130 gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0"); 131 gpu_expected.insert("Square_1:0"); 132 gpu_expected.insert("AddN_1:0"); 133 gpu_expected.insert("AddN_3:0"); 134 EXPECT_EQ(gpu_expected, gpu_tensors); 135 } 136 137 TEST_F(GraphMemoryTest, GpuSwapping) { 138 TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"}); 139 GrapplerItem item; 140 CHECK(fake_input.NextItem(&item)); 141 item.feed.clear(); 142 143 { 144 // Estimate the max memory usage for the graph. 145 GraphMemory memory(item); 146 Status s = memory.InferStatically(devices_); 147 TF_CHECK_OK(s); 148 149 const GraphMemory::MemoryUsage& gpu_mem = 150 memory.GetPeakMemoryUsage("/GPU:0"); 151 EXPECT_EQ(20971520, gpu_mem.used_memory); 152 std::set<string> gpu_tensors; 153 for (const auto& t : gpu_mem.live_tensors) { 154 gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 155 } 156 std::set<string> gpu_expected; 157 gpu_expected.insert("Square:0"); 158 gpu_expected.insert("Square_1:0"); 159 gpu_expected.insert("AddN:0"); 160 gpu_expected.insert("AddN_1:0"); 161 gpu_expected.insert("AddN_2:0"); 162 EXPECT_EQ(gpu_expected, gpu_tensors); 163 } 164 165 { 166 // Swap the first input to node AddN_1: its fanin (the square nodes) should 167 // not appear in the max cut anymore. 168 for (auto& node : *item.graph.mutable_node()) { 169 if (node.name() == "AddN_1") { 170 (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0); 171 } 172 } 173 GraphMemory memory(item); 174 Status s = memory.InferStatically(devices_); 175 TF_CHECK_OK(s); 176 const GraphMemory::MemoryUsage& new_gpu_mem = 177 memory.GetPeakMemoryUsage("/GPU:0"); 178 EXPECT_EQ(20971520, new_gpu_mem.used_memory); 179 std::set<string> new_gpu_tensors; 180 for (const auto& t : new_gpu_mem.live_tensors) { 181 new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 182 } 183 std::set<string> new_gpu_expected; 184 new_gpu_expected.insert("AddN:0"); 185 new_gpu_expected.insert("AddN_1:0"); 186 new_gpu_expected.insert("AddN_2:0"); 187 new_gpu_expected.insert("AddN_3:0"); 188 new_gpu_expected.insert("AddN_4:0"); 189 EXPECT_EQ(new_gpu_expected, new_gpu_tensors); 190 } 191 } 192 193 TEST_F(GraphMemoryTest, CtrlDependencies) { 194 // Build a simple graph with a control dependency. 195 Scope s = Scope::NewRootScope(); 196 Output a = ops::Const(s.WithOpName("a").WithDevice("/CPU:0"), 10.0f, {3}); 197 Output v = 198 ops::Variable(s.WithOpName("v").WithDevice("/CPU:0"), {3}, DT_FLOAT); 199 Output assign = 200 ops::Assign(s.WithOpName("assign").WithDevice("/CPU:0"), v, a); 201 ops::NoOp init( 202 s.WithOpName("init").WithDevice("/CPU:0").WithControlDependencies( 203 assign)); 204 205 GrapplerItem item; 206 item.fetch.push_back("init"); 207 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 208 209 GraphMemory memory(item); 210 Status status = memory.InferStatically(devices_); 211 TF_CHECK_OK(status); 212 213 const GraphMemory::MemoryUsage& mem = memory.GetPeakMemoryUsage("/CPU:0"); 214 EXPECT_EQ(36, mem.used_memory); 215 std::set<string> tensors; 216 for (const auto& t : mem.live_tensors) { 217 tensors.insert(strings::StrCat(t.node, ":", t.output_id)); 218 } 219 std::set<string> expected; 220 expected.insert("a:0"); 221 expected.insert("v:0"); 222 expected.insert("assign:0"); 223 EXPECT_EQ(expected, tensors); 224 } 225 226 } // namespace 227 } // namespace grappler 228 } // namespace tensorflow 229