Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Classes and methods for processing debugger-decorated graphs."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from six.moves import xrange  # pylint: disable=redefined-builtin
     21 
     22 from tensorflow.core.framework import graph_pb2
     23 from tensorflow.python.framework import op_def_registry
     24 from tensorflow.python.platform import tf_logging as logging
     25 
     26 
     27 def parse_node_or_tensor_name(name):
     28   """Get the node name from a string that can be node or tensor name.
     29 
     30   Args:
     31     name: An input node name (e.g., "node_a") or tensor name (e.g.,
     32       "node_a:0"), as a str.
     33 
     34   Returns:
     35     1) The node name, as a str. If the input name is a tensor name, i.e.,
     36       consists of a colon, the final colon and the following output slot
     37       will be stripped.
     38     2) If the input name is a tensor name, the output slot, as an int. If
     39       the input name is not a tensor name, None.
     40   """
     41 
     42   if ":" in name and not name.endswith(":"):
     43     node_name = name[:name.rfind(":")]
     44     output_slot = int(name[name.rfind(":") + 1:])
     45 
     46     return node_name, output_slot
     47   else:
     48     return name, None
     49 
     50 
     51 def get_node_name(element_name):
     52   node_name, _ = parse_node_or_tensor_name(element_name)
     53   return node_name
     54 
     55 
     56 def get_output_slot(element_name):
     57   """Get the output slot number from the name of a graph element.
     58 
     59   If element_name is a node name without output slot at the end, 0 will be
     60   assumed.
     61 
     62   Args:
     63     element_name: (`str`) name of the graph element in question.
     64 
     65   Returns:
     66     (`int`) output slot number.
     67   """
     68   _, output_slot = parse_node_or_tensor_name(element_name)
     69   return output_slot if output_slot is not None else 0
     70 
     71 
     72 def is_copy_node(node_name):
     73   """Determine whether a node name is that of a debug Copy node.
     74 
     75   Such nodes are inserted by TensorFlow core upon request in
     76   RunOptions.debug_options.debug_tensor_watch_opts.
     77 
     78   Args:
     79     node_name: Name of the node.
     80 
     81   Returns:
     82     A bool indicating whether the input argument is the name of a debug Copy
     83     node.
     84   """
     85   return node_name.startswith("__copy_")
     86 
     87 
     88 def is_debug_node(node_name):
     89   """Determine whether a node name is that of a debug node.
     90 
     91   Such nodes are inserted by TensorFlow core upon request in
     92   RunOptions.debug_options.debug_tensor_watch_opts.
     93 
     94   Args:
     95     node_name: Name of the node.
     96 
     97   Returns:
     98     A bool indicating whether the input argument is the name of a debug node.
     99   """
    100   return node_name.startswith("__dbg_")
    101 
    102 
    103 def parse_debug_node_name(node_name):
    104   """Parse the name of a debug node.
    105 
    106   Args:
    107     node_name: Name of the debug node.
    108 
    109   Returns:
    110     1. Name of the watched node, as a str.
    111     2. Output slot index of the watched tensor, as an int.
    112     3. Index of the debug node, as an int.
    113     4. Name of the debug op, as a str, e.g, "DebugIdentity".
    114 
    115   Raises:
    116     ValueError: If the input node name is not a valid debug node name.
    117   """
    118   prefix = "__dbg_"
    119 
    120   name = node_name
    121   if not name.startswith(prefix):
    122     raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
    123 
    124   name = name[len(prefix):]
    125 
    126   if name.count("_") < 2:
    127     raise ValueError("Invalid debug node name: '%s'" % node_name)
    128 
    129   debug_op = name[name.rindex("_") + 1:]
    130   name = name[:name.rindex("_")]
    131 
    132   debug_op_index = int(name[name.rindex("_") + 1:])
    133   name = name[:name.rindex("_")]
    134 
    135   if name.count(":") != 1:
    136     raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
    137 
    138   watched_node_name = name[:name.index(":")]
    139   watched_output_slot = int(name[name.index(":") + 1:])
    140 
    141   return watched_node_name, watched_output_slot, debug_op_index, debug_op
    142 
    143 
    144 class GraphTracingReachedDestination(Exception):
    145   pass
    146 
    147 
    148 class DFSGraphTracer(object):
    149   """Graph input tracer using depth-first search."""
    150 
    151   def __init__(self,
    152                input_lists,
    153                skip_node_names=None,
    154                destination_node_name=None):
    155     """Constructor of _DFSGraphTracer.
    156 
    157     Args:
    158       input_lists: A list of dicts. Each dict is an adjacency (input) map from
    159         the recipient node name as the key and the list of input node names
    160         as the value.
    161       skip_node_names: Optional: a list of node names to skip tracing.
    162       destination_node_name: Optional: destination node name. If not `None`, it
    163         should be the name of a destination not as a str and the graph tracing
    164         will raise GraphTracingReachedDestination as soon as the node has been
    165         reached.
    166 
    167     Raises:
    168       GraphTracingReachedDestination: if stop_at_node_name is not None and
    169         the specified node is reached.
    170     """
    171 
    172     self._input_lists = input_lists
    173     self._skip_node_names = skip_node_names
    174 
    175     self._inputs = []
    176     self._visited_nodes = []
    177     self._depth_count = 0
    178     self._depth_list = []
    179 
    180     self._destination_node_name = destination_node_name
    181 
    182   def trace(self, graph_element_name):
    183     """Trace inputs.
    184 
    185     Args:
    186       graph_element_name: Name of the node or an output tensor of the node, as a
    187         str.
    188 
    189     Raises:
    190       GraphTracingReachedDestination: if destination_node_name of this tracer
    191         object is not None and the specified node is reached.
    192     """
    193     self._depth_count += 1
    194 
    195     node_name = get_node_name(graph_element_name)
    196     if node_name == self._destination_node_name:
    197       raise GraphTracingReachedDestination()
    198 
    199     if node_name in self._skip_node_names:
    200       return
    201     if node_name in self._visited_nodes:
    202       return
    203 
    204     self._visited_nodes.append(node_name)
    205 
    206     for input_list in self._input_lists:
    207       if node_name not in input_list:
    208         continue
    209       for inp in input_list[node_name]:
    210         if get_node_name(inp) in self._visited_nodes:
    211           continue
    212         self._inputs.append(inp)
    213         self._depth_list.append(self._depth_count)
    214         self.trace(inp)
    215 
    216     self._depth_count -= 1
    217 
    218   def inputs(self):
    219     return self._inputs
    220 
    221   def depth_list(self):
    222     return self._depth_list
    223 
    224 
    225 def _infer_device_name(graph_def):
    226   """Infer device name from a partition GraphDef."""
    227   device_name = None
    228   for node in graph_def.node:
    229     if node.device:
    230       device_name = node.device
    231       break
    232   if device_name is None:
    233     logging.warn(
    234         "Failed to infer device name from partition GraphDef: none of the "
    235         "nodes of the GraphDef has a non-empty device name.")
    236   return device_name
    237 
    238 
    239 class DebugGraph(object):
    240   """Represents a debugger-decorated graph."""
    241 
    242   def __init__(self, debug_graph_def, device_name=None):
    243     self._debug_graph_def = debug_graph_def
    244     self._non_debug_graph_def = None
    245 
    246     self._node_attributes = {}
    247     self._node_inputs = {}
    248     self._node_reversed_ref_inputs = {}
    249     self._node_ctrl_inputs = {}
    250     self._node_recipients = {}
    251     self._node_ctrl_recipients = {}
    252     self._node_devices = {}
    253     self._node_op_types = {}
    254     self._copy_send_nodes = []
    255     self._ref_args = {}
    256 
    257     self._device_name = device_name
    258     if not self._device_name:
    259       self._device_name = _infer_device_name(debug_graph_def)
    260 
    261     for node in debug_graph_def.node:
    262       self._process_debug_graph_node(node)
    263 
    264     self._prune_non_control_edges_of_debug_ops()
    265     self._prune_control_edges_of_debug_ops()
    266     self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes())
    267 
    268     self._populate_recipient_maps()
    269 
    270   def _process_debug_graph_node(self, node):
    271     """Process a node from the debug GraphDef.
    272 
    273     Args:
    274       node: (NodeDef) A partition-graph node to be processed.
    275 
    276     Raises:
    277       ValueError: If duplicate node names are encountered.
    278     """
    279     if is_debug_node(node.name):
    280       # This is a debug node. Parse the node name and retrieve the
    281       # information about debug watches on tensors. But do not include
    282       # the node in the graph.
    283       return
    284 
    285     if node.name in self._node_inputs:
    286       raise ValueError("Duplicate node name on device %s: '%s'" %
    287                        (self._device_name, node.name))
    288 
    289     self._node_attributes[node.name] = node.attr
    290 
    291     self._node_inputs[node.name] = []
    292     self._node_ctrl_inputs[node.name] = []
    293     self._node_recipients[node.name] = []
    294     self._node_ctrl_recipients[node.name] = []
    295 
    296     if node.name not in self._node_devices:
    297       self._node_devices[node.name] = set()
    298     self._node_devices[node.name].add(
    299         node.device if node.device else self._device_name)
    300     self._node_op_types[node.name] = node.op
    301     self._ref_args[node.name] = self._get_ref_args(node)
    302 
    303     for inp in node.input:
    304       if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
    305         self._copy_send_nodes.append(node.name)
    306 
    307       if inp.startswith("^"):
    308         cinp = inp[1:]
    309         self._node_ctrl_inputs[node.name].append(cinp)
    310       else:
    311         self._node_inputs[node.name].append(inp)
    312 
    313   def _get_ref_args(self, node):
    314     """Determine whether an input of an op is ref-type.
    315 
    316     Args:
    317       node: A `NodeDef`.
    318 
    319     Returns:
    320       A list of the arg names (as strs) that are ref-type.
    321     """
    322     op_def = op_def_registry.get_registered_ops().get(node.op)
    323     ref_args = []
    324     if op_def:
    325       for i, output_arg in enumerate(op_def.output_arg):
    326         if output_arg.is_ref:
    327           arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
    328           ref_args.append(arg_name)
    329     return ref_args
    330 
    331   def _get_copy_nodes(self):
    332     """Find all Copy nodes in the loaded graph."""
    333     copy_nodes = []
    334     for node in self._node_inputs:
    335       if is_copy_node(node):
    336         copy_nodes.append(node)
    337     return copy_nodes
    338 
    339   def _prune_non_control_edges_of_debug_ops(self):
    340     """Prune (non-control) edges related to debug ops.
    341 
    342     Prune the Copy ops and associated _Send ops inserted by the debugger out
    343     from the non-control inputs and output recipients map. Replace the inputs
    344     and recipients with original ones.
    345     """
    346     for node in self._node_inputs:
    347       inputs = self._node_inputs[node]
    348 
    349       for i in xrange(len(inputs)):
    350         inp = inputs[i]
    351         if is_copy_node(inp):
    352           # Find the input to the Copy node, which should be the original
    353           # input to the node.
    354           orig_inp = self._node_inputs[inp][0]
    355           inputs[i] = orig_inp
    356 
    357   def _prune_control_edges_of_debug_ops(self):
    358     """Prune control edges related to the debug ops."""
    359     for node in self._node_ctrl_inputs:
    360       ctrl_inputs = self._node_ctrl_inputs[node]
    361       debug_op_inputs = []
    362       for ctrl_inp in ctrl_inputs:
    363         if is_debug_node(ctrl_inp):
    364           debug_op_inputs.append(ctrl_inp)
    365       for debug_op_inp in debug_op_inputs:
    366         ctrl_inputs.remove(debug_op_inp)
    367 
    368   def _populate_recipient_maps(self):
    369     """Populate the map from node name to recipient(s) of its output(s).
    370 
    371     This method also populates the input map based on reversed ref edges.
    372     """
    373     for node in self._node_inputs:
    374       inputs = self._node_inputs[node]
    375       for inp in inputs:
    376         inp = get_node_name(inp)
    377         if inp not in self._node_recipients:
    378           self._node_recipients[inp] = []
    379         self._node_recipients[inp].append(node)
    380 
    381         if inp in self._ref_args:
    382           if inp not in self._node_reversed_ref_inputs:
    383             self._node_reversed_ref_inputs[inp] = []
    384           self._node_reversed_ref_inputs[inp].append(node)
    385 
    386     for node in self._node_ctrl_inputs:
    387       ctrl_inputs = self._node_ctrl_inputs[node]
    388       for ctrl_inp in ctrl_inputs:
    389         if ctrl_inp in self._copy_send_nodes:
    390           continue
    391 
    392         if ctrl_inp not in self._node_ctrl_recipients:
    393           self._node_ctrl_recipients[ctrl_inp] = []
    394         self._node_ctrl_recipients[ctrl_inp].append(node)
    395 
    396   def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
    397     """Prune nodes out of input and recipient maps.
    398 
    399     Args:
    400       nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
    401     """
    402     for node in nodes_to_prune:
    403       del self._node_inputs[node]
    404       del self._node_ctrl_inputs[node]
    405       del self._node_recipients[node]
    406       del self._node_ctrl_recipients[node]
    407 
    408   def _reconstruct_non_debug_graph_def(self):
    409     """Reconstruct non-debug GraphDef.
    410 
    411     Non-debug GraphDef means the original GraphDef without the Copy* and Debug
    412     nodes inserted by the debugger.
    413     """
    414     if self._non_debug_graph_def:
    415       return
    416 
    417     self._non_debug_graph_def = graph_pb2.GraphDef()
    418     for node in self._debug_graph_def.node:
    419       if is_copy_node(node.name) or is_debug_node(node.name):
    420         continue
    421 
    422       new_node = self._non_debug_graph_def.node.add()
    423       new_node.CopyFrom(node)
    424 
    425       # Redo the list of inputs, because in _debug_graph_def, the list can
    426       # consist of Copy* and Debug* nodes inserted by the debugger. Those will
    427       # be replaced with the original inputs here.
    428       del new_node.input[:]
    429       for inp in self._node_inputs[node.name]:
    430         new_node.input.append(inp)
    431       for ctrl_inp in self._node_ctrl_inputs[node.name]:
    432         new_node.input.append("^" + ctrl_inp)
    433 
    434   @property
    435   def device_name(self):
    436     return self._device_name
    437 
    438   @property
    439   def debug_graph_def(self):
    440     """The debugger-decorated GraphDef."""
    441     return self._debug_graph_def
    442 
    443   @property
    444   def non_debug_graph_def(self):
    445     """The GraphDef without the Copy* and Debug* nodes added by the debugger."""
    446     self._reconstruct_non_debug_graph_def()
    447     return self._non_debug_graph_def
    448 
    449   @property
    450   def node_devices(self):
    451     return self._node_devices
    452 
    453   @property
    454   def node_op_types(self):
    455     return self._node_op_types
    456 
    457   @property
    458   def node_attributes(self):
    459     return self._node_attributes
    460 
    461   @property
    462   def node_inputs(self):
    463     return self._node_inputs
    464 
    465   @property
    466   def node_ctrl_inputs(self):
    467     return self._node_ctrl_inputs
    468 
    469   @property
    470   def node_reversed_ref_inputs(self):
    471     return self._node_reversed_ref_inputs
    472 
    473   @property
    474   def node_recipients(self):
    475     return self._node_recipients
    476 
    477   @property
    478   def node_ctrl_recipients(self):
    479     return self._node_ctrl_recipients
    480 
    481 
    482 def reconstruct_non_debug_graph_def(debug_graph_def):
    483   """Reconstruct original (non-debugger-decorated) partition GraphDef.
    484 
    485   This method strips the input `tf.GraphDef` of the Copy* and Debug*-type nodes
    486   inserted by the debugger.
    487 
    488   The reconstructed partition graph is identical to the original (i.e.,
    489     non-debugger-decorated) partition graph except in the following respects:
    490       1) The exact names of the runtime-inserted internal nodes may differ.
    491          These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops.
    492       2) As a consequence of 1, the nodes that receive input directly from such
    493          send- and recv-type ops will have different input names.
    494       3) The parallel_iteration attribute of while-loop Enter ops are set to 1.
    495 
    496   Args:
    497     debug_graph_def: The debugger-decorated `tf.GraphDef`, with the
    498       debugger-inserted Copy* and Debug* nodes.
    499 
    500   Returns:
    501     The reconstructed `tf.GraphDef` stripped of the debugger-inserted nodes.
    502   """
    503   return DebugGraph(debug_graph_def).non_debug_graph_def
    504