Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Support for sorting tensors.
     16 
     17 @@sort
     18 """
     19 
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 from tensorflow.python.framework import ops as framework_ops
     25 from tensorflow.python.framework import tensor_util
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn_ops
     29 
     30 
     31 def sort(values, axis=-1, direction='ASCENDING', name=None):
     32   """Sorts a tensor.
     33 
     34   Args:
     35     values: 1-D or higher numeric `Tensor`.
     36     axis: The axis along which to sort. The default is -1, which sorts the last
     37         axis.
     38     direction: The direction in which to sort the values (`'ASCENDING'` or
     39         `'DESCENDING'`).
     40     name: Optional name for the operation.
     41 
     42   Returns:
     43     A `Tensor` with the same dtype and shape as `values`, with the elements
     44         sorted along the given `axis`.
     45 
     46   Raises:
     47     ValueError: If axis is not a constant scalar, or the direction is invalid.
     48   """
     49   with framework_ops.name_scope(name, 'sort'):
     50     if direction not in _SORT_IMPL:
     51       raise ValueError('%s should be one of %s' %
     52                        (direction, ', '.join(sorted(_SORT_IMPL.keys()))))
     53     # Axis must be an integer, not a Tensor.
     54     axis = framework_ops.convert_to_tensor(axis, name='axis')
     55     axis_static = tensor_util.constant_value(axis)
     56     if axis.shape.ndims != 0 or axis_static is None:
     57       raise ValueError('axis must be a constant scalar')
     58     axis_static = int(axis_static)  # Avoids NumPy casting error
     59 
     60     values = framework_ops.convert_to_tensor(values, name='values')
     61 
     62     return _SORT_IMPL[direction](values, axis_static)
     63 
     64 
     65 def _descending_sort(values, axis):
     66   """Sorts values in reverse using `top_k`.
     67 
     68   Args:
     69     values: Tensor of numeric values.
     70     axis: Index of the axis which values should be sorted along.
     71 
     72   Returns:
     73     The sorted values.
     74   """
     75   k = array_ops.shape(values)[axis]
     76   rank = array_ops.rank(values)
     77   # Fast path: sorting the last axis.
     78   if axis == -1 or axis + 1 == values.get_shape().ndims:
     79     return nn_ops.top_k(values, k)[0]
     80 
     81   # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
     82   if axis < 0:
     83     # Make axis a Tensor with the real axis index if needed.
     84     axis += rank
     85   transposition = array_ops.concat(
     86       [
     87           # Axes up to axis are unchanged.
     88           math_ops.range(axis),
     89           # Swap axis and rank - 1.
     90           [rank - 1],
     91           # Axes in [axis + 1, rank - 1) are unchanged.
     92           math_ops.range(axis + 1, rank - 1),
     93           # Swap axis and rank - 1.
     94           [axis]
     95       ],
     96       axis=0)
     97   top_k_input = array_ops.transpose(values, transposition)
     98   values, unused_indices = nn_ops.top_k(top_k_input, k)
     99   # transposition contains a single cycle of length 2 (swapping 2 elements),
    100   # so it is an involution (it is its own inverse).
    101   return array_ops.transpose(values, transposition)
    102 
    103 
    104 def _ascending_sort(values, axis):
    105   # Negate the values to get the ascending order from descending sort.
    106   values_or_indices = _descending_sort(-values, axis)
    107   return -values_or_indices
    108 
    109 
    110 _SORT_IMPL = {
    111     'ASCENDING': _ascending_sort,
    112     'DESCENDING': _descending_sort,
    113 }
    114