1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for graph_matcher.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.framework.python import ops as contrib_ops 22 from tensorflow.contrib.layers.python.layers import initializers 23 from tensorflow.contrib.layers.python.layers import layers 24 from tensorflow.contrib.quantize.python import graph_matcher 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import init_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import nn_ops 32 from tensorflow.python.platform import googletest 33 34 35 class GraphMatcherTest(test_util.TensorFlowTestCase): 36 37 def test_conv_layer(self): 38 g = ops.Graph() 39 with g.as_default(): 40 inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) 41 42 with contrib_ops.arg_scope( 43 [layers.batch_norm], fused=True, is_training=True, trainable=True): 44 return layers.convolution( 45 inputs, 46 num_outputs=16, 47 kernel_size=3, 48 stride=1, 49 padding='VALID', 50 activation_fn=nn_ops.relu, 51 normalizer_fn=layers.batch_norm, 52 normalizer_params={}, 53 weights_initializer=initializers.xavier_initializer(), 54 weights_regularizer=None, 55 biases_initializer=init_ops.zeros_initializer(), 56 biases_regularizer=None, 57 reuse=None, 58 trainable=True, 59 scope=None) 60 61 inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') 62 relu_pattern = graph_matcher.OpTypePattern( 63 'Relu', 64 name='relu', 65 inputs=[ 66 graph_matcher.OpTypePattern( 67 'FusedBatchNorm', 68 inputs=[ 69 graph_matcher.OpTypePattern( 70 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', 71 '*' 72 ]) 73 ]) 74 matcher = graph_matcher.GraphMatcher(relu_pattern) 75 match_results = list(matcher.match_graph(g)) 76 self.assertEqual(1, len(match_results)) 77 match_result = match_results[0] 78 self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) 79 self.assertEqual(match_result.get_tensor('inputs'), inputs) 80 81 def test_multiple_outputs(self): 82 # - + 83 # / \y0 y1/ \ 84 # x split z 85 # | 86 # y (nodes are ops; edges are going up) 87 g = ops.Graph() 88 with g.as_default(): 89 x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') 90 y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') 91 y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) 92 z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') 93 math_ops.add(x, y0) 94 math_ops.subtract(y1, z) 95 96 y1_pattern = graph_matcher.OpTypePattern('*') 97 minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) 98 matcher = graph_matcher.GraphMatcher(minus_pattern) 99 100 match_results = list(matcher.match_graph(g)) 101 self.assertEqual(1, len(match_results)) 102 match_result = match_results[0] 103 104 self.assertEqual(y0.op, y1.op) 105 self.assertEqual(match_result.get_op(y1_pattern), y1.op) 106 self.assertEqual(match_result.get_tensor(y1_pattern), y1) 107 108 def test_oneof_type_pattern(self): 109 # - + 110 # / \ / \ 111 # x y z 112 g = ops.Graph() 113 with g.as_default(): 114 x = array_ops.placeholder(dtypes.float32, shape=[], name='x') 115 y = array_ops.placeholder(dtypes.float32, shape=[], name='y') 116 z = array_ops.placeholder(dtypes.float32, shape=[], name='z') 117 plus = x + y 118 minus = y - z 119 120 add_or_sub_pattern = graph_matcher.OpTypePattern( 121 'Add|Sub', inputs=['*', '*']) 122 matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) 123 self.assertEqual([ 124 match_result.get_op(add_or_sub_pattern) 125 for match_result in matcher.match_graph(g) 126 ], [plus.op, minus.op]) 127 128 def test_oneof_pattern(self): 129 reshape_pattern = graph_matcher.OpTypePattern('Reshape') 130 transpose_pattern = graph_matcher.OneofPattern([ 131 graph_matcher.OpTypePattern( 132 'Transpose', 133 name='transpose', 134 inputs=[ 135 graph_matcher.OpTypePattern( 136 'Slice', name='slice', inputs=[reshape_pattern, '*', '*']), 137 '*' 138 ]), 139 graph_matcher.OpTypePattern( 140 'Transpose', name='transpose', inputs=[reshape_pattern, '*']) 141 ]) 142 143 matcher = graph_matcher.GraphMatcher(transpose_pattern) 144 145 g = ops.Graph() 146 with g.as_default(): 147 inputs = array_ops.placeholder(dtypes.float32, shape=[6]) 148 reshape = array_ops.reshape(inputs, [2, 3]) 149 transpose = array_ops.transpose(reshape) 150 [match_result] = list(matcher.match_graph(g)) 151 self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) 152 self.assertEqual(match_result.get_tensor('slice'), None) 153 self.assertEqual(match_result.get_op('transpose'), transpose.op) 154 155 g = ops.Graph() 156 with g.as_default(): 157 inputs = array_ops.placeholder(dtypes.float32, shape=[6]) 158 reshape = array_ops.reshape(inputs, [2, 3]) 159 slicing = array_ops.slice(reshape, [0, 0], [-1, -1]) 160 transpose = array_ops.transpose(slicing) 161 [match_result] = list(matcher.match_graph(g)) 162 self.assertEqual(match_result.get_tensor(reshape_pattern), reshape) 163 self.assertEqual(match_result.get_tensor('slice'), slicing) 164 self.assertEqual(match_result.get_op('transpose'), transpose.op) 165 166 167 if __name__ == '__main__': 168 googletest.main() 169