Home | History | Annotate | Download | only in tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Simple graph matching functions."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from six import string_types
     22 
     23 from tensorflow.contrib.graph_editor import select
     24 from tensorflow.python.framework import ops as tf_ops
     25 
     26 __all__ = [
     27     "op_type",
     28     "OpMatcher",
     29 ]
     30 
     31 
     32 def _make_graph_match(graph_match):
     33   """Convert to a OpMatcher instance."""
     34   if graph_match is None:
     35     return None
     36   if not isinstance(graph_match, OpMatcher):
     37     graph_match = OpMatcher(graph_match)
     38   return graph_match
     39 
     40 
     41 def op_type(op_types, op=None):
     42   """Check if an op is of the given type.
     43 
     44   Args:
     45     op_types: tuple of strings containing the types to check against.
     46       For instance: ("Add", "Const")
     47     op: the operation to check (or None).
     48   Returns:
     49     if op is not None, return True if the op is of the correct type.
     50     if op is None, return a lambda function which does the type checking.
     51   """
     52   if isinstance(op_types, string_types):
     53     op_types = (op_types)
     54   if op is None:
     55     return lambda op: op.node_def.op in op_types
     56   else:
     57     return op.node_def.op in op_types
     58 
     59 
     60 class OpMatcher(object):
     61   """Graph match class."""
     62 
     63   def __init__(self, positive_filter):
     64     """Graph match constructor."""
     65     self.positive_filters = []
     66     self.input_op_matches = None
     67     self.control_input_op_matches = None
     68     self.output_op_matches = None
     69     positive_filter = self._finalize_positive_filter(positive_filter)
     70     self.positive_filters.append(positive_filter)
     71 
     72   def _finalize_positive_filter(self, elem):
     73     """Convert to a filter function."""
     74     if select.can_be_regex(elem):
     75       regex_ = select.make_regex(elem)
     76       return lambda op, regex=regex_: regex.search(op.name) is not None
     77     elif isinstance(elem, tf_ops.Operation):
     78       return lambda op, match_op=elem: op is match_op
     79     elif callable(elem):
     80       return elem
     81     elif elem is True:
     82       return lambda op: True
     83     else:
     84       raise ValueError("Cannot finalize the positive filter: {}".format(elem))
     85 
     86   def __call__(self, op):
     87     """Evaluate if the op matches or not."""
     88     if not isinstance(op, tf_ops.Operation):
     89       raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
     90     for positive_filter in self.positive_filters:
     91       if not positive_filter(op):
     92         return False
     93     if self.input_op_matches is not None:
     94       if len(op.inputs) != len(self.input_op_matches):
     95         return False
     96       for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
     97         if input_op_match is None:
     98           continue
     99         if not input_op_match(input_t.op):
    100           return False
    101     if self.control_input_op_matches is not None:
    102       if len(op.control_inputs) != len(self.control_input_op_matches):
    103         return False
    104       for cinput_op, cinput_op_match in zip(op.control_inputs,
    105                                             self.control_input_op_matches):
    106         if cinput_op_match is None:
    107           continue
    108         if not cinput_op_match(cinput_op):
    109           return False
    110     if self.output_op_matches is not None:
    111       if len(op.outputs) != len(self.output_op_matches):
    112         return False
    113       for output_t, output_op_matches in zip(op.outputs,
    114                                              self.output_op_matches):
    115         if output_op_matches is None:
    116           continue
    117         if len(output_t.consumers()) != len(output_op_matches):
    118           return False
    119         for consumer_op, consumer_op_match in zip(output_t.consumers(),
    120                                                   output_op_matches):
    121           if consumer_op_match is None:
    122             continue
    123           if not consumer_op_match(consumer_op):
    124             return False
    125     return True
    126 
    127   def input_ops(self, *args):
    128     """Add input matches."""
    129     if self.input_op_matches is not None:
    130       raise ValueError("input_op_matches is already set.")
    131     self.input_op_matches = []
    132     for input_match in args:
    133       self.input_op_matches.append(_make_graph_match(input_match))
    134     return self
    135 
    136   def control_input_ops(self, *args):
    137     """Add input matches."""
    138     if self.control_input_op_matches is not None:
    139       raise ValueError("control_input_op_matches is already set.")
    140     self.control_input_op_matches = []
    141     for input_match in args:
    142       self.control_input_op_matches.append(_make_graph_match(input_match))
    143     return self
    144 
    145   def output_ops(self, *args):
    146     """Add output matches."""
    147     if self.output_op_matches is not None:
    148       raise ValueError("output_op_matches is already set.")
    149     self.output_op_matches = []
    150     for consumer_op_matches in args:
    151       if consumer_op_matches is None:
    152         self.output_op_matches.append(None)
    153       if not isinstance(consumer_op_matches, list):
    154         consumer_op_matches = [consumer_op_matches]
    155       consumer_op_matches = [_make_graph_match(consumer_op_match)
    156                              for consumer_op_match in consumer_op_matches]
    157       self.output_op_matches.append(consumer_op_matches)
    158     return self
    159