Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TensorFlow Debugger (tfdbg) Utilities."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import re
     22 
     23 from six.moves import xrange  # pylint: disable=redefined-builtin
     24 
     25 
     26 def add_debug_tensor_watch(run_options,
     27                            node_name,
     28                            output_slot=0,
     29                            debug_ops="DebugIdentity",
     30                            debug_urls=None,
     31                            tolerate_debug_op_creation_failures=False,
     32                            global_step=-1):
     33   """Add watch on a `Tensor` to `RunOptions`.
     34 
     35   N.B.:
     36     1. Under certain circumstances, the `Tensor` may not get actually watched
     37       (e.g., if the node of the `Tensor` is constant-folded during runtime).
     38     2. For debugging purposes, the `parallel_iteration` attribute of all
     39       `tf.while_loop`s in the graph are set to 1 to prevent any node from
     40       being executed multiple times concurrently. This change does not affect
     41       subsequent non-debugged runs of the same `tf.while_loop`s.
     42 
     43   Args:
     44     run_options: An instance of `config_pb2.RunOptions` to be modified.
     45     node_name: (`str`) name of the node to watch.
     46     output_slot: (`int`) output slot index of the tensor from the watched node.
     47     debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a
     48       `list` of `str` or a single `str`. The latter case is equivalent to a
     49       `list` of `str` with only one element.
     50       For debug op types with customizable attributes, each debug op string can
     51       optionally contain a list of attribute names, in the syntax of:
     52         debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
     53     debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to,
     54       e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
     55     tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op
     56       creation failures by not throwing exceptions.
     57     global_step: (`int`) Optional global_step count for this debug tensor
     58       watch.
     59   """
     60 
     61   watch_opts = run_options.debug_options.debug_tensor_watch_opts
     62   run_options.debug_options.global_step = global_step
     63 
     64   watch = watch_opts.add()
     65   watch.tolerate_debug_op_creation_failures = (
     66       tolerate_debug_op_creation_failures)
     67   watch.node_name = node_name
     68   watch.output_slot = output_slot
     69 
     70   if isinstance(debug_ops, str):
     71     debug_ops = [debug_ops]
     72 
     73   watch.debug_ops.extend(debug_ops)
     74 
     75   if debug_urls:
     76     if isinstance(debug_urls, str):
     77       debug_urls = [debug_urls]
     78 
     79     watch.debug_urls.extend(debug_urls)
     80 
     81 
     82 def watch_graph(run_options,
     83                 graph,
     84                 debug_ops="DebugIdentity",
     85                 debug_urls=None,
     86                 node_name_regex_whitelist=None,
     87                 op_type_regex_whitelist=None,
     88                 tensor_dtype_regex_whitelist=None,
     89                 tolerate_debug_op_creation_failures=False,
     90                 global_step=-1):
     91   """Add debug watches to `RunOptions` for a TensorFlow graph.
     92 
     93   To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
     94   and `op_type_regex_whitelist` be the default (`None`).
     95 
     96   N.B.:
     97     1. Under certain circumstances, the `Tensor` may not get actually watched
     98       (e.g., if the node of the `Tensor` is constant-folded during runtime).
     99     2. For debugging purposes, the `parallel_iteration` attribute of all
    100       `tf.while_loop`s in the graph are set to 1 to prevent any node from
    101       being executed multiple times concurrently. This change does not affect
    102       subsequent non-debugged runs of the same `tf.while_loop`s.
    103 
    104 
    105   Args:
    106     run_options: An instance of `config_pb2.RunOptions` to be modified.
    107     graph: An instance of `ops.Graph`.
    108     debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
    109     debug_urls: URLs to send debug values to. Can be a list of strings,
    110       a single string, or None. The case of a single string is equivalent to
    111       a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`,
    112       `grpc://localhost:12345`.
    113       For debug op types with customizable attributes, each debug op name string
    114       can optionally contain a list of attribute names, in the syntax of:
    115         debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
    116     node_name_regex_whitelist: Regular-expression whitelist for node_name,
    117       e.g., `"(weight_[0-9]+|bias_.*)"`
    118     op_type_regex_whitelist: Regular-expression whitelist for the op type of
    119       nodes, e.g., `"(Variable|Add)"`.
    120       If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
    121       are set, the two filtering operations will occur in a logical `AND`
    122       relation. In other words, a node will be included if and only if it
    123       hits both whitelists.
    124     tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
    125       data type, e.g., `"^int.*"`.
    126       This whitelist operates in logical `AND` relations to the two whitelists
    127       above.
    128     tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
    129       failures (e.g., due to dtype incompatibility) are to be tolerated by not
    130       throwing exceptions.
    131     global_step: (`int`) Optional global_step count for this debug tensor
    132       watch.
    133   """
    134 
    135   if isinstance(debug_ops, str):
    136     debug_ops = [debug_ops]
    137 
    138   node_name_pattern = (re.compile(node_name_regex_whitelist)
    139                        if node_name_regex_whitelist else None)
    140   op_type_pattern = (re.compile(op_type_regex_whitelist)
    141                      if op_type_regex_whitelist else None)
    142   tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist)
    143                           if tensor_dtype_regex_whitelist else None)
    144 
    145   ops = graph.get_operations()
    146   for op in ops:
    147     # Skip nodes without any output tensors.
    148     if not op.outputs:
    149       continue
    150 
    151     node_name = op.name
    152     op_type = op.type
    153 
    154     if node_name_pattern and not node_name_pattern.match(node_name):
    155       continue
    156     if op_type_pattern and not op_type_pattern.match(op_type):
    157       continue
    158 
    159     for slot in xrange(len(op.outputs)):
    160       if (tensor_dtype_pattern and
    161           not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
    162         continue
    163 
    164       add_debug_tensor_watch(
    165           run_options,
    166           node_name,
    167           output_slot=slot,
    168           debug_ops=debug_ops,
    169           debug_urls=debug_urls,
    170           tolerate_debug_op_creation_failures=(
    171               tolerate_debug_op_creation_failures),
    172           global_step=global_step)
    173 
    174 
    175 def watch_graph_with_blacklists(run_options,
    176                                 graph,
    177                                 debug_ops="DebugIdentity",
    178                                 debug_urls=None,
    179                                 node_name_regex_blacklist=None,
    180                                 op_type_regex_blacklist=None,
    181                                 tensor_dtype_regex_blacklist=None,
    182                                 tolerate_debug_op_creation_failures=False,
    183                                 global_step=-1):
    184   """Add debug tensor watches, blacklisting nodes and op types.
    185 
    186   This is similar to `watch_graph()`, but the node names and op types are
    187   blacklisted, instead of whitelisted.
    188 
    189   N.B.:
    190     1. Under certain circumstances, the `Tensor` may not get actually watched
    191       (e.g., if the node of the `Tensor` is constant-folded during runtime).
    192     2. For debugging purposes, the `parallel_iteration` attribute of all
    193       `tf.while_loop`s in the graph are set to 1 to prevent any node from
    194       being executed multiple times concurrently. This change does not affect
    195       subsequent non-debugged runs of the same `tf.while_loop`s.
    196 
    197   Args:
    198     run_options: An instance of `config_pb2.RunOptions` to be modified.
    199     graph: An instance of `ops.Graph`.
    200     debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
    201       See the documentation of `watch_graph` for more details.
    202     debug_urls: URL(s) to send debug values to, e.g.,
    203       `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
    204     node_name_regex_blacklist: Regular-expression blacklist for node_name.
    205       This should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`.
    206     op_type_regex_blacklist: Regular-expression blacklist for the op type of
    207       nodes, e.g., `"(Variable|Add)"`.
    208       If both node_name_regex_blacklist and op_type_regex_blacklist
    209       are set, the two filtering operations will occur in a logical `OR`
    210       relation. In other words, a node will be excluded if it hits either of
    211       the two blacklists; a node will be included if and only if it hits
    212       neither of the blacklists.
    213     tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor
    214       data type, e.g., `"^int.*"`.
    215       This blacklist operates in logical `OR` relations to the two whitelists
    216       above.
    217     tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
    218       failures (e.g., due to dtype incompatibility) are to be tolerated by not
    219       throwing exceptions.
    220     global_step: (`int`) Optional global_step count for this debug tensor
    221       watch.
    222   """
    223 
    224   if isinstance(debug_ops, str):
    225     debug_ops = [debug_ops]
    226 
    227   node_name_pattern = (re.compile(node_name_regex_blacklist) if
    228                        node_name_regex_blacklist else None)
    229   op_type_pattern = (re.compile(op_type_regex_blacklist) if
    230                      op_type_regex_blacklist else None)
    231   tensor_dtype_pattern = (re.compile(tensor_dtype_regex_blacklist) if
    232                           tensor_dtype_regex_blacklist else None)
    233 
    234   ops = graph.get_operations()
    235   for op in ops:
    236     # Skip nodes without any output tensors.
    237     if not op.outputs:
    238       continue
    239 
    240     node_name = op.name
    241     op_type = op.type
    242 
    243     if node_name_pattern and node_name_pattern.match(node_name):
    244       continue
    245     if op_type_pattern and op_type_pattern.match(op_type):
    246       continue
    247 
    248     for slot in xrange(len(op.outputs)):
    249       if (tensor_dtype_pattern and
    250           tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
    251         continue
    252 
    253       add_debug_tensor_watch(
    254           run_options,
    255           node_name,
    256           output_slot=slot,
    257           debug_ops=debug_ops,
    258           debug_urls=debug_urls,
    259           tolerate_debug_op_creation_failures=(
    260               tolerate_debug_op_creation_failures),
    261           global_step=global_step)
    262