1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/utils.h" 17 #include "tensorflow/cc/ops/standard_ops.h" 18 #include "tensorflow/core/framework/node_def.pb.h" 19 #include "tensorflow/core/lib/core/status.h" 20 #include "tensorflow/core/lib/core/threadpool.h" 21 #include "tensorflow/core/platform/env.h" 22 #include "tensorflow/core/platform/notification.h" 23 #include "tensorflow/core/platform/test.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 namespace { 28 29 class UtilsTest : public ::testing::Test { 30 protected: 31 NodeDef CreateConcatOffsetNode() const { 32 const string gdef_ascii = 33 " name: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/" 34 "ConcatOffset'" 35 " op: 'ConcatOffset'" 36 " input: 'InceptionV3/Mixed_7c/Branch_1/concat_v2/axis'" 37 " input: 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape'" 38 " input: " 39 " 'gradients/InceptionV3/Mixed_7c/Branch_1/concat_v2_grad/Shape_1'" 40 " attr {" 41 " key: 'N'" 42 " value {" 43 " i: 2" 44 " }" 45 " }"; 46 NodeDef node; 47 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); 48 return node; 49 } 50 51 NodeDef CreateDequeueNode() const { 52 const string gdef_ascii = 53 " name: 'Train/TrainInput/input_producer_Dequeue'" 54 " op: 'QueueDequeueV2'" 55 " input: 'Train/TrainInput/input_producer'" 56 " attr {" 57 " key: 'component_types'" 58 " value {" 59 " list {" 60 " type: DT_INT32" 61 " }" 62 " }" 63 " }" 64 " attr {" 65 " key: 'timeout_ms'" 66 " value {" 67 " i: -1" 68 " }" 69 " }"; 70 71 NodeDef node; 72 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); 73 return node; 74 } 75 76 NodeDef CreateFusedBatchNormNode() const { 77 const string gdef_ascii = 78 " name: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" 79 " op: 'FusedBatchNorm'" 80 " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/FusedBatchNorm'" 81 " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/gamma/read'" 82 " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/beta/read'" 83 " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const'" 84 " input: 'InceptionV3/Conv2d_1a_3x3/BatchNorm/Const_1'" 85 " attr {" 86 " key: 'T'" 87 " value {" 88 " type: DT_FLOAT" 89 " }" 90 " }" 91 " attr {" 92 " key: 'data_format'" 93 " value {" 94 " s: 'NHWC'" 95 " }" 96 " }" 97 " attr {" 98 " key: 'epsilon'" 99 " value {" 100 " f: 0.001" 101 " }" 102 " }" 103 " attr {" 104 " key: 'is_training'" 105 " value {" 106 " b: true" 107 " }" 108 " }"; 109 110 NodeDef node; 111 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &node)); 112 return node; 113 } 114 }; 115 116 TEST_F(UtilsTest, NodeName) { 117 EXPECT_EQ("abc", NodeName("abc")); 118 EXPECT_EQ("abc", NodeName("^abc")); 119 EXPECT_EQ("abc", NodeName("abc:0")); 120 EXPECT_EQ("abc", NodeName("^abc:0")); 121 122 EXPECT_EQ("abc/def", NodeName("abc/def")); 123 EXPECT_EQ("abc/def", NodeName("^abc/def")); 124 EXPECT_EQ("abc/def", NodeName("abc/def:1")); 125 EXPECT_EQ("abc/def", NodeName("^abc/def:1")); 126 127 EXPECT_EQ("abc/def0", NodeName("abc/def0")); 128 EXPECT_EQ("abc/def0", NodeName("^abc/def0")); 129 EXPECT_EQ("abc/def0", NodeName("abc/def0:0")); 130 EXPECT_EQ("abc/def0", NodeName("^abc/def0:0")); 131 132 EXPECT_EQ("abc/def_0", NodeName("abc/def_0")); 133 EXPECT_EQ("abc/def_0", NodeName("^abc/def_0")); 134 EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3")); 135 EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3")); 136 137 EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214")); 138 } 139 140 TEST_F(UtilsTest, NodePosition) { 141 EXPECT_EQ(2, NodePosition("abc:2")); 142 EXPECT_EQ(123, NodePosition("abc:123")); 143 EXPECT_EQ(-1, NodePosition("^abc:123")); 144 EXPECT_EQ(-1, NodePosition("^abc")); 145 EXPECT_EQ(0, NodePosition("")); 146 } 147 148 TEST_F(UtilsTest, AddNodeNamePrefix) { 149 EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED")); 150 EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED")); 151 EXPECT_EQ("OPTIMIZED/", AddPrefixToNodeName("", "OPTIMIZED")); 152 } 153 154 TEST_F(UtilsTest, ExecuteWithTimeout) { 155 std::unique_ptr<thread::ThreadPool> thread_pool( 156 new thread::ThreadPool(Env::Default(), "ExecuteWithTimeout", 2)); 157 158 // This should run till the end. 159 ASSERT_TRUE(ExecuteWithTimeout( 160 []() { // Do nothing. 161 }, 162 1000 /* timeout_in_ms */, thread_pool.get())); 163 164 // This should time out. 165 Notification notification; 166 ASSERT_FALSE(ExecuteWithTimeout( 167 [¬ification]() { notification.WaitForNotification(); }, 168 1 /* timeout_in_ms */, thread_pool.get())); 169 // Make sure to unblock the thread. 170 notification.Notify(); 171 172 // This should run till the end. 173 ASSERT_TRUE(ExecuteWithTimeout([]() { sleep(1); }, 0 /* timeout_in_ms */, 174 thread_pool.get())); 175 176 // Deleting before local variables go off the stack. 177 thread_pool.reset(); 178 } 179 180 TEST_F(UtilsTest, NumOutputs) { 181 GraphDef graph; 182 EXPECT_EQ(2, NumOutputs(CreateConcatOffsetNode(), &graph)); 183 EXPECT_EQ(5, NumOutputs(CreateFusedBatchNormNode(), &graph)); 184 EXPECT_EQ(1, NumOutputs(CreateDequeueNode(), &graph)); 185 } 186 187 TEST_F(UtilsTest, AsControlDependency) { 188 NodeDef node; 189 node.set_name("foo"); 190 EXPECT_EQ("^foo", AsControlDependency(node)); 191 EXPECT_EQ("^foo", AsControlDependency(node.name())); 192 EXPECT_EQ("^foo", AsControlDependency("^foo")); 193 } 194 195 TEST_F(UtilsTest, GetTailOfChain) { 196 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 197 Output c0 = ops::Const(s.WithOpName("c0"), {1.0f, 2.0f}, {1, 2}); 198 Output c1 = ops::Const(s.WithOpName("c1"), {3.0f, 4.0f}, {1, 2}); 199 // Add a node with only connected by control output. 200 Output neg0 = ops::Neg(s.WithOpName("neg0"), c1); 201 // Add a node with two outputs. 202 Output neg1 = 203 ops::Neg(s.WithControlDependencies(neg0).WithOpName("neg1"), c0); 204 Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1); 205 Output id1 = ops::Identity(s.WithOpName("id1"), neg2); 206 Output id2 = ops::Identity(s.WithOpName("id2"), neg1); 207 auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop")); 208 GraphDef graph; 209 TF_CHECK_OK(s.ToGraphDef(&graph)); 210 LOG(INFO) << graph.DebugString(); 211 212 ASSERT_EQ("c0", graph.node(0).name()); 213 ASSERT_EQ("c1", graph.node(1).name()); 214 ASSERT_EQ("neg0", graph.node(2).name()); 215 ASSERT_EQ("neg1", graph.node(3).name()); 216 ASSERT_EQ("neg2", graph.node(4).name()); 217 ASSERT_EQ("id1", graph.node(5).name()); 218 ASSERT_EQ("id2", graph.node(6).name()); 219 ASSERT_EQ("noop", graph.node(7).name()); 220 221 NodeMap node_map(&graph); 222 auto is_neg = [&](const NodeDef& node) { return node.op() == "Neg"; }; 223 // We walk backwards, starting as "id1", so tail should be "neg1". 224 NodeDef* tail = GetTailOfChain(graph.node(5), node_map, 225 /*follow_control_input=*/false, is_neg); 226 EXPECT_NE(tail, nullptr); 227 EXPECT_EQ("neg1", tail->name()); 228 229 // We stop at branching nodes, so tail should be "neg2". 230 auto is_neg_and_non_branching = [&](const NodeDef& node) { 231 return node.op() == "Neg" && NumNonControlOutputs(node, node_map) == 1; 232 }; 233 tail = 234 GetTailOfChain(graph.node(5), node_map, 235 /*follow_control_input=*/false, is_neg_and_non_branching); 236 EXPECT_NE(tail, nullptr); 237 EXPECT_EQ("neg2", tail->name()); 238 239 // We walk backwards, starting from "noop", also following control inputs, 240 // so tail should be "neg0". 241 tail = GetTailOfChain(graph.node(7), node_map, 242 /*follow_control_input=*/true, is_neg); 243 EXPECT_NE(tail, nullptr); 244 EXPECT_EQ("neg0", tail->name()); 245 246 // We walk backwards, starting from "noop", not following control inputs, 247 // so tail should be "noop" itself. 248 tail = GetTailOfChain(graph.node(7), node_map, 249 /*follow_control_input=*/false, is_neg); 250 EXPECT_NE(tail, nullptr); 251 EXPECT_EQ("noop", tail->name()); 252 } 253 254 TEST_F(UtilsTest, DedupControlInputs) { 255 NodeDef foo; 256 foo.set_name("foo"); 257 foo.add_input("bar"); 258 DedupControlInputs(&foo); 259 EXPECT_EQ(1, foo.input_size()); 260 EXPECT_EQ("bar", foo.input(0)); 261 262 foo.set_input(0, "^bar"); 263 DedupControlInputs(&foo); 264 EXPECT_EQ(1, foo.input_size()); 265 EXPECT_EQ("^bar", foo.input(0)); 266 267 foo.set_input(0, "bar"); 268 foo.add_input("bar"); 269 DedupControlInputs(&foo); 270 EXPECT_EQ(2, foo.input_size()); 271 EXPECT_EQ("bar", foo.input(0)); 272 EXPECT_EQ("bar", foo.input(1)); 273 274 foo.set_input(1, "^bar"); 275 DedupControlInputs(&foo); 276 EXPECT_EQ(1, foo.input_size()); 277 EXPECT_EQ("bar", foo.input(0)); 278 279 foo.set_input(0, "^bar"); 280 foo.add_input("^bar"); 281 DedupControlInputs(&foo); 282 EXPECT_EQ(1, foo.input_size()); 283 EXPECT_EQ("^bar", foo.input(0)); 284 285 foo.set_input(0, "bar"); 286 foo.add_input("gnu"); 287 foo.add_input("^bar"); 288 foo.add_input("^gnu"); 289 DedupControlInputs(&foo); 290 EXPECT_EQ(2, foo.input_size()); 291 EXPECT_EQ("bar", foo.input(0)); 292 EXPECT_EQ("gnu", foo.input(1)); 293 } 294 295 TEST_F(UtilsTest, DeleteNodes) {} 296 297 } // namespace 298 } // namespace grappler 299 } // namespace tensorflow 300