Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Connects all half, float and double tensors to CheckNumericsOp."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 from tensorflow.python.eager import context
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import control_flow_ops
     27 from tensorflow.python.util.tf_export import tf_export
     28 
     29 
     30 @tf_export("verify_tensor_all_finite")
     31 def verify_tensor_all_finite(t, msg, name=None):
     32   """Assert that the tensor does not contain any NaN's or Inf's.
     33 
     34   Args:
     35     t: Tensor to check.
     36     msg: Message to log on failure.
     37     name: A name for this operation (optional).
     38 
     39   Returns:
     40     Same tensor as `t`.
     41   """
     42   with ops.name_scope(name, "VerifyFinite", [t]) as name:
     43     t = ops.convert_to_tensor(t, name="t")
     44     with ops.colocate_with(t):
     45       verify_input = array_ops.check_numerics(t, message=msg)
     46       out = control_flow_ops.with_dependencies([verify_input], t)
     47   return out
     48 
     49 
     50 @tf_export("add_check_numerics_ops")
     51 def add_check_numerics_ops():
     52   """Connect a `check_numerics` to every floating point tensor.
     53 
     54   `check_numerics` operations themselves are added for each `half`, `float`,
     55   or `double` tensor in the graph. For all ops in the graph, the
     56   `check_numerics` op for all of its (`half`, `float`, or `double`) inputs
     57   is guaranteed to run before the `check_numerics` op on any of its outputs.
     58 
     59   Note: This API is not compatible with the use of @{tf.cond} or
     60   @{tf.while_loop}, and will raise a `ValueError` if you attempt to call it
     61   in such a graph.
     62 
     63   Returns:
     64     A `group` op depending on all `check_numerics` ops added.
     65 
     66   Raises:
     67     ValueError: If the graph contains any numeric operations in a control flow
     68       structure.
     69     RuntimeError: If called with eager execution enabled.
     70 
     71   @compatibility(eager)
     72   Not compatible with eager execution. To check for `Inf`s and `NaN`s under
     73   eager execution, call tfe.seterr(inf_or_nan='raise') once before executing
     74   the checked operations.
     75   @enc_compatibility
     76   """
     77   if context.in_eager_mode():
     78     raise RuntimeError(
     79         "add_check_numerics_ops() is not compatible with eager execution. "
     80         "To check for Inf's and NaN's under eager execution, call "
     81         "tfe.seterr(inf_or_nan='raise') once before executing the "
     82         "checked operations.")
     83 
     84   check_op = []
     85   # This code relies on the ordering of ops in get_operations().
     86   # The producer of a tensor always comes before that tensor's consumer in
     87   # this list. This is true because get_operations() returns ops in the order
     88   # added, and an op can only be added after its inputs are added.
     89   for op in ops.get_default_graph().get_operations():
     90     for output in op.outputs:
     91       if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
     92         if op._get_control_flow_context() is not None:  # pylint: disable=protected-access
     93           raise ValueError("`tf.add_check_numerics_ops() is not compatible "
     94                            "with TensorFlow control flow operations such as "
     95                            "`tf.cond()` or `tf.while_loop()`.")
     96 
     97         message = op.name + ":" + str(output.value_index)
     98         with ops.control_dependencies(check_op):
     99           check_op = [array_ops.check_numerics(output, message=message)]
    100   return control_flow_ops.group(*check_op)
    101