Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for graph_matcher."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.framework.python import ops as contrib_ops
     22 from tensorflow.contrib.layers.python.layers import initializers
     23 from tensorflow.contrib.layers.python.layers import layers
     24 from tensorflow.contrib.quantize.python import graph_matcher
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import nn_ops
     32 from tensorflow.python.platform import googletest
     33 
     34 
     35 class GraphMatcherTest(test_util.TensorFlowTestCase):
     36 
     37   def test_conv_layer(self):
     38     g = ops.Graph()
     39     with g.as_default():
     40       inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3])
     41 
     42     with contrib_ops.arg_scope(
     43         [layers.batch_norm], fused=True, is_training=True, trainable=True):
     44       return layers.convolution(
     45           inputs,
     46           num_outputs=16,
     47           kernel_size=3,
     48           stride=1,
     49           padding='VALID',
     50           activation_fn=nn_ops.relu,
     51           normalizer_fn=layers.batch_norm,
     52           normalizer_params={},
     53           weights_initializer=initializers.xavier_initializer(),
     54           weights_regularizer=None,
     55           biases_initializer=init_ops.zeros_initializer(),
     56           biases_regularizer=None,
     57           reuse=None,
     58           trainable=True,
     59           scope=None)
     60 
     61     inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
     62     relu_pattern = graph_matcher.OpTypePattern(
     63         'Relu',
     64         name='relu',
     65         inputs=[
     66             graph_matcher.OpTypePattern(
     67                 'FusedBatchNorm',
     68                 inputs=[
     69                     graph_matcher.OpTypePattern(
     70                         'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*',
     71                     '*'
     72                 ])
     73         ])
     74     matcher = graph_matcher.GraphMatcher(relu_pattern)
     75     match_results = list(matcher.match_graph(g))
     76     self.assertEqual(1, len(match_results))
     77     match_result = match_results[0]
     78     self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
     79     self.assertEqual(match_result.get_tensor('inputs'), inputs)
     80 
     81   def test_multiple_outputs(self):
     82     #   -         +
     83     #  / \y0   y1/ \
     84     # x    split    z
     85     #       |
     86     #       y         (nodes are ops; edges are going up)
     87     g = ops.Graph()
     88     with g.as_default():
     89       x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
     90       y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
     91       y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
     92       z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
     93       math_ops.add(x, y0)
     94       math_ops.subtract(y1, z)
     95 
     96     y1_pattern = graph_matcher.OpTypePattern('*')
     97     minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*'])
     98     matcher = graph_matcher.GraphMatcher(minus_pattern)
     99 
    100     match_results = list(matcher.match_graph(g))
    101     self.assertEqual(1, len(match_results))
    102     match_result = match_results[0]
    103 
    104     self.assertEqual(y0.op, y1.op)
    105     self.assertEqual(match_result.get_op(y1_pattern), y1.op)
    106     self.assertEqual(match_result.get_tensor(y1_pattern), y1)
    107 
    108   def test_oneof_type_pattern(self):
    109     #   -   +
    110     #  / \ / \
    111     # x   y   z
    112     g = ops.Graph()
    113     with g.as_default():
    114       x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
    115       y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
    116       z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
    117       plus = x + y
    118       minus = y - z
    119 
    120     add_or_sub_pattern = graph_matcher.OpTypePattern(
    121         'Add|Sub', inputs=['*', '*'])
    122     matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
    123     self.assertEqual([
    124         match_result.get_op(add_or_sub_pattern)
    125         for match_result in matcher.match_graph(g)
    126     ], [plus.op, minus.op])
    127 
    128   def test_oneof_pattern(self):
    129     reshape_pattern = graph_matcher.OpTypePattern('Reshape')
    130     transpose_pattern = graph_matcher.OneofPattern([
    131         graph_matcher.OpTypePattern(
    132             'Transpose',
    133             name='transpose',
    134             inputs=[
    135                 graph_matcher.OpTypePattern(
    136                     'Slice', name='slice', inputs=[reshape_pattern, '*', '*']),
    137                 '*'
    138             ]),
    139         graph_matcher.OpTypePattern(
    140             'Transpose', name='transpose', inputs=[reshape_pattern, '*'])
    141     ])
    142 
    143     matcher = graph_matcher.GraphMatcher(transpose_pattern)
    144 
    145     g = ops.Graph()
    146     with g.as_default():
    147       inputs = array_ops.placeholder(dtypes.float32, shape=[6])
    148       reshape = array_ops.reshape(inputs, [2, 3])
    149       transpose = array_ops.transpose(reshape)
    150       [match_result] = list(matcher.match_graph(g))
    151       self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
    152       self.assertEqual(match_result.get_tensor('slice'), None)
    153       self.assertEqual(match_result.get_op('transpose'), transpose.op)
    154 
    155     g = ops.Graph()
    156     with g.as_default():
    157       inputs = array_ops.placeholder(dtypes.float32, shape=[6])
    158       reshape = array_ops.reshape(inputs, [2, 3])
    159       slicing = array_ops.slice(reshape, [0, 0], [-1, -1])
    160       transpose = array_ops.transpose(slicing)
    161       [match_result] = list(matcher.match_graph(g))
    162       self.assertEqual(match_result.get_tensor(reshape_pattern), reshape)
    163       self.assertEqual(match_result.get_tensor('slice'), slicing)
    164       self.assertEqual(match_result.get_op('transpose'), transpose.op)
    165 
    166 
    167 if __name__ == '__main__':
    168   googletest.main()
    169