1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Connects all half, float and double tensors to CheckNumericsOp.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 from tensorflow.python.eager import context 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import control_flow_ops 27 from tensorflow.python.util.tf_export import tf_export 28 29 30 @tf_export("verify_tensor_all_finite") 31 def verify_tensor_all_finite(t, msg, name=None): 32 """Assert that the tensor does not contain any NaN's or Inf's. 33 34 Args: 35 t: Tensor to check. 36 msg: Message to log on failure. 37 name: A name for this operation (optional). 38 39 Returns: 40 Same tensor as `t`. 41 """ 42 with ops.name_scope(name, "VerifyFinite", [t]) as name: 43 t = ops.convert_to_tensor(t, name="t") 44 with ops.colocate_with(t): 45 verify_input = array_ops.check_numerics(t, message=msg) 46 out = control_flow_ops.with_dependencies([verify_input], t) 47 return out 48 49 50 @tf_export("add_check_numerics_ops") 51 def add_check_numerics_ops(): 52 """Connect a `check_numerics` to every floating point tensor. 53 54 `check_numerics` operations themselves are added for each `half`, `float`, 55 or `double` tensor in the graph. For all ops in the graph, the 56 `check_numerics` op for all of its (`half`, `float`, or `double`) inputs 57 is guaranteed to run before the `check_numerics` op on any of its outputs. 58 59 Note: This API is not compatible with the use of @{tf.cond} or 60 @{tf.while_loop}, and will raise a `ValueError` if you attempt to call it 61 in such a graph. 62 63 Returns: 64 A `group` op depending on all `check_numerics` ops added. 65 66 Raises: 67 ValueError: If the graph contains any numeric operations in a control flow 68 structure. 69 RuntimeError: If called with eager execution enabled. 70 71 @compatibility(eager) 72 Not compatible with eager execution. To check for `Inf`s and `NaN`s under 73 eager execution, call tfe.seterr(inf_or_nan='raise') once before executing 74 the checked operations. 75 @enc_compatibility 76 """ 77 if context.in_eager_mode(): 78 raise RuntimeError( 79 "add_check_numerics_ops() is not compatible with eager execution. " 80 "To check for Inf's and NaN's under eager execution, call " 81 "tfe.seterr(inf_or_nan='raise') once before executing the " 82 "checked operations.") 83 84 check_op = [] 85 # This code relies on the ordering of ops in get_operations(). 86 # The producer of a tensor always comes before that tensor's consumer in 87 # this list. This is true because get_operations() returns ops in the order 88 # added, and an op can only be added after its inputs are added. 89 for op in ops.get_default_graph().get_operations(): 90 for output in op.outputs: 91 if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 92 if op._get_control_flow_context() is not None: # pylint: disable=protected-access 93 raise ValueError("`tf.add_check_numerics_ops() is not compatible " 94 "with TensorFlow control flow operations such as " 95 "`tf.cond()` or `tf.while_loop()`.") 96 97 message = op.name + ":" + str(output.value_index) 98 with ops.control_dependencies(check_op): 99 check_op = [array_ops.check_numerics(output, message=message)] 100 return control_flow_ops.group(*check_op) 101