Home | History | Annotate | Download | only in ragged
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Private convenience functions for RaggedTensors.
     16 
     17 None of these methods are exposed in the main "ragged" package.
     18 """
     19 
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import check_ops
     28 from tensorflow.python.ops import gen_ragged_math_ops
     29 from tensorflow.python.ops import math_ops
     30 
     31 
     32 def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
     33   """Converts the given value to an integer Tensor."""
     34   tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
     35   if tensor.dtype.is_integer:
     36     tensor = math_ops.cast(tensor, dtype)
     37   else:
     38     raise TypeError(
     39         "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype))
     40   return tensor
     41 
     42 
     43 def get_positive_axis(axis, ndims):
     44   """Validate an `axis` parameter, and normalize it to be positive.
     45 
     46   If `ndims` is known (i.e., not `None`), then check that `axis` is in the
     47   range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
     48   `axis + ndims` (otherwise).
     49   If `ndims` is not known, and `axis` is positive, then return it as-is.
     50   If `ndims` is not known, and `axis` is negative, then report an error.
     51 
     52   Args:
     53     axis: An integer constant
     54     ndims: An integer constant, or `None`
     55 
     56   Returns:
     57     The normalized `axis` value.
     58 
     59   Raises:
     60     ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
     61       `ndims is None`.
     62   """
     63   if not isinstance(axis, int):
     64     raise TypeError("axis must be an int; got %s" % type(axis).__name__)
     65   if ndims is not None:
     66     if 0 <= axis < ndims:
     67       return axis
     68     elif -ndims <= axis < 0:
     69       return axis + ndims
     70     else:
     71       raise ValueError(
     72           "axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims))
     73   elif axis < 0:
     74     raise ValueError("axis may only be negative if ndims is statically known.")
     75   return axis
     76 
     77 
     78 def assert_splits_match(nested_splits_lists):
     79   """Checks that the given splits lists are identical.
     80 
     81   Performs static tests to ensure that the given splits lists are identical,
     82   and returns a list of control dependency op tensors that check that they are
     83   fully identical.
     84 
     85   Args:
     86     nested_splits_lists: A list of nested_splits_lists, where each split_list is
     87       a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
     88       ragged dimension to innermost ragged dimension.
     89 
     90   Returns:
     91     A list of control dependency op tensors.
     92   Raises:
     93     ValueError: If the splits are not identical.
     94   """
     95   error_msg = "Inputs must have identical ragged splits"
     96   for splits_list in nested_splits_lists:
     97     if len(splits_list) != len(nested_splits_lists[0]):
     98       raise ValueError(error_msg)
     99   return [
    100       check_ops.assert_equal(s1, s2, message=error_msg)
    101       for splits_list in nested_splits_lists[1:]
    102       for (s1, s2) in zip(nested_splits_lists[0], splits_list)
    103   ]
    104 
    105 
    106 # This op is intended to exactly match the semantics of numpy.repeat, with
    107 # one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
    108 # when axis is not specified.  Rather than implement that special behavior, we
    109 # simply make `axis` be a required argument.
    110 #
    111 # External (OSS) `tf.repeat` feature request:
    112 # https://github.com/tensorflow/tensorflow/issues/8246
    113 def repeat(data, repeats, axis, name=None):
    114   """Repeats elements of `data`.
    115 
    116   Args:
    117     data: An `N`-dimensional tensor.
    118     repeats: A 1-D integer tensor specifying how many times each element in
    119       `axis` should be repeated.  `len(repeats)` must equal `data.shape[axis]`.
    120       Supports broadcasting from a scalar value.
    121     axis: `int`.  The axis along which to repeat values.  Must be less than
    122       `max(N, 1)`.
    123     name: A name for the operation.
    124 
    125   Returns:
    126     A tensor with `max(N, 1)` dimensions.  Has the same shape as `data`,
    127     except that dimension `axis` has size `sum(repeats)`.
    128 
    129   #### Examples:
    130     ```python
    131     >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
    132     ['a', 'a', 'a', 'c', 'c']
    133     >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
    134     [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
    135     >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
    136     [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
    137     ```
    138   """
    139   if not isinstance(axis, int):
    140     raise TypeError("axis must be an int; got %s" % type(axis).__name__)
    141 
    142   with ops.name_scope(name, "Repeat", [data, repeats]):
    143     data = ops.convert_to_tensor(data, name="data")
    144     repeats = convert_to_int_tensor(repeats, name="repeats")
    145     repeats.shape.with_rank_at_most(1)
    146 
    147     # If `data` is a scalar, then upgrade it to a vector.
    148     data = _with_nonzero_rank(data)
    149     data_shape = array_ops.shape(data)
    150 
    151     # If `axis` is negative, then convert it to a positive value.
    152     axis = get_positive_axis(axis, data.shape.ndims)
    153 
    154     # Check data Tensor shapes.
    155     if repeats.shape.ndims == 1:
    156       data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
    157 
    158     # If we know that `repeats` is a scalar, then we can just tile & reshape.
    159     if repeats.shape.ndims == 0:
    160       expanded = array_ops.expand_dims(data, axis + 1)
    161       tiled = tile_one_dimension(expanded, axis + 1, repeats)
    162       result_shape = array_ops.concat(
    163           [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
    164       return array_ops.reshape(tiled, result_shape)
    165 
    166     # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
    167     if repeats.shape.ndims != axis + 1:
    168       repeats_shape = array_ops.shape(repeats)
    169       repeats_ndims = array_ops.rank(repeats)
    170       broadcast_shape = array_ops.concat(
    171           [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
    172       repeats = array_ops.broadcast_to(repeats, broadcast_shape)
    173       repeats.set_shape([None] * (axis + 1))
    174 
    175     # Create a "sequence mask" based on `repeats`, where slices across `axis`
    176     # contain one `True` value for each repetition.  E.g., if
    177     # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
    178     max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats))
    179     mask = array_ops.sequence_mask(repeats, max_repeat)
    180 
    181     # Add a new dimension around each value that needs to be repeated, and
    182     # then tile that new dimension to match the maximum number of repetitions.
    183     expanded = array_ops.expand_dims(data, axis + 1)
    184     tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
    185 
    186     # Use `boolean_mask` to discard the extra repeated values.  This also
    187     # flattens all dimensions up through `axis`.
    188     masked = array_ops.boolean_mask(tiled, mask)
    189 
    190     # Reshape the output tensor to add the outer dimensions back.
    191     if axis == 0:
    192       result = masked
    193     else:
    194       result_shape = array_ops.concat(
    195           [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
    196       result = array_ops.reshape(masked, result_shape)
    197 
    198     # Preserve shape information.
    199     if data.shape.ndims is not None:
    200       new_axis_size = 0 if repeats.shape[0] == 0 else None
    201       result.set_shape(data.shape[:axis].concatenate(
    202           [new_axis_size]).concatenate(data.shape[axis + 1:]))
    203 
    204     return result
    205 
    206 
    207 def tile_one_dimension(data, axis, multiple):
    208   """Tiles a single dimension of a tensor."""
    209   # Assumes axis is a nonnegative int.
    210   if data.shape.ndims is not None:
    211     multiples = [1] * data.shape.ndims
    212     multiples[axis] = multiple
    213   else:
    214     ones = array_ops.ones(array_ops.rank(data), dtypes.int32)
    215     multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]],
    216                                  axis=0)
    217   return array_ops.tile(data, multiples)
    218 
    219 
    220 def _with_nonzero_rank(data):
    221   """If `data` is scalar, then add a dimension; otherwise return as-is."""
    222   if data.shape.ndims is not None:
    223     if data.shape.ndims == 0:
    224       return array_ops.stack([data])
    225     else:
    226       return data
    227   else:
    228     data_shape = array_ops.shape(data)
    229     data_ndims = array_ops.rank(data)
    230     return array_ops.reshape(
    231         data,
    232         array_ops.concat([[1], data_shape], axis=0)[-data_ndims:])
    233 
    234 
    235 def lengths_to_splits(lengths):
    236   """Returns splits corresponding to the given lengths."""
    237   return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1)
    238 
    239 
    240 def repeat_ranges(params, splits, repeats):
    241   """Repeats each range of `params` (as specified by `splits`) `repeats` times.
    242 
    243   Let the `i`th range of `params` be defined as
    244   `params[splits[i]:splits[i + 1]]`.  Then this function returns a tensor
    245   containing range 0 repeated `repeats[0]` times, followed by range 1 repeated
    246   `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times.
    247 
    248   Args:
    249     params: The `Tensor` whose values should be repeated.
    250     splits: A splits tensor indicating the ranges of `params` that should be
    251       repeated.
    252     repeats: The number of times each range should be repeated.  Supports
    253       broadcasting from a scalar value.
    254 
    255   Returns:
    256     A `Tensor` with the same rank and type as `params`.
    257 
    258   #### Example:
    259     ```python
    260     >>> repeat_ranges(['a', 'b', 'c'], [0, 2, 3], 3)
    261     ['a', 'b', 'a', 'b', 'a', 'b', 'c', 'c', 'c']
    262     ```
    263   """
    264   # Divide `splits` into starts and limits, and repeat them `repeats` times.
    265   if repeats.shape.ndims != 0:
    266     repeated_starts = repeat(splits[:-1], repeats, axis=0)
    267     repeated_limits = repeat(splits[1:], repeats, axis=0)
    268   else:
    269     # Optimization: we can just call repeat once, and then slice the result.
    270     repeated_splits = repeat(splits, repeats, axis=0)
    271     n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0]
    272     repeated_starts = repeated_splits[:n_splits - repeats]
    273     repeated_limits = repeated_splits[repeats:]
    274 
    275   # Get indices for each range from starts to limits, and use those to gather
    276   # the values in the desired repetition pattern.
    277   one = array_ops.ones((), repeated_starts.dtype)
    278   offsets = gen_ragged_math_ops.ragged_range(
    279       repeated_starts, repeated_limits, one)
    280   return array_ops.gather(params, offsets.rt_dense_values)
    281