1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Framework of debug wrapper sessions. 16 17 A debug wrapper session is a wrapper around a TensorFlow Python Session. 18 The wrapper preserves the Session interface, most importantly the run() method, 19 while providing abilities to: 20 a) Intercept a run() call to a wrapped session and insert debug tensor watches 21 according to externally-specified debug URLs. 22 23 b) Release control to an external (i.e., non-Session) object before and after 24 the run() call, so that the external object can perform actions such as 25 launching a UI to let users inspect the intermediate tensors and partition 26 graphs from the run() call. 27 28 c) (To be implemented) Intercept a run() call and give control to DebugStepper 29 to let it perform stepping / continuing-to actions on the graph. 30 31 b) (To be implemented in a future CL) Enter an instruction loop to let an 32 external object (e.g., remote client) launch run() and cont() calls 33 remotely. 34 35 *** The lifetime of a debug wrapper session: *** 36 37 1) The wrapper session is created by calling the constructor with a 38 wrapped (normal) session as the argument: 39 wrapper = FooDebugWrapperSession(sess) 40 wherein FooDebugWrapperSession is a concrete subclass implementing the 41 abstract BaseDebugWrapperSession class below. 42 43 2) Near the end of the constructor call, the on_session_init() callback is 44 invoked, with a OnSessionInitRequest object as the argument. The object 45 carries the wrapped (normal) session object. 46 47 3) The callback handles the request and returns a OnSessionInitResponse 48 object with an action field, directing the wrapper session what to do next. 49 50 If the action field in the OnSessionInitResponse is PROCEED, the constuctor 51 returns. Control is released back to the caller of the constructor, which can 52 invoke run() method of wrapper session with the same syntax as a non-wrapped 53 session, e.g.,: 54 wrapper.run(fetches, feed_dict=feeds, options=run_options) 55 56 Below, A1 - A2 is the lifetime of a wrapper run() call if the action is 57 PROCEED: 58 59 A1) Right at the start of each run() call, the on_run_start() callback is 60 invoked, with an OnRunStartRequest object carrying information such as 61 the fetches, the feed dict, the run options and run metadata used in 62 this run call, along with a count of how many run calls has occurred 63 on this wrapper session. The callback then returns an OnRunStartResponse 64 object, of which the action field directs what the wrapper session 65 actually will do of the run() call. 66 67 If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue, 68 with the debug URLs supplied in the debug_urls field of the response. 69 These can be file:// or grpc:// URLs, for example. 70 71 If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue. 72 73 If the action is INVOKE_STEPPER, no run() call will be issued to the 74 wrapped session. But instead, a DebugStepper (i.e., "continuation 75 debugger") will be used to perform stepping / continue-to actions on 76 the graph. 77 78 TODO(cais): The event loop for the DebugStepper will request additional 79 callbacks including on_cont_start() and on_cont_end(). Add those. 80 81 A2) Right before the run() returns, the on_run_end() callback is invoked, 82 with an OnRunEndRequest object as the argument, which carries information 83 including the actual action performed in the warpper run() call and the 84 run_metadata from the run() call. 85 86 However, if the action field in OnSessionInitResponse is 87 REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop 88 that gives the control to a remote caller. 89 90 In the remote instruction loop, the following steps will happen: 91 92 B1) Callback on_instr_start() is invoked. The callback will return an 93 OnInstrStartResponse object with an action field which can order one of 94 the following actions: 95 i) a run() call with fetches, feeds and debug_urls specified. 96 ii) a DebugStepper cont() call with target specified. 97 iii) value overrides in the cached tensors from the DebugStepper. 98 iv) exit the instruction loop. 99 100 B2) The wrapper session carries out the action specified above. 101 102 B3) If still in the instruction loop, the wrapper session invokes the 103 on_instr_end() callback. After the on_instr_end() callback returns, jump 104 back to B1. 105 106 TODO(cais): Implemented the instruction loop in B1 - B3. 107 108 """ 109 110 from __future__ import absolute_import 111 from __future__ import division 112 from __future__ import print_function 113 114 import abc 115 import re 116 import threading 117 118 from tensorflow.core.protobuf import config_pb2 119 from tensorflow.python.client import session 120 from tensorflow.python.debug.lib import debug_utils 121 from tensorflow.python.debug.lib import stepper 122 from tensorflow.python.framework import errors 123 from tensorflow.python.framework import ops 124 from tensorflow.python.platform import tf_logging 125 from tensorflow.python.training import monitored_session 126 from tensorflow.python.util import nest 127 128 129 # Helper function. 130 def _check_type(obj, expected_types): 131 """Check if an object is of the expected type. 132 133 Args: 134 obj: The object being checked. 135 expected_types: (`type` or an iterable of `type`s) The expected `type`(s) 136 of obj. 137 138 Raises: 139 TypeError: If obj is not an instance of expected_type. 140 """ 141 if not isinstance(obj, expected_types): 142 raise TypeError("Expected type %s; got type %s" % 143 (expected_types, type(obj))) 144 145 146 class OnSessionInitRequest(object): 147 """Request to an on-session-init callback. 148 149 This callback is invoked during the __init__ call to a debug-wrapper session. 150 """ 151 152 def __init__(self, sess): 153 """Constructor. 154 155 Args: 156 sess: A tensorflow Session object. 157 """ 158 159 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 160 self.session = sess 161 162 163 class OnSessionInitAction(object): 164 """Enum-like values for possible action to take on session init.""" 165 166 # Proceed, without special actions, in the wrapper session initialization. 167 # What action the wrapper session performs next is determined by the caller 168 # of the wrapper session. E.g., it can call run(). 169 PROCEED = "proceed" 170 171 # Instead of letting the caller of the wrapper session determine what actions 172 # the wrapper session will perform next, enter a loop to receive instructions 173 # from a remote client. 174 # For example, TensorBoard visual debugger can use this action so that it can 175 # launch session.run() calls remotely. 176 REMOTE_INSTR_LOOP = "remote_instr_loop" 177 178 179 class OnSessionInitResponse(object): 180 """Response from an on-session-init callback.""" 181 182 def __init__(self, action): 183 """Constructor. 184 185 Args: 186 action: (`OnSessionInitAction`) Debugger action to take on session init. 187 """ 188 _check_type(action, str) 189 self.action = action 190 191 192 class OnRunStartRequest(object): 193 """Request to an on-run-start callback. 194 195 This callback is invoked during a run() call of the debug-wrapper 196 session, immediately after the run() call counter is incremented. 197 """ 198 199 def __init__(self, fetches, feed_dict, run_options, run_metadata, 200 run_call_count, is_callable_runner=False): 201 """Constructor of `OnRunStartRequest`. 202 203 Args: 204 fetches: Fetch targets of the run() call. 205 feed_dict: The feed dictionary to the run() call. 206 run_options: RunOptions input to the run() call. 207 run_metadata: RunMetadata input to the run() call. 208 The above four arguments are identical to the input arguments to the 209 run() method of a non-wrapped TensorFlow session. 210 run_call_count: 1-based count of how many run calls (including this one) 211 has been invoked. 212 is_callable_runner: (bool) whether a runner returned by 213 Session.make_callable is being run. 214 """ 215 self.fetches = fetches 216 self.feed_dict = feed_dict 217 self.run_options = run_options 218 self.run_metadata = run_metadata 219 self.run_call_count = run_call_count 220 self.is_callable_runner = is_callable_runner 221 222 223 class OnRunStartAction(object): 224 """Enum-like values for possible action to take on start of a run() call.""" 225 226 # Run once with debug tensor-watching. 227 DEBUG_RUN = "debug_run" 228 229 # Run once with profiler. 230 PROFILE_RUN = "profile_run" 231 232 # Run without debug tensor-watching. 233 NON_DEBUG_RUN = "non_debug_run" 234 235 # Instead of running the fetches as a whole, as would normally happen, invoke 236 # the (to-be-implemented) debug stepper. 237 # TODO(cais): Remove "to-be-implemented". 238 INVOKE_STEPPER = "invoke_stepper" 239 240 241 class OnRunStartResponse(object): 242 """Request from an on-run-start callback. 243 244 The caller of the callback can use this response object to specify what 245 action the debug-wrapper session actually takes on the run() call. 246 """ 247 248 def __init__(self, 249 action, 250 debug_urls, 251 debug_ops="DebugIdentity", 252 node_name_regex_whitelist=None, 253 op_type_regex_whitelist=None, 254 tensor_dtype_regex_whitelist=None, 255 tolerate_debug_op_creation_failures=False): 256 """Constructor of `OnRunStartResponse`. 257 258 Args: 259 action: (`OnRunStartAction`) the action actually taken by the wrapped 260 session for the run() call. 261 debug_urls: (`list` of `str`) debug_urls used in watching the tensors 262 during the run() call. 263 debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the 264 debugger. 265 node_name_regex_whitelist: Regular-expression whitelist for node 266 name. 267 op_type_regex_whitelist: Regular-expression whitelist for op type. 268 tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor 269 dtype. 270 tolerate_debug_op_creation_failures: Whether debug op creation failures 271 are to be tolerated. 272 """ 273 274 _check_type(action, str) 275 self.action = action 276 277 _check_type(debug_urls, list) 278 self.debug_urls = debug_urls 279 280 self.debug_ops = debug_ops 281 282 self.node_name_regex_whitelist = node_name_regex_whitelist 283 self.op_type_regex_whitelist = op_type_regex_whitelist 284 self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist 285 self.tolerate_debug_op_creation_failures = ( 286 tolerate_debug_op_creation_failures) 287 288 289 class OnRunEndRequest(object): 290 """Request to an on-run-end callback. 291 292 The callback is invoked immediately before the wrapped run() call ends. 293 """ 294 295 def __init__(self, 296 performed_action, 297 run_metadata=None, 298 client_graph_def=None, 299 tf_error=None): 300 """Constructor for `OnRunEndRequest`. 301 302 Args: 303 performed_action: (`OnRunStartAction`) Actually-performed action by the 304 debug-wrapper session. 305 run_metadata: run_metadata output from the run() call (if any). 306 client_graph_def: (GraphDef) GraphDef from the client side, i.e., from 307 the python front end of TensorFlow. Can be obtained with 308 session.graph.as_graph_def(). 309 tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred 310 during the run (if any). 311 """ 312 313 _check_type(performed_action, str) 314 self.performed_action = performed_action 315 316 if run_metadata is not None: 317 _check_type(run_metadata, config_pb2.RunMetadata) 318 self.run_metadata = run_metadata 319 self.client_graph_def = client_graph_def 320 self.tf_error = tf_error 321 322 323 class OnRunEndResponse(object): 324 """Response from an on-run-end callback.""" 325 326 def __init__(self): 327 328 # Currently only a placeholder. 329 pass 330 331 332 class BaseDebugWrapperSession(session.SessionInterface): 333 """Base class of debug-wrapper session classes. 334 335 Concrete classes that inherit from this class need to implement the abstract 336 methods such as on_session_init, on_run_start and on_run_end. 337 """ 338 339 # TODO(cais): Add on_cont_start and on_cont_end callbacks once the stepper is 340 # is available. 341 342 def __init__(self, sess, thread_name_filter=None, 343 pass_through_operrors=False): 344 """Constructor of `BaseDebugWrapperSession`. 345 346 Args: 347 sess: An (unwrapped) TensorFlow session instance. It should be a subtype 348 of `BaseSession` or `tf.MonitoredSession`. 349 thread_name_filter: Regular-expression filter (whitelist) for name(s) of 350 thread(s) on which the wrapper session will be active. This regular 351 expression is used in a start-anchored fashion on the thread name, i.e., 352 by applying the `match` method of the compiled pattern. The default 353 `None` means that the wrapper session will be active on all threads. 354 E.g., r"MainThread$", r"QueueRunnerThread.*". 355 pass_through_operrors: If True, all captured OpErrors will be 356 propagated. By default this captures all OpErrors. 357 358 Raises: 359 ValueError: On invalid `OnSessionInitAction` value. 360 NotImplementedError: If a non-DirectSession sess object is received. 361 """ 362 363 _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession)) 364 365 # The session being wrapped. 366 self._sess = sess 367 self._thread_name_filter_pattern = (re.compile(thread_name_filter) 368 if thread_name_filter else None) 369 # TODO(cais/kstevens): Unittest this pass through feature. 370 self._pass_through_operrors = pass_through_operrors 371 372 # Keeps track of number of run calls that have been performed on this 373 # debug-wrapper session. The count can be used for purposes such as 374 # displaying the state of the Session in a UI and determining a run 375 # number-dependent debug URL. 376 self._run_call_count = 0 377 378 # Invoke on-session-init callback. 379 response = self.on_session_init(OnSessionInitRequest(self._sess)) 380 _check_type(response, OnSessionInitResponse) 381 382 if response.action == OnSessionInitAction.PROCEED: 383 pass 384 elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP: 385 # TODO(cais): Implement REMOTE_INSTR_LOOP 386 raise NotImplementedError( 387 "OnSessionInitAction REMOTE_INSTR_LOOP has not been " 388 "implemented.") 389 else: 390 raise ValueError( 391 "Invalid OnSessionInitAction value: %s" % response.action) 392 393 self._default_session_context_manager = None 394 395 @property 396 def graph(self): 397 return self._sess.graph 398 399 @property 400 def graph_def(self): 401 return self._sess.graph_def 402 403 @property 404 def sess_str(self): 405 return self._sess.sess_str 406 407 @property 408 def session(self): 409 return self._sess 410 411 def run(self, 412 fetches, 413 feed_dict=None, 414 options=None, 415 run_metadata=None, 416 callable_runner=None, 417 callable_runner_args=None): 418 """Wrapper around Session.run() that inserts tensor watch options. 419 420 Args: 421 fetches: Same as the `fetches` arg to regular `Session.run()`. 422 feed_dict: Same as the `feed_dict` arg to regular `Session.run()`. 423 options: Same as the `options` arg to regular `Session.run()`. 424 run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. 425 callable_runner: A `callable` returned by `Session.make_callable()`. 426 If not `None`, `fetches` and `feed_dict` must both be `None`. 427 callable_runner_args: An optional list of arguments to `callable_runner`. 428 429 Returns: 430 Simply forwards the output of the wrapped `Session.run()` call. 431 432 Raises: 433 ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` 434 is not `None` and either or both of `fetches` and `feed_dict` is `None`. 435 """ 436 if not callable_runner: 437 self.increment_run_call_count() 438 else: 439 if fetches or feed_dict: 440 raise ValueError( 441 "callable_runner and fetches/feed_dict are mutually exclusive, but " 442 "are used simultaneously.") 443 444 empty_fetches = not nest.flatten(fetches) 445 if empty_fetches: 446 tf_logging.info( 447 "Due to empty fetches, tfdbg Session wrapper is letting a " 448 "Session.run pass through without any debugging actions.") 449 if self._is_disabled_thread() or empty_fetches: 450 if callable_runner: 451 return callable_runner(*callable_runner_args) 452 else: 453 return self._sess.run(fetches, 454 feed_dict=feed_dict, 455 options=options, 456 run_metadata=run_metadata) 457 458 # Invoke on-run-start callback and obtain response. 459 run_start_resp = self.on_run_start( 460 OnRunStartRequest(fetches, feed_dict, options, run_metadata, 461 self._run_call_count, 462 is_callable_runner=bool(callable_runner))) 463 _check_type(run_start_resp, OnRunStartResponse) 464 465 if run_start_resp.action == OnRunStartAction.DEBUG_RUN: 466 # Decorate RunOption to fill in debugger tensor watch specifications. 467 decorated_run_options = options or config_pb2.RunOptions() 468 run_metadata = run_metadata or config_pb2.RunMetadata() 469 470 self._decorate_run_options_for_debug( 471 decorated_run_options, 472 run_start_resp.debug_urls, 473 debug_ops=run_start_resp.debug_ops, 474 node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, 475 op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, 476 tensor_dtype_regex_whitelist=( 477 run_start_resp.tensor_dtype_regex_whitelist), 478 tolerate_debug_op_creation_failures=( 479 run_start_resp.tolerate_debug_op_creation_failures)) 480 481 # Invoke the run() method of the wrapped Session. Catch any TensorFlow 482 # runtime errors. 483 tf_error = None 484 try: 485 if callable_runner: 486 retvals = callable_runner(*callable_runner_args, 487 options=decorated_run_options, 488 run_metadata=run_metadata) 489 else: 490 retvals = self._sess.run(fetches, 491 feed_dict=feed_dict, 492 options=decorated_run_options, 493 run_metadata=run_metadata) 494 except errors.OpError as op_error: 495 if self._pass_through_operrors: 496 raise op_error 497 tf_error = op_error 498 retvals = op_error 499 500 run_end_req = OnRunEndRequest( 501 run_start_resp.action, 502 run_metadata=run_metadata, 503 client_graph_def=self._sess.graph.as_graph_def(), 504 tf_error=tf_error) 505 506 elif run_start_resp.action == OnRunStartAction.PROFILE_RUN: 507 decorated_run_options = options or config_pb2.RunOptions() 508 run_metadata = run_metadata or config_pb2.RunMetadata() 509 self._decorate_run_options_for_profile(decorated_run_options) 510 if callable_runner: 511 retvals = callable_runner(*callable_runner_args, 512 options=decorated_run_options, 513 run_metadata=run_metadata) 514 else: 515 retvals = self._sess.run(fetches, 516 feed_dict=feed_dict, 517 options=decorated_run_options, 518 run_metadata=run_metadata) 519 run_end_req = OnRunEndRequest( 520 run_start_resp.action, 521 run_metadata=run_metadata, 522 client_graph_def=self._sess.graph.as_graph_def()) 523 elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or 524 run_start_resp.action == OnRunStartAction.INVOKE_STEPPER): 525 if callable_runner: 526 raise NotImplementedError( 527 "Stepper mode is not implemented for callables created by " 528 "Session.make_callable().") 529 530 if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER: 531 with stepper.NodeStepper( 532 self._sess, fetches, feed_dict) as node_stepper: 533 retvals = self.invoke_node_stepper( 534 node_stepper, restore_variable_values_on_exit=True) 535 536 # Invoke run() method of the wrapped session. 537 retvals = self._sess.run( 538 fetches, 539 feed_dict=feed_dict, 540 options=options, 541 run_metadata=run_metadata) 542 543 # Prepare arg for the on-run-end callback. 544 run_end_req = OnRunEndRequest(run_start_resp.action) 545 else: 546 raise ValueError( 547 "Invalid OnRunStartAction value: %s" % run_start_resp.action) 548 549 # Invoke on-run-end callback and obtain response. 550 run_end_resp = self.on_run_end(run_end_req) 551 _check_type(run_end_resp, OnRunEndResponse) 552 # Currently run_end_resp is only a placeholder. No action is taken on it. 553 554 return retvals 555 556 def _is_disabled_thread(self): 557 thread_name = threading.current_thread().name or "" 558 return (self._thread_name_filter_pattern and 559 not self._thread_name_filter_pattern.match(thread_name)) 560 561 def run_step_fn(self, step_fn): 562 return step_fn( 563 monitored_session.MonitoredSession.StepContext(self._sess, self.run)) 564 565 def partial_run_setup(self, fetches, feeds=None): 566 """Sets up the feeds and fetches for partial runs in the session.""" 567 raise NotImplementedError( 568 "partial_run_setup is not implemented for debug-wrapper sessions.") 569 570 def partial_run(self, handle, fetches, feed_dict=None): 571 raise NotImplementedError( 572 "partial_run is not implemented for debug-wrapper sessions.") 573 574 def list_devices(self, *args, **kwargs): 575 return self._sess.list_devices(*args, **kwargs) 576 577 def reset(self, *args, **kwargs): 578 return self._sess.reset(*args, **kwargs) 579 580 def make_callable(self, 581 fetches, 582 feed_list=None, 583 accept_options=False): 584 runner = self._sess.make_callable( 585 fetches, feed_list=feed_list, accept_options=True) 586 def wrapped_runner(*runner_args, **kwargs): 587 return self.run(None, 588 feed_dict=None, 589 options=kwargs.get("options", None), 590 run_metadata=kwargs.get("run_metadata", None), 591 callable_runner=runner, 592 callable_runner_args=runner_args) 593 594 return wrapped_runner 595 596 @property 597 def run_call_count(self): 598 return self._run_call_count 599 600 def increment_run_call_count(self): 601 self._run_call_count += 1 602 603 def _decorate_run_options_for_debug( 604 self, 605 run_options, 606 debug_urls, 607 debug_ops="DebugIdentity", 608 node_name_regex_whitelist=None, 609 op_type_regex_whitelist=None, 610 tensor_dtype_regex_whitelist=None, 611 tolerate_debug_op_creation_failures=False): 612 """Modify a RunOptions object for debug tensor watching. 613 614 Specifies request for outputting partition graphs. Adds 615 debug_tensor_watch_opts with proper debug URLs. 616 617 Args: 618 run_options: (RunOptions) the modified RunOptions object. 619 debug_urls: (list of str) debug URLs to be entered in run_options. 620 debug_tensor_watch_opts. 621 debug_ops: (str or list of str) debug op(s) to be used by the debugger. 622 node_name_regex_whitelist: Regular-expression whitelist for node 623 name. 624 op_type_regex_whitelist: Regular-expression whitelist for op type. 625 tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor 626 dtype. 627 tolerate_debug_op_creation_failures: Whether debug op creation failures 628 are to be tolerated. 629 """ 630 631 run_options.output_partition_graphs = True 632 debug_utils.watch_graph( 633 run_options, 634 self._sess.graph, 635 debug_urls=debug_urls, 636 debug_ops=debug_ops, 637 node_name_regex_whitelist=node_name_regex_whitelist, 638 op_type_regex_whitelist=op_type_regex_whitelist, 639 tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist, 640 tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures) 641 642 def _decorate_run_options_for_profile(self, run_options): 643 """Modify a RunOptions object for profiling TensorFlow graph execution. 644 645 Args: 646 run_options: (RunOptions) the modified RunOptions object. 647 """ 648 649 run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 650 651 @abc.abstractmethod 652 def on_session_init(self, request): 653 """Callback invoked during construction of the debug-wrapper session. 654 655 This is a blocking callback. 656 The invocation happens right before the constructor ends. 657 658 Args: 659 request: (`OnSessionInitRequest`) callback request carrying information 660 such as the session being wrapped. 661 662 Returns: 663 An instance of `OnSessionInitResponse`. 664 """ 665 666 @abc.abstractmethod 667 def on_run_start(self, request): 668 """Callback invoked on run() calls to the debug-wrapper session. 669 670 This is a blocking callback. 671 The invocation happens after the wrapper's run() call is entered, 672 after an increment of run call counter. 673 674 Args: 675 request: (`OnRunStartRequest`) callback request object carrying 676 information about the run call such as the fetches, feed dict, run 677 options, run metadata, and how many `run()` calls to this wrapper 678 session have occurred. 679 680 Returns: 681 An instance of `OnRunStartResponse`, carrying information to 682 1) direct the wrapper session to perform a specified action (e.g., run 683 with or without debug tensor watching, invoking the stepper.) 684 2) debug URLs used to watch the tensors. 685 """ 686 687 @abc.abstractmethod 688 def on_run_end(self, request): 689 """Callback invoked on run() calls to the debug-wrapper session. 690 691 This is a blocking callback. 692 The invocation happens right before the wrapper exits its run() call. 693 694 Args: 695 request: (`OnRunEndRequest`) callback request object carrying information 696 such as the actual action performed by the session wrapper for the 697 run() call. 698 699 Returns: 700 An instance of `OnRunStartResponse`. 701 """ 702 703 def as_default(self): 704 return ops.default_session(self) 705 706 def __enter__(self): 707 if self._default_session_context_manager is None: 708 self._default_session_context_manager = self.as_default() 709 return self._default_session_context_manager.__enter__() 710 711 def __exit__(self, exec_type, exec_value, exec_tb): 712 self._default_session_context_manager.__exit__( 713 exec_type, exec_value, exec_tb) 714 715 def __del__(self): 716 if hasattr(self._sess, "__del__"): 717 self._sess.__del__() 718 719 def close(self): 720 self._sess.close() 721 722 # TODO(cais): Add _node_name_regex_whitelist and 723 # _node_op_type_regex_whitelist. 724 725 @abc.abstractmethod 726 def invoke_node_stepper(self, 727 node_stepper, 728 restore_variable_values_on_exit=True): 729 """Callback invoked when the client intends to step through graph nodes. 730 731 Args: 732 node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used 733 in this stepping session. 734 restore_variable_values_on_exit: (bool) Whether any variables whose values 735 have been altered during this node-stepper invocation should be restored 736 to their old values when this invocation ends. 737 738 Returns: 739 The same return values as the `Session.run()` call on the same fetches as 740 the NodeStepper. 741 """ 742 743 def should_stop(self): 744 if hasattr(self._sess, "should_stop"): 745 return self._sess.should_stop() 746 else: 747 raise ValueError( 748 "The wrapped session %r does not have a method called 'should_stop'. " 749 "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess) 750 751 752 class WatchOptions(object): 753 """Type for return values of watch_fn.""" 754 755 def __init__(self, 756 debug_ops=None, 757 node_name_regex_whitelist=None, 758 op_type_regex_whitelist=None, 759 tensor_dtype_regex_whitelist=None, 760 tolerate_debug_op_creation_failures=False): 761 """Constructor of WatchOptions: Debug watch options. 762 763 Used as return values of `watch_fn`s. 764 765 Args: 766 debug_ops: (`str` or `list of str`) Debug ops to be used. 767 node_name_regex_whitelist: Regular-expression whitelist for node_name, 768 e.g., `"(weight_[0-9]+|bias_.*)"` 769 op_type_regex_whitelist: Regular-expression whitelist for the op type of 770 nodes, e.g., `"(Variable|Add)"`. 771 If both `node_name_regex_whitelist` and `op_type_regex_whitelist` 772 are set, the two filtering operations will occur in a logical `AND` 773 relation. In other words, a node will be included if and only if it 774 hits both whitelists. 775 tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor 776 data type, e.g., `"^int.*"`. 777 This whitelist operates in logical `AND` relations to the two whitelists 778 above. 779 tolerate_debug_op_creation_failures: (`bool`) whether debug op creation 780 failures (e.g., due to dtype incompatibility) are to be tolerated by not 781 throwing exceptions. 782 """ 783 if debug_ops: 784 self.debug_ops = debug_ops 785 else: 786 self.debug_ops = ["DebugIdentity"] 787 self.node_name_regex_whitelist = node_name_regex_whitelist 788 self.op_type_regex_whitelist = op_type_regex_whitelist 789 self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist 790 self.tolerate_debug_op_creation_failures = ( 791 tolerate_debug_op_creation_failures) 792 793 def __repr__(self): 794 return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, " 795 "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, " 796 "tolerate_debug_op_creation_failures=%r)" % ( 797 self.debug_ops, self.node_name_regex_whitelist, 798 self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist, 799 self.tolerate_debug_op_creation_failures)) 800 801 802 class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): 803 """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions.""" 804 805 def __init__(self, sess, watch_fn=None, thread_name_filter=None, 806 pass_through_operrors=False): 807 """Constructor of NonInteractiveDebugWrapperSession. 808 809 Args: 810 sess: The TensorFlow `Session` object being wrapped. 811 watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a 812 debugged `Session.run()` call to `WatchOptions.` 813 * Args: 814 * `fetches`: the fetches to the `Session.run()` call. 815 * `feeds`: the feeds to the `Session.run()` call. 816 817 * Returns: 818 (`tf_debug.WatchOptions`) An object containing debug options including 819 the debug ops to use, the node names, op types and/or tensor data 820 types to watch, etc. See the documentation of `tf_debug.WatchOptions` 821 for more details. 822 thread_name_filter: Regular-expression white list for threads on which the 823 wrapper session will be active. See doc of `BaseDebugWrapperSession` for 824 more details. 825 pass_through_operrors: If true, all captured OpErrors will be 826 propagated. By default this captures all OpErrors. 827 Raises: 828 TypeError: If a non-None `watch_fn` is specified and it is not callable. 829 """ 830 831 BaseDebugWrapperSession.__init__( 832 self, sess, thread_name_filter=thread_name_filter, 833 pass_through_operrors=pass_through_operrors) 834 835 self._watch_fn = None 836 if watch_fn is not None: 837 if not callable(watch_fn): 838 raise TypeError("watch_fn is not callable") 839 self._watch_fn = watch_fn 840 841 def on_session_init(self, request): 842 """See doc of BaseDebugWrapperSession.on_run_start.""" 843 844 return OnSessionInitResponse(OnSessionInitAction.PROCEED) 845 846 @abc.abstractmethod 847 def prepare_run_debug_urls(self, fetches, feed_dict): 848 """Abstract method to be implemented by concrete subclasses. 849 850 This method prepares the run-specific debug URL(s). 851 852 Args: 853 fetches: Same as the `fetches` argument to `Session.run()` 854 feed_dict: Same as the `feed_dict` argument to `Session.run()` 855 856 Returns: 857 debug_urls: (`str` or `list` of `str`) Debug URLs to be used in 858 this `Session.run()` call. 859 """ 860 861 def on_run_start(self, request): 862 """See doc of BaseDebugWrapperSession.on_run_start.""" 863 864 debug_urls, watch_opts = self._prepare_run_watch_config( 865 request.fetches, request.feed_dict) 866 867 return OnRunStartResponse( 868 OnRunStartAction.DEBUG_RUN, 869 debug_urls, 870 debug_ops=watch_opts.debug_ops, 871 node_name_regex_whitelist=watch_opts.node_name_regex_whitelist, 872 op_type_regex_whitelist=watch_opts.op_type_regex_whitelist, 873 tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist, 874 tolerate_debug_op_creation_failures=( 875 watch_opts.tolerate_debug_op_creation_failures)) 876 877 def _prepare_run_watch_config(self, fetches, feed_dict): 878 """Get the debug_urls, and node/op whitelists for the current run() call. 879 880 Args: 881 fetches: Same as the `fetches` argument to `Session.run()`. 882 feed_dict: Same as the `feed_dict argument` to `Session.run()`. 883 884 Returns: 885 debug_urls: (str or list of str) Debug URLs for the current run() call. 886 Currently, the list consists of only one URL that is a file:// URL. 887 watch_options: (WatchOptions) The return value of a watch_fn, containing 888 options including debug_ops, and whitelists. 889 """ 890 891 debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) 892 if self._watch_fn is None: 893 watch_options = WatchOptions() 894 else: 895 watch_options = self._watch_fn(fetches, feed_dict) 896 if isinstance(watch_options, tuple): 897 # For legacy return type (tuples). 898 watch_options = WatchOptions(*watch_options) 899 900 return debug_urls, watch_options 901 902 def on_run_end(self, request): 903 """See doc of BaseDebugWrapperSession.on_run_end.""" 904 905 return OnRunEndResponse() 906 907 def invoke_node_stepper(self, 908 node_stepper, 909 restore_variable_values_on_exit=True): 910 """See doc of BaseDebugWrapperSession.invoke_node_stepper.""" 911 912 raise NotImplementedError( 913 "NonInteractiveDebugWrapperSession does not support node-stepper mode.") 914