1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/static_schedule.h" 17 #include "tensorflow/cc/ops/standard_ops.h" 18 #include "tensorflow/core/framework/node_def.pb.h" 19 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 20 #include "tensorflow/core/grappler/grappler_item.h" 21 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" 22 #include "tensorflow/core/grappler/utils.h" 23 #include "tensorflow/core/lib/core/status_test_util.h" 24 #include "tensorflow/core/platform/test.h" 25 26 namespace tensorflow { 27 namespace grappler { 28 namespace { 29 30 class StaticScheduleTest : public ::testing::Test { 31 public: 32 std::unique_ptr<VirtualCluster> CreateVirtualCluster() const { 33 // Invent a CPU so that predictions remain the same from machine to machine. 34 DeviceProperties cpu_device; 35 cpu_device.set_type("CPU"); 36 cpu_device.set_frequency(1000); 37 cpu_device.set_num_cores(4); 38 cpu_device.set_bandwidth(32); 39 cpu_device.set_l1_cache_size(32 * 1024); 40 cpu_device.set_l2_cache_size(256 * 1024); 41 cpu_device.set_l3_cache_size(4 * 1024 * 1024); 42 std::unordered_map<string, DeviceProperties> devices; 43 devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; 44 return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices)); 45 } 46 }; 47 48 TEST_F(StaticScheduleTest, BasicGraph) { 49 // This trivial graph is so basic there's nothing to prune. 50 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); 51 GrapplerItem item; 52 CHECK(fake_input.NextItem(&item)); 53 54 std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster()); 55 56 std::unordered_map<const NodeDef*, Costs::NanoSeconds> completion_times; 57 Status status = 58 EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times); 59 TF_EXPECT_OK(status); 60 61 EXPECT_EQ(item.graph.node_size(), completion_times.size()); 62 63 for (auto time : completion_times) { 64 if (time.first->name() == "Const/Const") { 65 EXPECT_EQ(Costs::NanoSeconds(1), time.second); 66 } else if (time.first->name() == "x") { 67 EXPECT_EQ(Costs::NanoSeconds(250001), time.second); 68 } else if (time.first->name() == "Square") { 69 EXPECT_EQ(Costs::NanoSeconds(1500004), time.second); 70 } else if (time.first->name() == "Square_1") { 71 EXPECT_EQ(Costs::NanoSeconds(2750007), time.second); 72 } else if (time.first->name() == "Square_2") { 73 EXPECT_EQ(Costs::NanoSeconds(4000010), time.second); 74 } else if (time.first->name() == "Square_3") { 75 EXPECT_EQ(Costs::NanoSeconds(5250013), time.second); 76 } else if (time.first->name() == "y") { 77 EXPECT_EQ(Costs::NanoSeconds(6500013), time.second); 78 } 79 } 80 } 81 82 TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) { 83 // Build a simple graph with a control dependency. 84 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 85 86 Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); 87 Output b = ops::AddN(s.WithOpName("b"), {a}); 88 Output c = ops::Identity(s.WithOpName("c"), b); 89 Output d = ops::Identity(s.WithOpName("d"), c); 90 Output e = ops::AddN(s.WithOpName("e"), {d}); 91 92 GrapplerItem item; 93 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 94 95 // Add a control dependency between c and e. 96 EXPECT_EQ("c", item.graph.node(2).name()); 97 EXPECT_EQ("e", item.graph.node(4).name()); 98 *item.graph.mutable_node(4)->add_input() = "^c"; 99 100 std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster()); 101 102 std::unordered_map<const NodeDef*, Costs::NanoSeconds> completion_times; 103 Status status = 104 EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times); 105 TF_EXPECT_OK(status); 106 107 EXPECT_EQ(item.graph.node_size(), completion_times.size()); 108 109 for (auto time : completion_times) { 110 if (time.first->name() == "a") { 111 EXPECT_EQ(Costs::NanoSeconds(1), time.second); 112 } else if (time.first->name() == "b") { 113 EXPECT_EQ(Costs::NanoSeconds(12500001), time.second); 114 } else if (time.first->name() == "c") { 115 EXPECT_EQ(Costs::NanoSeconds(12500002), time.second); 116 } else if (time.first->name() == "d") { 117 EXPECT_EQ(Costs::NanoSeconds(12500003), time.second); 118 } else if (time.first->name() == "e") { 119 EXPECT_EQ(Costs::NanoSeconds(25000003), time.second); 120 } 121 } 122 } 123 124 TEST_F(StaticScheduleTest, RequiredTimes) { 125 // This trivial graph is so basic there's nothing to prune. 126 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); 127 GrapplerItem item; 128 CHECK(fake_input.NextItem(&item)); 129 130 std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster()); 131 132 std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times; 133 for (const NodeDef& node : item.graph.node()) { 134 execution_times[&node] = 0; 135 } 136 std::unordered_map<const NodeDef*, Costs::NanoSeconds> required_times; 137 Status status = EstimateRequiredTimes(item, cluster.get(), execution_times, 138 &required_times); 139 TF_EXPECT_OK(status); 140 141 EXPECT_EQ(item.graph.node_size(), required_times.size()); 142 143 for (auto time : required_times) { 144 if (time.first->name() == "Const/Const") { 145 EXPECT_EQ(Costs::NanoSeconds(-6500012), time.second); 146 } else if (time.first->name() == "x") { 147 EXPECT_EQ(Costs::NanoSeconds(-6250012), time.second); 148 } else if (time.first->name() == "Square") { 149 EXPECT_EQ(Costs::NanoSeconds(-5000009), time.second); 150 } else if (time.first->name() == "Square_1") { 151 EXPECT_EQ(Costs::NanoSeconds(-3750006), time.second); 152 } else if (time.first->name() == "Square_2") { 153 EXPECT_EQ(Costs::NanoSeconds(-2500003), time.second); 154 } else if (time.first->name() == "Square_3") { 155 EXPECT_EQ(Costs::NanoSeconds(-1250000), time.second); 156 } else if (time.first->name() == "y") { 157 EXPECT_EQ(Costs::NanoSeconds(0), time.second); 158 } 159 } 160 } 161 162 } // namespace 163 } // namespace grappler 164 } // namespace tensorflow 165