1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/graph_transforms/transform_utils.h" 17 #include "tensorflow/cc/ops/const_op.h" 18 #include "tensorflow/cc/ops/image_ops.h" 19 #include "tensorflow/cc/ops/nn_ops.h" 20 #include "tensorflow/cc/ops/standard_ops.h" 21 #include "tensorflow/core/framework/tensor_testutil.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/io/path.h" 24 #include "tensorflow/core/platform/test.h" 25 #include "tensorflow/core/platform/test_benchmark.h" 26 27 namespace tensorflow { 28 namespace graph_transforms { 29 30 class TransformUtilsTest : public ::testing::Test { 31 protected: 32 void TestMapNamesToNodes() { 33 auto root = tensorflow::Scope::NewRootScope(); 34 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 35 36 const int width = 100; 37 38 Tensor a_data(DT_FLOAT, TensorShape({width})); 39 test::FillIota<float>(&a_data, 1.0f); 40 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 41 42 Tensor b_data(DT_FLOAT, TensorShape({width})); 43 test::FillIota<float>(&b_data, 1.0f); 44 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 45 46 Output add = Add(root.WithOpName("add"), a_const, b_const); 47 48 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 49 50 Output mul = Mul(root.WithOpName("output"), add, placeholder); 51 52 GraphDef graph_def; 53 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 54 std::map<string, const NodeDef*> node_map; 55 MapNamesToNodes(graph_def, &node_map); 56 57 EXPECT_EQ(1, node_map.count("a")); 58 EXPECT_EQ(1, node_map.count("b")); 59 EXPECT_EQ(1, node_map.count("add")); 60 EXPECT_EQ(1, node_map.count("placeholder")); 61 EXPECT_EQ(1, node_map.count("output")); 62 EXPECT_EQ(0, node_map.count("no_such_node")); 63 } 64 65 void TestMapNodesToOutputs() { 66 auto root = tensorflow::Scope::NewRootScope(); 67 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 68 69 const int width = 100; 70 71 Tensor a_data(DT_FLOAT, TensorShape({width})); 72 test::FillIota<float>(&a_data, 1.0f); 73 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 74 75 Tensor b_data(DT_FLOAT, TensorShape({width})); 76 test::FillIota<float>(&b_data, 1.0f); 77 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 78 79 Output add = Add(root.WithOpName("add"), a_const, b_const); 80 81 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 82 83 Output mul = Mul(root.WithOpName("output"), add, placeholder); 84 85 GraphDef graph_def; 86 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 87 88 std::map<string, std::vector<const NodeDef*>> outputs_map; 89 MapNodesToOutputs(graph_def, &outputs_map); 90 91 EXPECT_EQ(1, outputs_map.count("a")); 92 EXPECT_EQ(1, outputs_map["a"].size()); 93 EXPECT_EQ("add", outputs_map["a"][0]->name()); 94 95 EXPECT_EQ(1, outputs_map.count("b")); 96 EXPECT_EQ(1, outputs_map["b"].size()); 97 EXPECT_EQ("add", outputs_map["b"][0]->name()); 98 99 EXPECT_EQ(1, outputs_map.count("add")); 100 EXPECT_EQ(1, outputs_map["add"].size()); 101 EXPECT_EQ("output", outputs_map["add"][0]->name()); 102 103 EXPECT_EQ(1, outputs_map.count("placeholder")); 104 EXPECT_EQ(1, outputs_map["placeholder"].size()); 105 EXPECT_EQ("output", outputs_map["placeholder"][0]->name()); 106 107 EXPECT_EQ(0, outputs_map.count("output")); 108 EXPECT_EQ(0, outputs_map.count("no_such_node")); 109 } 110 111 void TestNodeNamePartsFromInput() { 112 string prefix; 113 string node_name; 114 string suffix; 115 116 NodeNamePartsFromInput("some_node_name", &prefix, &node_name, &suffix); 117 EXPECT_EQ("", prefix); 118 EXPECT_EQ("some_node_name", node_name); 119 EXPECT_EQ("", suffix); 120 121 NodeNamePartsFromInput("some_node_name/with/slashes", &prefix, &node_name, 122 &suffix); 123 EXPECT_EQ("", prefix); 124 EXPECT_EQ("some_node_name/with/slashes", node_name); 125 EXPECT_EQ("", suffix); 126 127 NodeNamePartsFromInput("some_node_name:0", &prefix, &node_name, &suffix); 128 EXPECT_EQ("", prefix); 129 EXPECT_EQ("some_node_name", node_name); 130 EXPECT_EQ(":0", suffix); 131 132 NodeNamePartsFromInput("^some_node_name", &prefix, &node_name, &suffix); 133 EXPECT_EQ("^", prefix); 134 EXPECT_EQ("some_node_name", node_name); 135 EXPECT_EQ("", suffix); 136 137 NodeNamePartsFromInput("^some_node_name:99", &prefix, &node_name, &suffix); 138 EXPECT_EQ("^", prefix); 139 EXPECT_EQ("some_node_name", node_name); 140 EXPECT_EQ(":99", suffix); 141 } 142 143 void TestNodeNameFromInput() { 144 EXPECT_EQ("node_name", NodeNameFromInput("node_name")); 145 EXPECT_EQ("node_name", NodeNameFromInput("node_name:0")); 146 EXPECT_EQ("node_name", NodeNameFromInput("^node_name")); 147 EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42")); 148 } 149 150 void TestCanonicalInputName() { 151 EXPECT_EQ("node_name:0", CanonicalInputName("node_name")); 152 EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0")); 153 EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name")); 154 EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42")); 155 } 156 157 void TestAddNodeInput() { 158 NodeDef node; 159 AddNodeInput("foo", &node); 160 EXPECT_EQ("foo", node.input(0)); 161 } 162 163 void TestCopyNodeAttr() { 164 NodeDef node; 165 auto mutable_attr = node.mutable_attr(); 166 (*mutable_attr)["foo"].set_i(3); 167 168 NodeDef copied_node; 169 CopyNodeAttr(node, "foo", "bar", &copied_node); 170 EXPECT_EQ(3, copied_node.attr().at("bar").i()); 171 } 172 173 void TestSetNodeAttr() { 174 NodeDef node; 175 int32 value_i = 32; 176 SetNodeAttr("foo", value_i, &node); 177 EXPECT_EQ(32, node.attr().at("foo").i()); 178 string value_s = "some_value"; 179 SetNodeAttr("bar", value_s, &node); 180 EXPECT_EQ("some_value", node.attr().at("bar").s()); 181 } 182 183 void TestSetNodeTensorAttr() { 184 NodeDef node; 185 SetNodeTensorAttr<int32>("foo", {3, 1}, {1, 2, 3}, &node); 186 TensorProto tensor_proto = node.attr().at("foo").tensor(); 187 Tensor tensor; 188 CHECK(tensor.FromProto(tensor_proto)); 189 EXPECT_EQ(DT_INT32, tensor.dtype()); 190 EXPECT_EQ(3, tensor.shape().dim_size(0)); 191 EXPECT_EQ(1, tensor.shape().dim_size(1)); 192 EXPECT_EQ(1, tensor.flat<int32>()(0)); 193 EXPECT_EQ(2, tensor.flat<int32>()(1)); 194 EXPECT_EQ(3, tensor.flat<int32>()(2)); 195 } 196 197 void TestSetNodeTensorAttrWithTensor() { 198 NodeDef node; 199 Tensor input_tensor(DT_INT32, {4, 5}); 200 test::FillIota<int32>(&input_tensor, 1); 201 SetNodeTensorAttr<int32>("foo", input_tensor, &node); 202 TensorProto tensor_proto = node.attr().at("foo").tensor(); 203 Tensor tensor; 204 CHECK(tensor.FromProto(tensor_proto)); 205 test::ExpectTensorEqual<int32>(input_tensor, tensor); 206 } 207 208 void TestGetNodeTensorAttr() { 209 NodeDef node; 210 Tensor input_tensor(DT_INT32, {4, 5}); 211 test::FillIota<int32>(&input_tensor, 1); 212 TensorProto tensor_proto; 213 input_tensor.AsProtoTensorContent(&tensor_proto); 214 SetNodeAttr("foo", tensor_proto, &node); 215 Tensor result = GetNodeTensorAttr(node, "foo"); 216 test::ExpectTensorEqual<int32>(input_tensor, result); 217 } 218 219 void TestFilterGraphDef() { 220 auto root = tensorflow::Scope::NewRootScope(); 221 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 222 223 const int width = 100; 224 225 Tensor a_data(DT_FLOAT, TensorShape({width})); 226 test::FillIota<float>(&a_data, 1.0f); 227 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 228 229 Tensor b_data(DT_FLOAT, TensorShape({width})); 230 test::FillIota<float>(&b_data, 1.0f); 231 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 232 233 Output add = Add(root.WithOpName("add"), a_const, b_const); 234 235 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 236 237 Output mul = Mul(root.WithOpName("output"), add, placeholder); 238 239 Output remove_me = Add(root.WithOpName("remove_me"), mul, add); 240 241 GraphDef graph_def; 242 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 243 244 GraphDef result_graph_def; 245 FilterGraphDef( 246 graph_def, 247 [](const NodeDef& node) { return (node.name() != "remove_me"); }, 248 &result_graph_def); 249 250 std::map<string, const NodeDef*> node_map; 251 MapNamesToNodes(result_graph_def, &node_map); 252 EXPECT_EQ(1, node_map.count("a")); 253 EXPECT_EQ(1, node_map.count("b")); 254 EXPECT_EQ(1, node_map.count("add")); 255 EXPECT_EQ(1, node_map.count("placeholder")); 256 EXPECT_EQ(1, node_map.count("output")); 257 EXPECT_EQ(0, node_map.count("remove_me")); 258 } 259 260 void TestRemoveAttributes() { 261 auto root = tensorflow::Scope::NewRootScope(); 262 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 263 264 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 265 266 GraphDef graph_def; 267 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 268 269 GraphDef result_graph_def; 270 RemoveAttributes(graph_def, {"dtype"}, &result_graph_def); 271 272 std::map<string, const NodeDef*> node_map; 273 MapNamesToNodes(result_graph_def, &node_map); 274 const NodeDef* removed_placeholder = node_map["placeholder"]; 275 EXPECT_EQ(nullptr, 276 tensorflow::AttrSlice(*removed_placeholder).Find("dtype")); 277 } 278 279 void TestGetOpTypeMatches() { 280 auto root = tensorflow::Scope::NewRootScope(); 281 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 282 283 const int width = 100; 284 285 Tensor a_data(DT_FLOAT, TensorShape({width})); 286 test::FillIota<float>(&a_data, 1.0f); 287 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 288 289 Tensor b_data(DT_FLOAT, TensorShape({width})); 290 test::FillIota<float>(&b_data, 1.0f); 291 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 292 293 Output add = Add(root.WithOpName("add"), a_const, b_const); 294 295 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 296 297 Output mul = Mul(root.WithOpName("output"), add, placeholder); 298 299 GraphDef graph_def; 300 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 301 302 GraphMatcher matcher(graph_def); 303 304 std::vector<NodeMatch> const_matches; 305 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches)); 306 EXPECT_EQ(2, const_matches.size()); 307 for (const NodeMatch& match : const_matches) { 308 EXPECT_EQ("Const", match.node.op()); 309 EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name())) 310 << "match.node.name()=" << match.node.name(); 311 } 312 313 std::vector<NodeMatch> add_matches; 314 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches)); 315 EXPECT_EQ(1, add_matches.size()); 316 EXPECT_EQ("Add", add_matches[0].node.op()); 317 EXPECT_EQ("add", add_matches[0].node.name()); 318 319 std::vector<NodeMatch> add_child_matches; 320 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}}, 321 &add_child_matches)); 322 EXPECT_EQ(1, add_child_matches.size()); 323 EXPECT_EQ("Add", add_child_matches[0].node.op()); 324 EXPECT_EQ("add", add_child_matches[0].node.name()); 325 EXPECT_EQ(2, add_child_matches[0].inputs.size()); 326 for (const NodeMatch& match : add_child_matches[0].inputs) { 327 EXPECT_EQ("Const", match.node.op()); 328 EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name())) 329 << "match.node.name()=" << match.node.name(); 330 } 331 332 std::vector<NodeMatch> no_such_matches; 333 TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches)); 334 EXPECT_EQ(0, no_such_matches.size()); 335 336 std::vector<NodeMatch> all_matches; 337 TF_ASSERT_OK(matcher.GetOpTypeMatches( 338 {"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}}, 339 &all_matches)); 340 EXPECT_EQ(1, all_matches.size()); 341 EXPECT_EQ("Mul", all_matches[0].node.op()); 342 EXPECT_EQ("output", all_matches[0].node.name()); 343 EXPECT_EQ(2, all_matches[0].inputs.size()); 344 EXPECT_EQ("Add", all_matches[0].inputs[0].node.op()); 345 EXPECT_EQ("add", all_matches[0].inputs[0].node.name()); 346 EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size()); 347 EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op()); 348 EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name()); 349 EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size()); 350 EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op()); 351 EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name()); 352 EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size()); 353 EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op()); 354 EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name()); 355 EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size()); 356 357 std::vector<NodeMatch> wildcard_matches; 358 TF_ASSERT_OK( 359 matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches)); 360 EXPECT_EQ(1, wildcard_matches.size()); 361 EXPECT_EQ("Add", wildcard_matches[0].node.op()); 362 EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op()); 363 EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name()); 364 EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op()); 365 EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name()); 366 367 std::vector<NodeMatch> or_matches; 368 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches)); 369 EXPECT_EQ(2, or_matches.size()); 370 EXPECT_EQ("Add", or_matches[0].node.op()); 371 EXPECT_EQ("add", or_matches[0].node.name()); 372 EXPECT_EQ("Mul", or_matches[1].node.op()); 373 EXPECT_EQ("output", or_matches[1].node.name()); 374 } 375 376 void TestGetOpTypeMatchesDAG() { 377 auto root = tensorflow::Scope::NewRootScope(); 378 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 379 380 const int width = 100; 381 382 Tensor a_data(DT_FLOAT, TensorShape({width})); 383 test::FillIota<float>(&a_data, 1.0f); 384 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 385 386 Output add = Add(root.WithOpName("add"), a_const, a_const); 387 388 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 389 390 Output mul = Mul(root.WithOpName("output"), add, placeholder); 391 392 GraphDef graph_def; 393 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 394 395 GraphMatcher matcher(graph_def); 396 397 std::vector<NodeMatch> add_matches; 398 TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}}, 399 &add_matches)); 400 EXPECT_EQ(1, add_matches.size()); 401 EXPECT_EQ("Add", add_matches[0].node.op()); 402 EXPECT_EQ("add", add_matches[0].node.name()); 403 EXPECT_EQ("Const", add_matches[0].inputs[0].node.op()); 404 EXPECT_EQ("a", add_matches[0].inputs[0].node.name()); 405 EXPECT_EQ("Const", add_matches[0].inputs[1].node.op()); 406 EXPECT_EQ("a", add_matches[0].inputs[1].node.name()); 407 } 408 409 void TestReplaceMatchingOpTypes() { 410 auto root = tensorflow::Scope::NewRootScope(); 411 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 412 413 const int width = 10; 414 415 Tensor a_data(DT_FLOAT, TensorShape({width})); 416 test::FillIota<float>(&a_data, 1.0f); 417 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 418 419 Tensor b_data(DT_FLOAT, TensorShape({width})); 420 test::FillIota<float>(&b_data, 1.0f); 421 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 422 423 Output add = Add(root.WithOpName("add"), a_const, b_const); 424 425 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 426 427 Output mul = Mul(root.WithOpName("output"), add, placeholder); 428 429 GraphDef graph_def; 430 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 431 432 GraphDef replaced_graph_def; 433 TF_ASSERT_OK(ReplaceMatchingOpTypes( 434 graph_def, {"*"}, 435 [](const NodeMatch& match, const std::set<string>& input_nodes, 436 const std::set<string>& output_nodes, 437 std::vector<NodeDef>* new_nodes) { 438 NodeDef original_copy; 439 original_copy = match.node; 440 const string original_name = match.node.name(); 441 original_copy.set_name(original_name + "_before_identity"); 442 new_nodes->push_back(original_copy); 443 444 NodeDef identity_node; 445 identity_node.set_op("Identity"); 446 identity_node.set_name(original_name); 447 *(identity_node.mutable_input()->Add()) = original_copy.name(); 448 new_nodes->push_back(identity_node); 449 450 return Status::OK(); 451 }, 452 {}, &replaced_graph_def)); 453 454 EXPECT_EQ(10, replaced_graph_def.node_size()); 455 for (const NodeDef& node : replaced_graph_def.node()) { 456 if (node.name() == "output") { 457 EXPECT_EQ("Identity", node.op()); 458 EXPECT_EQ("output_before_identity", node.input(0)); 459 } else if (node.name() == "output_before_identity") { 460 EXPECT_EQ("Mul", node.op()); 461 EXPECT_EQ("add", node.input(0)); 462 EXPECT_EQ("placeholder", node.input(1)); 463 } else if (node.name() == "placeholder") { 464 EXPECT_EQ("Identity", node.op()); 465 EXPECT_EQ("placeholder_before_identity", node.input(0)); 466 } else if (node.name() == "placeholder_before_identity") { 467 EXPECT_EQ("Placeholder", node.op()); 468 } else if (node.name() == "add") { 469 EXPECT_EQ("Identity", node.op()); 470 EXPECT_EQ("add_before_identity", node.input(0)); 471 } else if (node.name() == "add_before_identity") { 472 EXPECT_EQ("Add", node.op()); 473 EXPECT_EQ("a", node.input(0)); 474 EXPECT_EQ("b", node.input(1)); 475 } else if (node.name() == "a") { 476 EXPECT_EQ("Identity", node.op()); 477 EXPECT_EQ("a_before_identity", node.input(0)); 478 } else if (node.name() == "a_before_identity") { 479 EXPECT_EQ("Const", node.op()); 480 } else if (node.name() == "b") { 481 EXPECT_EQ("Identity", node.op()); 482 EXPECT_EQ("b_before_identity", node.input(0)); 483 } else if (node.name() == "b_before_identity") { 484 EXPECT_EQ("Const", node.op()); 485 } else { 486 EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name(); 487 } 488 } 489 } 490 491 void TestMatchedNodesAsArray() { 492 NodeMatch fourth; 493 fourth.node.set_name("fourth"); 494 495 NodeMatch second; 496 second.node.set_name("second"); 497 second.inputs.push_back(fourth); 498 499 NodeMatch third; 500 third.node.set_name("third"); 501 third.inputs.push_back(fourth); 502 503 NodeMatch first; 504 first.node.set_name("first"); 505 first.inputs.push_back(second); 506 first.inputs.push_back(third); 507 508 std::vector<NodeDef> result; 509 MatchedNodesAsArray(first, &result); 510 511 EXPECT_EQ(4, result.size()); 512 EXPECT_EQ("first", result[0].name()); 513 EXPECT_EQ("second", result[1].name()); 514 EXPECT_EQ("third", result[2].name()); 515 EXPECT_EQ("fourth", result[3].name()); 516 } 517 518 void TestRenameNodeInputs() { 519 auto root = tensorflow::Scope::NewRootScope(); 520 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 521 522 const int width = 10; 523 524 Tensor a_data(DT_FLOAT, TensorShape({width})); 525 test::FillIota<float>(&a_data, 1.0f); 526 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 527 528 Tensor b_data(DT_FLOAT, TensorShape({width})); 529 test::FillIota<float>(&b_data, 1.0f); 530 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 531 532 Output add = Add(root.WithOpName("add"), a_const, a_const); 533 534 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 535 536 Output mul = Mul(root.WithOpName("output"), add, placeholder); 537 538 GraphDef graph_def; 539 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 540 541 GraphDef renamed_graph_def; 542 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, 543 std::unordered_set<string>(), 544 &renamed_graph_def)); 545 546 std::map<string, const NodeDef*> node_map; 547 MapNamesToNodes(renamed_graph_def, &node_map); 548 EXPECT_EQ("b", node_map.at("add")->input(0)); 549 EXPECT_EQ("b", node_map.at("add")->input(1)); 550 } 551 552 void TestRenameNodeInputsWithRedirects() { 553 auto root = tensorflow::Scope::NewRootScope(); 554 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 555 556 const int width = 10; 557 558 Tensor a_data(DT_FLOAT, TensorShape({width})); 559 test::FillIota<float>(&a_data, 1.0f); 560 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 561 562 Tensor b_data(DT_FLOAT, TensorShape({width})); 563 test::FillIota<float>(&b_data, 1.0f); 564 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 565 566 Tensor c_data(DT_FLOAT, TensorShape({width})); 567 test::FillIota<float>(&c_data, 1.0f); 568 Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data)); 569 570 Output add = Add(root.WithOpName("add"), a_const, b_const); 571 572 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 573 574 Output mul = Mul(root.WithOpName("output"), add, placeholder); 575 576 GraphDef graph_def; 577 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 578 579 GraphDef renamed_graph_def; 580 TF_ASSERT_OK(RenameNodeInputs( 581 graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}}, 582 std::unordered_set<string>(), &renamed_graph_def)); 583 584 std::map<string, const NodeDef*> node_map; 585 MapNamesToNodes(renamed_graph_def, &node_map); 586 EXPECT_EQ("c", node_map.at("add")->input(0)); 587 EXPECT_EQ("b", node_map.at("add")->input(1)); 588 } 589 590 void TestRenameNodeInputsWithCycle() { 591 auto root = tensorflow::Scope::NewRootScope(); 592 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 593 594 const int width = 10; 595 596 Tensor a_data(DT_FLOAT, TensorShape({width})); 597 test::FillIota<float>(&a_data, 1.0f); 598 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 599 600 Tensor b_data(DT_FLOAT, TensorShape({width})); 601 test::FillIota<float>(&b_data, 1.0f); 602 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 603 604 Tensor c_data(DT_FLOAT, TensorShape({width})); 605 test::FillIota<float>(&c_data, 1.0f); 606 Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data)); 607 608 Output add = Add(root.WithOpName("add"), a_const, b_const); 609 610 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 611 612 Output mul = Mul(root.WithOpName("output"), add, placeholder); 613 614 GraphDef graph_def; 615 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 616 617 GraphDef renamed_graph_def; 618 Status rename_status = 619 RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}}, 620 std::unordered_set<string>(), &renamed_graph_def); 621 EXPECT_FALSE(rename_status.ok()); 622 } 623 624 void TestRenameNodeInputsWithWildcard() { 625 auto root = tensorflow::Scope::DisabledShapeInferenceScope(); 626 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 627 628 const int width = 10; 629 630 Tensor a_data(DT_FLOAT, TensorShape({width})); 631 test::FillIota<float>(&a_data, 1.0f); 632 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 633 634 QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const, 635 a_const, DT_QUINT8, 636 QuantizeV2::Attrs().Mode("MIN_FIRST")); 637 638 Tensor b_data(DT_FLOAT, TensorShape({width})); 639 test::FillIota<float>(&b_data, 1.0f); 640 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 641 642 QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const, 643 b_const, DT_QUINT8, 644 QuantizeV2::Attrs().Mode("MIN_FIRST")); 645 646 Output add = Add(root.WithOpName("add"), quantize_a.output_min, 647 quantize_a.output_max); 648 649 GraphDef graph_def; 650 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 651 652 GraphDef renamed_graph_def; 653 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}}, 654 std::unordered_set<string>(), 655 &renamed_graph_def)); 656 657 std::map<string, const NodeDef*> node_map; 658 MapNamesToNodes(renamed_graph_def, &node_map); 659 EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0)); 660 EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1)); 661 } 662 663 void TestRenameNodeInputsWithIgnores() { 664 auto root = tensorflow::Scope::NewRootScope(); 665 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 666 667 const int width = 10; 668 669 Tensor a_data(DT_FLOAT, TensorShape({width})); 670 test::FillIota<float>(&a_data, 1.0f); 671 Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); 672 673 Tensor b_data(DT_FLOAT, TensorShape({width})); 674 test::FillIota<float>(&b_data, 1.0f); 675 Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); 676 677 Output add = Add(root.WithOpName("add"), a_const, a_const); 678 679 Output add2 = Add(root.WithOpName("add2"), a_const, a_const); 680 681 Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); 682 683 Output mul = Mul(root.WithOpName("mul"), add, placeholder); 684 685 Output mul2 = Mul(root.WithOpName("output"), mul, add2); 686 687 GraphDef graph_def; 688 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 689 690 GraphDef renamed_graph_def; 691 TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"}, 692 &renamed_graph_def)); 693 694 std::map<string, const NodeDef*> node_map; 695 MapNamesToNodes(renamed_graph_def, &node_map); 696 EXPECT_EQ("b", node_map.at("add")->input(0)); 697 EXPECT_EQ("b", node_map.at("add")->input(1)); 698 EXPECT_EQ("a", node_map.at("add2")->input(0)); 699 EXPECT_EQ("a", node_map.at("add2")->input(1)); 700 } 701 702 void TestFindInvalidInputs() { 703 GraphDef graph_def; 704 705 NodeDef* mul_node = graph_def.mutable_node()->Add(); 706 mul_node->set_op("Mul"); 707 mul_node->set_name("mul_node"); 708 *(mul_node->mutable_input()->Add()) = "add_node1"; 709 *(mul_node->mutable_input()->Add()) = "add_node2:0"; 710 *(mul_node->mutable_input()->Add()) = "^const_node1:0"; 711 712 NodeDef* add_node1 = graph_def.mutable_node()->Add(); 713 add_node1->set_op("Add"); 714 add_node1->set_name("add_node1"); 715 *(add_node1->mutable_input()->Add()) = "missing_input1"; 716 *(add_node1->mutable_input()->Add()) = "const_node1:0"; 717 *(add_node1->mutable_input()->Add()) = "missing_input2"; 718 719 NodeDef* add_node2 = graph_def.mutable_node()->Add(); 720 add_node2->set_op("Add"); 721 add_node2->set_name("add_node2"); 722 *(add_node2->mutable_input()->Add()) = "missing_input3"; 723 *(add_node2->mutable_input()->Add()) = "const_node1:0"; 724 *(add_node2->mutable_input()->Add()) = "^const_node2"; 725 726 NodeDef* const_node1 = graph_def.mutable_node()->Add(); 727 const_node1->set_op("Const"); 728 const_node1->set_name("const_node1"); 729 730 NodeDef* const_node2 = graph_def.mutable_node()->Add(); 731 const_node2->set_op("Const"); 732 const_node2->set_name("const_node2"); 733 734 std::vector<std::pair<string, string>> invalid_inputs; 735 FindInvalidInputs(graph_def, &invalid_inputs); 736 EXPECT_EQ(3, invalid_inputs.size()); 737 for (const std::pair<string, string>& invalid_input : invalid_inputs) { 738 EXPECT_TRUE((invalid_input.first == "add_node1") || 739 (invalid_input.first == "add_node2")); 740 if (invalid_input.first == "add_node1") { 741 EXPECT_TRUE((invalid_input.second == "missing_input1") || 742 (invalid_input.second == "missing_input2")) 743 << invalid_input.second; 744 } else if (invalid_input.first == "add_node2") { 745 EXPECT_EQ("missing_input3", invalid_input.second); 746 } 747 } 748 } 749 750 void TestIsGraphValid() { 751 GraphDef invalid_graph_def; 752 753 NodeDef* mul_node = invalid_graph_def.mutable_node()->Add(); 754 mul_node->set_op("Mul"); 755 mul_node->set_name("mul_node"); 756 *(mul_node->mutable_input()->Add()) = "add_node1"; 757 *(mul_node->mutable_input()->Add()) = "add_node2:0"; 758 *(mul_node->mutable_input()->Add()) = "^const_node1:0"; 759 760 NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add(); 761 add_node1->set_op("Add"); 762 add_node1->set_name("add_node1"); 763 *(add_node1->mutable_input()->Add()) = "missing_input1"; 764 *(add_node1->mutable_input()->Add()) = "const_node1:0"; 765 *(add_node1->mutable_input()->Add()) = "missing_input2"; 766 767 NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add(); 768 add_node2->set_op("Add"); 769 add_node2->set_name("add_node2"); 770 *(add_node2->mutable_input()->Add()) = "missing_input3"; 771 *(add_node2->mutable_input()->Add()) = "const_node1:0"; 772 *(add_node2->mutable_input()->Add()) = "^const_node2"; 773 774 NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add(); 775 const_node1->set_op("Const"); 776 const_node1->set_name("const_node1"); 777 778 NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add(); 779 const_node2->set_op("Const"); 780 const_node2->set_name("const_node2"); 781 782 EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok()); 783 784 GraphDef valid_graph_def; 785 786 NodeDef* const_node3 = valid_graph_def.mutable_node()->Add(); 787 const_node3->set_op("Const"); 788 const_node3->set_name("const_node2"); 789 790 EXPECT_TRUE(IsGraphValid(valid_graph_def).ok()); 791 } 792 793 void TestGetInOutTypes() { 794 auto root = tensorflow::Scope::NewRootScope(); 795 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 796 797 const int width = 20; 798 799 Tensor float_data(DT_FLOAT, TensorShape({width})); 800 test::FillIota<float>(&float_data, 1.0f); 801 Output float_const = 802 Const(root.WithOpName("float_const"), Input::Initializer(float_data)); 803 804 Tensor int_data(DT_INT32, TensorShape({width})); 805 test::FillIota<int32>(&int_data, 1); 806 Output int_const = 807 Const(root.WithOpName("int_const"), Input::Initializer(int_data)); 808 809 Output float_relu = Relu(root.WithOpName("float_relu"), float_const); 810 811 Output int_relu = Relu(root.WithOpName("int_relu"), int_const); 812 813 GraphDef graph_def; 814 TF_ASSERT_OK(root.ToGraphDef(&graph_def)); 815 816 std::map<string, const NodeDef*> node_map; 817 MapNamesToNodes(graph_def, &node_map); 818 819 const NodeDef* float_const_def = node_map.at("float_const"); 820 DataTypeVector float_const_inputs; 821 DataTypeVector float_const_outputs; 822 TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs, 823 &float_const_outputs)); 824 ASSERT_EQ(0, float_const_inputs.size()); 825 ASSERT_EQ(1, float_const_outputs.size()); 826 EXPECT_EQ(DT_FLOAT, float_const_outputs[0]); 827 828 const NodeDef* int_const_def = node_map.at("int_const"); 829 DataTypeVector int_const_inputs; 830 DataTypeVector int_const_outputs; 831 TF_EXPECT_OK( 832 GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs)); 833 ASSERT_EQ(0, int_const_inputs.size()); 834 ASSERT_EQ(1, int_const_outputs.size()); 835 EXPECT_EQ(DT_INT32, int_const_outputs[0]); 836 837 const NodeDef* float_relu_def = node_map.at("float_relu"); 838 DataTypeVector float_relu_inputs; 839 DataTypeVector float_relu_outputs; 840 TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs, 841 &float_relu_outputs)); 842 ASSERT_EQ(1, float_relu_inputs.size()); 843 EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]); 844 ASSERT_EQ(1, float_relu_outputs.size()); 845 EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]); 846 847 const NodeDef* int_relu_def = node_map.at("int_relu"); 848 DataTypeVector int_relu_inputs; 849 DataTypeVector int_relu_outputs; 850 TF_EXPECT_OK( 851 GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs)); 852 ASSERT_EQ(1, int_relu_inputs.size()); 853 EXPECT_EQ(DT_INT32, int_relu_inputs[0]); 854 ASSERT_EQ(1, int_relu_outputs.size()); 855 EXPECT_EQ(DT_INT32, int_relu_outputs[0]); 856 } 857 858 void TestCopyOriginalMatch() { 859 NodeDef a; 860 a.set_op("Relu"); 861 a.set_name("a"); 862 AddNodeInput("b", &a); 863 864 NodeDef b; 865 b.set_op("Const"); 866 b.set_name("b"); 867 868 NodeMatch b_match; 869 b_match.node = b; 870 871 NodeMatch a_match; 872 a_match.node = a; 873 a_match.inputs.push_back(b_match); 874 875 std::vector<NodeDef> new_nodes; 876 CopyOriginalMatch(a_match, &new_nodes); 877 EXPECT_EQ(2, new_nodes.size()); 878 EXPECT_EQ("a", new_nodes[0].name()); 879 EXPECT_EQ("Relu", new_nodes[0].op()); 880 EXPECT_EQ("b", new_nodes[1].name()); 881 EXPECT_EQ("Const", new_nodes[1].op()); 882 } 883 884 void TestHashNodeDef() { 885 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 886 887 const int width = 10; 888 889 auto a_root = tensorflow::Scope::NewRootScope(); 890 Tensor a_data(DT_FLOAT, TensorShape({width})); 891 test::FillIota<float>(&a_data, 1.0f); 892 Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data)); 893 GraphDef a_graph_def; 894 TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def)); 895 const NodeDef& a_node_def = a_graph_def.node(0); 896 897 auto b_root = tensorflow::Scope::NewRootScope(); 898 Tensor b_data(DT_FLOAT, TensorShape({width})); 899 test::FillIota<float>(&b_data, 1.0f); 900 Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data)); 901 GraphDef b_graph_def; 902 TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def)); 903 const NodeDef& b_node_def = b_graph_def.node(0); 904 905 auto c_root = tensorflow::Scope::NewRootScope(); 906 Tensor c_data(DT_FLOAT, TensorShape({width})); 907 test::FillIota<float>(&c_data, 2.0f); 908 Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data)); 909 GraphDef c_graph_def; 910 TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def)); 911 const NodeDef& c_node_def = c_graph_def.node(0); 912 913 auto d_root = tensorflow::Scope::NewRootScope(); 914 Tensor d_data(DT_FLOAT, TensorShape({width})); 915 test::FillIota<float>(&d_data, 1.0f); 916 Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data)); 917 GraphDef d_graph_def; 918 TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def)); 919 const NodeDef& d_node_def = d_graph_def.node(0); 920 921 auto e_root = tensorflow::Scope::NewRootScope(); 922 Tensor e_data(DT_INT32, TensorShape({width})); 923 test::FillIota<int32>(&e_data, 1); 924 Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data)); 925 GraphDef e_graph_def; 926 TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def)); 927 const NodeDef& e_node_def = e_graph_def.node(0); 928 929 auto f_root = tensorflow::Scope::NewRootScope(); 930 Tensor f_data(DT_FLOAT, TensorShape({width - 1})); 931 test::FillIota<float>(&f_data, 1.0f); 932 Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data)); 933 GraphDef f_graph_def; 934 TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def)); 935 const NodeDef& f_node_def = f_graph_def.node(0); 936 937 auto g_root = tensorflow::Scope::NewRootScope(); 938 Tensor g_data(DT_FLOAT, TensorShape({width})); 939 test::FillIota<float>(&g_data, 1); 940 Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"), 941 Input::Initializer(g_data)); 942 GraphDef g_graph_def; 943 TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def)); 944 const NodeDef& g_node_def = g_graph_def.node(0); 945 946 NodeDef relu1_node_def; 947 relu1_node_def.set_op("Relu"); 948 relu1_node_def.set_name("a"); 949 relu1_node_def.add_input("foo"); 950 951 NodeDef relu2_node_def; 952 relu2_node_def.set_op("Relu"); 953 relu2_node_def.set_name("a"); 954 relu2_node_def.add_input("bar"); 955 956 EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def)); 957 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def)); 958 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def)); 959 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def)); 960 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def)); 961 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def)); 962 EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def)); 963 EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def)); 964 } 965 966 void TestCountParameters() { 967 TransformFuncContext context; 968 context.params.insert({"foo", {"a", "b"}}); 969 context.params.insert({"bar", {"c"}}); 970 EXPECT_EQ(2, context.CountParameters("foo")); 971 EXPECT_EQ(1, context.CountParameters("bar")); 972 EXPECT_EQ(0, context.CountParameters("not_present")); 973 } 974 975 void TestGetOneStringParameter() { 976 TransformFuncContext context; 977 context.params.insert({"foo", {"a", "b"}}); 978 context.params.insert({"bar", {"c"}}); 979 string value; 980 TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value)); 981 EXPECT_EQ("c", value); 982 EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok()); 983 TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value)); 984 EXPECT_EQ("d", value); 985 } 986 987 void TestGetOneInt32Parameter() { 988 TransformFuncContext context; 989 context.params.insert({"foo", {"10", "20"}}); 990 context.params.insert({"bar", {"-23"}}); 991 context.params.insert({"not_a_number", {"not_numerical"}}); 992 context.params.insert({"float", {"-23.232323"}}); 993 int32 value; 994 TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value)); 995 EXPECT_EQ(-23, value); 996 EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok()); 997 TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value)); 998 EXPECT_EQ(10, value); 999 EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok()); 1000 EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok()); 1001 } 1002 1003 void TestGetOneInt64Parameter() { 1004 TransformFuncContext context; 1005 context.params.insert({"foo", {"10", "20"}}); 1006 context.params.insert({"bar", {"-23"}}); 1007 context.params.insert({"not_a_number", {"not_numerical"}}); 1008 context.params.insert({"float", {"-23.232323"}}); 1009 int64 value; 1010 TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value)); 1011 EXPECT_EQ(-23, value); 1012 EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok()); 1013 TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value)); 1014 EXPECT_EQ(10, value); 1015 EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok()); 1016 EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok()); 1017 } 1018 1019 void TestGetOneFloatParameter() { 1020 TransformFuncContext context; 1021 context.params.insert({"foo", {"10.0", "20.0"}}); 1022 context.params.insert({"bar", {"-23.2323"}}); 1023 context.params.insert({"not_a_number", {"not_numerical"}}); 1024 float value; 1025 TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value)); 1026 EXPECT_NEAR(-23.2323f, value, 1e-5f); 1027 EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok()); 1028 TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value)); 1029 EXPECT_NEAR(10.5f, value, 1e-5f); 1030 EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok()); 1031 } 1032 1033 void TestGetOneBoolParameter() { 1034 TransformFuncContext context; 1035 context.params.insert({"foo", {"true", "false"}}); 1036 context.params.insert({"true", {"true"}}); 1037 context.params.insert({"false", {"false"}}); 1038 context.params.insert({"one", {"1"}}); 1039 context.params.insert({"zero", {"0"}}); 1040 context.params.insert({"not_a_bool", {"not_boolean"}}); 1041 1042 bool value; 1043 EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok()); 1044 1045 value = false; 1046 TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value)); 1047 EXPECT_TRUE(value); 1048 1049 value = true; 1050 TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value)); 1051 EXPECT_FALSE(value); 1052 1053 value = false; 1054 TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value)); 1055 EXPECT_TRUE(value); 1056 1057 value = true; 1058 TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value)); 1059 EXPECT_FALSE(value); 1060 1061 EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok()); 1062 1063 value = false; 1064 TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value)); 1065 EXPECT_TRUE(value); 1066 } 1067 }; 1068 1069 TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } 1070 1071 TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); } 1072 1073 TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) { 1074 TestNodeNamePartsFromInput(); 1075 } 1076 1077 TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); } 1078 1079 TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); } 1080 1081 TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); } 1082 1083 TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); } 1084 1085 TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); } 1086 1087 TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) { 1088 TestSetNodeTensorAttrWithTensor(); 1089 } 1090 1091 TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); } 1092 1093 TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); } 1094 1095 TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); } 1096 1097 TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); } 1098 1099 TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); } 1100 1101 TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) { 1102 TestGetOpTypeMatchesDAG(); 1103 } 1104 1105 TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) { 1106 TestReplaceMatchingOpTypes(); 1107 } 1108 1109 TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) { 1110 TestMatchedNodesAsArray(); 1111 } 1112 1113 TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); } 1114 1115 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) { 1116 TestRenameNodeInputsWithRedirects(); 1117 } 1118 1119 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) { 1120 TestRenameNodeInputsWithCycle(); 1121 } 1122 1123 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) { 1124 TestRenameNodeInputsWithWildcard(); 1125 } 1126 1127 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) { 1128 TestRenameNodeInputsWithIgnores(); 1129 } 1130 1131 TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); } 1132 1133 TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } 1134 1135 TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); } 1136 1137 TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); } 1138 1139 TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); } 1140 1141 TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); } 1142 1143 TEST_F(TransformUtilsTest, TestGetOneStringParameter) { 1144 TestGetOneStringParameter(); 1145 } 1146 1147 TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) { 1148 TestGetOneInt32Parameter(); 1149 } 1150 1151 TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) { 1152 TestGetOneInt64Parameter(); 1153 } 1154 1155 TEST_F(TransformUtilsTest, TestGetOneFloatParameter) { 1156 TestGetOneFloatParameter(); 1157 } 1158 1159 TEST_F(TransformUtilsTest, TestGetOneBoolParameter) { 1160 TestGetOneBoolParameter(); 1161 } 1162 1163 } // namespace graph_transforms 1164 } // namespace tensorflow 1165