1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of image ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.compat import compat 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import random_seed 28 from tensorflow.python.framework import tensor_shape 29 from tensorflow.python.framework import tensor_util 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import check_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import gen_image_ops 34 from tensorflow.python.ops import gen_nn_ops 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import nn 37 from tensorflow.python.ops import nn_ops 38 from tensorflow.python.ops import random_ops 39 from tensorflow.python.ops import string_ops 40 from tensorflow.python.ops import variables 41 from tensorflow.python.util import deprecation 42 from tensorflow.python.util.tf_export import tf_export 43 44 ops.NotDifferentiable('RandomCrop') 45 # TODO(b/31222613): This op may be differentiable, and there may be 46 # latent bugs here. 47 ops.NotDifferentiable('RGBToHSV') 48 # TODO(b/31222613): This op may be differentiable, and there may be 49 # latent bugs here. 50 ops.NotDifferentiable('HSVToRGB') 51 ops.NotDifferentiable('DrawBoundingBoxes') 52 ops.NotDifferentiable('SampleDistortedBoundingBox') 53 ops.NotDifferentiable('SampleDistortedBoundingBoxV2') 54 # TODO(bsteiner): Implement the gradient function for extract_glimpse 55 # TODO(b/31222613): This op may be differentiable, and there may be 56 # latent bugs here. 57 ops.NotDifferentiable('ExtractGlimpse') 58 ops.NotDifferentiable('NonMaxSuppression') 59 ops.NotDifferentiable('NonMaxSuppressionV2') 60 ops.NotDifferentiable('NonMaxSuppressionWithOverlaps') 61 62 63 # pylint: disable=invalid-name 64 def _assert(cond, ex_type, msg): 65 """A polymorphic assert, works with tensors and boolean expressions. 66 67 If `cond` is not a tensor, behave like an ordinary assert statement, except 68 that a empty list is returned. If `cond` is a tensor, return a list 69 containing a single TensorFlow assert op. 70 71 Args: 72 cond: Something evaluates to a boolean value. May be a tensor. 73 ex_type: The exception class to use. 74 msg: The error message. 75 76 Returns: 77 A list, containing at most one assert op. 78 """ 79 if _is_tensor(cond): 80 return [control_flow_ops.Assert(cond, [msg])] 81 else: 82 if not cond: 83 raise ex_type(msg) 84 else: 85 return [] 86 87 88 def _is_tensor(x): 89 """Returns `True` if `x` is a symbolic tensor-like object. 90 91 Args: 92 x: A python object to check. 93 94 Returns: 95 `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. 96 """ 97 return isinstance(x, (ops.Tensor, variables.Variable)) 98 99 100 def _ImageDimensions(image, rank): 101 """Returns the dimensions of an image tensor. 102 103 Args: 104 image: A rank-D Tensor. For 3-D of shape: `[height, width, channels]`. 105 rank: The expected rank of the image 106 107 Returns: 108 A list of corresponding to the dimensions of the 109 input image. Dimensions that are statically known are python integers, 110 otherwise they are integer scalar tensors. 111 """ 112 if image.get_shape().is_fully_defined(): 113 return image.get_shape().as_list() 114 else: 115 static_shape = image.get_shape().with_rank(rank).as_list() 116 dynamic_shape = array_ops.unstack(array_ops.shape(image), rank) 117 return [ 118 s if s is not None else d for s, d in zip(static_shape, dynamic_shape) 119 ] 120 121 122 def _Check3DImage(image, require_static=True): 123 """Assert that we are working with properly shaped image. 124 125 Args: 126 image: 3-D Tensor of shape [height, width, channels] 127 require_static: If `True`, requires that all dimensions of `image` are 128 known and non-zero. 129 130 Raises: 131 ValueError: if `image.shape` is not a 3-vector. 132 133 Returns: 134 An empty list, if `image` has fully defined dimensions. Otherwise, a list 135 containing an assert op is returned. 136 """ 137 try: 138 image_shape = image.get_shape().with_rank(3) 139 except ValueError: 140 raise ValueError( 141 "'image' (shape %s) must be three-dimensional." % image.shape) 142 if require_static and not image_shape.is_fully_defined(): 143 raise ValueError("'image' (shape %s) must be fully defined." % image_shape) 144 if any(x == 0 for x in image_shape): 145 raise ValueError("all dims of 'image.shape' must be > 0: %s" % image_shape) 146 if not image_shape.is_fully_defined(): 147 return [ 148 check_ops.assert_positive( 149 array_ops.shape(image), 150 ["all dims of 'image.shape' " 151 'must be > 0.']) 152 ] 153 else: 154 return [] 155 156 157 def _Assert3DImage(image): 158 """Assert that we are working with a properly shaped image. 159 160 Performs the check statically if possible (i.e. if the shape 161 is statically known). Otherwise adds a control dependency 162 to an assert op that checks the dynamic shape. 163 164 Args: 165 image: 3-D Tensor of shape [height, width, channels] 166 167 Raises: 168 ValueError: if `image.shape` is not a 3-vector. 169 170 Returns: 171 If the shape of `image` could be verified statically, `image` is 172 returned unchanged, otherwise there will be a control dependency 173 added that asserts the correct dynamic shape. 174 """ 175 return control_flow_ops.with_dependencies( 176 _Check3DImage(image, require_static=False), image) 177 178 179 def _AssertAtLeast3DImage(image): 180 """Assert that we are working with a properly shaped image. 181 182 Performs the check statically if possible (i.e. if the shape 183 is statically known). Otherwise adds a control dependency 184 to an assert op that checks the dynamic shape. 185 186 Args: 187 image: >= 3-D Tensor of size [*, height, width, depth] 188 189 Raises: 190 ValueError: if image.shape is not a [>= 3] vector. 191 192 Returns: 193 If the shape of `image` could be verified statically, `image` is 194 returned unchanged, otherwise there will be a control dependency 195 added that asserts the correct dynamic shape. 196 """ 197 return control_flow_ops.with_dependencies( 198 _CheckAtLeast3DImage(image, require_static=False), image) 199 200 201 def _CheckAtLeast3DImage(image, require_static=True): 202 """Assert that we are working with properly shaped image. 203 204 Args: 205 image: >= 3-D Tensor of size [*, height, width, depth] 206 require_static: If `True`, requires that all dimensions of `image` are 207 known and non-zero. 208 209 Raises: 210 ValueError: if image.shape is not a [>= 3] vector. 211 212 Returns: 213 An empty list, if `image` has fully defined dimensions. Otherwise, a list 214 containing an assert op is returned. 215 """ 216 try: 217 if image.get_shape().ndims is None: 218 image_shape = image.get_shape().with_rank(3) 219 else: 220 image_shape = image.get_shape().with_rank_at_least(3) 221 except ValueError: 222 raise ValueError("'image' must be at least three-dimensional.") 223 if require_static and not image_shape.is_fully_defined(): 224 raise ValueError('\'image\' must be fully defined.') 225 if any(x == 0 for x in image_shape): 226 raise ValueError( 227 'all dims of \'image.shape\' must be > 0: %s' % image_shape) 228 if not image_shape.is_fully_defined(): 229 return [ 230 check_ops.assert_positive( 231 array_ops.shape(image), 232 ["all dims of 'image.shape' " 233 'must be > 0.']) 234 ] 235 else: 236 return [] 237 238 239 def fix_image_flip_shape(image, result): 240 """Set the shape to 3 dimensional if we don't know anything else. 241 242 Args: 243 image: original image size 244 result: flipped or transformed image 245 246 Returns: 247 An image whose shape is at least None,None,None. 248 """ 249 250 image_shape = image.get_shape() 251 if image_shape == tensor_shape.unknown_shape(): 252 result.set_shape([None, None, None]) 253 else: 254 result.set_shape(image_shape) 255 return result 256 257 258 @tf_export('image.random_flip_up_down') 259 def random_flip_up_down(image, seed=None): 260 """Randomly flips an image vertically (upside down). 261 262 With a 1 in 2 chance, outputs the contents of `image` flipped along the first 263 dimension, which is `height`. Otherwise output the image as-is. 264 265 Args: 266 image: 4-D Tensor of shape `[batch, height, width, channels]` or 267 3-D Tensor of shape `[height, width, channels]`. 268 seed: A Python integer. Used to create a random seed. See 269 `tf.set_random_seed` 270 for behavior. 271 272 Returns: 273 A tensor of the same type and shape as `image`. 274 Raises: 275 ValueError: if the shape of `image` not supported. 276 """ 277 return _random_flip(image, 0, seed, 'random_flip_up_down') 278 279 280 @tf_export('image.random_flip_left_right') 281 def random_flip_left_right(image, seed=None): 282 """Randomly flip an image horizontally (left to right). 283 284 With a 1 in 2 chance, outputs the contents of `image` flipped along the 285 second dimension, which is `width`. Otherwise output the image as-is. 286 287 Args: 288 image: 4-D Tensor of shape `[batch, height, width, channels]` or 289 3-D Tensor of shape `[height, width, channels]`. 290 seed: A Python integer. Used to create a random seed. See 291 `tf.set_random_seed` 292 for behavior. 293 294 Returns: 295 A tensor of the same type and shape as `image`. 296 297 Raises: 298 ValueError: if the shape of `image` not supported. 299 """ 300 return _random_flip(image, 1, seed, 'random_flip_left_right') 301 302 303 def _random_flip(image, flip_index, seed, scope_name): 304 """Randomly (50% chance) flip an image along axis `flip_index`. 305 306 Args: 307 image: 4-D Tensor of shape `[batch, height, width, channels]` or 308 3-D Tensor of shape `[height, width, channels]`. 309 flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1 310 seed: A Python integer. Used to create a random seed. See 311 `tf.set_random_seed` 312 for behavior. 313 scope_name: Name of the scope in which the ops are added. 314 315 Returns: 316 A tensor of the same type and shape as `image`. 317 318 Raises: 319 ValueError: if the shape of `image` not supported. 320 """ 321 with ops.name_scope(None, scope_name, [image]) as scope: 322 image = ops.convert_to_tensor(image, name='image') 323 image = _AssertAtLeast3DImage(image) 324 shape = image.get_shape() 325 if shape.ndims == 3 or shape.ndims is None: 326 uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) 327 mirror_cond = math_ops.less(uniform_random, .5) 328 result = control_flow_ops.cond( 329 mirror_cond, 330 lambda: array_ops.reverse(image, [flip_index]), 331 lambda: image, 332 name=scope 333 ) 334 return fix_image_flip_shape(image, result) 335 elif shape.ndims == 4: 336 batch_size = array_ops.shape(image)[0] 337 uniform_random = random_ops.random_uniform( 338 [batch_size], 0, 1.0, seed=seed 339 ) 340 flips = math_ops.round( 341 array_ops.reshape(uniform_random, [batch_size, 1, 1, 1]) 342 ) 343 flips = math_ops.cast(flips, image.dtype) 344 flipped_input = array_ops.reverse(image, [flip_index + 1]) 345 return flips * flipped_input + (1 - flips) * image 346 else: 347 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 348 349 350 @tf_export('image.flip_left_right') 351 def flip_left_right(image): 352 """Flip an image horizontally (left to right). 353 354 Outputs the contents of `image` flipped along the width dimension. 355 356 See also `reverse()`. 357 358 Args: 359 image: 4-D Tensor of shape `[batch, height, width, channels]` or 360 3-D Tensor of shape `[height, width, channels]`. 361 362 Returns: 363 A tensor of the same type and shape as `image`. 364 365 Raises: 366 ValueError: if the shape of `image` not supported. 367 """ 368 return _flip(image, 1, 'flip_left_right') 369 370 371 @tf_export('image.flip_up_down') 372 def flip_up_down(image): 373 """Flip an image vertically (upside down). 374 375 Outputs the contents of `image` flipped along the height dimension. 376 377 See also `reverse()`. 378 379 Args: 380 image: 4-D Tensor of shape `[batch, height, width, channels]` or 381 3-D Tensor of shape `[height, width, channels]`. 382 383 Returns: 384 A tensor of the same type and shape as `image`. 385 386 Raises: 387 ValueError: if the shape of `image` not supported. 388 """ 389 return _flip(image, 0, 'flip_up_down') 390 391 392 def _flip(image, flip_index, scope_name): 393 """Flip an image either horizontally or vertically. 394 395 Outputs the contents of `image` flipped along the dimension `flip_index`. 396 397 See also `reverse()`. 398 399 Args: 400 image: 4-D Tensor of shape `[batch, height, width, channels]` or 401 3-D Tensor of shape `[height, width, channels]`. 402 flip_index: 0 For vertical, 1 for horizontal. 403 404 Returns: 405 A tensor of the same type and shape as `image`. 406 407 Raises: 408 ValueError: if the shape of `image` not supported. 409 """ 410 with ops.name_scope(None, scope_name, [image]): 411 image = ops.convert_to_tensor(image, name='image') 412 image = _AssertAtLeast3DImage(image) 413 shape = image.get_shape() 414 if shape.ndims == 3 or shape.ndims is None: 415 return fix_image_flip_shape(image, array_ops.reverse(image, [flip_index])) 416 elif shape.ndims == 4: 417 return array_ops.reverse(image, [flip_index+1]) 418 else: 419 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 420 421 422 @tf_export('image.rot90') 423 def rot90(image, k=1, name=None): 424 """Rotate image(s) counter-clockwise by 90 degrees. 425 426 Args: 427 image: 4-D Tensor of shape `[batch, height, width, channels]` or 428 3-D Tensor of shape `[height, width, channels]`. 429 k: A scalar integer. The number of times the image is rotated by 90 degrees. 430 name: A name for this operation (optional). 431 432 Returns: 433 A rotated tensor of the same type and shape as `image`. 434 435 Raises: 436 ValueError: if the shape of `image` not supported. 437 """ 438 with ops.name_scope(name, 'rot90', [image, k]) as scope: 439 image = ops.convert_to_tensor(image, name='image') 440 image = _AssertAtLeast3DImage(image) 441 k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') 442 k.get_shape().assert_has_rank(0) 443 k = math_ops.mod(k, 4) 444 445 shape = image.get_shape() 446 if shape.ndims == 3 or shape.ndims is None: 447 return _rot90_3D(image, k, scope) 448 elif shape.ndims == 4: 449 return _rot90_4D(image, k, scope) 450 else: 451 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 452 453 454 def _rot90_3D(image, k, name_scope): 455 """Rotate image counter-clockwise by 90 degrees `k` times. 456 457 Args: 458 image: 3-D Tensor of shape `[height, width, channels]`. 459 k: A scalar integer. The number of times the image is rotated by 90 degrees. 460 name_scope: A valid TensorFlow name scope. 461 462 Returns: 463 A 3-D tensor of the same type and shape as `image`. 464 465 """ 466 467 def _rot90(): 468 return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) 469 470 def _rot180(): 471 return array_ops.reverse_v2(image, [0, 1]) 472 473 def _rot270(): 474 return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) 475 476 cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), 477 (math_ops.equal(k, 3), _rot270)] 478 479 result = control_flow_ops.case( 480 cases, default=lambda: image, exclusive=True, name=name_scope) 481 result.set_shape([None, None, image.get_shape()[2]]) 482 return result 483 484 485 def _rot90_4D(images, k, name_scope): 486 """Rotate batch of images counter-clockwise by 90 degrees `k` times. 487 488 Args: 489 images: 4-D Tensor of shape `[height, width, channels]`. 490 k: A scalar integer. The number of times the images are rotated by 90 491 degrees. 492 name_scope: A valid TensorFlow name scope. 493 494 Returns: 495 A 4-D tensor of the same type and shape as `images`. 496 497 """ 498 499 def _rot90(): 500 return array_ops.transpose(array_ops.reverse_v2(images, [2]), [0, 2, 1, 3]) 501 502 def _rot180(): 503 return array_ops.reverse_v2(images, [1, 2]) 504 def _rot270(): 505 return array_ops.reverse_v2(array_ops.transpose(images, [0, 2, 1, 3]), [2]) 506 507 cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), 508 (math_ops.equal(k, 3), _rot270)] 509 510 result = control_flow_ops.case( 511 cases, default=lambda: images, exclusive=True, name=name_scope) 512 shape = result.get_shape() 513 result.set_shape([shape[0], None, None, shape[3]]) 514 return result 515 516 517 @tf_export(v1=['image.transpose', 'image.transpose_image']) 518 def transpose_image(image): 519 return transpose(image=image, name=None) 520 521 522 @tf_export('image.transpose', v1=[]) 523 def transpose(image, name=None): 524 """Transpose image(s) by swapping the height and width dimension. 525 526 Args: 527 image: 4-D Tensor of shape `[batch, height, width, channels]` or 528 3-D Tensor of shape `[height, width, channels]`. 529 name: A name for this operation (optional). 530 531 Returns: 532 If `image` was 4-D, a 4-D float Tensor of shape 533 `[batch, width, height, channels]` 534 If `image` was 3-D, a 3-D float Tensor of shape 535 `[width, height, channels]` 536 537 Raises: 538 ValueError: if the shape of `image` not supported. 539 """ 540 with ops.name_scope(name, 'transpose', [image]): 541 image = ops.convert_to_tensor(image, name='image') 542 image = _AssertAtLeast3DImage(image) 543 shape = image.get_shape() 544 if shape.ndims == 3 or shape.ndims is None: 545 return array_ops.transpose(image, [1, 0, 2], name=name) 546 elif shape.ndims == 4: 547 return array_ops.transpose(image, [0, 2, 1, 3], name=name) 548 else: 549 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 550 551 552 @tf_export('image.central_crop') 553 def central_crop(image, central_fraction): 554 """Crop the central region of the image(s). 555 556 Remove the outer parts of an image but retain the central region of the image 557 along each dimension. If we specify central_fraction = 0.5, this function 558 returns the region marked with "X" in the below diagram. 559 560 -------- 561 | | 562 | XXXX | 563 | XXXX | 564 | | where "X" is the central 50% of the image. 565 -------- 566 567 This function works on either a single image (`image` is a 3-D Tensor), or a 568 batch of images (`image` is a 4-D Tensor). 569 570 Args: 571 image: Either a 3-D float Tensor of shape [height, width, depth], or a 4-D 572 Tensor of shape [batch_size, height, width, depth]. 573 central_fraction: float (0, 1], fraction of size to crop 574 575 Raises: 576 ValueError: if central_crop_fraction is not within (0, 1]. 577 578 Returns: 579 3-D / 4-D float Tensor, as per the input. 580 """ 581 with ops.name_scope(None, 'central_crop', [image]): 582 image = ops.convert_to_tensor(image, name='image') 583 if central_fraction <= 0.0 or central_fraction > 1.0: 584 raise ValueError('central_fraction must be within (0, 1]') 585 if central_fraction == 1.0: 586 return image 587 588 _AssertAtLeast3DImage(image) 589 rank = image.get_shape().ndims 590 if rank != 3 and rank != 4: 591 raise ValueError('`image` should either be a Tensor with rank = 3 or ' 592 'rank = 4. Had rank = {}.'.format(rank)) 593 594 # Helper method to return the `idx`-th dimension of `tensor`, along with 595 # a boolean signifying if the dimension is dynamic. 596 def _get_dim(tensor, idx): 597 static_shape = tensor.get_shape().dims[idx].value 598 if static_shape is not None: 599 return static_shape, False 600 return array_ops.shape(tensor)[idx], True 601 602 # Get the height, width, depth (and batch size, if the image is a 4-D 603 # tensor). 604 if rank == 3: 605 img_h, dynamic_h = _get_dim(image, 0) 606 img_w, dynamic_w = _get_dim(image, 1) 607 img_d = image.get_shape()[2] 608 else: 609 img_bs = image.get_shape()[0] 610 img_h, dynamic_h = _get_dim(image, 1) 611 img_w, dynamic_w = _get_dim(image, 2) 612 img_d = image.get_shape()[3] 613 614 # Compute the bounding boxes for the crop. The type and value of the 615 # bounding boxes depend on the `image` tensor's rank and whether / not the 616 # dimensions are statically defined. 617 if dynamic_h: 618 img_hd = math_ops.cast(img_h, dtypes.float64) 619 bbox_h_start = math_ops.cast( 620 (img_hd - img_hd * central_fraction) / 2, dtypes.int32) 621 else: 622 img_hd = float(img_h) 623 bbox_h_start = int((img_hd - img_hd * central_fraction) / 2) 624 625 if dynamic_w: 626 img_wd = math_ops.cast(img_w, dtypes.float64) 627 bbox_w_start = math_ops.cast( 628 (img_wd - img_wd * central_fraction) / 2, dtypes.int32) 629 else: 630 img_wd = float(img_w) 631 bbox_w_start = int((img_wd - img_wd * central_fraction) / 2) 632 633 bbox_h_size = img_h - bbox_h_start * 2 634 bbox_w_size = img_w - bbox_w_start * 2 635 636 if rank == 3: 637 bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0]) 638 bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1]) 639 else: 640 bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0]) 641 bbox_size = array_ops.stack([-1, bbox_h_size, bbox_w_size, -1]) 642 643 image = array_ops.slice(image, bbox_begin, bbox_size) 644 645 # Reshape the `image` tensor to the desired size. 646 if rank == 3: 647 image.set_shape([ 648 None if dynamic_h else bbox_h_size, 649 None if dynamic_w else bbox_w_size, 650 img_d 651 ]) 652 else: 653 image.set_shape([ 654 img_bs, 655 None if dynamic_h else bbox_h_size, 656 None if dynamic_w else bbox_w_size, 657 img_d 658 ]) 659 return image 660 661 662 @tf_export('image.pad_to_bounding_box') 663 def pad_to_bounding_box(image, offset_height, offset_width, target_height, 664 target_width): 665 """Pad `image` with zeros to the specified `height` and `width`. 666 667 Adds `offset_height` rows of zeros on top, `offset_width` columns of 668 zeros on the left, and then pads the image on the bottom and right 669 with zeros until it has dimensions `target_height`, `target_width`. 670 671 This op does nothing if `offset_*` is zero and the image already has size 672 `target_height` by `target_width`. 673 674 Args: 675 image: 4-D Tensor of shape `[batch, height, width, channels]` or 676 3-D Tensor of shape `[height, width, channels]`. 677 offset_height: Number of rows of zeros to add on top. 678 offset_width: Number of columns of zeros to add on the left. 679 target_height: Height of output image. 680 target_width: Width of output image. 681 682 Returns: 683 If `image` was 4-D, a 4-D float Tensor of shape 684 `[batch, target_height, target_width, channels]` 685 If `image` was 3-D, a 3-D float Tensor of shape 686 `[target_height, target_width, channels]` 687 688 Raises: 689 ValueError: If the shape of `image` is incompatible with the `offset_*` or 690 `target_*` arguments, or either `offset_height` or `offset_width` is 691 negative. 692 """ 693 with ops.name_scope(None, 'pad_to_bounding_box', [image]): 694 image = ops.convert_to_tensor(image, name='image') 695 696 is_batch = True 697 image_shape = image.get_shape() 698 if image_shape.ndims == 3: 699 is_batch = False 700 image = array_ops.expand_dims(image, 0) 701 elif image_shape.ndims is None: 702 is_batch = False 703 image = array_ops.expand_dims(image, 0) 704 image.set_shape([None] * 4) 705 elif image_shape.ndims != 4: 706 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 707 708 assert_ops = _CheckAtLeast3DImage(image, require_static=False) 709 batch, height, width, depth = _ImageDimensions(image, rank=4) 710 711 after_padding_width = target_width - offset_width - width 712 713 after_padding_height = target_height - offset_height - height 714 715 assert_ops += _assert(offset_height >= 0, ValueError, 716 'offset_height must be >= 0') 717 assert_ops += _assert(offset_width >= 0, ValueError, 718 'offset_width must be >= 0') 719 assert_ops += _assert(after_padding_width >= 0, ValueError, 720 'width must be <= target - offset') 721 assert_ops += _assert(after_padding_height >= 0, ValueError, 722 'height must be <= target - offset') 723 image = control_flow_ops.with_dependencies(assert_ops, image) 724 725 # Do not pad on the depth dimensions. 726 paddings = array_ops.reshape( 727 array_ops.stack([ 728 0, 0, offset_height, after_padding_height, offset_width, 729 after_padding_width, 0, 0 730 ]), [4, 2]) 731 padded = array_ops.pad(image, paddings) 732 733 padded_shape = [ 734 None if _is_tensor(i) else i 735 for i in [batch, target_height, target_width, depth] 736 ] 737 padded.set_shape(padded_shape) 738 739 if not is_batch: 740 padded = array_ops.squeeze(padded, axis=[0]) 741 742 return padded 743 744 745 @tf_export('image.crop_to_bounding_box') 746 def crop_to_bounding_box(image, offset_height, offset_width, target_height, 747 target_width): 748 """Crops an image to a specified bounding box. 749 750 This op cuts a rectangular part out of `image`. The top-left corner of the 751 returned image is at `offset_height, offset_width` in `image`, and its 752 lower-right corner is at 753 `offset_height + target_height, offset_width + target_width`. 754 755 Args: 756 image: 4-D Tensor of shape `[batch, height, width, channels]` or 757 3-D Tensor of shape `[height, width, channels]`. 758 offset_height: Vertical coordinate of the top-left corner of the result in 759 the input. 760 offset_width: Horizontal coordinate of the top-left corner of the result in 761 the input. 762 target_height: Height of the result. 763 target_width: Width of the result. 764 765 Returns: 766 If `image` was 4-D, a 4-D float Tensor of shape 767 `[batch, target_height, target_width, channels]` 768 If `image` was 3-D, a 3-D float Tensor of shape 769 `[target_height, target_width, channels]` 770 771 Raises: 772 ValueError: If the shape of `image` is incompatible with the `offset_*` or 773 `target_*` arguments, or either `offset_height` or `offset_width` is 774 negative, or either `target_height` or `target_width` is not positive. 775 """ 776 with ops.name_scope(None, 'crop_to_bounding_box', [image]): 777 image = ops.convert_to_tensor(image, name='image') 778 779 is_batch = True 780 image_shape = image.get_shape() 781 if image_shape.ndims == 3: 782 is_batch = False 783 image = array_ops.expand_dims(image, 0) 784 elif image_shape.ndims is None: 785 is_batch = False 786 image = array_ops.expand_dims(image, 0) 787 image.set_shape([None] * 4) 788 elif image_shape.ndims != 4: 789 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 790 791 assert_ops = _CheckAtLeast3DImage(image, require_static=False) 792 793 batch, height, width, depth = _ImageDimensions(image, rank=4) 794 795 assert_ops += _assert(offset_width >= 0, ValueError, 796 'offset_width must be >= 0.') 797 assert_ops += _assert(offset_height >= 0, ValueError, 798 'offset_height must be >= 0.') 799 assert_ops += _assert(target_width > 0, ValueError, 800 'target_width must be > 0.') 801 assert_ops += _assert(target_height > 0, ValueError, 802 'target_height must be > 0.') 803 assert_ops += _assert(width >= (target_width + offset_width), ValueError, 804 'width must be >= target + offset.') 805 assert_ops += _assert(height >= (target_height + offset_height), ValueError, 806 'height must be >= target + offset.') 807 image = control_flow_ops.with_dependencies(assert_ops, image) 808 809 cropped = array_ops.slice( 810 image, array_ops.stack([0, offset_height, offset_width, 0]), 811 array_ops.stack([-1, target_height, target_width, -1])) 812 813 cropped_shape = [ 814 None if _is_tensor(i) else i 815 for i in [batch, target_height, target_width, depth] 816 ] 817 cropped.set_shape(cropped_shape) 818 819 if not is_batch: 820 cropped = array_ops.squeeze(cropped, axis=[0]) 821 822 return cropped 823 824 825 @tf_export('image.resize_image_with_crop_or_pad') 826 def resize_image_with_crop_or_pad(image, target_height, target_width): 827 """Crops and/or pads an image to a target width and height. 828 829 Resizes an image to a target width and height by either centrally 830 cropping the image or padding it evenly with zeros. 831 832 If `width` or `height` is greater than the specified `target_width` or 833 `target_height` respectively, this op centrally crops along that dimension. 834 If `width` or `height` is smaller than the specified `target_width` or 835 `target_height` respectively, this op centrally pads with 0 along that 836 dimension. 837 838 Args: 839 image: 4-D Tensor of shape `[batch, height, width, channels]` or 840 3-D Tensor of shape `[height, width, channels]`. 841 target_height: Target height. 842 target_width: Target width. 843 844 Raises: 845 ValueError: if `target_height` or `target_width` are zero or negative. 846 847 Returns: 848 Cropped and/or padded image. 849 If `images` was 4-D, a 4-D float Tensor of shape 850 `[batch, new_height, new_width, channels]`. 851 If `images` was 3-D, a 3-D float Tensor of shape 852 `[new_height, new_width, channels]`. 853 """ 854 with ops.name_scope(None, 'resize_image_with_crop_or_pad', [image]): 855 image = ops.convert_to_tensor(image, name='image') 856 image_shape = image.get_shape() 857 is_batch = True 858 if image_shape.ndims == 3: 859 is_batch = False 860 image = array_ops.expand_dims(image, 0) 861 elif image_shape.ndims is None: 862 is_batch = False 863 image = array_ops.expand_dims(image, 0) 864 image.set_shape([None] * 4) 865 elif image_shape.ndims != 4: 866 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 867 868 assert_ops = _CheckAtLeast3DImage(image, require_static=False) 869 assert_ops += _assert(target_width > 0, ValueError, 870 'target_width must be > 0.') 871 assert_ops += _assert(target_height > 0, ValueError, 872 'target_height must be > 0.') 873 874 image = control_flow_ops.with_dependencies(assert_ops, image) 875 # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks. 876 # Make sure our checks come first, so that error messages are clearer. 877 if _is_tensor(target_height): 878 target_height = control_flow_ops.with_dependencies( 879 assert_ops, target_height) 880 if _is_tensor(target_width): 881 target_width = control_flow_ops.with_dependencies(assert_ops, 882 target_width) 883 884 def max_(x, y): 885 if _is_tensor(x) or _is_tensor(y): 886 return math_ops.maximum(x, y) 887 else: 888 return max(x, y) 889 890 def min_(x, y): 891 if _is_tensor(x) or _is_tensor(y): 892 return math_ops.minimum(x, y) 893 else: 894 return min(x, y) 895 896 def equal_(x, y): 897 if _is_tensor(x) or _is_tensor(y): 898 return math_ops.equal(x, y) 899 else: 900 return x == y 901 902 _, height, width, _ = _ImageDimensions(image, rank=4) 903 width_diff = target_width - width 904 offset_crop_width = max_(-width_diff // 2, 0) 905 offset_pad_width = max_(width_diff // 2, 0) 906 907 height_diff = target_height - height 908 offset_crop_height = max_(-height_diff // 2, 0) 909 offset_pad_height = max_(height_diff // 2, 0) 910 911 # Maybe crop if needed. 912 cropped = crop_to_bounding_box(image, offset_crop_height, offset_crop_width, 913 min_(target_height, height), 914 min_(target_width, width)) 915 916 # Maybe pad if needed. 917 resized = pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, 918 target_height, target_width) 919 920 # In theory all the checks below are redundant. 921 if resized.get_shape().ndims is None: 922 raise ValueError('resized contains no shape.') 923 924 _, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4) 925 926 assert_ops = [] 927 assert_ops += _assert( 928 equal_(resized_height, target_height), ValueError, 929 'resized height is not correct.') 930 assert_ops += _assert( 931 equal_(resized_width, target_width), ValueError, 932 'resized width is not correct.') 933 934 resized = control_flow_ops.with_dependencies(assert_ops, resized) 935 936 if not is_batch: 937 resized = array_ops.squeeze(resized, axis=[0]) 938 939 return resized 940 941 942 @tf_export(v1=['image.ResizeMethod']) 943 class ResizeMethodV1(object): 944 BILINEAR = 0 945 NEAREST_NEIGHBOR = 1 946 BICUBIC = 2 947 AREA = 3 948 949 950 @tf_export('image.ResizeMethod', v1=[]) 951 class ResizeMethod(object): 952 BILINEAR = 'bilinear' 953 NEAREST_NEIGHBOR = 'nearest' 954 BICUBIC = 'bicubic' 955 AREA = 'area' 956 LANCZOS3 = 'lanczos3' 957 LANCZOS5 = 'lanczos5' 958 GAUSSIAN = 'gaussian' 959 MITCHELLCUBIC = 'mitchellcubic' 960 961 962 def _resize_images_common(images, resizer_fn, size, preserve_aspect_ratio, name, 963 skip_resize_if_same): 964 """Core functionality for v1 and v2 resize functions.""" 965 with ops.name_scope(name, 'resize', [images, size]): 966 images = ops.convert_to_tensor(images, name='images') 967 if images.get_shape().ndims is None: 968 raise ValueError('\'images\' contains no shape.') 969 # TODO(shlens): Migrate this functionality to the underlying Op's. 970 is_batch = True 971 if images.get_shape().ndims == 3: 972 is_batch = False 973 images = array_ops.expand_dims(images, 0) 974 elif images.get_shape().ndims != 4: 975 raise ValueError('\'images\' must have either 3 or 4 dimensions.') 976 977 _, height, width, _ = images.get_shape().as_list() 978 979 try: 980 size = ops.convert_to_tensor(size, dtypes.int32, name='size') 981 except (TypeError, ValueError): 982 raise ValueError('\'size\' must be a 1-D int32 Tensor') 983 if not size.get_shape().is_compatible_with([2]): 984 raise ValueError('\'size\' must be a 1-D Tensor of 2 elements: ' 985 'new_height, new_width') 986 size_const_as_shape = tensor_util.constant_value_as_shape(size) 987 new_height_const = size_const_as_shape.dims[0].value 988 new_width_const = size_const_as_shape.dims[1].value 989 990 if preserve_aspect_ratio: 991 # Get the current shapes of the image, even if dynamic. 992 _, current_height, current_width, _ = _ImageDimensions(images, rank=4) 993 994 # do the computation to find the right scale and height/width. 995 scale_factor_height = ( 996 math_ops.cast(new_height_const, dtypes.float32) / 997 math_ops.cast(current_height, dtypes.float32)) 998 scale_factor_width = ( 999 math_ops.cast(new_width_const, dtypes.float32) / 1000 math_ops.cast(current_width, dtypes.float32)) 1001 scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width) 1002 scaled_height_const = math_ops.cast( 1003 math_ops.round( 1004 scale_factor * math_ops.cast(current_height, dtypes.float32)), 1005 dtypes.int32) 1006 scaled_width_const = math_ops.cast( 1007 math_ops.round( 1008 scale_factor * math_ops.cast(current_width, dtypes.float32)), 1009 dtypes.int32) 1010 1011 # NOTE: Reset the size and other constants used later. 1012 size = ops.convert_to_tensor([scaled_height_const, scaled_width_const], 1013 dtypes.int32, name='size') 1014 size_const_as_shape = tensor_util.constant_value_as_shape(size) 1015 new_height_const = size_const_as_shape.dims[0].value 1016 new_width_const = size_const_as_shape.dims[1].value 1017 1018 # If we can determine that the height and width will be unmodified by this 1019 # transformation, we avoid performing the resize. 1020 if skip_resize_if_same and all( 1021 x is not None 1022 for x in [new_width_const, width, new_height_const, height]) and ( 1023 width == new_width_const and height == new_height_const): 1024 if not is_batch: 1025 images = array_ops.squeeze(images, axis=[0]) 1026 return images 1027 1028 images = resizer_fn(images, size) 1029 1030 # NOTE(mrry): The shape functions for the resize ops cannot unpack 1031 # the packed values in `new_size`, so set the shape here. 1032 images.set_shape([None, new_height_const, new_width_const, None]) 1033 1034 if not is_batch: 1035 images = array_ops.squeeze(images, axis=[0]) 1036 return images 1037 1038 1039 @tf_export(v1=['image.resize_images', 'image.resize']) 1040 def resize_images(images, 1041 size, 1042 method=ResizeMethodV1.BILINEAR, 1043 align_corners=False, 1044 preserve_aspect_ratio=False, 1045 name=None): 1046 """Resize `images` to `size` using the specified `method`. 1047 1048 Resized images will be distorted if their original aspect ratio is not 1049 the same as `size`. To avoid distortions see 1050 `tf.image.resize_image_with_pad`. 1051 1052 `method` can be one of: 1053 1054 * <b>`ResizeMethod.BILINEAR`</b>: [Bilinear interpolation.]( 1055 https://en.wikipedia.org/wiki/Bilinear_interpolation) 1056 * <b>`ResizeMethod.NEAREST_NEIGHBOR`</b>: [Nearest neighbor interpolation.]( 1057 https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation) 1058 * <b>`ResizeMethod.BICUBIC`</b>: [Bicubic interpolation.]( 1059 https://en.wikipedia.org/wiki/Bicubic_interpolation) 1060 * <b>`ResizeMethod.AREA`</b>: Area interpolation. 1061 1062 The return value has the same type as `images` if `method` is 1063 `ResizeMethod.NEAREST_NEIGHBOR`. It will also have the same type as `images` 1064 if the size of `images` can be statically determined to be the same as `size`, 1065 because `images` is returned in this case. Otherwise, the return value has 1066 type `float32`. 1067 1068 Args: 1069 images: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor 1070 of shape `[height, width, channels]`. 1071 size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The new 1072 size for the images. 1073 method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`. 1074 align_corners: bool. If True, the centers of the 4 corner pixels of the 1075 input and output tensors are aligned, preserving the values at the corner 1076 pixels. Defaults to `False`. 1077 preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set, 1078 then `images` will be resized to a size that fits in `size` while 1079 preserving the aspect ratio of the original image. Scales up the image if 1080 `size` is bigger than the current size of the `image`. Defaults to False. 1081 name: A name for this operation (optional). 1082 1083 Raises: 1084 ValueError: if the shape of `images` is incompatible with the 1085 shape arguments to this function 1086 ValueError: if `size` has invalid shape or type. 1087 ValueError: if an unsupported resize method is specified. 1088 1089 Returns: 1090 If `images` was 4-D, a 4-D float Tensor of shape 1091 `[batch, new_height, new_width, channels]`. 1092 If `images` was 3-D, a 3-D float Tensor of shape 1093 `[new_height, new_width, channels]`. 1094 """ 1095 1096 def resize_fn(images_t, new_size): 1097 """Legacy resize core function, passed to _resize_images_common.""" 1098 if method == ResizeMethodV1.BILINEAR or method == ResizeMethod.BILINEAR: 1099 return gen_image_ops.resize_bilinear( 1100 images_t, new_size, align_corners=align_corners) 1101 elif (method == ResizeMethodV1.NEAREST_NEIGHBOR or 1102 method == ResizeMethod.NEAREST_NEIGHBOR): 1103 return gen_image_ops.resize_nearest_neighbor( 1104 images_t, new_size, align_corners=align_corners) 1105 elif method == ResizeMethodV1.BICUBIC or method == ResizeMethod.BICUBIC: 1106 return gen_image_ops.resize_bicubic( 1107 images_t, new_size, align_corners=align_corners) 1108 elif method == ResizeMethodV1.AREA or method == ResizeMethod.AREA: 1109 return gen_image_ops.resize_area( 1110 images_t, new_size, align_corners=align_corners) 1111 else: 1112 raise ValueError('Resize method is not implemented.') 1113 1114 return _resize_images_common( 1115 images, 1116 resize_fn, 1117 size, 1118 preserve_aspect_ratio=preserve_aspect_ratio, 1119 name=name, 1120 skip_resize_if_same=True) 1121 1122 1123 @tf_export('image.resize', v1=[]) 1124 def resize_images_v2(images, 1125 size, 1126 method=ResizeMethod.BILINEAR, 1127 preserve_aspect_ratio=False, 1128 antialias=False, 1129 name=None): 1130 """Resize `images` to `size` using the specified `method`. 1131 1132 Resized images will be distorted if their original aspect ratio is not 1133 the same as `size`. To avoid distortions see 1134 `tf.image.resize_with_pad`. 1135 1136 When 'antialias' is true, the sampling filter will anti-alias the input image 1137 as well as interpolate. When downsampling an image with [anti-aliasing]( 1138 https://en.wikipedia.org/wiki/Spatial_anti-aliasing) the sampling filter 1139 kernel is scaled in order to properly anti-alias the input image signal. 1140 'antialias' has no effect when upsampling an image. 1141 1142 * <b>`bilinear`</b>: [Bilinear interpolation.]( 1143 https://en.wikipedia.org/wiki/Bilinear_interpolation) If 'antialias' is 1144 true, becomes a hat/tent filter function with radius 1 when downsampling. 1145 * <b>`lanczos3`</b>: [Lanczos kernel]( 1146 https://en.wikipedia.org/wiki/Lanczos_resampling) with radius 3. 1147 High-quality practical filter but may have some ringing especially on 1148 synthetic images. 1149 * <b>`lanczos5`</b>: [Lanczos kernel] ( 1150 https://en.wikipedia.org/wiki/Lanczos_resampling) with radius 5. 1151 Very-high-quality filter but may have stronger ringing. 1152 * <b>`bicubic`</b>: [Cubic interpolant]( 1153 https://en.wikipedia.org/wiki/Bicubic_interpolation) of Keys. Equivalent to 1154 Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, 1155 particularly when upsampling. 1156 * <b>`gaussian`</b>: [Gaussian kernel]( 1157 https://en.wikipedia.org/wiki/Gaussian_filter) with radius 3, 1158 sigma = 1.5 / 3.] 1159 * <b>`nearest`</b>: [Nearest neighbor interpolation.]( 1160 https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation) 1161 'antialias' has no effect when used with nearest neighbor interpolation. 1162 * <b>`area`</b>: Anti-aliased resampling with area interpolation. 1163 'antialias' has no effect when used with area interpolation; it 1164 always anti-aliases. 1165 * <b>`mitchellcubic`</b>: Mitchell-Netravali Cubic non-interpolating filter. 1166 For synthetic images (especially those lacking proper prefiltering), less 1167 ringing than Keys cubic kernel but less sharp. 1168 1169 Note that near image edges the filtering kernel may be partially outside the 1170 image boundaries. For these pixels, only input pixels inside the image will be 1171 included in the filter sum, and the output value will be appropriately 1172 normalized. 1173 1174 The return value has the same type as `images` if `method` is 1175 `ResizeMethod.NEAREST_NEIGHBOR`. Otherwise, the return value has type 1176 `float32`. 1177 1178 Args: 1179 images: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor 1180 of shape `[height, width, channels]`. 1181 size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The new 1182 size for the images. 1183 method: ResizeMethod. Defaults to `bilinear`. 1184 preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set, 1185 then `images` will be resized to a size that fits in `size` while 1186 preserving the aspect ratio of the original image. Scales up the image if 1187 `size` is bigger than the current size of the `image`. Defaults to False. 1188 antialias: Whether to use an anti-aliasing filter when downsampling an 1189 image. 1190 name: A name for this operation (optional). 1191 1192 Raises: 1193 ValueError: if the shape of `images` is incompatible with the 1194 shape arguments to this function 1195 ValueError: if `size` has invalid shape or type. 1196 ValueError: if an unsupported resize method is specified. 1197 1198 Returns: 1199 If `images` was 4-D, a 4-D float Tensor of shape 1200 `[batch, new_height, new_width, channels]`. 1201 If `images` was 3-D, a 3-D float Tensor of shape 1202 `[new_height, new_width, channels]`. 1203 """ 1204 1205 def resize_fn(images_t, new_size): 1206 """Resize core function, passed to _resize_images_common.""" 1207 scale_and_translate_methods = [ 1208 ResizeMethod.LANCZOS3, ResizeMethod.LANCZOS5, ResizeMethod.GAUSSIAN, 1209 ResizeMethod.MITCHELLCUBIC 1210 ] 1211 1212 def resize_with_scale_and_translate(method): 1213 scale = ( 1214 math_ops.cast(new_size, dtype=dtypes.float32) / 1215 math_ops.cast(array_ops.shape(images_t)[1:3], dtype=dtypes.float32)) 1216 return gen_image_ops.scale_and_translate( 1217 images_t, 1218 new_size, 1219 scale, 1220 array_ops.zeros([2]), 1221 kernel_type=method, 1222 antialias=antialias) 1223 1224 if method == ResizeMethod.BILINEAR: 1225 if antialias: 1226 return resize_with_scale_and_translate('triangle') 1227 else: 1228 return gen_image_ops.resize_bilinear( 1229 images_t, new_size, half_pixel_centers=True) 1230 elif method == ResizeMethod.NEAREST_NEIGHBOR: 1231 return gen_image_ops.resize_nearest_neighbor( 1232 images_t, new_size, half_pixel_centers=True) 1233 elif method == ResizeMethod.BICUBIC: 1234 if antialias: 1235 return resize_with_scale_and_translate('keyscubic') 1236 else: 1237 return gen_image_ops.resize_bicubic( 1238 images_t, new_size, half_pixel_centers=True) 1239 elif method == ResizeMethod.AREA: 1240 return gen_image_ops.resize_area(images_t, new_size) 1241 elif method in scale_and_translate_methods: 1242 return resize_with_scale_and_translate(method) 1243 else: 1244 raise ValueError('Resize method is not implemented.') 1245 1246 return _resize_images_common( 1247 images, 1248 resize_fn, 1249 size, 1250 preserve_aspect_ratio=preserve_aspect_ratio, 1251 name=name, 1252 skip_resize_if_same=False) 1253 1254 1255 def _resize_image_with_pad_common(image, target_height, target_width, 1256 resize_fn): 1257 """Core functionality for v1 and v2 resize_image_with_pad functions.""" 1258 with ops.name_scope(None, 'resize_image_with_pad', [image]): 1259 image = ops.convert_to_tensor(image, name='image') 1260 image_shape = image.get_shape() 1261 is_batch = True 1262 if image_shape.ndims == 3: 1263 is_batch = False 1264 image = array_ops.expand_dims(image, 0) 1265 elif image_shape.ndims is None: 1266 is_batch = False 1267 image = array_ops.expand_dims(image, 0) 1268 image.set_shape([None] * 4) 1269 elif image_shape.ndims != 4: 1270 raise ValueError('\'image\' must have either 3 or 4 dimensions.') 1271 1272 assert_ops = _CheckAtLeast3DImage(image, require_static=False) 1273 assert_ops += _assert(target_width > 0, ValueError, 1274 'target_width must be > 0.') 1275 assert_ops += _assert(target_height > 0, ValueError, 1276 'target_height must be > 0.') 1277 1278 image = control_flow_ops.with_dependencies(assert_ops, image) 1279 1280 def max_(x, y): 1281 if _is_tensor(x) or _is_tensor(y): 1282 return math_ops.maximum(x, y) 1283 else: 1284 return max(x, y) 1285 1286 _, height, width, _ = _ImageDimensions(image, rank=4) 1287 1288 # convert values to float, to ease divisions 1289 f_height = math_ops.cast(height, dtype=dtypes.float64) 1290 f_width = math_ops.cast(width, dtype=dtypes.float64) 1291 f_target_height = math_ops.cast(target_height, dtype=dtypes.float64) 1292 f_target_width = math_ops.cast(target_width, dtype=dtypes.float64) 1293 1294 # Find the ratio by which the image must be adjusted 1295 # to fit within the target 1296 ratio = max_(f_width / f_target_width, f_height / f_target_height) 1297 resized_height_float = f_height / ratio 1298 resized_width_float = f_width / ratio 1299 resized_height = math_ops.cast( 1300 math_ops.floor(resized_height_float), dtype=dtypes.int32) 1301 resized_width = math_ops.cast( 1302 math_ops.floor(resized_width_float), dtype=dtypes.int32) 1303 1304 padding_height = (f_target_height - resized_height_float) / 2 1305 padding_width = (f_target_width - resized_width_float) / 2 1306 f_padding_height = math_ops.floor(padding_height) 1307 f_padding_width = math_ops.floor(padding_width) 1308 p_height = max_(0, math_ops.cast(f_padding_height, dtype=dtypes.int32)) 1309 p_width = max_(0, math_ops.cast(f_padding_width, dtype=dtypes.int32)) 1310 1311 # Resize first, then pad to meet requested dimensions 1312 resized = resize_fn(image, [resized_height, resized_width]) 1313 1314 padded = pad_to_bounding_box(resized, p_height, p_width, target_height, 1315 target_width) 1316 1317 if padded.get_shape().ndims is None: 1318 raise ValueError('padded contains no shape.') 1319 1320 _ImageDimensions(padded, rank=4) 1321 1322 if not is_batch: 1323 padded = array_ops.squeeze(padded, axis=[0]) 1324 1325 return padded 1326 1327 1328 @tf_export(v1=['image.resize_image_with_pad']) 1329 def resize_image_with_pad_v1(image, 1330 target_height, 1331 target_width, 1332 method=ResizeMethodV1.BILINEAR, 1333 align_corners=False): 1334 """Resizes and pads an image to a target width and height. 1335 1336 Resizes an image to a target width and height by keeping 1337 the aspect ratio the same without distortion. If the target 1338 dimensions don't match the image dimensions, the image 1339 is resized and then padded with zeroes to match requested 1340 dimensions. 1341 1342 Args: 1343 image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor 1344 of shape `[height, width, channels]`. 1345 target_height: Target height. 1346 target_width: Target width. 1347 method: Method to use for resizing image. See `resize_images()` 1348 align_corners: bool. If True, the centers of the 4 corner pixels of the 1349 input and output tensors are aligned, preserving the values at the corner 1350 pixels. Defaults to `False`. 1351 1352 Raises: 1353 ValueError: if `target_height` or `target_width` are zero or negative. 1354 1355 Returns: 1356 Resized and padded image. 1357 If `images` was 4-D, a 4-D float Tensor of shape 1358 `[batch, new_height, new_width, channels]`. 1359 If `images` was 3-D, a 3-D float Tensor of shape 1360 `[new_height, new_width, channels]`. 1361 """ 1362 1363 def _resize_fn(im, new_size): 1364 return resize_images(im, new_size, method, align_corners=align_corners) 1365 1366 return _resize_image_with_pad_common(image, target_height, target_width, 1367 _resize_fn) 1368 1369 1370 @tf_export('image.resize_with_pad', v1=[]) 1371 def resize_image_with_pad_v2(image, 1372 target_height, 1373 target_width, 1374 method=ResizeMethod.BILINEAR, 1375 antialias=False): 1376 """Resizes and pads an image to a target width and height. 1377 1378 Resizes an image to a target width and height by keeping 1379 the aspect ratio the same without distortion. If the target 1380 dimensions don't match the image dimensions, the image 1381 is resized and then padded with zeroes to match requested 1382 dimensions. 1383 1384 Args: 1385 image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor 1386 of shape `[height, width, channels]`. 1387 target_height: Target height. 1388 target_width: Target width. 1389 method: Method to use for resizing image. See `image.resize()` 1390 antialias: Whether to use anti-aliasing when resizing. See 'image.resize()'. 1391 1392 Raises: 1393 ValueError: if `target_height` or `target_width` are zero or negative. 1394 1395 Returns: 1396 Resized and padded image. 1397 If `images` was 4-D, a 4-D float Tensor of shape 1398 `[batch, new_height, new_width, channels]`. 1399 If `images` was 3-D, a 3-D float Tensor of shape 1400 `[new_height, new_width, channels]`. 1401 """ 1402 1403 def _resize_fn(im, new_size): 1404 return resize_images_v2(im, new_size, method, antialias=antialias) 1405 1406 return _resize_image_with_pad_common(image, target_height, target_width, 1407 _resize_fn) 1408 1409 1410 @tf_export('image.per_image_standardization') 1411 def per_image_standardization(image): 1412 """Linearly scales `image` to have zero mean and unit variance. 1413 1414 This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average 1415 of all values in image, and 1416 `adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))`. 1417 1418 `stddev` is the standard deviation of all values in `image`. It is capped 1419 away from zero to protect against division by 0 when handling uniform images. 1420 1421 Args: 1422 image: An n-D Tensor where the last 3 dimensions are `[height, width, 1423 channels]`. 1424 1425 Returns: 1426 The standardized image with same shape as `image`. 1427 1428 Raises: 1429 ValueError: if the shape of 'image' is incompatible with this function. 1430 """ 1431 with ops.name_scope(None, 'per_image_standardization', [image]) as scope: 1432 image = ops.convert_to_tensor(image, name='image') 1433 image = _AssertAtLeast3DImage(image) 1434 num_pixels = math_ops.reduce_prod(array_ops.shape(image)[-3:]) 1435 1436 image = math_ops.cast(image, dtype=dtypes.float32) 1437 image_mean = math_ops.reduce_mean(image, axis=[-1, -2, -3], keepdims=True) 1438 1439 variance = ( 1440 math_ops.reduce_mean( 1441 math_ops.square(image), axis=[-1, -2, -3], keepdims=True) - 1442 math_ops.square(image_mean)) 1443 variance = gen_nn_ops.relu(variance) 1444 stddev = math_ops.sqrt(variance) 1445 1446 # Apply a minimum normalization that protects us against uniform images. 1447 min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32)) 1448 pixel_value_scale = math_ops.maximum(stddev, min_stddev) 1449 pixel_value_offset = image_mean 1450 1451 image = math_ops.subtract(image, pixel_value_offset) 1452 image = math_ops.div(image, pixel_value_scale, name=scope) 1453 return image 1454 1455 1456 @tf_export('image.random_brightness') 1457 def random_brightness(image, max_delta, seed=None): 1458 """Adjust the brightness of images by a random factor. 1459 1460 Equivalent to `adjust_brightness()` using a `delta` randomly picked in the 1461 interval `[-max_delta, max_delta)`. 1462 1463 Args: 1464 image: An image or images to adjust. 1465 max_delta: float, must be non-negative. 1466 seed: A Python integer. Used to create a random seed. See 1467 `tf.set_random_seed` 1468 for behavior. 1469 1470 Returns: 1471 The brightness-adjusted image(s). 1472 1473 Raises: 1474 ValueError: if `max_delta` is negative. 1475 """ 1476 if max_delta < 0: 1477 raise ValueError('max_delta must be non-negative.') 1478 1479 delta = random_ops.random_uniform([], -max_delta, max_delta, seed=seed) 1480 return adjust_brightness(image, delta) 1481 1482 1483 @tf_export('image.random_contrast') 1484 def random_contrast(image, lower, upper, seed=None): 1485 """Adjust the contrast of an image or images by a random factor. 1486 1487 Equivalent to `adjust_contrast()` but uses a `contrast_factor` randomly 1488 picked in the interval `[lower, upper]`. 1489 1490 Args: 1491 image: An image tensor with 3 or more dimensions. 1492 lower: float. Lower bound for the random contrast factor. 1493 upper: float. Upper bound for the random contrast factor. 1494 seed: A Python integer. Used to create a random seed. See 1495 `tf.set_random_seed` for behavior. 1496 1497 Returns: 1498 The contrast-adjusted image(s). 1499 1500 Raises: 1501 ValueError: if `upper <= lower` or if `lower < 0`. 1502 """ 1503 if upper <= lower: 1504 raise ValueError('upper must be > lower.') 1505 1506 if lower < 0: 1507 raise ValueError('lower must be non-negative.') 1508 1509 # Generate an a float in [lower, upper] 1510 contrast_factor = random_ops.random_uniform([], lower, upper, seed=seed) 1511 return adjust_contrast(image, contrast_factor) 1512 1513 1514 @tf_export('image.adjust_brightness') 1515 def adjust_brightness(image, delta): 1516 """Adjust the brightness of RGB or Grayscale images. 1517 1518 This is a convenience method that converts RGB images to float 1519 representation, adjusts their brightness, and then converts them back to the 1520 original data type. If several adjustments are chained, it is advisable to 1521 minimize the number of redundant conversions. 1522 1523 The value `delta` is added to all components of the tensor `image`. `image` is 1524 converted to `float` and scaled appropriately if it is in fixed-point 1525 representation, and `delta` is converted to the same data type. For regular 1526 images, `delta` should be in the range `[0,1)`, as it is added to the image in 1527 floating point representation, where pixel values are in the `[0,1)` range. 1528 1529 Args: 1530 image: RGB image or images to adjust. 1531 delta: A scalar. Amount to add to the pixel values. 1532 1533 Returns: 1534 A brightness-adjusted tensor of the same shape and type as `image`. 1535 """ 1536 with ops.name_scope(None, 'adjust_brightness', [image, delta]) as name: 1537 image = ops.convert_to_tensor(image, name='image') 1538 # Remember original dtype to so we can convert back if needed 1539 orig_dtype = image.dtype 1540 1541 if orig_dtype in [dtypes.float16, dtypes.float32]: 1542 flt_image = image 1543 else: 1544 flt_image = convert_image_dtype(image, dtypes.float32) 1545 1546 adjusted = math_ops.add( 1547 flt_image, math_ops.cast(delta, flt_image.dtype), name=name) 1548 1549 return convert_image_dtype(adjusted, orig_dtype, saturate=True) 1550 1551 1552 @tf_export('image.adjust_contrast') 1553 def adjust_contrast(images, contrast_factor): 1554 """Adjust contrast of RGB or grayscale images. 1555 1556 This is a convenience method that converts RGB images to float 1557 representation, adjusts their contrast, and then converts them back to the 1558 original data type. If several adjustments are chained, it is advisable to 1559 minimize the number of redundant conversions. 1560 1561 `images` is a tensor of at least 3 dimensions. The last 3 dimensions are 1562 interpreted as `[height, width, channels]`. The other dimensions only 1563 represent a collection of images, such as `[batch, height, width, channels].` 1564 1565 Contrast is adjusted independently for each channel of each image. 1566 1567 For each channel, this Op computes the mean of the image pixels in the 1568 channel and then adjusts each component `x` of each pixel to 1569 `(x - mean) * contrast_factor + mean`. 1570 1571 Args: 1572 images: Images to adjust. At least 3-D. 1573 contrast_factor: A float multiplier for adjusting contrast. 1574 1575 Returns: 1576 The contrast-adjusted image or images. 1577 """ 1578 with ops.name_scope(None, 'adjust_contrast', 1579 [images, contrast_factor]) as name: 1580 images = ops.convert_to_tensor(images, name='images') 1581 # Remember original dtype to so we can convert back if needed 1582 orig_dtype = images.dtype 1583 1584 if orig_dtype in (dtypes.float16, dtypes.float32): 1585 flt_images = images 1586 else: 1587 flt_images = convert_image_dtype(images, dtypes.float32) 1588 1589 adjusted = gen_image_ops.adjust_contrastv2( 1590 flt_images, contrast_factor=contrast_factor, name=name) 1591 1592 return convert_image_dtype(adjusted, orig_dtype, saturate=True) 1593 1594 1595 @tf_export('image.adjust_gamma') 1596 def adjust_gamma(image, gamma=1, gain=1): 1597 """Performs Gamma Correction on the input image. 1598 1599 Also known as Power Law Transform. This function transforms the 1600 input image pixelwise according to the equation `Out = In**gamma` 1601 after scaling each pixel to the range 0 to 1. 1602 1603 Args: 1604 image : A Tensor. 1605 gamma : A scalar or tensor. Non negative real number. 1606 gain : A scalar or tensor. The constant multiplier. 1607 1608 Returns: 1609 A Tensor. Gamma corrected output image. 1610 1611 Raises: 1612 ValueError: If gamma is negative. 1613 1614 Notes: 1615 For gamma greater than 1, the histogram will shift towards left and 1616 the output image will be darker than the input image. 1617 For gamma less than 1, the histogram will shift towards right and 1618 the output image will be brighter than the input image. 1619 1620 References: 1621 [1] http://en.wikipedia.org/wiki/Gamma_correction 1622 """ 1623 1624 with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name: 1625 # Convert pixel value to DT_FLOAT for computing adjusted image. 1626 img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32) 1627 # Keep image dtype for computing the scale of corresponding dtype. 1628 image = ops.convert_to_tensor(image, name='image') 1629 1630 assert_op = _assert(gamma >= 0, ValueError, 1631 'Gamma should be a non-negative real number.') 1632 if assert_op: 1633 gamma = control_flow_ops.with_dependencies(assert_op, gamma) 1634 1635 # scale = max(dtype) - min(dtype). 1636 scale = constant_op.constant( 1637 image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32) 1638 # According to the definition of gamma correction. 1639 adjusted_img = (img / scale)**gamma * scale * gain 1640 1641 return adjusted_img 1642 1643 1644 @tf_export('image.convert_image_dtype') 1645 def convert_image_dtype(image, dtype, saturate=False, name=None): 1646 """Convert `image` to `dtype`, scaling its values if needed. 1647 1648 Images that are represented using floating point values are expected to have 1649 values in the range [0,1). Image data stored in integer data types are 1650 expected to have values in the range `[0,MAX]`, where `MAX` is the largest 1651 positive representable number for the data type. 1652 1653 This op converts between data types, scaling the values appropriately before 1654 casting. 1655 1656 Note that converting from floating point inputs to integer types may lead to 1657 over/underflow problems. Set saturate to `True` to avoid such problem in 1658 problematic conversions. If enabled, saturation will clip the output into the 1659 allowed range before performing a potentially dangerous cast (and only before 1660 performing such a cast, i.e., when casting from a floating point to an integer 1661 type, and when casting from a signed to an unsigned type; `saturate` has no 1662 effect on casts between floats, or on casts that increase the type's range). 1663 1664 Args: 1665 image: An image. 1666 dtype: A `DType` to convert `image` to. 1667 saturate: If `True`, clip the input before casting (if necessary). 1668 name: A name for this operation (optional). 1669 1670 Returns: 1671 `image`, converted to `dtype`. 1672 """ 1673 image = ops.convert_to_tensor(image, name='image') 1674 if dtype == image.dtype: 1675 return array_ops.identity(image, name=name) 1676 1677 with ops.name_scope(name, 'convert_image', [image]) as name: 1678 # Both integer: use integer multiplication in the larger range 1679 if image.dtype.is_integer and dtype.is_integer: 1680 scale_in = image.dtype.max 1681 scale_out = dtype.max 1682 if scale_in > scale_out: 1683 # Scaling down, scale first, then cast. The scaling factor will 1684 # cause in.max to be mapped to above out.max but below out.max+1, 1685 # so that the output is safely in the supported range. 1686 scale = (scale_in + 1) // (scale_out + 1) 1687 scaled = math_ops.div(image, scale) 1688 1689 if saturate: 1690 return math_ops.saturate_cast(scaled, dtype, name=name) 1691 else: 1692 return math_ops.cast(scaled, dtype, name=name) 1693 else: 1694 # Scaling up, cast first, then scale. The scale will not map in.max to 1695 # out.max, but converting back and forth should result in no change. 1696 if saturate: 1697 cast = math_ops.saturate_cast(image, dtype) 1698 else: 1699 cast = math_ops.cast(image, dtype) 1700 scale = (scale_out + 1) // (scale_in + 1) 1701 return math_ops.multiply(cast, scale, name=name) 1702 elif image.dtype.is_floating and dtype.is_floating: 1703 # Both float: Just cast, no possible overflows in the allowed ranges. 1704 # Note: We're ignoreing float overflows. If your image dynamic range 1705 # exceeds float range you're on your own. 1706 return math_ops.cast(image, dtype, name=name) 1707 else: 1708 if image.dtype.is_integer: 1709 # Converting to float: first cast, then scale. No saturation possible. 1710 cast = math_ops.cast(image, dtype) 1711 scale = 1. / image.dtype.max 1712 return math_ops.multiply(cast, scale, name=name) 1713 else: 1714 # Converting from float: first scale, then cast 1715 scale = dtype.max + 0.5 # avoid rounding problems in the cast 1716 scaled = math_ops.multiply(image, scale) 1717 if saturate: 1718 return math_ops.saturate_cast(scaled, dtype, name=name) 1719 else: 1720 return math_ops.cast(scaled, dtype, name=name) 1721 1722 1723 @tf_export('image.rgb_to_grayscale') 1724 def rgb_to_grayscale(images, name=None): 1725 """Converts one or more images from RGB to Grayscale. 1726 1727 Outputs a tensor of the same `DType` and rank as `images`. The size of the 1728 last dimension of the output is 1, containing the Grayscale value of the 1729 pixels. 1730 1731 Args: 1732 images: The RGB tensor to convert. Last dimension must have size 3 and 1733 should contain RGB values. 1734 name: A name for the operation (optional). 1735 1736 Returns: 1737 The converted grayscale image(s). 1738 """ 1739 with ops.name_scope(name, 'rgb_to_grayscale', [images]) as name: 1740 images = ops.convert_to_tensor(images, name='images') 1741 # Remember original dtype to so we can convert back if needed 1742 orig_dtype = images.dtype 1743 flt_image = convert_image_dtype(images, dtypes.float32) 1744 1745 # Reference for converting between RGB and grayscale. 1746 # https://en.wikipedia.org/wiki/Luma_%28video%29 1747 rgb_weights = [0.2989, 0.5870, 0.1140] 1748 gray_float = math_ops.tensordot(flt_image, rgb_weights, [-1, -1]) 1749 gray_float = array_ops.expand_dims(gray_float, -1) 1750 return convert_image_dtype(gray_float, orig_dtype, name=name) 1751 1752 1753 @tf_export('image.grayscale_to_rgb') 1754 def grayscale_to_rgb(images, name=None): 1755 """Converts one or more images from Grayscale to RGB. 1756 1757 Outputs a tensor of the same `DType` and rank as `images`. The size of the 1758 last dimension of the output is 3, containing the RGB value of the pixels. 1759 1760 Args: 1761 images: The Grayscale tensor to convert. Last dimension must be size 1. 1762 name: A name for the operation (optional). 1763 1764 Returns: 1765 The converted grayscale image(s). 1766 """ 1767 with ops.name_scope(name, 'grayscale_to_rgb', [images]) as name: 1768 images = ops.convert_to_tensor(images, name='images') 1769 rank_1 = array_ops.expand_dims(array_ops.rank(images) - 1, 0) 1770 shape_list = ([array_ops.ones(rank_1, dtype=dtypes.int32)] + 1771 [array_ops.expand_dims(3, 0)]) 1772 multiples = array_ops.concat(shape_list, 0) 1773 rgb = array_ops.tile(images, multiples, name=name) 1774 rgb.set_shape(images.get_shape()[:-1].concatenate([3])) 1775 return rgb 1776 1777 1778 # pylint: disable=invalid-name 1779 @tf_export('image.random_hue') 1780 def random_hue(image, max_delta, seed=None): 1781 """Adjust the hue of RGB images by a random factor. 1782 1783 Equivalent to `adjust_hue()` but uses a `delta` randomly 1784 picked in the interval `[-max_delta, max_delta]`. 1785 1786 `max_delta` must be in the interval `[0, 0.5]`. 1787 1788 Args: 1789 image: RGB image or images. Size of the last dimension must be 3. 1790 max_delta: float. Maximum value for the random delta. 1791 seed: An operation-specific seed. It will be used in conjunction with the 1792 graph-level seed to determine the real seeds that will be used in this 1793 operation. Please see the documentation of set_random_seed for its 1794 interaction with the graph-level random seed. 1795 1796 Returns: 1797 Adjusted image(s), same shape and DType as `image`. 1798 1799 Raises: 1800 ValueError: if `max_delta` is invalid. 1801 """ 1802 if max_delta > 0.5: 1803 raise ValueError('max_delta must be <= 0.5.') 1804 1805 if max_delta < 0: 1806 raise ValueError('max_delta must be non-negative.') 1807 1808 delta = random_ops.random_uniform([], -max_delta, max_delta, seed=seed) 1809 return adjust_hue(image, delta) 1810 1811 1812 @tf_export('image.adjust_hue') 1813 def adjust_hue(image, delta, name=None): 1814 """Adjust hue of RGB images. 1815 1816 This is a convenience method that converts an RGB image to float 1817 representation, converts it to HSV, add an offset to the hue channel, converts 1818 back to RGB and then back to the original data type. If several adjustments 1819 are chained it is advisable to minimize the number of redundant conversions. 1820 1821 `image` is an RGB image. The image hue is adjusted by converting the 1822 image(s) to HSV and rotating the hue channel (H) by 1823 `delta`. The image is then converted back to RGB. 1824 1825 `delta` must be in the interval `[-1, 1]`. 1826 1827 Args: 1828 image: RGB image or images. Size of the last dimension must be 3. 1829 delta: float. How much to add to the hue channel. 1830 name: A name for this operation (optional). 1831 1832 Returns: 1833 Adjusted image(s), same shape and DType as `image`. 1834 """ 1835 with ops.name_scope(name, 'adjust_hue', [image]) as name: 1836 image = ops.convert_to_tensor(image, name='image') 1837 # Remember original dtype to so we can convert back if needed 1838 orig_dtype = image.dtype 1839 if orig_dtype in (dtypes.float16, dtypes.float32): 1840 flt_image = image 1841 else: 1842 flt_image = convert_image_dtype(image, dtypes.float32) 1843 1844 rgb_altered = gen_image_ops.adjust_hue(flt_image, delta) 1845 1846 return convert_image_dtype(rgb_altered, orig_dtype) 1847 1848 1849 # pylint: disable=invalid-name 1850 @tf_export('image.random_jpeg_quality') 1851 def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None): 1852 """Randomly changes jpeg encoding quality for inducing jpeg noise. 1853 1854 `min_jpeg_quality` must be in the interval `[0, 100]` and less than 1855 `max_jpeg_quality`. 1856 `max_jpeg_quality` must be in the interval `[0, 100]`. 1857 1858 Args: 1859 image: RGB image or images. Size of the last dimension must be 3. 1860 min_jpeg_quality: Minimum jpeg encoding quality to use. 1861 max_jpeg_quality: Maximum jpeg encoding quality to use. 1862 seed: An operation-specific seed. It will be used in conjunction 1863 with the graph-level seed to determine the real seeds that will be 1864 used in this operation. Please see the documentation of 1865 set_random_seed for its interaction with the graph-level random seed. 1866 1867 Returns: 1868 Adjusted image(s), same shape and DType as `image`. 1869 1870 Raises: 1871 ValueError: if `min_jpeg_quality` or `max_jpeg_quality` is invalid. 1872 """ 1873 if (min_jpeg_quality < 0 or max_jpeg_quality < 0 or 1874 min_jpeg_quality > 100 or max_jpeg_quality > 100): 1875 raise ValueError('jpeg encoding range must be between 0 and 100.') 1876 1877 if min_jpeg_quality >= max_jpeg_quality: 1878 raise ValueError('`min_jpeg_quality` must be less than `max_jpeg_quality`.') 1879 1880 np.random.seed(seed) 1881 jpeg_quality = np.random.randint(min_jpeg_quality, max_jpeg_quality) 1882 return adjust_jpeg_quality(image, jpeg_quality) 1883 1884 1885 @tf_export('image.adjust_jpeg_quality') 1886 def adjust_jpeg_quality(image, jpeg_quality, name=None): 1887 """Adjust jpeg encoding quality of an RGB image. 1888 1889 This is a convenience method that adjusts jpeg encoding quality of an 1890 RGB image. 1891 1892 `image` is an RGB image. The image's encoding quality is adjusted 1893 to `jpeg_quality`. 1894 `jpeg_quality` must be in the interval `[0, 100]`. 1895 1896 Args: 1897 image: RGB image or images. Size of the last dimension must be 3. 1898 jpeg_quality: int. jpeg encoding quality. 1899 name: A name for this operation (optional). 1900 1901 Returns: 1902 Adjusted image(s), same shape and DType as `image`. 1903 """ 1904 with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name: 1905 image = ops.convert_to_tensor(image, name='image') 1906 # Remember original dtype to so we can convert back if needed 1907 orig_dtype = image.dtype 1908 # Convert to uint8 1909 image = convert_image_dtype(image, dtypes.uint8) 1910 # Encode image to jpeg with given jpeg quality 1911 image = gen_image_ops.encode_jpeg(image, quality=jpeg_quality) 1912 # Decode jpeg image 1913 image = gen_image_ops.decode_jpeg(image) 1914 # Convert back to original dtype and return 1915 return convert_image_dtype(image, orig_dtype) 1916 1917 1918 @tf_export('image.random_saturation') 1919 def random_saturation(image, lower, upper, seed=None): 1920 """Adjust the saturation of RGB images by a random factor. 1921 1922 Equivalent to `adjust_saturation()` but uses a `saturation_factor` randomly 1923 picked in the interval `[lower, upper]`. 1924 1925 Args: 1926 image: RGB image or images. Size of the last dimension must be 3. 1927 lower: float. Lower bound for the random saturation factor. 1928 upper: float. Upper bound for the random saturation factor. 1929 seed: An operation-specific seed. It will be used in conjunction with the 1930 graph-level seed to determine the real seeds that will be used in this 1931 operation. Please see the documentation of set_random_seed for its 1932 interaction with the graph-level random seed. 1933 1934 Returns: 1935 Adjusted image(s), same shape and DType as `image`. 1936 1937 Raises: 1938 ValueError: if `upper <= lower` or if `lower < 0`. 1939 """ 1940 if upper <= lower: 1941 raise ValueError('upper must be > lower.') 1942 1943 if lower < 0: 1944 raise ValueError('lower must be non-negative.') 1945 1946 # Pick a float in [lower, upper] 1947 saturation_factor = random_ops.random_uniform([], lower, upper, seed=seed) 1948 return adjust_saturation(image, saturation_factor) 1949 1950 1951 @tf_export('image.adjust_saturation') 1952 def adjust_saturation(image, saturation_factor, name=None): 1953 """Adjust saturation of RGB images. 1954 1955 This is a convenience method that converts RGB images to float 1956 representation, converts them to HSV, add an offset to the saturation channel, 1957 converts back to RGB and then back to the original data type. If several 1958 adjustments are chained it is advisable to minimize the number of redundant 1959 conversions. 1960 1961 `image` is an RGB image or images. The image saturation is adjusted by 1962 converting the images to HSV and multiplying the saturation (S) channel by 1963 `saturation_factor` and clipping. The images are then converted back to RGB. 1964 1965 Args: 1966 image: RGB image or images. Size of the last dimension must be 3. 1967 saturation_factor: float. Factor to multiply the saturation by. 1968 name: A name for this operation (optional). 1969 1970 Returns: 1971 Adjusted image(s), same shape and DType as `image`. 1972 """ 1973 with ops.name_scope(name, 'adjust_saturation', [image]) as name: 1974 image = ops.convert_to_tensor(image, name='image') 1975 # Remember original dtype to so we can convert back if needed 1976 orig_dtype = image.dtype 1977 if orig_dtype in (dtypes.float16, dtypes.float32): 1978 flt_image = image 1979 else: 1980 flt_image = convert_image_dtype(image, dtypes.float32) 1981 1982 adjusted = gen_image_ops.adjust_saturation(flt_image, saturation_factor) 1983 1984 return convert_image_dtype(adjusted, orig_dtype) 1985 1986 1987 @tf_export('io.is_jpeg', 'image.is_jpeg', v1=['io.is_jpeg', 'image.is_jpeg']) 1988 def is_jpeg(contents, name=None): 1989 r"""Convenience function to check if the 'contents' encodes a JPEG image. 1990 1991 Args: 1992 contents: 0-D `string`. The encoded image bytes. 1993 name: A name for the operation (optional) 1994 1995 Returns: 1996 A scalar boolean tensor indicating if 'contents' may be a JPEG image. 1997 is_jpeg is susceptible to false positives. 1998 """ 1999 # Normal JPEGs start with \xff\xd8\xff\xe0 2000 # JPEG with EXIF stats with \xff\xd8\xff\xe1 2001 # Use \xff\xd8\xff to cover both. 2002 with ops.name_scope(name, 'is_jpeg'): 2003 substr = string_ops.substr(contents, 0, 3) 2004 return math_ops.equal(substr, b'\xff\xd8\xff', name=name) 2005 2006 2007 def _is_png(contents, name=None): 2008 r"""Convenience function to check if the 'contents' encodes a PNG image. 2009 2010 Args: 2011 contents: 0-D `string`. The encoded image bytes. 2012 name: A name for the operation (optional) 2013 2014 Returns: 2015 A scalar boolean tensor indicating if 'contents' may be a PNG image. 2016 is_png is susceptible to false positives. 2017 """ 2018 with ops.name_scope(name, 'is_png'): 2019 substr = string_ops.substr(contents, 0, 3) 2020 return math_ops.equal(substr, b'\211PN', name=name) 2021 2022 tf_export('io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg', 2023 v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])( 2024 gen_image_ops.decode_and_crop_jpeg) 2025 2026 tf_export('io.decode_bmp', 'image.decode_bmp', 2027 v1=['io.decode_bmp', 'image.decode_bmp'])(gen_image_ops.decode_bmp) 2028 tf_export('io.decode_gif', 'image.decode_gif', 2029 v1=['io.decode_gif', 'image.decode_gif'])(gen_image_ops.decode_gif) 2030 tf_export('io.decode_jpeg', 'image.decode_jpeg', 2031 v1=['io.decode_jpeg', 'image.decode_jpeg'])(gen_image_ops.decode_jpeg) 2032 tf_export('io.decode_png', 'image.decode_png', 2033 v1=['io.decode_png', 'image.decode_png'])(gen_image_ops.decode_png) 2034 2035 tf_export('io.encode_jpeg', 'image.encode_jpeg', 2036 v1=['io.encode_jpeg', 'image.encode_jpeg'])(gen_image_ops.encode_jpeg) 2037 tf_export('io.extract_jpeg_shape', 'image.extract_jpeg_shape', 2038 v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])( 2039 gen_image_ops.extract_jpeg_shape) 2040 2041 2042 @tf_export('io.decode_image', 'image.decode_image', 2043 v1=['io.decode_image', 'image.decode_image']) 2044 def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): 2045 """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, 2046 and `decode_png`. 2047 2048 Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the 2049 appropriate operation to convert the input bytes `string` into a `Tensor` 2050 of type `dtype`. 2051 2052 Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as 2053 opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D 2054 arrays `[height, width, num_channels]`. Make sure to take this into account 2055 when constructing your graph if you are intermixing GIF files with BMP, JPEG, 2056 and/or PNG files. 2057 2058 Args: 2059 contents: 0-D `string`. The encoded image bytes. 2060 channels: An optional `int`. Defaults to `0`. Number of color channels for 2061 the decoded image. 2062 dtype: The desired DType of the returned `Tensor`. 2063 name: A name for the operation (optional) 2064 2065 Returns: 2066 `Tensor` with type `dtype` and shape `[height, width, num_channels]` for 2067 BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for 2068 GIF images. 2069 2070 Raises: 2071 ValueError: On incorrect number of channels. 2072 """ 2073 with ops.name_scope(name, 'decode_image'): 2074 if channels not in (None, 0, 1, 3, 4): 2075 raise ValueError('channels must be in (None, 0, 1, 3, 4)') 2076 substr = string_ops.substr(contents, 0, 3) 2077 2078 def _bmp(): 2079 """Decodes a GIF image.""" 2080 signature = string_ops.substr(contents, 0, 2) 2081 # Create assert op to check that bytes are BMP decodable 2082 is_bmp = math_ops.equal(signature, 'BM', name='is_bmp') 2083 decode_msg = 'Unable to decode bytes as JPEG, PNG, GIF, or BMP' 2084 assert_decode = control_flow_ops.Assert(is_bmp, [decode_msg]) 2085 bmp_channels = 0 if channels is None else channels 2086 good_channels = math_ops.not_equal(bmp_channels, 1, name='check_channels') 2087 channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images' 2088 assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) 2089 with ops.control_dependencies([assert_decode, assert_channels]): 2090 return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype) 2091 2092 def _gif(): 2093 # Create assert to make sure that channels is not set to 1 2094 # Already checked above that channels is in (None, 0, 1, 3) 2095 2096 gif_channels = 0 if channels is None else channels 2097 good_channels = math_ops.logical_and( 2098 math_ops.not_equal(gif_channels, 1, name='check_gif_channels'), 2099 math_ops.not_equal(gif_channels, 4, name='check_gif_channels')) 2100 channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' 2101 assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) 2102 with ops.control_dependencies([assert_channels]): 2103 return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype) 2104 2105 def check_gif(): 2106 # Create assert op to check that bytes are GIF decodable 2107 is_gif = math_ops.equal(substr, b'\x47\x49\x46', name='is_gif') 2108 return control_flow_ops.cond(is_gif, _gif, _bmp, name='cond_gif') 2109 2110 def _png(): 2111 """Decodes a PNG image.""" 2112 return convert_image_dtype( 2113 gen_image_ops.decode_png(contents, channels, 2114 dtype=dtypes.uint8 2115 if dtype == dtypes.uint8 2116 else dtypes.uint16), dtype) 2117 2118 def check_png(): 2119 """Checks if an image is PNG.""" 2120 return control_flow_ops.cond( 2121 _is_png(contents), _png, check_gif, name='cond_png') 2122 2123 def _jpeg(): 2124 """Decodes a jpeg image.""" 2125 jpeg_channels = 0 if channels is None else channels 2126 good_channels = math_ops.not_equal( 2127 jpeg_channels, 4, name='check_jpeg_channels') 2128 channels_msg = ('Channels must be in (None, 0, 1, 3) when decoding JPEG ' 2129 'images') 2130 assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) 2131 with ops.control_dependencies([assert_channels]): 2132 return convert_image_dtype( 2133 gen_image_ops.decode_jpeg(contents, channels), dtype) 2134 2135 # Decode normal JPEG images (start with \xff\xd8\xff\xe0) 2136 # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). 2137 return control_flow_ops.cond( 2138 is_jpeg(contents), _jpeg, check_png, name='cond_jpeg') 2139 2140 2141 @tf_export('image.total_variation') 2142 def total_variation(images, name=None): 2143 """Calculate and return the total variation for one or more images. 2144 2145 The total variation is the sum of the absolute differences for neighboring 2146 pixel-values in the input images. This measures how much noise is in the 2147 images. 2148 2149 This can be used as a loss-function during optimization so as to suppress 2150 noise in images. If you have a batch of images, then you should calculate 2151 the scalar loss-value as the sum: 2152 `loss = tf.reduce_sum(tf.image.total_variation(images))` 2153 2154 This implements the anisotropic 2-D version of the formula described here: 2155 2156 https://en.wikipedia.org/wiki/Total_variation_denoising 2157 2158 Args: 2159 images: 4-D Tensor of shape `[batch, height, width, channels]` or 2160 3-D Tensor of shape `[height, width, channels]`. 2161 2162 name: A name for the operation (optional). 2163 2164 Raises: 2165 ValueError: if images.shape is not a 3-D or 4-D vector. 2166 2167 Returns: 2168 The total variation of `images`. 2169 2170 If `images` was 4-D, return a 1-D float Tensor of shape `[batch]` with the 2171 total variation for each image in the batch. 2172 If `images` was 3-D, return a scalar float with the total variation for 2173 that image. 2174 """ 2175 2176 with ops.name_scope(name, 'total_variation'): 2177 ndims = images.get_shape().ndims 2178 2179 if ndims == 3: 2180 # The input is a single image with shape [height, width, channels]. 2181 2182 # Calculate the difference of neighboring pixel-values. 2183 # The images are shifted one pixel along the height and width by slicing. 2184 pixel_dif1 = images[1:, :, :] - images[:-1, :, :] 2185 pixel_dif2 = images[:, 1:, :] - images[:, :-1, :] 2186 2187 # Sum for all axis. (None is an alias for all axis.) 2188 sum_axis = None 2189 elif ndims == 4: 2190 # The input is a batch of images with shape: 2191 # [batch, height, width, channels]. 2192 2193 # Calculate the difference of neighboring pixel-values. 2194 # The images are shifted one pixel along the height and width by slicing. 2195 pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] 2196 pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] 2197 2198 # Only sum for the last 3 axis. 2199 # This results in a 1-D tensor with the total variation for each image. 2200 sum_axis = [1, 2, 3] 2201 else: 2202 raise ValueError('\'images\' must be either 3 or 4-dimensional.') 2203 2204 # Calculate the total variation by taking the absolute value of the 2205 # pixel-differences and summing over the appropriate axis. 2206 tot_var = ( 2207 math_ops.reduce_sum(math_ops.abs(pixel_dif1), axis=sum_axis) + 2208 math_ops.reduce_sum(math_ops.abs(pixel_dif2), axis=sum_axis)) 2209 2210 return tot_var 2211 2212 2213 @tf_export('image.sample_distorted_bounding_box', v1=[]) 2214 def sample_distorted_bounding_box_v2(image_size, 2215 bounding_boxes, 2216 seed=0, 2217 min_object_covered=0.1, 2218 aspect_ratio_range=None, 2219 area_range=None, 2220 max_attempts=None, 2221 use_image_if_no_bounding_boxes=None, 2222 name=None): 2223 """Generate a single randomly distorted bounding box for an image. 2224 2225 Bounding box annotations are often supplied in addition to ground-truth labels 2226 in image recognition or object localization tasks. A common technique for 2227 training such a system is to randomly distort an image while preserving 2228 its content, i.e. *data augmentation*. This Op outputs a randomly distorted 2229 localization of an object, i.e. bounding box, given an `image_size`, 2230 `bounding_boxes` and a series of constraints. 2231 2232 The output of this Op is a single bounding box that may be used to crop the 2233 original image. The output is returned as 3 tensors: `begin`, `size` and 2234 `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the 2235 image. The latter may be supplied to `tf.image.draw_bounding_boxes` to 2236 visualize what the bounding box looks like. 2237 2238 Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. 2239 The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width 2240 and height of the underlying image. 2241 2242 For example, 2243 2244 ```python 2245 # Generate a single distorted bounding box. 2246 begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( 2247 tf.shape(image), 2248 bounding_boxes=bounding_boxes, 2249 min_object_covered=0.1) 2250 2251 # Draw the bounding box in an image summary. 2252 image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 2253 bbox_for_draw) 2254 tf.summary.image('images_with_box', image_with_box) 2255 2256 # Employ the bounding box to distort the image. 2257 distorted_image = tf.slice(image, begin, size) 2258 ``` 2259 2260 Note that if no bounding box information is available, setting 2261 `use_image_if_no_bounding_boxes = true` will assume there is a single implicit 2262 bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is 2263 false and no bounding boxes are supplied, an error is raised. 2264 2265 Args: 2266 image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, 2267 `int16`, `int32`, `int64`. 2268 1-D, containing `[height, width, channels]`. 2269 bounding_boxes: A `Tensor` of type `float32`. 2270 3-D with shape `[batch, N, 4]` describing the N bounding boxes 2271 associated with the image. 2272 seed: An optional `int`. Defaults to `0`. 2273 If `seed` is set to non-zero, the random number generator is seeded by 2274 the given `seed`. Otherwise, it is seeded by a random seed. 2275 min_object_covered: A Tensor of type `float32`. Defaults to `0.1`. 2276 The cropped area of the image must contain at least this 2277 fraction of any bounding box supplied. The value of this parameter should 2278 be non-negative. In the case of 0, the cropped area does not need to 2279 overlap any of the bounding boxes supplied. 2280 aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 2281 1.33]`. 2282 The cropped area of the image must have an aspect `ratio = 2283 width / height` within this range. 2284 area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. 2285 The cropped area of the image must contain a fraction of the 2286 supplied image within this range. 2287 max_attempts: An optional `int`. Defaults to `100`. 2288 Number of attempts at generating a cropped region of the image 2289 of the specified constraints. After `max_attempts` failures, return the 2290 entire image. 2291 use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`. 2292 Controls behavior if no bounding boxes supplied. 2293 If true, assume an implicit bounding box covering the whole input. If 2294 false, raise an error. 2295 name: A name for the operation (optional). 2296 2297 Returns: 2298 A tuple of `Tensor` objects (begin, size, bboxes). 2299 2300 begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing 2301 `[offset_height, offset_width, 0]`. Provide as input to 2302 `tf.slice`. 2303 size: A `Tensor`. Has the same type as `image_size`. 1-D, containing 2304 `[target_height, target_width, -1]`. Provide as input to 2305 `tf.slice`. 2306 bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing 2307 the distorted bounding box. 2308 Provide as input to `tf.image.draw_bounding_boxes`. 2309 """ 2310 seed1, seed2 = random_seed.get_seed(seed) if seed else (0, 0) 2311 return sample_distorted_bounding_box( 2312 image_size, bounding_boxes, seed1, seed2, min_object_covered, 2313 aspect_ratio_range, area_range, max_attempts, 2314 use_image_if_no_bounding_boxes, name) 2315 2316 2317 @tf_export(v1=['image.sample_distorted_bounding_box']) 2318 @deprecation.deprecated(date=None, instructions='`seed2` arg is deprecated.' 2319 'Use sample_distorted_bounding_box_v2 instead.') 2320 def sample_distorted_bounding_box(image_size, 2321 bounding_boxes, 2322 seed=None, 2323 seed2=None, 2324 min_object_covered=0.1, 2325 aspect_ratio_range=None, 2326 area_range=None, 2327 max_attempts=None, 2328 use_image_if_no_bounding_boxes=None, 2329 name=None): 2330 """Generate a single randomly distorted bounding box for an image. 2331 2332 Bounding box annotations are often supplied in addition to ground-truth labels 2333 in image recognition or object localization tasks. A common technique for 2334 training such a system is to randomly distort an image while preserving 2335 its content, i.e. *data augmentation*. This Op outputs a randomly distorted 2336 localization of an object, i.e. bounding box, given an `image_size`, 2337 `bounding_boxes` and a series of constraints. 2338 2339 The output of this Op is a single bounding box that may be used to crop the 2340 original image. The output is returned as 3 tensors: `begin`, `size` and 2341 `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the 2342 image. The latter may be supplied to `tf.image.draw_bounding_boxes` to 2343 visualize 2344 what the bounding box looks like. 2345 2346 Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. 2347 The 2348 bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and 2349 height of the underlying image. 2350 2351 For example, 2352 2353 ```python 2354 # Generate a single distorted bounding box. 2355 begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box( 2356 tf.shape(image), 2357 bounding_boxes=bounding_boxes, 2358 min_object_covered=0.1) 2359 2360 # Draw the bounding box in an image summary. 2361 image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 2362 bbox_for_draw) 2363 tf.summary.image('images_with_box', image_with_box) 2364 2365 # Employ the bounding box to distort the image. 2366 distorted_image = tf.slice(image, begin, size) 2367 ``` 2368 2369 Note that if no bounding box information is available, setting 2370 `use_image_if_no_bounding_boxes = true` will assume there is a single implicit 2371 bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is 2372 false and no bounding boxes are supplied, an error is raised. 2373 2374 Args: 2375 image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`, 2376 `int16`, `int32`, `int64`. 2377 1-D, containing `[height, width, channels]`. 2378 bounding_boxes: A `Tensor` of type `float32`. 2379 3-D with shape `[batch, N, 4]` describing the N bounding boxes 2380 associated with the image. 2381 seed: An optional `int`. Defaults to `0`. 2382 If either `seed` or `seed2` are set to non-zero, the random number 2383 generator is seeded by the given `seed`. Otherwise, it is seeded by a 2384 random 2385 seed. 2386 seed2: An optional `int`. Defaults to `0`. 2387 A second seed to avoid seed collision. 2388 min_object_covered: A Tensor of type `float32`. Defaults to `0.1`. 2389 The cropped area of the image must contain at least this 2390 fraction of any bounding box supplied. The value of this parameter should 2391 be 2392 non-negative. In the case of 0, the cropped area does not need to overlap 2393 any of the bounding boxes supplied. 2394 aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75, 2395 1.33]`. 2396 The cropped area of the image must have an aspect ratio = 2397 width / height within this range. 2398 area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. 2399 The cropped area of the image must contain a fraction of the 2400 supplied image within this range. 2401 max_attempts: An optional `int`. Defaults to `100`. 2402 Number of attempts at generating a cropped region of the image 2403 of the specified constraints. After `max_attempts` failures, return the 2404 entire 2405 image. 2406 use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`. 2407 Controls behavior if no bounding boxes supplied. 2408 If true, assume an implicit bounding box covering the whole input. If 2409 false, 2410 raise an error. 2411 name: A name for the operation (optional). 2412 2413 Returns: 2414 A tuple of `Tensor` objects (begin, size, bboxes). 2415 2416 begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing 2417 `[offset_height, offset_width, 0]`. Provide as input to 2418 `tf.slice`. 2419 size: A `Tensor`. Has the same type as `image_size`. 1-D, containing 2420 `[target_height, target_width, -1]`. Provide as input to 2421 `tf.slice`. 2422 bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing 2423 the distorted bounding box. 2424 Provide as input to `tf.image.draw_bounding_boxes`. 2425 """ 2426 with ops.name_scope(name, 'sample_distorted_bounding_box'): 2427 return gen_image_ops.sample_distorted_bounding_box_v2( 2428 image_size, 2429 bounding_boxes, 2430 seed=seed, 2431 seed2=seed2, 2432 min_object_covered=min_object_covered, 2433 aspect_ratio_range=aspect_ratio_range, 2434 area_range=area_range, 2435 max_attempts=max_attempts, 2436 use_image_if_no_bounding_boxes=use_image_if_no_bounding_boxes, 2437 name=name) 2438 2439 2440 @tf_export('image.non_max_suppression') 2441 def non_max_suppression(boxes, 2442 scores, 2443 max_output_size, 2444 iou_threshold=0.5, 2445 score_threshold=float('-inf'), 2446 name=None): 2447 """Greedily selects a subset of bounding boxes in descending order of score. 2448 2449 Prunes away boxes that have high intersection-over-union (IOU) overlap 2450 with previously selected boxes. Bounding boxes are supplied as 2451 `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the coordinates of any 2452 diagonal pair of box corners and the coordinates can be provided as normalized 2453 (i.e., lying in the interval `[0, 1]`) or absolute. Note that this algorithm 2454 is agnostic to where the origin is in the coordinate system. Note that this 2455 algorithm is invariant to orthogonal transformations and translations 2456 of the coordinate system; thus translating or reflections of the coordinate 2457 system result in the same boxes being selected by the algorithm. 2458 The output of this operation is a set of integers indexing into the input 2459 collection of bounding boxes representing the selected boxes. The bounding 2460 box coordinates corresponding to the selected indices can then be obtained 2461 using the `tf.gather` operation. For example: 2462 ```python 2463 selected_indices = tf.image.non_max_suppression( 2464 boxes, scores, max_output_size, iou_threshold) 2465 selected_boxes = tf.gather(boxes, selected_indices) 2466 ``` 2467 2468 Args: 2469 boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. 2470 scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single 2471 score corresponding to each box (each row of boxes). 2472 max_output_size: A scalar integer `Tensor` representing the maximum number 2473 of boxes to be selected by non max suppression. 2474 iou_threshold: A float representing the threshold for deciding whether boxes 2475 overlap too much with respect to IOU. 2476 score_threshold: A float representing the threshold for deciding when to 2477 remove boxes based on score. 2478 name: A name for the operation (optional). 2479 2480 Returns: 2481 selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the 2482 selected indices from the boxes tensor, where `M <= max_output_size`. 2483 """ 2484 with ops.name_scope(name, 'non_max_suppression'): 2485 iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') 2486 score_threshold = ops.convert_to_tensor( 2487 score_threshold, name='score_threshold') 2488 return gen_image_ops.non_max_suppression_v3(boxes, scores, max_output_size, 2489 iou_threshold, score_threshold) 2490 2491 2492 @tf_export('image.non_max_suppression_padded') 2493 def non_max_suppression_padded(boxes, 2494 scores, 2495 max_output_size, 2496 iou_threshold=0.5, 2497 score_threshold=float('-inf'), 2498 pad_to_max_output_size=False, 2499 name=None): 2500 """Greedily selects a subset of bounding boxes in descending order of score. 2501 2502 Performs algorithmically equivalent operation to tf.image.non_max_suppression, 2503 with the addition of an optional parameter which zero-pads the output to 2504 be of size `max_output_size`. 2505 The output of this operation is a tuple containing the set of integers 2506 indexing into the input collection of bounding boxes representing the selected 2507 boxes and the number of valid indices in the index set. The bounding box 2508 coordinates corresponding to the selected indices can then be obtained using 2509 the `tf.slice` and `tf.gather` operations. For example: 2510 ```python 2511 selected_indices_padded, num_valid = tf.image.non_max_suppression_padded( 2512 boxes, scores, max_output_size, iou_threshold, 2513 score_threshold, pad_to_max_output_size=True) 2514 selected_indices = tf.slice( 2515 selected_indices_padded, tf.constant([0]), num_valid) 2516 selected_boxes = tf.gather(boxes, selected_indices) 2517 ``` 2518 2519 Args: 2520 boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. 2521 scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single 2522 score corresponding to each box (each row of boxes). 2523 max_output_size: A scalar integer `Tensor` representing the maximum number 2524 of boxes to be selected by non max suppression. 2525 iou_threshold: A float representing the threshold for deciding whether boxes 2526 overlap too much with respect to IOU. 2527 score_threshold: A float representing the threshold for deciding when to 2528 remove boxes based on score. 2529 pad_to_max_output_size: bool. If True, size of `selected_indices` output 2530 is padded to `max_output_size`. 2531 name: A name for the operation (optional). 2532 2533 Returns: 2534 selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the 2535 selected indices from the boxes tensor, where `M <= max_output_size`. 2536 valid_outputs: A scalar integer `Tensor` denoting how many elements in 2537 `selected_indices` are valid. Valid elements occur first, then padding. 2538 """ 2539 with ops.name_scope(name, 'non_max_suppression_padded'): 2540 iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') 2541 score_threshold = ops.convert_to_tensor( 2542 score_threshold, name='score_threshold') 2543 if compat.forward_compatible(2018, 8, 7) or pad_to_max_output_size: 2544 return gen_image_ops.non_max_suppression_v4( 2545 boxes, scores, max_output_size, iou_threshold, score_threshold, 2546 pad_to_max_output_size) 2547 else: 2548 return gen_image_ops.non_max_suppression_v3( 2549 boxes, scores, max_output_size, iou_threshold, score_threshold) 2550 2551 2552 @tf_export('image.non_max_suppression_overlaps') 2553 def non_max_suppression_with_overlaps(overlaps, 2554 scores, 2555 max_output_size, 2556 overlap_threshold=0.5, 2557 score_threshold=float('-inf'), 2558 name=None): 2559 """Greedily selects a subset of bounding boxes in descending order of score. 2560 2561 Prunes away boxes that have high overlap with previously selected boxes. 2562 N-by-n overlap values are supplied as square matrix. 2563 The output of this operation is a set of integers indexing into the input 2564 collection of bounding boxes representing the selected boxes. The bounding 2565 box coordinates corresponding to the selected indices can then be obtained 2566 using the `tf.gather` operation. For example: 2567 ```python 2568 selected_indices = tf.image.non_max_suppression_overlaps( 2569 overlaps, scores, max_output_size, iou_threshold) 2570 selected_boxes = tf.gather(boxes, selected_indices) 2571 ``` 2572 2573 Args: 2574 overlaps: A 2-D float `Tensor` of shape `[num_boxes, num_boxes]`. 2575 scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single 2576 score corresponding to each box (each row of boxes). 2577 max_output_size: A scalar integer `Tensor` representing the maximum number 2578 of boxes to be selected by non max suppression. 2579 overlap_threshold: A float representing the threshold for deciding whether 2580 boxes overlap too much with respect to the provided overlap values. 2581 score_threshold: A float representing the threshold for deciding when to 2582 remove boxes based on score. 2583 name: A name for the operation (optional). 2584 2585 Returns: 2586 selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the 2587 selected indices from the overlaps tensor, where `M <= max_output_size`. 2588 """ 2589 with ops.name_scope(name, 'non_max_suppression_overlaps'): 2590 overlap_threshold = ops.convert_to_tensor( 2591 overlap_threshold, name='overlap_threshold') 2592 # pylint: disable=protected-access 2593 return gen_image_ops.non_max_suppression_with_overlaps( 2594 overlaps, scores, max_output_size, overlap_threshold, score_threshold) 2595 # pylint: enable=protected-access 2596 2597 2598 _rgb_to_yiq_kernel = [[0.299, 0.59590059, 2599 0.2115], [0.587, -0.27455667, -0.52273617], 2600 [0.114, -0.32134392, 0.31119955]] 2601 2602 2603 @tf_export('image.rgb_to_yiq') 2604 def rgb_to_yiq(images): 2605 """Converts one or more images from RGB to YIQ. 2606 2607 Outputs a tensor of the same shape as the `images` tensor, containing the YIQ 2608 value of the pixels. 2609 The output is only well defined if the value in images are in [0,1]. 2610 2611 Args: 2612 images: 2-D or higher rank. Image data to convert. Last dimension must be 2613 size 3. 2614 2615 Returns: 2616 images: tensor with the same shape as `images`. 2617 """ 2618 images = ops.convert_to_tensor(images, name='images') 2619 kernel = ops.convert_to_tensor( 2620 _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') 2621 ndims = images.get_shape().ndims 2622 return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) 2623 2624 2625 _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], 2626 [0.6208248, -0.64720424, 1.70423049]] 2627 2628 2629 @tf_export('image.yiq_to_rgb') 2630 def yiq_to_rgb(images): 2631 """Converts one or more images from YIQ to RGB. 2632 2633 Outputs a tensor of the same shape as the `images` tensor, containing the RGB 2634 value of the pixels. 2635 The output is only well defined if the Y value in images are in [0,1], 2636 I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226]. 2637 2638 Args: 2639 images: 2-D or higher rank. Image data to convert. Last dimension must be 2640 size 3. 2641 2642 Returns: 2643 images: tensor with the same shape as `images`. 2644 """ 2645 images = ops.convert_to_tensor(images, name='images') 2646 kernel = ops.convert_to_tensor( 2647 _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') 2648 ndims = images.get_shape().ndims 2649 return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) 2650 2651 2652 _rgb_to_yuv_kernel = [[0.299, -0.14714119, 2653 0.61497538], [0.587, -0.28886916, -0.51496512], 2654 [0.114, 0.43601035, -0.10001026]] 2655 2656 2657 @tf_export('image.rgb_to_yuv') 2658 def rgb_to_yuv(images): 2659 """Converts one or more images from RGB to YUV. 2660 2661 Outputs a tensor of the same shape as the `images` tensor, containing the YUV 2662 value of the pixels. 2663 The output is only well defined if the value in images are in [0,1]. 2664 2665 Args: 2666 images: 2-D or higher rank. Image data to convert. Last dimension must be 2667 size 3. 2668 2669 Returns: 2670 images: tensor with the same shape as `images`. 2671 """ 2672 images = ops.convert_to_tensor(images, name='images') 2673 kernel = ops.convert_to_tensor( 2674 _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') 2675 ndims = images.get_shape().ndims 2676 return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) 2677 2678 2679 _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], 2680 [1.13988303, -0.58062185, 0]] 2681 2682 2683 @tf_export('image.yuv_to_rgb') 2684 def yuv_to_rgb(images): 2685 """Converts one or more images from YUV to RGB. 2686 2687 Outputs a tensor of the same shape as the `images` tensor, containing the RGB 2688 value of the pixels. 2689 The output is only well defined if the Y value in images are in [0,1], 2690 U and V value are in [-0.5,0.5]. 2691 2692 Args: 2693 images: 2-D or higher rank. Image data to convert. Last dimension must be 2694 size 3. 2695 2696 Returns: 2697 images: tensor with the same shape as `images`. 2698 """ 2699 images = ops.convert_to_tensor(images, name='images') 2700 kernel = ops.convert_to_tensor( 2701 _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') 2702 ndims = images.get_shape().ndims 2703 return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) 2704 2705 2706 def _verify_compatible_image_shapes(img1, img2): 2707 """Checks if two image tensors are compatible for applying SSIM or PSNR. 2708 2709 This function checks if two sets of images have ranks at least 3, and if the 2710 last three dimensions match. 2711 2712 Args: 2713 img1: Tensor containing the first image batch. 2714 img2: Tensor containing the second image batch. 2715 2716 Returns: 2717 A tuple containing: the first tensor shape, the second tensor shape, and a 2718 list of control_flow_ops.Assert() ops implementing the checks. 2719 2720 Raises: 2721 ValueError: When static shape check fails. 2722 """ 2723 shape1 = img1.get_shape().with_rank_at_least(3) 2724 shape2 = img2.get_shape().with_rank_at_least(3) 2725 shape1[-3:].assert_is_compatible_with(shape2[-3:]) 2726 2727 if shape1.ndims is not None and shape2.ndims is not None: 2728 for dim1, dim2 in zip(reversed(shape1.dims[:-3]), 2729 reversed(shape2.dims[:-3])): 2730 if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)): 2731 raise ValueError( 2732 'Two images are not compatible: %s and %s' % (shape1, shape2)) 2733 2734 # Now assign shape tensors. 2735 shape1, shape2 = array_ops.shape_n([img1, img2]) 2736 2737 # TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable. 2738 checks = [] 2739 checks.append(control_flow_ops.Assert( 2740 math_ops.greater_equal(array_ops.size(shape1), 3), 2741 [shape1, shape2], summarize=10)) 2742 checks.append(control_flow_ops.Assert( 2743 math_ops.reduce_all(math_ops.equal(shape1[-3:], shape2[-3:])), 2744 [shape1, shape2], summarize=10)) 2745 return shape1, shape2, checks 2746 2747 2748 @tf_export('image.psnr') 2749 def psnr(a, b, max_val, name=None): 2750 """Returns the Peak Signal-to-Noise Ratio between a and b. 2751 2752 This is intended to be used on signals (or images). Produces a PSNR value for 2753 each image in batch. 2754 2755 The last three dimensions of input are expected to be [height, width, depth]. 2756 2757 Example: 2758 2759 ```python 2760 # Read images from file. 2761 im1 = tf.decode_png('path/to/im1.png') 2762 im2 = tf.decode_png('path/to/im2.png') 2763 # Compute PSNR over tf.uint8 Tensors. 2764 psnr1 = tf.image.psnr(im1, im2, max_val=255) 2765 2766 # Compute PSNR over tf.float32 Tensors. 2767 im1 = tf.image.convert_image_dtype(im1, tf.float32) 2768 im2 = tf.image.convert_image_dtype(im2, tf.float32) 2769 psnr2 = tf.image.psnr(im1, im2, max_val=1.0) 2770 # psnr1 and psnr2 both have type tf.float32 and are almost equal. 2771 ``` 2772 2773 Arguments: 2774 a: First set of images. 2775 b: Second set of images. 2776 max_val: The dynamic range of the images (i.e., the difference between the 2777 maximum the and minimum allowed values). 2778 name: Namespace to embed the computation in. 2779 2780 Returns: 2781 The scalar PSNR between a and b. The returned tensor has type `tf.float32` 2782 and shape [batch_size, 1]. 2783 """ 2784 with ops.name_scope(name, 'PSNR', [a, b]): 2785 # Need to convert the images to float32. Scale max_val accordingly so that 2786 # PSNR is computed correctly. 2787 max_val = math_ops.cast(max_val, a.dtype) 2788 max_val = convert_image_dtype(max_val, dtypes.float32) 2789 a = convert_image_dtype(a, dtypes.float32) 2790 b = convert_image_dtype(b, dtypes.float32) 2791 mse = math_ops.reduce_mean(math_ops.squared_difference(a, b), [-3, -2, -1]) 2792 psnr_val = math_ops.subtract( 2793 20 * math_ops.log(max_val) / math_ops.log(10.0), 2794 np.float32(10 / np.log(10)) * math_ops.log(mse), 2795 name='psnr') 2796 2797 _, _, checks = _verify_compatible_image_shapes(a, b) 2798 with ops.control_dependencies(checks): 2799 return array_ops.identity(psnr_val) 2800 2801 _SSIM_K1 = 0.01 2802 _SSIM_K2 = 0.03 2803 2804 2805 def _ssim_helper(x, y, reducer, max_val, compensation=1.0): 2806 r"""Helper function for computing SSIM. 2807 2808 SSIM estimates covariances with weighted sums. The default parameters 2809 use a biased estimate of the covariance: 2810 Suppose `reducer` is a weighted sum, then the mean estimators are 2811 \mu_x = \sum_i w_i x_i, 2812 \mu_y = \sum_i w_i y_i, 2813 where w_i's are the weighted-sum weights, and covariance estimator is 2814 cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) 2815 with assumption \sum_i w_i = 1. This covariance estimator is biased, since 2816 E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y). 2817 For SSIM measure with unbiased covariance estimators, pass as `compensation` 2818 argument (1 - \sum_i w_i ^ 2). 2819 2820 Arguments: 2821 x: First set of images. 2822 y: Second set of images. 2823 reducer: Function that computes 'local' averages from set of images. 2824 For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]), 2825 and for convolutional version, this is usually tf.nn.avg_pool or 2826 tf.nn.conv2d with weighted-sum kernel. 2827 max_val: The dynamic range (i.e., the difference between the maximum 2828 possible allowed value and the minimum allowed value). 2829 compensation: Compensation factor. See above. 2830 2831 Returns: 2832 A pair containing the luminance measure, and the contrast-structure measure. 2833 """ 2834 c1 = (_SSIM_K1 * max_val) ** 2 2835 c2 = (_SSIM_K2 * max_val) ** 2 2836 2837 # SSIM luminance measure is 2838 # (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1). 2839 mean0 = reducer(x) 2840 mean1 = reducer(y) 2841 num0 = mean0 * mean1 * 2.0 2842 den0 = math_ops.square(mean0) + math_ops.square(mean1) 2843 luminance = (num0 + c1) / (den0 + c1) 2844 2845 # SSIM contrast-structure measure is 2846 # (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2). 2847 # Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then 2848 # cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) 2849 # = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j). 2850 num1 = reducer(x * y) * 2.0 2851 den1 = reducer(math_ops.square(x) + math_ops.square(y)) 2852 c2 *= compensation 2853 cs = (num1 - num0 + c2) / (den1 - den0 + c2) 2854 2855 # SSIM score is the product of the luminance and contrast-structure measures. 2856 return luminance, cs 2857 2858 2859 def _fspecial_gauss(size, sigma): 2860 """Function to mimic the 'fspecial' gaussian MATLAB function.""" 2861 size = ops.convert_to_tensor(size, dtypes.int32) 2862 sigma = ops.convert_to_tensor(sigma) 2863 2864 coords = math_ops.cast(math_ops.range(size), sigma.dtype) 2865 coords -= math_ops.cast(size - 1, sigma.dtype) / 2.0 2866 2867 g = math_ops.square(coords) 2868 g *= -0.5 / math_ops.square(sigma) 2869 2870 g = array_ops.reshape(g, shape=[1, -1]) + array_ops.reshape(g, shape=[-1, 1]) 2871 g = array_ops.reshape(g, shape=[1, -1]) # For tf.nn.softmax(). 2872 g = nn_ops.softmax(g) 2873 return array_ops.reshape(g, shape=[size, size, 1, 1]) 2874 2875 2876 def _ssim_per_channel(img1, img2, max_val=1.0): 2877 """Computes SSIM index between img1 and img2 per color channel. 2878 2879 This function matches the standard SSIM implementation from: 2880 Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image 2881 quality assessment: from error visibility to structural similarity. IEEE 2882 transactions on image processing. 2883 2884 Details: 2885 - 11x11 Gaussian filter of width 1.5 is used. 2886 - k1 = 0.01, k2 = 0.03 as in the original paper. 2887 2888 Args: 2889 img1: First image batch. 2890 img2: Second image batch. 2891 max_val: The dynamic range of the images (i.e., the difference between the 2892 maximum the and minimum allowed values). 2893 2894 Returns: 2895 A pair of tensors containing and channel-wise SSIM and contrast-structure 2896 values. The shape is [..., channels]. 2897 """ 2898 filter_size = constant_op.constant(11, dtype=dtypes.int32) 2899 filter_sigma = constant_op.constant(1.5, dtype=img1.dtype) 2900 2901 shape1, shape2 = array_ops.shape_n([img1, img2]) 2902 checks = [ 2903 control_flow_ops.Assert(math_ops.reduce_all(math_ops.greater_equal( 2904 shape1[-3:-1], filter_size)), [shape1, filter_size], summarize=8), 2905 control_flow_ops.Assert(math_ops.reduce_all(math_ops.greater_equal( 2906 shape2[-3:-1], filter_size)), [shape2, filter_size], summarize=8)] 2907 2908 # Enforce the check to run before computation. 2909 with ops.control_dependencies(checks): 2910 img1 = array_ops.identity(img1) 2911 2912 # TODO(sjhwang): Try to cache kernels and compensation factor. 2913 kernel = _fspecial_gauss(filter_size, filter_sigma) 2914 kernel = array_ops.tile(kernel, multiples=[1, 1, shape1[-1], 1]) 2915 2916 # The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`, 2917 # but to match MATLAB implementation of MS-SSIM, we use 1.0 instead. 2918 compensation = 1.0 2919 2920 # TODO(sjhwang): Try FFT. 2921 # TODO(sjhwang): Gaussian kernel is separable in space. Consider applying 2922 # 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter. 2923 def reducer(x): 2924 shape = array_ops.shape(x) 2925 x = array_ops.reshape(x, shape=array_ops.concat([[-1], shape[-3:]], 0)) 2926 y = nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID') 2927 return array_ops.reshape(y, array_ops.concat([shape[:-3], 2928 array_ops.shape(y)[1:]], 0)) 2929 2930 luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation) 2931 2932 # Average over the second and the third from the last: height, width. 2933 axes = constant_op.constant([-3, -2], dtype=dtypes.int32) 2934 ssim_val = math_ops.reduce_mean(luminance * cs, axes) 2935 cs = math_ops.reduce_mean(cs, axes) 2936 return ssim_val, cs 2937 2938 2939 @tf_export('image.ssim') 2940 def ssim(img1, img2, max_val): 2941 """Computes SSIM index between img1 and img2. 2942 2943 This function is based on the standard SSIM implementation from: 2944 Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image 2945 quality assessment: from error visibility to structural similarity. IEEE 2946 transactions on image processing. 2947 2948 Note: The true SSIM is only defined on grayscale. This function does not 2949 perform any colorspace transform. (If input is already YUV, then it will 2950 compute YUV SSIM average.) 2951 2952 Details: 2953 - 11x11 Gaussian filter of width 1.5 is used. 2954 - k1 = 0.01, k2 = 0.03 as in the original paper. 2955 2956 The image sizes must be at least 11x11 because of the filter size. 2957 2958 Example: 2959 2960 ```python 2961 # Read images from file. 2962 im1 = tf.decode_png('path/to/im1.png') 2963 im2 = tf.decode_png('path/to/im2.png') 2964 # Compute SSIM over tf.uint8 Tensors. 2965 ssim1 = tf.image.ssim(im1, im2, max_val=255) 2966 2967 # Compute SSIM over tf.float32 Tensors. 2968 im1 = tf.image.convert_image_dtype(im1, tf.float32) 2969 im2 = tf.image.convert_image_dtype(im2, tf.float32) 2970 ssim2 = tf.image.ssim(im1, im2, max_val=1.0) 2971 # ssim1 and ssim2 both have type tf.float32 and are almost equal. 2972 ``` 2973 2974 Args: 2975 img1: First image batch. 2976 img2: Second image batch. 2977 max_val: The dynamic range of the images (i.e., the difference between the 2978 maximum the and minimum allowed values). 2979 2980 Returns: 2981 A tensor containing an SSIM value for each image in batch. Returned SSIM 2982 values are in range (-1, 1], when pixel values are non-negative. Returns 2983 a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). 2984 """ 2985 _, _, checks = _verify_compatible_image_shapes(img1, img2) 2986 with ops.control_dependencies(checks): 2987 img1 = array_ops.identity(img1) 2988 2989 # Need to convert the images to float32. Scale max_val accordingly so that 2990 # SSIM is computed correctly. 2991 max_val = math_ops.cast(max_val, img1.dtype) 2992 max_val = convert_image_dtype(max_val, dtypes.float32) 2993 img1 = convert_image_dtype(img1, dtypes.float32) 2994 img2 = convert_image_dtype(img2, dtypes.float32) 2995 ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val) 2996 # Compute average over color channels. 2997 return math_ops.reduce_mean(ssim_per_channel, [-1]) 2998 2999 3000 # Default values obtained by Wang et al. 3001 _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) 3002 3003 3004 @tf_export('image.ssim_multiscale') 3005 def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS): 3006 """Computes the MS-SSIM between img1 and img2. 3007 3008 This function assumes that `img1` and `img2` are image batches, i.e. the last 3009 three dimensions are [height, width, channels]. 3010 3011 Note: The true SSIM is only defined on grayscale. This function does not 3012 perform any colorspace transform. (If input is already YUV, then it will 3013 compute YUV SSIM average.) 3014 3015 Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale 3016 structural similarity for image quality assessment." Signals, Systems and 3017 Computers, 2004. 3018 3019 Arguments: 3020 img1: First image batch. 3021 img2: Second image batch. Must have the same rank as img1. 3022 max_val: The dynamic range of the images (i.e., the difference between the 3023 maximum the and minimum allowed values). 3024 power_factors: Iterable of weights for each of the scales. The number of 3025 scales used is the length of the list. Index 0 is the unscaled 3026 resolution's weight and each increasing scale corresponds to the image 3027 being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363, 3028 0.1333), which are the values obtained in the original paper. 3029 3030 Returns: 3031 A tensor containing an MS-SSIM value for each image in batch. The values 3032 are in range [0, 1]. Returns a tensor with shape: 3033 broadcast(img1.shape[:-3], img2.shape[:-3]). 3034 """ 3035 # Shape checking. 3036 shape1 = img1.get_shape().with_rank_at_least(3) 3037 shape2 = img2.get_shape().with_rank_at_least(3) 3038 shape1[-3:].merge_with(shape2[-3:]) 3039 3040 with ops.name_scope(None, 'MS-SSIM', [img1, img2]): 3041 shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) 3042 with ops.control_dependencies(checks): 3043 img1 = array_ops.identity(img1) 3044 3045 # Need to convert the images to float32. Scale max_val accordingly so that 3046 # SSIM is computed correctly. 3047 max_val = math_ops.cast(max_val, img1.dtype) 3048 max_val = convert_image_dtype(max_val, dtypes.float32) 3049 img1 = convert_image_dtype(img1, dtypes.float32) 3050 img2 = convert_image_dtype(img2, dtypes.float32) 3051 3052 imgs = [img1, img2] 3053 shapes = [shape1, shape2] 3054 3055 # img1 and img2 are assumed to be a (multi-dimensional) batch of 3056 # 3-dimensional images (height, width, channels). `heads` contain the batch 3057 # dimensions, and `tails` contain the image dimensions. 3058 heads = [s[:-3] for s in shapes] 3059 tails = [s[-3:] for s in shapes] 3060 3061 divisor = [1, 2, 2, 1] 3062 divisor_tensor = constant_op.constant(divisor[1:], dtype=dtypes.int32) 3063 3064 def do_pad(images, remainder): 3065 padding = array_ops.expand_dims(remainder, -1) 3066 padding = array_ops.pad(padding, [[1, 0], [1, 0]]) 3067 return [array_ops.pad(x, padding, mode='SYMMETRIC') for x in images] 3068 3069 mcs = [] 3070 for k in range(len(power_factors)): 3071 with ops.name_scope(None, 'Scale%d' % k, imgs): 3072 if k > 0: 3073 # Avg pool takes rank 4 tensors. Flatten leading dimensions. 3074 flat_imgs = [ 3075 array_ops.reshape(x, array_ops.concat([[-1], t], 0)) 3076 for x, t in zip(imgs, tails) 3077 ] 3078 3079 remainder = tails[0] % divisor_tensor 3080 need_padding = math_ops.reduce_any(math_ops.not_equal(remainder, 0)) 3081 # pylint: disable=cell-var-from-loop 3082 padded = control_flow_ops.cond(need_padding, 3083 lambda: do_pad(flat_imgs, remainder), 3084 lambda: flat_imgs) 3085 # pylint: enable=cell-var-from-loop 3086 3087 downscaled = [nn_ops.avg_pool(x, ksize=divisor, strides=divisor, 3088 padding='VALID') 3089 for x in padded] 3090 tails = [x[1:] for x in array_ops.shape_n(downscaled)] 3091 imgs = [ 3092 array_ops.reshape(x, array_ops.concat([h, t], 0)) 3093 for x, h, t in zip(downscaled, heads, tails) 3094 ] 3095 3096 # Overwrite previous ssim value since we only need the last one. 3097 ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val) 3098 mcs.append(nn_ops.relu(cs)) 3099 3100 # Remove the cs score for the last scale. In the MS-SSIM calculation, 3101 # we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p). 3102 mcs.pop() # Remove the cs score for the last scale. 3103 mcs_and_ssim = array_ops.stack(mcs + [nn_ops.relu(ssim_per_channel)], 3104 axis=-1) 3105 # Take weighted geometric mean across the scale axis. 3106 ms_ssim = math_ops.reduce_prod(math_ops.pow(mcs_and_ssim, power_factors), 3107 [-1]) 3108 3109 return math_ops.reduce_mean(ms_ssim, [-1]) # Avg over color channels. 3110 3111 3112 @tf_export('image.image_gradients') 3113 def image_gradients(image): 3114 """Returns image gradients (dy, dx) for each color channel. 3115 3116 Both output tensors have the same shape as the input: [batch_size, h, w, 3117 d]. The gradient values are organized so that [I(x+1, y) - I(x, y)] is in 3118 location (x, y). That means that dy will always have zeros in the last row, 3119 and dx will always have zeros in the last column. 3120 3121 Arguments: 3122 image: Tensor with shape [batch_size, h, w, d]. 3123 3124 Returns: 3125 Pair of tensors (dy, dx) holding the vertical and horizontal image 3126 gradients (1-step finite difference). 3127 3128 Raises: 3129 ValueError: If `image` is not a 4D tensor. 3130 """ 3131 if image.get_shape().ndims != 4: 3132 raise ValueError('image_gradients expects a 4D tensor ' 3133 '[batch_size, h, w, d], not %s.', image.get_shape()) 3134 image_shape = array_ops.shape(image) 3135 batch_size, height, width, depth = array_ops.unstack(image_shape) 3136 dy = image[:, 1:, :, :] - image[:, :-1, :, :] 3137 dx = image[:, :, 1:, :] - image[:, :, :-1, :] 3138 3139 # Return tensors with same size as original image by concatenating 3140 # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y). 3141 shape = array_ops.stack([batch_size, 1, width, depth]) 3142 dy = array_ops.concat([dy, array_ops.zeros(shape, image.dtype)], 1) 3143 dy = array_ops.reshape(dy, image_shape) 3144 3145 shape = array_ops.stack([batch_size, height, 1, depth]) 3146 dx = array_ops.concat([dx, array_ops.zeros(shape, image.dtype)], 2) 3147 dx = array_ops.reshape(dx, image_shape) 3148 3149 return dy, dx 3150 3151 3152 @tf_export('image.sobel_edges') 3153 def sobel_edges(image): 3154 """Returns a tensor holding Sobel edge maps. 3155 3156 Arguments: 3157 image: Image tensor with shape [batch_size, h, w, d] and type float32 or 3158 float64. The image(s) must be 2x2 or larger. 3159 3160 Returns: 3161 Tensor holding edge maps for each channel. Returns a tensor with shape 3162 [batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]], 3163 [dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Sobel filter. 3164 """ 3165 # Define vertical and horizontal Sobel filters. 3166 static_image_shape = image.get_shape() 3167 image_shape = array_ops.shape(image) 3168 kernels = [[[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 3169 [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]] 3170 num_kernels = len(kernels) 3171 kernels = np.transpose(np.asarray(kernels), (1, 2, 0)) 3172 kernels = np.expand_dims(kernels, -2) 3173 kernels_tf = constant_op.constant(kernels, dtype=image.dtype) 3174 3175 kernels_tf = array_ops.tile(kernels_tf, [1, 1, image_shape[-1], 1], 3176 name='sobel_filters') 3177 3178 # Use depth-wise convolution to calculate edge maps per channel. 3179 pad_sizes = [[0, 0], [1, 1], [1, 1], [0, 0]] 3180 padded = array_ops.pad(image, pad_sizes, mode='REFLECT') 3181 3182 # Output tensor has shape [batch_size, h, w, d * num_kernels]. 3183 strides = [1, 1, 1, 1] 3184 output = nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID') 3185 3186 # Reshape to [batch_size, h, w, d, num_kernels]. 3187 shape = array_ops.concat([image_shape, [num_kernels]], 0) 3188 output = array_ops.reshape(output, shape=shape) 3189 output.set_shape(static_image_shape.concatenate([num_kernels])) 3190 return output 3191 3192 3193 def resize_bicubic(images, 3194 size, 3195 align_corners=False, 3196 name=None, 3197 half_pixel_centers=False): 3198 return gen_image_ops.resize_bicubic( 3199 images=images, 3200 size=size, 3201 align_corners=align_corners, 3202 half_pixel_centers=half_pixel_centers, 3203 name=name) 3204 3205 3206 def resize_bilinear(images, 3207 size, 3208 align_corners=False, 3209 name=None, 3210 half_pixel_centers=False): 3211 return gen_image_ops.resize_bilinear( 3212 images=images, 3213 size=size, 3214 align_corners=align_corners, 3215 half_pixel_centers=half_pixel_centers, 3216 name=name) 3217 3218 3219 def resize_nearest_neighbor(images, 3220 size, 3221 align_corners=False, 3222 name=None, 3223 half_pixel_centers=False): 3224 return gen_image_ops.resize_nearest_neighbor( 3225 images=images, 3226 size=size, 3227 align_corners=align_corners, 3228 half_pixel_centers=half_pixel_centers, 3229 name=name) 3230 3231 3232 resize_area_deprecation = deprecation.deprecated( 3233 date=None, 3234 instructions=( 3235 'Use `tf.image.resize(...method=ResizeMethod.AREA...)` instead.')) 3236 tf_export(v1=['image.resize_area'])( 3237 resize_area_deprecation(gen_image_ops.resize_area)) 3238 3239 resize_bicubic_deprecation = deprecation.deprecated( 3240 date=None, 3241 instructions=( 3242 'Use `tf.image.resize(...method=ResizeMethod.BICUBIC...)` instead.')) 3243 tf_export(v1=['image.resize_bicubic'])( 3244 resize_bicubic_deprecation(resize_bicubic)) 3245 3246 resize_bilinear_deprecation = deprecation.deprecated( 3247 date=None, 3248 instructions=( 3249 'Use `tf.image.resize(...method=ResizeMethod.BILINEAR...)` instead.')) 3250 tf_export(v1=['image.resize_bilinear'])( 3251 resize_bilinear_deprecation(resize_bilinear)) 3252 3253 resize_nearest_neighbor_deprecation = deprecation.deprecated( 3254 date=None, 3255 instructions=( 3256 'Use `tf.image.resize(...method=ResizeMethod.NEAREST_NEIGHBOR...)` ' 3257 'instead.')) 3258 tf_export(v1=['image.resize_nearest_neighbor'])( 3259 resize_nearest_neighbor_deprecation(resize_nearest_neighbor)) 3260 3261 3262 @tf_export('image.crop_and_resize', v1=[]) 3263 def crop_and_resize_v2( 3264 image, 3265 boxes, 3266 box_indices, 3267 crop_size, 3268 method='bilinear', 3269 extrapolation_value=0, 3270 name=None): 3271 """Extracts crops from the input image tensor and resizes them. 3272 3273 Extracts crops from the input image tensor and resizes them using bilinear 3274 sampling or nearest neighbor sampling (possibly with aspect ratio change) to a 3275 common output size specified by `crop_size`. This is more general than the 3276 `crop_to_bounding_box` op which extracts a fixed size slice from the input 3277 image and does not allow resizing or aspect ratio change. 3278 3279 Returns a tensor with `crops` from the input `image` at positions defined at 3280 the bounding box locations in `boxes`. The cropped boxes are all resized (with 3281 bilinear or nearest neighbor interpolation) to a fixed 3282 `size = [crop_height, crop_width]`. The result is a 4-D tensor 3283 `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned. 3284 In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical 3285 results to using `tf.image.resize_bilinear()` or 3286 `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with 3287 `align_corners=True`. 3288 3289 Args: 3290 image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. 3291 Both `image_height` and `image_width` need to be positive. 3292 boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor 3293 specifies the coordinates of a box in the `box_ind[i]` image and is 3294 specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized 3295 coordinate value of `y` is mapped to the image coordinate at `y * 3296 (image_height - 1)`, so as the `[0, 1]` interval of normalized image 3297 height is mapped to `[0, image_height - 1]` in image height coordinates. 3298 We do allow `y1` > `y2`, in which case the sampled crop is an up-down 3299 flipped version of the original image. The width dimension is treated 3300 similarly. Normalized coordinates outside the `[0, 1]` range are allowed, 3301 in which case we use `extrapolation_value` to extrapolate the input image 3302 values. 3303 box_indices: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0, 3304 batch)`. The value of `box_ind[i]` specifies the image that the `i`-th box 3305 refers to. 3306 crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`. 3307 All cropped image patches are resized to this size. The aspect ratio of 3308 the image content is not preserved. Both `crop_height` and `crop_width` 3309 need to be positive. 3310 method: An optional string specifying the sampling method for resizing. It 3311 can be either `"bilinear"` or `"nearest"` and default to `"bilinear"`. 3312 Currently two sampling methods are supported: Bilinear and Nearest 3313 Neighbor. 3314 extrapolation_value: An optional `float`. Defaults to `0`. Value used for 3315 extrapolation, when applicable. 3316 name: A name for the operation (optional). 3317 3318 Returns: 3319 A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`. 3320 """ 3321 return gen_image_ops.crop_and_resize( 3322 image, boxes, box_indices, crop_size, method, extrapolation_value, name) 3323 3324 3325 @tf_export(v1=['image.crop_and_resize']) 3326 @deprecation.deprecated_args( 3327 None, 'box_ind is deprecated, use box_indices instead', 'box_ind') 3328 def crop_and_resize_v1( # pylint: disable=missing-docstring 3329 image, 3330 boxes, 3331 box_ind=None, 3332 crop_size=None, 3333 method='bilinear', 3334 extrapolation_value=0, 3335 name=None, 3336 box_indices=None): 3337 box_ind = deprecation.deprecated_argument_lookup( 3338 "box_indices", box_indices, "box_ind", box_ind) 3339 return gen_image_ops.crop_and_resize( 3340 image, boxes, box_ind, crop_size, method, extrapolation_value, name) 3341 3342 crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__ 3343 3344 3345 @tf_export(v1=['image.extract_glimpse']) 3346 def extract_glimpse( 3347 input, # pylint: disable=redefined-builtin 3348 size, 3349 offsets, 3350 centered=True, 3351 normalized=True, 3352 uniform_noise=True, 3353 name=None): 3354 """Extracts a glimpse from the input tensor. 3355 3356 Returns a set of windows called glimpses extracted at location 3357 `offsets` from the input tensor. If the windows only partially 3358 overlaps the inputs, the non overlapping areas will be filled with 3359 random noise. 3360 3361 The result is a 4-D tensor of shape `[batch_size, glimpse_height, 3362 glimpse_width, channels]`. The channels and batch dimensions are the 3363 same as that of the input tensor. The height and width of the output 3364 windows are specified in the `size` parameter. 3365 3366 The argument `normalized` and `centered` controls how the windows are built: 3367 3368 * If the coordinates are normalized but not centered, 0.0 and 1.0 3369 correspond to the minimum and maximum of each height and width 3370 dimension. 3371 * If the coordinates are both normalized and centered, they range from 3372 -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper 3373 left corner, the lower right corner is located at (1.0, 1.0) and the 3374 center is at (0, 0). 3375 * If the coordinates are not normalized they are interpreted as 3376 numbers of pixels. 3377 3378 Args: 3379 input: A `Tensor` of type `float32`. A 4-D float tensor of shape 3380 `[batch_size, height, width, channels]`. 3381 size: A `Tensor` of type `int32`. A 1-D tensor of 2 elements containing the 3382 size of the glimpses to extract. The glimpse height must be specified 3383 first, following by the glimpse width. 3384 offsets: A `Tensor` of type `float32`. A 2-D integer tensor of shape 3385 `[batch_size, 2]` containing the y, x locations of the center of each 3386 window. 3387 centered: An optional `bool`. Defaults to `True`. indicates if the offset 3388 coordinates are centered relative to the image, in which case the (0, 0) 3389 offset is relative to the center of the input images. If false, the (0,0) 3390 offset corresponds to the upper left corner of the input images. 3391 normalized: An optional `bool`. Defaults to `True`. indicates if the offset 3392 coordinates are normalized. 3393 uniform_noise: An optional `bool`. Defaults to `True`. indicates if the 3394 noise should be generated using a uniform distribution or a Gaussian 3395 distribution. 3396 name: A name for the operation (optional). 3397 3398 Returns: 3399 A `Tensor` of type `float32`. 3400 """ 3401 return gen_image_ops.extract_glimpse( 3402 input=input, 3403 size=size, 3404 offsets=offsets, 3405 centered=centered, 3406 normalized=normalized, 3407 uniform_noise=uniform_noise, 3408 name=name) 3409 3410 3411 @tf_export('image.extract_glimpse', v1=[]) 3412 def extract_glimpse_v2( 3413 input, # pylint: disable=redefined-builtin 3414 size, 3415 offsets, 3416 centered=True, 3417 normalized=True, 3418 noise='uniform', 3419 name=None): 3420 """Extracts a glimpse from the input tensor. 3421 3422 Returns a set of windows called glimpses extracted at location 3423 `offsets` from the input tensor. If the windows only partially 3424 overlaps the inputs, the non overlapping areas will be filled with 3425 random noise. 3426 3427 The result is a 4-D tensor of shape `[batch_size, glimpse_height, 3428 glimpse_width, channels]`. The channels and batch dimensions are the 3429 same as that of the input tensor. The height and width of the output 3430 windows are specified in the `size` parameter. 3431 3432 The argument `normalized` and `centered` controls how the windows are built: 3433 3434 * If the coordinates are normalized but not centered, 0.0 and 1.0 3435 correspond to the minimum and maximum of each height and width 3436 dimension. 3437 * If the coordinates are both normalized and centered, they range from 3438 -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper 3439 left corner, the lower right corner is located at (1.0, 1.0) and the 3440 center is at (0, 0). 3441 * If the coordinates are not normalized they are interpreted as 3442 numbers of pixels. 3443 3444 Args: 3445 input: A `Tensor` of type `float32`. A 4-D float tensor of shape 3446 `[batch_size, height, width, channels]`. 3447 size: A `Tensor` of type `int32`. A 1-D tensor of 2 elements containing the 3448 size of the glimpses to extract. The glimpse height must be specified 3449 first, following by the glimpse width. 3450 offsets: A `Tensor` of type `float32`. A 2-D integer tensor of shape 3451 `[batch_size, 2]` containing the y, x locations of the center of each 3452 window. 3453 centered: An optional `bool`. Defaults to `True`. indicates if the offset 3454 coordinates are centered relative to the image, in which case the (0, 0) 3455 offset is relative to the center of the input images. If false, the (0,0) 3456 offset corresponds to the upper left corner of the input images. 3457 normalized: An optional `bool`. Defaults to `True`. indicates if the offset 3458 coordinates are normalized. 3459 noise: An optional `string`. Defaults to `uniform`. indicates if the noise 3460 should be `uniform` (uniform distribution), `gaussian` (gaussian 3461 distribution), or `zero` (zero padding). 3462 name: A name for the operation (optional). 3463 3464 Returns: 3465 A `Tensor` of type `float32`. 3466 """ 3467 return gen_image_ops.extract_glimpse( 3468 input=input, 3469 size=size, 3470 offsets=offsets, 3471 centered=centered, 3472 normalized=normalized, 3473 noise=noise, 3474 uniform_noise=False, 3475 name=name) 3476 3477 3478 @tf_export('image.combined_non_max_suppression') 3479 def combined_non_max_suppression(boxes, 3480 scores, 3481 max_output_size_per_class, 3482 max_total_size, 3483 iou_threshold=0.5, 3484 score_threshold=float('-inf'), 3485 pad_per_class=False, 3486 name=None): 3487 """Greedily selects a subset of bounding boxes in descending order of score. 3488 3489 This operation performs non_max_suppression on the inputs per batch, across 3490 all classes. 3491 Prunes away boxes that have high intersection-over-union (IOU) overlap 3492 with previously selected boxes. Bounding boxes are supplied as 3493 [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any 3494 diagonal pair of box corners and the coordinates can be provided as normalized 3495 (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm 3496 is agnostic to where the origin is in the coordinate system. Also note that 3497 this algorithm is invariant to orthogonal transformations and translations 3498 of the coordinate system; thus translating or reflections of the coordinate 3499 system result in the same boxes being selected by the algorithm. 3500 The output of this operation is the final boxes, scores and classes tensor 3501 returned after performing non_max_suppression. 3502 3503 Args: 3504 boxes: A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q` 3505 is 1 then same boxes are used for all classes otherwise, if `q` is equal 3506 to number of classes, class-specific boxes are used. 3507 scores: A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]` 3508 representing a single score corresponding to each box (each row of boxes). 3509 max_output_size_per_class: A scalar integer `Tensor` representing the 3510 maximum number of boxes to be selected by non max suppression per class 3511 max_total_size: A scalar representing maximum number of boxes retained over 3512 all classes. 3513 iou_threshold: A float representing the threshold for deciding whether boxes 3514 overlap too much with respect to IOU. 3515 score_threshold: A float representing the threshold for deciding when to 3516 remove boxes based on score. 3517 pad_per_class: If false, the output nmsed boxes, scores and classes are 3518 padded/clipped to `max_total_size`. If true, the output nmsed boxes, 3519 scores and classes are padded to be of length 3520 `max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in 3521 which case it is clipped to `max_total_size`. Defaults to false. 3522 name: A name for the operation (optional). 3523 3524 Returns: 3525 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor 3526 containing the non-max suppressed boxes. 3527 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing 3528 the scores for the boxes. 3529 'nmsed_classes': A [batch_size, max_detections] float32 tensor 3530 containing the class for boxes. 3531 'valid_detections': A [batch_size] int32 tensor indicating the number of 3532 valid detections per batch item. Only the top valid_detections[i] entries 3533 in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the 3534 entries are zero paddings. 3535 """ 3536 with ops.name_scope(name, 'combined_non_max_suppression'): 3537 iou_threshold = ops.convert_to_tensor( 3538 iou_threshold, dtype=dtypes.float32, name='iou_threshold') 3539 score_threshold = ops.convert_to_tensor( 3540 score_threshold, dtype=dtypes.float32, name='score_threshold') 3541 return gen_image_ops.combined_non_max_suppression( 3542 boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, 3543 score_threshold, pad_per_class) 3544