1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """TensorFlow Debugger (tfdbg) Utilities.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import re 22 23 from six.moves import xrange # pylint: disable=redefined-builtin 24 25 26 def add_debug_tensor_watch(run_options, 27 node_name, 28 output_slot=0, 29 debug_ops="DebugIdentity", 30 debug_urls=None, 31 tolerate_debug_op_creation_failures=False, 32 global_step=-1): 33 """Add watch on a `Tensor` to `RunOptions`. 34 35 N.B.: 36 1. Under certain circumstances, the `Tensor` may not get actually watched 37 (e.g., if the node of the `Tensor` is constant-folded during runtime). 38 2. For debugging purposes, the `parallel_iteration` attribute of all 39 `tf.while_loop`s in the graph are set to 1 to prevent any node from 40 being executed multiple times concurrently. This change does not affect 41 subsequent non-debugged runs of the same `tf.while_loop`s. 42 43 Args: 44 run_options: An instance of `config_pb2.RunOptions` to be modified. 45 node_name: (`str`) name of the node to watch. 46 output_slot: (`int`) output slot index of the tensor from the watched node. 47 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a 48 `list` of `str` or a single `str`. The latter case is equivalent to a 49 `list` of `str` with only one element. 50 For debug op types with customizable attributes, each debug op string can 51 optionally contain a list of attribute names, in the syntax of: 52 debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...) 53 debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to, 54 e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`. 55 tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op 56 creation failures by not throwing exceptions. 57 global_step: (`int`) Optional global_step count for this debug tensor 58 watch. 59 """ 60 61 watch_opts = run_options.debug_options.debug_tensor_watch_opts 62 run_options.debug_options.global_step = global_step 63 64 watch = watch_opts.add() 65 watch.tolerate_debug_op_creation_failures = ( 66 tolerate_debug_op_creation_failures) 67 watch.node_name = node_name 68 watch.output_slot = output_slot 69 70 if isinstance(debug_ops, str): 71 debug_ops = [debug_ops] 72 73 watch.debug_ops.extend(debug_ops) 74 75 if debug_urls: 76 if isinstance(debug_urls, str): 77 debug_urls = [debug_urls] 78 79 watch.debug_urls.extend(debug_urls) 80 81 82 def watch_graph(run_options, 83 graph, 84 debug_ops="DebugIdentity", 85 debug_urls=None, 86 node_name_regex_whitelist=None, 87 op_type_regex_whitelist=None, 88 tensor_dtype_regex_whitelist=None, 89 tolerate_debug_op_creation_failures=False, 90 global_step=-1): 91 """Add debug watches to `RunOptions` for a TensorFlow graph. 92 93 To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist` 94 and `op_type_regex_whitelist` be the default (`None`). 95 96 N.B.: 97 1. Under certain circumstances, the `Tensor` may not get actually watched 98 (e.g., if the node of the `Tensor` is constant-folded during runtime). 99 2. For debugging purposes, the `parallel_iteration` attribute of all 100 `tf.while_loop`s in the graph are set to 1 to prevent any node from 101 being executed multiple times concurrently. This change does not affect 102 subsequent non-debugged runs of the same `tf.while_loop`s. 103 104 105 Args: 106 run_options: An instance of `config_pb2.RunOptions` to be modified. 107 graph: An instance of `ops.Graph`. 108 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. 109 debug_urls: URLs to send debug values to. Can be a list of strings, 110 a single string, or None. The case of a single string is equivalent to 111 a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`, 112 `grpc://localhost:12345`. 113 For debug op types with customizable attributes, each debug op name string 114 can optionally contain a list of attribute names, in the syntax of: 115 debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...) 116 node_name_regex_whitelist: Regular-expression whitelist for node_name, 117 e.g., `"(weight_[0-9]+|bias_.*)"` 118 op_type_regex_whitelist: Regular-expression whitelist for the op type of 119 nodes, e.g., `"(Variable|Add)"`. 120 If both `node_name_regex_whitelist` and `op_type_regex_whitelist` 121 are set, the two filtering operations will occur in a logical `AND` 122 relation. In other words, a node will be included if and only if it 123 hits both whitelists. 124 tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor 125 data type, e.g., `"^int.*"`. 126 This whitelist operates in logical `AND` relations to the two whitelists 127 above. 128 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 129 failures (e.g., due to dtype incompatibility) are to be tolerated by not 130 throwing exceptions. 131 global_step: (`int`) Optional global_step count for this debug tensor 132 watch. 133 """ 134 135 if isinstance(debug_ops, str): 136 debug_ops = [debug_ops] 137 138 node_name_pattern = (re.compile(node_name_regex_whitelist) 139 if node_name_regex_whitelist else None) 140 op_type_pattern = (re.compile(op_type_regex_whitelist) 141 if op_type_regex_whitelist else None) 142 tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist) 143 if tensor_dtype_regex_whitelist else None) 144 145 ops = graph.get_operations() 146 for op in ops: 147 # Skip nodes without any output tensors. 148 if not op.outputs: 149 continue 150 151 node_name = op.name 152 op_type = op.type 153 154 if node_name_pattern and not node_name_pattern.match(node_name): 155 continue 156 if op_type_pattern and not op_type_pattern.match(op_type): 157 continue 158 159 for slot in xrange(len(op.outputs)): 160 if (tensor_dtype_pattern and 161 not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)): 162 continue 163 164 add_debug_tensor_watch( 165 run_options, 166 node_name, 167 output_slot=slot, 168 debug_ops=debug_ops, 169 debug_urls=debug_urls, 170 tolerate_debug_op_creation_failures=( 171 tolerate_debug_op_creation_failures), 172 global_step=global_step) 173 174 175 def watch_graph_with_blacklists(run_options, 176 graph, 177 debug_ops="DebugIdentity", 178 debug_urls=None, 179 node_name_regex_blacklist=None, 180 op_type_regex_blacklist=None, 181 tensor_dtype_regex_blacklist=None, 182 tolerate_debug_op_creation_failures=False, 183 global_step=-1): 184 """Add debug tensor watches, blacklisting nodes and op types. 185 186 This is similar to `watch_graph()`, but the node names and op types are 187 blacklisted, instead of whitelisted. 188 189 N.B.: 190 1. Under certain circumstances, the `Tensor` may not get actually watched 191 (e.g., if the node of the `Tensor` is constant-folded during runtime). 192 2. For debugging purposes, the `parallel_iteration` attribute of all 193 `tf.while_loop`s in the graph are set to 1 to prevent any node from 194 being executed multiple times concurrently. This change does not affect 195 subsequent non-debugged runs of the same `tf.while_loop`s. 196 197 Args: 198 run_options: An instance of `config_pb2.RunOptions` to be modified. 199 graph: An instance of `ops.Graph`. 200 debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. 201 See the documentation of `watch_graph` for more details. 202 debug_urls: URL(s) to send debug values to, e.g., 203 `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`. 204 node_name_regex_blacklist: Regular-expression blacklist for node_name. 205 This should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`. 206 op_type_regex_blacklist: Regular-expression blacklist for the op type of 207 nodes, e.g., `"(Variable|Add)"`. 208 If both node_name_regex_blacklist and op_type_regex_blacklist 209 are set, the two filtering operations will occur in a logical `OR` 210 relation. In other words, a node will be excluded if it hits either of 211 the two blacklists; a node will be included if and only if it hits 212 neither of the blacklists. 213 tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor 214 data type, e.g., `"^int.*"`. 215 This blacklist operates in logical `OR` relations to the two whitelists 216 above. 217 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 218 failures (e.g., due to dtype incompatibility) are to be tolerated by not 219 throwing exceptions. 220 global_step: (`int`) Optional global_step count for this debug tensor 221 watch. 222 """ 223 224 if isinstance(debug_ops, str): 225 debug_ops = [debug_ops] 226 227 node_name_pattern = (re.compile(node_name_regex_blacklist) if 228 node_name_regex_blacklist else None) 229 op_type_pattern = (re.compile(op_type_regex_blacklist) if 230 op_type_regex_blacklist else None) 231 tensor_dtype_pattern = (re.compile(tensor_dtype_regex_blacklist) if 232 tensor_dtype_regex_blacklist else None) 233 234 ops = graph.get_operations() 235 for op in ops: 236 # Skip nodes without any output tensors. 237 if not op.outputs: 238 continue 239 240 node_name = op.name 241 op_type = op.type 242 243 if node_name_pattern and node_name_pattern.match(node_name): 244 continue 245 if op_type_pattern and op_type_pattern.match(op_type): 246 continue 247 248 for slot in xrange(len(op.outputs)): 249 if (tensor_dtype_pattern and 250 tensor_dtype_pattern.match(op.outputs[slot].dtype.name)): 251 continue 252 253 add_debug_tensor_watch( 254 run_options, 255 node_name, 256 output_slot=slot, 257 debug_ops=debug_ops, 258 debug_urls=debug_urls, 259 tolerate_debug_op_creation_failures=( 260 tolerate_debug_op_creation_failures), 261 global_step=global_step) 262