Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/const_op.h"
     17 #include "tensorflow/cc/ops/image_ops.h"
     18 #include "tensorflow/cc/ops/nn_ops.h"
     19 #include "tensorflow/cc/ops/sendrecv_ops.h"
     20 #include "tensorflow/cc/ops/standard_ops.h"
     21 #include "tensorflow/core/framework/tensor_testutil.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/platform/test.h"
     24 #include "tensorflow/core/platform/test_benchmark.h"
     25 #include "tensorflow/core/public/session.h"
     26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     27 
     28 namespace tensorflow {
     29 namespace graph_transforms {
     30 
     31 // Declare here, so we don't need a public header.
     32 Status RemoveNodes(const GraphDef& input_graph_def,
     33                    const TransformFuncContext& context,
     34                    GraphDef* output_graph_def);
     35 
     36 class RemoveNodesTest : public ::testing::Test {
     37  protected:
     38   void TestRemoveNodes() {
     39     GraphDef graph_def;
     40 
     41     NodeDef* add_node1 = graph_def.add_node();
     42     add_node1->set_name("add_node1");
     43     add_node1->set_op("Add");
     44     add_node1->add_input("add_node2");
     45     add_node1->add_input("add_node3");
     46 
     47     NodeDef* add_node2 = graph_def.add_node();
     48     add_node2->set_name("add_node2");
     49     add_node2->set_op("Add");
     50     add_node2->add_input("identity_node1");
     51     add_node2->add_input("identity_node2");
     52 
     53     NodeDef* add_node3 = graph_def.add_node();
     54     add_node3->set_name("add_node3");
     55     add_node3->set_op("Add");
     56     add_node3->add_input("identity_node1");
     57     add_node3->add_input("const_node3");
     58 
     59     NodeDef* identity_node1 = graph_def.add_node();
     60     identity_node1->set_name("identity_node1");
     61     identity_node1->set_op("Identity");
     62     identity_node1->add_input("const_node1");
     63 
     64     NodeDef* identity_node2 = graph_def.add_node();
     65     identity_node2->set_name("identity_node2");
     66     identity_node2->set_op("Identity");
     67     identity_node2->add_input("const_node2");
     68 
     69     NodeDef* identity_node3 = graph_def.add_node();
     70     identity_node3->set_name("identity_node3");
     71     identity_node3->set_op("Identity");
     72     identity_node3->add_input("const_node3");
     73 
     74     NodeDef* const_node1 = graph_def.add_node();
     75     const_node1->set_name("const_node1");
     76     const_node1->set_op("Const");
     77 
     78     NodeDef* const_node2 = graph_def.add_node();
     79     const_node2->set_name("const_node2");
     80     const_node2->set_op("Const");
     81 
     82     NodeDef* const_node3 = graph_def.add_node();
     83     const_node3->set_name("const_node3");
     84     const_node3->set_op("Const");
     85 
     86     NodeDef* add_node4 = graph_def.add_node();
     87     add_node4->set_name("add_node4");
     88     add_node4->set_op("Add");
     89     add_node4->add_input("add_node2");
     90     add_node4->add_input("add_node3");
     91 
     92     GraphDef result;
     93     TransformFuncContext context;
     94     context.input_names = {};
     95     context.output_names = {"add_node1"};
     96     context.params.insert(
     97         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
     98     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
     99 
    100     std::map<string, const NodeDef*> node_lookup;
    101     MapNamesToNodes(result, &node_lookup);
    102     EXPECT_EQ(1, node_lookup.count("add_node1"));
    103     EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0));
    104     EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1));
    105     EXPECT_EQ(1, node_lookup.count("add_node2"));
    106     EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0));
    107     EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1));
    108     EXPECT_EQ(1, node_lookup.count("add_node3"));
    109     EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0));
    110     EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1));
    111     EXPECT_EQ(1, node_lookup.count("add_node4"));
    112     EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0));
    113     EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1));
    114     EXPECT_EQ(0, node_lookup.count("identity_node1"));
    115     EXPECT_EQ(0, node_lookup.count("identity_node2"));
    116     EXPECT_EQ(0, node_lookup.count("identity_node3"));
    117     EXPECT_EQ(1, node_lookup.count("const_node1"));
    118     EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
    119     EXPECT_EQ(1, node_lookup.count("const_node2"));
    120     EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
    121     EXPECT_EQ(1, node_lookup.count("const_node3"));
    122     EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
    123   }
    124 
    125   void TestRemoveOutputNodes() {
    126     GraphDef graph_def;
    127 
    128     NodeDef* const_node1 = graph_def.add_node();
    129     const_node1->set_name("const_node1");
    130     const_node1->set_op("Const");
    131 
    132     NodeDef* const_node2 = graph_def.add_node();
    133     const_node2->set_name("const_node2");
    134     const_node2->set_op("Const");
    135 
    136     NodeDef* add_node = graph_def.add_node();
    137     add_node->set_name("add_node");
    138     add_node->set_op("Add");
    139     add_node->add_input("const_node1");
    140     add_node->add_input("const_node2");
    141 
    142     NodeDef* identity_node = graph_def.add_node();
    143     identity_node->set_name("identity_node");
    144     identity_node->set_op("Identity");
    145     identity_node->add_input("add_node");
    146 
    147     GraphDef result;
    148     TransformFuncContext context;
    149     context.input_names = {};
    150     context.output_names = {"identity_node"};
    151     context.params.insert(
    152         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
    153     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
    154 
    155     std::map<string, const NodeDef*> node_lookup;
    156     MapNamesToNodes(result, &node_lookup);
    157     EXPECT_EQ(1, node_lookup.count("add_node"));
    158     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
    159     EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
    160     EXPECT_EQ(1, node_lookup.count("identity_node"));
    161     EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0));
    162   }
    163 
    164   void TestRemoveChainedNodes() {
    165     GraphDef graph_def;
    166 
    167     NodeDef* const_node1 = graph_def.add_node();
    168     const_node1->set_name("const_node1");
    169     const_node1->set_op("Const");
    170 
    171     NodeDef* identity_node1 = graph_def.add_node();
    172     identity_node1->set_name("identity_node1");
    173     identity_node1->set_op("Identity");
    174     identity_node1->add_input("const_node1");
    175 
    176     NodeDef* identity_node2 = graph_def.add_node();
    177     identity_node2->set_name("identity_node2");
    178     identity_node2->set_op("Identity");
    179     identity_node2->add_input("identity_node1");
    180 
    181     NodeDef* identity_node3 = graph_def.add_node();
    182     identity_node3->set_name("identity_node3");
    183     identity_node3->set_op("Identity");
    184     identity_node3->add_input("identity_node2");
    185 
    186     NodeDef* const_node2 = graph_def.add_node();
    187     const_node2->set_name("const_node2");
    188     const_node2->set_op("Const");
    189 
    190     NodeDef* add_node = graph_def.add_node();
    191     add_node->set_name("add_node");
    192     add_node->set_op("Add");
    193     add_node->add_input("identity_node3");
    194     add_node->add_input("const_node2");
    195 
    196     GraphDef result;
    197     TransformFuncContext context;
    198     context.input_names = {};
    199     context.output_names = {"identity_node"};
    200     context.params.insert(
    201         std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
    202     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
    203 
    204     std::map<string, const NodeDef*> node_lookup;
    205     MapNamesToNodes(result, &node_lookup);
    206     EXPECT_EQ(1, node_lookup.count("add_node"));
    207     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
    208     EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
    209     EXPECT_EQ(0, node_lookup.count("identity_node1"));
    210     EXPECT_EQ(0, node_lookup.count("identity_node2"));
    211     EXPECT_EQ(0, node_lookup.count("identity_node3"));
    212   }
    213 
    214   void TestRemoveMultipleInputs() {
    215     GraphDef graph_def;
    216 
    217     NodeDef* const_node1 = graph_def.add_node();
    218     const_node1->set_name("const_node1");
    219     const_node1->set_op("Const");
    220 
    221     NodeDef* const_node2 = graph_def.add_node();
    222     const_node2->set_name("const_node2");
    223     const_node2->set_op("Const");
    224 
    225     NodeDef* const_node3 = graph_def.add_node();
    226     const_node3->set_name("const_node3");
    227     const_node3->set_op("Const");
    228 
    229     NodeDef* const_node4 = graph_def.add_node();
    230     const_node4->set_name("const_node4");
    231     const_node4->set_op("Const");
    232 
    233     NodeDef* fake_quant_node = graph_def.add_node();
    234     fake_quant_node->set_name("fake_quant_node");
    235     fake_quant_node->set_op("FakeQuantWithMinMaxVars");
    236     fake_quant_node->add_input("const_node1");
    237     fake_quant_node->add_input("const_node2");
    238     fake_quant_node->add_input("const_node3");
    239 
    240     NodeDef* add_node = graph_def.add_node();
    241     add_node->set_name("add_node");
    242     add_node->set_op("Add");
    243     add_node->add_input("fake_quant_node");
    244     add_node->add_input("const_node4");
    245 
    246     GraphDef result;
    247     TransformFuncContext context;
    248     context.input_names = {};
    249     context.output_names = {"add_node"};
    250     context.params.insert(std::pair<string, std::vector<string>>(
    251         {"op", {string("FakeQuantWithMinMaxVars")}}));
    252     context.params.insert(
    253         std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
    254     TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
    255 
    256     std::map<string, const NodeDef*> node_lookup;
    257     MapNamesToNodes(result, &node_lookup);
    258     ASSERT_EQ(1, node_lookup.count("const_node1"));
    259     ASSERT_EQ(1, node_lookup.count("const_node4"));
    260     ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
    261     ASSERT_EQ(1, node_lookup.count("add_node"));
    262     EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
    263     EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
    264   }
    265 };
    266 
    267 TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
    268 
    269 TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
    270 
    271 TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
    272 
    273 TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
    274   TestRemoveMultipleInputs();
    275 }
    276 
    277 }  // namespace graph_transforms
    278 }  // namespace tensorflow
    279