Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for StatSummarizer Python wrapper."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.framework import attr_value_pb2
     22 from tensorflow.core.framework import graph_pb2
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import tensor_util
     25 from tensorflow.python.platform import test
     26 from tensorflow.tools.graph_transforms import TransformGraph
     27 
     28 
     29 class TransformGraphTest(test.TestCase):
     30 
     31   # This test constructs a graph with a relu op that's not used by the normal
     32   # inference path, and then tests that the strip_unused transform removes it as
     33   # expected.
     34   def testTransformGraph(self):
     35     input_graph_def = graph_pb2.GraphDef()
     36 
     37     const_op1 = input_graph_def.node.add()
     38     const_op1.op = "Const"
     39     const_op1.name = "const_op1"
     40     const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
     41         type=dtypes.float32.as_datatype_enum))
     42     const_op1.attr["value"].CopyFrom(
     43         attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
     44             [1, 2], dtypes.float32, [1, 2])))
     45 
     46     const_op2 = input_graph_def.node.add()
     47     const_op2.op = "Const"
     48     const_op2.name = "const_op2"
     49     const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
     50         type=dtypes.float32.as_datatype_enum))
     51     const_op2.attr["value"].CopyFrom(
     52         attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
     53             [3, 4], dtypes.float32, [1, 2])))
     54 
     55     # Create an add that has two constants as inputs.
     56     add_op = input_graph_def.node.add()
     57     add_op.op = "Add"
     58     add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
     59         type=dtypes.float32.as_datatype_enum))
     60     add_op.name = "add_op"
     61     add_op.input.extend(["const_op1", "const_op2"])
     62 
     63     # Create a relu that reads from the add.
     64     relu_op = input_graph_def.node.add()
     65     relu_op.op = "Relu"
     66     relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
     67         type=dtypes.float32.as_datatype_enum))
     68     relu_op.name = "relu_op"
     69     relu_op.input.extend(["add_op"])
     70 
     71     # We're specifying that add_op is the final output, and so the relu isn't
     72     # needed.
     73     input_names = []
     74     output_names = ["add_op"]
     75     transforms = ["strip_unused_nodes"]
     76     transformed_graph_def = TransformGraph(input_graph_def, input_names,
     77                                            output_names, transforms)
     78 
     79     # We expect that the relu is no longer present after running the transform.
     80     for node in transformed_graph_def.node:
     81       self.assertNotEqual("Relu", node.op)
     82 
     83 
     84 if __name__ == "__main__":
     85   test.main()
     86