Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=g-short-docstring-punctuation
     16 """Asserts and Boolean Checks.
     17 
     18 See the @{$python/check_ops} guide.
     19 
     20 @@assert_negative
     21 @@assert_positive
     22 @@assert_non_negative
     23 @@assert_non_positive
     24 @@assert_equal
     25 @@assert_none_equal
     26 @@assert_near
     27 @@assert_less
     28 @@assert_less_equal
     29 @@assert_greater
     30 @@assert_greater_equal
     31 @@assert_rank
     32 @@assert_rank_at_least
     33 @@assert_rank_in
     34 @@assert_type
     35 @@assert_integer
     36 @@assert_proper_iterable
     37 @@assert_same_float_dtype
     38 @@assert_scalar
     39 @@is_non_decreasing
     40 @@is_numeric_tensor
     41 @@is_strictly_increasing
     42 """
     43 
     44 from __future__ import absolute_import
     45 from __future__ import division
     46 from __future__ import print_function
     47 
     48 import numpy as np
     49 
     50 from tensorflow.python.eager import context
     51 from tensorflow.python.framework import dtypes
     52 from tensorflow.python.framework import errors
     53 from tensorflow.python.framework import ops
     54 from tensorflow.python.framework import sparse_tensor
     55 from tensorflow.python.framework import tensor_util
     56 from tensorflow.python.ops import array_ops
     57 from tensorflow.python.ops import control_flow_ops
     58 from tensorflow.python.ops import math_ops
     59 from tensorflow.python.util import compat
     60 from tensorflow.python.util.tf_export import tf_export
     61 
     62 NUMERIC_TYPES = frozenset(
     63     [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32,
     64      dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8,
     65      dtypes.complex64])
     66 
     67 __all__ = [
     68     'assert_negative',
     69     'assert_positive',
     70     'assert_proper_iterable',
     71     'assert_non_negative',
     72     'assert_non_positive',
     73     'assert_equal',
     74     'assert_none_equal',
     75     'assert_near',
     76     'assert_integer',
     77     'assert_less',
     78     'assert_less_equal',
     79     'assert_greater',
     80     'assert_greater_equal',
     81     'assert_rank',
     82     'assert_rank_at_least',
     83     'assert_rank_in',
     84     'assert_same_float_dtype',
     85     'assert_scalar',
     86     'assert_type',
     87     'is_non_decreasing',
     88     'is_numeric_tensor',
     89     'is_strictly_increasing',
     90 ]
     91 
     92 
     93 def _maybe_constant_value_string(t):
     94   if not isinstance(t, ops.Tensor):
     95     return str(t)
     96   const_t = tensor_util.constant_value(t)
     97   if const_t is not None:
     98     return str(const_t)
     99   return t
    100 
    101 
    102 def _assert_static(condition, data):
    103   """Raises a InvalidArgumentError with as much information as possible."""
    104   if not condition:
    105     data_static = [_maybe_constant_value_string(x) for x in data]
    106     raise errors.InvalidArgumentError(node_def=None, op=None,
    107                                       message='\n'.join(data_static))
    108 
    109 
    110 def _shape_and_dtype_str(tensor):
    111   """Returns a string containing tensor's shape and dtype."""
    112   return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
    113 
    114 
    115 @tf_export('assert_proper_iterable')
    116 def assert_proper_iterable(values):
    117   """Static assert that values is a "proper" iterable.
    118 
    119   `Ops` that expect iterables of `Tensor` can call this to validate input.
    120   Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
    121 
    122   Args:
    123     values:  Object to be checked.
    124 
    125   Raises:
    126     TypeError:  If `values` is not iterable or is one of
    127       `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
    128   """
    129   unintentional_iterables = (
    130       (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
    131       + compat.bytes_or_text_types
    132   )
    133   if isinstance(values, unintentional_iterables):
    134     raise TypeError(
    135         'Expected argument "values" to be a "proper" iterable.  Found: %s' %
    136         type(values))
    137 
    138   if not hasattr(values, '__iter__'):
    139     raise TypeError(
    140         'Expected argument "values" to be iterable.  Found: %s' % type(values))
    141 
    142 
    143 @tf_export('assert_negative')
    144 def assert_negative(x, data=None, summarize=None, message=None, name=None):
    145   """Assert the condition `x < 0` holds element-wise.
    146 
    147   Example of adding a dependency to an operation:
    148 
    149   ```python
    150   with tf.control_dependencies([tf.assert_negative(x)]):
    151     output = tf.reduce_sum(x)
    152   ```
    153 
    154   Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`.
    155   If `x` is empty this is trivially satisfied.
    156 
    157   Args:
    158     x:  Numeric `Tensor`.
    159     data:  The tensors to print out if the condition is False.  Defaults to
    160       error message and first few entries of `x`.
    161     summarize: Print this many entries of each tensor.
    162     message: A string to prefix to the default message.
    163     name: A name for this operation (optional).  Defaults to "assert_negative".
    164 
    165   Returns:
    166     Op raising `InvalidArgumentError` unless `x` is all negative.
    167   """
    168   message = message or ''
    169   with ops.name_scope(name, 'assert_negative', [x, data]):
    170     x = ops.convert_to_tensor(x, name='x')
    171     if data is None:
    172       if context.in_eager_mode():
    173         name = _shape_and_dtype_str(x)
    174       else:
    175         name = x.name
    176       data = [
    177           message,
    178           'Condition x < 0 did not hold element-wise:',
    179           'x (%s) = ' % name, x]
    180     zero = ops.convert_to_tensor(0, dtype=x.dtype)
    181     return assert_less(x, zero, data=data, summarize=summarize)
    182 
    183 
    184 @tf_export('assert_positive')
    185 def assert_positive(x, data=None, summarize=None, message=None, name=None):
    186   """Assert the condition `x > 0` holds element-wise.
    187 
    188   Example of adding a dependency to an operation:
    189 
    190   ```python
    191   with tf.control_dependencies([tf.assert_positive(x)]):
    192     output = tf.reduce_sum(x)
    193   ```
    194 
    195   Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`.
    196   If `x` is empty this is trivially satisfied.
    197 
    198   Args:
    199     x:  Numeric `Tensor`.
    200     data:  The tensors to print out if the condition is False.  Defaults to
    201       error message and first few entries of `x`.
    202     summarize: Print this many entries of each tensor.
    203     message: A string to prefix to the default message.
    204     name: A name for this operation (optional).  Defaults to "assert_positive".
    205 
    206   Returns:
    207     Op raising `InvalidArgumentError` unless `x` is all positive.
    208   """
    209   message = message or ''
    210   with ops.name_scope(name, 'assert_positive', [x, data]):
    211     x = ops.convert_to_tensor(x, name='x')
    212     if data is None:
    213       if context.in_eager_mode():
    214         name = _shape_and_dtype_str(x)
    215       else:
    216         name = x.name
    217       data = [
    218           message, 'Condition x > 0 did not hold element-wise:',
    219           'x (%s) = ' % name, x]
    220     zero = ops.convert_to_tensor(0, dtype=x.dtype)
    221     return assert_less(zero, x, data=data, summarize=summarize)
    222 
    223 
    224 @tf_export('assert_non_negative')
    225 def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
    226   """Assert the condition `x >= 0` holds element-wise.
    227 
    228   Example of adding a dependency to an operation:
    229 
    230   ```python
    231   with tf.control_dependencies([tf.assert_non_negative(x)]):
    232     output = tf.reduce_sum(x)
    233   ```
    234 
    235   Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`.
    236   If `x` is empty this is trivially satisfied.
    237 
    238   Args:
    239     x:  Numeric `Tensor`.
    240     data:  The tensors to print out if the condition is False.  Defaults to
    241       error message and first few entries of `x`.
    242     summarize: Print this many entries of each tensor.
    243     message: A string to prefix to the default message.
    244     name: A name for this operation (optional).
    245       Defaults to "assert_non_negative".
    246 
    247   Returns:
    248     Op raising `InvalidArgumentError` unless `x` is all non-negative.
    249   """
    250   message = message or ''
    251   with ops.name_scope(name, 'assert_non_negative', [x, data]):
    252     x = ops.convert_to_tensor(x, name='x')
    253     if data is None:
    254       if context.in_eager_mode():
    255         name = _shape_and_dtype_str(x)
    256       else:
    257         name = x.name
    258       data = [
    259           message,
    260           'Condition x >= 0 did not hold element-wise:',
    261           'x (%s) = ' % name, x]
    262     zero = ops.convert_to_tensor(0, dtype=x.dtype)
    263     return assert_less_equal(zero, x, data=data, summarize=summarize)
    264 
    265 
    266 @tf_export('assert_non_positive')
    267 def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
    268   """Assert the condition `x <= 0` holds element-wise.
    269 
    270   Example of adding a dependency to an operation:
    271 
    272   ```python
    273   with tf.control_dependencies([tf.assert_non_positive(x)]):
    274     output = tf.reduce_sum(x)
    275   ```
    276 
    277   Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`.
    278   If `x` is empty this is trivially satisfied.
    279 
    280   Args:
    281     x:  Numeric `Tensor`.
    282     data:  The tensors to print out if the condition is False.  Defaults to
    283       error message and first few entries of `x`.
    284     summarize: Print this many entries of each tensor.
    285     message: A string to prefix to the default message.
    286     name: A name for this operation (optional).
    287       Defaults to "assert_non_positive".
    288 
    289   Returns:
    290     Op raising `InvalidArgumentError` unless `x` is all non-positive.
    291   """
    292   message = message or ''
    293   with ops.name_scope(name, 'assert_non_positive', [x, data]):
    294     x = ops.convert_to_tensor(x, name='x')
    295     if data is None:
    296       if context.in_eager_mode():
    297         name = _shape_and_dtype_str(x)
    298       else:
    299         name = x.name
    300       data = [
    301           message,
    302           'Condition x <= 0 did not hold element-wise:'
    303           'x (%s) = ' % name, x]
    304     zero = ops.convert_to_tensor(0, dtype=x.dtype)
    305     return assert_less_equal(x, zero, data=data, summarize=summarize)
    306 
    307 
    308 @tf_export('assert_equal')
    309 def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
    310   """Assert the condition `x == y` holds element-wise.
    311 
    312   Example of adding a dependency to an operation:
    313 
    314   ```python
    315   with tf.control_dependencies([tf.assert_equal(x, y)]):
    316     output = tf.reduce_sum(x)
    317   ```
    318 
    319   This condition holds if for every pair of (possibly broadcast) elements
    320   `x[i]`, `y[i]`, we have `x[i] == y[i]`.
    321   If both `x` and `y` are empty, this is trivially satisfied.
    322 
    323   Args:
    324     x:  Numeric `Tensor`.
    325     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    326     data:  The tensors to print out if the condition is False.  Defaults to
    327       error message and first few entries of `x`, `y`.
    328     summarize: Print this many entries of each tensor.
    329     message: A string to prefix to the default message.
    330     name: A name for this operation (optional).  Defaults to "assert_equal".
    331 
    332   Returns:
    333     Op that raises `InvalidArgumentError` if `x == y` is False.
    334     @compatibility{eager} returns None
    335 
    336   Raises:
    337     InvalidArgumentError: if the check can be performed immediately and
    338       `x == y` is False. The check can be performed immediately during eager
    339       execution or if `x` and `y` are statically known.
    340   """
    341   message = message or ''
    342   with ops.name_scope(name, 'assert_equal', [x, y, data]):
    343     x = ops.convert_to_tensor(x, name='x')
    344     y = ops.convert_to_tensor(y, name='y')
    345 
    346     if context.in_eager_mode():
    347       eq = math_ops.equal(x, y)
    348       condition = math_ops.reduce_all(eq)
    349       if not condition:
    350         # Prepare a message with first elements of x and y.
    351         summary_msg = ''
    352         # Default to printing 3 elements like control_flow_ops.Assert (used
    353         # by graph mode) does.
    354         summarize = 3 if summarize is None else summarize
    355         if summarize:
    356           # reshape((-1,)) is the fastest way to get a flat array view.
    357           x_np = x.numpy().reshape((-1,))
    358           y_np = y.numpy().reshape((-1,))
    359           x_sum = min(x_np.size, summarize)
    360           y_sum = min(y_np.size, summarize)
    361           summary_msg = ('First %d elements of x:\n%s\n'
    362                          'First %d elements of y:\n%s\n' %
    363                          (x_sum, x_np[:x_sum],
    364                           y_sum, y_np[:y_sum]))
    365 
    366         # Get the values that actually differed and their indices.
    367         mask = math_ops.logical_not(eq)
    368         indices = array_ops.where(mask)
    369         indices_np = indices.numpy()
    370         x_vals = array_ops.boolean_mask(x, mask)
    371         y_vals = array_ops.boolean_mask(y, mask)
    372         summarize = min(summarize, indices_np.shape[0])
    373 
    374         raise errors.InvalidArgumentError(
    375             node_def=None, op=None,
    376             message=('%s\nCondition x == y did not hold.\n'
    377                      'Indices of first %s different values:\n%s\n'
    378                      'Corresponding x values:\n%s\n'
    379                      'Corresponding y values:\n%s\n'
    380                      '%s'
    381                      %
    382                      (message or '',
    383                       summarize, indices_np[:summarize],
    384                       x_vals.numpy().reshape((-1,))[:summarize],
    385                       y_vals.numpy().reshape((-1,))[:summarize],
    386                       summary_msg)))
    387       return
    388 
    389     if data is None:
    390       data = [
    391           message,
    392           'Condition x == y did not hold element-wise:',
    393           'x (%s) = ' % x.name, x,
    394           'y (%s) = ' % y.name, y
    395       ]
    396     condition = math_ops.reduce_all(math_ops.equal(x, y))
    397     x_static = tensor_util.constant_value(x)
    398     y_static = tensor_util.constant_value(y)
    399     if x_static is not None and y_static is not None:
    400       condition_static = (x_static == y_static).all()
    401       _assert_static(condition_static, data)
    402     return control_flow_ops.Assert(condition, data, summarize=summarize)
    403 
    404 
    405 @tf_export('assert_none_equal')
    406 def assert_none_equal(
    407     x, y, data=None, summarize=None, message=None, name=None):
    408   """Assert the condition `x != y` holds for all elements.
    409 
    410   Example of adding a dependency to an operation:
    411 
    412   ```python
    413   with tf.control_dependencies([tf.assert_none_equal(x, y)]):
    414     output = tf.reduce_sum(x)
    415   ```
    416 
    417   This condition holds if for every pair of (possibly broadcast) elements
    418   `x[i]`, `y[i]`, we have `x[i] != y[i]`.
    419   If both `x` and `y` are empty, this is trivially satisfied.
    420 
    421   Args:
    422     x:  Numeric `Tensor`.
    423     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    424     data:  The tensors to print out if the condition is False.  Defaults to
    425       error message and first few entries of `x`, `y`.
    426     summarize: Print this many entries of each tensor.
    427     message: A string to prefix to the default message.
    428     name: A name for this operation (optional).
    429       Defaults to "assert_none_equal".
    430 
    431   Returns:
    432     Op that raises `InvalidArgumentError` if `x != y` is ever False.
    433   """
    434   message = message or ''
    435   with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
    436     x = ops.convert_to_tensor(x, name='x')
    437     y = ops.convert_to_tensor(y, name='y')
    438     if context.in_eager_mode():
    439       x_name = _shape_and_dtype_str(x)
    440       y_name = _shape_and_dtype_str(y)
    441     else:
    442       x_name = x.name
    443       y_name = y.name
    444 
    445     if data is None:
    446       data = [
    447           message,
    448           'Condition x != y did not hold for every single element:',
    449           'x (%s) = ' % x_name, x,
    450           'y (%s) = ' % y_name, y
    451       ]
    452     condition = math_ops.reduce_all(math_ops.not_equal(x, y))
    453     return control_flow_ops.Assert(condition, data, summarize=summarize)
    454 
    455 
    456 @tf_export('assert_near')
    457 def assert_near(
    458     x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
    459     name=None):
    460   """Assert the condition `x` and `y` are close element-wise.
    461 
    462   Example of adding a dependency to an operation:
    463 
    464   ```python
    465   with tf.control_dependencies([tf.assert_near(x, y)]):
    466     output = tf.reduce_sum(x)
    467   ```
    468 
    469   This condition holds if for every pair of (possibly broadcast) elements
    470   `x[i]`, `y[i]`, we have
    471 
    472   ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
    473 
    474   If both `x` and `y` are empty, this is trivially satisfied.
    475 
    476   The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
    477   representable positive number such that `1 + eps != eps`.  This is about
    478   `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
    479   See `numpy.finfo`.
    480 
    481   Args:
    482     x:  Float or complex `Tensor`.
    483     y:  Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
    484     rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
    485       The relative tolerance.  Default is `10 * eps`.
    486     atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
    487       The absolute tolerance.  Default is `10 * eps`.
    488     data:  The tensors to print out if the condition is False.  Defaults to
    489       error message and first few entries of `x`, `y`.
    490     summarize: Print this many entries of each tensor.
    491     message: A string to prefix to the default message.
    492     name: A name for this operation (optional).  Defaults to "assert_near".
    493 
    494   Returns:
    495     Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
    496 
    497   @compatibility(numpy)
    498   Similar to `numpy.assert_allclose`, except tolerance depends on data type.
    499   This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`,
    500   and even `16bit` data.
    501   @end_compatibility
    502   """
    503   message = message or ''
    504   with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
    505     x = ops.convert_to_tensor(x, name='x')
    506     y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
    507 
    508     eps = np.finfo(x.dtype.as_numpy_dtype).eps
    509     rtol = 10 * eps if rtol is None else rtol
    510     atol = 10 * eps if atol is None else atol
    511 
    512     rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
    513     atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
    514 
    515     if context.in_eager_mode():
    516       x_name = _shape_and_dtype_str(x)
    517       y_name = _shape_and_dtype_str(y)
    518     else:
    519       x_name = x.name
    520       y_name = y.name
    521 
    522     if data is None:
    523       data = [
    524           message,
    525           'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
    526           'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
    527       ]
    528     tol = atol + rtol * math_ops.abs(y)
    529     diff = math_ops.abs(x - y)
    530     condition = math_ops.reduce_all(math_ops.less(diff, tol))
    531     return control_flow_ops.Assert(condition, data, summarize=summarize)
    532 
    533 
    534 @tf_export('assert_less')
    535 def assert_less(x, y, data=None, summarize=None, message=None, name=None):
    536   """Assert the condition `x < y` holds element-wise.
    537 
    538   Example of adding a dependency to an operation:
    539 
    540   ```python
    541   with tf.control_dependencies([tf.assert_less(x, y)]):
    542     output = tf.reduce_sum(x)
    543   ```
    544 
    545   This condition holds if for every pair of (possibly broadcast) elements
    546   `x[i]`, `y[i]`, we have `x[i] < y[i]`.
    547   If both `x` and `y` are empty, this is trivially satisfied.
    548 
    549   Args:
    550     x:  Numeric `Tensor`.
    551     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    552     data:  The tensors to print out if the condition is False.  Defaults to
    553       error message and first few entries of `x`, `y`.
    554     summarize: Print this many entries of each tensor.
    555     message: A string to prefix to the default message.
    556     name: A name for this operation (optional).  Defaults to "assert_less".
    557 
    558   Returns:
    559     Op that raises `InvalidArgumentError` if `x < y` is False.
    560   """
    561   message = message or ''
    562   with ops.name_scope(name, 'assert_less', [x, y, data]):
    563     x = ops.convert_to_tensor(x, name='x')
    564     y = ops.convert_to_tensor(y, name='y')
    565     if context.in_eager_mode():
    566       x_name = _shape_and_dtype_str(x)
    567       y_name = _shape_and_dtype_str(y)
    568     else:
    569       x_name = x.name
    570       y_name = y.name
    571 
    572     if data is None:
    573       data = [
    574           message,
    575           'Condition x < y did not hold element-wise:',
    576           'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
    577       ]
    578     condition = math_ops.reduce_all(math_ops.less(x, y))
    579     return control_flow_ops.Assert(condition, data, summarize=summarize)
    580 
    581 
    582 @tf_export('assert_less_equal')
    583 def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
    584   """Assert the condition `x <= y` holds element-wise.
    585 
    586   Example of adding a dependency to an operation:
    587 
    588   ```python
    589   with tf.control_dependencies([tf.assert_less_equal(x, y)]):
    590     output = tf.reduce_sum(x)
    591   ```
    592 
    593   This condition holds if for every pair of (possibly broadcast) elements
    594   `x[i]`, `y[i]`, we have `x[i] <= y[i]`.
    595   If both `x` and `y` are empty, this is trivially satisfied.
    596 
    597   Args:
    598     x:  Numeric `Tensor`.
    599     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    600     data:  The tensors to print out if the condition is False.  Defaults to
    601       error message and first few entries of `x`, `y`.
    602     summarize: Print this many entries of each tensor.
    603     message: A string to prefix to the default message.
    604     name: A name for this operation (optional).  Defaults to "assert_less_equal"
    605 
    606   Returns:
    607     Op that raises `InvalidArgumentError` if `x <= y` is False.
    608   """
    609   message = message or ''
    610   with ops.name_scope(name, 'assert_less_equal', [x, y, data]):
    611     x = ops.convert_to_tensor(x, name='x')
    612     y = ops.convert_to_tensor(y, name='y')
    613     if context.in_eager_mode():
    614       x_name = _shape_and_dtype_str(x)
    615       y_name = _shape_and_dtype_str(y)
    616     else:
    617       x_name = x.name
    618       y_name = y.name
    619 
    620     if data is None:
    621       data = [
    622           message,
    623           'Condition x <= y did not hold element-wise:'
    624           'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
    625       ]
    626     condition = math_ops.reduce_all(math_ops.less_equal(x, y))
    627     return control_flow_ops.Assert(condition, data, summarize=summarize)
    628 
    629 
    630 @tf_export('assert_greater')
    631 def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
    632   """Assert the condition `x > y` holds element-wise.
    633 
    634   Example of adding a dependency to an operation:
    635 
    636   ```python
    637   with tf.control_dependencies([tf.assert_greater(x, y)]):
    638     output = tf.reduce_sum(x)
    639   ```
    640 
    641   This condition holds if for every pair of (possibly broadcast) elements
    642   `x[i]`, `y[i]`, we have `x[i] > y[i]`.
    643   If both `x` and `y` are empty, this is trivially satisfied.
    644 
    645   Args:
    646     x:  Numeric `Tensor`.
    647     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    648     data:  The tensors to print out if the condition is False.  Defaults to
    649       error message and first few entries of `x`, `y`.
    650     summarize: Print this many entries of each tensor.
    651     message: A string to prefix to the default message.
    652     name: A name for this operation (optional).  Defaults to "assert_greater".
    653 
    654   Returns:
    655     Op that raises `InvalidArgumentError` if `x > y` is False.
    656   """
    657   message = message or ''
    658   with ops.name_scope(name, 'assert_greater', [x, y, data]):
    659     x = ops.convert_to_tensor(x, name='x')
    660     y = ops.convert_to_tensor(y, name='y')
    661     if context.in_eager_mode():
    662       x_name = _shape_and_dtype_str(x)
    663       y_name = _shape_and_dtype_str(y)
    664     else:
    665       x_name = x.name
    666       y_name = y.name
    667 
    668     if data is None:
    669       data = [
    670           message,
    671           'Condition x > y did not hold element-wise:'
    672           'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
    673       ]
    674     condition = math_ops.reduce_all(math_ops.greater(x, y))
    675     return control_flow_ops.Assert(condition, data, summarize=summarize)
    676 
    677 
    678 @tf_export('assert_greater_equal')
    679 def assert_greater_equal(x, y, data=None, summarize=None, message=None,
    680                          name=None):
    681   """Assert the condition `x >= y` holds element-wise.
    682 
    683   Example of adding a dependency to an operation:
    684 
    685   ```python
    686   with tf.control_dependencies([tf.assert_greater_equal(x, y)]):
    687     output = tf.reduce_sum(x)
    688   ```
    689 
    690   This condition holds if for every pair of (possibly broadcast) elements
    691   `x[i]`, `y[i]`, we have `x[i] >= y[i]`.
    692   If both `x` and `y` are empty, this is trivially satisfied.
    693 
    694   Args:
    695     x:  Numeric `Tensor`.
    696     y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    697     data:  The tensors to print out if the condition is False.  Defaults to
    698       error message and first few entries of `x`, `y`.
    699     summarize: Print this many entries of each tensor.
    700     message: A string to prefix to the default message.
    701     name: A name for this operation (optional).  Defaults to
    702       "assert_greater_equal"
    703 
    704   Returns:
    705     Op that raises `InvalidArgumentError` if `x >= y` is False.
    706   """
    707   message = message or ''
    708   with ops.name_scope(name, 'assert_greater_equal', [x, y, data]):
    709     x = ops.convert_to_tensor(x, name='x')
    710     y = ops.convert_to_tensor(y, name='y')
    711     if context.in_eager_mode():
    712       x_name = _shape_and_dtype_str(x)
    713       y_name = _shape_and_dtype_str(y)
    714     else:
    715       x_name = x.name
    716       y_name = y.name
    717 
    718     if data is None:
    719       data = [
    720           message,
    721           'Condition x >= y did not hold element-wise:'
    722           'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
    723       ]
    724     condition = math_ops.reduce_all(math_ops.greater_equal(x, y))
    725     return control_flow_ops.Assert(condition, data, summarize=summarize)
    726 
    727 
    728 def _assert_rank_condition(
    729     x, rank, static_condition, dynamic_condition, data, summarize):
    730   """Assert `x` has a rank that satisfies a given condition.
    731 
    732   Args:
    733     x:  Numeric `Tensor`.
    734     rank:  Scalar `Tensor`.
    735     static_condition:   A python function that takes `[actual_rank, given_rank]`
    736       and returns `True` if the condition is satisfied, `False` otherwise.
    737     dynamic_condition:  An `op` that takes [actual_rank, given_rank]
    738       and return `True` if the condition is satisfied, `False` otherwise.
    739     data:  The tensors to print out if the condition is false.  Defaults to
    740       error message and first few entries of `x`.
    741     summarize: Print this many entries of each tensor.
    742 
    743   Returns:
    744     Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
    745 
    746   Raises:
    747     ValueError:  If static checks determine `x` fails static_condition.
    748   """
    749   assert_type(rank, dtypes.int32)
    750 
    751   # Attempt to statically defined rank.
    752   rank_static = tensor_util.constant_value(rank)
    753   if rank_static is not None:
    754     if rank_static.ndim != 0:
    755       raise ValueError('Rank must be a scalar.')
    756 
    757     x_rank_static = x.get_shape().ndims
    758     if x_rank_static is not None:
    759       if not static_condition(x_rank_static, rank_static):
    760         raise ValueError(
    761             'Static rank condition failed', x_rank_static, rank_static)
    762       return control_flow_ops.no_op(name='static_checks_determined_all_ok')
    763 
    764   condition = dynamic_condition(array_ops.rank(x), rank)
    765 
    766   # Add the condition that `rank` must have rank zero.  Prevents the bug where
    767   # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
    768   if rank_static is None:
    769     this_data = ['Rank must be a scalar. Received rank: ', rank]
    770     rank_check = assert_rank(rank, 0, data=this_data)
    771     condition = control_flow_ops.with_dependencies([rank_check], condition)
    772 
    773   return control_flow_ops.Assert(condition, data, summarize=summarize)
    774 
    775 
    776 @tf_export('assert_rank')
    777 def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
    778   """Assert `x` has rank equal to `rank`.
    779 
    780   Example of adding a dependency to an operation:
    781 
    782   ```python
    783   with tf.control_dependencies([tf.assert_rank(x, 2)]):
    784     output = tf.reduce_sum(x)
    785   ```
    786 
    787   Args:
    788     x:  Numeric `Tensor`.
    789     rank:  Scalar integer `Tensor`.
    790     data:  The tensors to print out if the condition is False.  Defaults to
    791       error message and first few entries of `x`.
    792     summarize: Print this many entries of each tensor.
    793     message: A string to prefix to the default message.
    794     name: A name for this operation (optional).  Defaults to "assert_rank".
    795 
    796   Returns:
    797     Op raising `InvalidArgumentError` unless `x` has specified rank.
    798     If static checks determine `x` has correct rank, a `no_op` is returned.
    799 
    800   Raises:
    801     ValueError:  If static checks determine `x` has wrong rank.
    802   """
    803   with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
    804     x = ops.convert_to_tensor(x, name='x')
    805     rank = ops.convert_to_tensor(rank, name='rank')
    806     message = message or ''
    807 
    808     static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
    809     dynamic_condition = math_ops.equal
    810 
    811     if context.in_eager_mode():
    812       name = ''
    813     else:
    814       name = x.name
    815 
    816     if data is None:
    817       data = [
    818           message,
    819           'Tensor %s must have rank' % name, rank, 'Received shape: ',
    820           array_ops.shape(x)
    821       ]
    822 
    823     try:
    824       assert_op = _assert_rank_condition(x, rank, static_condition,
    825                                          dynamic_condition, data, summarize)
    826 
    827     except ValueError as e:
    828       if e.args[0] == 'Static rank condition failed':
    829         raise ValueError(
    830             '%s.  Tensor %s must have rank %d.  Received rank %d, shape %s' %
    831             (message, name, e.args[2], e.args[1], x.get_shape()))
    832       else:
    833         raise
    834 
    835   return assert_op
    836 
    837 
    838 @tf_export('assert_rank_at_least')
    839 def assert_rank_at_least(
    840     x, rank, data=None, summarize=None, message=None, name=None):
    841   """Assert `x` has rank equal to `rank` or higher.
    842 
    843   Example of adding a dependency to an operation:
    844 
    845   ```python
    846   with tf.control_dependencies([tf.assert_rank_at_least(x, 2)]):
    847     output = tf.reduce_sum(x)
    848   ```
    849 
    850   Args:
    851     x:  Numeric `Tensor`.
    852     rank:  Scalar `Tensor`.
    853     data:  The tensors to print out if the condition is False.  Defaults to
    854       error message and first few entries of `x`.
    855     summarize: Print this many entries of each tensor.
    856     message: A string to prefix to the default message.
    857     name: A name for this operation (optional).
    858       Defaults to "assert_rank_at_least".
    859 
    860   Returns:
    861     Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
    862     If static checks determine `x` has correct rank, a `no_op` is returned.
    863 
    864   Raises:
    865     ValueError:  If static checks determine `x` has wrong rank.
    866   """
    867   with ops.name_scope(
    868       name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
    869     x = ops.convert_to_tensor(x, name='x')
    870     rank = ops.convert_to_tensor(rank, name='rank')
    871     message = message or ''
    872 
    873     static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
    874     dynamic_condition = math_ops.greater_equal
    875 
    876     if context.in_eager_mode():
    877       name = ''
    878     else:
    879       name = x.name
    880 
    881     if data is None:
    882       data = [
    883           message,
    884           'Tensor %s must have rank at least' % name, rank,
    885           'Received shape: ', array_ops.shape(x)
    886       ]
    887 
    888     try:
    889       assert_op = _assert_rank_condition(x, rank, static_condition,
    890                                          dynamic_condition, data, summarize)
    891 
    892     except ValueError as e:
    893       if e.args[0] == 'Static rank condition failed':
    894         raise ValueError(
    895             '%s.  Tensor %s must have rank at least %d.  Received rank %d, '
    896             'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
    897       else:
    898         raise
    899 
    900   return assert_op
    901 
    902 
    903 def _static_rank_in(actual_rank, given_ranks):
    904   return actual_rank in given_ranks
    905 
    906 
    907 def _dynamic_rank_in(actual_rank, given_ranks):
    908   if len(given_ranks) < 1:
    909     return ops.convert_to_tensor(False)
    910   result = math_ops.equal(given_ranks[0], actual_rank)
    911   for given_rank in given_ranks[1:]:
    912     result = math_ops.logical_or(
    913         result, math_ops.equal(given_rank, actual_rank))
    914   return result
    915 
    916 
    917 def _assert_ranks_condition(
    918     x, ranks, static_condition, dynamic_condition, data, summarize):
    919   """Assert `x` has a rank that satisfies a given condition.
    920 
    921   Args:
    922     x:  Numeric `Tensor`.
    923     ranks:  Scalar `Tensor`.
    924     static_condition:   A python function that takes
    925       `[actual_rank, given_ranks]` and returns `True` if the condition is
    926       satisfied, `False` otherwise.
    927     dynamic_condition:  An `op` that takes [actual_rank, given_ranks]
    928       and return `True` if the condition is satisfied, `False` otherwise.
    929     data:  The tensors to print out if the condition is false.  Defaults to
    930       error message and first few entries of `x`.
    931     summarize: Print this many entries of each tensor.
    932 
    933   Returns:
    934     Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
    935 
    936   Raises:
    937     ValueError:  If static checks determine `x` fails static_condition.
    938   """
    939   for rank in ranks:
    940     assert_type(rank, dtypes.int32)
    941 
    942   # Attempt to statically defined rank.
    943   ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
    944   if not any(r is None for r in ranks_static):
    945     for rank_static in ranks_static:
    946       if rank_static.ndim != 0:
    947         raise ValueError('Rank must be a scalar.')
    948 
    949     x_rank_static = x.get_shape().ndims
    950     if x_rank_static is not None:
    951       if not static_condition(x_rank_static, ranks_static):
    952         raise ValueError(
    953             'Static rank condition failed', x_rank_static, ranks_static)
    954       return control_flow_ops.no_op(name='static_checks_determined_all_ok')
    955 
    956   condition = dynamic_condition(array_ops.rank(x), ranks)
    957 
    958   # Add the condition that `rank` must have rank zero.  Prevents the bug where
    959   # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
    960   for rank, rank_static in zip(ranks, ranks_static):
    961     if rank_static is None:
    962       this_data = ['Rank must be a scalar. Received rank: ', rank]
    963       rank_check = assert_rank(rank, 0, data=this_data)
    964       condition = control_flow_ops.with_dependencies([rank_check], condition)
    965 
    966   return control_flow_ops.Assert(condition, data, summarize=summarize)
    967 
    968 
    969 @tf_export('assert_rank_in')
    970 def assert_rank_in(
    971     x, ranks, data=None, summarize=None, message=None, name=None):
    972   """Assert `x` has rank in `ranks`.
    973 
    974   Example of adding a dependency to an operation:
    975 
    976   ```python
    977   with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]):
    978     output = tf.reduce_sum(x)
    979   ```
    980 
    981   Args:
    982     x:  Numeric `Tensor`.
    983     ranks:  Iterable of scalar `Tensor` objects.
    984     data:  The tensors to print out if the condition is False.  Defaults to
    985       error message and first few entries of `x`.
    986     summarize: Print this many entries of each tensor.
    987     message: A string to prefix to the default message.
    988     name: A name for this operation (optional).
    989       Defaults to "assert_rank_in".
    990 
    991   Returns:
    992     Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
    993     If static checks determine `x` has matching rank, a `no_op` is returned.
    994 
    995   Raises:
    996     ValueError:  If static checks determine `x` has mismatched rank.
    997   """
    998   with ops.name_scope(
    999       name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
   1000     x = ops.convert_to_tensor(x, name='x')
   1001     ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
   1002     message = message or ''
   1003 
   1004     if context.in_eager_mode():
   1005       name = ''
   1006     else:
   1007       name = x.name
   1008 
   1009     if data is None:
   1010       data = [
   1011           message, 'Tensor %s must have rank in' % name
   1012       ] + list(ranks) + [
   1013           'Received shape: ', array_ops.shape(x)
   1014       ]
   1015 
   1016     try:
   1017       assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
   1018                                           _dynamic_rank_in, data, summarize)
   1019 
   1020     except ValueError as e:
   1021       if e.args[0] == 'Static rank condition failed':
   1022         raise ValueError(
   1023             '%s.  Tensor %s must have rank in %s.  Received rank %d, '
   1024             'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
   1025       else:
   1026         raise
   1027 
   1028   return assert_op
   1029 
   1030 
   1031 @tf_export('assert_integer')
   1032 def assert_integer(x, message=None, name=None):
   1033   """Assert that `x` is of integer dtype.
   1034 
   1035   Example of adding a dependency to an operation:
   1036 
   1037   ```python
   1038   with tf.control_dependencies([tf.assert_integer(x)]):
   1039     output = tf.reduce_sum(x)
   1040   ```
   1041 
   1042   Args:
   1043     x: `Tensor` whose basetype is integer and is not quantized.
   1044     message: A string to prefix to the default message.
   1045     name: A name for this operation (optional).  Defaults to "assert_integer".
   1046 
   1047   Raises:
   1048     TypeError:  If `x.dtype` is anything other than non-quantized integer.
   1049 
   1050   Returns:
   1051     A `no_op` that does nothing.  Type can be determined statically.
   1052   """
   1053   message = message or ''
   1054   with ops.name_scope(name, 'assert_integer', [x]):
   1055     x = ops.convert_to_tensor(x, name='x')
   1056     if not x.dtype.is_integer:
   1057       if context.in_eager_mode():
   1058         name = 'tensor'
   1059       else:
   1060         name = x.name
   1061       err_msg = (
   1062           '%s  Expected "x" to be integer type.  Found: %s of dtype %s'
   1063           % (message, name, x.dtype))
   1064       raise TypeError(err_msg)
   1065 
   1066     return control_flow_ops.no_op('statically_determined_was_integer')
   1067 
   1068 
   1069 @tf_export('assert_type')
   1070 def assert_type(tensor, tf_type, message=None, name=None):
   1071   """Statically asserts that the given `Tensor` is of the specified type.
   1072 
   1073   Args:
   1074     tensor: A tensorflow `Tensor`.
   1075     tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
   1076       etc).
   1077     message: A string to prefix to the default message.
   1078     name:  A name to give this `Op`.  Defaults to "assert_type"
   1079 
   1080   Raises:
   1081     TypeError: If the tensors data type doesn't match `tf_type`.
   1082 
   1083   Returns:
   1084     A `no_op` that does nothing.  Type can be determined statically.
   1085   """
   1086   message = message or ''
   1087   with ops.name_scope(name, 'assert_type', [tensor]):
   1088     tensor = ops.convert_to_tensor(tensor, name='tensor')
   1089     if tensor.dtype != tf_type:
   1090       if context.in_graph_mode():
   1091         raise TypeError(
   1092             '%s  %s must be of type %s' % (message, tensor.name, tf_type))
   1093       else:
   1094         raise TypeError(
   1095             '%s tensor must be of type %s' % (message, tf_type))
   1096 
   1097     return control_flow_ops.no_op('statically_determined_correct_type')
   1098 
   1099 
   1100 # pylint: disable=line-too-long
   1101 def _get_diff_for_monotonic_comparison(x):
   1102   """Gets the difference x[1:] - x[:-1]."""
   1103   x = array_ops.reshape(x, [-1])
   1104   if not is_numeric_tensor(x):
   1105     raise TypeError('Expected x to be numeric, instead found: %s' % x)
   1106 
   1107   # If x has less than 2 elements, there is nothing to compare.  So return [].
   1108   is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
   1109   short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
   1110 
   1111   # With 2 or more elements, return x[1:] - x[:-1]
   1112   s_len = array_ops.shape(x) - 1
   1113   diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
   1114   return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
   1115 
   1116 
   1117 @tf_export('is_numeric_tensor')
   1118 def is_numeric_tensor(tensor):
   1119   return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
   1120 
   1121 
   1122 @tf_export('is_non_decreasing')
   1123 def is_non_decreasing(x, name=None):
   1124   """Returns `True` if `x` is non-decreasing.
   1125 
   1126   Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
   1127   is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
   1128   If `x` has less than two elements, it is trivially non-decreasing.
   1129 
   1130   See also:  `is_strictly_increasing`
   1131 
   1132   Args:
   1133     x: Numeric `Tensor`.
   1134     name: A name for this operation (optional).  Defaults to "is_non_decreasing"
   1135 
   1136   Returns:
   1137     Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
   1138 
   1139   Raises:
   1140     TypeError: if `x` is not a numeric tensor.
   1141   """
   1142   with ops.name_scope(name, 'is_non_decreasing', [x]):
   1143     diff = _get_diff_for_monotonic_comparison(x)
   1144     # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
   1145     zero = ops.convert_to_tensor(0, dtype=diff.dtype)
   1146     return math_ops.reduce_all(math_ops.less_equal(zero, diff))
   1147 
   1148 
   1149 @tf_export('is_strictly_increasing')
   1150 def is_strictly_increasing(x, name=None):
   1151   """Returns `True` if `x` is strictly increasing.
   1152 
   1153   Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
   1154   is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
   1155   If `x` has less than two elements, it is trivially strictly increasing.
   1156 
   1157   See also:  `is_non_decreasing`
   1158 
   1159   Args:
   1160     x: Numeric `Tensor`.
   1161     name: A name for this operation (optional).
   1162       Defaults to "is_strictly_increasing"
   1163 
   1164   Returns:
   1165     Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
   1166 
   1167   Raises:
   1168     TypeError: if `x` is not a numeric tensor.
   1169   """
   1170   with ops.name_scope(name, 'is_strictly_increasing', [x]):
   1171     diff = _get_diff_for_monotonic_comparison(x)
   1172     # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
   1173     zero = ops.convert_to_tensor(0, dtype=diff.dtype)
   1174     return math_ops.reduce_all(math_ops.less(zero, diff))
   1175 
   1176 
   1177 def _assert_same_base_type(items, expected_type=None):
   1178   r"""Asserts all items are of the same base type.
   1179 
   1180   Args:
   1181     items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
   1182         `Operation`, or `IndexedSlices`). Can include `None` elements, which
   1183         will be ignored.
   1184     expected_type: Expected type. If not specified, assert all items are
   1185         of the same base type.
   1186 
   1187   Returns:
   1188     Validated type, or none if neither expected_type nor items provided.
   1189 
   1190   Raises:
   1191     ValueError: If any types do not match.
   1192   """
   1193   original_item_str = None
   1194   for item in items:
   1195     if item is not None:
   1196       item_type = item.dtype.base_dtype
   1197       if not expected_type:
   1198         expected_type = item_type
   1199         original_item_str = item.name if hasattr(item, 'name') else str(item)
   1200       elif expected_type != item_type:
   1201         raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
   1202             item.name if hasattr(item, 'name') else str(item),
   1203             item_type, expected_type,
   1204             (' as %s' % original_item_str) if original_item_str else ''))
   1205   return expected_type
   1206 
   1207 
   1208 @tf_export('assert_same_float_dtype')
   1209 def assert_same_float_dtype(tensors=None, dtype=None):
   1210   """Validate and return float type based on `tensors` and `dtype`.
   1211 
   1212   For ops such as matrix multiplication, inputs and weights must be of the
   1213   same float type. This function validates that all `tensors` are the same type,
   1214   validates that type is `dtype` (if supplied), and returns the type. Type must
   1215   be a floating point type. If neither `tensors` nor `dtype` is supplied,
   1216   the function will return `dtypes.float32`.
   1217 
   1218   Args:
   1219     tensors: Tensors of input values. Can include `None` elements, which will be
   1220         ignored.
   1221     dtype: Expected type.
   1222   Returns:
   1223     Validated type.
   1224   Raises:
   1225     ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
   1226         float, or the common type of the inputs is not a floating point type.
   1227   """
   1228   if tensors:
   1229     dtype = _assert_same_base_type(tensors, dtype)
   1230   if not dtype:
   1231     dtype = dtypes.float32
   1232   elif not dtype.is_floating:
   1233     raise ValueError('Expected floating point type, got %s.' % dtype)
   1234   return dtype
   1235 
   1236 
   1237 @tf_export('assert_scalar')
   1238 def assert_scalar(tensor, name=None):
   1239   with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
   1240     tensor = ops.convert_to_tensor(tensor, name=name_scope)
   1241     shape = tensor.get_shape()
   1242     if shape.ndims != 0:
   1243       if context.in_eager_mode():
   1244         raise ValueError('Expected scalar shape, saw shape: %s.'
   1245                          % (shape,))
   1246       else:
   1247         raise ValueError('Expected scalar shape for %s, saw shape: %s.'
   1248                          % (tensor.name, shape))
   1249     return tensor
   1250