Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/costs/graph_properties.h"
     17 #include "tensorflow/cc/framework/scope.h"
     18 #include "tensorflow/cc/ops/standard_ops.h"
     19 #include "tensorflow/core/framework/node_def_builder.h"
     20 #include "tensorflow/core/framework/tensor_shape.pb.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/grappler/clusters/single_machine.h"
     23 #include "tensorflow/core/grappler/grappler_item.h"
     24 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
     25 #include "tensorflow/core/grappler/inputs/utils.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 #include "tensorflow/core/lib/io/path.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/platform/protobuf.h"
     30 #include "tensorflow/core/platform/test.h"
     31 
     32 namespace tensorflow {
     33 namespace grappler {
     34 namespace {
     35 
     36 const char kTestDataPath[] = "core/grappler/costs/graph_properties_testdata";
     37 
     38 class GraphPropertiesTest : public ::testing::Test {
     39  public:
     40   void SetUp() override {
     41     // Provision a single machine with 3 cpu cores
     42     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
     43     TF_CHECK_OK(cluster_->Provision());
     44   }
     45 
     46   void TearDown() override {
     47     TF_CHECK_OK(cluster_->Shutdown());
     48     cluster_.reset();
     49   }
     50 
     51  protected:
     52   // Returns a string form of <p>, suitable for comparing type and shape.
     53   // Example output for 4-d float tensor: "float: [10,2,30,4]"
     54   string PropToString(const OpInfo::TensorProperties& p) {
     55     string s = strings::StrCat(DataTypeString(p.dtype()), ": ");
     56     if (p.shape().unknown_rank()) {
     57       strings::StrAppend(&s, "?");
     58     } else {
     59       strings::StrAppend(&s, "[");
     60       for (int i = 0; i < p.shape().dim_size(); ++i) {
     61         strings::StrAppend(&s, i == 0 ? "" : ",",
     62                            std::max<int64>(p.shape().dim(i).size(), -1));
     63       }
     64       strings::StrAppend(&s, "]");
     65     }
     66     return s;
     67   }
     68 
     69   std::unique_ptr<SingleMachine> cluster_;
     70 };
     71 
     72 TEST_F(GraphPropertiesTest, StaticProperties) {
     73   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
     74                                           cluster_->GetDeviceNames());
     75   GrapplerItem item;
     76   CHECK(fake_input.NextItem(&item));
     77 
     78   GraphProperties properties(item);
     79   Status s = properties.InferStatically(true);
     80   TF_CHECK_OK(s);
     81 
     82   for (const auto& node : item.graph.node()) {
     83     if (node.op() == "RandomStandardNormal") {
     84       // The node has one input (the shape of the tensor to generate).
     85       EXPECT_EQ(1, properties.GetInputProperties(node.name()).size());
     86       // The const node has one output.
     87       const auto props = properties.GetOutputProperties(node.name());
     88       EXPECT_EQ(1, props.size());
     89       const OpInfo::TensorProperties& prop = props[0];
     90       EXPECT_EQ(DT_FLOAT, prop.dtype());
     91       EXPECT_FALSE(prop.shape().unknown_rank());
     92       EXPECT_EQ(2, prop.shape().dim_size());
     93       EXPECT_EQ(10, prop.shape().dim(0).size());
     94       EXPECT_EQ(1, prop.shape().dim(1).size());
     95     } else if (node.op() == "AddN") {
     96       const auto in_props = properties.GetInputProperties(node.name());
     97       EXPECT_EQ(1, in_props.size());
     98       const OpInfo::TensorProperties& in_prop = in_props[0];
     99       EXPECT_EQ(DT_FLOAT, in_prop.dtype());
    100       EXPECT_FALSE(in_prop.shape().unknown_rank());
    101       EXPECT_EQ(2, in_prop.shape().dim_size());
    102       EXPECT_EQ(10, in_prop.shape().dim(0).size());
    103       EXPECT_EQ(1, in_prop.shape().dim(1).size());
    104       const auto out_props = properties.GetOutputProperties(node.name());
    105       EXPECT_EQ(1, out_props.size());
    106       string in_prop_str;
    107       ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str);
    108       string out_prop_str;
    109       ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
    110                                                         &out_prop_str);
    111       EXPECT_EQ(in_prop_str, out_prop_str);
    112     }
    113   }
    114 }
    115 
    116 TEST_F(GraphPropertiesTest, DynamicProperties) {
    117   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
    118                                           cluster_->GetDeviceNames());
    119   GrapplerItem item;
    120   CHECK(fake_input.NextItem(&item));
    121 
    122   GraphProperties properties(item);
    123   TF_CHECK_OK(cluster_->Initialize(item));
    124   Status s = properties.InferDynamically(cluster_.get());
    125   TF_CHECK_OK(s);
    126 
    127   for (const auto& node : item.graph.node()) {
    128     if (node.op() == "RandomStandardNormal") {
    129       // The random node is missing from the cost graph (why ?)
    130       EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
    131     } else if (node.op() == "AddN") {
    132       // Since the random node is missing, we can't infer the input properties
    133       // of the first AddN node. The other AddN nodes have the expected
    134       // properties.
    135       if (node.name() == "AddN") {
    136         const auto props = properties.GetInputProperties(node.name());
    137         EXPECT_EQ(1, props.size());
    138         const OpInfo::TensorProperties& prop = props[0];
    139         EXPECT_EQ(DT_INVALID, prop.dtype());
    140         EXPECT_TRUE(prop.shape().unknown_rank());
    141       } else {
    142         const auto props = properties.GetInputProperties(node.name());
    143         EXPECT_EQ(1, props.size());
    144         const OpInfo::TensorProperties& prop = props[0];
    145         EXPECT_EQ(DT_FLOAT, prop.dtype());
    146         EXPECT_FALSE(prop.shape().unknown_rank());
    147         EXPECT_EQ(2, prop.shape().dim_size());
    148         EXPECT_EQ(10, prop.shape().dim(0).size());
    149         EXPECT_EQ(1, prop.shape().dim(1).size());
    150         const auto out_props = properties.GetOutputProperties(node.name());
    151         EXPECT_EQ(1, out_props.size());
    152         string prop_str;
    153         ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
    154         string out_prop_str;
    155         ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
    156                                                           &out_prop_str);
    157         EXPECT_EQ(prop_str, out_prop_str);
    158       }
    159     }
    160   }
    161 }
    162 
    163 TEST_F(GraphPropertiesTest, Variables) {
    164   GrapplerItem item;
    165   TF_CHECK_OK(NodeDefBuilder("Var", "Variable")
    166                   .Attr("dtype", DT_FLOAT)
    167                   .Attr("shape", TensorShape({3, 7}))
    168                   .Finalize(item.graph.add_node()));
    169   item.fetch.push_back("Var");
    170 
    171   Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
    172   test::FillIota<float>(&initial_val, 0);
    173   TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const")
    174                   .Attr("dtype", DT_FLOAT)
    175                   .Attr("value", initial_val)
    176                   .Finalize(item.graph.add_node()));
    177   TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign")
    178                   .Input("Var", 0, DT_FLOAT_REF)
    179                   .Input("InitialVal", 0, DT_FLOAT)
    180                   .Finalize(item.graph.add_node()));
    181   item.init_ops.push_back("InitVar");
    182 
    183   {
    184     GraphProperties static_properties(item);
    185     TF_CHECK_OK(static_properties.InferStatically(false));
    186 
    187     const auto props = static_properties.GetOutputProperties("Var");
    188     EXPECT_EQ(1, props.size());
    189     const OpInfo::TensorProperties& prop = props[0];
    190     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
    191     EXPECT_FALSE(prop.shape().unknown_rank());
    192     EXPECT_EQ(2, prop.shape().dim_size());
    193     EXPECT_EQ(3, prop.shape().dim(0).size());
    194     EXPECT_EQ(7, prop.shape().dim(1).size());
    195   }
    196   {
    197     TF_CHECK_OK(cluster_->Initialize(item));
    198     GraphProperties dynamic_properties(item);
    199     TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get()));
    200 
    201     const auto props = dynamic_properties.GetOutputProperties("Var");
    202     EXPECT_EQ(1, props.size());
    203     const OpInfo::TensorProperties& prop = props[0];
    204     EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
    205     EXPECT_FALSE(prop.shape().unknown_rank());
    206     EXPECT_EQ(2, prop.shape().dim_size());
    207     EXPECT_EQ(3, prop.shape().dim(0).size());
    208     EXPECT_EQ(7, prop.shape().dim(1).size());
    209   }
    210 }
    211 
    212 TEST_F(GraphPropertiesTest, VarHandles) {
    213   GrapplerItem item;
    214   TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
    215                   .Attr("dtype", DT_FLOAT)
    216                   .Attr("shape", TensorShape({3, 7}))
    217                   .Finalize(item.graph.add_node()));
    218 
    219   TF_CHECK_OK(NodeDefBuilder("VarRead", "ReadVariableOp")
    220                   .Attr("dtype", DT_FLOAT)
    221                   .Input("Var", 0, DT_RESOURCE)
    222                   .Finalize(item.graph.add_node()));
    223 
    224   GraphProperties properties(item);
    225   TF_CHECK_OK(properties.InferStatically(false));
    226 
    227   const auto props = properties.GetOutputProperties("VarRead");
    228   EXPECT_EQ(1, props.size());
    229   const OpInfo::TensorProperties& prop = props[0];
    230   EXPECT_EQ(DT_FLOAT, prop.dtype());
    231   EXPECT_FALSE(prop.shape().unknown_rank());
    232   EXPECT_EQ(2, prop.shape().dim_size());
    233   EXPECT_EQ(3, prop.shape().dim(0).size());
    234   EXPECT_EQ(7, prop.shape().dim(1).size());
    235 }
    236 
    237 TEST_F(GraphPropertiesTest, Queues) {
    238   // Create a graph with known input shapes, and propagate the shapes through a
    239   // couple of queues.
    240   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
    241 
    242   auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
    243   Output rnd =
    244       ops::RandomNormal(root.WithOpName("rnd"), {3, 7}, DataType::DT_FLOAT);
    245   Output square1 = ops::Square(root.WithOpName("Square1"), rnd);
    246   auto enqueue1 = ops::QueueEnqueue(root.WithOpName("Enqueue1"), q1, {square1});
    247   auto dequeue1 =
    248       ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
    249 
    250   auto q2 =
    251       ops::RandomShuffleQueue(root.WithOpName("Queue2"), {DataType::DT_FLOAT});
    252   Output square2 = ops::Square(root.WithOpName("Square2"), dequeue1[0]);
    253   auto enqueue2 = ops::QueueEnqueue(root.WithOpName("Enqueue2"), q2, {square2});
    254   auto dequeue2 =
    255       ops::QueueDequeue(root.WithOpName("Dequeue2"), q2, {DataType::DT_FLOAT});
    256 
    257   // Create a queue that feeds itself.
    258   auto q3 =
    259       ops::RandomShuffleQueue(root.WithOpName("Queue3"), {DataType::DT_FLOAT});
    260   auto dequeue3 =
    261       ops::QueueDequeue(root.WithOpName("Dequeue3"), q3, {DataType::DT_FLOAT});
    262   auto merge3 = ops::Merge(root.WithOpName("Merge3"), {dequeue3[0], square2});
    263   auto enqueue3 =
    264       ops::QueueEnqueue(root.WithOpName("Enqueue3"), q3, {merge3.output});
    265 
    266   auto q4 =
    267       ops::RandomShuffleQueue(root.WithOpName("Queue4"), {DataType::DT_FLOAT});
    268   auto enqueue4 = ops::QueueEnqueue(root.WithOpName("Enqueue4"), q4, {square2});
    269   auto enqueue4_2 =
    270       ops::QueueEnqueue(root.WithOpName("Enqueue4_2"), q4, {dequeue3[0]});
    271   auto dequeue4 =
    272       ops::QueueDequeue(root.WithOpName("Dequeue4"), q4, {DataType::DT_FLOAT});
    273 
    274   // Create a queue that takes in three tensors.
    275   auto q5 = ops::RandomShuffleQueue(
    276       root.WithOpName("Queue5"),
    277       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
    278   Output rnd2 =
    279       ops::RandomNormal(root.WithOpName("rnd"), {10}, DataType::DT_DOUBLE);
    280   Output rnd3 =
    281       ops::RandomNormal(root.WithOpName("rnd"), {1, 2, 3}, DataType::DT_FLOAT);
    282   auto enqueue5 =
    283       ops::QueueEnqueue(root.WithOpName("Enqueue5"), q5, {rnd, rnd2, rnd3});
    284   auto dequeue5 = ops::QueueDequeue(
    285       root.WithOpName("Dequeue5"), q5,
    286       {DataType::DT_FLOAT, DataType::DT_DOUBLE, DataType::DT_FLOAT});
    287 
    288   GrapplerItem item;
    289   TF_CHECK_OK(root.ToGraphDef(&item.graph));
    290 
    291   GraphProperties properties(item);
    292   TF_CHECK_OK(properties.InferStatically(false));
    293 
    294   const auto props1 = properties.GetOutputProperties("Dequeue1");
    295   ASSERT_EQ(1, props1.size());
    296   EXPECT_EQ("float: [3,7]", PropToString(props1[0]));
    297 
    298   const auto props2 = properties.GetOutputProperties("Dequeue2");
    299   ASSERT_EQ(1, props2.size());
    300   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
    301 
    302   const auto props3 = properties.GetOutputProperties("Dequeue3");
    303   ASSERT_EQ(1, props3.size());
    304   EXPECT_EQ("float: [3,7]", PropToString(props3[0]));
    305 
    306   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
    307   // that we merge the 2 properly to determine the shape of the data coming out
    308   // of the queue.
    309   const auto props4 = properties.GetOutputProperties("Dequeue4");
    310   ASSERT_EQ(1, props4.size());
    311   EXPECT_EQ("float: [3,7]", PropToString(props4[0]));
    312 
    313   // The dequeue5 op shape is known.
    314   const auto props5 = properties.GetOutputProperties("Dequeue5");
    315   ASSERT_EQ(3, props5.size());
    316   EXPECT_EQ("float: [3,7]", PropToString(props5[0]));
    317   EXPECT_EQ("double: [10]", PropToString(props5[1]));
    318   EXPECT_EQ("float: [1,2,3]", PropToString(props5[2]));
    319 }
    320 
    321 TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
    322   // Test graph produced in python using:
    323   /*
    324     with tf.Graph().as_default():
    325       x = tf.constant(2)
    326       y = tf.constant(5)
    327       z = tf.ones([1,1,1])
    328       def f1(): return tf.concat([z, z], axis=0)
    329       def f2(): return tf.concat([z, z], axis=1)
    330       r = tf.cond(tf.less(x, y), f1, f2)
    331       tf.concat([r, r], axis=2)
    332       with open('/tmp/graph.pbtxt', 'w') as f:
    333         f.write(str(tf.get_default_graph().as_graph_def()))
    334    */
    335 
    336   GrapplerItem item;
    337   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    338                                  "merge_without_loops.pbtxt");
    339   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    340   GraphProperties properties(item);
    341   TF_CHECK_OK(properties.InferStatically(false));
    342 
    343   std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
    344   std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
    345                                        "float: [1,2,1]"};
    346   for (int i = 0; i < nodes.size(); i++) {
    347     const auto props = properties.GetOutputProperties(nodes[i]);
    348     const OpInfo::TensorProperties& prop = props[0];
    349     EXPECT_EQ(DT_FLOAT, prop.dtype());
    350     EXPECT_EQ(expected_outputs[i], PropToString(prop));
    351   }
    352 
    353   // The "Less" node should be fed by 2 int32 scalar constant values.
    354   const auto props = properties.GetInputProperties("Less");
    355   EXPECT_EQ(2, props.size());
    356   for (int i = 0; i < props.size(); ++i) {
    357     EXPECT_EQ(DT_INT32, props[i].dtype());
    358     EXPECT_TRUE(props[i].has_value());
    359     EXPECT_EQ("int32: []", PropToString(props[i]));
    360   }
    361 }
    362 
    363 TEST_F(GraphPropertiesTest, WhileLoop) {
    364   // Test graph produced in python using:
    365   /*
    366      with tf.Graph().as_default():
    367        i0 = tf.constant(0)
    368        m0 = tf.placeholder([-1, 2])
    369        c = lambda i, m: i < 10
    370        b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
    371        r = tf.while_loop(
    372               c, b, loop_vars=[i0, m0],
    373               shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
    374        with open('/tmp/graph.pbtxt', 'w') as f:
    375          f.write(str(tf.get_default_graph().as_graph_def()))
    376   */
    377 
    378   GrapplerItem item;
    379   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    380                                  "while_loop.pbtxt");
    381   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    382   GraphProperties properties(item);
    383   TF_CHECK_OK(properties.InferStatically(false));
    384 
    385   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
    386                             "while/Exit_1"};
    387   for (const string& node : nodes) {
    388     const auto props = properties.GetOutputProperties(node);
    389     const OpInfo::TensorProperties& prop = props[0];
    390     EXPECT_EQ(DT_FLOAT, prop.dtype());
    391     EXPECT_EQ("float: [-1,2]", PropToString(prop));
    392   }
    393 
    394   // The loop outputs batch dim should be different from the input batch dim
    395   // since we concatenated along the batch dim.
    396   auto shape_in = properties.GetOutputProperties("ones").at(0).shape();
    397   auto shape_out = properties.GetOutputProperties("while/Exit_1").at(0).shape();
    398   EXPECT_GE(-2, shape_in.dim(0).size());
    399   EXPECT_GE(-2, shape_out.dim(0).size());
    400   EXPECT_NE(shape_in.dim(0).size(), shape_out.dim(0).size());
    401 }
    402 
    403 TEST_F(GraphPropertiesTest, NestedLoop) {
    404   // Test graph produced in python using:
    405   /*
    406     with tf.Graph().as_default():
    407       i0 = tf.constant(0)
    408 
    409       def inner(j, y):
    410         def inner_cond(j, y):
    411           return j < 3
    412 
    413         def inner_body(j, y):
    414           return j+1, tf.concat([y, y], axis=2)
    415 
    416         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y],
    417                              shape_invariants=[i0.get_shape(),
    418                                               tf.TensorShape([None, 1, None])])
    419 
    420       def outer_cond(i, x):
    421         return i < 3
    422 
    423       def outer_body(i, x):
    424         j, y = inner(0, x)
    425         return i+1, tf.concat([x, x], axis=0)
    426 
    427       r = tf.while_loop(outer_cond, outer_body,
    428                         loop_vars=[i0, tf.ones([1, 1, 1])],
    429                         shape_invariants=[i0.get_shape(),
    430                                           tf.TensorShape([None, 1, None])])
    431 
    432       with open('/tmp/graph.pbtxt', 'w') as f:
    433         f.write(str(tf.get_default_graph().as_graph_def()))
    434   */
    435 
    436   GrapplerItem item;
    437   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    438                                  "nested_loop.pbtxt");
    439   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    440   GraphProperties properties(item);
    441   TF_CHECK_OK(properties.InferStatically(false));
    442 
    443   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
    444                                   "while/Exit_1"};
    445   std::vector<string> inner_nodes{"while/while/Merge_1",
    446                                   "while/while/NextIteration_1",
    447                                   "while/while/Exit_1"};
    448   for (const string& node : outer_nodes) {
    449     const auto props = properties.GetOutputProperties(node);
    450     const OpInfo::TensorProperties& prop = props[0];
    451     EXPECT_EQ(DT_FLOAT, prop.dtype());
    452     EXPECT_EQ("float: [-1,1,1]", PropToString(prop));
    453   }
    454   for (const string& node : inner_nodes) {
    455     const auto props = properties.GetOutputProperties(node);
    456     const OpInfo::TensorProperties& prop = props[0];
    457     EXPECT_EQ(DT_FLOAT, prop.dtype());
    458     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
    459   }
    460 }
    461 
    462 TEST_F(GraphPropertiesTest, LoopsAndQueues) {
    463   // Test graph produced in python using:
    464   /*
    465     with tf.Graph().as_default():
    466       i0 = tf.constant(0)
    467       q = tf.FIFOQueue(1, "float")
    468 
    469       def inner(j, y):
    470         def inner_cond(j, y):
    471           return j < 3
    472 
    473         def inner_body(j, y):
    474           return j+1, tf.concat([y, y], axis=0)
    475 
    476         return tf.while_loop(inner_cond, inner_body,
    477                              loop_vars=[j, y],
    478                              shape_invariants=[i0.get_shape(),
    479                                                tf.TensorShape(None)])
    480 
    481       def outer_cond(i, x):
    482         return i < 3
    483 
    484       def outer_body(i, x):
    485         q.enqueue(x)
    486         y = tf.concat([x, x], axis=2)
    487         inner(0, q.dequeue())
    488         return i+1, y
    489 
    490       i, z = tf.while_loop(outer_cond, outer_body,
    491                            loop_vars=[i0, tf.ones([1, 1, 1])],
    492                            shape_invariants=[i0.get_shape(),
    493                                              tf.TensorShape([None, 1, None])])
    494 
    495       with open('/tmp/graph.pbtxt', 'w') as f:
    496         f.write(str(tf.get_default_graph().as_graph_def()))
    497    */
    498 
    499   GrapplerItem item;
    500   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    501                                  "loops_and_queues.pbtxt");
    502   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    503   GraphProperties properties(item);
    504   TF_CHECK_OK(properties.InferStatically(false));
    505 
    506   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
    507                                   "while/Exit_1"};
    508   std::vector<string> inner_nodes{"while/while/Merge_1",
    509                                   "while/while/NextIteration_1",
    510                                   "while/while/Exit_1"};
    511   for (const string& node : outer_nodes) {
    512     const auto props = properties.GetOutputProperties(node);
    513     const OpInfo::TensorProperties& prop = props[0];
    514     EXPECT_EQ(DT_FLOAT, prop.dtype());
    515     EXPECT_EQ("float: [1,1,-1]", PropToString(prop));
    516   }
    517   for (const string& node : inner_nodes) {
    518     const auto props = properties.GetOutputProperties(node);
    519     const OpInfo::TensorProperties& prop = props[0];
    520     EXPECT_EQ(DT_FLOAT, prop.dtype());
    521     EXPECT_EQ("float: [-1,1,-1]", PropToString(prop));
    522   }
    523 }
    524 
    525 TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
    526   // Test graph produced in python using:
    527   /*
    528     with tf.Graph().as_default():
    529       i0 = tf.constant(0)
    530       with tf.variable_scope(VariableScope(reuse=None, use_resource=True)):
    531         v = tf.get_variable(initializer=i0, name='loop_var')
    532 
    533       def inner(j, y):
    534         def inner_cond(j, y):
    535           return j < 3
    536 
    537         def inner_body(j, y):
    538           return j + 1, y + y
    539 
    540         return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y])
    541 
    542       def outer_cond(i, x):
    543         return i < 3
    544 
    545       def outer_body(i, x):
    546         y = x + x
    547         inner(0, v)
    548         return i + 1, y
    549 
    550       v, z = tf.while_loop(outer_cond, outer_body,
    551                            loop_vars=[v, tf.constant(1)])
    552 
    553       with open('/tmp/graph.pbtxt', 'w') as f:
    554         f.write(str(tf.get_default_graph().as_graph_def()))
    555   */
    556 
    557   GrapplerItem item;
    558   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    559                                  "loops_and_resource_vars.pbtxt");
    560   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    561   GraphProperties properties(item);
    562   TF_CHECK_OK(properties.InferStatically(false));
    563 
    564   std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
    565                                   "while/Exit_1"};
    566   std::vector<string> inner_nodes{"while/while/Merge_1",
    567                                   "while/while/NextIteration_1",
    568                                   "while/while/Exit_1"};
    569   for (const string& node : outer_nodes) {
    570     const auto props = properties.GetOutputProperties(node);
    571     const OpInfo::TensorProperties& prop = props[0];
    572     EXPECT_EQ(DT_INT32, prop.dtype());
    573     EXPECT_EQ("int32: []", PropToString(prop));
    574   }
    575   for (const string& node : inner_nodes) {
    576     const auto props = properties.GetOutputProperties(node);
    577     const OpInfo::TensorProperties& prop = props[0];
    578     EXPECT_EQ(DT_INT32, prop.dtype());
    579     EXPECT_EQ("int32: []", PropToString(prop));
    580   }
    581 }
    582 
    583 TEST_F(GraphPropertiesTest, QueuesAndLoops) {
    584   // Test graph produced in python using:
    585   /*
    586     with tf.Graph().as_default():
    587       i0 = tf.constant(0)
    588       q0 = tf.FIFOQueue(1, "float")
    589       q0.enqueue(tf.ones([2, 2]))
    590       q1 = tf.FIFOQueue(1, "float")
    591 
    592       def c(i, m):
    593         return i < 10
    594 
    595       def b(i, m):
    596         return i+1, tf.concat([m, m], axis=0)
    597 
    598       i, m = tf.while_loop(
    599           c, b, loop_vars=[i0,  q0.dequeue()],
    600           shape_invariants=[i0.get_shape(), tf.TensorShape(None)])
    601 
    602       q1.enqueue(m)
    603       v = q1.dequeue();
    604       tf.concat([v, v], axis=1)
    605       with open('/tmp/graph.pbtxt', 'w') as f:
    606         f.write(str(tf.get_default_graph().as_graph_def()))
    607   */
    608 
    609   GrapplerItem item;
    610   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    611                                  "queues_and_loops.pbtxt");
    612   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    613   GraphProperties properties(item);
    614   TF_CHECK_OK(properties.InferStatically(false));
    615 
    616   std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
    617                             "while/Exit_1"};
    618 
    619   for (const string& node : nodes) {
    620     const auto props = properties.GetOutputProperties(node);
    621     const OpInfo::TensorProperties& prop = props[0];
    622     EXPECT_EQ(DT_FLOAT, prop.dtype());
    623     EXPECT_EQ("float: [-1,2]", PropToString(prop));
    624   }
    625 
    626   const auto props = properties.GetOutputProperties("concat");
    627   const OpInfo::TensorProperties& prop = props[0];
    628   EXPECT_EQ(DT_FLOAT, prop.dtype());
    629   EXPECT_EQ("float: [-1,4]", PropToString(prop));
    630 }
    631 
    632 TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
    633   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    634   Output var = ops::Variable(s.WithOpName("var"), TensorShape({128, 256}),
    635                              DataType::DT_FLOAT);
    636   Output filename =
    637       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
    638   Output tensor_name =
    639       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
    640   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
    641                                 DataType::DT_FLOAT);
    642   Output init_restore = ops::Assign(s.WithOpName("init_restore"), var, restore);
    643 
    644   Output shape_and_slice = ops::Const(s.WithOpName("shape_and_slice"),
    645                                       string("256 256 0,128:-"), TensorShape());
    646   Output restore_slice =
    647       ops::RestoreSlice(s.WithOpName("restore_slice"), filename, tensor_name,
    648                         shape_and_slice, DataType::DT_FLOAT);
    649   Output init_restore_slice =
    650       ops::Assign(s.WithOpName("init_restore_slice"), var, restore_slice);
    651 
    652   Output restore_v2 =
    653       ops::RestoreSlice(s.WithOpName("restore_v2"), filename, tensor_name,
    654                         shape_and_slice, DataType::DT_FLOAT);
    655   Output init_restore_v2 =
    656       ops::Assign(s.WithOpName("init_restore_v2"), var, restore_v2);
    657 
    658   GrapplerItem item;
    659   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    660   item.fetch.push_back("init_restore");
    661 
    662   GraphProperties properties(item);
    663   TF_CHECK_OK(properties.InferStatically(false));
    664 
    665   const auto restore_props = properties.GetOutputProperties("restore");
    666   const OpInfo::TensorProperties& restore_prop = restore_props[0];
    667   EXPECT_EQ(DT_FLOAT, restore_prop.dtype());
    668   EXPECT_EQ("float: [128,256]", PropToString(restore_prop));
    669 
    670   const auto restore_slice_props =
    671       properties.GetOutputProperties("restore_slice");
    672   const OpInfo::TensorProperties& restore_slice_prop = restore_slice_props[0];
    673   EXPECT_EQ(DT_FLOAT, restore_slice_prop.dtype());
    674   EXPECT_EQ("float: [128,256]", PropToString(restore_slice_prop));
    675 
    676   const auto restorev2_props = properties.GetOutputProperties("restore_v2");
    677   const OpInfo::TensorProperties& restorev2_prop = restorev2_props[0];
    678   EXPECT_EQ(DT_FLOAT, restorev2_prop.dtype());
    679   EXPECT_EQ("float: [128,256]", PropToString(restorev2_prop));
    680 
    681   // Check input shapes of assign op are propagted correctly.
    682   const auto input_props = properties.GetInputProperties("init_restore");
    683   ASSERT_EQ(2, input_props.size());
    684   const OpInfo::TensorProperties& input_prop = input_props[1];
    685   EXPECT_EQ(DT_FLOAT, input_prop.dtype());
    686   EXPECT_EQ("float: [128,256]", PropToString(input_prop));
    687 }
    688 
    689 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
    690   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    691   Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
    692                              DataType::DT_FLOAT);
    693   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
    694                               DataType::DT_FLOAT);
    695   Output filename =
    696       ops::Const(s.WithOpName("filename"), string("model"), TensorShape());
    697   Output tensor_name =
    698       ops::Const(s.WithOpName("tensorname"), string("a"), TensorShape());
    699   Output restore = ops::Restore(s.WithOpName("restore"), filename, tensor_name,
    700                                 DataType::DT_FLOAT);
    701   Output init = ops::Assign(s.WithOpName("init"), var, restore);
    702   Output init2 = ops::Assign(s.WithOpName("init2"), var2, restore);
    703 
    704   GrapplerItem item;
    705   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    706   item.fetch.push_back("init");
    707   item.fetch.push_back("init2");
    708 
    709   GraphProperties properties(item);
    710   TF_CHECK_OK(properties.InferStatically(false));
    711 
    712   const auto props = properties.GetOutputProperties("restore");
    713   const OpInfo::TensorProperties& prop = props[0];
    714   EXPECT_EQ(DT_FLOAT, prop.dtype());
    715   EXPECT_EQ("float: [128,256]", PropToString(prop));
    716 }
    717 
    718 TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
    719   // Test graph produced in python using:
    720   /*
    721     @function.Defun(*[tf.float32] * 2, noinline=True)
    722     def MyAdd(x, y):
    723       return tf.add(x,y)
    724 
    725     with tf.Graph().as_default():
    726       x = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
    727       y = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
    728       z = MyAdd(x, y)
    729       z = MyAdd(x, z)
    730   */
    731   // Check that the shape of the second MyAdd node propagates
    732   // correctly.
    733   GrapplerItem item;
    734   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    735                                  "simple_function.pbtxt");
    736   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    737   GraphProperties properties(item);
    738   TF_CHECK_OK(properties.InferStatically(false));
    739   const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1");
    740   const OpInfo::TensorProperties& prop = props[0];
    741   EXPECT_EQ(DT_FLOAT, prop.dtype());
    742   EXPECT_FALSE(prop.shape().unknown_rank());
    743   EXPECT_EQ(2, prop.shape().dim_size());
    744   EXPECT_EQ(1, prop.shape().dim(0).size());
    745   EXPECT_EQ(2, prop.shape().dim(1).size());
    746 
    747   PartialTensorShape shape(prop.shape());
    748   EXPECT_TRUE(shape.IsFullyDefined());
    749   EXPECT_FALSE(shape.unknown_rank());
    750 }
    751 
    752 TEST_F(GraphPropertiesTest, SymbolicShapes) {
    753   // Build a simple graph with placeholders of unknown dimensions. These
    754   // dimensions will be encoded symbolically.
    755   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    756 
    757   Output a =
    758       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
    759                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
    760   Output b =
    761       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
    762                        ops::Placeholder::Shape(PartialTensorShape({-1})));
    763   Output c = ops::Identity(s.WithOpName("c"), a);
    764   Output d = ops::Identity(s.WithOpName("d"), b);
    765   Output e = ops::Add(s.WithOpName("e"), c, d);
    766   Output f = ops::Add(s.WithOpName("f"), a, c);
    767 
    768   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
    769   Output g = ops::Shape(s.WithOpName("g"), c);
    770   Output h = ops::Fill(s.WithOpName("h"), g, zero);
    771 
    772   GrapplerItem item;
    773   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    774 
    775   GraphProperties properties(item);
    776   TF_CHECK_OK(properties.InferStatically(false));
    777   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
    778   const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
    779   EXPECT_EQ(2, shape_a.dim_size());
    780   EXPECT_EQ(shape_a.dim_size(), shape_c.dim_size());
    781   EXPECT_GE(-2, shape_a.dim(0).size());
    782   EXPECT_EQ(shape_a.dim(0).size(), shape_c.dim(0).size());
    783   EXPECT_GE(-2, shape_a.dim(1).size());
    784   EXPECT_EQ(shape_a.dim(1).size(), shape_c.dim(1).size());
    785 
    786   PartialTensorShape shape(shape_a);
    787   EXPECT_FALSE(shape.IsFullyDefined());
    788   EXPECT_FALSE(shape.unknown_rank());
    789 
    790   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
    791   const auto shape_d = properties.GetOutputProperties("d").at(0).shape();
    792   EXPECT_EQ(1, shape_b.dim_size());
    793   EXPECT_EQ(shape_b.dim_size(), shape_d.dim_size());
    794   EXPECT_GE(-2, shape_b.dim(0).size());
    795   EXPECT_NE(shape_a.dim(0).size(), shape_b.dim(0).size());
    796   EXPECT_EQ(shape_b.dim(0).size(), shape_d.dim(0).size());
    797 
    798   const auto shape_e = properties.GetOutputProperties("e").at(0).shape();
    799   ASSERT_EQ(2, shape_e.dim_size());
    800   EXPECT_EQ(shape_e.dim(0).size(), shape_c.dim(0).size());
    801   EXPECT_NE(shape_e.dim(1).size(), shape_c.dim(1).size());
    802   EXPECT_NE(shape_e.dim(0).size(), shape_d.dim(0).size());
    803 
    804   const auto shape_f = properties.GetOutputProperties("f").at(0).shape();
    805   ASSERT_EQ(2, shape_f.dim_size());
    806   EXPECT_EQ(shape_f.dim(0).size(), shape_a.dim(0).size());
    807   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
    808 
    809   const auto shape_h = properties.GetOutputProperties("h").at(0).shape();
    810   ASSERT_EQ(2, shape_f.dim_size());
    811   EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
    812   EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
    813 }
    814 
    815 TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
    816   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    817   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
    818   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
    819   Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
    820   GrapplerItem item;
    821   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    822   // Create a graph with node a removed (say by some graph optimization
    823   // pass), noting that node c is colocated with a. This is fine as it
    824   // is in the late stage of graph execution, the colocation constraints have
    825   // been validated previously and the device placement of nodes has completed.
    826   GraphDef optimized_graph;
    827   for (const auto& node : item.graph.node()) {
    828     if (node.name() != "a") {
    829       *optimized_graph.add_node() = node;
    830     }
    831   }
    832   item.graph.Swap(&optimized_graph);
    833   GraphProperties properties(item);
    834   // This function should return OK, since it doesn't validate the colocation
    835   // constraints internally.
    836   TF_EXPECT_OK(properties.InferStatically(false));
    837 }
    838 
    839 TEST_F(GraphPropertiesTest, ShapeTracking) {
    840   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    841   Output a =
    842       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
    843                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
    844   Output b =
    845       ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
    846                        ops::Placeholder::Shape(PartialTensorShape({-1})));
    847   Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
    848   auto shp = ops::ShapeN(s.WithOpName("shapes"), {a, b});
    849   Output o1 = ops::Fill(s.WithOpName("o1"), shp[0], zero);
    850   Output o2 = ops::Fill(s.WithOpName("o2"), shp[1], zero);
    851 
    852   GrapplerItem item;
    853   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    854 
    855   GraphProperties properties(item);
    856   TF_CHECK_OK(properties.InferStatically(false));
    857   const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
    858   const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
    859   const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
    860   const auto shape_o2 = properties.GetOutputProperties("o2").at(0).shape();
    861   EXPECT_EQ(shape_a.DebugString(), shape_o1.DebugString());
    862   EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
    863 }
    864 
    865 TEST_F(GraphPropertiesTest, FedNodes) {
    866   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
    867                                           cluster_->GetDeviceNames());
    868   GrapplerItem item;
    869   CHECK(fake_input.NextItem(&item));
    870 
    871   {
    872     // Conservative shape analysis: the shape of fed ports should be unknown
    873     GraphProperties properties(item);
    874     Status s = properties.InferStatically(false);
    875     TF_CHECK_OK(s);
    876     for (const auto& node : item.graph.node()) {
    877       if (node.op() == "Const") {
    878         continue;
    879       }
    880       const auto in_props = properties.GetInputProperties(node.name());
    881       EXPECT_EQ(1, in_props.size());
    882       const OpInfo::TensorProperties& in_prop = in_props[0];
    883       const auto out_props = properties.GetOutputProperties(node.name());
    884       EXPECT_EQ(1, out_props.size());
    885       const OpInfo::TensorProperties& out_prop = out_props[0];
    886 
    887       if (node.name() == "x") {
    888         // x is fed: its input should have a known shape, while its output
    889         // doesn't
    890         EXPECT_FALSE(in_prop.shape().unknown_rank());
    891         EXPECT_EQ(1, in_prop.shape().dim_size());
    892         EXPECT_EQ(2, in_prop.shape().dim(0).size());
    893         EXPECT_TRUE(out_prop.shape().unknown_rank());
    894       } else if (node.op() == "Square" || node.op() == "AddN") {
    895         // These nodes are in the fanout of x: their shapes should be unknown.
    896         EXPECT_TRUE(in_prop.shape().unknown_rank());
    897         EXPECT_TRUE(out_prop.shape().unknown_rank());
    898       }
    899     }
    900   }
    901   {
    902     // Optimistic shape analysis: the shape of fed ports should be derived from
    903     // the shape of the fanin.
    904     GraphProperties properties(item);
    905     Status s = properties.InferStatically(true);
    906     TF_CHECK_OK(s);
    907     for (const auto& node : item.graph.node()) {
    908       if (node.op() == "Square" || node.op() == "AddN") {
    909         const auto in_props = properties.GetInputProperties(node.name());
    910         EXPECT_EQ(1, in_props.size());
    911         const OpInfo::TensorProperties& in_prop = in_props[0];
    912         EXPECT_EQ(DT_FLOAT, in_prop.dtype());
    913         EXPECT_FALSE(in_prop.shape().unknown_rank());
    914         EXPECT_EQ(2, in_prop.shape().dim_size());
    915         const auto out_props = properties.GetOutputProperties(node.name());
    916         EXPECT_EQ(1, out_props.size());
    917         const OpInfo::TensorProperties& out_prop = out_props[0];
    918         EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
    919       }
    920     }
    921   }
    922 }
    923 
    924 TEST_F(GraphPropertiesTest, Performance) {
    925   // Load a large graph with many nested loops to make sure we can infer shapes
    926   // quickly.
    927   GrapplerItem item;
    928   string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
    929                                  "large_graph.pbtxt.html");
    930   TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
    931   GraphProperties properties(item);
    932   TF_CHECK_OK(properties.InferStatically(false));
    933 }
    934 
    935 }  // namespace
    936 }  // namespace grappler
    937 }  // namespace tensorflow
    938