Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/utils.h"
     17 #include "tensorflow/cc/ops/standard_ops.h"
     18 #include "tensorflow/core/framework/node_def.pb.h"
     19 #include "tensorflow/core/lib/core/status.h"
     20 #include "tensorflow/core/lib/core/threadpool.h"
     21 #include "tensorflow/core/platform/env.h"
     22 #include "tensorflow/core/platform/notification.h"
     23 #include "tensorflow/core/platform/test.h"
     24 
     25 namespace tensorflow {
     26 namespace grappler {
     27 namespace {
     28 
     29 class UtilsTest : public ::testing::Test {
     30  protected:
     31   NodeDef CreateConcatOffsetNode() const {
     32     const string gdef_ascii =
     33         " name: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/"
     34         "ConcatOffset'"
     35         " op: 'ConcatOffset'"
     36         " input: 'InceptionV3/Mixed_7c/Branch_1/concat_v2/axis'"
     37         " input: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape'"
     38         " input: "
     39         " 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1'"
     40         " attr {"
     41         "  key: 'N'"
     42         "  value {"
     43         "    i: 2"
     44         "  }"
     45         " }";
     46     NodeDef node;
     47     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node));
     48     return node;
     49   }
     50 
     51   NodeDef CreateDequeueNode() const {
     52     const string gdef_ascii =
     53         " name: 'Train/TrainInput/input_producer_Dequeue'"
     54         " op: 'QueueDequeueV2'"
     55         " input: 'Train/TrainInput/input_producer'"
     56         " attr {"
     57         "  key: 'component_types'"
     58         "   value {"
     59         "     list {"
     60         "       type: DT_INT32"
     61         "     }"
     62         "   }"
     63         " }"
     64         " attr {"
     65         "   key: 'timeout_ms'"
     66         "   value {"
     67         "     i: -1"
     68         "   }"
     69         " }";
     70 
     71     NodeDef node;
     72     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node));
     73     return node;
     74   }
     75 
     76   NodeDef CreateFusedBatchNormNode() const {
     77     const string gdef_ascii =
     78         " name: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'"
     79         " op: 'FusedBatchNorm'"
     80         " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'"
     81         " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read'"
     82         " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read'"
     83         " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const'"
     84         " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1'"
     85         " attr {"
     86         "   key: 'T'"
     87         "   value {"
     88         "     type: DT_FLOAT"
     89         "   }"
     90         " }"
     91         " attr {"
     92         "   key: 'data_format'"
     93         "   value {"
     94         "     s: 'NHWC'"
     95         "   }"
     96         " }"
     97         " attr {"
     98         "   key: 'epsilon'"
     99         "   value {"
    100         "     f: 0.001"
    101         "   }"
    102         " }"
    103         " attr {"
    104         "   key: 'is_training'"
    105         "   value {"
    106         "     b: true"
    107         "   }"
    108         " }";
    109 
    110     NodeDef node;
    111     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node));
    112     return node;
    113   }
    114 };
    115 
    116 TEST_F(UtilsTest, NodeName) {
    117   EXPECT_EQ("abc", NodeName("abc"));
    118   EXPECT_EQ("abc", NodeName("^abc"));
    119   EXPECT_EQ("abc", NodeName("abc:0"));
    120   EXPECT_EQ("abc", NodeName("^abc:0"));
    121 
    122   EXPECT_EQ("abc/def", NodeName("abc/def"));
    123   EXPECT_EQ("abc/def", NodeName("^abc/def"));
    124   EXPECT_EQ("abc/def", NodeName("abc/def:1"));
    125   EXPECT_EQ("abc/def", NodeName("^abc/def:1"));
    126 
    127   EXPECT_EQ("abc/def0", NodeName("abc/def0"));
    128   EXPECT_EQ("abc/def0", NodeName("^abc/def0"));
    129   EXPECT_EQ("abc/def0", NodeName("abc/def0:0"));
    130   EXPECT_EQ("abc/def0", NodeName("^abc/def0:0"));
    131 
    132   EXPECT_EQ("abc/def_0", NodeName("abc/def_0"));
    133   EXPECT_EQ("abc/def_0", NodeName("^abc/def_0"));
    134   EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3"));
    135   EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3"));
    136 
    137   EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214"));
    138 }
    139 
    140 TEST_F(UtilsTest, NodePosition) {
    141   EXPECT_EQ(2, NodePosition("abc:2"));
    142   EXPECT_EQ(123, NodePosition("abc:123"));
    143   EXPECT_EQ(-1, NodePosition("^abc:123"));
    144   EXPECT_EQ(-1, NodePosition("^abc"));
    145   EXPECT_EQ(0, NodePosition(""));
    146 }
    147 
    148 TEST_F(UtilsTest, AddNodeNamePrefix) {
    149   EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
    150   EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
    151   EXPECT_EQ("OPTIMIZED/", AddPrefixToNodeName("", "OPTIMIZED"));
    152 }
    153 
    154 TEST_F(UtilsTest, ExecuteWithTimeout) {
    155   std::unique_ptr<thread::ThreadPool> thread_pool(
    156       new thread::ThreadPool(Env::Default(), "ExecuteWithTimeout", 2));
    157 
    158   // This should run till the end.
    159   ASSERT_TRUE(ExecuteWithTimeout(
    160       []() {  // Do nothing.
    161       },
    162       1000 /* timeout_in_ms */, thread_pool.get()));
    163 
    164   // This should time out.
    165   Notification notification;
    166   ASSERT_FALSE(ExecuteWithTimeout(
    167       [&notification]() { notification.WaitForNotification(); },
    168       1 /* timeout_in_ms */, thread_pool.get()));
    169   // Make sure to unblock the thread.
    170   notification.Notify();
    171 
    172   // This should run till the end.
    173   ASSERT_TRUE(ExecuteWithTimeout([]() { sleep(1); }, 0 /* timeout_in_ms */,
    174                                  thread_pool.get()));
    175 
    176   // Deleting before local variables go off the stack.
    177   thread_pool.reset();
    178 }
    179 
    180 TEST_F(UtilsTest, NumOutputs) {
    181   GraphDef graph;
    182   EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph));
    183   EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph));
    184   EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph));
    185 }
    186 
    187 TEST_F(UtilsTest, AsControlDependency) {
    188   NodeDef node;
    189   node.set_name("foo");
    190   EXPECT_EQ("^foo", AsControlDependency(node));
    191   EXPECT_EQ("^foo", AsControlDependency(node.name()));
    192   EXPECT_EQ("^foo", AsControlDependency("^foo"));
    193 }
    194 
    195 TEST_F(UtilsTest, GetTailOfChain) {
    196   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    197   Output c0 = ops::Const(s.WithOpName("c0"), {1.0f, 2.0f}, {1, 2});
    198   Output c1 = ops::Const(s.WithOpName("c1"), {3.0f, 4.0f}, {1, 2});
    199   // Add a node with only connected by control output.
    200   Output neg0 = ops::Neg(s.WithOpName("neg0"), c1);
    201   // Add a node with two outputs.
    202   Output neg1 =
    203       ops::Neg(s.WithControlDependencies(neg0).WithOpName("neg1"), c0);
    204   Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
    205   Output id1 = ops::Identity(s.WithOpName("id1"), neg2);
    206   Output id2 = ops::Identity(s.WithOpName("id2"), neg1);
    207   auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
    208   GraphDef graph;
    209   TF_CHECK_OK(s.ToGraphDef(&graph));
    210   LOG(INFO) << graph.DebugString();
    211 
    212   ASSERT_EQ("c0", graph.node(0).name());
    213   ASSERT_EQ("c1", graph.node(1).name());
    214   ASSERT_EQ("neg0", graph.node(2).name());
    215   ASSERT_EQ("neg1", graph.node(3).name());
    216   ASSERT_EQ("neg2", graph.node(4).name());
    217   ASSERT_EQ("id1", graph.node(5).name());
    218   ASSERT_EQ("id2", graph.node(6).name());
    219   ASSERT_EQ("noop", graph.node(7).name());
    220 
    221   NodeMap node_map(&graph);
    222   auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; };
    223   // We walk backwards, starting as "id1", so tail should be "neg1".
    224   NodeDef* tail = GetTailOfChain(graph.node(5), node_map,
    225                                  /*follow_control_input=*/false, is_neg);
    226   EXPECT_NE(tail, nullptr);
    227   EXPECT_EQ("neg1", tail->name());
    228 
    229   // We stop at branching nodes, so tail should be "neg2".
    230   auto is_neg_and_non_branching = [&](const NodeDef& node) {
    231     return node.op() == "Neg" && NumNonControlOutputs(node, node_map) == 1;
    232   };
    233   tail =
    234       GetTailOfChain(graph.node(5), node_map,
    235                      /*follow_control_input=*/false, is_neg_and_non_branching);
    236   EXPECT_NE(tail, nullptr);
    237   EXPECT_EQ("neg2", tail->name());
    238 
    239   // We walk backwards, starting from "noop", also following control inputs,
    240   // so tail should be "neg0".
    241   tail = GetTailOfChain(graph.node(7), node_map,
    242                         /*follow_control_input=*/true, is_neg);
    243   EXPECT_NE(tail, nullptr);
    244   EXPECT_EQ("neg0", tail->name());
    245 
    246   // We walk backwards, starting from "noop", not following control inputs,
    247   // so tail should be "noop" itself.
    248   tail = GetTailOfChain(graph.node(7), node_map,
    249                         /*follow_control_input=*/false, is_neg);
    250   EXPECT_NE(tail, nullptr);
    251   EXPECT_EQ("noop", tail->name());
    252 }
    253 
    254 TEST_F(UtilsTest, DedupControlInputs) {
    255   NodeDef foo;
    256   foo.set_name("foo");
    257   foo.add_input("bar");
    258   DedupControlInputs(&foo);
    259   EXPECT_EQ(1, foo.input_size());
    260   EXPECT_EQ("bar", foo.input(0));
    261 
    262   foo.set_input(0, "^bar");
    263   DedupControlInputs(&foo);
    264   EXPECT_EQ(1, foo.input_size());
    265   EXPECT_EQ("^bar", foo.input(0));
    266 
    267   foo.set_input(0, "bar");
    268   foo.add_input("bar");
    269   DedupControlInputs(&foo);
    270   EXPECT_EQ(2, foo.input_size());
    271   EXPECT_EQ("bar", foo.input(0));
    272   EXPECT_EQ("bar", foo.input(1));
    273 
    274   foo.set_input(1, "^bar");
    275   DedupControlInputs(&foo);
    276   EXPECT_EQ(1, foo.input_size());
    277   EXPECT_EQ("bar", foo.input(0));
    278 
    279   foo.set_input(0, "^bar");
    280   foo.add_input("^bar");
    281   DedupControlInputs(&foo);
    282   EXPECT_EQ(1, foo.input_size());
    283   EXPECT_EQ("^bar", foo.input(0));
    284 
    285   foo.set_input(0, "bar");
    286   foo.add_input("gnu");
    287   foo.add_input("^bar");
    288   foo.add_input("^gnu");
    289   DedupControlInputs(&foo);
    290   EXPECT_EQ(2, foo.input_size());
    291   EXPECT_EQ("bar", foo.input(0));
    292   EXPECT_EQ("gnu", foo.input(1));
    293 }
    294 
    295 TEST_F(UtilsTest, DeleteNodes) {}
    296 
    297 }  // namespace
    298 }  // namespace grappler
    299 }  // namespace tensorflow
    300