Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/graph_memory.h"
     17 #include "tensorflow/cc/ops/standard_ops.h"
     18 #include "tensorflow/core/grappler/grappler_item.h"
     19 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
     20 #include "tensorflow/core/platform/test.h"
     21 
     22 namespace tensorflow {
     23 namespace grappler {
     24 namespace {
     25 
     26 class GraphMemoryTest : public ::testing::Test {
     27  protected:
     28   std::unordered_map<string, DeviceProperties> devices_;
     29 
     30  public:
     31   GraphMemoryTest() {
     32     devices_["/CPU:0"].set_type("CPU");
     33     devices_["/CPU:0"].set_num_cores(1);
     34     devices_["/CPU:0"].set_frequency(1);
     35     devices_["/CPU:0"].set_bandwidth(1);
     36 
     37     devices_["/GPU:0"].set_type("GPU");
     38     devices_["/GPU:0"].set_num_cores(1);
     39     devices_["/GPU:0"].set_frequency(1);
     40     devices_["/CPU:0"].set_bandwidth(1);
     41     (*devices_["/GPU:0"].mutable_environment())["architecture"] = "3";
     42   }
     43 };
     44 
     45 TEST_F(GraphMemoryTest, Basic) {
     46   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"/CPU:0"});
     47   GrapplerItem item;
     48   CHECK(fake_input.NextItem(&item));
     49   item.feed.clear();
     50 
     51   GraphMemory memory(item);
     52   Status s = memory.InferStatically(devices_);
     53   TF_CHECK_OK(s);
     54   const GraphMemory::MemoryUsage& mem_usage =
     55       memory.GetPeakMemoryUsage("/CPU:0");
     56   EXPECT_EQ(120, mem_usage.used_memory);
     57 
     58   std::set<string> tensors;
     59   for (const auto& t : mem_usage.live_tensors) {
     60     tensors.insert(strings::StrCat(t.node, ":", t.output_id));
     61   }
     62   // When the execution of the 'Square' node completes, TF can start executing
     63   // 'Square_1' and release the memory used by 'x'. Since we can't be sure of
     64   // the order in which this takes place, in the worst case the 3 tensors are in
     65   // memory.
     66   std::set<string> expected;
     67   expected.insert("Square:0");
     68   expected.insert("Square_1:0");
     69   expected.insert("x:0");
     70   EXPECT_EQ(expected, tensors);
     71 }
     72 
     73 TEST_F(GraphMemoryTest, UnknownBatchSize) {
     74   TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {"/CPU:0"});
     75   GrapplerItem item;
     76   CHECK(fake_input.NextItem(&item));
     77   item.feed.clear();
     78 
     79   GraphMemory memory(item);
     80   Status s = memory.InferStatically(devices_);
     81   TF_CHECK_OK(s);
     82   // Same maths as before, except that batch size is unknown and therefore
     83   // assumed to be one.
     84   const GraphMemory::MemoryUsage& mem_usage =
     85       memory.GetPeakMemoryUsage("/CPU:0");
     86   EXPECT_EQ(16, mem_usage.used_memory);
     87 
     88   std::set<string> tensors;
     89   for (const auto& t : mem_usage.live_tensors) {
     90     tensors.insert(strings::StrCat(t.node, ":", t.output_id));
     91   }
     92   std::set<string> expected;
     93   expected.insert("Const/Const:0");
     94   expected.insert("Square:0");
     95   expected.insert("x:0");
     96   EXPECT_EQ(expected, tensors);
     97 }
     98 
     99 TEST_F(GraphMemoryTest, MultiDevice) {
    100   TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false,
    101                                           {"/CPU:0", "/GPU:0"});
    102   GrapplerItem item;
    103   CHECK(fake_input.NextItem(&item));
    104   item.feed.clear();
    105 
    106   GraphMemory memory(item);
    107   Status s = memory.InferStatically(devices_);
    108   TF_CHECK_OK(s);
    109 
    110   const GraphMemory::MemoryUsage& cpu_mem = memory.GetPeakMemoryUsage("/CPU:0");
    111   EXPECT_EQ(16777216, cpu_mem.used_memory);
    112   std::set<string> cpu_tensors;
    113   for (const auto& t : cpu_mem.live_tensors) {
    114     cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
    115   }
    116   std::set<string> cpu_expected;
    117   cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0");
    118   cpu_expected.insert("Square:0");
    119   cpu_expected.insert("x:0");
    120   cpu_expected.insert("AddN:0");
    121   EXPECT_EQ(cpu_expected, cpu_tensors);
    122 
    123   const GraphMemory::MemoryUsage& gpu_mem = memory.GetPeakMemoryUsage("/GPU:0");
    124   EXPECT_EQ(16777216, gpu_mem.used_memory);
    125   std::set<string> gpu_tensors;
    126   for (const auto& t : gpu_mem.live_tensors) {
    127     gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
    128   }
    129   std::set<string> gpu_expected;
    130   gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
    131   gpu_expected.insert("Square_1:0");
    132   gpu_expected.insert("AddN_1:0");
    133   gpu_expected.insert("AddN_3:0");
    134   EXPECT_EQ(gpu_expected, gpu_tensors);
    135 }
    136 
    137 TEST_F(GraphMemoryTest, GpuSwapping) {
    138   TrivialTestGraphInputYielder fake_input(4, 2, 1024 * 1024, false, {"/GPU:0"});
    139   GrapplerItem item;
    140   CHECK(fake_input.NextItem(&item));
    141   item.feed.clear();
    142 
    143   {
    144     // Estimate the max memory usage for the graph.
    145     GraphMemory memory(item);
    146     Status s = memory.InferStatically(devices_);
    147     TF_CHECK_OK(s);
    148 
    149     const GraphMemory::MemoryUsage& gpu_mem =
    150         memory.GetPeakMemoryUsage("/GPU:0");
    151     EXPECT_EQ(20971520, gpu_mem.used_memory);
    152     std::set<string> gpu_tensors;
    153     for (const auto& t : gpu_mem.live_tensors) {
    154       gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
    155     }
    156     std::set<string> gpu_expected;
    157     gpu_expected.insert("Square:0");
    158     gpu_expected.insert("Square_1:0");
    159     gpu_expected.insert("AddN:0");
    160     gpu_expected.insert("AddN_1:0");
    161     gpu_expected.insert("AddN_2:0");
    162     EXPECT_EQ(gpu_expected, gpu_tensors);
    163   }
    164 
    165   {
    166     // Swap the first input to node AddN_1: its fanin (the square nodes) should
    167     // not appear in the max cut anymore.
    168     for (auto& node : *item.graph.mutable_node()) {
    169       if (node.name() == "AddN_1") {
    170         (*node.mutable_attr())["_swap_to_host"].mutable_list()->add_i(0);
    171       }
    172     }
    173     GraphMemory memory(item);
    174     Status s = memory.InferStatically(devices_);
    175     TF_CHECK_OK(s);
    176     const GraphMemory::MemoryUsage& new_gpu_mem =
    177         memory.GetPeakMemoryUsage("/GPU:0");
    178     EXPECT_EQ(20971520, new_gpu_mem.used_memory);
    179     std::set<string> new_gpu_tensors;
    180     for (const auto& t : new_gpu_mem.live_tensors) {
    181       new_gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
    182     }
    183     std::set<string> new_gpu_expected;
    184     new_gpu_expected.insert("AddN:0");
    185     new_gpu_expected.insert("AddN_1:0");
    186     new_gpu_expected.insert("AddN_2:0");
    187     new_gpu_expected.insert("AddN_3:0");
    188     new_gpu_expected.insert("AddN_4:0");
    189     EXPECT_EQ(new_gpu_expected, new_gpu_tensors);
    190   }
    191 }
    192 
    193 TEST_F(GraphMemoryTest, CtrlDependencies) {
    194   // Build a simple graph with a control dependency.
    195   Scope s = Scope::NewRootScope();
    196   Output a = ops::Const(s.WithOpName("a").WithDevice("/CPU:0"), 10.0f, {3});
    197   Output v =
    198       ops::Variable(s.WithOpName("v").WithDevice("/CPU:0"), {3}, DT_FLOAT);
    199   Output assign =
    200       ops::Assign(s.WithOpName("assign").WithDevice("/CPU:0"), v, a);
    201   ops::NoOp init(
    202       s.WithOpName("init").WithDevice("/CPU:0").WithControlDependencies(
    203           assign));
    204 
    205   GrapplerItem item;
    206   item.fetch.push_back("init");
    207   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    208 
    209   GraphMemory memory(item);
    210   Status status = memory.InferStatically(devices_);
    211   TF_CHECK_OK(status);
    212 
    213   const GraphMemory::MemoryUsage& mem = memory.GetPeakMemoryUsage("/CPU:0");
    214   EXPECT_EQ(36, mem.used_memory);
    215   std::set<string> tensors;
    216   for (const auto& t : mem.live_tensors) {
    217     tensors.insert(strings::StrCat(t.node, ":", t.output_id));
    218   }
    219   std::set<string> expected;
    220   expected.insert("a:0");
    221   expected.insert("v:0");
    222   expected.insert("assign:0");
    223   EXPECT_EQ(expected, tensors);
    224 }
    225 
    226 }  // namespace
    227 }  // namespace grappler
    228 }  // namespace tensorflow
    229