Home | History | Annotate | Download | only in graph_editor
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Various function for graph editing."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.graph_editor import reroute
     22 from tensorflow.contrib.graph_editor import select
     23 from tensorflow.contrib.graph_editor import subgraph
     24 from tensorflow.contrib.graph_editor import util
     25 from tensorflow.python.ops import array_ops as tf_array_ops
     26 
     27 __all__ = [
     28     "detach_control_inputs",
     29     "detach_control_outputs",
     30     "detach_inputs",
     31     "detach_outputs",
     32     "detach",
     33     "connect",
     34     "bypass",
     35 ]
     36 
     37 
     38 def detach_control_inputs(sgv):
     39   """Detach all the external control inputs of the subgraph sgv.
     40 
     41   Args:
     42     sgv: the subgraph view to be detached. This argument is converted to a
     43       subgraph using the same rules as the function subgraph.make_view.
     44   """
     45   sgv = subgraph.make_view(sgv)
     46   for op in sgv.ops:
     47     cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
     48     reroute.remove_control_inputs(op, cops)
     49 
     50 
     51 def detach_control_outputs(sgv, control_outputs):
     52   """Detach all the external control outputs of the subgraph sgv.
     53 
     54   Args:
     55     sgv: the subgraph view to be detached. This argument is converted to a
     56       subgraph using the same rules as the function subgraph.make_view.
     57     control_outputs: a util.ControlOutputs instance.
     58   """
     59   if not isinstance(control_outputs, util.ControlOutputs):
     60     raise TypeError("Expected a util.ControlOutputs, got: {}",
     61                     type(control_outputs))
     62   control_outputs.update()
     63   sgv = subgraph.make_view(sgv)
     64   for op in sgv.ops:
     65     for cop in control_outputs.get(op):
     66       if cop not in sgv.ops:
     67         reroute.remove_control_inputs(cop, op)
     68 
     69 
     70 def detach_inputs(sgv, control_inputs=False):
     71   """Detach the inputs of a subgraph view.
     72 
     73   Args:
     74     sgv: the subgraph view to be detached. This argument is converted to a
     75       subgraph using the same rules as the function subgraph.make_view.
     76       Note that sgv is modified in place.
     77     control_inputs: if True control_inputs are also detached.
     78   Returns:
     79     A tuple `(sgv, input_placeholders)` where
     80       `sgv` is a new subgraph view of the detached subgraph;
     81       `input_placeholders` is a list of the created input placeholders.
     82   Raises:
     83     StandardError: if sgv cannot be converted to a SubGraphView using
     84       the same rules than the function subgraph.make_view.
     85   """
     86   sgv = subgraph.make_view(sgv)
     87 
     88   with sgv.graph.as_default():
     89     input_placeholders = [
     90         tf_array_ops.placeholder(
     91             dtype=input_t.dtype, name=util.placeholder_name(input_t))
     92         for input_t in sgv.inputs
     93     ]
     94 
     95   reroute.swap_inputs(sgv, input_placeholders)
     96   if control_inputs:
     97     detach_control_inputs(sgv)
     98   return sgv, input_placeholders
     99 
    100 
    101 def detach_outputs(sgv, control_outputs=None):
    102   """Detach the output of a subgraph view.
    103 
    104   Args:
    105     sgv: the subgraph view to be detached. This argument is converted to a
    106       subgraph using the same rules as the function subgraph.make_view.
    107       Note that sgv is modified in place.
    108     control_outputs: a util.ControlOutputs instance or None. If not None the
    109       control outputs are also detached.
    110   Returns:
    111     A tuple `(sgv, output_placeholders)` where
    112       `sgv` is a new subgraph view of the detached subgraph;
    113       `output_placeholders` is a list of the created output placeholders.
    114   Raises:
    115     StandardError: if sgv cannot be converted to a SubGraphView using
    116       the same rules than the function subgraph.make_view.
    117   """
    118   sgv = subgraph.make_view(sgv)
    119   # only select outputs with consumers
    120   sgv_ = sgv.remap_outputs([output_id
    121                             for output_id, output_t in enumerate(sgv.outputs)
    122                             if output_t.consumers()])
    123   # create consumer subgraph and remap
    124   consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
    125   consumers_sgv = consumers_sgv.remap_inputs(
    126       [input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
    127        if input_t in sgv_.outputs])
    128 
    129   with sgv_.graph.as_default():
    130     output_placeholders = [
    131         util.make_placeholder_from_tensor(input_t)
    132         for input_t in consumers_sgv.inputs
    133     ]
    134 
    135   reroute.swap_outputs(sgv_, output_placeholders)
    136   if control_outputs is not None:
    137     detach_control_outputs(sgv_, control_outputs)
    138   return sgv_, output_placeholders
    139 
    140 
    141 def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
    142   """Detach both the inputs and the outputs of a subgraph view.
    143 
    144   Args:
    145     sgv: the subgraph view to be detached. This argument is converted to a
    146       subgraph using the same rules as the function subgraph.make_view.
    147       Note that sgv is modified in place.
    148     control_inputs: A boolean indicating whether control inputs are enabled.
    149     control_outputs: An instance of util.ControlOutputs or None. If not None,
    150       control outputs are enabled.
    151     control_ios:  An instance of util.ControlOutputs or None. If not None, both
    152       control inputs and control outputs are enabled. This is equivalent to set
    153       control_inputs to True and control_outputs to the util.ControlOutputs
    154       instance.
    155   Returns:
    156     A tuple `(sgv, detached_inputs, detached_outputs)` where:
    157     `sgv` is a new subgraph view of the detached subgraph;
    158     `detach_inputs` is a list of the created input placeholders;
    159     `detach_outputs` is a list of the created output placeholders.
    160   Raises:
    161     StandardError: if sgv cannot be converted to a SubGraphView using
    162       the same rules than the function subgraph.make_view.
    163   """
    164   control_inputs, control_outputs = select.check_cios(control_inputs,
    165                                                       control_outputs,
    166                                                       control_ios)
    167   _, detached_inputs = detach_inputs(sgv, control_inputs)
    168   _, detached_outputs = detach_outputs(sgv, control_outputs)
    169   return sgv, detached_inputs, detached_outputs
    170 
    171 
    172 def connect(sgv0, sgv1, disconnect_first=False):
    173   """Connect the outputs of sgv0 to the inputs of sgv1.
    174 
    175   Args:
    176     sgv0: the first subgraph to have its outputs swapped. This argument is
    177       converted to a subgraph using the same rules as the function
    178       subgraph.make_view.
    179       Note that sgv0 is modified in place.
    180     sgv1: the second subgraph to have its outputs swapped. This argument is
    181       converted to a subgraph using the same rules as the function
    182       subgraph.make_view.
    183       Note that sgv1 is modified in place.
    184     disconnect_first: if True the current outputs of sgv0 are disconnected.
    185   Returns:
    186     A tuple `(sgv0, sgv1)` of the now connected subgraphs.
    187   Raises:
    188     StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
    189       the same rules than the function subgraph.make_view.
    190   """
    191   sgv0 = subgraph.make_view(sgv0)
    192   sgv1 = subgraph.make_view(sgv1)
    193   util.check_graphs(sgv0, sgv1)
    194   if disconnect_first:
    195     detach_outputs(sgv0)
    196   sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
    197   reroute.reroute_inputs(sgv0_outputs, sgv1)
    198   return sgv0, sgv1
    199 
    200 
    201 def bypass(sgv):
    202   """Bypass the given subgraph by connecting its inputs to its outputs.
    203 
    204   Args:
    205     sgv: the subgraph view to be bypassed. This argument is converted to a
    206       subgraph using the same rules than the function subgraph.make_view.
    207       Note that sgv is modified in place.
    208   Returns:
    209     A tuple `(sgv, detached_inputs)` where:
    210       `sgv` is a new subgraph view of the bypassed subgraph;
    211       `detached_inputs` is a list of the created input placeholders.
    212   Raises:
    213     StandardError: if sgv cannot be converted to a SubGraphView using
    214       the same rules than the function subgraph.make_view.
    215   """
    216   # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
    217   sgv = subgraph.make_view(sgv)
    218   sgv_inputs = list(sgv.inputs)
    219   sgv, detached_inputs = detach_inputs(sgv)
    220   reroute.reroute_ts(sgv_inputs, sgv.outputs)
    221   return sgv, detached_inputs
    222