Home | History | Annotate | Download | only in eager
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Execution Callbacks for Eager Mode."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python import pywrap_tensorflow
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.eager import core
     28 from tensorflow.python.eager import execute
     29 from tensorflow.python.platform import tf_logging as logging
     30 
     31 _DEFAULT_CALLBACK_ACTION = "raise"
     32 _VALID_CALLBACK_ACTIONS = (None, "ignore", "print", "raise", "warn")
     33 
     34 
     35 # TODO(cais): Consider moving this exception class to errors_impl.py.
     36 class InfOrNanError(Exception):
     37   """Exception for inf and/or nan being present in tensor."""
     38 
     39   def __init__(self,
     40                op_type,
     41                op_name,
     42                output_index,
     43                num_outputs,
     44                value):
     45     """Constructor of InfOrNanError.
     46 
     47     Args:
     48       op_type: Type name of the op that generated the tensor that generated the
     49         `inf`(s) or `nan`(s) (e.g., `Div`).
     50       op_name: Name of the op that generated the tensor with `inf`(s) or
     51         `nan`(s). This name is set by client and can be `None` if it is unset.
     52       output_index: The 0-based output index of the tensor that contains
     53         `inf`(s) or `nan`(s).
     54       num_outputs: Total number of outputs of the operation.
     55       value: The tensor value that contains `inf`(s) or `nan`(s).
     56     """
     57     self._op_type = op_type
     58     self._op_name = op_name
     59     self._output_index = output_index
     60     self._num_outputs = num_outputs
     61     self._value = value
     62 
     63     self._total_count = np.size(value)
     64     self._inf_count = np.count_nonzero(np.isinf(value))
     65     self._nan_count = np.count_nonzero(np.isnan(value))
     66 
     67     super(InfOrNanError, self).__init__(self._get_error_message())
     68 
     69   def _get_error_message(self):
     70     """Get the error message describing this InfOrNanError object."""
     71     name_str = (("'%s'" % self._op_name) if self._op_name is not None
     72                 else str(self._op_name))
     73     msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
     74         self._output_index + 1, self._num_outputs, self._op_type, name_str)
     75     if self._inf_count and self._nan_count:
     76       msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
     77     elif self._inf_count:
     78       msg += "%d inf(s) " % self._inf_count
     79     else:
     80       msg += "%d nan(s) " % self._nan_count
     81     msg += "out of a total of %d element(s). Tensor value: %s" % (
     82         self._total_count, self._value)
     83     return msg
     84 
     85   @property
     86   def op_type(self):
     87     return self._op_type
     88 
     89   @property
     90   def op_name(self):
     91     return self._op_name
     92 
     93   @property
     94   def output_index(self):
     95     return self._output_index
     96 
     97   @property
     98   def num_outputs(self):
     99     return self._num_outputs
    100 
    101   @property
    102   def value(self):
    103     return self._value
    104 
    105 
    106 def inf_nan_callback(op_type,
    107                      inputs,
    108                      attrs,
    109                      outputs,
    110                      op_name,
    111                      check_inf=True,
    112                      check_nan=True,
    113                      action=_DEFAULT_CALLBACK_ACTION):
    114   """An execution callback that checks for `inf`s and `nan`s in output tensors.
    115 
    116   This callback can be used with `tfe.add_execute_callback` to check for invalid
    117   numeric values. E.g.,
    118   ```python
    119   tfe.add_execute_callback(tfe.inf_nan_callback)
    120   ```
    121 
    122   Args:
    123     op_type: Name of the TFE operation type (e.g., `MatMul`).
    124     inputs: The `list` of input tensors to the operation, currently unused by
    125       this callback.
    126     attrs: Attributes of the TFE operation, as a tuple of alternating attribute
    127       names and attribute values.
    128     outputs: The `list` of output tensors from the operation, checked by this
    129       callback for `inf` and `nan` values.
    130     op_name: Name of the TFE operation. This name is set by client and can be
    131       `None` if it unset.
    132     check_inf: (`bool`) Whether this callback should check for `inf` values in
    133       the output tensor values.
    134     check_nan: (`bool`) Whether this callback should check for `nan` values in
    135       the output tensor values.
    136     action: (`str`) Action to be taken by the callback when `inf` or `nan`
    137       values are detected. Possible values {"raise", "warn", "print"}
    138       `"raise"`: Raise a `InfOrNanError`.
    139       `"warn"`: Log a warning using `tf.logging.warn`.
    140       `"print"`: Print a message to `sys.stdout`.
    141 
    142   Raises:
    143     InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
    144       `action` is `"raise"`.
    145     ValueError: iff the value of `action` is invalid.
    146   """
    147   del attrs, inputs  # Not used.
    148 
    149   ctx = context.get_default_context()
    150 
    151   for index, output in enumerate(outputs):
    152     if not output.dtype.is_numpy_compatible:
    153       continue
    154 
    155     numpy_dtype = output.dtype.as_numpy_dtype
    156     if (np.issubdtype(numpy_dtype, np.floating) or
    157         np.issubdtype(numpy_dtype, np.complex) or
    158         np.issubdtype(numpy_dtype, np.integer)):
    159       try:
    160         check_numerics_op_attrs = (
    161             "message", "Eager-mode inf/nan check",
    162             "T", outputs[0].dtype.as_datatype_enum)
    163         # TODO(cais): Consider moving this into execute.py.
    164         # pylint: disable=protected-access
    165         pywrap_tensorflow.TFE_Py_Execute(
    166             ctx._handle, output.device, "CheckNumerics", [output],
    167             check_numerics_op_attrs, 1)
    168         # pylint: enable=protected-access
    169       except core._NotOkStatusException:  # pylint: disable=protected-access
    170         value = output.numpy()
    171         inf_detected = np.any(np.isinf(value)) and check_inf
    172         nan_detected = np.any(np.isnan(value)) and check_nan
    173         if not inf_detected and not nan_detected:
    174           continue
    175 
    176         error = InfOrNanError(op_type, op_name, index, len(outputs), value)
    177         if action == "print":
    178           print("Warning: %s" % str(error))
    179         elif action == "warn":
    180           logging.warn(str(error))
    181         elif action == "raise":
    182           raise error
    183         else:
    184           raise ValueError(
    185               "Invalid action for inf_nan_callback: %s. Valid actions are: "
    186               "{print | warn | raise}" % action)
    187 
    188 
    189 def inf_callback(op_type,
    190                  inputs,
    191                  attrs,
    192                  outputs,
    193                  op_name,
    194                  action=_DEFAULT_CALLBACK_ACTION):
    195   """A specialization of `inf_nan_callback` that checks for `inf`s only."""
    196   inf_nan_callback(
    197       op_type,
    198       inputs,
    199       attrs,
    200       outputs,
    201       op_name,
    202       check_inf=True,
    203       check_nan=False,
    204       action=action)
    205 
    206 
    207 def nan_callback(op_type,
    208                  inputs,
    209                  attrs,
    210                  outputs,
    211                  op_name,
    212                  action=_DEFAULT_CALLBACK_ACTION):
    213   """A specialization of `inf_nan_callback` that checks for `nan`s only."""
    214   inf_nan_callback(
    215       op_type,
    216       inputs,
    217       attrs,
    218       outputs,
    219       op_name,
    220       check_inf=False,
    221       check_nan=True,
    222       action=action)
    223 
    224 
    225 def add_execution_callback(callback):
    226   """Add an execution callback to the default eager context.
    227 
    228   An execution callback is invoked immediately after an eager operation or
    229   function has finished execution, providing access to the op's type, name
    230   input and output tensors. Multiple execution callbacks can be added, in
    231   which case the callbacks will be invoked in the order in which they are
    232   added. To clear all execution callbacks that have been added, use
    233   `clear_execution_callbacks()`.
    234 
    235   Example:
    236   ```python
    237   def print_even_callback(op_type, op_name, attrs, inputs, outputs):
    238     # A callback that prints only the even output values.
    239     if outputs[0].numpy() % 2 == 0:
    240       print("Even output from %s: %s" % (op_name or op_type,  outputs))
    241   tfe.add_execution_callback(print_even_callback)
    242 
    243   x = tf.pow(2.0, 3.0) - 3.0
    244   y = tf.multiply(x, tf.add(1.0, 5.0))
    245   # When the line above is run, you will see all intermediate outputs that are
    246   # even numbers printed to the console.
    247 
    248   tfe.clear_execution_callbacks()
    249   ```
    250 
    251   Args:
    252     callback: a callable of the signature
    253       `f(op_type, op_name, attrs, inputs, outputs)`.
    254       `op_type` is the type of the operation that was just executed (e.g.,
    255         `MatMul`).
    256       `op_name` is the name of the operation that has was just executed. This
    257         name is set by the client who created the operation and can be `None` if
    258         it is unset.
    259       `attrs` contains the attributes of the operation as a `tuple` of
    260         alternating attribute name and attribute value.
    261       `inputs` is the `list` of input `Tensor`(s) to the op.
    262       `outputs` is the `list` of output `Tensor`(s) from the op.
    263        Return value(s) from the callback are ignored.
    264   """
    265   execute.execute = execute.execute_with_callbacks
    266   context.get_default_context().add_post_execution_callback(callback)
    267 
    268 
    269 def clear_execution_callbacks():
    270   """Clear all execution callbacks from the default eager context."""
    271   context.get_default_context().clear_post_execution_callbacks()
    272 
    273 
    274 def seterr(inf_or_nan=None):
    275   """Set how abnormal conditions are handled by the default eager context.
    276 
    277   Example:
    278   ```python
    279   tfe.seterr(inf_or_nan="raise")
    280   a = tf.constant(10.0)
    281   b = tf.constant(0.0)
    282   try:
    283     c = a / b  # <-- Raises InfOrNanError.
    284   except Exception as e:
    285     print("Caught Exception: %s" % e)
    286 
    287   tfe.seterr(inf_or_nan="ignore")
    288   c = a / b  # <-- Does NOT raise exception anymore.
    289   ```
    290 
    291   Args:
    292     inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values.
    293       Possible values: `{"ignore", "print", "raise", "warn"}`.
    294       `"ignore"`: take no action when `inf` values appear.
    295       `"print"`: print a warning to `stdout`.
    296       `"raise"`: raise an `InfOrNanError`.
    297       `"warn"`: print a warning using `tf.logging.warn`.
    298       A value of `None` leads to no change in the action of the condition.
    299 
    300   Returns:
    301     A dictionary of old actions.
    302 
    303   Raises:
    304     ValueError: If the value of any keyword arguments is invalid.
    305   """
    306   if inf_or_nan not in _VALID_CALLBACK_ACTIONS:
    307     raise ValueError(
    308         "Invalid action value for inf_or_nan: %s. "
    309         "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
    310 
    311   old_settings = {"inf_or_nan": "ignore"}
    312   default_context = context.get_default_context()
    313 
    314   carryover_callbacks = []
    315   for callback in default_context.post_execution_callbacks:
    316     # Check whether the callback is inf_nan_callback or a partial object of
    317     # inf_nan_callback.
    318     if (callback == inf_nan_callback or
    319         isinstance(callback, functools.partial) and
    320         callback.func == inf_nan_callback):
    321       if callback == inf_nan_callback:
    322         old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
    323       else:
    324         old_settings["inf_or_nan"] = callback.keywords.get(
    325             "action", _DEFAULT_CALLBACK_ACTION)
    326     elif inf_or_nan is not None:
    327       carryover_callbacks.append(callback)
    328 
    329   if inf_or_nan is not None:
    330     default_context.clear_post_execution_callbacks()
    331     for callback in carryover_callbacks:
    332       default_context.add_post_execution_callback(callback)
    333     if inf_or_nan != "ignore":
    334       default_context.add_post_execution_callback(
    335           functools.partial(inf_nan_callback, action=inf_or_nan))
    336 
    337   return old_settings
    338