1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for StatSummarizer Python wrapper.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.core.framework import attr_value_pb2 22 from tensorflow.core.framework import graph_pb2 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import tensor_util 25 from tensorflow.python.platform import test 26 from tensorflow.tools.graph_transforms import TransformGraph 27 28 29 class TransformGraphTest(test.TestCase): 30 31 # This test constructs a graph with a relu op that's not used by the normal 32 # inference path, and then tests that the strip_unused transform removes it as 33 # expected. 34 def testTransformGraph(self): 35 input_graph_def = graph_pb2.GraphDef() 36 37 const_op1 = input_graph_def.node.add() 38 const_op1.op = "Const" 39 const_op1.name = "const_op1" 40 const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( 41 type=dtypes.float32.as_datatype_enum)) 42 const_op1.attr["value"].CopyFrom( 43 attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( 44 [1, 2], dtypes.float32, [1, 2]))) 45 46 const_op2 = input_graph_def.node.add() 47 const_op2.op = "Const" 48 const_op2.name = "const_op2" 49 const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( 50 type=dtypes.float32.as_datatype_enum)) 51 const_op2.attr["value"].CopyFrom( 52 attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( 53 [3, 4], dtypes.float32, [1, 2]))) 54 55 # Create an add that has two constants as inputs. 56 add_op = input_graph_def.node.add() 57 add_op.op = "Add" 58 add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( 59 type=dtypes.float32.as_datatype_enum)) 60 add_op.name = "add_op" 61 add_op.input.extend(["const_op1", "const_op2"]) 62 63 # Create a relu that reads from the add. 64 relu_op = input_graph_def.node.add() 65 relu_op.op = "Relu" 66 relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( 67 type=dtypes.float32.as_datatype_enum)) 68 relu_op.name = "relu_op" 69 relu_op.input.extend(["add_op"]) 70 71 # We're specifying that add_op is the final output, and so the relu isn't 72 # needed. 73 input_names = [] 74 output_names = ["add_op"] 75 transforms = ["strip_unused_nodes"] 76 transformed_graph_def = TransformGraph(input_graph_def, input_names, 77 output_names, transforms) 78 79 # We expect that the relu is no longer present after running the transform. 80 for node in transformed_graph_def.node: 81 self.assertNotEqual("Relu", node.op) 82 83 84 if __name__ == "__main__": 85 test.main() 86