Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/static_schedule.h"
     17 #include "tensorflow/cc/ops/standard_ops.h"
     18 #include "tensorflow/core/framework/node_def.pb.h"
     19 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
     20 #include "tensorflow/core/grappler/grappler_item.h"
     21 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
     22 #include "tensorflow/core/grappler/utils.h"
     23 #include "tensorflow/core/lib/core/status_test_util.h"
     24 #include "tensorflow/core/platform/test.h"
     25 
     26 namespace tensorflow {
     27 namespace grappler {
     28 namespace {
     29 
     30 class StaticScheduleTest : public ::testing::Test {
     31  public:
     32   std::unique_ptr<VirtualCluster> CreateVirtualCluster() const {
     33     // Invent a CPU so that predictions remain the same from machine to machine.
     34     DeviceProperties cpu_device;
     35     cpu_device.set_type("CPU");
     36     cpu_device.set_frequency(1000);
     37     cpu_device.set_num_cores(4);
     38     cpu_device.set_bandwidth(32);
     39     cpu_device.set_l1_cache_size(32 * 1024);
     40     cpu_device.set_l2_cache_size(256 * 1024);
     41     cpu_device.set_l3_cache_size(4 * 1024 * 1024);
     42     std::unordered_map<string, DeviceProperties> devices;
     43     devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
     44     return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
     45   }
     46 };
     47 
     48 TEST_F(StaticScheduleTest, BasicGraph) {
     49   // This trivial graph is so basic there's nothing to prune.
     50   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
     51   GrapplerItem item;
     52   CHECK(fake_input.NextItem(&item));
     53 
     54   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
     55 
     56   std::unordered_map<const NodeDef*, Costs::NanoSeconds> completion_times;
     57   Status status =
     58       EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times);
     59   TF_EXPECT_OK(status);
     60 
     61   EXPECT_EQ(item.graph.node_size(), completion_times.size());
     62 
     63   for (auto time : completion_times) {
     64     if (time.first->name() == "Const/Const") {
     65       EXPECT_EQ(Costs::NanoSeconds(1), time.second);
     66     } else if (time.first->name() == "x") {
     67       EXPECT_EQ(Costs::NanoSeconds(250001), time.second);
     68     } else if (time.first->name() == "Square") {
     69       EXPECT_EQ(Costs::NanoSeconds(1500004), time.second);
     70     } else if (time.first->name() == "Square_1") {
     71       EXPECT_EQ(Costs::NanoSeconds(2750007), time.second);
     72     } else if (time.first->name() == "Square_2") {
     73       EXPECT_EQ(Costs::NanoSeconds(4000010), time.second);
     74     } else if (time.first->name() == "Square_3") {
     75       EXPECT_EQ(Costs::NanoSeconds(5250013), time.second);
     76     } else if (time.first->name() == "y") {
     77       EXPECT_EQ(Costs::NanoSeconds(6500013), time.second);
     78     }
     79   }
     80 }
     81 
     82 TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
     83   // Build a simple graph with a control dependency.
     84   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
     85 
     86   Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
     87   Output b = ops::AddN(s.WithOpName("b"), {a});
     88   Output c = ops::Identity(s.WithOpName("c"), b);
     89   Output d = ops::Identity(s.WithOpName("d"), c);
     90   Output e = ops::AddN(s.WithOpName("e"), {d});
     91 
     92   GrapplerItem item;
     93   TF_CHECK_OK(s.ToGraphDef(&item.graph));
     94 
     95   // Add a control dependency between c and e.
     96   EXPECT_EQ("c", item.graph.node(2).name());
     97   EXPECT_EQ("e", item.graph.node(4).name());
     98   *item.graph.mutable_node(4)->add_input() = "^c";
     99 
    100   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
    101 
    102   std::unordered_map<const NodeDef*, Costs::NanoSeconds> completion_times;
    103   Status status =
    104       EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times);
    105   TF_EXPECT_OK(status);
    106 
    107   EXPECT_EQ(item.graph.node_size(), completion_times.size());
    108 
    109   for (auto time : completion_times) {
    110     if (time.first->name() == "a") {
    111       EXPECT_EQ(Costs::NanoSeconds(1), time.second);
    112     } else if (time.first->name() == "b") {
    113       EXPECT_EQ(Costs::NanoSeconds(12500001), time.second);
    114     } else if (time.first->name() == "c") {
    115       EXPECT_EQ(Costs::NanoSeconds(12500002), time.second);
    116     } else if (time.first->name() == "d") {
    117       EXPECT_EQ(Costs::NanoSeconds(12500003), time.second);
    118     } else if (time.first->name() == "e") {
    119       EXPECT_EQ(Costs::NanoSeconds(25000003), time.second);
    120     }
    121   }
    122 }
    123 
    124 TEST_F(StaticScheduleTest, RequiredTimes) {
    125   // This trivial graph is so basic there's nothing to prune.
    126   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
    127   GrapplerItem item;
    128   CHECK(fake_input.NextItem(&item));
    129 
    130   std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
    131 
    132   std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
    133   for (const NodeDef& node : item.graph.node()) {
    134     execution_times[&node] = 0;
    135   }
    136   std::unordered_map<const NodeDef*, Costs::NanoSeconds> required_times;
    137   Status status = EstimateRequiredTimes(item, cluster.get(), execution_times,
    138                                         &required_times);
    139   TF_EXPECT_OK(status);
    140 
    141   EXPECT_EQ(item.graph.node_size(), required_times.size());
    142 
    143   for (auto time : required_times) {
    144     if (time.first->name() == "Const/Const") {
    145       EXPECT_EQ(Costs::NanoSeconds(-6500012), time.second);
    146     } else if (time.first->name() == "x") {
    147       EXPECT_EQ(Costs::NanoSeconds(-6250012), time.second);
    148     } else if (time.first->name() == "Square") {
    149       EXPECT_EQ(Costs::NanoSeconds(-5000009), time.second);
    150     } else if (time.first->name() == "Square_1") {
    151       EXPECT_EQ(Costs::NanoSeconds(-3750006), time.second);
    152     } else if (time.first->name() == "Square_2") {
    153       EXPECT_EQ(Costs::NanoSeconds(-2500003), time.second);
    154     } else if (time.first->name() == "Square_3") {
    155       EXPECT_EQ(Costs::NanoSeconds(-1250000), time.second);
    156     } else if (time.first->name() == "y") {
    157       EXPECT_EQ(Costs::NanoSeconds(0), time.second);
    158     }
    159   }
    160 }
    161 
    162 }  // namespace
    163 }  // namespace grappler
    164 }  // namespace tensorflow
    165