1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Support for sorting tensors. 16 17 @@sort 18 """ 19 20 from __future__ import absolute_import 21 from __future__ import division 22 from __future__ import print_function 23 24 from tensorflow.python.framework import ops as framework_ops 25 from tensorflow.python.framework import tensor_util 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import nn_ops 29 30 31 def sort(values, axis=-1, direction='ASCENDING', name=None): 32 """Sorts a tensor. 33 34 Args: 35 values: 1-D or higher numeric `Tensor`. 36 axis: The axis along which to sort. The default is -1, which sorts the last 37 axis. 38 direction: The direction in which to sort the values (`'ASCENDING'` or 39 `'DESCENDING'`). 40 name: Optional name for the operation. 41 42 Returns: 43 A `Tensor` with the same dtype and shape as `values`, with the elements 44 sorted along the given `axis`. 45 46 Raises: 47 ValueError: If axis is not a constant scalar, or the direction is invalid. 48 """ 49 with framework_ops.name_scope(name, 'sort'): 50 if direction not in _SORT_IMPL: 51 raise ValueError('%s should be one of %s' % 52 (direction, ', '.join(sorted(_SORT_IMPL.keys())))) 53 # Axis must be an integer, not a Tensor. 54 axis = framework_ops.convert_to_tensor(axis, name='axis') 55 axis_static = tensor_util.constant_value(axis) 56 if axis.shape.ndims != 0 or axis_static is None: 57 raise ValueError('axis must be a constant scalar') 58 axis_static = int(axis_static) # Avoids NumPy casting error 59 60 values = framework_ops.convert_to_tensor(values, name='values') 61 62 return _SORT_IMPL[direction](values, axis_static) 63 64 65 def _descending_sort(values, axis): 66 """Sorts values in reverse using `top_k`. 67 68 Args: 69 values: Tensor of numeric values. 70 axis: Index of the axis which values should be sorted along. 71 72 Returns: 73 The sorted values. 74 """ 75 k = array_ops.shape(values)[axis] 76 rank = array_ops.rank(values) 77 # Fast path: sorting the last axis. 78 if axis == -1 or axis + 1 == values.get_shape().ndims: 79 return nn_ops.top_k(values, k)[0] 80 81 # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. 82 if axis < 0: 83 # Make axis a Tensor with the real axis index if needed. 84 axis += rank 85 transposition = array_ops.concat( 86 [ 87 # Axes up to axis are unchanged. 88 math_ops.range(axis), 89 # Swap axis and rank - 1. 90 [rank - 1], 91 # Axes in [axis + 1, rank - 1) are unchanged. 92 math_ops.range(axis + 1, rank - 1), 93 # Swap axis and rank - 1. 94 [axis] 95 ], 96 axis=0) 97 top_k_input = array_ops.transpose(values, transposition) 98 values, unused_indices = nn_ops.top_k(top_k_input, k) 99 # transposition contains a single cycle of length 2 (swapping 2 elements), 100 # so it is an involution (it is its own inverse). 101 return array_ops.transpose(values, transposition) 102 103 104 def _ascending_sort(values, axis): 105 # Negate the values to get the ascending order from descending sort. 106 values_or_indices = _descending_sort(-values, axis) 107 return -values_or_indices 108 109 110 _SORT_IMPL = { 111 'ASCENDING': _ascending_sort, 112 'DESCENDING': _descending_sort, 113 } 114