1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 # pylint: disable=g-short-docstring-punctuation 16 """Asserts and Boolean Checks. 17 18 See the @{$python/check_ops} guide. 19 20 @@assert_negative 21 @@assert_positive 22 @@assert_non_negative 23 @@assert_non_positive 24 @@assert_equal 25 @@assert_none_equal 26 @@assert_near 27 @@assert_less 28 @@assert_less_equal 29 @@assert_greater 30 @@assert_greater_equal 31 @@assert_rank 32 @@assert_rank_at_least 33 @@assert_rank_in 34 @@assert_type 35 @@assert_integer 36 @@assert_proper_iterable 37 @@assert_same_float_dtype 38 @@assert_scalar 39 @@is_non_decreasing 40 @@is_numeric_tensor 41 @@is_strictly_increasing 42 """ 43 44 from __future__ import absolute_import 45 from __future__ import division 46 from __future__ import print_function 47 48 import numpy as np 49 50 from tensorflow.python.eager import context 51 from tensorflow.python.framework import dtypes 52 from tensorflow.python.framework import errors 53 from tensorflow.python.framework import ops 54 from tensorflow.python.framework import sparse_tensor 55 from tensorflow.python.framework import tensor_util 56 from tensorflow.python.ops import array_ops 57 from tensorflow.python.ops import control_flow_ops 58 from tensorflow.python.ops import math_ops 59 from tensorflow.python.util import compat 60 from tensorflow.python.util.tf_export import tf_export 61 62 NUMERIC_TYPES = frozenset( 63 [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, 64 dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, 65 dtypes.complex64]) 66 67 __all__ = [ 68 'assert_negative', 69 'assert_positive', 70 'assert_proper_iterable', 71 'assert_non_negative', 72 'assert_non_positive', 73 'assert_equal', 74 'assert_none_equal', 75 'assert_near', 76 'assert_integer', 77 'assert_less', 78 'assert_less_equal', 79 'assert_greater', 80 'assert_greater_equal', 81 'assert_rank', 82 'assert_rank_at_least', 83 'assert_rank_in', 84 'assert_same_float_dtype', 85 'assert_scalar', 86 'assert_type', 87 'is_non_decreasing', 88 'is_numeric_tensor', 89 'is_strictly_increasing', 90 ] 91 92 93 def _maybe_constant_value_string(t): 94 if not isinstance(t, ops.Tensor): 95 return str(t) 96 const_t = tensor_util.constant_value(t) 97 if const_t is not None: 98 return str(const_t) 99 return t 100 101 102 def _assert_static(condition, data): 103 """Raises a InvalidArgumentError with as much information as possible.""" 104 if not condition: 105 data_static = [_maybe_constant_value_string(x) for x in data] 106 raise errors.InvalidArgumentError(node_def=None, op=None, 107 message='\n'.join(data_static)) 108 109 110 def _shape_and_dtype_str(tensor): 111 """Returns a string containing tensor's shape and dtype.""" 112 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 113 114 115 @tf_export('assert_proper_iterable') 116 def assert_proper_iterable(values): 117 """Static assert that values is a "proper" iterable. 118 119 `Ops` that expect iterables of `Tensor` can call this to validate input. 120 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 121 122 Args: 123 values: Object to be checked. 124 125 Raises: 126 TypeError: If `values` is not iterable or is one of 127 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 128 """ 129 unintentional_iterables = ( 130 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 131 + compat.bytes_or_text_types 132 ) 133 if isinstance(values, unintentional_iterables): 134 raise TypeError( 135 'Expected argument "values" to be a "proper" iterable. Found: %s' % 136 type(values)) 137 138 if not hasattr(values, '__iter__'): 139 raise TypeError( 140 'Expected argument "values" to be iterable. Found: %s' % type(values)) 141 142 143 @tf_export('assert_negative') 144 def assert_negative(x, data=None, summarize=None, message=None, name=None): 145 """Assert the condition `x < 0` holds element-wise. 146 147 Example of adding a dependency to an operation: 148 149 ```python 150 with tf.control_dependencies([tf.assert_negative(x)]): 151 output = tf.reduce_sum(x) 152 ``` 153 154 Negative means, for every element `x[i]` of `x`, we have `x[i] < 0`. 155 If `x` is empty this is trivially satisfied. 156 157 Args: 158 x: Numeric `Tensor`. 159 data: The tensors to print out if the condition is False. Defaults to 160 error message and first few entries of `x`. 161 summarize: Print this many entries of each tensor. 162 message: A string to prefix to the default message. 163 name: A name for this operation (optional). Defaults to "assert_negative". 164 165 Returns: 166 Op raising `InvalidArgumentError` unless `x` is all negative. 167 """ 168 message = message or '' 169 with ops.name_scope(name, 'assert_negative', [x, data]): 170 x = ops.convert_to_tensor(x, name='x') 171 if data is None: 172 if context.in_eager_mode(): 173 name = _shape_and_dtype_str(x) 174 else: 175 name = x.name 176 data = [ 177 message, 178 'Condition x < 0 did not hold element-wise:', 179 'x (%s) = ' % name, x] 180 zero = ops.convert_to_tensor(0, dtype=x.dtype) 181 return assert_less(x, zero, data=data, summarize=summarize) 182 183 184 @tf_export('assert_positive') 185 def assert_positive(x, data=None, summarize=None, message=None, name=None): 186 """Assert the condition `x > 0` holds element-wise. 187 188 Example of adding a dependency to an operation: 189 190 ```python 191 with tf.control_dependencies([tf.assert_positive(x)]): 192 output = tf.reduce_sum(x) 193 ``` 194 195 Positive means, for every element `x[i]` of `x`, we have `x[i] > 0`. 196 If `x` is empty this is trivially satisfied. 197 198 Args: 199 x: Numeric `Tensor`. 200 data: The tensors to print out if the condition is False. Defaults to 201 error message and first few entries of `x`. 202 summarize: Print this many entries of each tensor. 203 message: A string to prefix to the default message. 204 name: A name for this operation (optional). Defaults to "assert_positive". 205 206 Returns: 207 Op raising `InvalidArgumentError` unless `x` is all positive. 208 """ 209 message = message or '' 210 with ops.name_scope(name, 'assert_positive', [x, data]): 211 x = ops.convert_to_tensor(x, name='x') 212 if data is None: 213 if context.in_eager_mode(): 214 name = _shape_and_dtype_str(x) 215 else: 216 name = x.name 217 data = [ 218 message, 'Condition x > 0 did not hold element-wise:', 219 'x (%s) = ' % name, x] 220 zero = ops.convert_to_tensor(0, dtype=x.dtype) 221 return assert_less(zero, x, data=data, summarize=summarize) 222 223 224 @tf_export('assert_non_negative') 225 def assert_non_negative(x, data=None, summarize=None, message=None, name=None): 226 """Assert the condition `x >= 0` holds element-wise. 227 228 Example of adding a dependency to an operation: 229 230 ```python 231 with tf.control_dependencies([tf.assert_non_negative(x)]): 232 output = tf.reduce_sum(x) 233 ``` 234 235 Non-negative means, for every element `x[i]` of `x`, we have `x[i] >= 0`. 236 If `x` is empty this is trivially satisfied. 237 238 Args: 239 x: Numeric `Tensor`. 240 data: The tensors to print out if the condition is False. Defaults to 241 error message and first few entries of `x`. 242 summarize: Print this many entries of each tensor. 243 message: A string to prefix to the default message. 244 name: A name for this operation (optional). 245 Defaults to "assert_non_negative". 246 247 Returns: 248 Op raising `InvalidArgumentError` unless `x` is all non-negative. 249 """ 250 message = message or '' 251 with ops.name_scope(name, 'assert_non_negative', [x, data]): 252 x = ops.convert_to_tensor(x, name='x') 253 if data is None: 254 if context.in_eager_mode(): 255 name = _shape_and_dtype_str(x) 256 else: 257 name = x.name 258 data = [ 259 message, 260 'Condition x >= 0 did not hold element-wise:', 261 'x (%s) = ' % name, x] 262 zero = ops.convert_to_tensor(0, dtype=x.dtype) 263 return assert_less_equal(zero, x, data=data, summarize=summarize) 264 265 266 @tf_export('assert_non_positive') 267 def assert_non_positive(x, data=None, summarize=None, message=None, name=None): 268 """Assert the condition `x <= 0` holds element-wise. 269 270 Example of adding a dependency to an operation: 271 272 ```python 273 with tf.control_dependencies([tf.assert_non_positive(x)]): 274 output = tf.reduce_sum(x) 275 ``` 276 277 Non-positive means, for every element `x[i]` of `x`, we have `x[i] <= 0`. 278 If `x` is empty this is trivially satisfied. 279 280 Args: 281 x: Numeric `Tensor`. 282 data: The tensors to print out if the condition is False. Defaults to 283 error message and first few entries of `x`. 284 summarize: Print this many entries of each tensor. 285 message: A string to prefix to the default message. 286 name: A name for this operation (optional). 287 Defaults to "assert_non_positive". 288 289 Returns: 290 Op raising `InvalidArgumentError` unless `x` is all non-positive. 291 """ 292 message = message or '' 293 with ops.name_scope(name, 'assert_non_positive', [x, data]): 294 x = ops.convert_to_tensor(x, name='x') 295 if data is None: 296 if context.in_eager_mode(): 297 name = _shape_and_dtype_str(x) 298 else: 299 name = x.name 300 data = [ 301 message, 302 'Condition x <= 0 did not hold element-wise:' 303 'x (%s) = ' % name, x] 304 zero = ops.convert_to_tensor(0, dtype=x.dtype) 305 return assert_less_equal(x, zero, data=data, summarize=summarize) 306 307 308 @tf_export('assert_equal') 309 def assert_equal(x, y, data=None, summarize=None, message=None, name=None): 310 """Assert the condition `x == y` holds element-wise. 311 312 Example of adding a dependency to an operation: 313 314 ```python 315 with tf.control_dependencies([tf.assert_equal(x, y)]): 316 output = tf.reduce_sum(x) 317 ``` 318 319 This condition holds if for every pair of (possibly broadcast) elements 320 `x[i]`, `y[i]`, we have `x[i] == y[i]`. 321 If both `x` and `y` are empty, this is trivially satisfied. 322 323 Args: 324 x: Numeric `Tensor`. 325 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 326 data: The tensors to print out if the condition is False. Defaults to 327 error message and first few entries of `x`, `y`. 328 summarize: Print this many entries of each tensor. 329 message: A string to prefix to the default message. 330 name: A name for this operation (optional). Defaults to "assert_equal". 331 332 Returns: 333 Op that raises `InvalidArgumentError` if `x == y` is False. 334 @compatibility{eager} returns None 335 336 Raises: 337 InvalidArgumentError: if the check can be performed immediately and 338 `x == y` is False. The check can be performed immediately during eager 339 execution or if `x` and `y` are statically known. 340 """ 341 message = message or '' 342 with ops.name_scope(name, 'assert_equal', [x, y, data]): 343 x = ops.convert_to_tensor(x, name='x') 344 y = ops.convert_to_tensor(y, name='y') 345 346 if context.in_eager_mode(): 347 eq = math_ops.equal(x, y) 348 condition = math_ops.reduce_all(eq) 349 if not condition: 350 # Prepare a message with first elements of x and y. 351 summary_msg = '' 352 # Default to printing 3 elements like control_flow_ops.Assert (used 353 # by graph mode) does. 354 summarize = 3 if summarize is None else summarize 355 if summarize: 356 # reshape((-1,)) is the fastest way to get a flat array view. 357 x_np = x.numpy().reshape((-1,)) 358 y_np = y.numpy().reshape((-1,)) 359 x_sum = min(x_np.size, summarize) 360 y_sum = min(y_np.size, summarize) 361 summary_msg = ('First %d elements of x:\n%s\n' 362 'First %d elements of y:\n%s\n' % 363 (x_sum, x_np[:x_sum], 364 y_sum, y_np[:y_sum])) 365 366 # Get the values that actually differed and their indices. 367 mask = math_ops.logical_not(eq) 368 indices = array_ops.where(mask) 369 indices_np = indices.numpy() 370 x_vals = array_ops.boolean_mask(x, mask) 371 y_vals = array_ops.boolean_mask(y, mask) 372 summarize = min(summarize, indices_np.shape[0]) 373 374 raise errors.InvalidArgumentError( 375 node_def=None, op=None, 376 message=('%s\nCondition x == y did not hold.\n' 377 'Indices of first %s different values:\n%s\n' 378 'Corresponding x values:\n%s\n' 379 'Corresponding y values:\n%s\n' 380 '%s' 381 % 382 (message or '', 383 summarize, indices_np[:summarize], 384 x_vals.numpy().reshape((-1,))[:summarize], 385 y_vals.numpy().reshape((-1,))[:summarize], 386 summary_msg))) 387 return 388 389 if data is None: 390 data = [ 391 message, 392 'Condition x == y did not hold element-wise:', 393 'x (%s) = ' % x.name, x, 394 'y (%s) = ' % y.name, y 395 ] 396 condition = math_ops.reduce_all(math_ops.equal(x, y)) 397 x_static = tensor_util.constant_value(x) 398 y_static = tensor_util.constant_value(y) 399 if x_static is not None and y_static is not None: 400 condition_static = (x_static == y_static).all() 401 _assert_static(condition_static, data) 402 return control_flow_ops.Assert(condition, data, summarize=summarize) 403 404 405 @tf_export('assert_none_equal') 406 def assert_none_equal( 407 x, y, data=None, summarize=None, message=None, name=None): 408 """Assert the condition `x != y` holds for all elements. 409 410 Example of adding a dependency to an operation: 411 412 ```python 413 with tf.control_dependencies([tf.assert_none_equal(x, y)]): 414 output = tf.reduce_sum(x) 415 ``` 416 417 This condition holds if for every pair of (possibly broadcast) elements 418 `x[i]`, `y[i]`, we have `x[i] != y[i]`. 419 If both `x` and `y` are empty, this is trivially satisfied. 420 421 Args: 422 x: Numeric `Tensor`. 423 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 424 data: The tensors to print out if the condition is False. Defaults to 425 error message and first few entries of `x`, `y`. 426 summarize: Print this many entries of each tensor. 427 message: A string to prefix to the default message. 428 name: A name for this operation (optional). 429 Defaults to "assert_none_equal". 430 431 Returns: 432 Op that raises `InvalidArgumentError` if `x != y` is ever False. 433 """ 434 message = message or '' 435 with ops.name_scope(name, 'assert_none_equal', [x, y, data]): 436 x = ops.convert_to_tensor(x, name='x') 437 y = ops.convert_to_tensor(y, name='y') 438 if context.in_eager_mode(): 439 x_name = _shape_and_dtype_str(x) 440 y_name = _shape_and_dtype_str(y) 441 else: 442 x_name = x.name 443 y_name = y.name 444 445 if data is None: 446 data = [ 447 message, 448 'Condition x != y did not hold for every single element:', 449 'x (%s) = ' % x_name, x, 450 'y (%s) = ' % y_name, y 451 ] 452 condition = math_ops.reduce_all(math_ops.not_equal(x, y)) 453 return control_flow_ops.Assert(condition, data, summarize=summarize) 454 455 456 @tf_export('assert_near') 457 def assert_near( 458 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 459 name=None): 460 """Assert the condition `x` and `y` are close element-wise. 461 462 Example of adding a dependency to an operation: 463 464 ```python 465 with tf.control_dependencies([tf.assert_near(x, y)]): 466 output = tf.reduce_sum(x) 467 ``` 468 469 This condition holds if for every pair of (possibly broadcast) elements 470 `x[i]`, `y[i]`, we have 471 472 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 473 474 If both `x` and `y` are empty, this is trivially satisfied. 475 476 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 477 representable positive number such that `1 + eps != eps`. This is about 478 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 479 See `numpy.finfo`. 480 481 Args: 482 x: Float or complex `Tensor`. 483 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 484 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 485 The relative tolerance. Default is `10 * eps`. 486 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 487 The absolute tolerance. Default is `10 * eps`. 488 data: The tensors to print out if the condition is False. Defaults to 489 error message and first few entries of `x`, `y`. 490 summarize: Print this many entries of each tensor. 491 message: A string to prefix to the default message. 492 name: A name for this operation (optional). Defaults to "assert_near". 493 494 Returns: 495 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 496 497 @compatibility(numpy) 498 Similar to `numpy.assert_allclose`, except tolerance depends on data type. 499 This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, 500 and even `16bit` data. 501 @end_compatibility 502 """ 503 message = message or '' 504 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 505 x = ops.convert_to_tensor(x, name='x') 506 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 507 508 eps = np.finfo(x.dtype.as_numpy_dtype).eps 509 rtol = 10 * eps if rtol is None else rtol 510 atol = 10 * eps if atol is None else atol 511 512 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) 513 atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) 514 515 if context.in_eager_mode(): 516 x_name = _shape_and_dtype_str(x) 517 y_name = _shape_and_dtype_str(y) 518 else: 519 x_name = x.name 520 y_name = y.name 521 522 if data is None: 523 data = [ 524 message, 525 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 526 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 527 ] 528 tol = atol + rtol * math_ops.abs(y) 529 diff = math_ops.abs(x - y) 530 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 531 return control_flow_ops.Assert(condition, data, summarize=summarize) 532 533 534 @tf_export('assert_less') 535 def assert_less(x, y, data=None, summarize=None, message=None, name=None): 536 """Assert the condition `x < y` holds element-wise. 537 538 Example of adding a dependency to an operation: 539 540 ```python 541 with tf.control_dependencies([tf.assert_less(x, y)]): 542 output = tf.reduce_sum(x) 543 ``` 544 545 This condition holds if for every pair of (possibly broadcast) elements 546 `x[i]`, `y[i]`, we have `x[i] < y[i]`. 547 If both `x` and `y` are empty, this is trivially satisfied. 548 549 Args: 550 x: Numeric `Tensor`. 551 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 552 data: The tensors to print out if the condition is False. Defaults to 553 error message and first few entries of `x`, `y`. 554 summarize: Print this many entries of each tensor. 555 message: A string to prefix to the default message. 556 name: A name for this operation (optional). Defaults to "assert_less". 557 558 Returns: 559 Op that raises `InvalidArgumentError` if `x < y` is False. 560 """ 561 message = message or '' 562 with ops.name_scope(name, 'assert_less', [x, y, data]): 563 x = ops.convert_to_tensor(x, name='x') 564 y = ops.convert_to_tensor(y, name='y') 565 if context.in_eager_mode(): 566 x_name = _shape_and_dtype_str(x) 567 y_name = _shape_and_dtype_str(y) 568 else: 569 x_name = x.name 570 y_name = y.name 571 572 if data is None: 573 data = [ 574 message, 575 'Condition x < y did not hold element-wise:', 576 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 577 ] 578 condition = math_ops.reduce_all(math_ops.less(x, y)) 579 return control_flow_ops.Assert(condition, data, summarize=summarize) 580 581 582 @tf_export('assert_less_equal') 583 def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 584 """Assert the condition `x <= y` holds element-wise. 585 586 Example of adding a dependency to an operation: 587 588 ```python 589 with tf.control_dependencies([tf.assert_less_equal(x, y)]): 590 output = tf.reduce_sum(x) 591 ``` 592 593 This condition holds if for every pair of (possibly broadcast) elements 594 `x[i]`, `y[i]`, we have `x[i] <= y[i]`. 595 If both `x` and `y` are empty, this is trivially satisfied. 596 597 Args: 598 x: Numeric `Tensor`. 599 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 600 data: The tensors to print out if the condition is False. Defaults to 601 error message and first few entries of `x`, `y`. 602 summarize: Print this many entries of each tensor. 603 message: A string to prefix to the default message. 604 name: A name for this operation (optional). Defaults to "assert_less_equal" 605 606 Returns: 607 Op that raises `InvalidArgumentError` if `x <= y` is False. 608 """ 609 message = message or '' 610 with ops.name_scope(name, 'assert_less_equal', [x, y, data]): 611 x = ops.convert_to_tensor(x, name='x') 612 y = ops.convert_to_tensor(y, name='y') 613 if context.in_eager_mode(): 614 x_name = _shape_and_dtype_str(x) 615 y_name = _shape_and_dtype_str(y) 616 else: 617 x_name = x.name 618 y_name = y.name 619 620 if data is None: 621 data = [ 622 message, 623 'Condition x <= y did not hold element-wise:' 624 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 625 ] 626 condition = math_ops.reduce_all(math_ops.less_equal(x, y)) 627 return control_flow_ops.Assert(condition, data, summarize=summarize) 628 629 630 @tf_export('assert_greater') 631 def assert_greater(x, y, data=None, summarize=None, message=None, name=None): 632 """Assert the condition `x > y` holds element-wise. 633 634 Example of adding a dependency to an operation: 635 636 ```python 637 with tf.control_dependencies([tf.assert_greater(x, y)]): 638 output = tf.reduce_sum(x) 639 ``` 640 641 This condition holds if for every pair of (possibly broadcast) elements 642 `x[i]`, `y[i]`, we have `x[i] > y[i]`. 643 If both `x` and `y` are empty, this is trivially satisfied. 644 645 Args: 646 x: Numeric `Tensor`. 647 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 648 data: The tensors to print out if the condition is False. Defaults to 649 error message and first few entries of `x`, `y`. 650 summarize: Print this many entries of each tensor. 651 message: A string to prefix to the default message. 652 name: A name for this operation (optional). Defaults to "assert_greater". 653 654 Returns: 655 Op that raises `InvalidArgumentError` if `x > y` is False. 656 """ 657 message = message or '' 658 with ops.name_scope(name, 'assert_greater', [x, y, data]): 659 x = ops.convert_to_tensor(x, name='x') 660 y = ops.convert_to_tensor(y, name='y') 661 if context.in_eager_mode(): 662 x_name = _shape_and_dtype_str(x) 663 y_name = _shape_and_dtype_str(y) 664 else: 665 x_name = x.name 666 y_name = y.name 667 668 if data is None: 669 data = [ 670 message, 671 'Condition x > y did not hold element-wise:' 672 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 673 ] 674 condition = math_ops.reduce_all(math_ops.greater(x, y)) 675 return control_flow_ops.Assert(condition, data, summarize=summarize) 676 677 678 @tf_export('assert_greater_equal') 679 def assert_greater_equal(x, y, data=None, summarize=None, message=None, 680 name=None): 681 """Assert the condition `x >= y` holds element-wise. 682 683 Example of adding a dependency to an operation: 684 685 ```python 686 with tf.control_dependencies([tf.assert_greater_equal(x, y)]): 687 output = tf.reduce_sum(x) 688 ``` 689 690 This condition holds if for every pair of (possibly broadcast) elements 691 `x[i]`, `y[i]`, we have `x[i] >= y[i]`. 692 If both `x` and `y` are empty, this is trivially satisfied. 693 694 Args: 695 x: Numeric `Tensor`. 696 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 697 data: The tensors to print out if the condition is False. Defaults to 698 error message and first few entries of `x`, `y`. 699 summarize: Print this many entries of each tensor. 700 message: A string to prefix to the default message. 701 name: A name for this operation (optional). Defaults to 702 "assert_greater_equal" 703 704 Returns: 705 Op that raises `InvalidArgumentError` if `x >= y` is False. 706 """ 707 message = message or '' 708 with ops.name_scope(name, 'assert_greater_equal', [x, y, data]): 709 x = ops.convert_to_tensor(x, name='x') 710 y = ops.convert_to_tensor(y, name='y') 711 if context.in_eager_mode(): 712 x_name = _shape_and_dtype_str(x) 713 y_name = _shape_and_dtype_str(y) 714 else: 715 x_name = x.name 716 y_name = y.name 717 718 if data is None: 719 data = [ 720 message, 721 'Condition x >= y did not hold element-wise:' 722 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 723 ] 724 condition = math_ops.reduce_all(math_ops.greater_equal(x, y)) 725 return control_flow_ops.Assert(condition, data, summarize=summarize) 726 727 728 def _assert_rank_condition( 729 x, rank, static_condition, dynamic_condition, data, summarize): 730 """Assert `x` has a rank that satisfies a given condition. 731 732 Args: 733 x: Numeric `Tensor`. 734 rank: Scalar `Tensor`. 735 static_condition: A python function that takes `[actual_rank, given_rank]` 736 and returns `True` if the condition is satisfied, `False` otherwise. 737 dynamic_condition: An `op` that takes [actual_rank, given_rank] 738 and return `True` if the condition is satisfied, `False` otherwise. 739 data: The tensors to print out if the condition is false. Defaults to 740 error message and first few entries of `x`. 741 summarize: Print this many entries of each tensor. 742 743 Returns: 744 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 745 746 Raises: 747 ValueError: If static checks determine `x` fails static_condition. 748 """ 749 assert_type(rank, dtypes.int32) 750 751 # Attempt to statically defined rank. 752 rank_static = tensor_util.constant_value(rank) 753 if rank_static is not None: 754 if rank_static.ndim != 0: 755 raise ValueError('Rank must be a scalar.') 756 757 x_rank_static = x.get_shape().ndims 758 if x_rank_static is not None: 759 if not static_condition(x_rank_static, rank_static): 760 raise ValueError( 761 'Static rank condition failed', x_rank_static, rank_static) 762 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 763 764 condition = dynamic_condition(array_ops.rank(x), rank) 765 766 # Add the condition that `rank` must have rank zero. Prevents the bug where 767 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 768 if rank_static is None: 769 this_data = ['Rank must be a scalar. Received rank: ', rank] 770 rank_check = assert_rank(rank, 0, data=this_data) 771 condition = control_flow_ops.with_dependencies([rank_check], condition) 772 773 return control_flow_ops.Assert(condition, data, summarize=summarize) 774 775 776 @tf_export('assert_rank') 777 def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 778 """Assert `x` has rank equal to `rank`. 779 780 Example of adding a dependency to an operation: 781 782 ```python 783 with tf.control_dependencies([tf.assert_rank(x, 2)]): 784 output = tf.reduce_sum(x) 785 ``` 786 787 Args: 788 x: Numeric `Tensor`. 789 rank: Scalar integer `Tensor`. 790 data: The tensors to print out if the condition is False. Defaults to 791 error message and first few entries of `x`. 792 summarize: Print this many entries of each tensor. 793 message: A string to prefix to the default message. 794 name: A name for this operation (optional). Defaults to "assert_rank". 795 796 Returns: 797 Op raising `InvalidArgumentError` unless `x` has specified rank. 798 If static checks determine `x` has correct rank, a `no_op` is returned. 799 800 Raises: 801 ValueError: If static checks determine `x` has wrong rank. 802 """ 803 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 804 x = ops.convert_to_tensor(x, name='x') 805 rank = ops.convert_to_tensor(rank, name='rank') 806 message = message or '' 807 808 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 809 dynamic_condition = math_ops.equal 810 811 if context.in_eager_mode(): 812 name = '' 813 else: 814 name = x.name 815 816 if data is None: 817 data = [ 818 message, 819 'Tensor %s must have rank' % name, rank, 'Received shape: ', 820 array_ops.shape(x) 821 ] 822 823 try: 824 assert_op = _assert_rank_condition(x, rank, static_condition, 825 dynamic_condition, data, summarize) 826 827 except ValueError as e: 828 if e.args[0] == 'Static rank condition failed': 829 raise ValueError( 830 '%s. Tensor %s must have rank %d. Received rank %d, shape %s' % 831 (message, name, e.args[2], e.args[1], x.get_shape())) 832 else: 833 raise 834 835 return assert_op 836 837 838 @tf_export('assert_rank_at_least') 839 def assert_rank_at_least( 840 x, rank, data=None, summarize=None, message=None, name=None): 841 """Assert `x` has rank equal to `rank` or higher. 842 843 Example of adding a dependency to an operation: 844 845 ```python 846 with tf.control_dependencies([tf.assert_rank_at_least(x, 2)]): 847 output = tf.reduce_sum(x) 848 ``` 849 850 Args: 851 x: Numeric `Tensor`. 852 rank: Scalar `Tensor`. 853 data: The tensors to print out if the condition is False. Defaults to 854 error message and first few entries of `x`. 855 summarize: Print this many entries of each tensor. 856 message: A string to prefix to the default message. 857 name: A name for this operation (optional). 858 Defaults to "assert_rank_at_least". 859 860 Returns: 861 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 862 If static checks determine `x` has correct rank, a `no_op` is returned. 863 864 Raises: 865 ValueError: If static checks determine `x` has wrong rank. 866 """ 867 with ops.name_scope( 868 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 869 x = ops.convert_to_tensor(x, name='x') 870 rank = ops.convert_to_tensor(rank, name='rank') 871 message = message or '' 872 873 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 874 dynamic_condition = math_ops.greater_equal 875 876 if context.in_eager_mode(): 877 name = '' 878 else: 879 name = x.name 880 881 if data is None: 882 data = [ 883 message, 884 'Tensor %s must have rank at least' % name, rank, 885 'Received shape: ', array_ops.shape(x) 886 ] 887 888 try: 889 assert_op = _assert_rank_condition(x, rank, static_condition, 890 dynamic_condition, data, summarize) 891 892 except ValueError as e: 893 if e.args[0] == 'Static rank condition failed': 894 raise ValueError( 895 '%s. Tensor %s must have rank at least %d. Received rank %d, ' 896 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 897 else: 898 raise 899 900 return assert_op 901 902 903 def _static_rank_in(actual_rank, given_ranks): 904 return actual_rank in given_ranks 905 906 907 def _dynamic_rank_in(actual_rank, given_ranks): 908 if len(given_ranks) < 1: 909 return ops.convert_to_tensor(False) 910 result = math_ops.equal(given_ranks[0], actual_rank) 911 for given_rank in given_ranks[1:]: 912 result = math_ops.logical_or( 913 result, math_ops.equal(given_rank, actual_rank)) 914 return result 915 916 917 def _assert_ranks_condition( 918 x, ranks, static_condition, dynamic_condition, data, summarize): 919 """Assert `x` has a rank that satisfies a given condition. 920 921 Args: 922 x: Numeric `Tensor`. 923 ranks: Scalar `Tensor`. 924 static_condition: A python function that takes 925 `[actual_rank, given_ranks]` and returns `True` if the condition is 926 satisfied, `False` otherwise. 927 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 928 and return `True` if the condition is satisfied, `False` otherwise. 929 data: The tensors to print out if the condition is false. Defaults to 930 error message and first few entries of `x`. 931 summarize: Print this many entries of each tensor. 932 933 Returns: 934 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 935 936 Raises: 937 ValueError: If static checks determine `x` fails static_condition. 938 """ 939 for rank in ranks: 940 assert_type(rank, dtypes.int32) 941 942 # Attempt to statically defined rank. 943 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 944 if not any(r is None for r in ranks_static): 945 for rank_static in ranks_static: 946 if rank_static.ndim != 0: 947 raise ValueError('Rank must be a scalar.') 948 949 x_rank_static = x.get_shape().ndims 950 if x_rank_static is not None: 951 if not static_condition(x_rank_static, ranks_static): 952 raise ValueError( 953 'Static rank condition failed', x_rank_static, ranks_static) 954 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 955 956 condition = dynamic_condition(array_ops.rank(x), ranks) 957 958 # Add the condition that `rank` must have rank zero. Prevents the bug where 959 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 960 for rank, rank_static in zip(ranks, ranks_static): 961 if rank_static is None: 962 this_data = ['Rank must be a scalar. Received rank: ', rank] 963 rank_check = assert_rank(rank, 0, data=this_data) 964 condition = control_flow_ops.with_dependencies([rank_check], condition) 965 966 return control_flow_ops.Assert(condition, data, summarize=summarize) 967 968 969 @tf_export('assert_rank_in') 970 def assert_rank_in( 971 x, ranks, data=None, summarize=None, message=None, name=None): 972 """Assert `x` has rank in `ranks`. 973 974 Example of adding a dependency to an operation: 975 976 ```python 977 with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]): 978 output = tf.reduce_sum(x) 979 ``` 980 981 Args: 982 x: Numeric `Tensor`. 983 ranks: Iterable of scalar `Tensor` objects. 984 data: The tensors to print out if the condition is False. Defaults to 985 error message and first few entries of `x`. 986 summarize: Print this many entries of each tensor. 987 message: A string to prefix to the default message. 988 name: A name for this operation (optional). 989 Defaults to "assert_rank_in". 990 991 Returns: 992 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 993 If static checks determine `x` has matching rank, a `no_op` is returned. 994 995 Raises: 996 ValueError: If static checks determine `x` has mismatched rank. 997 """ 998 with ops.name_scope( 999 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 1000 x = ops.convert_to_tensor(x, name='x') 1001 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 1002 message = message or '' 1003 1004 if context.in_eager_mode(): 1005 name = '' 1006 else: 1007 name = x.name 1008 1009 if data is None: 1010 data = [ 1011 message, 'Tensor %s must have rank in' % name 1012 ] + list(ranks) + [ 1013 'Received shape: ', array_ops.shape(x) 1014 ] 1015 1016 try: 1017 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 1018 _dynamic_rank_in, data, summarize) 1019 1020 except ValueError as e: 1021 if e.args[0] == 'Static rank condition failed': 1022 raise ValueError( 1023 '%s. Tensor %s must have rank in %s. Received rank %d, ' 1024 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1025 else: 1026 raise 1027 1028 return assert_op 1029 1030 1031 @tf_export('assert_integer') 1032 def assert_integer(x, message=None, name=None): 1033 """Assert that `x` is of integer dtype. 1034 1035 Example of adding a dependency to an operation: 1036 1037 ```python 1038 with tf.control_dependencies([tf.assert_integer(x)]): 1039 output = tf.reduce_sum(x) 1040 ``` 1041 1042 Args: 1043 x: `Tensor` whose basetype is integer and is not quantized. 1044 message: A string to prefix to the default message. 1045 name: A name for this operation (optional). Defaults to "assert_integer". 1046 1047 Raises: 1048 TypeError: If `x.dtype` is anything other than non-quantized integer. 1049 1050 Returns: 1051 A `no_op` that does nothing. Type can be determined statically. 1052 """ 1053 message = message or '' 1054 with ops.name_scope(name, 'assert_integer', [x]): 1055 x = ops.convert_to_tensor(x, name='x') 1056 if not x.dtype.is_integer: 1057 if context.in_eager_mode(): 1058 name = 'tensor' 1059 else: 1060 name = x.name 1061 err_msg = ( 1062 '%s Expected "x" to be integer type. Found: %s of dtype %s' 1063 % (message, name, x.dtype)) 1064 raise TypeError(err_msg) 1065 1066 return control_flow_ops.no_op('statically_determined_was_integer') 1067 1068 1069 @tf_export('assert_type') 1070 def assert_type(tensor, tf_type, message=None, name=None): 1071 """Statically asserts that the given `Tensor` is of the specified type. 1072 1073 Args: 1074 tensor: A tensorflow `Tensor`. 1075 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1076 etc). 1077 message: A string to prefix to the default message. 1078 name: A name to give this `Op`. Defaults to "assert_type" 1079 1080 Raises: 1081 TypeError: If the tensors data type doesn't match `tf_type`. 1082 1083 Returns: 1084 A `no_op` that does nothing. Type can be determined statically. 1085 """ 1086 message = message or '' 1087 with ops.name_scope(name, 'assert_type', [tensor]): 1088 tensor = ops.convert_to_tensor(tensor, name='tensor') 1089 if tensor.dtype != tf_type: 1090 if context.in_graph_mode(): 1091 raise TypeError( 1092 '%s %s must be of type %s' % (message, tensor.name, tf_type)) 1093 else: 1094 raise TypeError( 1095 '%s tensor must be of type %s' % (message, tf_type)) 1096 1097 return control_flow_ops.no_op('statically_determined_correct_type') 1098 1099 1100 # pylint: disable=line-too-long 1101 def _get_diff_for_monotonic_comparison(x): 1102 """Gets the difference x[1:] - x[:-1].""" 1103 x = array_ops.reshape(x, [-1]) 1104 if not is_numeric_tensor(x): 1105 raise TypeError('Expected x to be numeric, instead found: %s' % x) 1106 1107 # If x has less than 2 elements, there is nothing to compare. So return []. 1108 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 1109 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 1110 1111 # With 2 or more elements, return x[1:] - x[:-1] 1112 s_len = array_ops.shape(x) - 1 1113 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 1114 return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 1115 1116 1117 @tf_export('is_numeric_tensor') 1118 def is_numeric_tensor(tensor): 1119 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 1120 1121 1122 @tf_export('is_non_decreasing') 1123 def is_non_decreasing(x, name=None): 1124 """Returns `True` if `x` is non-decreasing. 1125 1126 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 1127 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 1128 If `x` has less than two elements, it is trivially non-decreasing. 1129 1130 See also: `is_strictly_increasing` 1131 1132 Args: 1133 x: Numeric `Tensor`. 1134 name: A name for this operation (optional). Defaults to "is_non_decreasing" 1135 1136 Returns: 1137 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 1138 1139 Raises: 1140 TypeError: if `x` is not a numeric tensor. 1141 """ 1142 with ops.name_scope(name, 'is_non_decreasing', [x]): 1143 diff = _get_diff_for_monotonic_comparison(x) 1144 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 1145 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 1146 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 1147 1148 1149 @tf_export('is_strictly_increasing') 1150 def is_strictly_increasing(x, name=None): 1151 """Returns `True` if `x` is strictly increasing. 1152 1153 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 1154 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 1155 If `x` has less than two elements, it is trivially strictly increasing. 1156 1157 See also: `is_non_decreasing` 1158 1159 Args: 1160 x: Numeric `Tensor`. 1161 name: A name for this operation (optional). 1162 Defaults to "is_strictly_increasing" 1163 1164 Returns: 1165 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 1166 1167 Raises: 1168 TypeError: if `x` is not a numeric tensor. 1169 """ 1170 with ops.name_scope(name, 'is_strictly_increasing', [x]): 1171 diff = _get_diff_for_monotonic_comparison(x) 1172 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 1173 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 1174 return math_ops.reduce_all(math_ops.less(zero, diff)) 1175 1176 1177 def _assert_same_base_type(items, expected_type=None): 1178 r"""Asserts all items are of the same base type. 1179 1180 Args: 1181 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 1182 `Operation`, or `IndexedSlices`). Can include `None` elements, which 1183 will be ignored. 1184 expected_type: Expected type. If not specified, assert all items are 1185 of the same base type. 1186 1187 Returns: 1188 Validated type, or none if neither expected_type nor items provided. 1189 1190 Raises: 1191 ValueError: If any types do not match. 1192 """ 1193 original_item_str = None 1194 for item in items: 1195 if item is not None: 1196 item_type = item.dtype.base_dtype 1197 if not expected_type: 1198 expected_type = item_type 1199 original_item_str = item.name if hasattr(item, 'name') else str(item) 1200 elif expected_type != item_type: 1201 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 1202 item.name if hasattr(item, 'name') else str(item), 1203 item_type, expected_type, 1204 (' as %s' % original_item_str) if original_item_str else '')) 1205 return expected_type 1206 1207 1208 @tf_export('assert_same_float_dtype') 1209 def assert_same_float_dtype(tensors=None, dtype=None): 1210 """Validate and return float type based on `tensors` and `dtype`. 1211 1212 For ops such as matrix multiplication, inputs and weights must be of the 1213 same float type. This function validates that all `tensors` are the same type, 1214 validates that type is `dtype` (if supplied), and returns the type. Type must 1215 be a floating point type. If neither `tensors` nor `dtype` is supplied, 1216 the function will return `dtypes.float32`. 1217 1218 Args: 1219 tensors: Tensors of input values. Can include `None` elements, which will be 1220 ignored. 1221 dtype: Expected type. 1222 Returns: 1223 Validated type. 1224 Raises: 1225 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 1226 float, or the common type of the inputs is not a floating point type. 1227 """ 1228 if tensors: 1229 dtype = _assert_same_base_type(tensors, dtype) 1230 if not dtype: 1231 dtype = dtypes.float32 1232 elif not dtype.is_floating: 1233 raise ValueError('Expected floating point type, got %s.' % dtype) 1234 return dtype 1235 1236 1237 @tf_export('assert_scalar') 1238 def assert_scalar(tensor, name=None): 1239 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 1240 tensor = ops.convert_to_tensor(tensor, name=name_scope) 1241 shape = tensor.get_shape() 1242 if shape.ndims != 0: 1243 if context.in_eager_mode(): 1244 raise ValueError('Expected scalar shape, saw shape: %s.' 1245 % (shape,)) 1246 else: 1247 raise ValueError('Expected scalar shape for %s, saw shape: %s.' 1248 % (tensor.name, shape)) 1249 return tensor 1250