Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     17 #include "tensorflow/cc/ops/const_op.h"
     18 #include "tensorflow/cc/ops/image_ops.h"
     19 #include "tensorflow/cc/ops/nn_ops.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/io/path.h"
     24 #include "tensorflow/core/platform/test.h"
     25 #include "tensorflow/core/platform/test_benchmark.h"
     26 
     27 namespace tensorflow {
     28 namespace graph_transforms {
     29 
     30 class TransformUtilsTest : public ::testing::Test {
     31  protected:
     32   void TestMapNamesToNodes() {
     33     auto root = tensorflow::Scope::NewRootScope();
     34     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
     35 
     36     const int width = 100;
     37 
     38     Tensor a_data(DT_FLOAT, TensorShape({width}));
     39     test::FillIota<float>(&a_data, 1.0f);
     40     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
     41 
     42     Tensor b_data(DT_FLOAT, TensorShape({width}));
     43     test::FillIota<float>(&b_data, 1.0f);
     44     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
     45 
     46     Output add = Add(root.WithOpName("add"), a_const, b_const);
     47 
     48     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
     49 
     50     Output mul = Mul(root.WithOpName("output"), add, placeholder);
     51 
     52     GraphDef graph_def;
     53     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
     54     std::map<string, const NodeDef*> node_map;
     55     MapNamesToNodes(graph_def, &node_map);
     56 
     57     EXPECT_EQ(1, node_map.count("a"));
     58     EXPECT_EQ(1, node_map.count("b"));
     59     EXPECT_EQ(1, node_map.count("add"));
     60     EXPECT_EQ(1, node_map.count("placeholder"));
     61     EXPECT_EQ(1, node_map.count("output"));
     62     EXPECT_EQ(0, node_map.count("no_such_node"));
     63   }
     64 
     65   void TestMapNodesToOutputs() {
     66     auto root = tensorflow::Scope::NewRootScope();
     67     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
     68 
     69     const int width = 100;
     70 
     71     Tensor a_data(DT_FLOAT, TensorShape({width}));
     72     test::FillIota<float>(&a_data, 1.0f);
     73     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
     74 
     75     Tensor b_data(DT_FLOAT, TensorShape({width}));
     76     test::FillIota<float>(&b_data, 1.0f);
     77     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
     78 
     79     Output add = Add(root.WithOpName("add"), a_const, b_const);
     80 
     81     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
     82 
     83     Output mul = Mul(root.WithOpName("output"), add, placeholder);
     84 
     85     GraphDef graph_def;
     86     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
     87 
     88     std::map<string, std::vector<const NodeDef*>> outputs_map;
     89     MapNodesToOutputs(graph_def, &outputs_map);
     90 
     91     EXPECT_EQ(1, outputs_map.count("a"));
     92     EXPECT_EQ(1, outputs_map["a"].size());
     93     EXPECT_EQ("add", outputs_map["a"][0]->name());
     94 
     95     EXPECT_EQ(1, outputs_map.count("b"));
     96     EXPECT_EQ(1, outputs_map["b"].size());
     97     EXPECT_EQ("add", outputs_map["b"][0]->name());
     98 
     99     EXPECT_EQ(1, outputs_map.count("add"));
    100     EXPECT_EQ(1, outputs_map["add"].size());
    101     EXPECT_EQ("output", outputs_map["add"][0]->name());
    102 
    103     EXPECT_EQ(1, outputs_map.count("placeholder"));
    104     EXPECT_EQ(1, outputs_map["placeholder"].size());
    105     EXPECT_EQ("output", outputs_map["placeholder"][0]->name());
    106 
    107     EXPECT_EQ(0, outputs_map.count("output"));
    108     EXPECT_EQ(0, outputs_map.count("no_such_node"));
    109   }
    110 
    111   void TestNodeNamePartsFromInput() {
    112     string prefix;
    113     string node_name;
    114     string suffix;
    115 
    116     NodeNamePartsFromInput("some_node_name", &prefix, &node_name, &suffix);
    117     EXPECT_EQ("", prefix);
    118     EXPECT_EQ("some_node_name", node_name);
    119     EXPECT_EQ("", suffix);
    120 
    121     NodeNamePartsFromInput("some_node_name/with/slashes", &prefix, &node_name,
    122                            &suffix);
    123     EXPECT_EQ("", prefix);
    124     EXPECT_EQ("some_node_name/with/slashes", node_name);
    125     EXPECT_EQ("", suffix);
    126 
    127     NodeNamePartsFromInput("some_node_name:0", &prefix, &node_name, &suffix);
    128     EXPECT_EQ("", prefix);
    129     EXPECT_EQ("some_node_name", node_name);
    130     EXPECT_EQ(":0", suffix);
    131 
    132     NodeNamePartsFromInput("^some_node_name", &prefix, &node_name, &suffix);
    133     EXPECT_EQ("^", prefix);
    134     EXPECT_EQ("some_node_name", node_name);
    135     EXPECT_EQ("", suffix);
    136 
    137     NodeNamePartsFromInput("^some_node_name:99", &prefix, &node_name, &suffix);
    138     EXPECT_EQ("^", prefix);
    139     EXPECT_EQ("some_node_name", node_name);
    140     EXPECT_EQ(":99", suffix);
    141   }
    142 
    143   void TestNodeNameFromInput() {
    144     EXPECT_EQ("node_name", NodeNameFromInput("node_name"));
    145     EXPECT_EQ("node_name", NodeNameFromInput("node_name:0"));
    146     EXPECT_EQ("node_name", NodeNameFromInput("^node_name"));
    147     EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42"));
    148   }
    149 
    150   void TestCanonicalInputName() {
    151     EXPECT_EQ("node_name:0", CanonicalInputName("node_name"));
    152     EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0"));
    153     EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name"));
    154     EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42"));
    155   }
    156 
    157   void TestAddNodeInput() {
    158     NodeDef node;
    159     AddNodeInput("foo", &node);
    160     EXPECT_EQ("foo", node.input(0));
    161   }
    162 
    163   void TestCopyNodeAttr() {
    164     NodeDef node;
    165     auto mutable_attr = node.mutable_attr();
    166     (*mutable_attr)["foo"].set_i(3);
    167 
    168     NodeDef copied_node;
    169     CopyNodeAttr(node, "foo", "bar", &copied_node);
    170     EXPECT_EQ(3, copied_node.attr().at("bar").i());
    171   }
    172 
    173   void TestSetNodeAttr() {
    174     NodeDef node;
    175     int32 value_i = 32;
    176     SetNodeAttr("foo", value_i, &node);
    177     EXPECT_EQ(32, node.attr().at("foo").i());
    178     string value_s = "some_value";
    179     SetNodeAttr("bar", value_s, &node);
    180     EXPECT_EQ("some_value", node.attr().at("bar").s());
    181   }
    182 
    183   void TestSetNodeTensorAttr() {
    184     NodeDef node;
    185     SetNodeTensorAttr<int32>("foo", {3, 1}, {1, 2, 3}, &node);
    186     TensorProto tensor_proto = node.attr().at("foo").tensor();
    187     Tensor tensor;
    188     CHECK(tensor.FromProto(tensor_proto));
    189     EXPECT_EQ(DT_INT32, tensor.dtype());
    190     EXPECT_EQ(3, tensor.shape().dim_size(0));
    191     EXPECT_EQ(1, tensor.shape().dim_size(1));
    192     EXPECT_EQ(1, tensor.flat<int32>()(0));
    193     EXPECT_EQ(2, tensor.flat<int32>()(1));
    194     EXPECT_EQ(3, tensor.flat<int32>()(2));
    195   }
    196 
    197   void TestSetNodeTensorAttrWithTensor() {
    198     NodeDef node;
    199     Tensor input_tensor(DT_INT32, {4, 5});
    200     test::FillIota<int32>(&input_tensor, 1);
    201     SetNodeTensorAttr<int32>("foo", input_tensor, &node);
    202     TensorProto tensor_proto = node.attr().at("foo").tensor();
    203     Tensor tensor;
    204     CHECK(tensor.FromProto(tensor_proto));
    205     test::ExpectTensorEqual<int32>(input_tensor, tensor);
    206   }
    207 
    208   void TestGetNodeTensorAttr() {
    209     NodeDef node;
    210     Tensor input_tensor(DT_INT32, {4, 5});
    211     test::FillIota<int32>(&input_tensor, 1);
    212     TensorProto tensor_proto;
    213     input_tensor.AsProtoTensorContent(&tensor_proto);
    214     SetNodeAttr("foo", tensor_proto, &node);
    215     Tensor result = GetNodeTensorAttr(node, "foo");
    216     test::ExpectTensorEqual<int32>(input_tensor, result);
    217   }
    218 
    219   void TestFilterGraphDef() {
    220     auto root = tensorflow::Scope::NewRootScope();
    221     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    222 
    223     const int width = 100;
    224 
    225     Tensor a_data(DT_FLOAT, TensorShape({width}));
    226     test::FillIota<float>(&a_data, 1.0f);
    227     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    228 
    229     Tensor b_data(DT_FLOAT, TensorShape({width}));
    230     test::FillIota<float>(&b_data, 1.0f);
    231     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    232 
    233     Output add = Add(root.WithOpName("add"), a_const, b_const);
    234 
    235     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    236 
    237     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    238 
    239     Output remove_me = Add(root.WithOpName("remove_me"), mul, add);
    240 
    241     GraphDef graph_def;
    242     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    243 
    244     GraphDef result_graph_def;
    245     FilterGraphDef(
    246         graph_def,
    247         [](const NodeDef& node) { return (node.name() != "remove_me"); },
    248         &result_graph_def);
    249 
    250     std::map<string, const NodeDef*> node_map;
    251     MapNamesToNodes(result_graph_def, &node_map);
    252     EXPECT_EQ(1, node_map.count("a"));
    253     EXPECT_EQ(1, node_map.count("b"));
    254     EXPECT_EQ(1, node_map.count("add"));
    255     EXPECT_EQ(1, node_map.count("placeholder"));
    256     EXPECT_EQ(1, node_map.count("output"));
    257     EXPECT_EQ(0, node_map.count("remove_me"));
    258   }
    259 
    260   void TestRemoveAttributes() {
    261     auto root = tensorflow::Scope::NewRootScope();
    262     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    263 
    264     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    265 
    266     GraphDef graph_def;
    267     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    268 
    269     GraphDef result_graph_def;
    270     RemoveAttributes(graph_def, {"dtype"}, &result_graph_def);
    271 
    272     std::map<string, const NodeDef*> node_map;
    273     MapNamesToNodes(result_graph_def, &node_map);
    274     const NodeDef* removed_placeholder = node_map["placeholder"];
    275     EXPECT_EQ(nullptr,
    276               tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
    277   }
    278 
    279   void TestGetOpTypeMatches() {
    280     auto root = tensorflow::Scope::NewRootScope();
    281     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    282 
    283     const int width = 100;
    284 
    285     Tensor a_data(DT_FLOAT, TensorShape({width}));
    286     test::FillIota<float>(&a_data, 1.0f);
    287     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    288 
    289     Tensor b_data(DT_FLOAT, TensorShape({width}));
    290     test::FillIota<float>(&b_data, 1.0f);
    291     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    292 
    293     Output add = Add(root.WithOpName("add"), a_const, b_const);
    294 
    295     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    296 
    297     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    298 
    299     GraphDef graph_def;
    300     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    301 
    302     GraphMatcher matcher(graph_def);
    303 
    304     std::vector<NodeMatch> const_matches;
    305     TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches));
    306     EXPECT_EQ(2, const_matches.size());
    307     for (const NodeMatch& match : const_matches) {
    308       EXPECT_EQ("Const", match.node.op());
    309       EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
    310           << "match.node.name()=" << match.node.name();
    311     }
    312 
    313     std::vector<NodeMatch> add_matches;
    314     TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches));
    315     EXPECT_EQ(1, add_matches.size());
    316     EXPECT_EQ("Add", add_matches[0].node.op());
    317     EXPECT_EQ("add", add_matches[0].node.name());
    318 
    319     std::vector<NodeMatch> add_child_matches;
    320     TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
    321                                           &add_child_matches));
    322     EXPECT_EQ(1, add_child_matches.size());
    323     EXPECT_EQ("Add", add_child_matches[0].node.op());
    324     EXPECT_EQ("add", add_child_matches[0].node.name());
    325     EXPECT_EQ(2, add_child_matches[0].inputs.size());
    326     for (const NodeMatch& match : add_child_matches[0].inputs) {
    327       EXPECT_EQ("Const", match.node.op());
    328       EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
    329           << "match.node.name()=" << match.node.name();
    330     }
    331 
    332     std::vector<NodeMatch> no_such_matches;
    333     TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches));
    334     EXPECT_EQ(0, no_such_matches.size());
    335 
    336     std::vector<NodeMatch> all_matches;
    337     TF_ASSERT_OK(matcher.GetOpTypeMatches(
    338         {"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}},
    339         &all_matches));
    340     EXPECT_EQ(1, all_matches.size());
    341     EXPECT_EQ("Mul", all_matches[0].node.op());
    342     EXPECT_EQ("output", all_matches[0].node.name());
    343     EXPECT_EQ(2, all_matches[0].inputs.size());
    344     EXPECT_EQ("Add", all_matches[0].inputs[0].node.op());
    345     EXPECT_EQ("add", all_matches[0].inputs[0].node.name());
    346     EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size());
    347     EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op());
    348     EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name());
    349     EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size());
    350     EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op());
    351     EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name());
    352     EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size());
    353     EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op());
    354     EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name());
    355     EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size());
    356 
    357     std::vector<NodeMatch> wildcard_matches;
    358     TF_ASSERT_OK(
    359         matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches));
    360     EXPECT_EQ(1, wildcard_matches.size());
    361     EXPECT_EQ("Add", wildcard_matches[0].node.op());
    362     EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op());
    363     EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name());
    364     EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op());
    365     EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name());
    366 
    367     std::vector<NodeMatch> or_matches;
    368     TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches));
    369     EXPECT_EQ(2, or_matches.size());
    370     EXPECT_EQ("Add", or_matches[0].node.op());
    371     EXPECT_EQ("add", or_matches[0].node.name());
    372     EXPECT_EQ("Mul", or_matches[1].node.op());
    373     EXPECT_EQ("output", or_matches[1].node.name());
    374   }
    375 
    376   void TestGetOpTypeMatchesDAG() {
    377     auto root = tensorflow::Scope::NewRootScope();
    378     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    379 
    380     const int width = 100;
    381 
    382     Tensor a_data(DT_FLOAT, TensorShape({width}));
    383     test::FillIota<float>(&a_data, 1.0f);
    384     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    385 
    386     Output add = Add(root.WithOpName("add"), a_const, a_const);
    387 
    388     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    389 
    390     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    391 
    392     GraphDef graph_def;
    393     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    394 
    395     GraphMatcher matcher(graph_def);
    396 
    397     std::vector<NodeMatch> add_matches;
    398     TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
    399                                           &add_matches));
    400     EXPECT_EQ(1, add_matches.size());
    401     EXPECT_EQ("Add", add_matches[0].node.op());
    402     EXPECT_EQ("add", add_matches[0].node.name());
    403     EXPECT_EQ("Const", add_matches[0].inputs[0].node.op());
    404     EXPECT_EQ("a", add_matches[0].inputs[0].node.name());
    405     EXPECT_EQ("Const", add_matches[0].inputs[1].node.op());
    406     EXPECT_EQ("a", add_matches[0].inputs[1].node.name());
    407   }
    408 
    409   void TestReplaceMatchingOpTypes() {
    410     auto root = tensorflow::Scope::NewRootScope();
    411     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    412 
    413     const int width = 10;
    414 
    415     Tensor a_data(DT_FLOAT, TensorShape({width}));
    416     test::FillIota<float>(&a_data, 1.0f);
    417     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    418 
    419     Tensor b_data(DT_FLOAT, TensorShape({width}));
    420     test::FillIota<float>(&b_data, 1.0f);
    421     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    422 
    423     Output add = Add(root.WithOpName("add"), a_const, b_const);
    424 
    425     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    426 
    427     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    428 
    429     GraphDef graph_def;
    430     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    431 
    432     GraphDef replaced_graph_def;
    433     TF_ASSERT_OK(ReplaceMatchingOpTypes(
    434         graph_def, {"*"},
    435         [](const NodeMatch& match, const std::set<string>& input_nodes,
    436            const std::set<string>& output_nodes,
    437            std::vector<NodeDef>* new_nodes) {
    438           NodeDef original_copy;
    439           original_copy = match.node;
    440           const string original_name = match.node.name();
    441           original_copy.set_name(original_name + "_before_identity");
    442           new_nodes->push_back(original_copy);
    443 
    444           NodeDef identity_node;
    445           identity_node.set_op("Identity");
    446           identity_node.set_name(original_name);
    447           *(identity_node.mutable_input()->Add()) = original_copy.name();
    448           new_nodes->push_back(identity_node);
    449 
    450           return Status::OK();
    451         },
    452         {}, &replaced_graph_def));
    453 
    454     EXPECT_EQ(10, replaced_graph_def.node_size());
    455     for (const NodeDef& node : replaced_graph_def.node()) {
    456       if (node.name() == "output") {
    457         EXPECT_EQ("Identity", node.op());
    458         EXPECT_EQ("output_before_identity", node.input(0));
    459       } else if (node.name() == "output_before_identity") {
    460         EXPECT_EQ("Mul", node.op());
    461         EXPECT_EQ("add", node.input(0));
    462         EXPECT_EQ("placeholder", node.input(1));
    463       } else if (node.name() == "placeholder") {
    464         EXPECT_EQ("Identity", node.op());
    465         EXPECT_EQ("placeholder_before_identity", node.input(0));
    466       } else if (node.name() == "placeholder_before_identity") {
    467         EXPECT_EQ("Placeholder", node.op());
    468       } else if (node.name() == "add") {
    469         EXPECT_EQ("Identity", node.op());
    470         EXPECT_EQ("add_before_identity", node.input(0));
    471       } else if (node.name() == "add_before_identity") {
    472         EXPECT_EQ("Add", node.op());
    473         EXPECT_EQ("a", node.input(0));
    474         EXPECT_EQ("b", node.input(1));
    475       } else if (node.name() == "a") {
    476         EXPECT_EQ("Identity", node.op());
    477         EXPECT_EQ("a_before_identity", node.input(0));
    478       } else if (node.name() == "a_before_identity") {
    479         EXPECT_EQ("Const", node.op());
    480       } else if (node.name() == "b") {
    481         EXPECT_EQ("Identity", node.op());
    482         EXPECT_EQ("b_before_identity", node.input(0));
    483       } else if (node.name() == "b_before_identity") {
    484         EXPECT_EQ("Const", node.op());
    485       } else {
    486         EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name();
    487       }
    488     }
    489   }
    490 
    491   void TestMatchedNodesAsArray() {
    492     NodeMatch fourth;
    493     fourth.node.set_name("fourth");
    494 
    495     NodeMatch second;
    496     second.node.set_name("second");
    497     second.inputs.push_back(fourth);
    498 
    499     NodeMatch third;
    500     third.node.set_name("third");
    501     third.inputs.push_back(fourth);
    502 
    503     NodeMatch first;
    504     first.node.set_name("first");
    505     first.inputs.push_back(second);
    506     first.inputs.push_back(third);
    507 
    508     std::vector<NodeDef> result;
    509     MatchedNodesAsArray(first, &result);
    510 
    511     EXPECT_EQ(4, result.size());
    512     EXPECT_EQ("first", result[0].name());
    513     EXPECT_EQ("second", result[1].name());
    514     EXPECT_EQ("third", result[2].name());
    515     EXPECT_EQ("fourth", result[3].name());
    516   }
    517 
    518   void TestRenameNodeInputs() {
    519     auto root = tensorflow::Scope::NewRootScope();
    520     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    521 
    522     const int width = 10;
    523 
    524     Tensor a_data(DT_FLOAT, TensorShape({width}));
    525     test::FillIota<float>(&a_data, 1.0f);
    526     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    527 
    528     Tensor b_data(DT_FLOAT, TensorShape({width}));
    529     test::FillIota<float>(&b_data, 1.0f);
    530     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    531 
    532     Output add = Add(root.WithOpName("add"), a_const, a_const);
    533 
    534     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    535 
    536     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    537 
    538     GraphDef graph_def;
    539     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    540 
    541     GraphDef renamed_graph_def;
    542     TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}},
    543                                   std::unordered_set<string>(),
    544                                   &renamed_graph_def));
    545 
    546     std::map<string, const NodeDef*> node_map;
    547     MapNamesToNodes(renamed_graph_def, &node_map);
    548     EXPECT_EQ("b", node_map.at("add")->input(0));
    549     EXPECT_EQ("b", node_map.at("add")->input(1));
    550   }
    551 
    552   void TestRenameNodeInputsWithRedirects() {
    553     auto root = tensorflow::Scope::NewRootScope();
    554     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    555 
    556     const int width = 10;
    557 
    558     Tensor a_data(DT_FLOAT, TensorShape({width}));
    559     test::FillIota<float>(&a_data, 1.0f);
    560     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    561 
    562     Tensor b_data(DT_FLOAT, TensorShape({width}));
    563     test::FillIota<float>(&b_data, 1.0f);
    564     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    565 
    566     Tensor c_data(DT_FLOAT, TensorShape({width}));
    567     test::FillIota<float>(&c_data, 1.0f);
    568     Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
    569 
    570     Output add = Add(root.WithOpName("add"), a_const, b_const);
    571 
    572     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    573 
    574     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    575 
    576     GraphDef graph_def;
    577     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    578 
    579     GraphDef renamed_graph_def;
    580     TF_ASSERT_OK(RenameNodeInputs(
    581         graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
    582         std::unordered_set<string>(), &renamed_graph_def));
    583 
    584     std::map<string, const NodeDef*> node_map;
    585     MapNamesToNodes(renamed_graph_def, &node_map);
    586     EXPECT_EQ("c", node_map.at("add")->input(0));
    587     EXPECT_EQ("b", node_map.at("add")->input(1));
    588   }
    589 
    590   void TestRenameNodeInputsWithCycle() {
    591     auto root = tensorflow::Scope::NewRootScope();
    592     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    593 
    594     const int width = 10;
    595 
    596     Tensor a_data(DT_FLOAT, TensorShape({width}));
    597     test::FillIota<float>(&a_data, 1.0f);
    598     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    599 
    600     Tensor b_data(DT_FLOAT, TensorShape({width}));
    601     test::FillIota<float>(&b_data, 1.0f);
    602     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    603 
    604     Tensor c_data(DT_FLOAT, TensorShape({width}));
    605     test::FillIota<float>(&c_data, 1.0f);
    606     Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
    607 
    608     Output add = Add(root.WithOpName("add"), a_const, b_const);
    609 
    610     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    611 
    612     Output mul = Mul(root.WithOpName("output"), add, placeholder);
    613 
    614     GraphDef graph_def;
    615     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    616 
    617     GraphDef renamed_graph_def;
    618     Status rename_status =
    619         RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
    620                          std::unordered_set<string>(), &renamed_graph_def);
    621     EXPECT_FALSE(rename_status.ok());
    622   }
    623 
    624   void TestRenameNodeInputsWithWildcard() {
    625     auto root = tensorflow::Scope::DisabledShapeInferenceScope();
    626     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    627 
    628     const int width = 10;
    629 
    630     Tensor a_data(DT_FLOAT, TensorShape({width}));
    631     test::FillIota<float>(&a_data, 1.0f);
    632     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    633 
    634     QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const,
    635                           a_const, DT_QUINT8,
    636                           QuantizeV2::Attrs().Mode("MIN_FIRST"));
    637 
    638     Tensor b_data(DT_FLOAT, TensorShape({width}));
    639     test::FillIota<float>(&b_data, 1.0f);
    640     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    641 
    642     QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const,
    643                           b_const, DT_QUINT8,
    644                           QuantizeV2::Attrs().Mode("MIN_FIRST"));
    645 
    646     Output add = Add(root.WithOpName("add"), quantize_a.output_min,
    647                      quantize_a.output_max);
    648 
    649     GraphDef graph_def;
    650     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    651 
    652     GraphDef renamed_graph_def;
    653     TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
    654                                   std::unordered_set<string>(),
    655                                   &renamed_graph_def));
    656 
    657     std::map<string, const NodeDef*> node_map;
    658     MapNamesToNodes(renamed_graph_def, &node_map);
    659     EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0));
    660     EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
    661   }
    662 
    663   void TestRenameNodeInputsWithIgnores() {
    664     auto root = tensorflow::Scope::NewRootScope();
    665     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    666 
    667     const int width = 10;
    668 
    669     Tensor a_data(DT_FLOAT, TensorShape({width}));
    670     test::FillIota<float>(&a_data, 1.0f);
    671     Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
    672 
    673     Tensor b_data(DT_FLOAT, TensorShape({width}));
    674     test::FillIota<float>(&b_data, 1.0f);
    675     Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
    676 
    677     Output add = Add(root.WithOpName("add"), a_const, a_const);
    678 
    679     Output add2 = Add(root.WithOpName("add2"), a_const, a_const);
    680 
    681     Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
    682 
    683     Output mul = Mul(root.WithOpName("mul"), add, placeholder);
    684 
    685     Output mul2 = Mul(root.WithOpName("output"), mul, add2);
    686 
    687     GraphDef graph_def;
    688     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    689 
    690     GraphDef renamed_graph_def;
    691     TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"},
    692                                   &renamed_graph_def));
    693 
    694     std::map<string, const NodeDef*> node_map;
    695     MapNamesToNodes(renamed_graph_def, &node_map);
    696     EXPECT_EQ("b", node_map.at("add")->input(0));
    697     EXPECT_EQ("b", node_map.at("add")->input(1));
    698     EXPECT_EQ("a", node_map.at("add2")->input(0));
    699     EXPECT_EQ("a", node_map.at("add2")->input(1));
    700   }
    701 
    702   void TestFindInvalidInputs() {
    703     GraphDef graph_def;
    704 
    705     NodeDef* mul_node = graph_def.mutable_node()->Add();
    706     mul_node->set_op("Mul");
    707     mul_node->set_name("mul_node");
    708     *(mul_node->mutable_input()->Add()) = "add_node1";
    709     *(mul_node->mutable_input()->Add()) = "add_node2:0";
    710     *(mul_node->mutable_input()->Add()) = "^const_node1:0";
    711 
    712     NodeDef* add_node1 = graph_def.mutable_node()->Add();
    713     add_node1->set_op("Add");
    714     add_node1->set_name("add_node1");
    715     *(add_node1->mutable_input()->Add()) = "missing_input1";
    716     *(add_node1->mutable_input()->Add()) = "const_node1:0";
    717     *(add_node1->mutable_input()->Add()) = "missing_input2";
    718 
    719     NodeDef* add_node2 = graph_def.mutable_node()->Add();
    720     add_node2->set_op("Add");
    721     add_node2->set_name("add_node2");
    722     *(add_node2->mutable_input()->Add()) = "missing_input3";
    723     *(add_node2->mutable_input()->Add()) = "const_node1:0";
    724     *(add_node2->mutable_input()->Add()) = "^const_node2";
    725 
    726     NodeDef* const_node1 = graph_def.mutable_node()->Add();
    727     const_node1->set_op("Const");
    728     const_node1->set_name("const_node1");
    729 
    730     NodeDef* const_node2 = graph_def.mutable_node()->Add();
    731     const_node2->set_op("Const");
    732     const_node2->set_name("const_node2");
    733 
    734     std::vector<std::pair<string, string>> invalid_inputs;
    735     FindInvalidInputs(graph_def, &invalid_inputs);
    736     EXPECT_EQ(3, invalid_inputs.size());
    737     for (const std::pair<string, string>& invalid_input : invalid_inputs) {
    738       EXPECT_TRUE((invalid_input.first == "add_node1") ||
    739                   (invalid_input.first == "add_node2"));
    740       if (invalid_input.first == "add_node1") {
    741         EXPECT_TRUE((invalid_input.second == "missing_input1") ||
    742                     (invalid_input.second == "missing_input2"))
    743             << invalid_input.second;
    744       } else if (invalid_input.first == "add_node2") {
    745         EXPECT_EQ("missing_input3", invalid_input.second);
    746       }
    747     }
    748   }
    749 
    750   void TestIsGraphValid() {
    751     GraphDef invalid_graph_def;
    752 
    753     NodeDef* mul_node = invalid_graph_def.mutable_node()->Add();
    754     mul_node->set_op("Mul");
    755     mul_node->set_name("mul_node");
    756     *(mul_node->mutable_input()->Add()) = "add_node1";
    757     *(mul_node->mutable_input()->Add()) = "add_node2:0";
    758     *(mul_node->mutable_input()->Add()) = "^const_node1:0";
    759 
    760     NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add();
    761     add_node1->set_op("Add");
    762     add_node1->set_name("add_node1");
    763     *(add_node1->mutable_input()->Add()) = "missing_input1";
    764     *(add_node1->mutable_input()->Add()) = "const_node1:0";
    765     *(add_node1->mutable_input()->Add()) = "missing_input2";
    766 
    767     NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add();
    768     add_node2->set_op("Add");
    769     add_node2->set_name("add_node2");
    770     *(add_node2->mutable_input()->Add()) = "missing_input3";
    771     *(add_node2->mutable_input()->Add()) = "const_node1:0";
    772     *(add_node2->mutable_input()->Add()) = "^const_node2";
    773 
    774     NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add();
    775     const_node1->set_op("Const");
    776     const_node1->set_name("const_node1");
    777 
    778     NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add();
    779     const_node2->set_op("Const");
    780     const_node2->set_name("const_node2");
    781 
    782     EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok());
    783 
    784     GraphDef valid_graph_def;
    785 
    786     NodeDef* const_node3 = valid_graph_def.mutable_node()->Add();
    787     const_node3->set_op("Const");
    788     const_node3->set_name("const_node2");
    789 
    790     EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
    791   }
    792 
    793   void TestGetInOutTypes() {
    794     auto root = tensorflow::Scope::NewRootScope();
    795     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    796 
    797     const int width = 20;
    798 
    799     Tensor float_data(DT_FLOAT, TensorShape({width}));
    800     test::FillIota<float>(&float_data, 1.0f);
    801     Output float_const =
    802         Const(root.WithOpName("float_const"), Input::Initializer(float_data));
    803 
    804     Tensor int_data(DT_INT32, TensorShape({width}));
    805     test::FillIota<int32>(&int_data, 1);
    806     Output int_const =
    807         Const(root.WithOpName("int_const"), Input::Initializer(int_data));
    808 
    809     Output float_relu = Relu(root.WithOpName("float_relu"), float_const);
    810 
    811     Output int_relu = Relu(root.WithOpName("int_relu"), int_const);
    812 
    813     GraphDef graph_def;
    814     TF_ASSERT_OK(root.ToGraphDef(&graph_def));
    815 
    816     std::map<string, const NodeDef*> node_map;
    817     MapNamesToNodes(graph_def, &node_map);
    818 
    819     const NodeDef* float_const_def = node_map.at("float_const");
    820     DataTypeVector float_const_inputs;
    821     DataTypeVector float_const_outputs;
    822     TF_EXPECT_OK(GetInOutTypes(*float_const_def, &float_const_inputs,
    823                                &float_const_outputs));
    824     ASSERT_EQ(0, float_const_inputs.size());
    825     ASSERT_EQ(1, float_const_outputs.size());
    826     EXPECT_EQ(DT_FLOAT, float_const_outputs[0]);
    827 
    828     const NodeDef* int_const_def = node_map.at("int_const");
    829     DataTypeVector int_const_inputs;
    830     DataTypeVector int_const_outputs;
    831     TF_EXPECT_OK(
    832         GetInOutTypes(*int_const_def, &int_const_inputs, &int_const_outputs));
    833     ASSERT_EQ(0, int_const_inputs.size());
    834     ASSERT_EQ(1, int_const_outputs.size());
    835     EXPECT_EQ(DT_INT32, int_const_outputs[0]);
    836 
    837     const NodeDef* float_relu_def = node_map.at("float_relu");
    838     DataTypeVector float_relu_inputs;
    839     DataTypeVector float_relu_outputs;
    840     TF_EXPECT_OK(GetInOutTypes(*float_relu_def, &float_relu_inputs,
    841                                &float_relu_outputs));
    842     ASSERT_EQ(1, float_relu_inputs.size());
    843     EXPECT_EQ(DT_FLOAT, float_relu_inputs[0]);
    844     ASSERT_EQ(1, float_relu_outputs.size());
    845     EXPECT_EQ(DT_FLOAT, float_relu_outputs[0]);
    846 
    847     const NodeDef* int_relu_def = node_map.at("int_relu");
    848     DataTypeVector int_relu_inputs;
    849     DataTypeVector int_relu_outputs;
    850     TF_EXPECT_OK(
    851         GetInOutTypes(*int_relu_def, &int_relu_inputs, &int_relu_outputs));
    852     ASSERT_EQ(1, int_relu_inputs.size());
    853     EXPECT_EQ(DT_INT32, int_relu_inputs[0]);
    854     ASSERT_EQ(1, int_relu_outputs.size());
    855     EXPECT_EQ(DT_INT32, int_relu_outputs[0]);
    856   }
    857 
    858   void TestCopyOriginalMatch() {
    859     NodeDef a;
    860     a.set_op("Relu");
    861     a.set_name("a");
    862     AddNodeInput("b", &a);
    863 
    864     NodeDef b;
    865     b.set_op("Const");
    866     b.set_name("b");
    867 
    868     NodeMatch b_match;
    869     b_match.node = b;
    870 
    871     NodeMatch a_match;
    872     a_match.node = a;
    873     a_match.inputs.push_back(b_match);
    874 
    875     std::vector<NodeDef> new_nodes;
    876     CopyOriginalMatch(a_match, &new_nodes);
    877     EXPECT_EQ(2, new_nodes.size());
    878     EXPECT_EQ("a", new_nodes[0].name());
    879     EXPECT_EQ("Relu", new_nodes[0].op());
    880     EXPECT_EQ("b", new_nodes[1].name());
    881     EXPECT_EQ("Const", new_nodes[1].op());
    882   }
    883 
    884   void TestHashNodeDef() {
    885     using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
    886 
    887     const int width = 10;
    888 
    889     auto a_root = tensorflow::Scope::NewRootScope();
    890     Tensor a_data(DT_FLOAT, TensorShape({width}));
    891     test::FillIota<float>(&a_data, 1.0f);
    892     Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data));
    893     GraphDef a_graph_def;
    894     TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def));
    895     const NodeDef& a_node_def = a_graph_def.node(0);
    896 
    897     auto b_root = tensorflow::Scope::NewRootScope();
    898     Tensor b_data(DT_FLOAT, TensorShape({width}));
    899     test::FillIota<float>(&b_data, 1.0f);
    900     Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data));
    901     GraphDef b_graph_def;
    902     TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def));
    903     const NodeDef& b_node_def = b_graph_def.node(0);
    904 
    905     auto c_root = tensorflow::Scope::NewRootScope();
    906     Tensor c_data(DT_FLOAT, TensorShape({width}));
    907     test::FillIota<float>(&c_data, 2.0f);
    908     Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data));
    909     GraphDef c_graph_def;
    910     TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def));
    911     const NodeDef& c_node_def = c_graph_def.node(0);
    912 
    913     auto d_root = tensorflow::Scope::NewRootScope();
    914     Tensor d_data(DT_FLOAT, TensorShape({width}));
    915     test::FillIota<float>(&d_data, 1.0f);
    916     Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data));
    917     GraphDef d_graph_def;
    918     TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def));
    919     const NodeDef& d_node_def = d_graph_def.node(0);
    920 
    921     auto e_root = tensorflow::Scope::NewRootScope();
    922     Tensor e_data(DT_INT32, TensorShape({width}));
    923     test::FillIota<int32>(&e_data, 1);
    924     Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data));
    925     GraphDef e_graph_def;
    926     TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def));
    927     const NodeDef& e_node_def = e_graph_def.node(0);
    928 
    929     auto f_root = tensorflow::Scope::NewRootScope();
    930     Tensor f_data(DT_FLOAT, TensorShape({width - 1}));
    931     test::FillIota<float>(&f_data, 1.0f);
    932     Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data));
    933     GraphDef f_graph_def;
    934     TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def));
    935     const NodeDef& f_node_def = f_graph_def.node(0);
    936 
    937     auto g_root = tensorflow::Scope::NewRootScope();
    938     Tensor g_data(DT_FLOAT, TensorShape({width}));
    939     test::FillIota<float>(&g_data, 1);
    940     Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"),
    941                            Input::Initializer(g_data));
    942     GraphDef g_graph_def;
    943     TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def));
    944     const NodeDef& g_node_def = g_graph_def.node(0);
    945 
    946     NodeDef relu1_node_def;
    947     relu1_node_def.set_op("Relu");
    948     relu1_node_def.set_name("a");
    949     relu1_node_def.add_input("foo");
    950 
    951     NodeDef relu2_node_def;
    952     relu2_node_def.set_op("Relu");
    953     relu2_node_def.set_name("a");
    954     relu2_node_def.add_input("bar");
    955 
    956     EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def));
    957     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def));
    958     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def));
    959     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def));
    960     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def));
    961     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def));
    962     EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def));
    963     EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def));
    964   }
    965 
    966   void TestCountParameters() {
    967     TransformFuncContext context;
    968     context.params.insert({"foo", {"a", "b"}});
    969     context.params.insert({"bar", {"c"}});
    970     EXPECT_EQ(2, context.CountParameters("foo"));
    971     EXPECT_EQ(1, context.CountParameters("bar"));
    972     EXPECT_EQ(0, context.CountParameters("not_present"));
    973   }
    974 
    975   void TestGetOneStringParameter() {
    976     TransformFuncContext context;
    977     context.params.insert({"foo", {"a", "b"}});
    978     context.params.insert({"bar", {"c"}});
    979     string value;
    980     TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value));
    981     EXPECT_EQ("c", value);
    982     EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok());
    983     TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value));
    984     EXPECT_EQ("d", value);
    985   }
    986 
    987   void TestGetOneInt32Parameter() {
    988     TransformFuncContext context;
    989     context.params.insert({"foo", {"10", "20"}});
    990     context.params.insert({"bar", {"-23"}});
    991     context.params.insert({"not_a_number", {"not_numerical"}});
    992     context.params.insert({"float", {"-23.232323"}});
    993     int32 value;
    994     TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value));
    995     EXPECT_EQ(-23, value);
    996     EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok());
    997     TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value));
    998     EXPECT_EQ(10, value);
    999     EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok());
   1000     EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok());
   1001   }
   1002 
   1003   void TestGetOneInt64Parameter() {
   1004     TransformFuncContext context;
   1005     context.params.insert({"foo", {"10", "20"}});
   1006     context.params.insert({"bar", {"-23"}});
   1007     context.params.insert({"not_a_number", {"not_numerical"}});
   1008     context.params.insert({"float", {"-23.232323"}});
   1009     int64 value;
   1010     TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value));
   1011     EXPECT_EQ(-23, value);
   1012     EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok());
   1013     TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value));
   1014     EXPECT_EQ(10, value);
   1015     EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok());
   1016     EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok());
   1017   }
   1018 
   1019   void TestGetOneFloatParameter() {
   1020     TransformFuncContext context;
   1021     context.params.insert({"foo", {"10.0", "20.0"}});
   1022     context.params.insert({"bar", {"-23.2323"}});
   1023     context.params.insert({"not_a_number", {"not_numerical"}});
   1024     float value;
   1025     TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value));
   1026     EXPECT_NEAR(-23.2323f, value, 1e-5f);
   1027     EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok());
   1028     TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value));
   1029     EXPECT_NEAR(10.5f, value, 1e-5f);
   1030     EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok());
   1031   }
   1032 
   1033   void TestGetOneBoolParameter() {
   1034     TransformFuncContext context;
   1035     context.params.insert({"foo", {"true", "false"}});
   1036     context.params.insert({"true", {"true"}});
   1037     context.params.insert({"false", {"false"}});
   1038     context.params.insert({"one", {"1"}});
   1039     context.params.insert({"zero", {"0"}});
   1040     context.params.insert({"not_a_bool", {"not_boolean"}});
   1041 
   1042     bool value;
   1043     EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok());
   1044 
   1045     value = false;
   1046     TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value));
   1047     EXPECT_TRUE(value);
   1048 
   1049     value = true;
   1050     TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value));
   1051     EXPECT_FALSE(value);
   1052 
   1053     value = false;
   1054     TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value));
   1055     EXPECT_TRUE(value);
   1056 
   1057     value = true;
   1058     TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value));
   1059     EXPECT_FALSE(value);
   1060 
   1061     EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok());
   1062 
   1063     value = false;
   1064     TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
   1065     EXPECT_TRUE(value);
   1066   }
   1067 };
   1068 
   1069 TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
   1070 
   1071 TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); }
   1072 
   1073 TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) {
   1074   TestNodeNamePartsFromInput();
   1075 }
   1076 
   1077 TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); }
   1078 
   1079 TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); }
   1080 
   1081 TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); }
   1082 
   1083 TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); }
   1084 
   1085 TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); }
   1086 
   1087 TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) {
   1088   TestSetNodeTensorAttrWithTensor();
   1089 }
   1090 
   1091 TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); }
   1092 
   1093 TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
   1094 
   1095 TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
   1096 
   1097 TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
   1098 
   1099 TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); }
   1100 
   1101 TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) {
   1102   TestGetOpTypeMatchesDAG();
   1103 }
   1104 
   1105 TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) {
   1106   TestReplaceMatchingOpTypes();
   1107 }
   1108 
   1109 TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) {
   1110   TestMatchedNodesAsArray();
   1111 }
   1112 
   1113 TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); }
   1114 
   1115 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) {
   1116   TestRenameNodeInputsWithRedirects();
   1117 }
   1118 
   1119 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) {
   1120   TestRenameNodeInputsWithCycle();
   1121 }
   1122 
   1123 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
   1124   TestRenameNodeInputsWithWildcard();
   1125 }
   1126 
   1127 TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) {
   1128   TestRenameNodeInputsWithIgnores();
   1129 }
   1130 
   1131 TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
   1132 
   1133 TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
   1134 
   1135 TEST_F(TransformUtilsTest, TestGetInOutTypes) { TestGetInOutTypes(); }
   1136 
   1137 TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
   1138 
   1139 TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
   1140 
   1141 TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
   1142 
   1143 TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
   1144   TestGetOneStringParameter();
   1145 }
   1146 
   1147 TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) {
   1148   TestGetOneInt32Parameter();
   1149 }
   1150 
   1151 TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) {
   1152   TestGetOneInt64Parameter();
   1153 }
   1154 
   1155 TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
   1156   TestGetOneFloatParameter();
   1157 }
   1158 
   1159 TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
   1160   TestGetOneBoolParameter();
   1161 }
   1162 
   1163 }  // namespace graph_transforms
   1164 }  // namespace tensorflow
   1165