Home | History | Annotate | Download | only in graph_editor
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utility functions for the graph_editor.
     16 """
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 import re
     24 from six import iteritems
     25 from tensorflow.python.framework import ops as tf_ops
     26 from tensorflow.python.ops import array_ops as tf_array_ops
     27 
     28 __all__ = [
     29     "make_list_of_op",
     30     "get_tensors",
     31     "make_list_of_t",
     32     "get_generating_ops",
     33     "get_consuming_ops",
     34     "ControlOutputs",
     35     "placeholder_name",
     36     "make_placeholder_from_tensor",
     37     "make_placeholder_from_dtype_and_shape",
     38 ]
     39 
     40 
     41 def concatenate_unique(la, lb):
     42   """Add all the elements of `lb` to `la` if they are not there already.
     43 
     44   The elements added to `la` maintain ordering with respect to `lb`.
     45 
     46   Args:
     47     la: List of Python objects.
     48     lb: List of Python objects.
     49   Returns:
     50     `la`: The list `la` with missing elements from `lb`.
     51   """
     52   la_set = set(la)
     53   for l in lb:
     54     if l not in la_set:
     55       la.append(l)
     56       la_set.add(l)
     57   return la
     58 
     59 
     60 # TODO(fkp): very generic code, it should be moved in a more generic place.
     61 class ListView(object):
     62   """Immutable list wrapper.
     63 
     64   This class is strongly inspired by the one in tf.Operation.
     65   """
     66 
     67   def __init__(self, list_):
     68     if not isinstance(list_, list):
     69       raise TypeError("Expected a list, got: {}.".format(type(list_)))
     70     self._list = list_
     71 
     72   def __iter__(self):
     73     return iter(self._list)
     74 
     75   def __len__(self):
     76     return len(self._list)
     77 
     78   def __bool__(self):
     79     return bool(self._list)
     80 
     81   # Python 3 wants __bool__, Python 2.7 wants __nonzero__
     82   __nonzero__ = __bool__
     83 
     84   def __getitem__(self, i):
     85     return self._list[i]
     86 
     87   def __add__(self, other):
     88     if not isinstance(other, list):
     89       other = list(other)
     90     return list(self) + other
     91 
     92 
     93 # TODO(fkp): very generic code, it should be moved in a more generic place.
     94 def is_iterable(obj):
     95   """Return true if the object is iterable."""
     96   if isinstance(obj, tf_ops.Tensor):
     97     return False
     98   try:
     99     _ = iter(obj)
    100   except Exception:  # pylint: disable=broad-except
    101     return False
    102   return True
    103 
    104 
    105 def flatten_tree(tree, leaves=None):
    106   """Flatten a tree into a list.
    107 
    108   Args:
    109     tree: iterable or not. If iterable, its elements (child) can also be
    110       iterable or not.
    111     leaves: list to which the tree leaves are appended (None by default).
    112   Returns:
    113     A list of all the leaves in the tree.
    114   """
    115   if leaves is None:
    116     leaves = []
    117   if isinstance(tree, dict):
    118     for _, child in iteritems(tree):
    119       flatten_tree(child, leaves)
    120   elif is_iterable(tree):
    121     for child in tree:
    122       flatten_tree(child, leaves)
    123   else:
    124     leaves.append(tree)
    125   return leaves
    126 
    127 
    128 def transform_tree(tree, fn, iterable_type=tuple):
    129   """Transform all the nodes of a tree.
    130 
    131   Args:
    132     tree: iterable or not. If iterable, its elements (child) can also be
    133       iterable or not.
    134     fn: function to apply to each leaves.
    135     iterable_type: type use to construct the resulting tree for unknown
    136       iterable, typically `list` or `tuple`.
    137   Returns:
    138     A tree whose leaves has been transformed by `fn`.
    139     The hierarchy of the output tree mimics the one of the input tree.
    140   """
    141   if is_iterable(tree):
    142     if isinstance(tree, dict):
    143       res = tree.__new__(type(tree))
    144       res.__init__(
    145           (k, transform_tree(child, fn)) for k, child in iteritems(tree))
    146       return res
    147     elif isinstance(tree, tuple):
    148       # NamedTuple?
    149       if hasattr(tree, "_asdict"):
    150         res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn))
    151       else:
    152         res = tree.__new__(type(tree),
    153                            (transform_tree(child, fn) for child in tree))
    154       return res
    155     elif isinstance(tree, collections.Sequence):
    156       res = tree.__new__(type(tree))
    157       res.__init__(transform_tree(child, fn) for child in tree)
    158       return res
    159     else:
    160       return iterable_type(transform_tree(child, fn) for child in tree)
    161   else:
    162     return fn(tree)
    163 
    164 
    165 def check_graphs(*args):
    166   """Check that all the element in args belong to the same graph.
    167 
    168   Args:
    169     *args: a list of object with a obj.graph property.
    170   Raises:
    171     ValueError: if all the elements do not belong to the same graph.
    172   """
    173   graph = None
    174   for i, sgv in enumerate(args):
    175     if graph is None and sgv.graph is not None:
    176       graph = sgv.graph
    177     elif sgv.graph is not None and sgv.graph is not graph:
    178       raise ValueError("Argument[{}]: Wrong graph!".format(i))
    179 
    180 
    181 def get_unique_graph(tops, check_types=None, none_if_empty=False):
    182   """Return the unique graph used by the all the elements in tops.
    183 
    184   Args:
    185     tops: list of elements to check (usually a list of tf.Operation and/or
    186       tf.Tensor). Or a tf.Graph.
    187     check_types: check that the element in tops are of given type(s). If None,
    188       the types (tf.Operation, tf.Tensor) are used.
    189     none_if_empty: don't raise an error if tops is an empty list, just return
    190       None.
    191   Returns:
    192     The unique graph used by all the tops.
    193   Raises:
    194     TypeError: if tops is not a iterable of tf.Operation.
    195     ValueError: if the graph is not unique.
    196   """
    197   if isinstance(tops, tf_ops.Graph):
    198     return tops
    199   if not is_iterable(tops):
    200     raise TypeError("{} is not iterable".format(type(tops)))
    201   if check_types is None:
    202     check_types = (tf_ops.Operation, tf_ops.Tensor)
    203   elif not is_iterable(check_types):
    204     check_types = (check_types,)
    205   g = None
    206   for op in tops:
    207     if not isinstance(op, check_types):
    208       raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
    209           t) for t in check_types]), type(op)))
    210     if g is None:
    211       g = op.graph
    212     elif g is not op.graph:
    213       raise ValueError("Operation {} does not belong to given graph".format(op))
    214   if g is None and not none_if_empty:
    215     raise ValueError("Can't find the unique graph of an empty list")
    216   return g
    217 
    218 
    219 def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
    220   """Convert ops to a list of `tf.Operation`.
    221 
    222   Args:
    223     ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
    224       operation.
    225     check_graph: if `True` check if all the operations belong to the same graph.
    226     allow_graph: if `False` a `tf.Graph` cannot be converted.
    227     ignore_ts: if True, silently ignore `tf.Tensor`.
    228   Returns:
    229     A newly created list of `tf.Operation`.
    230   Raises:
    231     TypeError: if ops cannot be converted to a list of `tf.Operation` or,
    232      if `check_graph` is `True`, if all the ops do not belong to the
    233      same graph.
    234   """
    235   if isinstance(ops, tf_ops.Graph):
    236     if allow_graph:
    237       return ops.get_operations()
    238     else:
    239       raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
    240   else:
    241     if not is_iterable(ops):
    242       ops = [ops]
    243     if not ops:
    244       return []
    245     if check_graph:
    246       check_types = None if ignore_ts else tf_ops.Operation
    247       get_unique_graph(ops, check_types=check_types)
    248     return [op for op in ops if isinstance(op, tf_ops.Operation)]
    249 
    250 
    251 # TODO(fkp): move this function in tf.Graph?
    252 def get_tensors(graph):
    253   """get all the tensors which are input or output of an op in the graph.
    254 
    255   Args:
    256     graph: a `tf.Graph`.
    257   Returns:
    258     A list of `tf.Tensor`.
    259   Raises:
    260     TypeError: if graph is not a `tf.Graph`.
    261   """
    262   if not isinstance(graph, tf_ops.Graph):
    263     raise TypeError("Expected a graph, got: {}".format(type(graph)))
    264   ts = []
    265   for op in graph.get_operations():
    266     ts += op.outputs
    267   return ts
    268 
    269 
    270 def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
    271   """Convert ts to a list of `tf.Tensor`.
    272 
    273   Args:
    274     ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
    275     check_graph: if `True` check if all the tensors belong to the same graph.
    276     allow_graph: if `False` a `tf.Graph` cannot be converted.
    277     ignore_ops: if `True`, silently ignore `tf.Operation`.
    278   Returns:
    279     A newly created list of `tf.Tensor`.
    280   Raises:
    281     TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
    282      if `check_graph` is `True`, if all the ops do not belong to the same graph.
    283   """
    284   if isinstance(ts, tf_ops.Graph):
    285     if allow_graph:
    286       return get_tensors(ts)
    287     else:
    288       raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
    289   else:
    290     if not is_iterable(ts):
    291       ts = [ts]
    292     if not ts:
    293       return []
    294     if check_graph:
    295       check_types = None if ignore_ops else tf_ops.Tensor
    296       get_unique_graph(ts, check_types=check_types)
    297     return [t for t in ts if isinstance(t, tf_ops.Tensor)]
    298 
    299 
    300 def get_generating_ops(ts):
    301   """Return all the generating ops of the tensors in `ts`.
    302 
    303   Args:
    304     ts: a list of `tf.Tensor`
    305   Returns:
    306     A list of all the generating `tf.Operation` of the tensors in `ts`.
    307   Raises:
    308     TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
    309   """
    310   ts = make_list_of_t(ts, allow_graph=False)
    311   return [t.op for t in ts]
    312 
    313 
    314 def get_consuming_ops(ts):
    315   """Return all the consuming ops of the tensors in ts.
    316 
    317   Args:
    318     ts: a list of `tf.Tensor`
    319   Returns:
    320     A list of all the consuming `tf.Operation` of the tensors in `ts`.
    321   Raises:
    322     TypeError: if ts cannot be converted to a list of `tf.Tensor`.
    323   """
    324   ts = make_list_of_t(ts, allow_graph=False)
    325   ops = []
    326   for t in ts:
    327     for op in t.consumers():
    328       if op not in ops:
    329         ops.append(op)
    330   return ops
    331 
    332 
    333 class ControlOutputs(object):
    334   """The control outputs topology."""
    335 
    336   def __init__(self, graph):
    337     """Create a dictionary of control-output dependencies.
    338 
    339     Args:
    340       graph: a `tf.Graph`.
    341     Returns:
    342       A dictionary where a key is a `tf.Operation` instance and the
    343          corresponding value is a list of all the ops which have the key
    344          as one of their control-input dependencies.
    345     Raises:
    346       TypeError: graph is not a `tf.Graph`.
    347     """
    348     if not isinstance(graph, tf_ops.Graph):
    349       raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
    350     self._control_outputs = {}
    351     self._graph = graph
    352     self._version = None
    353     self._build()
    354 
    355   def update(self):
    356     """Update the control outputs if the graph has changed."""
    357     if self._version != self._graph.version:
    358       self._build()
    359     return self
    360 
    361   def _build(self):
    362     """Build the control outputs dictionary."""
    363     self._control_outputs.clear()
    364     ops = self._graph.get_operations()
    365     for op in ops:
    366       for control_input in op.control_inputs:
    367         if control_input not in self._control_outputs:
    368           self._control_outputs[control_input] = []
    369         if op not in self._control_outputs[control_input]:
    370           self._control_outputs[control_input].append(op)
    371     self._version = self._graph.version
    372 
    373   def get_all(self):
    374     return self._control_outputs
    375 
    376   def get(self, op):
    377     """return the control outputs of op."""
    378     if op in self._control_outputs:
    379       return self._control_outputs[op]
    380     else:
    381       return ()
    382 
    383   @property
    384   def graph(self):
    385     return self._graph
    386 
    387 
    388 def scope_finalize(scope):
    389   if scope and scope[-1] != "/":
    390     scope += "/"
    391   return scope
    392 
    393 
    394 def scope_dirname(scope):
    395   slash = scope.rfind("/")
    396   if slash == -1:
    397     return ""
    398   return scope[:slash + 1]
    399 
    400 
    401 def scope_basename(scope):
    402   slash = scope.rfind("/")
    403   if slash == -1:
    404     return scope
    405   return scope[slash + 1:]
    406 
    407 
    408 def placeholder_name(t=None, scope=None):
    409   """Create placeholder name for the graph editor.
    410 
    411   Args:
    412     t: optional tensor on which the placeholder operation's name will be based
    413       on
    414     scope: absolute scope with which to prefix the placeholder's name. None
    415       means that the scope of t is preserved. "" means the root scope.
    416   Returns:
    417     A new placeholder name prefixed by "geph". Note that "geph" stands for
    418       Graph Editor PlaceHolder. This convention allows to quickly identify the
    419       placeholder generated by the Graph Editor.
    420   Raises:
    421     TypeError: if t is not None or a tf.Tensor.
    422   """
    423   if scope is not None:
    424     scope = scope_finalize(scope)
    425   if t is not None:
    426     if not isinstance(t, tf_ops.Tensor):
    427       raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t)))
    428     op_dirname = scope_dirname(t.op.name)
    429     op_basename = scope_basename(t.op.name)
    430     if scope is None:
    431       scope = op_dirname
    432 
    433     if op_basename.startswith("geph__"):
    434       ph_name = op_basename
    435     else:
    436       ph_name = "geph__{}_{}".format(op_basename, t.value_index)
    437 
    438     return scope + ph_name
    439   else:
    440     if scope is None:
    441       scope = ""
    442     return scope + "geph"
    443 
    444 
    445 def make_placeholder_from_tensor(t, scope=None):
    446   """Create a `tf.placeholder` for the Graph Editor.
    447 
    448   Note that the correct graph scope must be set by the calling function.
    449 
    450   Args:
    451     t: a `tf.Tensor` whose name will be used to create the placeholder
    452       (see function placeholder_name).
    453     scope: absolute scope within which to create the placeholder. None
    454       means that the scope of `t` is preserved. `""` means the root scope.
    455   Returns:
    456     A newly created `tf.placeholder`.
    457   Raises:
    458     TypeError: if `t` is not `None` or a `tf.Tensor`.
    459   """
    460   return tf_array_ops.placeholder(
    461       dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(
    462           t, scope=scope))
    463 
    464 
    465 def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
    466   """Create a tf.placeholder for the Graph Editor.
    467 
    468   Note that the correct graph scope must be set by the calling function.
    469   The placeholder is named using the function placeholder_name (with no
    470   tensor argument).
    471 
    472   Args:
    473     dtype: the tensor type.
    474     shape: the tensor shape (optional).
    475     scope: absolute scope within which to create the placeholder. None
    476       means that the scope of t is preserved. "" means the root scope.
    477   Returns:
    478     A newly created tf.placeholder.
    479   """
    480   return tf_array_ops.placeholder(
    481       dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
    482 
    483 
    484 _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
    485 
    486 
    487 def get_predefined_collection_names():
    488   """Return all the predefined collection names."""
    489   return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
    490           if not _INTERNAL_VARIABLE_RE.match(key)]
    491 
    492 
    493 def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
    494   """Find corresponding op/tensor in a different graph.
    495 
    496   Args:
    497     target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
    498     dst_graph: The graph in which the corresponding graph element must be found.
    499     dst_scope: A scope which is prepended to the name to look for.
    500     src_scope: A scope which is removed from the original of `target` name.
    501 
    502   Returns:
    503     The corresponding tf.Tensor` or a `tf.Operation`.
    504 
    505   Raises:
    506     ValueError: if `src_name` does not start with `src_scope`.
    507     TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
    508     KeyError: If the corresponding graph element cannot be found.
    509   """
    510   src_name = target.name
    511   if src_scope:
    512     src_scope = scope_finalize(src_scope)
    513     if not src_name.startswidth(src_scope):
    514       raise ValueError("{} does not start with {}".format(src_name, src_scope))
    515     src_name = src_name[len(src_scope):]
    516 
    517   dst_name = src_name
    518   if dst_scope:
    519     dst_scope = scope_finalize(dst_scope)
    520     dst_name = dst_scope + dst_name
    521 
    522   if isinstance(target, tf_ops.Tensor):
    523     return dst_graph.get_tensor_by_name(dst_name)
    524   if isinstance(target, tf_ops.Operation):
    525     return dst_graph.get_operation_by_name(dst_name)
    526   raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
    527 
    528 
    529 def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
    530   """Find corresponding ops/tensors in a different graph.
    531 
    532   `targets` is a Python tree, that is, a nested structure of iterable
    533   (list, tupple, dictionary) whose leaves are instances of
    534   `tf.Tensor` or `tf.Operation`
    535 
    536   Args:
    537     targets: A Python tree containing `tf.Tensor` or `tf.Operation`
    538       belonging to the original graph.
    539     dst_graph: The graph in which the corresponding graph element must be found.
    540     dst_scope: A scope which is prepended to the name to look for.
    541     src_scope: A scope which is removed from the original of `top` name.
    542 
    543   Returns:
    544     A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
    545 
    546   Raises:
    547     ValueError: if `src_name` does not start with `src_scope`.
    548     TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
    549     KeyError: If the corresponding graph element cannot be found.
    550   """
    551   def func(top):
    552     return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
    553   return transform_tree(targets, func)
    554