Home | History | Annotate | Download | only in ragged
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for ragged_array_ops.concat."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from absl.testing import parameterized
     22 
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops.ragged import ragged_concat_ops
     30 from tensorflow.python.ops.ragged import ragged_factory_ops
     31 from tensorflow.python.ops.ragged import ragged_test_util
     32 from tensorflow.python.platform import googletest
     33 
     34 
     35 @test_util.run_all_in_graph_and_eager_modes
     36 class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
     37                          parameterized.TestCase):
     38 
     39   def _rt_inputs_to_tensors(self, rt_inputs, ragged_ranks=None):
     40     if ragged_ranks is None:
     41       ragged_ranks = [None] * len(rt_inputs)
     42     return [  # pylint: disable=g-long-ternary
     43         ragged_factory_ops.constant(rt_input, ragged_rank=rrank)
     44         if rrank != 0 else constant_op.constant(rt_input)
     45         for (rt_input, rrank) in zip(rt_inputs, ragged_ranks)
     46     ]
     47 
     48   @parameterized.parameters(
     49       dict(
     50           descr='Two rank-2 inputs with empty value axis=1',
     51           rt_inputs=([[]], [[]]),
     52           axis=1,
     53           expected=[[]]),
     54       dict(
     55           descr='Two rank-2 inputs (ragged_rank=1), axis=0',
     56           rt_inputs=(
     57               [['a00', 'a01'], [], ['a20', 'a21']],   # shape=(3, None)
     58               [['b00'], ['b10']]),                    # shape=(2, None)
     59           axis=0,
     60           expected=[[b'a00', b'a01'], [], [b'a20', b'a21'], [b'b00'],
     61                     [b'b10']]),
     62       dict(
     63           descr='Two rank-2 inputs (ragged_rank=1), axis=1',
     64           rt_inputs=(
     65               [['a00', 'a01'], [], ['a20', 'a21', 'a22']],   # shape=(3, None)
     66               [['b00'], ['b10', 'b11', 'b12'], ['b20']]),    # shape=(3, None)
     67           axis=1,
     68           expected=[
     69               [b'a00', b'a01', b'b00'],
     70               [b'b10', b'b11', b'b12'],
     71               [b'a20', b'a21', b'a22', b'b20']]),
     72       dict(
     73           descr='Two rank-2 inputs (ragged_rank=1), axis=-2',
     74           rt_inputs=(
     75               [['a00', 'a01'], [], ['a20', 'a21']],   # shape=(3, None)
     76               [['b00'], ['b10']]),                    # shape=(2, None)
     77           axis=-2,
     78           expected=[[b'a00', b'a01'], [], [b'a20', b'a21'], [b'b00'],
     79                     [b'b10']]),
     80       dict(
     81           descr='Two rank-2 inputs (ragged_rank=1), axis=-1',
     82           rt_inputs=(
     83               [['a00', 'a01'], [], ['a20', 'a21', 'a22']],   # shape=(3, None)
     84               [['b00'], ['b10', 'b11', 'b12'], ['b20']]),    # shape=(3, None)
     85           axis=-1,
     86           expected=[
     87               [b'a00', b'a01', b'b00'],
     88               [b'b10', b'b11', b'b12'],
     89               [b'a20', b'a21', b'a22', b'b20']],
     90           expected_shape=[3, None]),
     91       dict(
     92           descr='Three rank-2 inputs (ragged_rank=1), axis=0',
     93           rt_inputs=(
     94               [['a00', 'a01'], [], ['a20', 'a21', 'a22']],   # shape=(3, None)
     95               [['b00'], ['b10']],                            # shape=(2, None)
     96               [['c00'], ['c10', 'c11'], ['c21']]),           # shape=(3, None)
     97           axis=0,
     98           expected=[[b'a00', b'a01'], [], [b'a20', b'a21', b'a22'], [b'b00'],
     99                     [b'b10'], [b'c00'], [b'c10', b'c11'], [b'c21']]),
    100       dict(
    101           descr='Three rank-2 inputs (ragged_rank=1), axis=1',
    102           rt_inputs=(
    103               [['a00', 'a01'], [], ['a20', 'a21', 'a22']],   # shape=(3, None)
    104               [['b00'], ['b10', 'b11', 'b12'], ['b20']],     # shape=(3, None)
    105               [[], ['c10', 'c11'], ['c20', 'c21']]),         # shape=(3, None)
    106           axis=1,
    107           expected=[
    108               [b'a00', b'a01', b'b00'],
    109               [b'b10', b'b11', b'b12', b'c10', b'c11'],
    110               [b'a20', b'a21', b'a22', b'b20', b'c20', b'c21']]),
    111       dict(
    112           descr='Three rank-3 inputs (ragged_rank=2), axis=0',
    113           rt_inputs=(
    114               [[['a000', 'a001'], ['a010']],
    115                [['a100', 'a101', 'a102'], ['a110', 'a111']]],
    116               [[['b000']], [['b100', 'b101'], ['b110']]],
    117               [[], [['c100', 'c101', 'c102', 'c103']], [[], ['c210', 'c211']]]),
    118           axis=0,
    119           expected=[
    120               [[b'a000', b'a001'], [b'a010']],
    121               [[b'a100', b'a101', b'a102'], [b'a110', b'a111']],
    122               [[b'b000']],
    123               [[b'b100', b'b101'], [b'b110']],
    124               [],
    125               [[b'c100', b'c101', b'c102', b'c103']],
    126               [[], [b'c210', b'c211']]]),
    127       dict(
    128           descr='Three rank-3 inputs (ragged_rank=2), axis=1',
    129           rt_inputs=(
    130               [[['a000', 'a001'], ['a010']],
    131                [['a100', 'a101', 'a102'], ['a110', 'a111']]],
    132               [[['b000']], [['b100', 'b101'], ['b110']]],
    133               [[], [[], ['c110', 'c111']]]),
    134           axis=1,
    135           expected=[
    136               [[b'a000', b'a001'], [b'a010'], [b'b000']],
    137               [[b'a100', b'a101', b'a102'], [b'a110', b'a111'],
    138                [b'b100', b'b101'], [b'b110'], [], [b'c110', b'c111']]]),
    139       dict(
    140           descr='Three rank-3 inputs (ragged_rank=2), axis=2',
    141           rt_inputs=(
    142               [[['a000', 'a001'], ['a010']],
    143                [['a100', 'a101', 'a102'], ['a110', 'a111']]],
    144               [[[], ['b010', 'b011']], [['b100', 'b101'], ['b110']]],
    145               [[['c000'], ['c010']], [[], ['c110', 'c111']]]),
    146           axis=2,
    147           expected=[
    148               [[b'a000', b'a001', b'c000'],
    149                [b'a010', b'b010', b'b011', b'c010']],
    150               [[b'a100', b'a101', b'a102', b'b100', b'b101'],
    151                [b'a110', b'a111', b'b110', b'c110', b'c111']]]),
    152       dict(
    153           descr='Three rank-3 inputs (ragged_rank=2), axis=-1',
    154           rt_inputs=(
    155               [[['a000', 'a001'], ['a010']],
    156                [['a100', 'a101', 'a102'], ['a110', 'a111']]],
    157               [[[], ['b010', 'b011']], [['b100', 'b101'], ['b110']]],
    158               [[['c000'], ['c010']], [[], ['c110', 'c111']]]),
    159           axis=-1,
    160           expected=[
    161               [[b'a000', b'a001', b'c000'],
    162                [b'a010', b'b010', b'b011', b'c010']],
    163               [[b'a100', b'a101', b'a102', b'b100', b'b101'],
    164                [b'a110', b'a111', b'b110', b'c110', b'c111']]]),
    165       dict(
    166           descr='ragged_concat([uniform, ragged, uniform], axis=1)',
    167           ragged_ranks=[0, 1, 0],
    168           rt_inputs=(
    169               [['0('], ['1('], ['2(']],                   # shape=(3, 1)
    170               [['b00'], ['b10', 'b11', 'b12'], ['b20']],  # shape=(3, None)
    171               [[')0'], [')1'], [')2']]),                  # shape=(3, 1)
    172           axis=1,
    173           expected=[
    174               [b'0(', b'b00', b')0'],
    175               [b'1(', b'b10', b'b11', b'b12', b')1'],
    176               [b'2(', b'b20', b')2']]),
    177       dict(
    178           descr='ragged_concat([uniform, uniform], axis=0)',
    179           ragged_ranks=[0, 0],
    180           rt_inputs=(
    181               [['a00', 'a01'], ['a10', 'a11'], ['a20', 'a21']],  # shape=(3, 2)
    182               [['b00', 'b01', 'b02'], ['b10', 'b11', 'b12']]),   # shape=(2, 3)
    183           axis=0,
    184           expected=[
    185               [b'a00', b'a01'], [b'a10', b'a11'], [b'a20', b'a21'],
    186               [b'b00', b'b01', b'b02'], [b'b10', b'b11', b'b12']],
    187           expected_ragged_rank=1),
    188       dict(
    189           descr='ragged_concat([uniform, ragged], axis=0)',
    190           ragged_ranks=[0, 1],
    191           rt_inputs=(
    192               [['a00', 'a01'], ['a10', 'a11'], ['a20', 'a21']],  # shape=(3, 2)
    193               [['b00', 'b01', 'b02'], ['b10', 'b11', 'b12']]),   # shape=(2, 3)
    194           axis=0,
    195           expected=[
    196               [b'a00', b'a01'], [b'a10', b'a11'], [b'a20', b'a21'],
    197               [b'b00', b'b01', b'b02'], [b'b10', b'b11', b'b12']]),
    198       dict(
    199           descr='ragged_concat([uniform, ragged], axis=0) with rank-3 inputs',
    200           ragged_ranks=[0, 2],
    201           rt_inputs=(
    202               [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],  # shape = (2, 2, 2)
    203               [[[8], [8, 8]]]),                      # shape = (2, None, None)
    204           axis=0,
    205           expected=[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], [8, 8]]]),
    206       dict(
    207           descr='Two rank-3 inputs with ragged_rank=1, axis=-1',
    208           ragged_ranks=[1, 1],
    209           rt_inputs=(
    210               [[[0, 1], [2, 3], [4, 5]], [], [[6, 7], [8, 9]]],
    211               [[[9, 8], [7, 6], [5, 4]], [], [[3, 2], [1, 0]]]),
    212           axis=-1,
    213           expected=[
    214               [[0, 1, 9, 8], [2, 3, 7, 6], [4, 5, 5, 4]], [],
    215               [[6, 7, 3, 2], [8, 9, 1, 0]]],
    216           expected_ragged_rank=1),
    217       dict(
    218           descr='ragged_concat([vector, vector], axis=0)',
    219           ragged_ranks=[0, 0],
    220           rt_inputs=([1, 2, 3], [4, 5, 6]),
    221           axis=0,
    222           expected=[1, 2, 3, 4, 5, 6]),
    223       dict(
    224           descr='One input (so ragged_conat is a noop)',
    225           rt_inputs=([['a00', 'a01'], [], ['a20', 'a21']],),
    226           axis=0,
    227           expected=[[b'a00', b'a01'], [], [b'a20', b'a21']]),
    228   )   # pyformat: disable
    229   def testRaggedConcat(self,
    230                        descr,
    231                        rt_inputs,
    232                        axis,
    233                        expected,
    234                        ragged_ranks=None,
    235                        expected_ragged_rank=None,
    236                        expected_shape=None):
    237     rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks)
    238     concatenated = ragged_concat_ops.concat(rt_inputs, axis)
    239     if expected_ragged_rank is not None:
    240       self.assertEqual(concatenated.ragged_rank, expected_ragged_rank)
    241     if expected_shape is not None:
    242       self.assertEqual(concatenated.shape.as_list(), expected_shape)
    243     self.assertRaggedEqual(concatenated, expected)
    244 
    245   @parameterized.parameters(
    246       dict(
    247           rt_inputs=(),
    248           axis=0,
    249           error=ValueError,
    250           message=r'rt_inputs may not be empty\.'),
    251       dict(
    252           rt_inputs=([[1, 2]], [[3, 4]]),
    253           axis=r'foo',
    254           error=TypeError,
    255           message='axis must be an int'),
    256       dict(
    257           rt_inputs=([[1, 2]], [[3, 4]]),
    258           axis=-3,
    259           error=ValueError,
    260           message='axis=-3 out of bounds: expected -2<=axis<2'),
    261       dict(
    262           rt_inputs=([[1, 2]], [[3, 4]]),
    263           axis=2,
    264           error=ValueError,
    265           message='axis=2 out of bounds: expected -2<=axis<2'),
    266       dict(
    267           ragged_ranks=(0, 0),
    268           rt_inputs=([[1, 2]], [[3, 4], [5, 6]]),
    269           axis=1,
    270           error=(ValueError, errors.InvalidArgumentError)),
    271   )
    272   def testStaticError(self,
    273                       rt_inputs,
    274                       axis,
    275                       error,
    276                       message=None,
    277                       ragged_ranks=None):
    278     rt_inputs = self._rt_inputs_to_tensors(rt_inputs, ragged_ranks)
    279     self.assertRaisesRegexp(error, message, ragged_concat_ops.concat, rt_inputs,
    280                             axis)
    281 
    282   @parameterized.parameters([
    283       dict(
    284           ragged_ranks=(1, 1),
    285           rt_inputs=([[1, 2]], [[3, 4], [5, 6]]),
    286           axis=1,
    287           error=errors.InvalidArgumentError,
    288           message='Input tensors have incompatible shapes'),
    289   ])
    290   def testRuntimeError(self, rt_inputs, axis, error, message,
    291                        ragged_ranks=None):
    292     if context.executing_eagerly():
    293       return
    294     rt_inputs = [
    295         array_ops.placeholder_with_default(rt, shape=None) for rt in rt_inputs
    296     ]
    297     concatenated = ragged_concat_ops.concat(rt_inputs, axis)
    298     with self.assertRaisesRegexp(error, message):
    299       self.evaluate(concatenated)
    300 
    301   def testNegativeAxisWithUnknownRankError(self):
    302     if context.executing_eagerly():
    303       return
    304     rt_inputs = [
    305         array_ops.placeholder(dtypes.int64),
    306         array_ops.placeholder(dtypes.int64)
    307     ]
    308     self.assertRaisesRegexp(
    309         ValueError, r'axis may only be negative if ndims is statically known.',
    310         ragged_concat_ops.concat, rt_inputs, -1)
    311 
    312   def testSingleTensorInput(self):
    313     """Tests ragged_concat with a single tensor input.
    314 
    315     Usually, we pass a list of values in for rt_inputs.  However, you can
    316     also pass in a single value (as with tf.concat), in which case it simply
    317     returns that tensor.  This test exercises that path.
    318     """
    319     rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]])
    320     concatenated = ragged_concat_ops.concat(rt_inputs, 0)
    321     self.assertRaggedEqual(concatenated, [[1, 2], [3, 4]])
    322 
    323 
    324 if __name__ == '__main__':
    325   googletest.main()
    326