Home | History | Annotate | Download | only in eager
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Execution Callbacks for Eager Mode."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import contextlib
     22 import functools
     23 import enum  # pylint: disable=g-bad-import-order
     24 
     25 import numpy as np
     26 
     27 from tensorflow.python import pywrap_tensorflow
     28 from tensorflow.python.eager import context
     29 from tensorflow.python.eager import core
     30 from tensorflow.python.eager import execute
     31 from tensorflow.python.platform import tf_logging as logging
     32 
     33 
     34 class ExecutionCallback(enum.Enum):
     35   """Valid callback actions.
     36 
     37   These can be passed to `seterr` or `errstate` to create callbacks when
     38   specific events occur (e.g. an operation produces `NaN`s).
     39 
     40   IGNORE: take no action.
     41   PRINT:  print a warning to `stdout`.
     42   RAISE:  raise an error (e.g. `InfOrNanError`).
     43   WARN:   print a warning using `tf.logging.warn`.
     44   """
     45 
     46   IGNORE = "ignore"
     47   PRINT = "print"
     48   RAISE = "raise"
     49   WARN = "warn"
     50 
     51 _DEFAULT_CALLBACK_ACTION = ExecutionCallback.RAISE
     52 
     53 
     54 # TODO(cais): Consider moving this exception class to errors_impl.py.
     55 class InfOrNanError(Exception):
     56   """Exception for inf and/or nan being present in tensor."""
     57 
     58   def __init__(self,
     59                op_type,
     60                op_name,
     61                output_index,
     62                num_outputs,
     63                value):
     64     """Constructor of InfOrNanError.
     65 
     66     Args:
     67       op_type: Type name of the op that generated the tensor with
     68         `inf`(s) or `nan`(s) (e.g., `Div`).
     69       op_name: Name of the op that generated the tensor with `inf`(s) or
     70         `nan`(s). This name is set by client and can be `None` if it is unset.
     71       output_index: The 0-based output index of the tensor that contains
     72         `inf`(s) or `nan`(s).
     73       num_outputs: Total number of outputs of the operation.
     74       value: The tensor value that contains `inf`(s) or `nan`(s).
     75     """
     76     self._op_type = op_type
     77     self._op_name = op_name
     78     self._output_index = output_index
     79     self._num_outputs = num_outputs
     80     self._value = value
     81 
     82     self._total_count = np.size(value)
     83     self._inf_count = np.count_nonzero(np.isinf(value))
     84     self._nan_count = np.count_nonzero(np.isnan(value))
     85 
     86     super(InfOrNanError, self).__init__(self._get_error_message())
     87 
     88   def _get_error_message(self):
     89     """Get the error message describing this InfOrNanError object."""
     90     name_str = (("'%s'" % self._op_name) if self._op_name is not None
     91                 else str(self._op_name))
     92     msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
     93         self._output_index + 1, self._num_outputs, self._op_type, name_str)
     94     if self._inf_count and self._nan_count:
     95       msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
     96     elif self._inf_count:
     97       msg += "%d inf(s) " % self._inf_count
     98     else:
     99       msg += "%d nan(s) " % self._nan_count
    100     msg += "out of a total of %d element(s). Tensor value: %s" % (
    101         self._total_count, self._value)
    102     return msg
    103 
    104   @property
    105   def op_type(self):
    106     return self._op_type
    107 
    108   @property
    109   def op_name(self):
    110     return self._op_name
    111 
    112   @property
    113   def output_index(self):
    114     return self._output_index
    115 
    116   @property
    117   def num_outputs(self):
    118     return self._num_outputs
    119 
    120   @property
    121   def value(self):
    122     return self._value
    123 
    124 
    125 def inf_nan_callback(op_type,
    126                      inputs,
    127                      attrs,
    128                      outputs,
    129                      op_name,
    130                      check_inf=True,
    131                      check_nan=True,
    132                      action=_DEFAULT_CALLBACK_ACTION):
    133   """An execution callback that checks for `inf`s and `nan`s in output tensors.
    134 
    135   This callback can be used with `tfe.add_execute_callback` to check for invalid
    136   numeric values. E.g.,
    137   ```python
    138   tfe.add_execute_callback(tfe.inf_nan_callback)
    139   ```
    140 
    141   Args:
    142     op_type: Name of the TFE operation type (e.g., `MatMul`).
    143     inputs: The `list` of input tensors to the operation, currently unused by
    144       this callback.
    145     attrs: Attributes of the TFE operation, as a tuple of alternating attribute
    146       names and attribute values.
    147     outputs: The `list` of output tensors from the operation, checked by this
    148       callback for `inf` and `nan` values.
    149     op_name: Name of the TFE operation. This name is set by client and can be
    150       `None` if it unset.
    151     check_inf: (`bool`) Whether this callback should check for `inf` values in
    152       the output tensor values.
    153     check_nan: (`bool`) Whether this callback should check for `nan` values in
    154       the output tensor values.
    155     action: (`ExecutionCallback`) Action to be taken by the callback when
    156       `inf` or `nan` values are detected.
    157 
    158   Raises:
    159     InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
    160       `action` is `"raise"`.
    161     ValueError: iff the value of `action` is invalid.
    162   """
    163   del attrs, inputs  # Not used.
    164 
    165   action = ExecutionCallback(action)
    166   ctx = context.context()
    167 
    168   for index, output in enumerate(outputs):
    169     if not output.dtype.is_numpy_compatible:
    170       continue
    171 
    172     numpy_dtype = output.dtype.as_numpy_dtype
    173     if (np.issubdtype(numpy_dtype, np.floating) or
    174         np.issubdtype(numpy_dtype, np.complex) or
    175         np.issubdtype(numpy_dtype, np.integer)):
    176       try:
    177         check_numerics_op_attrs = (
    178             "message", "Eager-mode inf/nan check",
    179             "T", outputs[0].dtype.as_datatype_enum)
    180         # TODO(cais): Consider moving this into execute.py.
    181         # pylint: disable=protected-access
    182         pywrap_tensorflow.TFE_Py_Execute(
    183             ctx._handle, output.device, "CheckNumerics", [output],
    184             check_numerics_op_attrs, 1)
    185         # pylint: enable=protected-access
    186       except core._NotOkStatusException:  # pylint: disable=protected-access
    187         value = output.numpy()
    188         inf_detected = np.any(np.isinf(value)) and check_inf
    189         nan_detected = np.any(np.isnan(value)) and check_nan
    190         if not inf_detected and not nan_detected:
    191           continue
    192 
    193         error = InfOrNanError(op_type, op_name, index, len(outputs), value)
    194         if action == ExecutionCallback.PRINT:
    195           print("Warning: %s" % str(error))
    196         elif action == ExecutionCallback.WARN:
    197           logging.warn(str(error))
    198         elif action == ExecutionCallback.RAISE:
    199           raise error
    200         else:
    201           raise ValueError(
    202               "Invalid action for inf_nan_callback: %s. Valid actions are: "
    203               "{PRINT | WARN | RAISE}" % action)
    204 
    205 
    206 def inf_callback(op_type,
    207                  inputs,
    208                  attrs,
    209                  outputs,
    210                  op_name,
    211                  action=_DEFAULT_CALLBACK_ACTION):
    212   """A specialization of `inf_nan_callback` that checks for `inf`s only."""
    213   inf_nan_callback(
    214       op_type,
    215       inputs,
    216       attrs,
    217       outputs,
    218       op_name,
    219       check_inf=True,
    220       check_nan=False,
    221       action=action)
    222 
    223 
    224 def nan_callback(op_type,
    225                  inputs,
    226                  attrs,
    227                  outputs,
    228                  op_name,
    229                  action=_DEFAULT_CALLBACK_ACTION):
    230   """A specialization of `inf_nan_callback` that checks for `nan`s only."""
    231   inf_nan_callback(
    232       op_type,
    233       inputs,
    234       attrs,
    235       outputs,
    236       op_name,
    237       check_inf=False,
    238       check_nan=True,
    239       action=action)
    240 
    241 
    242 def add_execution_callback(callback):
    243   """Add an execution callback to the default eager context.
    244 
    245   An execution callback is invoked immediately after an eager operation or
    246   function has finished execution, providing access to the op's type, name
    247   input and output tensors. Multiple execution callbacks can be added, in
    248   which case the callbacks will be invoked in the order in which they are
    249   added. To clear all execution callbacks that have been added, use
    250   `clear_execution_callbacks()`.
    251 
    252   Example:
    253   ```python
    254   def print_even_callback(op_type, op_name, attrs, inputs, outputs):
    255     # A callback that prints only the even output values.
    256     if outputs[0].numpy() % 2 == 0:
    257       print("Even output from %s: %s" % (op_name or op_type,  outputs))
    258   tfe.add_execution_callback(print_even_callback)
    259 
    260   x = tf.pow(2.0, 3.0) - 3.0
    261   y = tf.multiply(x, tf.add(1.0, 5.0))
    262   # When the line above is run, you will see all intermediate outputs that are
    263   # even numbers printed to the console.
    264 
    265   tfe.clear_execution_callbacks()
    266   ```
    267 
    268   Args:
    269     callback: a callable of the signature
    270       `f(op_type, op_name, attrs, inputs, outputs)`.
    271       `op_type` is the type of the operation that was just executed (e.g.,
    272         `MatMul`).
    273       `op_name` is the name of the operation that was just executed. This
    274         name is set by the client who created the operation and can be `None` if
    275         it is unset.
    276       `attrs` contains the attributes of the operation as a `tuple` of
    277         alternating attribute name and attribute value.
    278       `inputs` is the `list` of input `Tensor`(s) to the op.
    279       `outputs` is the `list` of output `Tensor`(s) from the op.
    280        Return value(s) from the callback are ignored.
    281   """
    282   execute.execute = execute.execute_with_callbacks
    283   context.context().add_post_execution_callback(callback)
    284 
    285 
    286 def clear_execution_callbacks():
    287   """Clear all execution callbacks from the default eager context."""
    288   context.context().clear_post_execution_callbacks()
    289 
    290 
    291 def seterr(inf_or_nan=None):
    292   """Set how abnormal conditions are handled by the default eager context.
    293 
    294   Example:
    295   ```python
    296   tfe.seterr(inf_or_nan=ExecutionCallback.RAISE)
    297   a = tf.constant(10.0)
    298   b = tf.constant(0.0)
    299   try:
    300     c = a / b  # <-- Raises InfOrNanError.
    301   except Exception as e:
    302     print("Caught Exception: %s" % e)
    303 
    304   tfe.seterr(inf_or_nan=ExecutionCallback.IGNORE)
    305   c = a / b  # <-- Does NOT raise exception anymore.
    306   ```
    307 
    308   Args:
    309     inf_or_nan: An `ExecutionCallback` determining the action for infinity
    310       (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in
    311       the action of the condition.
    312 
    313   Returns:
    314     A dictionary of old actions.
    315 
    316   Raises:
    317     ValueError: If the value of any keyword arguments is invalid.
    318   """
    319   inf_or_nan = ExecutionCallback(inf_or_nan) if inf_or_nan is not None else None
    320   old_settings = {"inf_or_nan": ExecutionCallback.IGNORE}
    321   default_context = context.context()
    322 
    323   carryover_callbacks = []
    324   for callback in default_context.post_execution_callbacks:
    325     # Check whether the callback is inf_nan_callback or a partial object of
    326     # inf_nan_callback.
    327     if (callback == inf_nan_callback or
    328         isinstance(callback, functools.partial) and
    329         callback.func == inf_nan_callback):
    330       if callback == inf_nan_callback:
    331         old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
    332       else:
    333         old_settings["inf_or_nan"] = callback.keywords.get(
    334             "action", _DEFAULT_CALLBACK_ACTION)
    335     elif inf_or_nan is not None:
    336       carryover_callbacks.append(callback)
    337 
    338   if inf_or_nan is not None:
    339     default_context.clear_post_execution_callbacks()
    340     for callback in carryover_callbacks:
    341       default_context.add_post_execution_callback(callback)
    342     if inf_or_nan != ExecutionCallback.IGNORE:
    343       default_context.add_post_execution_callback(
    344           functools.partial(inf_nan_callback, action=inf_or_nan))
    345 
    346   return old_settings
    347 
    348 
    349 @contextlib.contextmanager
    350 def errstate(inf_or_nan=None):
    351   """Context manager setting error state.
    352 
    353   Example:
    354   ```
    355   c = tf.log(0.)  # -inf
    356 
    357   with errstate(inf_or_nan=ExecutionCallback.RAISE):
    358     tf.log(0.)  # <-- Raises InfOrNanError.
    359   ```
    360 
    361   Args:
    362     inf_or_nan: An `ExecutionCallback` determining the action for infinity
    363       (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in
    364       the action of the condition.
    365 
    366   Yields:
    367     None.
    368 
    369   Raises:
    370     ValueError: If the value of any keyword arguments is invalid.
    371   """
    372   if not context.executing_eagerly():
    373     yield
    374   else:
    375     old_settings = seterr(inf_or_nan=inf_or_nan)
    376     yield
    377     seterr(**old_settings)
    378