1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/graph_properties.h" 17 #include "tensorflow/cc/framework/scope.h" 18 #include "tensorflow/cc/ops/standard_ops.h" 19 #include "tensorflow/core/framework/node_def_builder.h" 20 #include "tensorflow/core/framework/tensor_shape.pb.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/grappler/clusters/single_machine.h" 23 #include "tensorflow/core/grappler/grappler_item.h" 24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" 25 #include "tensorflow/core/grappler/inputs/utils.h" 26 #include "tensorflow/core/lib/core/status_test_util.h" 27 #include "tensorflow/core/lib/io/path.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/platform/protobuf.h" 30 #include "tensorflow/core/platform/test.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 namespace { 35 36 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata"; 37 38 class GraphPropertiesTest : public ::testing::Test { 39 public: 40 void SetUp() override { 41 // Provision a single machine with 3 cpu cores 42 cluster_.reset(new SingleMachine(5 * 60, 3, 0)); 43 TF_CHECK_OK(cluster_->Provision()); 44 } 45 46 void TearDown() override { 47 TF_CHECK_OK(cluster_->Shutdown()); 48 cluster_.reset(); 49 } 50 51 protected: 52 // Returns a string form of <p>, suitable for comparing type and shape. 53 // Example output for 4-d float tensor: "float: [10,2,30,4]" 54 string PropToString(const OpInfo::TensorProperties& p) { 55 string s = strings::StrCat(DataTypeString(p.dtype()), ": "); 56 if (p.shape().unknown_rank()) { 57 strings::StrAppend(&s, "?"); 58 } else { 59 strings::StrAppend(&s, "["); 60 for (int i = 0; i < p.shape().dim_size(); ++i) { 61 strings::StrAppend(&s, i == 0 ? "" : ",", 62 std::max<int64>(p.shape().dim(i).size(), -1)); 63 } 64 strings::StrAppend(&s, "]"); 65 } 66 return s; 67 } 68 69 std::unique_ptr<SingleMachine> cluster_; 70 }; 71 72 TEST_F(GraphPropertiesTest, StaticProperties) { 73 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, 74 cluster_->GetDeviceNames()); 75 GrapplerItem item; 76 CHECK(fake_input.NextItem(&item)); 77 78 GraphProperties properties(item); 79 Status s = properties.InferStatically(true); 80 TF_CHECK_OK(s); 81 82 for (const auto& node : item.graph.node()) { 83 if (node.op() == "RandomStandardNormal") { 84 // The node has one input (the shape of the tensor to generate). 85 EXPECT_EQ(1, properties.GetInputProperties(node.name()).size()); 86 // The const node has one output. 87 const auto props = properties.GetOutputProperties(node.name()); 88 EXPECT_EQ(1, props.size()); 89 const OpInfo::TensorProperties& prop = props[0]; 90 EXPECT_EQ(DT_FLOAT, prop.dtype()); 91 EXPECT_FALSE(prop.shape().unknown_rank()); 92 EXPECT_EQ(2, prop.shape().dim_size()); 93 EXPECT_EQ(10, prop.shape().dim(0).size()); 94 EXPECT_EQ(1, prop.shape().dim(1).size()); 95 } else if (node.op() == "AddN") { 96 const auto in_props = properties.GetInputProperties(node.name()); 97 EXPECT_EQ(1, in_props.size()); 98 const OpInfo::TensorProperties& in_prop = in_props[0]; 99 EXPECT_EQ(DT_FLOAT, in_prop.dtype()); 100 EXPECT_FALSE(in_prop.shape().unknown_rank()); 101 EXPECT_EQ(2, in_prop.shape().dim_size()); 102 EXPECT_EQ(10, in_prop.shape().dim(0).size()); 103 EXPECT_EQ(1, in_prop.shape().dim(1).size()); 104 const auto out_props = properties.GetOutputProperties(node.name()); 105 EXPECT_EQ(1, out_props.size()); 106 string in_prop_str; 107 ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str); 108 string out_prop_str; 109 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0], 110 &out_prop_str); 111 EXPECT_EQ(in_prop_str, out_prop_str); 112 } 113 } 114 } 115 116 TEST_F(GraphPropertiesTest, DynamicProperties) { 117 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, 118 cluster_->GetDeviceNames()); 119 GrapplerItem item; 120 CHECK(fake_input.NextItem(&item)); 121 122 GraphProperties properties(item); 123 TF_CHECK_OK(cluster_->Initialize(item)); 124 Status s = properties.InferDynamically(cluster_.get()); 125 TF_CHECK_OK(s); 126 127 for (const auto& node : item.graph.node()) { 128 if (node.op() == "RandomStandardNormal") { 129 // The random node is missing from the cost graph (why ?) 130 EXPECT_EQ(0, properties.GetInputProperties(node.name()).size()); 131 } else if (node.op() == "AddN") { 132 // Since the random node is missing, we can't infer the input properties 133 // of the first AddN node. The other AddN nodes have the expected 134 // properties. 135 if (node.name() == "AddN") { 136 const auto props = properties.GetInputProperties(node.name()); 137 EXPECT_EQ(1, props.size()); 138 const OpInfo::TensorProperties& prop = props[0]; 139 EXPECT_EQ(DT_INVALID, prop.dtype()); 140 EXPECT_TRUE(prop.shape().unknown_rank()); 141 } else { 142 const auto props = properties.GetInputProperties(node.name()); 143 EXPECT_EQ(1, props.size()); 144 const OpInfo::TensorProperties& prop = props[0]; 145 EXPECT_EQ(DT_FLOAT, prop.dtype()); 146 EXPECT_FALSE(prop.shape().unknown_rank()); 147 EXPECT_EQ(2, prop.shape().dim_size()); 148 EXPECT_EQ(10, prop.shape().dim(0).size()); 149 EXPECT_EQ(1, prop.shape().dim(1).size()); 150 const auto out_props = properties.GetOutputProperties(node.name()); 151 EXPECT_EQ(1, out_props.size()); 152 string prop_str; 153 ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str); 154 string out_prop_str; 155 ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0], 156 &out_prop_str); 157 EXPECT_EQ(prop_str, out_prop_str); 158 } 159 } 160 } 161 } 162 163 TEST_F(GraphPropertiesTest, Variables) { 164 GrapplerItem item; 165 TF_CHECK_OK(NodeDefBuilder("Var", "Variable") 166 .Attr("dtype", DT_FLOAT) 167 .Attr("shape", TensorShape({3, 7})) 168 .Finalize(item.graph.add_node())); 169 item.fetch.push_back("Var"); 170 171 Tensor initial_val(DT_FLOAT, TensorShape({3, 7})); 172 test::FillIota<float>(&initial_val, 0); 173 TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const") 174 .Attr("dtype", DT_FLOAT) 175 .Attr("value", initial_val) 176 .Finalize(item.graph.add_node())); 177 TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign") 178 .Input("Var", 0, DT_FLOAT_REF) 179 .Input("InitialVal", 0, DT_FLOAT) 180 .Finalize(item.graph.add_node())); 181 item.init_ops.push_back("InitVar"); 182 183 { 184 GraphProperties static_properties(item); 185 TF_CHECK_OK(static_properties.InferStatically(false)); 186 187 const auto props = static_properties.GetOutputProperties("Var"); 188 EXPECT_EQ(1, props.size()); 189 const OpInfo::TensorProperties& prop = props[0]; 190 EXPECT_EQ(DT_FLOAT_REF, prop.dtype()); 191 EXPECT_FALSE(prop.shape().unknown_rank()); 192 EXPECT_EQ(2, prop.shape().dim_size()); 193 EXPECT_EQ(3, prop.shape().dim(0).size()); 194 EXPECT_EQ(7, prop.shape().dim(1).size()); 195 } 196 { 197 TF_CHECK_OK(cluster_->Initialize(item)); 198 GraphProperties dynamic_properties(item); 199 TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get())); 200 201 const auto props = dynamic_properties.GetOutputProperties("Var"); 202 EXPECT_EQ(1, props.size()); 203 const OpInfo::TensorProperties& prop = props[0]; 204 EXPECT_EQ(DT_FLOAT_REF, prop.dtype()); 205 EXPECT_FALSE(prop.shape().unknown_rank()); 206 EXPECT_EQ(2, prop.shape().dim_size()); 207 EXPECT_EQ(3, prop.shape().dim(0).size()); 208 EXPECT_EQ(7, prop.shape().dim(1).size()); 209 } 210 } 211 212 TEST_F(GraphPropertiesTest, VarHandles) { 213 GrapplerItem item; 214 TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp") 215 .Attr("dtype", DT_FLOAT) 216 .Attr("shape", TensorShape({3, 7})) 217 .Finalize(item.graph.add_node())); 218 219 TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp") 220 .Attr("dtype", DT_FLOAT) 221 .Input("Var", 0, DT_RESOURCE) 222 .Finalize(item.graph.add_node())); 223 224 GraphProperties properties(item); 225 TF_CHECK_OK(properties.InferStatically(false)); 226 227 const auto props = properties.GetOutputProperties("VarRead"); 228 EXPECT_EQ(1, props.size()); 229 const OpInfo::TensorProperties& prop = props[0]; 230 EXPECT_EQ(DT_FLOAT, prop.dtype()); 231 EXPECT_FALSE(prop.shape().unknown_rank()); 232 EXPECT_EQ(2, prop.shape().dim_size()); 233 EXPECT_EQ(3, prop.shape().dim(0).size()); 234 EXPECT_EQ(7, prop.shape().dim(1).size()); 235 } 236 237 TEST_F(GraphPropertiesTest, Queues) { 238 // Create a graph with known input shapes, and propagate the shapes through a 239 // couple of queues. 240 tensorflow::Scope root = tensorflow::Scope::NewRootScope(); 241 242 auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT}); 243 Output rnd = 244 ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT); 245 Output square1 = ops::Square(root.WithOpName("Square1"), rnd); 246 auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1}); 247 auto dequeue1 = 248 ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT}); 249 250 auto q2 = 251 ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT}); 252 Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]); 253 auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2}); 254 auto dequeue2 = 255 ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT}); 256 257 // Create a queue that feeds itself. 258 auto q3 = 259 ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT}); 260 auto dequeue3 = 261 ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT}); 262 auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2}); 263 auto enqueue3 = 264 ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output}); 265 266 auto q4 = 267 ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT}); 268 auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2}); 269 auto enqueue4_2 = 270 ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]}); 271 auto dequeue4 = 272 ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT}); 273 274 // Create a queue that takes in three tensors. 275 auto q5 = ops::RandomShuffleQueue( 276 root.WithOpName("Queue5"), 277 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT}); 278 Output rnd2 = 279 ops::RandomNormal(root.WithOpName("rnd"), {10}, DataType::DT_DOUBLE); 280 Output rnd3 = 281 ops::RandomNormal(root.WithOpName("rnd"), {1, 2, 3}, DataType::DT_FLOAT); 282 auto enqueue5 = 283 ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3}); 284 auto dequeue5 = ops::QueueDequeue( 285 root.WithOpName("Dequeue5"), q5, 286 {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT}); 287 288 GrapplerItem item; 289 TF_CHECK_OK(root.ToGraphDef(&item.graph)); 290 291 GraphProperties properties(item); 292 TF_CHECK_OK(properties.InferStatically(false)); 293 294 const auto props1 = properties.GetOutputProperties("Dequeue1"); 295 ASSERT_EQ(1, props1.size()); 296 EXPECT_EQ("float: [3,7]", PropToString(props1[0])); 297 298 const auto props2 = properties.GetOutputProperties("Dequeue2"); 299 ASSERT_EQ(1, props2.size()); 300 EXPECT_EQ("float: [3,7]", PropToString(props2[0])); 301 302 const auto props3 = properties.GetOutputProperties("Dequeue3"); 303 ASSERT_EQ(1, props3.size()); 304 EXPECT_EQ("float: [3,7]", PropToString(props3[0])); 305 306 // The dequeue3 op shape is unknown. The square2 op shape is known. Verify 307 // that we merge the 2 properly to determine the shape of the data coming out 308 // of the queue. 309 const auto props4 = properties.GetOutputProperties("Dequeue4"); 310 ASSERT_EQ(1, props4.size()); 311 EXPECT_EQ("float: [3,7]", PropToString(props4[0])); 312 313 // The dequeue5 op shape is known. 314 const auto props5 = properties.GetOutputProperties("Dequeue5"); 315 ASSERT_EQ(3, props5.size()); 316 EXPECT_EQ("float: [3,7]", PropToString(props5[0])); 317 EXPECT_EQ("double: [10]", PropToString(props5[1])); 318 EXPECT_EQ("float: [1,2,3]", PropToString(props5[2])); 319 } 320 321 TEST_F(GraphPropertiesTest, MergeWithoutLoops) { 322 // Test graph produced in python using: 323 /* 324 with tf.Graph().as_default(): 325 x = tf.constant(2) 326 y = tf.constant(5) 327 z = tf.ones([1,1,1]) 328 def f1(): return tf.concat([z, z], axis=0) 329 def f2(): return tf.concat([z, z], axis=1) 330 r = tf.cond(tf.less(x, y), f1, f2) 331 tf.concat([r, r], axis=2) 332 with open('/tmp/graph.pbtxt', 'w') as f: 333 f.write(str(tf.get_default_graph().as_graph_def())) 334 */ 335 336 GrapplerItem item; 337 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 338 "merge_without_loops.pbtxt"); 339 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 340 GraphProperties properties(item); 341 TF_CHECK_OK(properties.InferStatically(false)); 342 343 std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"}; 344 std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]", 345 "float: [1,2,1]"}; 346 for (int i = 0; i < nodes.size(); i++) { 347 const auto props = properties.GetOutputProperties(nodes[i]); 348 const OpInfo::TensorProperties& prop = props[0]; 349 EXPECT_EQ(DT_FLOAT, prop.dtype()); 350 EXPECT_EQ(expected_outputs[i], PropToString(prop)); 351 } 352 353 // The "Less" node should be fed by 2 int32 scalar constant values. 354 const auto props = properties.GetInputProperties("Less"); 355 EXPECT_EQ(2, props.size()); 356 for (int i = 0; i < props.size(); ++i) { 357 EXPECT_EQ(DT_INT32, props[i].dtype()); 358 EXPECT_TRUE(props[i].has_value()); 359 EXPECT_EQ("int32: []", PropToString(props[i])); 360 } 361 } 362 363 TEST_F(GraphPropertiesTest, WhileLoop) { 364 // Test graph produced in python using: 365 /* 366 with tf.Graph().as_default(): 367 i0 = tf.constant(0) 368 m0 = tf.placeholder([-1, 2]) 369 c = lambda i, m: i < 10 370 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 371 r = tf.while_loop( 372 c, b, loop_vars=[i0, m0], 373 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 374 with open('/tmp/graph.pbtxt', 'w') as f: 375 f.write(str(tf.get_default_graph().as_graph_def())) 376 */ 377 378 GrapplerItem item; 379 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 380 "while_loop.pbtxt"); 381 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 382 GraphProperties properties(item); 383 TF_CHECK_OK(properties.InferStatically(false)); 384 385 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", 386 "while/Exit_1"}; 387 for (const string& node : nodes) { 388 const auto props = properties.GetOutputProperties(node); 389 const OpInfo::TensorProperties& prop = props[0]; 390 EXPECT_EQ(DT_FLOAT, prop.dtype()); 391 EXPECT_EQ("float: [-1,2]", PropToString(prop)); 392 } 393 394 // The loop outputs batch dim should be different from the input batch dim 395 // since we concatenated along the batch dim. 396 auto shape_in = properties.GetOutputProperties("ones").at(0).shape(); 397 auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape(); 398 EXPECT_GE(-2, shape_in.dim(0).size()); 399 EXPECT_GE(-2, shape_out.dim(0).size()); 400 EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size()); 401 } 402 403 TEST_F(GraphPropertiesTest, NestedLoop) { 404 // Test graph produced in python using: 405 /* 406 with tf.Graph().as_default(): 407 i0 = tf.constant(0) 408 409 def inner(j, y): 410 def inner_cond(j, y): 411 return j < 3 412 413 def inner_body(j, y): 414 return j+1, tf.concat([y, y], axis=2) 415 416 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y], 417 shape_invariants=[i0.get_shape(), 418 tf.TensorShape([None, 1, None])]) 419 420 def outer_cond(i, x): 421 return i < 3 422 423 def outer_body(i, x): 424 j, y = inner(0, x) 425 return i+1, tf.concat([x, x], axis=0) 426 427 r = tf.while_loop(outer_cond, outer_body, 428 loop_vars=[i0, tf.ones([1, 1, 1])], 429 shape_invariants=[i0.get_shape(), 430 tf.TensorShape([None, 1, None])]) 431 432 with open('/tmp/graph.pbtxt', 'w') as f: 433 f.write(str(tf.get_default_graph().as_graph_def())) 434 */ 435 436 GrapplerItem item; 437 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 438 "nested_loop.pbtxt"); 439 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 440 GraphProperties properties(item); 441 TF_CHECK_OK(properties.InferStatically(false)); 442 443 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", 444 "while/Exit_1"}; 445 std::vector<string> inner_nodes{"while/while/Merge_1", 446 "while/while/NextIteration_1", 447 "while/while/Exit_1"}; 448 for (const string& node : outer_nodes) { 449 const auto props = properties.GetOutputProperties(node); 450 const OpInfo::TensorProperties& prop = props[0]; 451 EXPECT_EQ(DT_FLOAT, prop.dtype()); 452 EXPECT_EQ("float: [-1,1,1]", PropToString(prop)); 453 } 454 for (const string& node : inner_nodes) { 455 const auto props = properties.GetOutputProperties(node); 456 const OpInfo::TensorProperties& prop = props[0]; 457 EXPECT_EQ(DT_FLOAT, prop.dtype()); 458 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop)); 459 } 460 } 461 462 TEST_F(GraphPropertiesTest, LoopsAndQueues) { 463 // Test graph produced in python using: 464 /* 465 with tf.Graph().as_default(): 466 i0 = tf.constant(0) 467 q = tf.FIFOQueue(1, "float") 468 469 def inner(j, y): 470 def inner_cond(j, y): 471 return j < 3 472 473 def inner_body(j, y): 474 return j+1, tf.concat([y, y], axis=0) 475 476 return tf.while_loop(inner_cond, inner_body, 477 loop_vars=[j, y], 478 shape_invariants=[i0.get_shape(), 479 tf.TensorShape(None)]) 480 481 def outer_cond(i, x): 482 return i < 3 483 484 def outer_body(i, x): 485 q.enqueue(x) 486 y = tf.concat([x, x], axis=2) 487 inner(0, q.dequeue()) 488 return i+1, y 489 490 i, z = tf.while_loop(outer_cond, outer_body, 491 loop_vars=[i0, tf.ones([1, 1, 1])], 492 shape_invariants=[i0.get_shape(), 493 tf.TensorShape([None, 1, None])]) 494 495 with open('/tmp/graph.pbtxt', 'w') as f: 496 f.write(str(tf.get_default_graph().as_graph_def())) 497 */ 498 499 GrapplerItem item; 500 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 501 "loops_and_queues.pbtxt"); 502 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 503 GraphProperties properties(item); 504 TF_CHECK_OK(properties.InferStatically(false)); 505 506 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", 507 "while/Exit_1"}; 508 std::vector<string> inner_nodes{"while/while/Merge_1", 509 "while/while/NextIteration_1", 510 "while/while/Exit_1"}; 511 for (const string& node : outer_nodes) { 512 const auto props = properties.GetOutputProperties(node); 513 const OpInfo::TensorProperties& prop = props[0]; 514 EXPECT_EQ(DT_FLOAT, prop.dtype()); 515 EXPECT_EQ("float: [1,1,-1]", PropToString(prop)); 516 } 517 for (const string& node : inner_nodes) { 518 const auto props = properties.GetOutputProperties(node); 519 const OpInfo::TensorProperties& prop = props[0]; 520 EXPECT_EQ(DT_FLOAT, prop.dtype()); 521 EXPECT_EQ("float: [-1,1,-1]", PropToString(prop)); 522 } 523 } 524 525 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) { 526 // Test graph produced in python using: 527 /* 528 with tf.Graph().as_default(): 529 i0 = tf.constant(0) 530 with tf.variable_scope(VariableScope(reuse=None, use_resource=True)): 531 v = tf.get_variable(initializer=i0, name='loop_var') 532 533 def inner(j, y): 534 def inner_cond(j, y): 535 return j < 3 536 537 def inner_body(j, y): 538 return j + 1, y + y 539 540 return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y]) 541 542 def outer_cond(i, x): 543 return i < 3 544 545 def outer_body(i, x): 546 y = x + x 547 inner(0, v) 548 return i + 1, y 549 550 v, z = tf.while_loop(outer_cond, outer_body, 551 loop_vars=[v, tf.constant(1)]) 552 553 with open('/tmp/graph.pbtxt', 'w') as f: 554 f.write(str(tf.get_default_graph().as_graph_def())) 555 */ 556 557 GrapplerItem item; 558 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 559 "loops_and_resource_vars.pbtxt"); 560 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 561 GraphProperties properties(item); 562 TF_CHECK_OK(properties.InferStatically(false)); 563 564 std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", 565 "while/Exit_1"}; 566 std::vector<string> inner_nodes{"while/while/Merge_1", 567 "while/while/NextIteration_1", 568 "while/while/Exit_1"}; 569 for (const string& node : outer_nodes) { 570 const auto props = properties.GetOutputProperties(node); 571 const OpInfo::TensorProperties& prop = props[0]; 572 EXPECT_EQ(DT_INT32, prop.dtype()); 573 EXPECT_EQ("int32: []", PropToString(prop)); 574 } 575 for (const string& node : inner_nodes) { 576 const auto props = properties.GetOutputProperties(node); 577 const OpInfo::TensorProperties& prop = props[0]; 578 EXPECT_EQ(DT_INT32, prop.dtype()); 579 EXPECT_EQ("int32: []", PropToString(prop)); 580 } 581 } 582 583 TEST_F(GraphPropertiesTest, QueuesAndLoops) { 584 // Test graph produced in python using: 585 /* 586 with tf.Graph().as_default(): 587 i0 = tf.constant(0) 588 q0 = tf.FIFOQueue(1, "float") 589 q0.enqueue(tf.ones([2, 2])) 590 q1 = tf.FIFOQueue(1, "float") 591 592 def c(i, m): 593 return i < 10 594 595 def b(i, m): 596 return i+1, tf.concat([m, m], axis=0) 597 598 i, m = tf.while_loop( 599 c, b, loop_vars=[i0, q0.dequeue()], 600 shape_invariants=[i0.get_shape(), tf.TensorShape(None)]) 601 602 q1.enqueue(m) 603 v = q1.dequeue(); 604 tf.concat([v, v], axis=1) 605 with open('/tmp/graph.pbtxt', 'w') as f: 606 f.write(str(tf.get_default_graph().as_graph_def())) 607 */ 608 609 GrapplerItem item; 610 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 611 "queues_and_loops.pbtxt"); 612 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 613 GraphProperties properties(item); 614 TF_CHECK_OK(properties.InferStatically(false)); 615 616 std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", 617 "while/Exit_1"}; 618 619 for (const string& node : nodes) { 620 const auto props = properties.GetOutputProperties(node); 621 const OpInfo::TensorProperties& prop = props[0]; 622 EXPECT_EQ(DT_FLOAT, prop.dtype()); 623 EXPECT_EQ("float: [-1,2]", PropToString(prop)); 624 } 625 626 const auto props = properties.GetOutputProperties("concat"); 627 const OpInfo::TensorProperties& prop = props[0]; 628 EXPECT_EQ(DT_FLOAT, prop.dtype()); 629 EXPECT_EQ("float: [-1,4]", PropToString(prop)); 630 } 631 632 TEST_F(GraphPropertiesTest, InferRestoreOpShape) { 633 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 634 Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}), 635 DataType::DT_FLOAT); 636 Output filename = 637 ops::Const(s.WithOpName("filename"), string("model"), TensorShape()); 638 Output tensor_name = 639 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape()); 640 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name, 641 DataType::DT_FLOAT); 642 Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore); 643 644 Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"), 645 string("256 256 0,128:-"), TensorShape()); 646 Output restore_slice = 647 ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name, 648 shape_and_slice, DataType::DT_FLOAT); 649 Output init_restore_slice = 650 ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice); 651 652 Output restore_v2 = 653 ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name, 654 shape_and_slice, DataType::DT_FLOAT); 655 Output init_restore_v2 = 656 ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2); 657 658 GrapplerItem item; 659 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 660 item.fetch.push_back("init_restore"); 661 662 GraphProperties properties(item); 663 TF_CHECK_OK(properties.InferStatically(false)); 664 665 const auto restore_props = properties.GetOutputProperties("restore"); 666 const OpInfo::TensorProperties& restore_prop = restore_props[0]; 667 EXPECT_EQ(DT_FLOAT, restore_prop.dtype()); 668 EXPECT_EQ("float: [128,256]", PropToString(restore_prop)); 669 670 const auto restore_slice_props = 671 properties.GetOutputProperties("restore_slice"); 672 const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0]; 673 EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype()); 674 EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop)); 675 676 const auto restorev2_props = properties.GetOutputProperties("restore_v2"); 677 const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0]; 678 EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype()); 679 EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop)); 680 681 // Check input shapes of assign op are propagted correctly. 682 const auto input_props = properties.GetInputProperties("init_restore"); 683 ASSERT_EQ(2, input_props.size()); 684 const OpInfo::TensorProperties& input_prop = input_props[1]; 685 EXPECT_EQ(DT_FLOAT, input_prop.dtype()); 686 EXPECT_EQ("float: [128,256]", PropToString(input_prop)); 687 } 688 689 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { 690 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 691 Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(), 692 DataType::DT_FLOAT); 693 Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}), 694 DataType::DT_FLOAT); 695 Output filename = 696 ops::Const(s.WithOpName("filename"), string("model"), TensorShape()); 697 Output tensor_name = 698 ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape()); 699 Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name, 700 DataType::DT_FLOAT); 701 Output init = ops::Assign(s.WithOpName("init"), var, restore); 702 Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore); 703 704 GrapplerItem item; 705 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 706 item.fetch.push_back("init"); 707 item.fetch.push_back("init2"); 708 709 GraphProperties properties(item); 710 TF_CHECK_OK(properties.InferStatically(false)); 711 712 const auto props = properties.GetOutputProperties("restore"); 713 const OpInfo::TensorProperties& prop = props[0]; 714 EXPECT_EQ(DT_FLOAT, prop.dtype()); 715 EXPECT_EQ("float: [128,256]", PropToString(prop)); 716 } 717 718 TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) { 719 // Test graph produced in python using: 720 /* 721 @function.Defun(*[tf.float32] * 2, noinline=True) 722 def MyAdd(x, y): 723 return tf.add(x,y) 724 725 with tf.Graph().as_default(): 726 x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32) 727 y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32) 728 z = MyAdd(x, y) 729 z = MyAdd(x, z) 730 */ 731 // Check that the shape of the second MyAdd node propagates 732 // correctly. 733 GrapplerItem item; 734 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 735 "simple_function.pbtxt"); 736 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 737 GraphProperties properties(item); 738 TF_CHECK_OK(properties.InferStatically(false)); 739 const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1"); 740 const OpInfo::TensorProperties& prop = props[0]; 741 EXPECT_EQ(DT_FLOAT, prop.dtype()); 742 EXPECT_FALSE(prop.shape().unknown_rank()); 743 EXPECT_EQ(2, prop.shape().dim_size()); 744 EXPECT_EQ(1, prop.shape().dim(0).size()); 745 EXPECT_EQ(2, prop.shape().dim(1).size()); 746 747 PartialTensorShape shape(prop.shape()); 748 EXPECT_TRUE(shape.IsFullyDefined()); 749 EXPECT_FALSE(shape.unknown_rank()); 750 } 751 752 TEST_F(GraphPropertiesTest, SymbolicShapes) { 753 // Build a simple graph with placeholders of unknown dimensions. These 754 // dimensions will be encoded symbolically. 755 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 756 757 Output a = 758 ops::Placeholder(s.WithOpName("a"), DT_FLOAT, 759 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 760 Output b = 761 ops::Placeholder(s.WithOpName("b"), DT_FLOAT, 762 ops::Placeholder::Shape(PartialTensorShape({-1}))); 763 Output c = ops::Identity(s.WithOpName("c"), a); 764 Output d = ops::Identity(s.WithOpName("d"), b); 765 Output e = ops::Add(s.WithOpName("e"), c, d); 766 Output f = ops::Add(s.WithOpName("f"), a, c); 767 768 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); 769 Output g = ops::Shape(s.WithOpName("g"), c); 770 Output h = ops::Fill(s.WithOpName("h"), g, zero); 771 772 GrapplerItem item; 773 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 774 775 GraphProperties properties(item); 776 TF_CHECK_OK(properties.InferStatically(false)); 777 const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); 778 const auto shape_c = properties.GetOutputProperties("c").at(0).shape(); 779 EXPECT_EQ(2, shape_a.dim_size()); 780 EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size()); 781 EXPECT_GE(-2, shape_a.dim(0).size()); 782 EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size()); 783 EXPECT_GE(-2, shape_a.dim(1).size()); 784 EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size()); 785 786 PartialTensorShape shape(shape_a); 787 EXPECT_FALSE(shape.IsFullyDefined()); 788 EXPECT_FALSE(shape.unknown_rank()); 789 790 const auto shape_b = properties.GetOutputProperties("b").at(0).shape(); 791 const auto shape_d = properties.GetOutputProperties("d").at(0).shape(); 792 EXPECT_EQ(1, shape_b.dim_size()); 793 EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size()); 794 EXPECT_GE(-2, shape_b.dim(0).size()); 795 EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size()); 796 EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size()); 797 798 const auto shape_e = properties.GetOutputProperties("e").at(0).shape(); 799 ASSERT_EQ(2, shape_e.dim_size()); 800 EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size()); 801 EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size()); 802 EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size()); 803 804 const auto shape_f = properties.GetOutputProperties("f").at(0).shape(); 805 ASSERT_EQ(2, shape_f.dim_size()); 806 EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size()); 807 EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size()); 808 809 const auto shape_h = properties.GetOutputProperties("h").at(0).shape(); 810 ASSERT_EQ(2, shape_f.dim_size()); 811 EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size()); 812 EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size()); 813 } 814 815 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { 816 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 817 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); 818 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); 819 Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1}); 820 GrapplerItem item; 821 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 822 // Create a graph with node a removed (say by some graph optimization 823 // pass), noting that node c is colocated with a. This is fine as it 824 // is in the late stage of graph execution, the colocation constraints have 825 // been validated previously and the device placement of nodes has completed. 826 GraphDef optimized_graph; 827 for (const auto& node : item.graph.node()) { 828 if (node.name() != "a") { 829 *optimized_graph.add_node() = node; 830 } 831 } 832 item.graph.Swap(&optimized_graph); 833 GraphProperties properties(item); 834 // This function should return OK, since it doesn't validate the colocation 835 // constraints internally. 836 TF_EXPECT_OK(properties.InferStatically(false)); 837 } 838 839 TEST_F(GraphPropertiesTest, ShapeTracking) { 840 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 841 Output a = 842 ops::Placeholder(s.WithOpName("a"), DT_FLOAT, 843 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 844 Output b = 845 ops::Placeholder(s.WithOpName("b"), DT_FLOAT, 846 ops::Placeholder::Shape(PartialTensorShape({-1}))); 847 Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); 848 auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b}); 849 Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero); 850 Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero); 851 852 GrapplerItem item; 853 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 854 855 GraphProperties properties(item); 856 TF_CHECK_OK(properties.InferStatically(false)); 857 const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); 858 const auto shape_b = properties.GetOutputProperties("b").at(0).shape(); 859 const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape(); 860 const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape(); 861 EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString()); 862 EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString()); 863 } 864 865 TEST_F(GraphPropertiesTest, FedNodes) { 866 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, 867 cluster_->GetDeviceNames()); 868 GrapplerItem item; 869 CHECK(fake_input.NextItem(&item)); 870 871 { 872 // Conservative shape analysis: the shape of fed ports should be unknown 873 GraphProperties properties(item); 874 Status s = properties.InferStatically(false); 875 TF_CHECK_OK(s); 876 for (const auto& node : item.graph.node()) { 877 if (node.op() == "Const") { 878 continue; 879 } 880 const auto in_props = properties.GetInputProperties(node.name()); 881 EXPECT_EQ(1, in_props.size()); 882 const OpInfo::TensorProperties& in_prop = in_props[0]; 883 const auto out_props = properties.GetOutputProperties(node.name()); 884 EXPECT_EQ(1, out_props.size()); 885 const OpInfo::TensorProperties& out_prop = out_props[0]; 886 887 if (node.name() == "x") { 888 // x is fed: its input should have a known shape, while its output 889 // doesn't 890 EXPECT_FALSE(in_prop.shape().unknown_rank()); 891 EXPECT_EQ(1, in_prop.shape().dim_size()); 892 EXPECT_EQ(2, in_prop.shape().dim(0).size()); 893 EXPECT_TRUE(out_prop.shape().unknown_rank()); 894 } else if (node.op() == "Square" || node.op() == "AddN") { 895 // These nodes are in the fanout of x: their shapes should be unknown. 896 EXPECT_TRUE(in_prop.shape().unknown_rank()); 897 EXPECT_TRUE(out_prop.shape().unknown_rank()); 898 } 899 } 900 } 901 { 902 // Optimistic shape analysis: the shape of fed ports should be derived from 903 // the shape of the fanin. 904 GraphProperties properties(item); 905 Status s = properties.InferStatically(true); 906 TF_CHECK_OK(s); 907 for (const auto& node : item.graph.node()) { 908 if (node.op() == "Square" || node.op() == "AddN") { 909 const auto in_props = properties.GetInputProperties(node.name()); 910 EXPECT_EQ(1, in_props.size()); 911 const OpInfo::TensorProperties& in_prop = in_props[0]; 912 EXPECT_EQ(DT_FLOAT, in_prop.dtype()); 913 EXPECT_FALSE(in_prop.shape().unknown_rank()); 914 EXPECT_EQ(2, in_prop.shape().dim_size()); 915 const auto out_props = properties.GetOutputProperties(node.name()); 916 EXPECT_EQ(1, out_props.size()); 917 const OpInfo::TensorProperties& out_prop = out_props[0]; 918 EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString()); 919 } 920 } 921 } 922 } 923 924 TEST_F(GraphPropertiesTest, Performance) { 925 // Load a large graph with many nested loops to make sure we can infer shapes 926 // quickly. 927 GrapplerItem item; 928 string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath, 929 "large_graph.pbtxt.html"); 930 TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); 931 GraphProperties properties(item); 932 TF_CHECK_OK(properties.InferStatically(false)); 933 } 934 935 } // namespace 936 } // namespace grappler 937 } // namespace tensorflow 938