Home | History | Annotate | Download | only in wrappers
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Framework of debug wrapper sessions.
     16 
     17 A debug wrapper session is a wrapper around a TensorFlow Python Session.
     18 The wrapper preserves the Session interface, most importantly the run() method,
     19 while providing abilities to:
     20 a) Intercept a run() call to a wrapped session and insert debug tensor watches
     21    according to externally-specified debug URLs.
     22 
     23 b) Release control to an external (i.e., non-Session) object before and after
     24    the run() call, so that the external object can perform actions such as
     25    launching a UI to let users inspect the intermediate tensors and partition
     26    graphs from the run() call.
     27 
     28 c) (To be implemented) Intercept a run() call and give control to DebugStepper
     29    to let it perform stepping / continuing-to actions on the graph.
     30 
     31 b) (To be implemented in a future CL) Enter an instruction loop to let an
     32    external object (e.g., remote client) launch run() and cont() calls
     33    remotely.
     34 
     35 *** The lifetime of a debug wrapper session: ***
     36 
     37 1) The wrapper session is created by calling the constructor with a
     38    wrapped (normal) session as the argument:
     39      wrapper = FooDebugWrapperSession(sess)
     40    wherein FooDebugWrapperSession is a concrete subclass implementing the
     41    abstract BaseDebugWrapperSession class below.
     42 
     43 2) Near the end of the constructor call, the on_session_init() callback is
     44    invoked, with a OnSessionInitRequest object as the argument. The object
     45    carries the wrapped (normal) session object.
     46 
     47 3) The callback handles the request and returns a OnSessionInitResponse
     48    object with an action field, directing the wrapper session what to do next.
     49 
     50 If the action field in the OnSessionInitResponse is PROCEED, the constuctor
     51 returns. Control is released back to the caller of the constructor, which can
     52 invoke run() method of wrapper session with the same syntax as a non-wrapped
     53 session, e.g.,:
     54   wrapper.run(fetches, feed_dict=feeds, options=run_options)
     55 
     56 Below, A1 - A2 is the lifetime of a wrapper run() call if the action is
     57 PROCEED:
     58 
     59 A1) Right at the start of each run() call, the on_run_start() callback is
     60     invoked, with an OnRunStartRequest object carrying information such as
     61     the fetches, the feed dict, the run options and run metadata used in
     62     this run call, along with a count of how many run calls has occurred
     63     on this wrapper session. The callback then returns an OnRunStartResponse
     64     object, of which the action field directs what the wrapper session
     65     actually will do of the run() call.
     66 
     67     If the action is DEBUG_RUN, a debugged (tensor-watched) run will ensue,
     68     with the debug URLs supplied in the debug_urls field of the response.
     69     These can be file:// or grpc:// URLs, for example.
     70 
     71     If the action is NON_DEBUG_RUN, a non-debug (normal) run will ensue.
     72 
     73     If the action is INVOKE_STEPPER, no run() call will be issued to the
     74     wrapped session. But instead, a DebugStepper (i.e., "continuation
     75     debugger") will be used to perform stepping / continue-to actions on
     76     the graph.
     77 
     78 TODO(cais): The event loop for the DebugStepper will request additional
     79    callbacks including on_cont_start() and on_cont_end(). Add those.
     80 
     81 A2) Right before the run() returns, the on_run_end() callback is invoked,
     82     with an OnRunEndRequest object as the argument, which carries information
     83     including the actual action performed in the warpper run() call and the
     84     run_metadata from the run() call.
     85 
     86 However, if the action field in OnSessionInitResponse is
     87 REMOTE_INSTR_LOOP, the constructor will automatically invoke an instruction loop
     88 that gives the control to a remote caller.
     89 
     90 In the remote instruction loop, the following steps will happen:
     91 
     92 B1) Callback on_instr_start() is invoked. The callback will return an
     93     OnInstrStartResponse object with an action field which can order one of
     94     the following actions:
     95         i) a run() call with fetches, feeds and debug_urls specified.
     96        ii) a DebugStepper cont() call with target specified.
     97       iii) value overrides in the cached tensors from the DebugStepper.
     98        iv) exit the instruction loop.
     99 
    100 B2) The wrapper session carries out the action specified above.
    101 
    102 B3) If still in the instruction loop, the wrapper session invokes the
    103     on_instr_end() callback. After the on_instr_end() callback returns, jump
    104     back to B1.
    105 
    106 TODO(cais): Implemented the instruction loop in B1 - B3.
    107 
    108 """
    109 
    110 from __future__ import absolute_import
    111 from __future__ import division
    112 from __future__ import print_function
    113 
    114 import abc
    115 import re
    116 import threading
    117 
    118 from tensorflow.core.protobuf import config_pb2
    119 from tensorflow.python.client import session
    120 from tensorflow.python.debug.lib import debug_utils
    121 from tensorflow.python.debug.lib import stepper
    122 from tensorflow.python.framework import errors
    123 from tensorflow.python.framework import ops
    124 from tensorflow.python.platform import tf_logging
    125 from tensorflow.python.training import monitored_session
    126 from tensorflow.python.util import nest
    127 
    128 
    129 # Helper function.
    130 def _check_type(obj, expected_types):
    131   """Check if an object is of the expected type.
    132 
    133   Args:
    134     obj: The object being checked.
    135     expected_types: (`type` or an iterable of `type`s) The expected `type`(s)
    136       of obj.
    137 
    138   Raises:
    139       TypeError: If obj is not an instance of expected_type.
    140   """
    141   if not isinstance(obj, expected_types):
    142     raise TypeError("Expected type %s; got type %s" %
    143                     (expected_types, type(obj)))
    144 
    145 
    146 class OnSessionInitRequest(object):
    147   """Request to an on-session-init callback.
    148 
    149   This callback is invoked during the __init__ call to a debug-wrapper session.
    150   """
    151 
    152   def __init__(self, sess):
    153     """Constructor.
    154 
    155     Args:
    156       sess: A tensorflow Session object.
    157     """
    158 
    159     _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
    160     self.session = sess
    161 
    162 
    163 class OnSessionInitAction(object):
    164   """Enum-like values for possible action to take on session init."""
    165 
    166   # Proceed, without special actions, in the wrapper session initialization.
    167   # What action the wrapper session performs next is determined by the caller
    168   # of the wrapper session. E.g., it can call run().
    169   PROCEED = "proceed"
    170 
    171   # Instead of letting the caller of the wrapper session determine what actions
    172   # the wrapper session will perform next, enter a loop to receive instructions
    173   # from a remote client.
    174   # For example, TensorBoard visual debugger can use this action so that it can
    175   # launch session.run() calls remotely.
    176   REMOTE_INSTR_LOOP = "remote_instr_loop"
    177 
    178 
    179 class OnSessionInitResponse(object):
    180   """Response from an on-session-init callback."""
    181 
    182   def __init__(self, action):
    183     """Constructor.
    184 
    185     Args:
    186       action: (`OnSessionInitAction`) Debugger action to take on session init.
    187     """
    188     _check_type(action, str)
    189     self.action = action
    190 
    191 
    192 class OnRunStartRequest(object):
    193   """Request to an on-run-start callback.
    194 
    195   This callback is invoked during a run() call of the debug-wrapper
    196   session, immediately after the run() call counter is incremented.
    197   """
    198 
    199   def __init__(self, fetches, feed_dict, run_options, run_metadata,
    200                run_call_count, is_callable_runner=False):
    201     """Constructor of `OnRunStartRequest`.
    202 
    203     Args:
    204       fetches: Fetch targets of the run() call.
    205       feed_dict: The feed dictionary to the run() call.
    206       run_options: RunOptions input to the run() call.
    207       run_metadata: RunMetadata input to the run() call.
    208         The above four arguments are identical to the input arguments to the
    209         run() method of a non-wrapped TensorFlow session.
    210       run_call_count: 1-based count of how many run calls (including this one)
    211         has been invoked.
    212       is_callable_runner: (bool) whether a runner returned by
    213         Session.make_callable is being run.
    214     """
    215     self.fetches = fetches
    216     self.feed_dict = feed_dict
    217     self.run_options = run_options
    218     self.run_metadata = run_metadata
    219     self.run_call_count = run_call_count
    220     self.is_callable_runner = is_callable_runner
    221 
    222 
    223 class OnRunStartAction(object):
    224   """Enum-like values for possible action to take on start of a run() call."""
    225 
    226   # Run once with debug tensor-watching.
    227   DEBUG_RUN = "debug_run"
    228 
    229   # Run once with profiler.
    230   PROFILE_RUN = "profile_run"
    231 
    232   # Run without debug tensor-watching.
    233   NON_DEBUG_RUN = "non_debug_run"
    234 
    235   # Instead of running the fetches as a whole, as would normally happen, invoke
    236   # the (to-be-implemented) debug stepper.
    237   # TODO(cais): Remove "to-be-implemented".
    238   INVOKE_STEPPER = "invoke_stepper"
    239 
    240 
    241 class OnRunStartResponse(object):
    242   """Request from an on-run-start callback.
    243 
    244   The caller of the callback can use this response object to specify what
    245   action the debug-wrapper session actually takes on the run() call.
    246   """
    247 
    248   def __init__(self,
    249                action,
    250                debug_urls,
    251                debug_ops="DebugIdentity",
    252                node_name_regex_whitelist=None,
    253                op_type_regex_whitelist=None,
    254                tensor_dtype_regex_whitelist=None,
    255                tolerate_debug_op_creation_failures=False):
    256     """Constructor of `OnRunStartResponse`.
    257 
    258     Args:
    259       action: (`OnRunStartAction`) the action actually taken by the wrapped
    260         session for the run() call.
    261       debug_urls: (`list` of `str`) debug_urls used in watching the tensors
    262         during the run() call.
    263       debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
    264         debugger.
    265       node_name_regex_whitelist: Regular-expression whitelist for node
    266         name.
    267       op_type_regex_whitelist: Regular-expression whitelist for op type.
    268       tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
    269         dtype.
    270       tolerate_debug_op_creation_failures: Whether debug op creation failures
    271         are to be tolerated.
    272     """
    273 
    274     _check_type(action, str)
    275     self.action = action
    276 
    277     _check_type(debug_urls, list)
    278     self.debug_urls = debug_urls
    279 
    280     self.debug_ops = debug_ops
    281 
    282     self.node_name_regex_whitelist = node_name_regex_whitelist
    283     self.op_type_regex_whitelist = op_type_regex_whitelist
    284     self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
    285     self.tolerate_debug_op_creation_failures = (
    286         tolerate_debug_op_creation_failures)
    287 
    288 
    289 class OnRunEndRequest(object):
    290   """Request to an on-run-end callback.
    291 
    292   The callback is invoked immediately before the wrapped run() call ends.
    293   """
    294 
    295   def __init__(self,
    296                performed_action,
    297                run_metadata=None,
    298                client_graph_def=None,
    299                tf_error=None):
    300     """Constructor for `OnRunEndRequest`.
    301 
    302     Args:
    303       performed_action: (`OnRunStartAction`) Actually-performed action by the
    304         debug-wrapper session.
    305       run_metadata: run_metadata output from the run() call (if any).
    306       client_graph_def: (GraphDef) GraphDef from the client side, i.e., from
    307         the python front end of TensorFlow. Can be obtained with
    308         session.graph.as_graph_def().
    309       tf_error: (errors.OpError subtypes) TensorFlow OpError that occurred
    310         during the run (if any).
    311     """
    312 
    313     _check_type(performed_action, str)
    314     self.performed_action = performed_action
    315 
    316     if run_metadata is not None:
    317       _check_type(run_metadata, config_pb2.RunMetadata)
    318     self.run_metadata = run_metadata
    319     self.client_graph_def = client_graph_def
    320     self.tf_error = tf_error
    321 
    322 
    323 class OnRunEndResponse(object):
    324   """Response from an on-run-end callback."""
    325 
    326   def __init__(self):
    327 
    328     # Currently only a placeholder.
    329     pass
    330 
    331 
    332 class BaseDebugWrapperSession(session.SessionInterface):
    333   """Base class of debug-wrapper session classes.
    334 
    335   Concrete classes that inherit from this class need to implement the abstract
    336   methods such as on_session_init, on_run_start and on_run_end.
    337   """
    338 
    339   # TODO(cais): Add on_cont_start and on_cont_end callbacks once the stepper is
    340   # is available.
    341 
    342   def __init__(self, sess, thread_name_filter=None,
    343                pass_through_operrors=False):
    344     """Constructor of `BaseDebugWrapperSession`.
    345 
    346     Args:
    347       sess: An (unwrapped) TensorFlow session instance. It should be a subtype
    348         of `BaseSession` or `tf.MonitoredSession`.
    349       thread_name_filter: Regular-expression filter (whitelist) for name(s) of
    350         thread(s) on which the wrapper session will be active. This regular
    351         expression is used in a start-anchored fashion on the thread name, i.e.,
    352         by applying the `match` method of the compiled pattern. The default
    353         `None` means that the wrapper session will be active on all threads.
    354         E.g., r"MainThread$", r"QueueRunnerThread.*".
    355       pass_through_operrors: If True, all captured OpErrors will be
    356         propagated.  By default this captures all OpErrors.
    357 
    358     Raises:
    359       ValueError: On invalid `OnSessionInitAction` value.
    360       NotImplementedError: If a non-DirectSession sess object is received.
    361     """
    362 
    363     _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
    364 
    365     # The session being wrapped.
    366     self._sess = sess
    367     self._thread_name_filter_pattern = (re.compile(thread_name_filter)
    368                                         if thread_name_filter else None)
    369     # TODO(cais/kstevens): Unittest this pass through feature.
    370     self._pass_through_operrors = pass_through_operrors
    371 
    372     # Keeps track of number of run calls that have been performed on this
    373     # debug-wrapper session. The count can be used for purposes such as
    374     # displaying the state of the Session in a UI and determining a run
    375     # number-dependent debug URL.
    376     self._run_call_count = 0
    377 
    378     # Invoke on-session-init callback.
    379     response = self.on_session_init(OnSessionInitRequest(self._sess))
    380     _check_type(response, OnSessionInitResponse)
    381 
    382     if response.action == OnSessionInitAction.PROCEED:
    383       pass
    384     elif response.action == OnSessionInitAction.REMOTE_INSTR_LOOP:
    385       # TODO(cais): Implement REMOTE_INSTR_LOOP
    386       raise NotImplementedError(
    387           "OnSessionInitAction REMOTE_INSTR_LOOP has not been "
    388           "implemented.")
    389     else:
    390       raise ValueError(
    391           "Invalid OnSessionInitAction value: %s" % response.action)
    392 
    393     self._default_session_context_manager = None
    394 
    395   @property
    396   def graph(self):
    397     return self._sess.graph
    398 
    399   @property
    400   def graph_def(self):
    401     return self._sess.graph_def
    402 
    403   @property
    404   def sess_str(self):
    405     return self._sess.sess_str
    406 
    407   @property
    408   def session(self):
    409     return self._sess
    410 
    411   def run(self,
    412           fetches,
    413           feed_dict=None,
    414           options=None,
    415           run_metadata=None,
    416           callable_runner=None,
    417           callable_runner_args=None):
    418     """Wrapper around Session.run() that inserts tensor watch options.
    419 
    420     Args:
    421       fetches: Same as the `fetches` arg to regular `Session.run()`.
    422       feed_dict: Same as the `feed_dict` arg to regular `Session.run()`.
    423       options: Same as the `options` arg to regular `Session.run()`.
    424       run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
    425       callable_runner: A `callable` returned by `Session.make_callable()`.
    426         If not `None`, `fetches` and `feed_dict` must both be `None`.
    427       callable_runner_args: An optional list of arguments to `callable_runner`.
    428 
    429     Returns:
    430       Simply forwards the output of the wrapped `Session.run()` call.
    431 
    432     Raises:
    433       ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
    434         is not `None` and either or both of `fetches` and `feed_dict` is `None`.
    435     """
    436     if not callable_runner:
    437       self.increment_run_call_count()
    438     else:
    439       if fetches or feed_dict:
    440         raise ValueError(
    441             "callable_runner and fetches/feed_dict are mutually exclusive, but "
    442             "are used simultaneously.")
    443 
    444     empty_fetches = not nest.flatten(fetches)
    445     if empty_fetches:
    446       tf_logging.info(
    447           "Due to empty fetches, tfdbg Session wrapper is letting a "
    448           "Session.run pass through without any debugging actions.")
    449     if self._is_disabled_thread() or empty_fetches:
    450       if callable_runner:
    451         return callable_runner(*callable_runner_args)
    452       else:
    453         return self._sess.run(fetches,
    454                               feed_dict=feed_dict,
    455                               options=options,
    456                               run_metadata=run_metadata)
    457 
    458     # Invoke on-run-start callback and obtain response.
    459     run_start_resp = self.on_run_start(
    460         OnRunStartRequest(fetches, feed_dict, options, run_metadata,
    461                           self._run_call_count,
    462                           is_callable_runner=bool(callable_runner)))
    463     _check_type(run_start_resp, OnRunStartResponse)
    464 
    465     if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
    466       # Decorate RunOption to fill in debugger tensor watch specifications.
    467       decorated_run_options = options or config_pb2.RunOptions()
    468       run_metadata = run_metadata or config_pb2.RunMetadata()
    469 
    470       self._decorate_run_options_for_debug(
    471           decorated_run_options,
    472           run_start_resp.debug_urls,
    473           debug_ops=run_start_resp.debug_ops,
    474           node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
    475           op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
    476           tensor_dtype_regex_whitelist=(
    477               run_start_resp.tensor_dtype_regex_whitelist),
    478           tolerate_debug_op_creation_failures=(
    479               run_start_resp.tolerate_debug_op_creation_failures))
    480 
    481       # Invoke the run() method of the wrapped Session. Catch any TensorFlow
    482       # runtime errors.
    483       tf_error = None
    484       try:
    485         if callable_runner:
    486           retvals = callable_runner(*callable_runner_args,
    487                                     options=decorated_run_options,
    488                                     run_metadata=run_metadata)
    489         else:
    490           retvals = self._sess.run(fetches,
    491                                    feed_dict=feed_dict,
    492                                    options=decorated_run_options,
    493                                    run_metadata=run_metadata)
    494       except errors.OpError as op_error:
    495         if self._pass_through_operrors:
    496           raise op_error
    497         tf_error = op_error
    498         retvals = op_error
    499 
    500       run_end_req = OnRunEndRequest(
    501           run_start_resp.action,
    502           run_metadata=run_metadata,
    503           client_graph_def=self._sess.graph.as_graph_def(),
    504           tf_error=tf_error)
    505 
    506     elif run_start_resp.action == OnRunStartAction.PROFILE_RUN:
    507       decorated_run_options = options or config_pb2.RunOptions()
    508       run_metadata = run_metadata or config_pb2.RunMetadata()
    509       self._decorate_run_options_for_profile(decorated_run_options)
    510       if callable_runner:
    511         retvals = callable_runner(*callable_runner_args,
    512                                   options=decorated_run_options,
    513                                   run_metadata=run_metadata)
    514       else:
    515         retvals = self._sess.run(fetches,
    516                                  feed_dict=feed_dict,
    517                                  options=decorated_run_options,
    518                                  run_metadata=run_metadata)
    519       run_end_req = OnRunEndRequest(
    520           run_start_resp.action,
    521           run_metadata=run_metadata,
    522           client_graph_def=self._sess.graph.as_graph_def())
    523     elif (run_start_resp.action == OnRunStartAction.NON_DEBUG_RUN or
    524           run_start_resp.action == OnRunStartAction.INVOKE_STEPPER):
    525       if callable_runner:
    526         raise NotImplementedError(
    527             "Stepper mode is not implemented for callables created by "
    528             "Session.make_callable().")
    529 
    530       if run_start_resp.action == OnRunStartAction.INVOKE_STEPPER:
    531         with stepper.NodeStepper(
    532             self._sess, fetches, feed_dict) as node_stepper:
    533           retvals = self.invoke_node_stepper(
    534               node_stepper, restore_variable_values_on_exit=True)
    535 
    536       # Invoke run() method of the wrapped session.
    537       retvals = self._sess.run(
    538           fetches,
    539           feed_dict=feed_dict,
    540           options=options,
    541           run_metadata=run_metadata)
    542 
    543       # Prepare arg for the on-run-end callback.
    544       run_end_req = OnRunEndRequest(run_start_resp.action)
    545     else:
    546       raise ValueError(
    547           "Invalid OnRunStartAction value: %s" % run_start_resp.action)
    548 
    549     # Invoke on-run-end callback and obtain response.
    550     run_end_resp = self.on_run_end(run_end_req)
    551     _check_type(run_end_resp, OnRunEndResponse)
    552     # Currently run_end_resp is only a placeholder. No action is taken on it.
    553 
    554     return retvals
    555 
    556   def _is_disabled_thread(self):
    557     thread_name = threading.current_thread().name or ""
    558     return (self._thread_name_filter_pattern and
    559             not self._thread_name_filter_pattern.match(thread_name))
    560 
    561   def run_step_fn(self, step_fn):
    562     return step_fn(
    563         monitored_session.MonitoredSession.StepContext(self._sess, self.run))
    564 
    565   def partial_run_setup(self, fetches, feeds=None):
    566     """Sets up the feeds and fetches for partial runs in the session."""
    567     raise NotImplementedError(
    568         "partial_run_setup is not implemented for debug-wrapper sessions.")
    569 
    570   def partial_run(self, handle, fetches, feed_dict=None):
    571     raise NotImplementedError(
    572         "partial_run is not implemented for debug-wrapper sessions.")
    573 
    574   def list_devices(self, *args, **kwargs):
    575     return self._sess.list_devices(*args, **kwargs)
    576 
    577   def reset(self, *args, **kwargs):
    578     return self._sess.reset(*args, **kwargs)
    579 
    580   def make_callable(self,
    581                     fetches,
    582                     feed_list=None,
    583                     accept_options=False):
    584     runner = self._sess.make_callable(
    585         fetches, feed_list=feed_list, accept_options=True)
    586     def wrapped_runner(*runner_args, **kwargs):
    587       return self.run(None,
    588                       feed_dict=None,
    589                       options=kwargs.get("options", None),
    590                       run_metadata=kwargs.get("run_metadata", None),
    591                       callable_runner=runner,
    592                       callable_runner_args=runner_args)
    593 
    594     return wrapped_runner
    595 
    596   @property
    597   def run_call_count(self):
    598     return self._run_call_count
    599 
    600   def increment_run_call_count(self):
    601     self._run_call_count += 1
    602 
    603   def _decorate_run_options_for_debug(
    604       self,
    605       run_options,
    606       debug_urls,
    607       debug_ops="DebugIdentity",
    608       node_name_regex_whitelist=None,
    609       op_type_regex_whitelist=None,
    610       tensor_dtype_regex_whitelist=None,
    611       tolerate_debug_op_creation_failures=False):
    612     """Modify a RunOptions object for debug tensor watching.
    613 
    614     Specifies request for outputting partition graphs. Adds
    615     debug_tensor_watch_opts with proper debug URLs.
    616 
    617     Args:
    618       run_options: (RunOptions) the modified RunOptions object.
    619       debug_urls: (list of str) debug URLs to be entered in run_options.
    620         debug_tensor_watch_opts.
    621       debug_ops: (str or list of str) debug op(s) to be used by the debugger.
    622       node_name_regex_whitelist: Regular-expression whitelist for node
    623         name.
    624       op_type_regex_whitelist: Regular-expression whitelist for op type.
    625       tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
    626         dtype.
    627       tolerate_debug_op_creation_failures: Whether debug op creation failures
    628         are to be tolerated.
    629     """
    630 
    631     run_options.output_partition_graphs = True
    632     debug_utils.watch_graph(
    633         run_options,
    634         self._sess.graph,
    635         debug_urls=debug_urls,
    636         debug_ops=debug_ops,
    637         node_name_regex_whitelist=node_name_regex_whitelist,
    638         op_type_regex_whitelist=op_type_regex_whitelist,
    639         tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
    640         tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures)
    641 
    642   def _decorate_run_options_for_profile(self, run_options):
    643     """Modify a RunOptions object for profiling TensorFlow graph execution.
    644 
    645     Args:
    646       run_options: (RunOptions) the modified RunOptions object.
    647     """
    648 
    649     run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
    650 
    651   @abc.abstractmethod
    652   def on_session_init(self, request):
    653     """Callback invoked during construction of the debug-wrapper session.
    654 
    655     This is a blocking callback.
    656     The invocation happens right before the constructor ends.
    657 
    658     Args:
    659       request: (`OnSessionInitRequest`) callback request carrying information
    660         such as the session being wrapped.
    661 
    662     Returns:
    663       An instance of `OnSessionInitResponse`.
    664     """
    665 
    666   @abc.abstractmethod
    667   def on_run_start(self, request):
    668     """Callback invoked on run() calls to the debug-wrapper session.
    669 
    670     This is a blocking callback.
    671     The invocation happens after the wrapper's run() call is entered,
    672     after an increment of run call counter.
    673 
    674     Args:
    675       request: (`OnRunStartRequest`) callback request object carrying
    676         information about the run call such as the fetches, feed dict, run
    677         options, run metadata, and how many `run()` calls to this wrapper
    678         session have occurred.
    679 
    680     Returns:
    681       An instance of `OnRunStartResponse`, carrying information to
    682         1) direct the wrapper session to perform a specified action (e.g., run
    683           with or without debug tensor watching, invoking the stepper.)
    684         2) debug URLs used to watch the tensors.
    685     """
    686 
    687   @abc.abstractmethod
    688   def on_run_end(self, request):
    689     """Callback invoked on run() calls to the debug-wrapper session.
    690 
    691     This is a blocking callback.
    692     The invocation happens right before the wrapper exits its run() call.
    693 
    694     Args:
    695       request: (`OnRunEndRequest`) callback request object carrying information
    696         such as the actual action performed by the session wrapper for the
    697         run() call.
    698 
    699     Returns:
    700       An instance of `OnRunStartResponse`.
    701     """
    702 
    703   def as_default(self):
    704     return ops.default_session(self)
    705 
    706   def __enter__(self):
    707     if self._default_session_context_manager is None:
    708       self._default_session_context_manager = self.as_default()
    709     return self._default_session_context_manager.__enter__()
    710 
    711   def __exit__(self, exec_type, exec_value, exec_tb):
    712     self._default_session_context_manager.__exit__(
    713         exec_type, exec_value, exec_tb)
    714 
    715   def __del__(self):
    716     if hasattr(self._sess, "__del__"):
    717       self._sess.__del__()
    718 
    719   def close(self):
    720     self._sess.close()
    721 
    722   # TODO(cais): Add _node_name_regex_whitelist and
    723   #   _node_op_type_regex_whitelist.
    724 
    725   @abc.abstractmethod
    726   def invoke_node_stepper(self,
    727                           node_stepper,
    728                           restore_variable_values_on_exit=True):
    729     """Callback invoked when the client intends to step through graph nodes.
    730 
    731     Args:
    732       node_stepper: (stepper.NodeStepper) An instance of NodeStepper to be used
    733         in this stepping session.
    734       restore_variable_values_on_exit: (bool) Whether any variables whose values
    735         have been altered during this node-stepper invocation should be restored
    736         to their old values when this invocation ends.
    737 
    738     Returns:
    739       The same return values as the `Session.run()` call on the same fetches as
    740         the NodeStepper.
    741     """
    742 
    743   def should_stop(self):
    744     if hasattr(self._sess, "should_stop"):
    745       return self._sess.should_stop()
    746     else:
    747       raise ValueError(
    748           "The wrapped session %r does not have a method called 'should_stop'. "
    749           "Do you intend to wrap a tf.MonitoredSession instead?" % self._sess)
    750 
    751 
    752 class WatchOptions(object):
    753   """Type for return values of watch_fn."""
    754 
    755   def __init__(self,
    756                debug_ops=None,
    757                node_name_regex_whitelist=None,
    758                op_type_regex_whitelist=None,
    759                tensor_dtype_regex_whitelist=None,
    760                tolerate_debug_op_creation_failures=False):
    761     """Constructor of WatchOptions: Debug watch options.
    762 
    763     Used as return values of `watch_fn`s.
    764 
    765     Args:
    766       debug_ops: (`str` or `list of str`) Debug ops to be used.
    767       node_name_regex_whitelist: Regular-expression whitelist for node_name,
    768         e.g., `"(weight_[0-9]+|bias_.*)"`
    769       op_type_regex_whitelist: Regular-expression whitelist for the op type of
    770         nodes, e.g., `"(Variable|Add)"`.
    771         If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
    772         are set, the two filtering operations will occur in a logical `AND`
    773         relation. In other words, a node will be included if and only if it
    774         hits both whitelists.
    775       tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
    776         data type, e.g., `"^int.*"`.
    777         This whitelist operates in logical `AND` relations to the two whitelists
    778         above.
    779       tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
    780         failures (e.g., due to dtype incompatibility) are to be tolerated by not
    781         throwing exceptions.
    782     """
    783     if debug_ops:
    784       self.debug_ops = debug_ops
    785     else:
    786       self.debug_ops = ["DebugIdentity"]
    787     self.node_name_regex_whitelist = node_name_regex_whitelist
    788     self.op_type_regex_whitelist = op_type_regex_whitelist
    789     self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
    790     self.tolerate_debug_op_creation_failures = (
    791         tolerate_debug_op_creation_failures)
    792 
    793   def __repr__(self):
    794     return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, "
    795             "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, "
    796             "tolerate_debug_op_creation_failures=%r)" % (
    797                 self.debug_ops, self.node_name_regex_whitelist,
    798                 self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist,
    799                 self.tolerate_debug_op_creation_failures))
    800 
    801 
    802 class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
    803   """Base class for non-interactive (i.e., non-CLI) debug wrapper sessions."""
    804 
    805   def __init__(self, sess, watch_fn=None, thread_name_filter=None,
    806                pass_through_operrors=False):
    807     """Constructor of NonInteractiveDebugWrapperSession.
    808 
    809     Args:
    810       sess: The TensorFlow `Session` object being wrapped.
    811       watch_fn: (`Callable`) A Callable that maps the fetches and feeds of a
    812         debugged `Session.run()` call to `WatchOptions.`
    813         * Args:
    814           * `fetches`: the fetches to the `Session.run()` call.
    815           * `feeds`: the feeds to the `Session.run()` call.
    816 
    817         * Returns:
    818          (`tf_debug.WatchOptions`) An object containing debug options including
    819            the debug ops to use, the node names, op types and/or tensor data
    820            types to watch, etc. See the documentation of `tf_debug.WatchOptions`
    821            for more details.
    822       thread_name_filter: Regular-expression white list for threads on which the
    823         wrapper session will be active. See doc of `BaseDebugWrapperSession` for
    824         more details.
    825       pass_through_operrors: If true, all captured OpErrors will be
    826         propagated.  By default this captures all OpErrors.
    827     Raises:
    828        TypeError: If a non-None `watch_fn` is specified and it is not callable.
    829     """
    830 
    831     BaseDebugWrapperSession.__init__(
    832         self, sess, thread_name_filter=thread_name_filter,
    833         pass_through_operrors=pass_through_operrors)
    834 
    835     self._watch_fn = None
    836     if watch_fn is not None:
    837       if not callable(watch_fn):
    838         raise TypeError("watch_fn is not callable")
    839       self._watch_fn = watch_fn
    840 
    841   def on_session_init(self, request):
    842     """See doc of BaseDebugWrapperSession.on_run_start."""
    843 
    844     return OnSessionInitResponse(OnSessionInitAction.PROCEED)
    845 
    846   @abc.abstractmethod
    847   def prepare_run_debug_urls(self, fetches, feed_dict):
    848     """Abstract method to be implemented by concrete subclasses.
    849 
    850     This method prepares the run-specific debug URL(s).
    851 
    852     Args:
    853       fetches: Same as the `fetches` argument to `Session.run()`
    854       feed_dict: Same as the `feed_dict` argument to `Session.run()`
    855 
    856     Returns:
    857       debug_urls: (`str` or `list` of `str`) Debug URLs to be used in
    858         this `Session.run()` call.
    859     """
    860 
    861   def on_run_start(self, request):
    862     """See doc of BaseDebugWrapperSession.on_run_start."""
    863 
    864     debug_urls, watch_opts = self._prepare_run_watch_config(
    865         request.fetches, request.feed_dict)
    866 
    867     return OnRunStartResponse(
    868         OnRunStartAction.DEBUG_RUN,
    869         debug_urls,
    870         debug_ops=watch_opts.debug_ops,
    871         node_name_regex_whitelist=watch_opts.node_name_regex_whitelist,
    872         op_type_regex_whitelist=watch_opts.op_type_regex_whitelist,
    873         tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist,
    874         tolerate_debug_op_creation_failures=(
    875             watch_opts.tolerate_debug_op_creation_failures))
    876 
    877   def _prepare_run_watch_config(self, fetches, feed_dict):
    878     """Get the debug_urls, and node/op whitelists for the current run() call.
    879 
    880     Args:
    881       fetches: Same as the `fetches` argument to `Session.run()`.
    882       feed_dict: Same as the `feed_dict argument` to `Session.run()`.
    883 
    884     Returns:
    885       debug_urls: (str or list of str) Debug URLs for the current run() call.
    886         Currently, the list consists of only one URL that is a file:// URL.
    887       watch_options: (WatchOptions) The return value of a watch_fn, containing
    888         options including debug_ops, and whitelists.
    889     """
    890 
    891     debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
    892     if self._watch_fn is None:
    893       watch_options = WatchOptions()
    894     else:
    895       watch_options = self._watch_fn(fetches, feed_dict)
    896       if isinstance(watch_options, tuple):
    897         # For legacy return type (tuples).
    898         watch_options = WatchOptions(*watch_options)
    899 
    900     return debug_urls, watch_options
    901 
    902   def on_run_end(self, request):
    903     """See doc of BaseDebugWrapperSession.on_run_end."""
    904 
    905     return OnRunEndResponse()
    906 
    907   def invoke_node_stepper(self,
    908                           node_stepper,
    909                           restore_variable_values_on_exit=True):
    910     """See doc of BaseDebugWrapperSession.invoke_node_stepper."""
    911 
    912     raise NotImplementedError(
    913         "NonInteractiveDebugWrapperSession does not support node-stepper mode.")
    914