1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/constant_folding.h" 17 #include "tensorflow/cc/ops/array_ops_internal.h" 18 #include "tensorflow/cc/ops/standard_ops.h" 19 #include "tensorflow/core/framework/node_def.pb.h" 20 #include "tensorflow/core/framework/tensor_testutil.h" 21 #include "tensorflow/core/grappler/grappler_item.h" 22 #include "tensorflow/core/grappler/utils.h" 23 #include "tensorflow/core/grappler/utils/grappler_test.h" 24 #include "tensorflow/core/lib/core/status_test_util.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 namespace { 30 31 class ConstantFoldingTest : public GrapplerTest {}; 32 33 TEST_F(ConstantFoldingTest, SimpleFolding) { 34 // Build a simple graph with a few trivially prunable ops. 35 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 36 37 Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); 38 Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); 39 Output c = ops::AddN(s.WithOpName("c").WithDevice("/CPU:0"), {a, b}); 40 Output d = ops::AddN(s.WithOpName("d"), {b, c}); 41 42 GrapplerItem item; 43 item.fetch.push_back("d"); 44 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 45 46 ConstantFolding fold(nullptr /* cpu_device */); 47 GraphDef output; 48 Status status = fold.Optimize(nullptr, item, &output); 49 TF_EXPECT_OK(status); 50 51 EXPECT_EQ(1, output.node_size()); 52 53 const NodeDef& node_d = output.node(0); 54 EXPECT_EQ("d", node_d.name()); 55 EXPECT_EQ("Const", node_d.op()); 56 57 std::vector<string> fetch = {"d"}; 58 auto tensors_expected = EvaluateNodes(item.graph, fetch); 59 auto tensors = EvaluateNodes(output, fetch); 60 EXPECT_EQ(1, tensors_expected.size()); 61 EXPECT_EQ(1, tensors.size()); 62 test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); 63 } 64 65 TEST_F(ConstantFoldingTest, AddTree) { 66 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 67 68 Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2}); 69 Output c3 = ops::Const(s.WithOpName("c3"), 3.0f, {2}); 70 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, 71 ops::Placeholder::Shape(TensorShape({2, 2}))); 72 Output add_child = ops::Add(s.WithOpName("add_child"), c2, x); 73 Output c1 = ops::Const(s.WithOpName("c1").WithControlDependencies(add_child), 74 1.0f, {1}); 75 Output add_parent = ops::Add(s.WithOpName("add_parent"), c1, add_child); 76 77 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, 78 ops::Placeholder::Shape(TensorShape({2, 2}))); 79 Output c4 = ops::Const(s.WithOpName("c4"), 4.0f, {2}); 80 Output c5 = ops::Const(s.WithOpName("c5"), 5.0f, {2}); 81 Output c20 = ops::Const(s.WithOpName("c20"), 20.0f, {2}); 82 Output mul_child = ops::Mul(s.WithOpName("mul_child"), c4, y); 83 Output mul_parent = ops::Mul(s.WithOpName("mul_parent"), c5, mul_child); 84 Output addmul_child = ops::Add(s.WithOpName("addmul_child"), c4, x); 85 Output addmul_parent = 86 ops::Mul(s.WithOpName("addmul_parent"), c5, addmul_child); 87 88 GrapplerItem item; 89 item.fetch = {"add_parent", "mul_parent", "addmul_parent"}; 90 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 91 92 ConstantFolding fold(nullptr /* cpu_device */); 93 GraphDef output; 94 Status status = fold.Optimize(nullptr, item, &output); 95 TF_EXPECT_OK(status); 96 97 // We expect the following rewrite(s) to occur: 98 // 99 // + + + 100 // / \ / \ / \ 101 // 1.0 + --> x + --> x 3.0 102 // / \ / \ 103 // 2.0 x 1.0 2.0 104 // 105 // * * * 106 // / \ / \ / \ 107 // 4.0 * --> y * --> y 20.0 108 // / \ / \ 109 // 5.0 y 4.0 5.0 110 111 EXPECT_EQ(11, output.node_size()); 112 for (const auto& node : output.node()) { 113 if (node.name() == "add_child") { 114 EXPECT_EQ("Const", node.op()); 115 TensorProto t = node.attr().at("value").tensor(); 116 EXPECT_EQ(1, t.tensor_shape().dim_size()); 117 EXPECT_EQ(2, t.tensor_shape().dim(0).size()); 118 } else if (node.name() == "add_parent") { 119 EXPECT_EQ("Add", node.op()); 120 EXPECT_EQ(2, node.input_size()); 121 EXPECT_EQ("x", node.input(0)); 122 EXPECT_EQ("add_child", node.input(1)); 123 } else if (node.name() == "mul_child") { 124 EXPECT_EQ("Const", node.op()); 125 TensorProto t = node.attr().at("value").tensor(); 126 EXPECT_EQ(1, t.tensor_shape().dim_size()); 127 EXPECT_EQ(2, t.tensor_shape().dim(0).size()); 128 } else if (node.name() == "mul_parent") { 129 EXPECT_EQ("Mul", node.op()); 130 EXPECT_EQ(2, node.input_size()); 131 EXPECT_EQ("y", node.input(0)); 132 EXPECT_EQ("mul_child", node.input(1)); 133 } else if (node.name() == "addmul_child") { 134 // Unchanged. 135 EXPECT_EQ("Add", node.op()); 136 EXPECT_EQ(2, node.input_size()); 137 EXPECT_EQ("c4", node.input(0)); 138 EXPECT_EQ("x", node.input(1)); 139 } 140 } 141 142 // Check that the result nodes have the expected value. 143 std::vector<string> fetch = {"c3", "c20"}; 144 auto tensor_expected = EvaluateNodes(item.graph, fetch); 145 EXPECT_EQ(fetch.size(), tensor_expected.size()); 146 fetch = {"add_child", "mul_child"}; 147 auto tensors = EvaluateNodes(output, fetch); 148 EXPECT_EQ(fetch.size(), tensors.size()); 149 for (int i = 0; i < fetch.size(); i++) { 150 test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]); 151 } 152 } 153 154 TEST_F(ConstantFoldingTest, NeutralElement) { 155 for (bool use_const : {true, false}) { 156 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 157 Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, 158 ops::Placeholder::Shape(TensorShape({2, 2}))); 159 Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, 160 ops::Placeholder::Shape(TensorShape({2, 2}))); 161 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT, 162 ops::Placeholder::Shape(TensorShape({3, 2}))); 163 Output b = ops::Placeholder(s.WithOpName("b"), DT_FLOAT, 164 ops::Placeholder::Shape(TensorShape({2, 3}))); 165 Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT, 166 ops::Placeholder::Shape(TensorShape({2}))); 167 Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) 168 : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2}); 169 Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2}); 170 Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x) 171 : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2}); 172 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); 173 Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); 174 Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); 175 Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); 176 Output mul5 = ops::Mul(s.WithOpName("mul5"), x, zeros_1d); 177 Output mul6 = ops::Mul(s.WithOpName("mul6"), zeros_1d, y); 178 Output div1 = ops::Div(s.WithOpName("div1"), x, ones); 179 Output div2 = ops::Div(s.WithOpName("div2"), ones, y); 180 Output matmul1 = ops::MatMul(s.WithOpName("matmul1"), x, zeros); 181 Output matmul2 = ops::MatMul(s.WithOpName("matmul2"), zeros, y); 182 Output matmul3 = ops::MatMul(s.WithOpName("matmul3"), a, zeros); 183 Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b); 184 Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); 185 Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); 186 Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d); 187 Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias); 188 Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros); 189 Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y); 190 Output addn = 191 ops::AddN(s.WithOpName("addn"), 192 {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1, 193 matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2}); 194 GrapplerItem item; 195 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 196 item.fetch = {"addn", "matmul3", "matmul4"}; 197 198 ConstantFolding optimizer(nullptr /* cpu_device */); 199 GraphDef output; 200 Status status = optimizer.Optimize(nullptr, item, &output); 201 TF_EXPECT_OK(status); 202 203 EXPECT_EQ(27, output.node_size()); 204 for (int i = 0; i < output.node_size(); ++i) { 205 const NodeDef& node = output.node(i); 206 const string& name = node.name(); 207 if (name == "mul1") { 208 EXPECT_EQ("Const", node.op()); 209 EXPECT_EQ("^x", node.input(0)); 210 EXPECT_EQ("^zeros", node.input(1)); 211 } else if (name == "mul2") { 212 EXPECT_EQ("Const", node.op()); 213 EXPECT_EQ("^zeros", node.input(0)); 214 EXPECT_EQ("^y", node.input(1)); 215 } else if (name == "mul3") { 216 EXPECT_EQ("Snapshot", node.op()); 217 EXPECT_EQ("x", node.input(0)); 218 EXPECT_EQ("^ones", node.input(1)); 219 } else if (name == "mul4") { 220 EXPECT_EQ("Snapshot", node.op()); 221 EXPECT_EQ("y", node.input(0)); 222 EXPECT_EQ("^ones", node.input(1)); 223 } else if (name == "mul5") { 224 EXPECT_EQ("Const", node.op()); 225 EXPECT_EQ("^x", node.input(0)); 226 EXPECT_EQ("^zeros_1d", node.input(1)); 227 } else if (name == "mul6") { 228 EXPECT_EQ("Const", node.op()); 229 EXPECT_EQ("^zeros_1d", node.input(0)); 230 EXPECT_EQ("^y", node.input(1)); 231 } else if (name == "div1") { 232 EXPECT_EQ("Snapshot", node.op()); 233 EXPECT_EQ("x", node.input(0)); 234 EXPECT_EQ("^ones", node.input(1)); 235 } else if (name == "div2") { 236 EXPECT_EQ("Reciprocal", node.op()); 237 EXPECT_EQ("y", node.input(0)); 238 EXPECT_EQ("^ones", node.input(1)); 239 } else if (name == "matmul1") { 240 EXPECT_EQ("Const", node.op()); 241 EXPECT_EQ("^x", node.input(0)); 242 EXPECT_EQ("^zeros", node.input(1)); 243 } else if (name == "matmul2") { 244 EXPECT_EQ("Const", node.op()); 245 EXPECT_EQ("^zeros", node.input(0)); 246 EXPECT_EQ("^y", node.input(1)); 247 } else if (name == "matmul3") { 248 EXPECT_EQ("Const", node.op()); 249 EXPECT_EQ("^a", node.input(0)); 250 EXPECT_EQ("^zeros", node.input(1)); 251 TensorProto t = node.attr().at("value").tensor(); 252 EXPECT_EQ(1, t.float_val_size()); 253 EXPECT_EQ(0, t.float_val(0)); 254 EXPECT_EQ(2, t.tensor_shape().dim_size()); 255 EXPECT_EQ(3, t.tensor_shape().dim(0).size()); 256 EXPECT_EQ(2, t.tensor_shape().dim(1).size()); 257 } else if (name == "matmul4") { 258 EXPECT_EQ("Const", node.op()); 259 EXPECT_EQ("^zeros", node.input(0)); 260 EXPECT_EQ("^b", node.input(1)); 261 TensorProto t = node.attr().at("value").tensor(); 262 EXPECT_EQ(1, t.float_val_size()); 263 EXPECT_EQ(0, t.float_val(0)); 264 EXPECT_EQ(2, t.tensor_shape().dim_size()); 265 EXPECT_EQ(2, t.tensor_shape().dim(0).size()); 266 EXPECT_EQ(3, t.tensor_shape().dim(1).size()); 267 } else if (name == "add1") { 268 EXPECT_EQ("Snapshot", node.op()); 269 EXPECT_EQ("x", node.input(0)); 270 EXPECT_EQ("^zeros", node.input(1)); 271 } else if (name == "add2") { 272 EXPECT_EQ("Snapshot", node.op()); 273 EXPECT_EQ("y", node.input(0)); 274 EXPECT_EQ("^zeros", node.input(1)); 275 } else if (name == "bias_add1") { 276 EXPECT_EQ("Snapshot", node.op()); 277 EXPECT_EQ("x", node.input(0)); 278 EXPECT_EQ("^zeros_1d", node.input(1)); 279 } else if (name == "bias_add2") { 280 // We don't eliminate this one, because it requires broadcasting. 281 EXPECT_EQ("BiasAdd", node.op()); 282 EXPECT_EQ("zeros", node.input(0)); 283 EXPECT_EQ("bias", node.input(1)); 284 } else if (name == "sub1") { 285 EXPECT_EQ("Snapshot", node.op()); 286 EXPECT_EQ("x", node.input(0)); 287 EXPECT_EQ("^zeros", node.input(1)); 288 } else if (name == "sub2") { 289 // We don't handle this case yet. 290 EXPECT_EQ("Sub", node.op()); 291 EXPECT_EQ("zeros", node.input(0)); 292 EXPECT_EQ("y", node.input(1)); 293 } 294 const std::set<string> square_zero_const{"mul1", "mul2", "mul5", 295 "mul6", "matmul1", "matmul2"}; 296 if (square_zero_const.count(name) > 0) { 297 TensorProto t = node.attr().at("value").tensor(); 298 EXPECT_EQ(1, t.float_val_size()); 299 EXPECT_EQ(0, t.float_val(0)); 300 EXPECT_EQ(2, t.tensor_shape().dim_size()); 301 EXPECT_EQ(2, t.tensor_shape().dim(0).size()); 302 EXPECT_EQ(2, t.tensor_shape().dim(1).size()); 303 } 304 } 305 } 306 } 307 308 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) { 309 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 310 Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1}); 311 Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT, 312 ops::Placeholder::Shape(TensorShape({2, 2}))); 313 Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32, 314 ops::Placeholder::Shape(TensorShape({2, 2}))); 315 Output ci = ops::Const(s.WithOpName("ci"), 2, {1}); 316 Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1}); 317 Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci); 318 Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf); 319 Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf); 320 321 GrapplerItem item; 322 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 323 item.fetch = {"div_f", "div_i", "realdiv"}; 324 ConstantFolding optimizer(nullptr /* cpu_device */); 325 GraphDef output; 326 Status status = optimizer.Optimize(nullptr, item, &output); 327 TF_EXPECT_OK(status); 328 329 EXPECT_EQ(8, output.node_size()); 330 for (int i = 0; i < output.node_size(); ++i) { 331 const NodeDef& node = output.node(i); 332 const string& name = node.name(); 333 if (name == "div_i") { 334 // Integer division is unchanged. 335 EXPECT_EQ("Div", node.op()); 336 EXPECT_EQ("xi", node.input(0)); 337 EXPECT_EQ("ci", node.input(1)); 338 } else if (name == "div_f") { 339 EXPECT_EQ("Mul", node.op()); 340 EXPECT_EQ("xf", node.input(0)); 341 EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1)); 342 } else if (name == "realdiv") { 343 EXPECT_EQ("Mul", node.op()); 344 EXPECT_EQ("xf", node.input(0)); 345 EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1)); 346 } else if (name == "ConstantFolding/div_f_recip") { 347 EXPECT_EQ("Const", node.op()); 348 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); 349 TensorProto t = node.attr().at("value").tensor(); 350 EXPECT_EQ(DT_FLOAT, t.dtype()); 351 EXPECT_EQ(1, t.tensor_shape().dim_size()); 352 EXPECT_EQ(1, t.tensor_shape().dim(0).size()); 353 } else if (name == "ConstantFolding/realdiv_recip") { 354 EXPECT_EQ("Const", node.op()); 355 EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type()); 356 TensorProto t = node.attr().at("value").tensor(); 357 EXPECT_EQ(DT_FLOAT, t.dtype()); 358 EXPECT_EQ(1, t.tensor_shape().dim_size()); 359 EXPECT_EQ(1, t.tensor_shape().dim(0).size()); 360 } 361 } 362 363 // Check that the reciprocals have the expected value. 364 std::vector<string> fetch = {"cf_half"}; 365 auto tensor_expected = EvaluateNodes(item.graph, fetch); 366 EXPECT_EQ(fetch.size(), tensor_expected.size()); 367 fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"}; 368 auto tensors = EvaluateNodes(output, fetch); 369 EXPECT_EQ(fetch.size(), tensors.size()); 370 for (int i = 0; i < fetch.size(); i++) { 371 test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]); 372 } 373 } 374 375 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) { 376 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 377 Output x_known = 378 ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT, 379 ops::Placeholder::Shape(TensorShape({2, 2}))); 380 Output x_partially_known = 381 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT, 382 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 383 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT); 384 Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known); 385 Output zeros_partially_known = 386 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known); 387 Output zeros_unknown = 388 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown); 389 390 // Multiplies without any additional ops to supply the output shape. 391 int count = 0; 392 std::vector<Output> muls; 393 std::unordered_set<string> not_converted; 394 std::unordered_set<string> to_const; 395 std::unordered_set<string> to_identity; 396 for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) { 397 for (const auto* zeros : 398 {&zeros_known, &zeros_partially_known, &zeros_unknown}) { 399 const string name = strings::StrCat("mul_", count++); 400 muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros)); 401 if (x == &x_partially_known && zeros == &zeros_partially_known) { 402 to_identity.insert(name); 403 } else if (x == &x_unknown || zeros == &zeros_unknown) { 404 not_converted.insert(name); 405 } else { 406 to_const.insert(name); 407 } 408 } 409 } 410 411 GrapplerItem item; 412 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 413 414 ConstantFolding optimizer(nullptr /* cpu_device */); 415 GraphDef output; 416 Status status = optimizer.Optimize(nullptr, item, &output); 417 TF_EXPECT_OK(status); 418 LOG(INFO) << output.DebugString(); 419 420 EXPECT_EQ(15, output.node_size()); 421 for (int i = 0; i < output.node_size(); ++i) { 422 const NodeDef& node = output.node(i); 423 const string& name = node.name(); 424 if (to_const.count(name) > 0) { 425 EXPECT_EQ("Const", node.op()) << node.name(); 426 } else if (to_identity.count(name) > 0) { 427 EXPECT_EQ("Identity", node.op()) << node.name(); 428 } else if (not_converted.count(name) > 0) { 429 EXPECT_EQ("Mul", node.op()) << node.name(); 430 } 431 } 432 } 433 434 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) { 435 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 436 Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2}); 437 Output x_partially_known = 438 ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT, 439 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 440 Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT); 441 Output zeros_partially_known = 442 ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known); 443 Output zeros_unknown = 444 ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown); 445 446 // If at least one of the inputs to AddN has a known shape, shape inference 447 // will propagate the shape back to the inputs of AddN, making the 448 // output shapes of all its inputs known 449 std::vector<Output> muls_deduced_output_shape; 450 std::unordered_set<string> to_const; 451 int count = 0; 452 for (const auto& x : {x_partially_known, x_unknown}) { 453 for (const auto& zeros : {zeros_partially_known, zeros_unknown}) { 454 const string name = strings::StrCat("mul_", count++); 455 muls_deduced_output_shape.push_back( 456 ops::Mul(s.WithOpName(name), x, zeros)); 457 to_const.insert(name); 458 } 459 } 460 // We add a known shape as input to AddN to propagate it back to the 461 // multiplies above, which means they can all be turned into Const nodes. 462 muls_deduced_output_shape.push_back(known_shape); 463 Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape); 464 465 GrapplerItem item; 466 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 467 468 ConstantFolding optimizer(nullptr /* cpu_device */); 469 GraphDef output; 470 Status status = optimizer.Optimize(nullptr, item, &output); 471 TF_EXPECT_OK(status); 472 LOG(INFO) << output.DebugString(); 473 474 EXPECT_EQ(10, output.node_size()); 475 for (int i = 0; i < output.node_size(); ++i) { 476 const NodeDef& node = output.node(i); 477 const string& name = node.name(); 478 if (to_const.count(name) > 0) { 479 EXPECT_EQ("Const", node.op()) << node.name(); 480 EXPECT_EQ(2, node.input_size()); 481 EXPECT_TRUE(IsControlInput(node.input(0))); 482 EXPECT_TRUE(IsControlInput(node.input(1))); 483 } 484 } 485 } 486 487 TEST_F(ConstantFoldingTest, CreateConstNodes) { 488 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 489 490 #define MAKE_TEST_GRAPH(TYPE) \ 491 Output TYPE##_const = \ 492 ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \ 493 Output TYPE##_mul = \ 494 ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \ 495 Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul) 496 497 MAKE_TEST_GRAPH(float); 498 MAKE_TEST_GRAPH(double); 499 MAKE_TEST_GRAPH(int64); 500 MAKE_TEST_GRAPH(int32); 501 MAKE_TEST_GRAPH(int16); 502 MAKE_TEST_GRAPH(int8); 503 MAKE_TEST_GRAPH(uint8); 504 #undef MAKE_TEST_GRAPH 505 506 Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5}); 507 Output bool_and = 508 ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const); 509 Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and); 510 511 GrapplerItem item; 512 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 513 ConstantFolding fold(nullptr /* cpu_device */); 514 GraphDef output; 515 Status status = fold.Optimize(nullptr, item, &output); 516 TF_EXPECT_OK(status); 517 518 EXPECT_EQ(24, output.node_size()); 519 for (const NodeDef& node : output.node()) { 520 #define CHECK_RESULT(TYPE, FIELD) \ 521 if (node.name() == #TYPE "_mul") { \ 522 EXPECT_EQ(5, \ 523 node.attr().at("value").tensor().tensor_shape().dim(0).size()); \ 524 EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \ 525 EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \ 526 } 527 528 CHECK_RESULT(float, float); 529 CHECK_RESULT(double, double); 530 CHECK_RESULT(int64, int64); 531 CHECK_RESULT(int32, int); 532 CHECK_RESULT(int16, int); 533 CHECK_RESULT(int8, int); 534 CHECK_RESULT(uint8, int); 535 #undef CHECK_RESULT 536 537 if (node.name() == "bool_and") { 538 EXPECT_EQ(5, 539 node.attr().at("value").tensor().tensor_shape().dim(0).size()); 540 EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size()); 541 EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0)); 542 } 543 } 544 } 545 546 TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { 547 // Build a simple graph with a few trivially prunable ops. 548 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 549 550 Output a = ops::Const(s.WithOpName("a"), 10, {5}); 551 auto b = ops::Unique(s.WithOpName("b"), {a}); 552 Output c = ops::Identity(s.WithOpName("c"), {b.y}); 553 Output d = ops::Identity(s.WithOpName("d"), {b.idx}); 554 Output e = ops::Identity(s.WithOpName("e"), {c}); 555 Output f = ops::Identity(s.WithOpName("f"), {d}); 556 557 GrapplerItem item; 558 item.fetch.push_back("e"); 559 item.fetch.push_back("f"); 560 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 561 562 ConstantFolding fold(nullptr /* cpu_device */); 563 GraphDef output; 564 Status status = fold.Optimize(nullptr, item, &output); 565 TF_EXPECT_OK(status); 566 567 EXPECT_EQ(2, output.node_size()); 568 569 const NodeDef& new_c = output.node(0); 570 EXPECT_EQ("e", new_c.name()); 571 EXPECT_EQ("Const", new_c.op()); 572 573 const NodeDef& new_d = output.node(1); 574 EXPECT_EQ("f", new_d.name()); 575 EXPECT_EQ("Const", new_d.op()); 576 577 std::vector<string> fetch = {"e", "f"}; 578 auto tensors_expected = EvaluateNodes(item.graph, fetch); 579 auto tensors = EvaluateNodes(output, fetch); 580 EXPECT_EQ(fetch.size(), tensors_expected.size()); 581 EXPECT_EQ(fetch.size(), tensors.size()); 582 for (int i = 0; i < fetch.size(); i++) { 583 test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]); 584 } 585 } 586 587 TEST_F(ConstantFoldingTest, ControlDependencies) { 588 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 589 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); 590 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); 591 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); 592 Output c = 593 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); 594 Output i1 = ops::Identity(scope.WithOpName("i1"), {c}); 595 Output i2 = 596 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1}); 597 Output i3 = ops::Identity(scope.WithOpName("e"), {i2}); 598 599 GrapplerItem item; 600 item.fetch.push_back("e"); 601 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 602 603 ConstantFolding fold(nullptr /* cpu_device */); 604 GraphDef output; 605 Status status = fold.Optimize(nullptr, item, &output); 606 TF_EXPECT_OK(status); 607 608 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "e"}; 609 EXPECT_EQ(output.node_size(), expected_nodes.size()); 610 int i = 0; 611 int found = 0; 612 for (const auto& node : output.node()) { 613 EXPECT_EQ(expected_nodes[i], output.node(i).name()); 614 i++; 615 if (node.name() == "e") { 616 EXPECT_EQ("Const", node.op()); 617 ++found; 618 auto folded = EvaluateNodes(output, {"e"}); 619 auto expected = EvaluateNodes(item.graph, {"e"}); 620 EXPECT_EQ(1, expected.size()); 621 EXPECT_EQ(1, folded.size()); 622 test::ExpectTensorEqual<int>(folded[0], expected[0]); 623 EXPECT_EQ(2, node.input_size()); 624 EXPECT_EQ("^p1", node.input(0)); 625 EXPECT_EQ("^p2", node.input(1)); 626 } 627 } 628 EXPECT_EQ(1, found); 629 } 630 631 TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { 632 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 633 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); 634 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); 635 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); 636 Output c = 637 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); 638 Output i1 = ops::Identity(scope.WithOpName("i1"), {c}); 639 Output i2 = 640 ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1}); 641 Output i3 = ops::Identity(scope.WithOpName("e"), {i2}); 642 643 GrapplerItem item; 644 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 645 646 ConstantFolding fold(nullptr /* cpu_device */); 647 GraphDef output; 648 Status status = fold.Optimize(nullptr, item, &output); 649 TF_EXPECT_OK(status); 650 651 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "c", 652 "i1", "i2", "e"}; 653 EXPECT_EQ(output.node_size(), expected_nodes.size()); 654 int i = 0; 655 int found = 0; 656 for (const auto& node : output.node()) { 657 EXPECT_EQ(expected_nodes[i], output.node(i).name()); 658 i++; 659 if (node.name() == "i1") { 660 EXPECT_EQ("Const", node.op()); 661 ++found; 662 auto folded = EvaluateNodes(output, {"i1"}); 663 auto expected = EvaluateNodes(item.graph, {"i1"}); 664 EXPECT_EQ(1, expected.size()); 665 EXPECT_EQ(1, folded.size()); 666 test::ExpectTensorEqual<int>(folded[0], expected[0]); 667 EXPECT_EQ(1, node.input_size()); 668 EXPECT_EQ("^p1", node.input(0)); 669 } 670 if (node.name() == "i2") { 671 EXPECT_EQ("Const", node.op()); 672 ++found; 673 auto folded = EvaluateNodes(output, {"i2"}); 674 auto expected = EvaluateNodes(item.graph, {"i2"}); 675 EXPECT_EQ(1, expected.size()); 676 EXPECT_EQ(1, folded.size()); 677 test::ExpectTensorEqual<int>(folded[0], expected[0]); 678 EXPECT_EQ(2, node.input_size()); 679 EXPECT_EQ("^p1", node.input(0)); 680 EXPECT_EQ("^p2", node.input(1)); 681 } 682 } 683 EXPECT_EQ(2, found); 684 } 685 686 TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { 687 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 688 Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); 689 Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); 690 Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); 691 Output c = 692 ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); 693 Output i1 = ops::Identity(scope.WithOpName("i1") 694 .WithControlDependencies(p2) 695 .WithControlDependencies(p1), 696 {c}); 697 Output i2 = ops::Identity(scope.WithOpName("i2"), {i1}); 698 699 GrapplerItem item; 700 item.fetch.push_back("i2"); 701 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 702 703 ConstantFolding fold(nullptr /* cpu_device */); 704 GraphDef output; 705 Status status = fold.Optimize(nullptr, item, &output); 706 TF_EXPECT_OK(status); 707 708 std::vector<string> expected_nodes = {"dflt", "p1", "p2", "i2"}; 709 EXPECT_EQ(output.node_size(), expected_nodes.size()); 710 int i = 0; 711 for (const auto& node : output.node()) { 712 EXPECT_EQ(expected_nodes[i], output.node(i).name()); 713 i++; 714 if (node.name() == "i2") { 715 EXPECT_EQ("Const", node.op()); 716 EXPECT_EQ(2, node.input_size()); 717 EXPECT_EQ("^p1", node.input(0)); 718 EXPECT_EQ("^p2", node.input(1)); 719 } 720 } 721 } 722 723 TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { 724 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 725 // Add a DynamicPartition node to the graph 726 Output input = ops::Const(scope.WithOpName("in0"), 314, {3, 4, 5}); 727 Output indices = ops::Const(scope.WithOpName("indices"), 1, {3, 4}); 728 int num_partitions = 4; 729 ops::DynamicPartition part(scope.WithOpName("partition"), input, indices, 730 num_partitions); 731 732 std::vector<string> outputs; 733 for (int i = 0; i < num_partitions; ++i) { 734 string part_out_name = strings::StrCat("part_out", i); 735 ops::Identity partition_out(scope.WithOpName(part_out_name), 736 {part.outputs[i]}); 737 outputs.push_back(part_out_name); 738 } 739 740 GrapplerItem item; 741 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 742 743 // Add a ConcatOffset node to the graph 744 Tensor initial_val(DT_INT32, TensorShape({3})); 745 test::FillIota<int>(&initial_val, 7); 746 for (int i = 1; i < 5; ++i) { 747 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("in", i), "Const") 748 .Attr("dtype", DT_INT32) 749 .Attr("value", initial_val) 750 .Finalize(item.graph.add_node())); 751 } 752 Tensor concat_dim(DT_INT32, TensorShape({})); 753 test::FillIota<int>(&concat_dim, 0); 754 TF_CHECK_OK(NodeDefBuilder("concat_dim", "Const") 755 .Attr("dtype", DT_INT32) 756 .Attr("value", concat_dim) 757 .Finalize(item.graph.add_node())); 758 759 TF_CHECK_OK(NodeDefBuilder("concat_offsets", "ConcatOffset") 760 .Input("concat_dim", 0, DT_INT32) 761 .Input({NodeDefBuilder::NodeOut("in1", 0, DT_INT32), 762 NodeDefBuilder::NodeOut("in2", 0, DT_INT32), 763 NodeDefBuilder::NodeOut("in3", 0, DT_INT32), 764 NodeDefBuilder::NodeOut("in4", 0, DT_INT32)}) 765 .Finalize(item.graph.add_node())); 766 767 for (int i = 0; i < 4; ++i) { 768 string concat_offset_out_name = strings::StrCat("concat_offset_out", i); 769 TF_CHECK_OK(NodeDefBuilder(concat_offset_out_name, "Identity") 770 .Attr("T", DT_INT32) 771 .Input("concat_offsets", i, DT_INT32) 772 .Finalize(item.graph.add_node())); 773 outputs.push_back(concat_offset_out_name); 774 } 775 776 item.fetch = outputs; 777 ConstantFolding fold(nullptr /* cpu_device */); 778 GraphDef output; 779 Status status = fold.Optimize(nullptr, item, &output); 780 TF_EXPECT_OK(status); 781 782 int constant_folded = 0; 783 for (const auto& node : output.node()) { 784 if (node.name().find("part_out") != string::npos || 785 node.name().find("concat_offset_out") != string::npos) { 786 ++constant_folded; 787 EXPECT_EQ("Const", node.op()); 788 } 789 } 790 EXPECT_EQ(8, constant_folded); 791 792 auto expected = EvaluateNodes(item.graph, outputs); 793 auto optimized = EvaluateNodes(output, outputs); 794 ASSERT_EQ(expected.size(), optimized.size()); 795 for (int i = 0; i < expected.size(); ++i) { 796 test::ExpectTensorEqual<int>(expected[i], optimized[i]); 797 } 798 } 799 800 TEST_F(ConstantFoldingTest, ShapeMaterialization) { 801 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 802 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT); 803 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT); 804 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT); 805 Output rank = ops::Rank(scope.WithOpName("rank"), v1); 806 Output shape = ops::Shape(scope.WithOpName("shape"), v2); 807 Output size = ops::Size(scope.WithOpName("size"), v3); 808 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank); 809 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape); 810 811 GrapplerItem item; 812 item.fetch.push_back("p2"); 813 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 814 815 ConstantFolding fold(nullptr /* cpu_device */); 816 GraphDef output; 817 Status status = fold.Optimize(nullptr, item, &output); 818 TF_EXPECT_OK(status); 819 820 int found = 0; 821 for (const auto& node : output.node()) { 822 if (node.name() == "p2") { 823 ++found; 824 EXPECT_EQ("Const", node.op()); 825 EXPECT_EQ(3, node.input_size()); 826 EXPECT_EQ("^v3", node.input(0)); 827 EXPECT_EQ("^v1", node.input(1)); 828 EXPECT_EQ("^v2", node.input(2)); 829 Tensor value; 830 CHECK(value.FromProto(node.attr().at("value").tensor())); 831 // rank = 1, shape = (5, 7), size = 143 = 11*13 832 // p2 = (715, 1001) = (5*143, 7*143) 833 EXPECT_EQ(715, value.flat<int>()(0)); 834 EXPECT_EQ(1001, value.flat<int>()(1)); 835 } 836 } 837 EXPECT_EQ(1, found); 838 } 839 840 TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) { 841 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 842 Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT); 843 Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT); 844 Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT); 845 Output rank = ops::Rank(scope.WithOpName("rank"), v1); 846 Output shape = ops::Shape(scope.WithOpName("shape"), v2); 847 Output size = ops::Size(scope.WithOpName("size"), v3); 848 Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank); 849 Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape); 850 851 GrapplerItem item; 852 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 853 854 ConstantFolding fold(nullptr /* cpu_device */); 855 GraphDef output; 856 Status status = fold.Optimize(nullptr, item, &output); 857 TF_EXPECT_OK(status); 858 859 int found = 0; 860 for (const auto& node : output.node()) { 861 if (node.name() == "size") { 862 ++found; 863 EXPECT_EQ("Const", node.op()); 864 EXPECT_EQ(1, node.input_size()); 865 EXPECT_EQ("^v3", node.input(0)); 866 Tensor value; 867 CHECK(value.FromProto(node.attr().at("value").tensor())); 868 EXPECT_EQ(11 * 13, value.flat<int>()(0)); 869 } else if (node.name() == "rank") { 870 ++found; 871 EXPECT_EQ("Const", node.op()); 872 EXPECT_EQ(1, node.input_size()); 873 EXPECT_EQ("^v1", node.input(0)); 874 Tensor value; 875 CHECK(value.FromProto(node.attr().at("value").tensor())); 876 EXPECT_EQ(1, value.flat<int>()(0)); 877 } else if (node.name() == "shape") { 878 ++found; 879 EXPECT_EQ("Const", node.op()); 880 EXPECT_EQ(1, node.input_size()); 881 EXPECT_EQ("^v2", node.input(0)); 882 Tensor value; 883 CHECK(value.FromProto(node.attr().at("value").tensor())); 884 EXPECT_EQ(5, value.flat<int>()(0)); 885 EXPECT_EQ(7, value.flat<int>()(1)); 886 } 887 } 888 EXPECT_EQ(3, found); 889 } 890 891 TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) { 892 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 893 Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT); 894 Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT); 895 Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT); 896 auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3}); 897 Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]); 898 Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]); 899 Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]); 900 Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]); 901 Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]); 902 Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]); 903 Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]); 904 905 GrapplerItem item; 906 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 907 908 ConstantFolding fold(nullptr /* cpu_device */); 909 GraphDef output; 910 Status status = fold.Optimize(nullptr, item, &output); 911 TF_EXPECT_OK(status); 912 int found = 0; 913 for (const auto& node : output.node()) { 914 EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst), 915 node.name()); 916 EXPECT_NE(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst), 917 node.name()); 918 if (node.name() == "i1a" || node.name() == "i1b") { 919 ++found; 920 EXPECT_EQ("s", node.input(0)); 921 } 922 if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") { 923 ++found; 924 EXPECT_EQ("s:1", node.input(0)); 925 } 926 if (node.name() == "i3a" || node.name() == "i3b") { 927 ++found; 928 EXPECT_EQ(AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst), 929 node.input(0)); 930 } 931 if (node.name() == "s") { 932 ++found; 933 EXPECT_EQ("ShapeN", node.op()); 934 EXPECT_EQ("v1", node.input(0)); 935 EXPECT_EQ("v2", node.input(1)); 936 EXPECT_EQ("v3", node.input(2)); 937 } 938 if (node.name() == 939 AddPrefixToNodeName("s-matshapes-2", kConstantFoldingConst)) { 940 ++found; 941 EXPECT_EQ("Const", node.op()); 942 EXPECT_EQ("^s", node.input(0)); 943 Tensor value; 944 CHECK(value.FromProto(node.attr().at("value").tensor())); 945 EXPECT_EQ(4, value.flat<int>()(0)); 946 EXPECT_EQ(6, value.flat<int>()(1)); 947 } 948 } 949 EXPECT_EQ(9, found); 950 } 951 952 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { 953 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 954 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); 955 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); 956 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl); 957 ops::Rank rank(scope.WithOpName("rank"), s1.output_false); 958 ops::Identity i(scope.WithOpName("i"), s1.output_true); 959 ops::Size size(scope.WithOpName("size"), i); 960 ops::Square p1(scope.WithOpName("p1"), rank); 961 ops::Square p2(scope.WithOpName("p2"), size); 962 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y}); 963 964 Output predicate = 965 ops::Const(scope.WithOpName("false"), false, TensorShape({})); 966 Output constant = 967 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1})); 968 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate); 969 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false); 970 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true); 971 ops::Merge m2(scope.WithOpName("m2"), 972 {statically_known.output, never_generated.output}); 973 974 GrapplerItem item; 975 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 976 977 ConstantFolding fold(nullptr /* cpu_device */); 978 GraphDef output; 979 Status status = fold.Optimize(nullptr, item, &output); 980 TF_EXPECT_OK(status); 981 982 std::set<string> present_nodes = {"v_in", "v_ctrl", 983 "switch", "i", 984 "p1", "p2", 985 "m", "false", 986 "constant", "switch2", 987 "i2", "i3", 988 "m2", "ConstantFoldingCtrl/switch_0", 989 "rank", "size"}; 990 std::set<string> not_present_nodes = {"ConstantFolding/switch2-0"}; 991 EXPECT_EQ(present_nodes.size(), output.node_size()); 992 int found = 0; 993 for (const auto& node : output.node()) { 994 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()); 995 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()); 996 present_nodes.erase(node.name()); 997 not_present_nodes.erase(node.name()); 998 if (node.name() == "rank") { 999 ++found; 1000 EXPECT_EQ("Const", node.op()); 1001 EXPECT_EQ(1, node.input_size()); 1002 EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0)); 1003 } 1004 if (node.name() == "size") { 1005 ++found; 1006 EXPECT_EQ("Const", node.op()); 1007 EXPECT_EQ(1, node.input_size()); 1008 EXPECT_EQ("^i", node.input(0)); 1009 } 1010 if (node.name() == "i2") { 1011 ++found; 1012 EXPECT_EQ("Const", node.op()); 1013 EXPECT_EQ(0, node.input_size()); 1014 } 1015 if (node.name() == "i3") { 1016 ++found; 1017 EXPECT_EQ("Identity", node.op()); 1018 EXPECT_EQ(1, node.input_size()); 1019 EXPECT_EQ("switch2:1", node.input(0)); 1020 } 1021 } 1022 EXPECT_EQ(4, found); 1023 } 1024 1025 TEST_F(ConstantFoldingTest, SwitchNodes) { 1026 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 1027 ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); 1028 ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); 1029 ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl); 1030 ops::Rank rank(scope.WithOpName("rank"), s1.output_false); 1031 ops::Identity i(scope.WithOpName("i"), s1.output_true); 1032 ops::Size size(scope.WithOpName("size"), i); 1033 ops::Square p1(scope.WithOpName("p1"), rank); 1034 ops::Square p2(scope.WithOpName("p2"), size); 1035 ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y}); 1036 1037 Output predicate = 1038 ops::Const(scope.WithOpName("false"), false, TensorShape({})); 1039 Output constant = 1040 ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1})); 1041 ops::Switch s2(scope.WithOpName("switch2"), constant, predicate); 1042 ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false); 1043 ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true); 1044 ops::Merge m2(scope.WithOpName("m2"), 1045 {statically_known.output, never_generated.output}); 1046 1047 GrapplerItem item; 1048 item.fetch.push_back("m"); 1049 item.fetch.push_back("m2"); 1050 1051 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 1052 1053 ConstantFolding fold(nullptr /* cpu_device */); 1054 GraphDef output; 1055 Status status = fold.Optimize(nullptr, item, &output); 1056 TF_EXPECT_OK(status); 1057 std::set<string> present_nodes = {"v_in", "v_ctrl", 1058 "switch", "i", 1059 "p1", "p2", 1060 "m", "false", 1061 "constant", "switch2", 1062 "i2", "i3", 1063 "m2", "ConstantFoldingCtrl/switch_0"}; 1064 std::set<string> not_present_nodes = {"rank", "size", 1065 "ConstantFolding/switch2-0"}; 1066 EXPECT_EQ(present_nodes.size(), output.node_size()); 1067 1068 int found = 0; 1069 for (const auto& node : output.node()) { 1070 EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()); 1071 EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()); 1072 present_nodes.erase(node.name()); 1073 not_present_nodes.erase(node.name()); 1074 if (node.name() == "i2") { 1075 ++found; 1076 EXPECT_EQ("Const", node.op()); 1077 EXPECT_EQ(0, node.input_size()); 1078 } 1079 if (node.name() == "i3") { 1080 ++found; 1081 EXPECT_EQ("Identity", node.op()); 1082 EXPECT_EQ(1, node.input_size()); 1083 EXPECT_EQ("switch2:1", node.input(0)); 1084 } 1085 } 1086 EXPECT_EQ(2, found); 1087 } 1088 1089 TEST_F(ConstantFoldingTest, MergeNodes) { 1090 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 1091 1092 Output x = 1093 ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT); 1094 Output y = 1095 ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT); 1096 Output const1 = 1097 ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f, 1098 TensorShape({3, 5})); 1099 Output const2 = 1100 ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5})); 1101 Output const3 = 1102 ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f, 1103 TensorShape({3, 5})); 1104 1105 // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't. 1106 ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2}); 1107 ops::Merge m2(scope.WithOpName("m2"), {const1, const3}); 1108 ops::Merge m3(scope.WithOpName("m3"), {x, y}); 1109 1110 ops::Identity out1(scope.WithOpName("out1"), m1.output); 1111 ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index); 1112 ops::Identity out2(scope.WithOpName("out2"), m2.output); 1113 ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index); 1114 ops::Identity out3(scope.WithOpName("out3"), m3.output); 1115 ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index); 1116 1117 GrapplerItem item; 1118 item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"}; 1119 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 1120 1121 ConstantFolding fold(nullptr /* cpu_device */); 1122 GraphDef output; 1123 Status status = fold.Optimize(nullptr, item, &output); 1124 TF_EXPECT_OK(status); 1125 1126 int found_nodes = 0; 1127 for (const auto& node : output.node()) { 1128 if (node.name() == "out1") { 1129 EXPECT_EQ(1, node.input_size()); 1130 EXPECT_EQ("^m1", node.input(0)); 1131 ++found_nodes; 1132 } else if (node.name() == "idx1") { 1133 EXPECT_EQ(1, node.input_size()); 1134 EXPECT_EQ("^m1", node.input(0)); 1135 ++found_nodes; 1136 } else if (node.name() == "ConstantFolding/m1") { 1137 EXPECT_EQ("Const", node.op()); 1138 EXPECT_EQ(1, node.input_size()); 1139 EXPECT_EQ("^m1", node.input(0)); 1140 ++found_nodes; 1141 } else if (node.name() == "ConstantFolding/m1_index") { 1142 EXPECT_EQ("Const", node.op()); 1143 EXPECT_EQ(1, node.input_size()); 1144 EXPECT_EQ("^m1", node.input(0)); 1145 ++found_nodes; 1146 } else if (node.name() == "out2") { 1147 EXPECT_EQ(1, node.input_size()); 1148 EXPECT_EQ("m2", node.input(0)); 1149 ++found_nodes; 1150 } else if (node.name() == "idx2") { 1151 EXPECT_EQ(1, node.input_size()); 1152 EXPECT_EQ("m2:1", node.input(0)); 1153 ++found_nodes; 1154 } else if (node.name() == "out3") { 1155 EXPECT_EQ(1, node.input_size()); 1156 EXPECT_EQ("m3", node.input(0)); 1157 ++found_nodes; 1158 } else if (node.name() == "idx3") { 1159 EXPECT_EQ(1, node.input_size()); 1160 EXPECT_EQ("m3:1", node.input(0)); 1161 ++found_nodes; 1162 } 1163 } 1164 // Make sure the graph contains all the nodes we're expecting. 1165 EXPECT_EQ(6, found_nodes); 1166 1167 std::vector<string> fetch = {"out1", "idx1"}; 1168 auto tensors = EvaluateNodes(output, fetch); 1169 EXPECT_EQ(2, tensors.size()); 1170 const Tensor& out_value = tensors[0]; 1171 EXPECT_EQ(3 * 5, out_value.NumElements()); 1172 for (int i = 0; i < 3 * 5; ++i) { 1173 EXPECT_EQ(3.14f, out_value.flat<float>()(i)); 1174 } 1175 const Tensor& out_idx = tensors[1]; 1176 EXPECT_EQ(1, out_idx.NumElements()); 1177 EXPECT_EQ(2, out_idx.flat<int32>()(0)); 1178 } 1179 1180 TEST_F(ConstantFoldingTest, NoOpReduction) { 1181 // Build a simple graph with a reduction that can be reduced to the identity. 1182 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 1183 1184 Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT); 1185 Output c = 1186 ops::Const(scope.WithOpName("c").WithControlDependencies(v), 0, {0}); 1187 Output i = ops::Identity(scope.WithOpName("i"), c); 1188 Output p = ops::Prod(scope.WithOpName("p"), v, i); 1189 Output s = ops::Square(scope.WithOpName("s"), p); 1190 1191 GrapplerItem item; 1192 item.fetch.push_back("s"); 1193 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 1194 1195 ConstantFolding fold(nullptr /* cpu_device */); 1196 GraphDef output; 1197 Status status = fold.Optimize(nullptr, item, &output); 1198 TF_EXPECT_OK(status); 1199 1200 bool found = false; 1201 for (const auto& node : output.node()) { 1202 if (node.name() == "p") { 1203 found = true; 1204 EXPECT_EQ("Identity", node.op()); 1205 EXPECT_EQ(2, node.input_size()); 1206 EXPECT_EQ("v", node.input(0)); 1207 EXPECT_EQ("^i", node.input(1)); 1208 } 1209 } 1210 EXPECT_TRUE(found); 1211 } 1212 1213 TEST_F(ConstantFoldingTest, NoOpReshape) { 1214 // Build a simple graph with a reshape that can be reduced to the identity. 1215 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 1216 1217 // A reshape than can be optimized 1218 Output d1 = ops::Const(scope.WithOpName("d1"), 3.14f, {17}); 1219 Output v1 = ops::Variable(scope.WithOpName("v1"), {17}, DT_FLOAT); 1220 Output c1 = 1221 ops::Const(scope.WithOpName("c1").WithControlDependencies(v1), 17, {1}); 1222 Output i1 = ops::Identity(scope.WithOpName("i1"), c1); 1223 Output r1 = 1224 ops::Reshape(scope.WithOpName("r1").WithControlDependencies(d1), v1, i1); 1225 Output s1 = ops::Square(scope.WithOpName("s1"), r1); 1226 1227 // A multi dimensional reshape than can be optimized 1228 Output v3 = ops::Variable(scope.WithOpName("v3"), {5, 5, 5}, DT_FLOAT); 1229 Output c3 = 1230 ops::Const(scope.WithOpName("c3").WithControlDependencies(v3), 5, {3}); 1231 Output i3 = ops::Identity(scope.WithOpName("i3"), c3); 1232 Output r3 = ops::Reshape(scope.WithOpName("r3"), v3, i3); 1233 Output s3 = ops::Square(scope.WithOpName("s3"), r3); 1234 1235 // A multi dimensional partially defined reshape than can be optimized 1236 Output v4 = ops::Variable(scope.WithOpName("v4"), {5, 5, 5}, DT_FLOAT); 1237 Output c4 = ops::Const(scope.WithOpName("c4").WithControlDependencies(v4), 1238 {5, -1, 5}, {3}); 1239 Output i4 = ops::Identity(scope.WithOpName("i4"), c4); 1240 Output r4 = ops::Reshape(scope.WithOpName("r4"), v4, i4); 1241 Output s4 = ops::Square(scope.WithOpName("s4"), r4); 1242 1243 // A reshape that can't be optimized 1244 Output v2 = ops::Variable(scope.WithOpName("v2"), {17, 1}, DT_FLOAT); 1245 Output c2 = 1246 ops::Const(scope.WithOpName("c2").WithControlDependencies(v2), 17, {1}); 1247 Output r2 = ops::Reshape(scope.WithOpName("r2"), v2, c2); 1248 Output s2 = ops::Square(scope.WithOpName("s2"), r2); 1249 1250 GrapplerItem item; 1251 item.fetch = {"s1", "s2", "s3", "s4"}; 1252 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 1253 1254 ConstantFolding fold(nullptr /* cpu_device */); 1255 GraphDef output; 1256 Status status = fold.Optimize(nullptr, item, &output); 1257 TF_EXPECT_OK(status); 1258 1259 int found = 0; 1260 for (const auto& node : output.node()) { 1261 if (node.name() == "r1") { 1262 ++found; 1263 EXPECT_EQ("Identity", node.op()); 1264 ASSERT_EQ(3, node.input_size()); 1265 EXPECT_EQ("v1", node.input(0)); 1266 EXPECT_EQ("^i1", node.input(1)); 1267 EXPECT_EQ("^d1", node.input(2)); 1268 } else if (node.name() == "r3") { 1269 ++found; 1270 EXPECT_EQ("Identity", node.op()); 1271 ASSERT_EQ(2, node.input_size()); 1272 EXPECT_EQ("v3", node.input(0)); 1273 EXPECT_EQ("^i3", node.input(1)); 1274 } else if (node.name() == "r4") { 1275 ++found; 1276 EXPECT_EQ("Identity", node.op()); 1277 ASSERT_EQ(2, node.input_size()); 1278 EXPECT_EQ("v4", node.input(0)); 1279 EXPECT_EQ("^i4", node.input(1)); 1280 } else if (node.name() == "r2") { 1281 ++found; 1282 EXPECT_EQ("Reshape", node.op()); 1283 ASSERT_EQ(2, node.input_size()); 1284 EXPECT_EQ("v2", node.input(0)); 1285 EXPECT_EQ("c2", node.input(1)); 1286 } 1287 } 1288 EXPECT_EQ(4, found); 1289 } 1290 1291 TEST_F(ConstantFoldingTest, Packing) { 1292 // Build a simple graph with a large constant that can be folded. 1293 tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); 1294 Output c = ops::Const(scope.WithOpName("c"), 3.14f, {1000}); 1295 Output i1 = ops::Identity(scope.WithOpName("i1"), c); 1296 Output i2 = ops::Identity(scope.WithOpName("i2"), c); 1297 1298 GrapplerItem item; 1299 TF_CHECK_OK(scope.ToGraphDef(&item.graph)); 1300 1301 ConstantFolding fold(nullptr /* cpu_device */); 1302 GraphDef output; 1303 Status status = fold.Optimize(nullptr, item, &output); 1304 TF_EXPECT_OK(status); 1305 1306 // Make sure that the representation of the folded constant is space 1307 // efficient: in particular, the whole message should be smaller than 8k (the 1308 // size needed to naively encode 1000 floats folded twice). 1309 EXPECT_GT(8000, output.ByteSizeLong()); 1310 } 1311 1312 TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { 1313 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 1314 Output a = 1315 ops::Placeholder(s.WithOpName("a"), DT_FLOAT, 1316 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 1317 Output b = ops::Square(s.WithOpName("b"), a); 1318 Output c = ops::Mul(s.WithOpName("c"), a, b); 1319 Output d = ops::Shape(s.WithOpName("d"), a); 1320 Output e = ops::Shape(s.WithOpName("e"), b); 1321 1322 auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e); 1323 Output o1 = ops::Identity(s.WithOpName("o1"), f.r0); 1324 Output o2 = ops::Identity(s.WithOpName("o2"), f.r1); 1325 1326 Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT, 1327 ops::Placeholder::Shape(PartialTensorShape({1}))); 1328 Output h = ops::Shape(s.WithOpName("h"), g); 1329 auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h); 1330 Output p1 = ops::Identity(s.WithOpName("p1"), i.r0); 1331 Output p2 = ops::Identity(s.WithOpName("p2"), i.r1); 1332 1333 GrapplerItem item; 1334 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 1335 1336 ConstantFolding fold(nullptr /* cpu_device */); 1337 GraphDef output; 1338 Status status = fold.Optimize(nullptr, item, &output); 1339 TF_EXPECT_OK(status); 1340 1341 // Run a second time to make sure the optimization is idempotent. 1342 item.graph.Swap(&output); 1343 status = fold.Optimize(nullptr, item, &output); 1344 TF_EXPECT_OK(status); 1345 1346 int found = 0; 1347 for (const auto& node : output.node()) { 1348 if (node.name() == "o1") { 1349 ++found; 1350 EXPECT_EQ(1, node.input_size()); 1351 EXPECT_EQ("ConstantFolding/f-bcastargs-0", node.input(0)); 1352 } else if (node.name() == "o2") { 1353 ++found; 1354 EXPECT_EQ(1, node.input_size()); 1355 EXPECT_EQ("ConstantFolding/f-bcastargs-1", node.input(0)); 1356 } else if (node.name() == "ConstantFolding/f-bcastargs-0") { 1357 ++found; 1358 EXPECT_EQ("Const", node.op()); 1359 EXPECT_EQ(1, node.input_size()); 1360 EXPECT_EQ("^f", node.input(0)); 1361 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) 1362 .num_elements()); 1363 } else if (node.name() == "ConstantFolding/f-bcastargs-1") { 1364 ++found; 1365 EXPECT_EQ("Const", node.op()); 1366 EXPECT_EQ(1, node.input_size()); 1367 EXPECT_EQ("^f", node.input(0)); 1368 EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) 1369 .num_elements()); 1370 } else if (node.name() == "p1") { 1371 ++found; 1372 EXPECT_EQ(1, node.input_size()); 1373 EXPECT_EQ("i", node.input(0)); 1374 } else if (node.name() == "p2") { 1375 ++found; 1376 EXPECT_EQ(1, node.input_size()); 1377 EXPECT_EQ("i:1", node.input(0)); 1378 } 1379 } 1380 EXPECT_EQ(6, found); 1381 } 1382 1383 TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { 1384 tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 1385 Output input = 1386 ops::Placeholder(s.WithOpName("input"), DT_FLOAT, 1387 ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); 1388 Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); 1389 Output sum = ops::Sum(s.WithOpName("sum"), input, indices); 1390 Output size = ops::Const(s.WithOpName("size"), 1, {1}); 1391 Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); 1392 1393 GrapplerItem item; 1394 TF_CHECK_OK(s.ToGraphDef(&item.graph)); 1395 item.fetch.push_back("reshape"); 1396 1397 ConstantFolding fold(nullptr /* cpu_device */); 1398 GraphDef output; 1399 Status status = fold.Optimize(nullptr, item, &output); 1400 TF_EXPECT_OK(status); 1401 1402 // Run a second time to make sure the optimization is idempotent. 1403 item.graph.Swap(&output); 1404 status = fold.Optimize(nullptr, item, &output); 1405 TF_EXPECT_OK(status); 1406 1407 int found = 0; 1408 for (const auto& node : output.node()) { 1409 if (node.name() == "ConstantFolding/sum-reduction_indices") { 1410 ++found; 1411 EXPECT_EQ("Const", node.op()); 1412 EXPECT_EQ("^indices", node.input(0)); 1413 EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) 1414 .num_elements()); 1415 } else if (node.name() == "sum") { 1416 ++found; 1417 EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); 1418 } else if (node.name() == "indices") { 1419 ++found; 1420 } 1421 } 1422 EXPECT_EQ(3, found); 1423 } 1424 1425 } // namespace 1426 } // namespace grappler 1427 } // namespace tensorflow 1428 1429 // LocalWords: NewRootScope 1430