Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
     17 #include "tensorflow/cc/ops/array_ops_internal.h"
     18 #include "tensorflow/cc/ops/standard_ops.h"
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/grappler/grappler_item.h"
     22 #include "tensorflow/core/grappler/utils.h"
     23 #include "tensorflow/core/grappler/utils/grappler_test.h"
     24 #include "tensorflow/core/lib/core/status_test_util.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 
     27 namespace tensorflow {
     28 namespace grappler {
     29 namespace {
     30 
     31 class ConstantFoldingTest : public GrapplerTest {};
     32 
     33 TEST_F(ConstantFoldingTest, SimpleFolding) {
     34   // Build a simple graph with a few trivially prunable ops.
     35   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
     36 
     37   Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
     38   Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
     39   Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b});
     40   Output d = ops::AddN(s.WithOpName("d"), {b, c});
     41 
     42   GrapplerItem item;
     43   item.fetch.push_back("d");
     44   TF_CHECK_OK(s.ToGraphDef(&item.graph));
     45 
     46   ConstantFolding fold(nullptr /* cpu_device */);
     47   GraphDef output;
     48   Status status = fold.Optimize(nullptr, item, &output);
     49   TF_EXPECT_OK(status);
     50 
     51   EXPECT_EQ(1, output.node_size());
     52 
     53   const NodeDef& node_d = output.node(0);
     54   EXPECT_EQ("d", node_d.name());
     55   EXPECT_EQ("Const", node_d.op());
     56 
     57   std::vector<string> fetch = {"d"};
     58   auto tensors_expected = EvaluateNodes(item.graph, fetch);
     59   auto tensors = EvaluateNodes(output, fetch);
     60   EXPECT_EQ(1, tensors_expected.size());
     61   EXPECT_EQ(1, tensors.size());
     62   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
     63 }
     64 
     65 TEST_F(ConstantFoldingTest, AddTree) {
     66   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
     67 
     68   Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2});
     69   Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2});
     70   Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
     71                               ops::Placeholder::Shape(TensorShape({2, 2})));
     72   Output add_child = ops::Add(s.WithOpName("add_child"), c2, x);
     73   Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child),
     74                          1.0f, {1});
     75   Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child);
     76 
     77   Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
     78                               ops::Placeholder::Shape(TensorShape({2, 2})));
     79   Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2});
     80   Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2});
     81   Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2});
     82   Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y);
     83   Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child);
     84   Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x);
     85   Output addmul_parent =
     86       ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child);
     87 
     88   GrapplerItem item;
     89   item.fetch = {"add_parent", "mul_parent", "addmul_parent"};
     90   TF_CHECK_OK(s.ToGraphDef(&item.graph));
     91 
     92   ConstantFolding fold(nullptr /* cpu_device */);
     93   GraphDef output;
     94   Status status = fold.Optimize(nullptr, item, &output);
     95   TF_EXPECT_OK(status);
     96 
     97   // We expect the following rewrite(s) to occur:
     98   //
     99   //    +                +             +
    100   //   / \              / \           / \
    101   // 1.0  +     -->    x   +    -->  x  3.0
    102   //     / \              / \
    103   //   2.0  x           1.0 2.0
    104   //
    105   //    *                *             *
    106   //   / \              / \           / \
    107   // 4.0  *     -->    y   *    -->  y  20.0
    108   //     / \              / \
    109   //   5.0  y           4.0 5.0
    110 
    111   EXPECT_EQ(11, output.node_size());
    112   for (const auto& node : output.node()) {
    113     if (node.name() == "add_child") {
    114       EXPECT_EQ("Const", node.op());
    115       TensorProto t = node.attr().at("value").tensor();
    116       EXPECT_EQ(1, t.tensor_shape().dim_size());
    117       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
    118     } else if (node.name() == "add_parent") {
    119       EXPECT_EQ("Add", node.op());
    120       EXPECT_EQ(2, node.input_size());
    121       EXPECT_EQ("x", node.input(0));
    122       EXPECT_EQ("add_child", node.input(1));
    123     } else if (node.name() == "mul_child") {
    124       EXPECT_EQ("Const", node.op());
    125       TensorProto t = node.attr().at("value").tensor();
    126       EXPECT_EQ(1, t.tensor_shape().dim_size());
    127       EXPECT_EQ(2, t.tensor_shape().dim(0).size());
    128     } else if (node.name() == "mul_parent") {
    129       EXPECT_EQ("Mul", node.op());
    130       EXPECT_EQ(2, node.input_size());
    131       EXPECT_EQ("y", node.input(0));
    132       EXPECT_EQ("mul_child", node.input(1));
    133     } else if (node.name() == "addmul_child") {
    134       // Unchanged.
    135       EXPECT_EQ("Add", node.op());
    136       EXPECT_EQ(2, node.input_size());
    137       EXPECT_EQ("c4", node.input(0));
    138       EXPECT_EQ("x", node.input(1));
    139     }
    140   }
    141 
    142   // Check that the result nodes have the expected value.
    143   std::vector<string> fetch = {"c3", "c20"};
    144   auto tensor_expected = EvaluateNodes(item.graph, fetch);
    145   EXPECT_EQ(fetch.size(), tensor_expected.size());
    146   fetch = {"add_child", "mul_child"};
    147   auto tensors = EvaluateNodes(output, fetch);
    148   EXPECT_EQ(fetch.size(), tensors.size());
    149   for (int i = 0; i < fetch.size(); i++) {
    150     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
    151   }
    152 }
    153 
    154 TEST_F(ConstantFoldingTest, NeutralElement) {
    155   for (bool use_const : {true, false}) {
    156     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    157     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
    158                                 ops::Placeholder::Shape(TensorShape({2, 2})));
    159     Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
    160                                 ops::Placeholder::Shape(TensorShape({2, 2})));
    161     Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
    162                                 ops::Placeholder::Shape(TensorShape({3, 2})));
    163     Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT,
    164                                 ops::Placeholder::Shape(TensorShape({2, 3})));
    165     Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
    166                                    ops::Placeholder::Shape(TensorShape({2})));
    167     Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
    168                               : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
    169     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
    170     Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
    171                              : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
    172     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
    173     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
    174     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
    175     Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
    176     Output mul5 = ops::Mul(s.WithOpName("mul5"), x, zeros_1d);
    177     Output mul6 = ops::Mul(s.WithOpName("mul6"), zeros_1d, y);
    178     Output div1 = ops::Div(s.WithOpName("div1"), x, ones);
    179     Output div2 = ops::Div(s.WithOpName("div2"), ones, y);
    180     Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros);
    181     Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y);
    182     Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros);
    183     Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
    184     Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
    185     Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
    186     Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
    187     Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
    188     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
    189     Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y);
    190     Output addn =
    191         ops::AddN(s.WithOpName("addn"),
    192                   {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1,
    193                    matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
    194     GrapplerItem item;
    195     TF_CHECK_OK(s.ToGraphDef(&item.graph));
    196     item.fetch = {"addn", "matmul3", "matmul4"};
    197 
    198     ConstantFolding optimizer(nullptr /* cpu_device */);
    199     GraphDef output;
    200     Status status = optimizer.Optimize(nullptr, item, &output);
    201     TF_EXPECT_OK(status);
    202 
    203     EXPECT_EQ(27, output.node_size());
    204     for (int i = 0; i < output.node_size(); ++i) {
    205       const NodeDef& node = output.node(i);
    206       const string& name = node.name();
    207       if (name == "mul1") {
    208         EXPECT_EQ("Const", node.op());
    209         EXPECT_EQ("^x", node.input(0));
    210         EXPECT_EQ("^zeros", node.input(1));
    211       } else if (name == "mul2") {
    212         EXPECT_EQ("Const", node.op());
    213         EXPECT_EQ("^zeros", node.input(0));
    214         EXPECT_EQ("^y", node.input(1));
    215       } else if (name == "mul3") {
    216         EXPECT_EQ("Snapshot", node.op());
    217         EXPECT_EQ("x", node.input(0));
    218         EXPECT_EQ("^ones", node.input(1));
    219       } else if (name == "mul4") {
    220         EXPECT_EQ("Snapshot", node.op());
    221         EXPECT_EQ("y", node.input(0));
    222         EXPECT_EQ("^ones", node.input(1));
    223       } else if (name == "mul5") {
    224         EXPECT_EQ("Const", node.op());
    225         EXPECT_EQ("^x", node.input(0));
    226         EXPECT_EQ("^zeros_1d", node.input(1));
    227       } else if (name == "mul6") {
    228         EXPECT_EQ("Const", node.op());
    229         EXPECT_EQ("^zeros_1d", node.input(0));
    230         EXPECT_EQ("^y", node.input(1));
    231       } else if (name == "div1") {
    232         EXPECT_EQ("Snapshot", node.op());
    233         EXPECT_EQ("x", node.input(0));
    234         EXPECT_EQ("^ones", node.input(1));
    235       } else if (name == "div2") {
    236         EXPECT_EQ("Reciprocal", node.op());
    237         EXPECT_EQ("y", node.input(0));
    238         EXPECT_EQ("^ones", node.input(1));
    239       } else if (name == "matmul1") {
    240         EXPECT_EQ("Const", node.op());
    241         EXPECT_EQ("^x", node.input(0));
    242         EXPECT_EQ("^zeros", node.input(1));
    243       } else if (name == "matmul2") {
    244         EXPECT_EQ("Const", node.op());
    245         EXPECT_EQ("^zeros", node.input(0));
    246         EXPECT_EQ("^y", node.input(1));
    247       } else if (name == "matmul3") {
    248         EXPECT_EQ("Const", node.op());
    249         EXPECT_EQ("^a", node.input(0));
    250         EXPECT_EQ("^zeros", node.input(1));
    251         TensorProto t = node.attr().at("value").tensor();
    252         EXPECT_EQ(1, t.float_val_size());
    253         EXPECT_EQ(0, t.float_val(0));
    254         EXPECT_EQ(2, t.tensor_shape().dim_size());
    255         EXPECT_EQ(3, t.tensor_shape().dim(0).size());
    256         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
    257       } else if (name == "matmul4") {
    258         EXPECT_EQ("Const", node.op());
    259         EXPECT_EQ("^zeros", node.input(0));
    260         EXPECT_EQ("^b", node.input(1));
    261         TensorProto t = node.attr().at("value").tensor();
    262         EXPECT_EQ(1, t.float_val_size());
    263         EXPECT_EQ(0, t.float_val(0));
    264         EXPECT_EQ(2, t.tensor_shape().dim_size());
    265         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
    266         EXPECT_EQ(3, t.tensor_shape().dim(1).size());
    267       } else if (name == "add1") {
    268         EXPECT_EQ("Snapshot", node.op());
    269         EXPECT_EQ("x", node.input(0));
    270         EXPECT_EQ("^zeros", node.input(1));
    271       } else if (name == "add2") {
    272         EXPECT_EQ("Snapshot", node.op());
    273         EXPECT_EQ("y", node.input(0));
    274         EXPECT_EQ("^zeros", node.input(1));
    275       } else if (name == "bias_add1") {
    276         EXPECT_EQ("Snapshot", node.op());
    277         EXPECT_EQ("x", node.input(0));
    278         EXPECT_EQ("^zeros_1d", node.input(1));
    279       } else if (name == "bias_add2") {
    280         // We don't eliminate this one, because it requires broadcasting.
    281         EXPECT_EQ("BiasAdd", node.op());
    282         EXPECT_EQ("zeros", node.input(0));
    283         EXPECT_EQ("bias", node.input(1));
    284       } else if (name == "sub1") {
    285         EXPECT_EQ("Snapshot", node.op());
    286         EXPECT_EQ("x", node.input(0));
    287         EXPECT_EQ("^zeros", node.input(1));
    288       } else if (name == "sub2") {
    289         // We don't handle this case yet.
    290         EXPECT_EQ("Sub", node.op());
    291         EXPECT_EQ("zeros", node.input(0));
    292         EXPECT_EQ("y", node.input(1));
    293       }
    294       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
    295                                                "mul6", "matmul1", "matmul2"};
    296       if (square_zero_const.count(name) > 0) {
    297         TensorProto t = node.attr().at("value").tensor();
    298         EXPECT_EQ(1, t.float_val_size());
    299         EXPECT_EQ(0, t.float_val(0));
    300         EXPECT_EQ(2, t.tensor_shape().dim_size());
    301         EXPECT_EQ(2, t.tensor_shape().dim(0).size());
    302         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
    303       }
    304     }
    305   }
    306 }
    307 
    308 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
    309   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    310   Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
    311   Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
    312                                ops::Placeholder::Shape(TensorShape({2, 2})));
    313   Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
    314                                ops::Placeholder::Shape(TensorShape({2, 2})));
    315   Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
    316   Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
    317   Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
    318   Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
    319   Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
    320 
    321   GrapplerItem item;
    322   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    323   item.fetch = {"div_f", "div_i", "realdiv"};
    324   ConstantFolding optimizer(nullptr /* cpu_device */);
    325   GraphDef output;
    326   Status status = optimizer.Optimize(nullptr, item, &output);
    327   TF_EXPECT_OK(status);
    328 
    329   EXPECT_EQ(8, output.node_size());
    330   for (int i = 0; i < output.node_size(); ++i) {
    331     const NodeDef& node = output.node(i);
    332     const string& name = node.name();
    333     if (name == "div_i") {
    334       // Integer division is unchanged.
    335       EXPECT_EQ("Div", node.op());
    336       EXPECT_EQ("xi", node.input(0));
    337       EXPECT_EQ("ci", node.input(1));
    338     } else if (name == "div_f") {
    339       EXPECT_EQ("Mul", node.op());
    340       EXPECT_EQ("xf", node.input(0));
    341       EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
    342     } else if (name == "realdiv") {
    343       EXPECT_EQ("Mul", node.op());
    344       EXPECT_EQ("xf", node.input(0));
    345       EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
    346     } else if (name == "ConstantFolding/div_f_recip") {
    347       EXPECT_EQ("Const", node.op());
    348       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
    349       TensorProto t = node.attr().at("value").tensor();
    350       EXPECT_EQ(DT_FLOAT, t.dtype());
    351       EXPECT_EQ(1, t.tensor_shape().dim_size());
    352       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
    353     } else if (name == "ConstantFolding/realdiv_recip") {
    354       EXPECT_EQ("Const", node.op());
    355       EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
    356       TensorProto t = node.attr().at("value").tensor();
    357       EXPECT_EQ(DT_FLOAT, t.dtype());
    358       EXPECT_EQ(1, t.tensor_shape().dim_size());
    359       EXPECT_EQ(1, t.tensor_shape().dim(0).size());
    360     }
    361   }
    362 
    363   // Check that the reciprocals have the expected value.
    364   std::vector<string> fetch = {"cf_half"};
    365   auto tensor_expected = EvaluateNodes(item.graph, fetch);
    366   EXPECT_EQ(fetch.size(), tensor_expected.size());
    367   fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
    368   auto tensors = EvaluateNodes(output, fetch);
    369   EXPECT_EQ(fetch.size(), tensors.size());
    370   for (int i = 0; i < fetch.size(); i++) {
    371     test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
    372   }
    373 }
    374 
    375 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
    376   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    377   Output x_known =
    378       ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT,
    379                        ops::Placeholder::Shape(TensorShape({2, 2})));
    380   Output x_partially_known =
    381       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
    382                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
    383   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
    384   Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known);
    385   Output zeros_partially_known =
    386       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
    387   Output zeros_unknown =
    388       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
    389 
    390   // Multiplies without any additional ops to supply the output shape.
    391   int count = 0;
    392   std::vector<Output> muls;
    393   std::unordered_set<string> not_converted;
    394   std::unordered_set<string> to_const;
    395   std::unordered_set<string> to_identity;
    396   for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) {
    397     for (const auto* zeros :
    398          {&zeros_known, &zeros_partially_known, &zeros_unknown}) {
    399       const string name = strings::StrCat("mul_", count++);
    400       muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros));
    401       if (x == &x_partially_known && zeros == &zeros_partially_known) {
    402         to_identity.insert(name);
    403       } else if (x == &x_unknown || zeros == &zeros_unknown) {
    404         not_converted.insert(name);
    405       } else {
    406         to_const.insert(name);
    407       }
    408     }
    409   }
    410 
    411   GrapplerItem item;
    412   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    413 
    414   ConstantFolding optimizer(nullptr /* cpu_device */);
    415   GraphDef output;
    416   Status status = optimizer.Optimize(nullptr, item, &output);
    417   TF_EXPECT_OK(status);
    418   LOG(INFO) << output.DebugString();
    419 
    420   EXPECT_EQ(15, output.node_size());
    421   for (int i = 0; i < output.node_size(); ++i) {
    422     const NodeDef& node = output.node(i);
    423     const string& name = node.name();
    424     if (to_const.count(name) > 0) {
    425       EXPECT_EQ("Const", node.op()) << node.name();
    426     } else if (to_identity.count(name) > 0) {
    427       EXPECT_EQ("Identity", node.op()) << node.name();
    428     } else if (not_converted.count(name) > 0) {
    429       EXPECT_EQ("Mul", node.op()) << node.name();
    430     }
    431   }
    432 }
    433 
    434 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
    435   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    436   Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2});
    437   Output x_partially_known =
    438       ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT,
    439                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
    440   Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT);
    441   Output zeros_partially_known =
    442       ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known);
    443   Output zeros_unknown =
    444       ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown);
    445 
    446   // If at least one of the inputs to AddN has a known shape, shape inference
    447   // will propagate the shape back to the inputs of AddN, making the
    448   // output shapes of all its inputs known
    449   std::vector<Output> muls_deduced_output_shape;
    450   std::unordered_set<string> to_const;
    451   int count = 0;
    452   for (const auto& x : {x_partially_known, x_unknown}) {
    453     for (const auto& zeros : {zeros_partially_known, zeros_unknown}) {
    454       const string name = strings::StrCat("mul_", count++);
    455       muls_deduced_output_shape.push_back(
    456           ops::Mul(s.WithOpName(name), x, zeros));
    457       to_const.insert(name);
    458     }
    459   }
    460   // We add a known shape as input to AddN to propagate it back to the
    461   // multiplies above, which means they can all be turned into Const nodes.
    462   muls_deduced_output_shape.push_back(known_shape);
    463   Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape);
    464 
    465   GrapplerItem item;
    466   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    467 
    468   ConstantFolding optimizer(nullptr /* cpu_device */);
    469   GraphDef output;
    470   Status status = optimizer.Optimize(nullptr, item, &output);
    471   TF_EXPECT_OK(status);
    472   LOG(INFO) << output.DebugString();
    473 
    474   EXPECT_EQ(10, output.node_size());
    475   for (int i = 0; i < output.node_size(); ++i) {
    476     const NodeDef& node = output.node(i);
    477     const string& name = node.name();
    478     if (to_const.count(name) > 0) {
    479       EXPECT_EQ("Const", node.op()) << node.name();
    480       EXPECT_EQ(2, node.input_size());
    481       EXPECT_TRUE(IsControlInput(node.input(0)));
    482       EXPECT_TRUE(IsControlInput(node.input(1)));
    483     }
    484   }
    485 }
    486 
    487 TEST_F(ConstantFoldingTest, CreateConstNodes) {
    488   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    489 
    490 #define MAKE_TEST_GRAPH(TYPE)                                               \
    491   Output TYPE##_const =                                                     \
    492       ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
    493   Output TYPE##_mul =                                                       \
    494       ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const);     \
    495   Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
    496 
    497   MAKE_TEST_GRAPH(float);
    498   MAKE_TEST_GRAPH(double);
    499   MAKE_TEST_GRAPH(int64);
    500   MAKE_TEST_GRAPH(int32);
    501   MAKE_TEST_GRAPH(int16);
    502   MAKE_TEST_GRAPH(int8);
    503   MAKE_TEST_GRAPH(uint8);
    504 #undef MAKE_TEST_GRAPH
    505 
    506   Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
    507   Output bool_and =
    508       ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
    509   Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
    510 
    511   GrapplerItem item;
    512   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    513   ConstantFolding fold(nullptr /* cpu_device */);
    514   GraphDef output;
    515   Status status = fold.Optimize(nullptr, item, &output);
    516   TF_EXPECT_OK(status);
    517 
    518   EXPECT_EQ(24, output.node_size());
    519   for (const NodeDef& node : output.node()) {
    520 #define CHECK_RESULT(TYPE, FIELD)                                             \
    521   if (node.name() == #TYPE "_mul") {                                          \
    522     EXPECT_EQ(5,                                                              \
    523               node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
    524     EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size());        \
    525     EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0));      \
    526   }
    527 
    528     CHECK_RESULT(float, float);
    529     CHECK_RESULT(double, double);
    530     CHECK_RESULT(int64, int64);
    531     CHECK_RESULT(int32, int);
    532     CHECK_RESULT(int16, int);
    533     CHECK_RESULT(int8, int);
    534     CHECK_RESULT(uint8, int);
    535 #undef CHECK_RESULT
    536 
    537     if (node.name() == "bool_and") {
    538       EXPECT_EQ(5,
    539                 node.attr().at("value").tensor().tensor_shape().dim(0).size());
    540       EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
    541       EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
    542     }
    543   }
    544 }
    545 
    546 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
    547   // Build a simple graph with a few trivially prunable ops.
    548   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
    549 
    550   Output a = ops::Const(s.WithOpName("a"), 10, {5});
    551   auto b = ops::Unique(s.WithOpName("b"), {a});
    552   Output c = ops::Identity(s.WithOpName("c"), {b.y});
    553   Output d = ops::Identity(s.WithOpName("d"), {b.idx});
    554   Output e = ops::Identity(s.WithOpName("e"), {c});
    555   Output f = ops::Identity(s.WithOpName("f"), {d});
    556 
    557   GrapplerItem item;
    558   item.fetch.push_back("e");
    559   item.fetch.push_back("f");
    560   TF_CHECK_OK(s.ToGraphDef(&item.graph));
    561 
    562   ConstantFolding fold(nullptr /* cpu_device */);
    563   GraphDef output;
    564   Status status = fold.Optimize(nullptr, item, &output);
    565   TF_EXPECT_OK(status);
    566 
    567   EXPECT_EQ(2, output.node_size());
    568 
    569   const NodeDef& new_c = output.node(0);
    570   EXPECT_EQ("e", new_c.name());
    571   EXPECT_EQ("Const", new_c.op());
    572 
    573   const NodeDef& new_d = output.node(1);
    574   EXPECT_EQ("f", new_d.name());
    575   EXPECT_EQ("Const", new_d.op());
    576 
    577   std::vector<string> fetch = {"e", "f"};
    578   auto tensors_expected = EvaluateNodes(item.graph, fetch);
    579   auto tensors = EvaluateNodes(output, fetch);
    580   EXPECT_EQ(fetch.size(), tensors_expected.size());
    581   EXPECT_EQ(fetch.size(), tensors.size());
    582   for (int i = 0; i < fetch.size(); i++) {
    583     test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
    584   }
    585 }
    586 
    587 TEST_F(ConstantFoldingTest, ControlDependencies) {
    588   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    589   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
    590   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
    591   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
    592   Output c =
    593       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
    594   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
    595   Output i2 =
    596       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
    597   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
    598 
    599   GrapplerItem item;
    600   item.fetch.push_back("e");
    601   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    602 
    603   ConstantFolding fold(nullptr /* cpu_device */);
    604   GraphDef output;
    605   Status status = fold.Optimize(nullptr, item, &output);
    606   TF_EXPECT_OK(status);
    607 
    608   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "e"};
    609   EXPECT_EQ(output.node_size(), expected_nodes.size());
    610   int i = 0;
    611   int found = 0;
    612   for (const auto& node : output.node()) {
    613     EXPECT_EQ(expected_nodes[i], output.node(i).name());
    614     i++;
    615     if (node.name() == "e") {
    616       EXPECT_EQ("Const", node.op());
    617       ++found;
    618       auto folded = EvaluateNodes(output, {"e"});
    619       auto expected = EvaluateNodes(item.graph, {"e"});
    620       EXPECT_EQ(1, expected.size());
    621       EXPECT_EQ(1, folded.size());
    622       test::ExpectTensorEqual<int>(folded[0], expected[0]);
    623       EXPECT_EQ(2, node.input_size());
    624       EXPECT_EQ("^p1", node.input(0));
    625       EXPECT_EQ("^p2", node.input(1));
    626     }
    627   }
    628   EXPECT_EQ(1, found);
    629 }
    630 
    631 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
    632   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    633   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
    634   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
    635   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
    636   Output c =
    637       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
    638   Output i1 = ops::Identity(scope.WithOpName("i1"), {c});
    639   Output i2 =
    640       ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1});
    641   Output i3 = ops::Identity(scope.WithOpName("e"), {i2});
    642 
    643   GrapplerItem item;
    644   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    645 
    646   ConstantFolding fold(nullptr /* cpu_device */);
    647   GraphDef output;
    648   Status status = fold.Optimize(nullptr, item, &output);
    649   TF_EXPECT_OK(status);
    650 
    651   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c",
    652                                         "i1",   "i2", "e"};
    653   EXPECT_EQ(output.node_size(), expected_nodes.size());
    654   int i = 0;
    655   int found = 0;
    656   for (const auto& node : output.node()) {
    657     EXPECT_EQ(expected_nodes[i], output.node(i).name());
    658     i++;
    659     if (node.name() == "i1") {
    660       EXPECT_EQ("Const", node.op());
    661       ++found;
    662       auto folded = EvaluateNodes(output, {"i1"});
    663       auto expected = EvaluateNodes(item.graph, {"i1"});
    664       EXPECT_EQ(1, expected.size());
    665       EXPECT_EQ(1, folded.size());
    666       test::ExpectTensorEqual<int>(folded[0], expected[0]);
    667       EXPECT_EQ(1, node.input_size());
    668       EXPECT_EQ("^p1", node.input(0));
    669     }
    670     if (node.name() == "i2") {
    671       EXPECT_EQ("Const", node.op());
    672       ++found;
    673       auto folded = EvaluateNodes(output, {"i2"});
    674       auto expected = EvaluateNodes(item.graph, {"i2"});
    675       EXPECT_EQ(1, expected.size());
    676       EXPECT_EQ(1, folded.size());
    677       test::ExpectTensorEqual<int>(folded[0], expected[0]);
    678       EXPECT_EQ(2, node.input_size());
    679       EXPECT_EQ("^p1", node.input(0));
    680       EXPECT_EQ("^p2", node.input(1));
    681     }
    682   }
    683   EXPECT_EQ(2, found);
    684 }
    685 
    686 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) {
    687   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    688   Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1});
    689   Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1});
    690   Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1});
    691   Output c =
    692       ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3});
    693   Output i1 = ops::Identity(scope.WithOpName("i1")
    694                                 .WithControlDependencies(p2)
    695                                 .WithControlDependencies(p1),
    696                             {c});
    697   Output i2 = ops::Identity(scope.WithOpName("i2"), {i1});
    698 
    699   GrapplerItem item;
    700   item.fetch.push_back("i2");
    701   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    702 
    703   ConstantFolding fold(nullptr /* cpu_device */);
    704   GraphDef output;
    705   Status status = fold.Optimize(nullptr, item, &output);
    706   TF_EXPECT_OK(status);
    707 
    708   std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"};
    709   EXPECT_EQ(output.node_size(), expected_nodes.size());
    710   int i = 0;
    711   for (const auto& node : output.node()) {
    712     EXPECT_EQ(expected_nodes[i], output.node(i).name());
    713     i++;
    714     if (node.name() == "i2") {
    715       EXPECT_EQ("Const", node.op());
    716       EXPECT_EQ(2, node.input_size());
    717       EXPECT_EQ("^p1", node.input(0));
    718       EXPECT_EQ("^p2", node.input(1));
    719     }
    720   }
    721 }
    722 
    723 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
    724   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    725   // Add a DynamicPartition node to the graph
    726   Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5});
    727   Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4});
    728   int num_partitions = 4;
    729   ops::DynamicPartition part(scope.WithOpName("partition"), input, indices,
    730                              num_partitions);
    731 
    732   std::vector<string> outputs;
    733   for (int i = 0; i < num_partitions; ++i) {
    734     string part_out_name = strings::StrCat("part_out", i);
    735     ops::Identity partition_out(scope.WithOpName(part_out_name),
    736                                 {part.outputs[i]});
    737     outputs.push_back(part_out_name);
    738   }
    739 
    740   GrapplerItem item;
    741   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    742 
    743   // Add a ConcatOffset node to the graph
    744   Tensor initial_val(DT_INT32, TensorShape({3}));
    745   test::FillIota<int>(&initial_val, 7);
    746   for (int i = 1; i < 5; ++i) {
    747     TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const")
    748                     .Attr("dtype", DT_INT32)
    749                     .Attr("value", initial_val)
    750                     .Finalize(item.graph.add_node()));
    751   }
    752   Tensor concat_dim(DT_INT32, TensorShape({}));
    753   test::FillIota<int>(&concat_dim, 0);
    754   TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const")
    755                   .Attr("dtype", DT_INT32)
    756                   .Attr("value", concat_dim)
    757                   .Finalize(item.graph.add_node()));
    758 
    759   TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset")
    760                   .Input("concat_dim", 0, DT_INT32)
    761                   .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32),
    762                           NodeDefBuilder::NodeOut("in2", 0, DT_INT32),
    763                           NodeDefBuilder::NodeOut("in3", 0, DT_INT32),
    764                           NodeDefBuilder::NodeOut("in4", 0, DT_INT32)})
    765                   .Finalize(item.graph.add_node()));
    766 
    767   for (int i = 0; i < 4; ++i) {
    768     string concat_offset_out_name = strings::StrCat("concat_offset_out", i);
    769     TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity")
    770                     .Attr("T", DT_INT32)
    771                     .Input("concat_offsets", i, DT_INT32)
    772                     .Finalize(item.graph.add_node()));
    773     outputs.push_back(concat_offset_out_name);
    774   }
    775 
    776   item.fetch = outputs;
    777   ConstantFolding fold(nullptr /* cpu_device */);
    778   GraphDef output;
    779   Status status = fold.Optimize(nullptr, item, &output);
    780   TF_EXPECT_OK(status);
    781 
    782   int constant_folded = 0;
    783   for (const auto& node : output.node()) {
    784     if (node.name().find("part_out") != string::npos ||
    785         node.name().find("concat_offset_out") != string::npos) {
    786       ++constant_folded;
    787       EXPECT_EQ("Const", node.op());
    788     }
    789   }
    790   EXPECT_EQ(8, constant_folded);
    791 
    792   auto expected = EvaluateNodes(item.graph, outputs);
    793   auto optimized = EvaluateNodes(output, outputs);
    794   ASSERT_EQ(expected.size(), optimized.size());
    795   for (int i = 0; i < expected.size(); ++i) {
    796     test::ExpectTensorEqual<int>(expected[i], optimized[i]);
    797   }
    798 }
    799 
    800 TEST_F(ConstantFoldingTest, ShapeMaterialization) {
    801   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    802   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
    803   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
    804   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
    805   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
    806   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
    807   Output size = ops::Size(scope.WithOpName("size"), v3);
    808   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
    809   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
    810 
    811   GrapplerItem item;
    812   item.fetch.push_back("p2");
    813   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    814 
    815   ConstantFolding fold(nullptr /* cpu_device */);
    816   GraphDef output;
    817   Status status = fold.Optimize(nullptr, item, &output);
    818   TF_EXPECT_OK(status);
    819 
    820   int found = 0;
    821   for (const auto& node : output.node()) {
    822     if (node.name() == "p2") {
    823       ++found;
    824       EXPECT_EQ("Const", node.op());
    825       EXPECT_EQ(3, node.input_size());
    826       EXPECT_EQ("^v3", node.input(0));
    827       EXPECT_EQ("^v1", node.input(1));
    828       EXPECT_EQ("^v2", node.input(2));
    829       Tensor value;
    830       CHECK(value.FromProto(node.attr().at("value").tensor()));
    831       // rank = 1, shape = (5, 7), size = 143 = 11*13
    832       // p2 = (715, 1001) = (5*143, 7*143)
    833       EXPECT_EQ(715, value.flat<int>()(0));
    834       EXPECT_EQ(1001, value.flat<int>()(1));
    835     }
    836   }
    837   EXPECT_EQ(1, found);
    838 }
    839 
    840 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) {
    841   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    842   Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT);
    843   Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT);
    844   Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT);
    845   Output rank = ops::Rank(scope.WithOpName("rank"), v1);
    846   Output shape = ops::Shape(scope.WithOpName("shape"), v2);
    847   Output size = ops::Size(scope.WithOpName("size"), v3);
    848   Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank);
    849   Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape);
    850 
    851   GrapplerItem item;
    852   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    853 
    854   ConstantFolding fold(nullptr /* cpu_device */);
    855   GraphDef output;
    856   Status status = fold.Optimize(nullptr, item, &output);
    857   TF_EXPECT_OK(status);
    858 
    859   int found = 0;
    860   for (const auto& node : output.node()) {
    861     if (node.name() == "size") {
    862       ++found;
    863       EXPECT_EQ("Const", node.op());
    864       EXPECT_EQ(1, node.input_size());
    865       EXPECT_EQ("^v3", node.input(0));
    866       Tensor value;
    867       CHECK(value.FromProto(node.attr().at("value").tensor()));
    868       EXPECT_EQ(11 * 13, value.flat<int>()(0));
    869     } else if (node.name() == "rank") {
    870       ++found;
    871       EXPECT_EQ("Const", node.op());
    872       EXPECT_EQ(1, node.input_size());
    873       EXPECT_EQ("^v1", node.input(0));
    874       Tensor value;
    875       CHECK(value.FromProto(node.attr().at("value").tensor()));
    876       EXPECT_EQ(1, value.flat<int>()(0));
    877     } else if (node.name() == "shape") {
    878       ++found;
    879       EXPECT_EQ("Const", node.op());
    880       EXPECT_EQ(1, node.input_size());
    881       EXPECT_EQ("^v2", node.input(0));
    882       Tensor value;
    883       CHECK(value.FromProto(node.attr().at("value").tensor()));
    884       EXPECT_EQ(5, value.flat<int>()(0));
    885       EXPECT_EQ(7, value.flat<int>()(1));
    886     }
    887   }
    888   EXPECT_EQ(3, found);
    889 }
    890 
    891 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
    892   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    893   Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
    894   Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT);
    895   Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT);
    896   auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3});
    897   Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]);
    898   Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]);
    899   Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]);
    900   Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]);
    901   Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]);
    902   Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]);
    903   Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]);
    904 
    905   GrapplerItem item;
    906   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    907 
    908   ConstantFolding fold(nullptr /* cpu_device */);
    909   GraphDef output;
    910   Status status = fold.Optimize(nullptr, item, &output);
    911   TF_EXPECT_OK(status);
    912   int found = 0;
    913   for (const auto& node : output.node()) {
    914     EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
    915               node.name());
    916     EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
    917               node.name());
    918     if (node.name() == "i1a" || node.name() == "i1b") {
    919       ++found;
    920       EXPECT_EQ("s", node.input(0));
    921     }
    922     if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") {
    923       ++found;
    924       EXPECT_EQ("s:1", node.input(0));
    925     }
    926     if (node.name() == "i3a" || node.name() == "i3b") {
    927       ++found;
    928       EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst),
    929                 node.input(0));
    930     }
    931     if (node.name() == "s") {
    932       ++found;
    933       EXPECT_EQ("ShapeN", node.op());
    934       EXPECT_EQ("v1", node.input(0));
    935       EXPECT_EQ("v2", node.input(1));
    936       EXPECT_EQ("v3", node.input(2));
    937     }
    938     if (node.name() ==
    939         AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) {
    940       ++found;
    941       EXPECT_EQ("Const", node.op());
    942       EXPECT_EQ("^s", node.input(0));
    943       Tensor value;
    944       CHECK(value.FromProto(node.attr().at("value").tensor()));
    945       EXPECT_EQ(4, value.flat<int>()(0));
    946       EXPECT_EQ(6, value.flat<int>()(1));
    947     }
    948   }
    949   EXPECT_EQ(9, found);
    950 }
    951 
    952 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
    953   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
    954   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
    955   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
    956   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
    957   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
    958   ops::Identity i(scope.WithOpName("i"), s1.output_true);
    959   ops::Size size(scope.WithOpName("size"), i);
    960   ops::Square p1(scope.WithOpName("p1"), rank);
    961   ops::Square p2(scope.WithOpName("p2"), size);
    962   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
    963 
    964   Output predicate =
    965       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
    966   Output constant =
    967       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
    968   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
    969   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
    970   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
    971   ops::Merge m2(scope.WithOpName("m2"),
    972                 {statically_known.output, never_generated.output});
    973 
    974   GrapplerItem item;
    975   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
    976 
    977   ConstantFolding fold(nullptr /* cpu_device */);
    978   GraphDef output;
    979   Status status = fold.Optimize(nullptr, item, &output);
    980   TF_EXPECT_OK(status);
    981 
    982   std::set<string> present_nodes = {"v_in",     "v_ctrl",
    983                                     "switch",   "i",
    984                                     "p1",       "p2",
    985                                     "m",        "false",
    986                                     "constant", "switch2",
    987                                     "i2",       "i3",
    988                                     "m2",       "ConstantFoldingCtrl/switch_0",
    989                                     "rank",     "size"};
    990   std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"};
    991   EXPECT_EQ(present_nodes.size(), output.node_size());
    992   int found = 0;
    993   for (const auto& node : output.node()) {
    994     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
    995     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
    996     present_nodes.erase(node.name());
    997     not_present_nodes.erase(node.name());
    998     if (node.name() == "rank") {
    999       ++found;
   1000       EXPECT_EQ("Const", node.op());
   1001       EXPECT_EQ(1, node.input_size());
   1002       EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0));
   1003     }
   1004     if (node.name() == "size") {
   1005       ++found;
   1006       EXPECT_EQ("Const", node.op());
   1007       EXPECT_EQ(1, node.input_size());
   1008       EXPECT_EQ("^i", node.input(0));
   1009     }
   1010     if (node.name() == "i2") {
   1011       ++found;
   1012       EXPECT_EQ("Const", node.op());
   1013       EXPECT_EQ(0, node.input_size());
   1014     }
   1015     if (node.name() == "i3") {
   1016       ++found;
   1017       EXPECT_EQ("Identity", node.op());
   1018       EXPECT_EQ(1, node.input_size());
   1019       EXPECT_EQ("switch2:1", node.input(0));
   1020     }
   1021   }
   1022   EXPECT_EQ(4, found);
   1023 }
   1024 
   1025 TEST_F(ConstantFoldingTest, SwitchNodes) {
   1026   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   1027   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
   1028   ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL);
   1029   ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl);
   1030   ops::Rank rank(scope.WithOpName("rank"), s1.output_false);
   1031   ops::Identity i(scope.WithOpName("i"), s1.output_true);
   1032   ops::Size size(scope.WithOpName("size"), i);
   1033   ops::Square p1(scope.WithOpName("p1"), rank);
   1034   ops::Square p2(scope.WithOpName("p2"), size);
   1035   ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y});
   1036 
   1037   Output predicate =
   1038       ops::Const(scope.WithOpName("false"), false, TensorShape({}));
   1039   Output constant =
   1040       ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1}));
   1041   ops::Switch s2(scope.WithOpName("switch2"), constant, predicate);
   1042   ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false);
   1043   ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true);
   1044   ops::Merge m2(scope.WithOpName("m2"),
   1045                 {statically_known.output, never_generated.output});
   1046 
   1047   GrapplerItem item;
   1048   item.fetch.push_back("m");
   1049   item.fetch.push_back("m2");
   1050 
   1051   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
   1052 
   1053   ConstantFolding fold(nullptr /* cpu_device */);
   1054   GraphDef output;
   1055   Status status = fold.Optimize(nullptr, item, &output);
   1056   TF_EXPECT_OK(status);
   1057   std::set<string> present_nodes = {"v_in",     "v_ctrl",
   1058                                     "switch",   "i",
   1059                                     "p1",       "p2",
   1060                                     "m",        "false",
   1061                                     "constant", "switch2",
   1062                                     "i2",       "i3",
   1063                                     "m2",       "ConstantFoldingCtrl/switch_0"};
   1064   std::set<string> not_present_nodes = {"rank", "size",
   1065                                         "ConstantFolding/switch2-0"};
   1066   EXPECT_EQ(present_nodes.size(), output.node_size());
   1067 
   1068   int found = 0;
   1069   for (const auto& node : output.node()) {
   1070     EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end());
   1071     EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end());
   1072     present_nodes.erase(node.name());
   1073     not_present_nodes.erase(node.name());
   1074     if (node.name() == "i2") {
   1075       ++found;
   1076       EXPECT_EQ("Const", node.op());
   1077       EXPECT_EQ(0, node.input_size());
   1078     }
   1079     if (node.name() == "i3") {
   1080       ++found;
   1081       EXPECT_EQ("Identity", node.op());
   1082       EXPECT_EQ(1, node.input_size());
   1083       EXPECT_EQ("switch2:1", node.input(0));
   1084     }
   1085   }
   1086   EXPECT_EQ(2, found);
   1087 }
   1088 
   1089 TEST_F(ConstantFoldingTest, MergeNodes) {
   1090   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   1091 
   1092   Output x =
   1093       ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT);
   1094   Output y =
   1095       ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT);
   1096   Output const1 =
   1097       ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f,
   1098                  TensorShape({3, 5}));
   1099   Output const2 =
   1100       ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5}));
   1101   Output const3 =
   1102       ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f,
   1103                  TensorShape({3, 5}));
   1104 
   1105   // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't.
   1106   ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2});
   1107   ops::Merge m2(scope.WithOpName("m2"), {const1, const3});
   1108   ops::Merge m3(scope.WithOpName("m3"), {x, y});
   1109 
   1110   ops::Identity out1(scope.WithOpName("out1"), m1.output);
   1111   ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index);
   1112   ops::Identity out2(scope.WithOpName("out2"), m2.output);
   1113   ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index);
   1114   ops::Identity out3(scope.WithOpName("out3"), m3.output);
   1115   ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index);
   1116 
   1117   GrapplerItem item;
   1118   item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"};
   1119   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
   1120 
   1121   ConstantFolding fold(nullptr /* cpu_device */);
   1122   GraphDef output;
   1123   Status status = fold.Optimize(nullptr, item, &output);
   1124   TF_EXPECT_OK(status);
   1125 
   1126   int found_nodes = 0;
   1127   for (const auto& node : output.node()) {
   1128     if (node.name() == "out1") {
   1129       EXPECT_EQ(1, node.input_size());
   1130       EXPECT_EQ("^m1", node.input(0));
   1131       ++found_nodes;
   1132     } else if (node.name() == "idx1") {
   1133       EXPECT_EQ(1, node.input_size());
   1134       EXPECT_EQ("^m1", node.input(0));
   1135       ++found_nodes;
   1136     } else if (node.name() == "ConstantFolding/m1") {
   1137       EXPECT_EQ("Const", node.op());
   1138       EXPECT_EQ(1, node.input_size());
   1139       EXPECT_EQ("^m1", node.input(0));
   1140       ++found_nodes;
   1141     } else if (node.name() == "ConstantFolding/m1_index") {
   1142       EXPECT_EQ("Const", node.op());
   1143       EXPECT_EQ(1, node.input_size());
   1144       EXPECT_EQ("^m1", node.input(0));
   1145       ++found_nodes;
   1146     } else if (node.name() == "out2") {
   1147       EXPECT_EQ(1, node.input_size());
   1148       EXPECT_EQ("m2", node.input(0));
   1149       ++found_nodes;
   1150     } else if (node.name() == "idx2") {
   1151       EXPECT_EQ(1, node.input_size());
   1152       EXPECT_EQ("m2:1", node.input(0));
   1153       ++found_nodes;
   1154     } else if (node.name() == "out3") {
   1155       EXPECT_EQ(1, node.input_size());
   1156       EXPECT_EQ("m3", node.input(0));
   1157       ++found_nodes;
   1158     } else if (node.name() == "idx3") {
   1159       EXPECT_EQ(1, node.input_size());
   1160       EXPECT_EQ("m3:1", node.input(0));
   1161       ++found_nodes;
   1162     }
   1163   }
   1164   // Make sure the graph contains all the nodes we're expecting.
   1165   EXPECT_EQ(6, found_nodes);
   1166 
   1167   std::vector<string> fetch = {"out1", "idx1"};
   1168   auto tensors = EvaluateNodes(output, fetch);
   1169   EXPECT_EQ(2, tensors.size());
   1170   const Tensor& out_value = tensors[0];
   1171   EXPECT_EQ(3 * 5, out_value.NumElements());
   1172   for (int i = 0; i < 3 * 5; ++i) {
   1173     EXPECT_EQ(3.14f, out_value.flat<float>()(i));
   1174   }
   1175   const Tensor& out_idx = tensors[1];
   1176   EXPECT_EQ(1, out_idx.NumElements());
   1177   EXPECT_EQ(2, out_idx.flat<int32>()(0));
   1178 }
   1179 
   1180 TEST_F(ConstantFoldingTest, NoOpReduction) {
   1181   // Build a simple graph with a reduction that can be reduced to the identity.
   1182   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   1183 
   1184   Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT);
   1185   Output c =
   1186       ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0});
   1187   Output i = ops::Identity(scope.WithOpName("i"), c);
   1188   Output p = ops::Prod(scope.WithOpName("p"), v, i);
   1189   Output s = ops::Square(scope.WithOpName("s"), p);
   1190 
   1191   GrapplerItem item;
   1192   item.fetch.push_back("s");
   1193   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
   1194 
   1195   ConstantFolding fold(nullptr /* cpu_device */);
   1196   GraphDef output;
   1197   Status status = fold.Optimize(nullptr, item, &output);
   1198   TF_EXPECT_OK(status);
   1199 
   1200   bool found = false;
   1201   for (const auto& node : output.node()) {
   1202     if (node.name() == "p") {
   1203       found = true;
   1204       EXPECT_EQ("Identity", node.op());
   1205       EXPECT_EQ(2, node.input_size());
   1206       EXPECT_EQ("v", node.input(0));
   1207       EXPECT_EQ("^i", node.input(1));
   1208     }
   1209   }
   1210   EXPECT_TRUE(found);
   1211 }
   1212 
   1213 TEST_F(ConstantFoldingTest, NoOpReshape) {
   1214   // Build a simple graph with a reshape that can be reduced to the identity.
   1215   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   1216 
   1217   // A reshape than can be optimized
   1218   Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17});
   1219   Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT);
   1220   Output c1 =
   1221       ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1});
   1222   Output i1 = ops::Identity(scope.WithOpName("i1"), c1);
   1223   Output r1 =
   1224       ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1);
   1225   Output s1 = ops::Square(scope.WithOpName("s1"), r1);
   1226 
   1227   // A multi dimensional reshape than can be optimized
   1228   Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT);
   1229   Output c3 =
   1230       ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3});
   1231   Output i3 = ops::Identity(scope.WithOpName("i3"), c3);
   1232   Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3);
   1233   Output s3 = ops::Square(scope.WithOpName("s3"), r3);
   1234 
   1235   // A multi dimensional partially defined reshape than can be optimized
   1236   Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT);
   1237   Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4),
   1238                          {5, -1, 5}, {3});
   1239   Output i4 = ops::Identity(scope.WithOpName("i4"), c4);
   1240   Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4);
   1241   Output s4 = ops::Square(scope.WithOpName("s4"), r4);
   1242 
   1243   // A reshape that can't be optimized
   1244   Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT);
   1245   Output c2 =
   1246       ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1});
   1247   Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2);
   1248   Output s2 = ops::Square(scope.WithOpName("s2"), r2);
   1249 
   1250   GrapplerItem item;
   1251   item.fetch = {"s1", "s2", "s3", "s4"};
   1252   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
   1253 
   1254   ConstantFolding fold(nullptr /* cpu_device */);
   1255   GraphDef output;
   1256   Status status = fold.Optimize(nullptr, item, &output);
   1257   TF_EXPECT_OK(status);
   1258 
   1259   int found = 0;
   1260   for (const auto& node : output.node()) {
   1261     if (node.name() == "r1") {
   1262       ++found;
   1263       EXPECT_EQ("Identity", node.op());
   1264       ASSERT_EQ(3, node.input_size());
   1265       EXPECT_EQ("v1", node.input(0));
   1266       EXPECT_EQ("^i1", node.input(1));
   1267       EXPECT_EQ("^d1", node.input(2));
   1268     } else if (node.name() == "r3") {
   1269       ++found;
   1270       EXPECT_EQ("Identity", node.op());
   1271       ASSERT_EQ(2, node.input_size());
   1272       EXPECT_EQ("v3", node.input(0));
   1273       EXPECT_EQ("^i3", node.input(1));
   1274     } else if (node.name() == "r4") {
   1275       ++found;
   1276       EXPECT_EQ("Identity", node.op());
   1277       ASSERT_EQ(2, node.input_size());
   1278       EXPECT_EQ("v4", node.input(0));
   1279       EXPECT_EQ("^i4", node.input(1));
   1280     } else if (node.name() == "r2") {
   1281       ++found;
   1282       EXPECT_EQ("Reshape", node.op());
   1283       ASSERT_EQ(2, node.input_size());
   1284       EXPECT_EQ("v2", node.input(0));
   1285       EXPECT_EQ("c2", node.input(1));
   1286     }
   1287   }
   1288   EXPECT_EQ(4, found);
   1289 }
   1290 
   1291 TEST_F(ConstantFoldingTest, Packing) {
   1292   // Build a simple graph with a large constant that can be folded.
   1293   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   1294   Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000});
   1295   Output i1 = ops::Identity(scope.WithOpName("i1"), c);
   1296   Output i2 = ops::Identity(scope.WithOpName("i2"), c);
   1297 
   1298   GrapplerItem item;
   1299   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
   1300 
   1301   ConstantFolding fold(nullptr /* cpu_device */);
   1302   GraphDef output;
   1303   Status status = fold.Optimize(nullptr, item, &output);
   1304   TF_EXPECT_OK(status);
   1305 
   1306   // Make sure that the representation of the folded constant is space
   1307   // efficient: in particular, the whole message should be smaller than 8k (the
   1308   // size needed to naively encode 1000 floats folded twice).
   1309   EXPECT_GT(8000, output.ByteSizeLong());
   1310 }
   1311 
   1312 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
   1313   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   1314   Output a =
   1315       ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
   1316                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
   1317   Output b = ops::Square(s.WithOpName("b"), a);
   1318   Output c = ops::Mul(s.WithOpName("c"), a, b);
   1319   Output d = ops::Shape(s.WithOpName("d"), a);
   1320   Output e = ops::Shape(s.WithOpName("e"), b);
   1321 
   1322   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
   1323   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
   1324   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
   1325 
   1326   Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
   1327                               ops::Placeholder::Shape(PartialTensorShape({1})));
   1328   Output h = ops::Shape(s.WithOpName("h"), g);
   1329   auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
   1330   Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
   1331   Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
   1332 
   1333   GrapplerItem item;
   1334   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   1335 
   1336   ConstantFolding fold(nullptr /* cpu_device */);
   1337   GraphDef output;
   1338   Status status = fold.Optimize(nullptr, item, &output);
   1339   TF_EXPECT_OK(status);
   1340 
   1341   // Run a second time to make sure the optimization is idempotent.
   1342   item.graph.Swap(&output);
   1343   status = fold.Optimize(nullptr, item, &output);
   1344   TF_EXPECT_OK(status);
   1345 
   1346   int found = 0;
   1347   for (const auto& node : output.node()) {
   1348     if (node.name() == "o1") {
   1349       ++found;
   1350       EXPECT_EQ(1, node.input_size());
   1351       EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0));
   1352     } else if (node.name() == "o2") {
   1353       ++found;
   1354       EXPECT_EQ(1, node.input_size());
   1355       EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0));
   1356     } else if (node.name() == "ConstantFolding/f-bcastargs-0") {
   1357       ++found;
   1358       EXPECT_EQ("Const", node.op());
   1359       EXPECT_EQ(1, node.input_size());
   1360       EXPECT_EQ("^f", node.input(0));
   1361       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
   1362                        .num_elements());
   1363     } else if (node.name() == "ConstantFolding/f-bcastargs-1") {
   1364       ++found;
   1365       EXPECT_EQ("Const", node.op());
   1366       EXPECT_EQ(1, node.input_size());
   1367       EXPECT_EQ("^f", node.input(0));
   1368       EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
   1369                        .num_elements());
   1370     } else if (node.name() == "p1") {
   1371       ++found;
   1372       EXPECT_EQ(1, node.input_size());
   1373       EXPECT_EQ("i", node.input(0));
   1374     } else if (node.name() == "p2") {
   1375       ++found;
   1376       EXPECT_EQ(1, node.input_size());
   1377       EXPECT_EQ("i:1", node.input(0));
   1378     }
   1379   }
   1380   EXPECT_EQ(6, found);
   1381 }
   1382 
   1383 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
   1384   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   1385   Output input =
   1386       ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
   1387                        ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
   1388   Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
   1389   Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
   1390   Output size = ops::Const(s.WithOpName("size"), 1, {1});
   1391   Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
   1392 
   1393   GrapplerItem item;
   1394   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   1395   item.fetch.push_back("reshape");
   1396 
   1397   ConstantFolding fold(nullptr /* cpu_device */);
   1398   GraphDef output;
   1399   Status status = fold.Optimize(nullptr, item, &output);
   1400   TF_EXPECT_OK(status);
   1401 
   1402   // Run a second time to make sure the optimization is idempotent.
   1403   item.graph.Swap(&output);
   1404   status = fold.Optimize(nullptr, item, &output);
   1405   TF_EXPECT_OK(status);
   1406 
   1407   int found = 0;
   1408   for (const auto& node : output.node()) {
   1409     if (node.name() == "ConstantFolding/sum-reduction_indices") {
   1410       ++found;
   1411       EXPECT_EQ("Const", node.op());
   1412       EXPECT_EQ("^indices", node.input(0));
   1413       EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
   1414                        .num_elements());
   1415     } else if (node.name() == "sum") {
   1416       ++found;
   1417       EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
   1418     } else if (node.name() == "indices") {
   1419       ++found;
   1420     }
   1421   }
   1422   EXPECT_EQ(3, found);
   1423 }
   1424 
   1425 }  // namespace
   1426 }  // namespace grappler
   1427 }  // namespace tensorflow
   1428 
   1429 //  LocalWords:  NewRootScope
   1430